# Minimal demonstration of using AstroPy specutils to read a plot reduced
# spectra from Liverpool Telescope SPRAT level 2 reduced FITS files

import argparse
from astropy.io import fits
from astropy import units as u
from astropy.visualization import quantity_support
from matplotlib import pyplot as plt

import astropy.wcs as fitswcs
import specutils as sp

quantity_support()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Plot a SPRAT spectrum')

    parser.add_argument('infile', action='store', help='Reduced SPRAT FITS to display. Filename ends _2.fits, mandatory')
    parser.add_argument('-x', dest='ext_str', choices=['2','SPEC_NONSS','3','SPEC_SS','4','NORMFLUX','5','FLUX'], default='3', help='Extension to display (default: 3, SPEC_SS)')
    parser.add_argument('-v', dest='verbose', action='store_true', help='Turn on verbose mode (default: Off)')

    args = parser.parse_args()

    # Parse the string names of the extensions into extension numbers
    if args.ext_str in ['2','SPEC_NONSS']:
      extension = 2
      unitName = "adu"
    elif args.ext_str in ['3','SPEC_SS']:
      extension = 3
      unitName = "adu"
    elif args.ext_str in ['4','NORMFLUX']:
      extension = 4
      # Relative flux normalized to 5500A is dimensionless
      unitName = ""
    elif args.ext_str in ['5','FLUX']:
      extension = 5
      unitName = "erg/(s*cm**2*Angstrom)"


if args.verbose:
  print (args)

# Read from the file into memory and then close it
f=fits.open(args.infile)
# Spectra are in the FITS as 2D (NAXIS1 x 1) arrays. Convert to a 1D (NAXIS1) vector.
specdata = f[extension].data[0]
specheader=f[extension].header
f.close()

# We now need to make various changes to the FITS header. These changes fall into
# three categories.
# 1. We converted the data from the slightly pointless 2D array with NAXIS2=1 into a true
#    1D vector, so we need to delete the WCS that was associated with axis 2.
# 2. WCS standard has been changed since the LT FITS header convention was created.
# 3. Fix true errors in the original SPRAT header definition

# Delete the WCS that was associated with axis 2
specheader['NAXIS'] = 1
del specheader['NAXIS2']
del specheader['CTYPE2']
del specheader['CUNIT2']
del specheader['CRVAL2']
del specheader['CRPIX2']
del specheader['CDELT2']

# WCS standard has been changed since the LT FITS header convention was created.
specheader['RADESYSa'] = specheader['RADECSYS']
del specheader['RADECSYS']

# The SPRAT FITS header naughtily uses "Angstroms". You shouldn't pluralise units.
specheader['CUNIT1']="Angstrom"

flux = specdata * u.Unit(unitName)

# Make WCS Wavelength Calibration from fits header
my_wcs = fitswcs.WCS(specheader)

if args.verbose :
  print(my_wcs)

# Make Spectrum1D object
sp1d = sp.Spectrum1D(flux=flux,wcs=my_wcs)

plt.plot(sp1d.spectral_axis,sp1d.flux)
plt.title(specheader['OBJECT'] + " " + args.infile)
plt.show()

