from marvin.tools import Maps
from matplotlib import pyplot as plt
from marvin.utils.dap.bpt import kewley_sf_nii, kewley_comp_nii
import numpy as np

mm = Maps('8485-1901')

masks, fig, axes = mm.get_bpt(show_plot=False)

# Gets the masks for NII/Halpha
sf = masks['sf']['nii']
comp = masks['comp']['nii']
agn = masks['agn']['nii']

# Gets the necessary maps
ha = mm['emline_gflux_ha_6564']
hb = mm['emline_gflux_hb_4862']
nii = mm['emline_gflux_nii_6585']
oiii = mm['emline_gflux_oiii_5008']

# Calculates log(NII/Ha) and log(OIII/Hb)
log_nii_ha = np.ma.log10(nii.value / ha.value)
log_oiii_hb = np.ma.log10(oiii.value / hb.value)

# Creates figure and axes
fig, ax = plt.subplots()

# Plots SF, composite, and AGN spaxels using the masks
ax.scatter(log_nii_ha[sf], log_oiii_hb[sf], c='b')
ax.scatter(log_nii_ha[comp], log_oiii_hb[comp], c='g')
ax.scatter(log_nii_ha[agn], log_oiii_hb[agn], c='r')

# Creates a linspace of points for plotting the classification lines
xx_sf_nii = np.linspace(-2, 0.045, int(1e4))
xx_comp_nii = np.linspace(-2, 0.4, int(1e4))

# Uses kewley_sf_nii and kewley_comp_nii to plot the classification lines
ax.plot(xx_sf_nii, kewley_sf_nii(xx_sf_nii), 'k-')
ax.plot(xx_comp_nii, kewley_comp_nii(xx_comp_nii), 'r-')

ax.set_xlim(-2, 1)
ax.set_ylim(-1.5, 1.6)

ax.set_xlabel(r'log([NII]/H$\alpha$)')
ax.set_ylabel(r'log([OIII]/H$\beta$)')