Source code for marvin.tools.image

#!/usr/bin/env python
# encoding: utf-8


# Created by Brian Cherinka on 2016-05-10 20:17:52
# Licensed under a 3-clause BSD license.

# Revision History:
#     Initial Version: 2016-05-10 20:17:52 by Brian Cherinka
#     Last Modified On: 2016-05-10 20:17:52 by Brian


from __future__ import absolute_import, division, print_function

import os
import sys
import warnings
import random
import itertools
import requests
import PIL
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import EllipseCollection

from astropy.io import fits
import marvin
from marvin.tools.mixins import MMAMixIn
from marvin.core.exceptions import MarvinError, MarvinUserWarning
from marvin.utils.general import (getWCSFromPng, Bundle, Cutout, target_is_mastar,
                                  get_plates, check_versions)
try:
    from sdss_access import HttpAccess
except ImportError:
    HttpAccess = None

if sys.version_info.major == 2:
    from cStringIO import StringIO as stringio
else:
    from io import BytesIO as stringio

__all__ = ['Image']


[docs]class Image(MMAMixIn): '''A class to interface with MaNGA images. This class represents a MaNGA image object initialised either from a file, or remotely via the Marvin API. TODO: what kinds of images should this handle? optical, maps, nsa preimaging? TODO: should this be subclasses into different kinds of images? DRPImage, MapImage, NSAImage? Attributes: header (`astropy.io.fits.Header`): The header of the datacube. ra,dec (float): Coordinates of the target. wcs (`astropy.wcs.WCS`): The WCS solution for this plate bundle (object): A Bundle of fibers associated with the IFU ''' def __init__(self, input=None, filename=None, mangaid=None, plateifu=None, mode=None, data=None, release=None, download=None): MMAMixIn.__init__(self, input=input, filename=filename, mangaid=mangaid, plateifu=plateifu, mode=mode, data=data, release=release, download=download, ignore_db=True) if self.data_origin == 'file': self._load_image_from_file() elif self.data_origin == 'db': raise MarvinError('Images cannot currently be accessed from the db') elif self.data_origin == 'api': self._load_image_from_api() # initialize attributes self._init_attributes() # create the hex bundle if self.ra and self.dec: self.bundle = Bundle(self.ra, self.dec, plateifu=self.plateifu, size=int(str(self.ifu)[:-2])) def __repr__(self): '''Image representation.''' return '<Marvin Image (plateifu={0}, mode={1}, data-origin={2})>'.format(repr(self.plateifu), repr(self.mode), repr(self.data_origin)) def _init_attributes(self): ''' Initialize some attributes ''' self.ra = None self.dec = None # try to create a header try: self.header = fits.header.Header(self.data.info) except Exception: warnings.warn('No proper header found image', MarvinUserWarning) self.header = None # try to set the RA, Dec if self.header: self.ra = float(self.header["RA"]) if 'RA' in self.header else None self.dec = float(self.header["DEC"]) if 'DEC' in self.header else None # try to set the WCS try: self.wcs = getWCSFromPng(image=self.data) except MarvinError: self.wcs = None warnings.warn('No proper WCS info for this image') def _get_image_dir(self): ''' Gets the correct images directory by release and mastar target ''' # all images in MPL-4 are in stack dirs if self.release == 'MPL-4': return 'stack' # get the appropriate image directory is_mastar = target_is_mastar(self.plateifu, drpver=self._drpver) image_dir = 'mastar' if is_mastar else 'stack' return image_dir def _getFullPath(self): """Returns the full path of the file in the tree.""" if not self.plateifu: return None plate, ifu = self.plateifu.split('-') dir3d = self._get_image_dir() name = 'mangaimage' return super(Image, self)._getFullPath(name, ifu=ifu, dir3d=dir3d, drpver=self._drpver, plate=plate)
[docs] def download(self): """Downloads the image using sdss_access - Rsync,""" if not self.plateifu: return None plate, ifu = self.plateifu.split('-') dir3d = self._get_image_dir() name = 'mangaimage' return super(Image, self).download(name, ifu=ifu, dir3d=dir3d, drpver=self._drpver, plate=plate)
def __del__(self): if self.data: self.data.close() def _load_image_from_file(self): ''' Load an image from a local file ''' filepath = self._getFullPath() if os.path.exists(filepath): self._filepath = filepath self.data = self._open_image(filepath) else: raise MarvinError('Error: local filepath {0} does not exist. '.format(filepath)) def _load_image_from_api(self): ''' Load an image from a remote location ''' filepath = self._getFullPath() response = requests.get(self.url) if not response.ok: raise MarvinError('Error: remote filepath {0} does not exist'.format(filepath)) else: fileobj = stringio(response.content) self.data = self._open_image(fileobj, filepath=self.url) @property def url(self): if not HttpAccess: raise MarvinError('sdss_access not installed') filepath = self._getFullPath() http = HttpAccess(verbose=False) url = http.url("", full=filepath) return url @staticmethod def _open_image(fileobj, filepath=None): ''' Open the Image using PIL ''' try: image = PIL.Image.open(fileobj) except IOError as e: warnings.warn('Error: cannot open image', MarvinUserWarning) image = None else: image.filename = filepath or fileobj return image
[docs] def show(self): ''' Show the image ''' if self.data: self.data.show()
[docs] def close(self): ''' Slose the image ''' if self.data: self.data.close()
[docs] def save(self, filename, filetype='png', **kwargs): ''' Save the image to a file This only saves the original image. To save the Matplotlib plot, use the savefig method on the matplotlib.pyplot.figure object Parameters: filename (str): The filename of the output image filetype (str): The filetype, e.g. png kwargs: Additional keyword arguments to the PIL.Image.save method ''' __, fileext = os.path.splitext(filename) assert filetype or fileext, 'Filename must have an extension or specify the filetype' if self.data: self.data.save(filename, format=filetype, **kwargs)
[docs] def plot(self, return_figure=True, dpi=100, with_axes=None, fibers=None, skies=None, **kwargs): ''' Creates a Matplotlib plot the image Parameters: fibers (bool): If True, overlays the fiber positions. Default is False. skies (bool): If True, overlays the sky fibers if possible. Default is False. return_figure (bool): If True, returns the figure axis object. Default is True dpi (int): The dots per inch for the matplotlib figure with_axes (bool): If True, plots the image with axes kwargs: Keyword arguments for overlay_fibers and overlay_skies ''' pix_size = np.array(self.data.size) figsize = np.ceil(pix_size / dpi) if with_axes: fig = plt.figure() ax = plt.subplot(projection=self.wcs) ax.set_xlabel('Declination') ax.set_ylabel('Right Ascension') aspect = None else: fig = plt.figure(figsize=figsize) ax = fig.add_axes([0., 0., 1., 1.], projection=self.wcs) ax.set_axis_off() plt.axis('off') aspect = 'auto' ax.imshow(self.data, origin='lower', aspect=aspect) # overlay the IFU fibers if fibers: self.overlay_fibers(ax, return_figure=return_figure, skies=skies, **kwargs) if return_figure: return ax
[docs] def overlay_fibers(self, ax, diameter=None, skies=None, return_figure=True, **kwargs): """ Overlay the individual fibers within an IFU on a plot. Parameters: ax (Axis): The matplotlib axis object diameter (float): The fiber diameter in arcsec. Default is 2". skies (bool): Set to True to additionally overlay the sky fibers. Default if False return_figure (bool): If True, returns the figure axis object. Default is True kwargs: Any keyword arguments accepted by Matplotlib EllipseCollection """ if self.wcs is None: raise MarvinError('No WCS found. Cannot overlay fibers.') # check the diameter if diameter: assert isinstance(diameter, (float, int)), 'diameter must be a number' diameter = (diameter or 2.0) / float(self.header['SCALE']) # get the fiber pixel coordinates fibers = self.bundle.fibers[:, [1, 2]] fiber_pix = self.wcs.wcs_world2pix(fibers, 1) # some matplotlib kwargs kwargs['edgecolor'] = kwargs.get('edgecolor', 'Orange') kwargs['facecolor'] = kwargs.get('facecolor', 'none') kwargs['linewidth'] = kwargs.get('linewidth', 0.4) ec = EllipseCollection(diameter, diameter, 0.0, units='xy', offsets=fiber_pix, transOffset=ax.transData, **kwargs) ax.add_collection(ec) # overlay the sky fibers if skies: self.overlay_skies(ax, diameter=diameter, return_figure=return_figure, **kwargs) if return_figure: return ax
[docs] def overlay_hexagon(self, ax, return_figure=True, **kwargs): """ Overlay the IFU hexagon on a plot Parameters: ax (Axis): The matplotlib axis object return_figure (bool): If True, returns the figure axis object. Default is True kwargs: Any keyword arguments accepted by Matplotlib plot """ if self.wcs is None: raise MarvinError('No WCS found. Cannot overlay hexagon.') # get IFU hexagon pixel coordinates hexagon_pix = self.wcs.wcs_world2pix(self.bundle.hexagon, 1) # reconnect the last point to the first point. hexagon_pix = np.concatenate((hexagon_pix, [hexagon_pix[0]]), axis=0) # some matplotlib kwargs kwargs['color'] = kwargs.get('color', 'magenta') kwargs['linestyle'] = kwargs.get('linestyle', 'solid') kwargs['linewidth'] = kwargs.get('linewidth', 0.8) ax.plot(hexagon_pix[:, 0], hexagon_pix[:, 1], **kwargs) if return_figure: return ax
[docs] def overlay_skies(self, ax, diameter=None, return_figure=True, **kwargs): """ Overlay the sky fibers on a plot Parameters: ax (Axis): The matplotlib axis object diameter (float): The fiber diameter in arcsec return_figure (bool): If True, returns the figure axis object. Default is True kwargs: Any keyword arguments accepted by Matplotlib EllipseCollection """ if self.wcs is None: raise MarvinError('No WCS found. Cannot overlay sky fibers.') # check for sky coordinates if self.bundle.skies is None: self.bundle.get_sky_coordinates() # check the diameter if diameter: assert isinstance(diameter, (float, int)), 'diameter must be a number' diameter = (diameter or 2.0) / float(self.header['SCALE']) # get sky fiber pixel positions fiber_pix = self.wcs.wcs_world2pix(self.bundle.skies, 1) outside_range = ((fiber_pix < 0) | (fiber_pix > self.data.size[0])).any() if outside_range: raise MarvinError('Cannot overlay sky fibers. Image is too small. ' 'Please retrieve a bigger image cutout') # some matplotlib kwargs kwargs['edgecolor'] = kwargs.get('edgecolor', 'Orange') kwargs['facecolor'] = kwargs.get('facecolor', 'none') kwargs['linewidth'] = kwargs.get('linewidth', 0.7) # draw the sky fibers ec = EllipseCollection(diameter, diameter, 0.0, units='xy', offsets=fiber_pix, transOffset=ax.transData, **kwargs) ax.add_collection(ec) # Add a larger circle to help identify the sky fiber locations in large images. if (self.data.size[0] > 1000) or (self.data.size[1] > 1000): ec = EllipseCollection(diameter * 5, diameter * 5, 0.0, units='xy', offsets=fiber_pix, transOffset=ax.transData, **kwargs) ax.add_collection(ec) if return_figure: return ax
[docs] def get_new_cutout(self, width, height, scale=None, **kwargs): ''' Get a new Image Cutout using Skyserver Replaces the current Image with a new image cutout. The data, header, and wcs attributes are updated accordingly. Parameters: width (int): Cutout image width in arcsec height (int): Cutout image height in arcsec scale (float): arcsec/pixel scale of the image kwargs: Any additional keywords for Cutout ''' cutout = Cutout(self.ra, self.dec, width, height, scale=scale, **kwargs) self.data = cutout.image self._init_attributes()
[docs] @classmethod def from_list(cls, values, release=None): ''' Generate a list of Marvin Image objects Class method to generate a list of Marvin Images from an input list of targets Parameters: values (list): A list of target ids (i.e. plateifus, mangaids, or filenames) release (str): The release of Images to get Returns: a list of Marvin Image objects Example: >>> from marvin.tools.image import Image >>> targets = ['8485-1901', '7443-1201'] >>> images = Image.from_list(targets) ''' images = [] for item in values: images.append(cls(item, release=release)) return images
[docs] @classmethod def by_plate(cls, plateid, minis=None, release=None): ''' Generate a list of Marvin Images by plate Class method to generate a list of Marvin Images from a single plateid. Parameters: plateid (int): The plate id to grab minis (bool): If True, includes the mini-bundles release (str): The release of Images to get Returns: a list of Marvin Image objects Example: >>> from marvin.tools.image import Image >>> images = Image.by_plate(8485) ''' ifus = cls._get_ifus(minis=minis) plateifus = ['{0}-{1}'.format(plateid, i) for i in ifus] images = cls.from_list(plateifus, release=release) return images
[docs] @classmethod def get_random(cls, num=5, minis=None, release=None): ''' Generate a set of random Marvin Images Class method to generate a random list of Marvin Images Parameters: num (int): The number to grab. Default is 5 minis (bool): If True, includes the mini-bundles release (str): The release of Images to get Returns: a list of Marvin Image objects Example: >>> from marvin.tools.image import Image >>> images = Image.get_random(5) ''' ifus = cls._get_ifus(minis=minis) plates = get_plates(release=release) rand_samp = random.sample(list(itertools.product(map(str, plates), ifus)), num) plateifus = ['-'.join(r) for r in rand_samp] images = cls.from_list(plateifus, release=release) return images
[docs] def getCube(self): """Returns the :class:`~marvin.tools.cube.Cube` for this Image.""" return marvin.tools.cube.Cube(plateifu=self.plateifu, release=self.release)
[docs] def getMaps(self, **kwargs): """Retrieves the DAP :class:`~marvin.tools.maps.Maps` for this Image. If called without additional ``kwargs``, :func:`getMaps` will initialize the :class:`~marvin.tools.maps.Maps` using the ``plateifu`` of this :class:`~marvin.tools.image.Image`. Otherwise, the ``kwargs`` will be passed when initialising the :class:`~marvin.tools.maps.Maps`. """ if len(kwargs.keys()) == 0 or 'filename' not in kwargs: kwargs.update({'plateifu': self.plateifu, 'release': self._release}) maps = marvin.tools.maps.Maps(**kwargs) return maps