# Minimal demonstration of using AstroPy specutils to read a plot reduced
# spectra from Liverpool Telescope SPRAT level 2 reduced FITS files
#
# FITS extension 1 contains reduced version of the entire long slit
# not the other extensions which include the pipeline's guess at which object
# from the slit you wanted. Also, not the primary image which is the original
# CCD array.
#
# Examples:
#
# Co-add pixels 89,90,91 from FITS extension v_s_20200809_99_1_0_2.fits[1] 
#
#   python sprat_splot_lss.py v_s_20200809_99_1_0_2.fits 89 91
#
# Co-add pixels 89,90,91 from FITS extension v_s_20200809_99_1_0_2.fits[1] with sky subtraction
# from forty pixels near the top of the slit.
#
#   python sprat_splot_lss.py v_s_20200809_99_1_0_2.fits 89 91  -s 150 190

# Tested using
#   astropy                       4.0.2
#   specutils                     1.1             
#   matplotlib                    3.3.4              
#   numpy                         1.17.4             
# Specific version requirements are not well defined

import argparse
import numpy as np
from astropy.io import fits
from astropy import units as u
from astropy.visualization import quantity_support
from matplotlib import pyplot as plt
import astropy.wcs as fitswcs
import specutils as sp

quantity_support()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Plot a spectrum for a contiguous section of the SPRAT slit')

    parser.add_argument('infile', action='store', help='Reduced SPRAT FITS to display. Filename ends _2.fits, mandatory')
    parser.add_argument('firstPix', action='store', type=int, help='First pixel to include in output. FITS convention: count from 1 as bottom pixel. Mandatory')
    parser.add_argument('lastPix', action='store', type=int, help='Last pixel to include in output. FITS convention: count from 1 as bottom pixel. Mandatory')
    parser.add_argument('-s', dest='skyPix', action='store', type=int, nargs=2, help='First and last pixels of block to use for sky. FITS convention: count from 1 as bottom pixel. Optional. Example: -s 100 110')
    parser.add_argument('-v', dest='verbose', action='store_true', help='Turn on verbose mode (default: Off)')

    args = parser.parse_args()

if args.verbose:
  print (args)

# Select the pixel rows to plot, counting up from the bottom of the LSS extension
# To get FITS pixels 100 - 103 inclusive, python islice would be [99:103]
rowsToPlot = np.arange(args.firstPix-1, args.lastPix)
countToPlot = args.lastPix - args.firstPix + 1

# Read the FITS extension 1[LSS_SS]
f = fits.open(args.infile)
specarray = f[1].data
specheader = f[1].header
f.close()

# specutils does not work well with the pixel-based WCS defined in the default SPRAT
# headers. Delete WCS dimension 2, which is a pixel image coordinate system, not really 
# a world coordinate system.
specheader['NAXIS'] = 1
del specheader['NAXIS2']
del specheader['CTYPE2']
del specheader['CUNIT2']
del specheader['CRVAL2']
del specheader['CRPIX2']
del specheader['CDELT2']

# SPRAT header should be 'Angstrom', not 'Angstroms'
specheader['CUNIT1'] = 'Angstrom'

# WCS standard has been changed since the LT FITS header convention was created.
specheader['RADESYSa'] = specheader['RADECSYS']
del specheader['RADECSYS']

my_wcs = fitswcs.WCS(specheader)

# Create sky if -s was set. Otherwise create empty array of zeros 
# Also convert from one-based FITS pixel numbers to zero-based python arrays
if args.skyPix :
  skyFlux = np.mean(specarray[args.skyPix], axis=0) * u.Unit("adu")
  args.skyPix[0] = args.skyPix[0] - 1
else :
  skyFlux = np.zeros(specarray.shape[1])

# Create a Spectrum1D object containing the entire slit 
sp1d = sp.Spectrum1D(flux=specarray * u.Unit("adu"), wcs=my_wcs)

# Plot the individual spectra, one pixel row at a time
fig = plt.figure(figsize=(12, 4), dpi=80, facecolor='w', edgecolor='k')
ax = plt.gca()
for ii in range(0, countToPlot):
  plt.plot(sp1d.spectral_axis,sp1d.flux[rowsToPlot[ii]]-skyFlux,label='Pixel row '+str(rowsToPlot[ii]))

# Coadd the requested rows into one total and plot the result 
coadded = np.sum(specarray[args.firstPix-1:args.lastPix],axis=0)
sp1d = sp.Spectrum1D(flux=coadded * u.Unit("adu"), wcs=my_wcs)

# Plot the coadded spectrum
plt.plot(sp1d.spectral_axis,sp1d.flux,label='Coadded')

if args.skyPix :
  plt.plot(sp1d.spectral_axis,(sp1d.flux - countToPlot*skyFlux),label='Sky subtracted')

plt.title(specheader['OBJECT'] + " " + args.infile)
plt.legend()
plt.show()

