Source code for buildcal

#!/usr/bin/env python

from __future__ import print_function
import glob
import re
import os
import time
import numpy as np
from scipy import ndimage, interpolate
try:
    import primitives
    import utr
    from image import Image
    from parallel import Task, Consumer
except:
    import charis.primitives as primitives
    import charis.utr as utr
    from charis.image import Image
    from charis.parallel import Task, Consumer
import logging
from astropy.io import fits
import multiprocessing
import copy
import shutil
import sys
import pkg_resources

log = logging.getLogger('main')


[docs]def buildcalibrations(inImage, inLam, mask, indir, outdir="./", order=3, lam1=1150, lam2=2400, R=25, trans=None, upsample=True, header=None, ncpus=multiprocessing.cpu_count()): """ Build the calibration files needed to extract data cubes from sequences of CHARIS reads. Inputs: 1. inImage: Image class, should include count rate and ivar for a narrow-band flatfield calibration image. 2. inLam: wavelength in nm of inImage 3. mask: bad pixel mask, =0 for bad pixels 4. indir: directory where master calibration files live Optional inputs: 1. outdir: directory in which to place 1. order: int, order of polynomial fit to position(lambda). Default 3 (strongly recommended). 2. lam1: minimum wavelength (in nm) of the bandpass Default 1150 3. lam2: maximum wavelength (in nm) of the bandpass Default 2400 4. R: spectral resolution of the PSFlets templates Default 25 5. trans: ndarray, trans[:, 0] = wavelength in nm, trans[:, 1] is fractional transmission through the filter and atmosphere. Default None --> trans[:, 1] = 1 6. header: FITS header, to which will be appended the shifts and rotation angle between the stored and the fitted wavelength solutions. Default None. 7. ncpus: number of threads for multithreading. Default multiprocessing.cpu_count() Returns None, writes calibration files to outdir. """ tstart = time.time() ################################################################# # Fit the PSFlet positions on the input image, compute the shift # in the mean position (coefficients 0 and 10) and in the linear # component of the fit (coefficients 1, 4, 11, and 14). The # comparison point is the location solution for this wavelength in # the existing calibration files. ################################################################# log.info("Loading wavelength solution from " + indir + "/lamsol.dat") lam = np.loadtxt(os.path.join(indir, "lamsol.dat"))[:, 0] allcoef = np.loadtxt(os.path.join(indir, "lamsol.dat"))[:, 1:] psftool = primitives.PSFLets() oldcoef = psftool.monochrome_coef(inLam, lam, allcoef, order=order) print('Generating new wavelength solution') x, y, good, newcoef = primitives.locatePSFlets(inImage, polyorder=order, coef=oldcoef) psftool.geninterparray(lam, allcoef, order=order) dcoef = newcoef - oldcoef indx = np.asarray([0, 1, 4, 10, 11, 14]) psftool.interp_arr[0][indx] += dcoef[indx] psftool.genpixsol(lam, allcoef, order=3, lam1=lam1/1.05, lam2=lam2*1.05) psftool.savepixsol(outdir=outdir) ################################################################# # Record the shift in the spot locations. ################################################################# phi1 = np.mean([np.arctan2(oldcoef[4], oldcoef[1]), np.arctan2(-oldcoef[11], oldcoef[14])]) phi2 = np.mean([np.arctan2(newcoef[4], newcoef[1]), np.arctan2(-newcoef[11], newcoef[14])]) dx, dy, dphi = [dcoef[0], dcoef[10], phi2 - phi1] if header is not None: header['cal_dx'] = (dx, 'x-shift from archival spot positions (pixels)') header['cal_dy'] = (dy, 'y-shift from archival spot positions (pixels)') header['cal_dphi'] = (dphi, 'Rotation from archival spot positions (radians)') ################################################################# # Load the high-resolution PSFlet images and associated # wavelengths. ################################################################# hires_list = np.sort(glob.glob(indir + '/hires_psflets_lam????.fits')) hires_arrs = [fits.open(filename)[0].data for filename in hires_list] lam_hires = [int(re.sub('.*lam', '', re.sub('.fits', '', filename))) for filename in hires_list] psflet_res = 9 # Oversampling of high-resolution PSFlet images ################################################################# # Width of high-resolution PSFlets, in pixels. First compute the # width from the images perpendicular to the dispersion direction # at the central pixel along the dispersion direction. ################################################################# shape = hires_arrs[0].shape sigarr = np.zeros((len(hires_list), shape[0], shape[1])) _x = np.arange(shape[3])/9. _x -= _x[_x.shape[0]//2] for i in range(sigarr.shape[0]): for j in range(sigarr.shape[1]): for k in range(sigarr.shape[2]): row = hires_arrs[i][j, k, shape[2]//2] sigarr[i, j, k] = np.sum(row*_x**2) sigarr[i, j, k] /= np.sum(row) sigarr[i] = np.sqrt(sigarr[i]) ################################################################# # Now interpolate the width at the locations and wavelengths of # the microspectra for optimal extraction. First interpolate in # location, then interpolate in wavelength for each lenslet. ################################################################# mean_x = psftool.xindx[:, :, psftool.xindx.shape[-1]//2] mean_y = psftool.yindx[:, :, psftool.yindx.shape[-1]//2] longsigarr = np.zeros((len(lam_hires), mean_x.shape[0], mean_x.shape[1])) ix = mean_x*hires_arrs[0].shape[1]/2048. - 0.5 iy = mean_y*hires_arrs[0].shape[0]/2048. - 0.5 for i in range(sigarr.shape[0]): longsigarr[i] = ndimage.map_coordinates(sigarr[i], [iy, ix], order=3, mode='nearest') fullsigarr = np.zeros((psftool.xindx.shape)) for i in range(mean_x.shape[0]): for j in range(mean_x.shape[1]): fit = interpolate.interp1d(np.asarray(lam_hires), longsigarr[:, i, j], bounds_error=False, fill_value='extrapolate') fullsigarr[i, j] = fit(psftool.lam_indx[i, j]) out = fits.HDUList(fits.PrimaryHDU(fullsigarr.astype(np.float32))) out.writeto('PSFwidths.fits', clobber=True) ################################################################# # Wavelengths at which to return the PSFlet templates ################################################################# Nspec = int(np.log(lam2*1./lam1)*R + 1.5) loglam_endpts = np.linspace(np.log(lam1), np.log(lam2), Nspec) loglam_midpts = (loglam_endpts[1:] + loglam_endpts[:-1])/2 lam_endpts = np.exp(loglam_endpts) lam_midpts = np.exp(loglam_midpts) ################################################################# # Compute the PSFlets integrated over small ranges in wavelength, # accounting for atmospheric+filter transmission. Do this # calculation in parallel. ################################################################# xindx = np.arange(-100, 101) xindx, yindx = np.meshgrid(xindx, xindx) ################################################################# # Oversampling in x in final calibration frame. If >1, fitting a # subpixel shift is possible in cube extraction. ################################################################# if upsample: upsamp = 5 else: upsamp = 1 nlam = 10 # Number of monochromatic PSFlets per integrated PSFlet tasks = multiprocessing.Queue() results = multiprocessing.Queue() consumers = [ Consumer(tasks, results) for i in range(ncpus) ] for w in consumers: w.start() for i in range(upsamp*(Nspec - 1)): ilam = i//upsamp dx = (i%upsamp)*1./upsamp tool = copy.deepcopy(psftool) tool.interp_arr[0, 0] -= dx tasks.put(Task(i, primitives.make_polychrome, (lam_endpts[ilam], lam_endpts[ilam + 1], hires_arrs, lam_hires, tool, allcoef, xindx, yindx, psflet_res, nlam, trans))) for i in range(ncpus): tasks.put(None) polyimage = np.empty((Nspec - 1, 2048, 2048*upsamp), np.float32) print('Generating narrowband template images') for i in range(upsamp*(Nspec - 1)): frac_complete = (i + 1)*1./(upsamp*(Nspec - 1)) N = int(frac_complete*40) print('-'*N + '>' + ' '*(40 - N) + ' %3d%% complete\r' % (int(100*frac_complete)), end='') index, result = results.get() ilam = index//upsamp dx = (index%upsamp) polyimage[ilam, :, dx::upsamp] = result print('') ################################################################# # Save the positions of the PSFlet centers to cut out the # appropriate regions in the least-squares extraction ################################################################# xpos = [] ypos = [] good = [] for i in range(Nspec - 1): _x, _y = psftool.return_locations(lam_midpts[i], allcoef, xindx, yindx) _good = (_x > 8)*(_x < 2040)*(_y > 8)*(_y < 2040) xpos += [_x] ypos += [_y] good += [_good] if upsamp > 1: np.save(outdir + 'polychromefullR%d.npy' % (R), polyimage) out = fits.HDUList(fits.PrimaryHDU(polyimage[:, :, ::upsamp].astype(np.float32))) out.writeto(outdir + 'polychromeR%d.fits' % (R), clobber=True) outkey = fits.HDUList(fits.PrimaryHDU(lam_midpts)) outkey.append(fits.PrimaryHDU(np.asarray(xpos))) outkey.append(fits.PrimaryHDU(np.asarray(ypos))) outkey.append(fits.PrimaryHDU(np.asarray(good).astype(np.uint8))) outkey.writeto(outdir + 'polychromekeyR%d.fits' % (R), clobber=True) print("Total time elapsed: %.0f seconds" % (time.time() - tstart)) return None
if __name__ == "__main__": if len(sys.argv) < 4: print("Must call buildcal with at least three arguments:") print(" 1: The path to the narrow-band flatfield image") print(" 2: The wavelength, in nm, of the narrow-band image") print(" 3: The band/filter: 'J', 'H', 'K', or 'lowres'") print("Example: buildcal CRSA00000000.fits 1550 lowres") print("Optional additional arguments: filenames of darks") print(" taken with the same observing setup.") print("Example: buildcal CRSA00000000.fits 1550 lowres darks/CRSA*.fits") exit() infile = sys.argv[1] lam = int(sys.argv[2]) band = sys.argv[3] bgfiles = [] bgimages = [] for i in range(4, len(sys.argv)): bgfiles += glob.glob(sys.argv[i]) print("\n" + "*"*60) print("Oversample PSFlet templates to enable fitting a subpixel offset in cube") print("extraction? Cost is a factor of ~2-4 in the time to build calibrations.") print("*"*60) while True: upsample = raw_input(" Oversample? [Y/n]: ") if upsample in ['', 'y', 'Y']: upsample = True break elif upsample in ['n', 'N']: upsample = False break else: print("Invalid input.") ncpus = multiprocessing.cpu_count() print("\n" + "*"*60) print("How many threads would you like to use? %d threads detected." % (ncpus)) print("*"*60) while True: nthreads = raw_input(" Number of threads to use [%d]: " % (ncpus)) try: nthreads = int(nthreads) if nthreads < 0 or nthreads > ncpus: print("Must use between 1 and %d threads." % (ncpus)) else: break except: if nthreads == '': nthreads = ncpus break print("Invalid input.") print("\n" + "*"*60) print("Building calibration files, placing results in current directory:") print(os.path.abspath('.')) print("\nSettings:\n") print("Using %d threads" % (nthreads)) print("Narrow-band flatfield image: " + infile) print("Wavelength:", lam, "nm") print("Observing mode: " + band) print("Upsample PSFlet templates? ", upsample) if len(bgfiles) > 0: print("Background count rates will be computed.") else: print("No background will be computed.") print("*"*60) while True: do_calib = raw_input(" Continue with these settings? [Y/n]: ") if do_calib in ['', 'y', 'Y']: break elif do_calib in ['n', 'N']: exit() else: print("Invalid input.") ############################################################### # Wavelength limits in nm ############################################################### if band == 'J': lam1, lam2 = [1155, 1340] elif band == 'H': lam1, lam2 = [1470, 1800] elif band == 'K': lam1, lam2 = [2005, 2380] elif band == 'lowres': lam1, lam2 = [1140, 2410] else: raise ValueError('Band must be one of: J, H, K, lowres') if lam < lam1 or lam > lam2: raise ValueError("Error: wavelength " + str(lam) + " outside range (" + str(lam1) + ", " + str(lam2) + ") of mode " + band) #prefix = os.path.dirname(os.path.realpath(__file__)) prefix = pkg_resources.resource_filename('charis', 'calibrations') ############################################################### # Spectral resolutions for the final calibration files ############################################################### if band in ['J', 'H', 'K']: indir = os.path.join(prefix, "highres_" + band) R = 100 else: indir = os.path.join(prefix, "lowres") R = 30 mask = fits.open(os.path.join(prefix, 'mask.fits'))[0].data hdr = fits.PrimaryHDU().header hdr.clear() infilelist = glob.glob(infile) if len(infilelist) == 0: raise ValueError("No CHARIS file found for calibration.") hdr['calfname'] = (re.sub('.*/', '', infilelist[0]), 'Monochromatic image used for calibration') try: hdr['cal_date'] = (fits.open(infilelist[0])[0].header['mjd'], 'MJD date of calibration image') except: hdr['cal_date'] = ('unavailable', 'MJD date of calibration image') hdr['cal_lam'] = (lam, 'Wavelength of calibration image (nm)') hdr['cal_band'] = (band, 'Band/mode of calibration image (J/H/K/lowres)') ############################################################### # Mean background count rate, weighted by inverse variance ############################################################### print('Computing ramps from sequences of raw reads') num = 0 denom = 1e-100 ibg = 1 for filename in bgfiles: bg = utr.calcramp(filename=filename, mask=mask, maxcpus=nthreads) num = num + bg.data*bg.ivar denom = denom + bg.ivar hdr['bkgnd%03d' % (ibg)] = (re.sub('.*/', '', filename), 'Dark(s) used for background subtraction') ibg += 1 if len(bgfiles) > 0: background = Image(data=num/denom, ivar=1./denom) background.write('background.fits') else: hdr['bkgnd001'] = ('None', 'Dark(s) used for background subtraction') ############################################################### # Monochromatic flatfield image ############################################################### num = 0 denom = 1e-100 for filename in infilelist: im = utr.calcramp(filename=filename, mask=mask, maxcpus=nthreads) num = num + im.data*im.ivar denom = denom + im.ivar inImage = Image(data=num/denom, ivar=mask*1./denom) trans = np.loadtxt(os.path.join(indir, band + '_tottrans.dat')) buildcalibrations(inImage, lam, mask, indir, lam1=lam1, lam2=lam2, upsample=upsample, R=R, order=3, trans=trans, header=hdr, ncpus=nthreads) out = fits.HDUList(fits.PrimaryHDU(None, hdr)) out.writeto('cal_params.fits', clobber=True) shutil.copy(os.path.join(indir, 'lensletflat.fits'), './lensletflat.fits') for filename in ['mask.fits', 'pixelflat.fits']: shutil.copy(os.path.join(prefix, filename), './' + filename)