Source code for marvin.utils.plot.scatter

# !usr/bin/env python
# -*- coding: utf-8 -*-
#
# Licensed under a 3-clause BSD license.
#
# @Author: Brian Cherinka
# @Date:   2017-08-21 17:11:22
# @Last modified by:   Brian Cherinka
# @Last Modified time: 2018-11-08 16:21:30

from __future__ import print_function, division, absolute_import
from marvin import config
from marvin.utils.datamodel.dap import datamodel
from marvin.core.exceptions import MarvinUserWarning
from marvin.utils.general import invalidArgs, isCallableWithArgs
from matplotlib.gridspec import GridSpec
from collections import defaultdict, OrderedDict
from astropy.visualization import hist as ahist
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
import six
import pandas as pd
import itertools
import warnings

try:
    import mpl_scatter_density as msd
except ImportError as e:
    msd = None
    msderr = ('mpl-scatter-density is required to plot large results and was not found.  '
              'To use this feature, please install the python package!')


[docs]def compute_stats(data): ''' Compute some statistics given a data array Computes some basic statistics given a data array, excluding NaN values. Computes and returns the following Numpy statistics: mean, standard deviation, median, and the 10th, 25th, 75th, and 90th percentiles. Parameters: data (list|ndarray): A list or Numpy array of data Returns: A dictionary of statistics values ''' stats = {'mean': np.nanmean(data), 'std': np.nanstd(data), 'median': np.nanmedian(data), 'per10': np.nanpercentile(data, 10), 'per25': np.nanpercentile(data, 25), 'per75': np.nanpercentile(data, 75), 'per90': np.nanpercentile(data, 90)} return stats
def _make_masked(data, mask=None): ''' Makes a masked array ''' arr_data = data if not isinstance(data, np.ma.MaskedArray): # mask out NaN values if a mask not provided warnings.warn("Masking out NaN values!", MarvinUserWarning) mask = mask if mask else np.isnan(data) # create array arr_data = np.ma.MaskedArray(data, mask=mask) return arr_data def _create_figure(hist=None, hist_axes_visible=None, use_density=None): ''' Create a generic figure and axis ''' # use a scatter density projection or not projection = 'scatter_density' if use_density else None # check if mpl-scatter-density if installed if not msd: raise ImportError(msderr) # create the figure fig = plt.figure() ax_hist_x = None ax_hist_y = None # create axes with or without histogram if hist: if hist is True: gs = GridSpec(4, 4) ax_scat = fig.add_subplot(gs[1:4, 0:3], projection=projection) ax_hist_x = fig.add_subplot(gs[0, 0:3]) ax_hist_y = fig.add_subplot(gs[1:4, 3]) elif hist == 'x': gs = GridSpec(2, 1, height_ratios=[1, 2]) ax_scat = fig.add_subplot(gs[1], projection=projection) ax_hist_x = fig.add_subplot(gs[0]) elif hist == 'y': gs = GridSpec(1, 2, width_ratios=[2, 1]) ax_scat = fig.add_subplot(gs[0], projection=projection) ax_hist_y = fig.add_subplot(gs[1]) else: ax_scat = fig.add_subplot(1, 1, 1, projection=projection) # turn off histogram axes if ax_hist_x: plt.setp(ax_hist_x.get_xticklabels(), visible=hist_axes_visible) if ax_hist_y: plt.setp(ax_hist_y.get_yticklabels(), visible=hist_axes_visible) return fig, ax_scat, ax_hist_x, ax_hist_y def _create_hist_title(data): ''' create a title for the histogram ''' stats = compute_stats(data) hist_title = 'Stats: $\\mu={mean:.3f}, \\sigma={std:.3f}$'.format(**stats) return hist_title def _get_dap_datamodel_property_label(quantity): ''' Format a DAP datamodel property string label ''' return '{0} [{1}]'.format(quantity.to_string('latex'), quantity.unit.to_string('latex')) def _get_axis_label(column, axis=''): ''' Create an axis label ''' from marvin.utils.datamodel.query.base import QueryParameter from marvin.utils.datamodel.dap.base import Property if isinstance(column, QueryParameter): if hasattr(column, 'property') and column.property: label = _get_dap_datamodel_property_label(column.property) else: label = column.display elif isinstance(column, Property): label = _get_dap_datamodel_property_label(column) elif isinstance(column, six.string_types): label = column else: # label = '{0} axis'.format(axis).strip() label = '' return label def _set_options(): ''' Set some default Matplotlib options ''' mpl.rcParams['axes.axisbelow'] = True mpl.rcParams['grid.color'] = 'gray' mpl.rcParams['grid.linestyle'] = 'dashed' mpl.rcParams['grid.alpha'] = 0.8 def _set_limits(column, lim=None, sigma_cutoff=50, percent_clip=1): ''' Set an axis limit Determines whether to apply percentile clipping or not if any data has a zscore value above the sigma_cutoff value. Applies percentile clipping centered around the mean. Parameters: column: The array of data to get limits of lim (list|tuple): A user provided range sigma_cutoff (int): The number of sigma away from the mean to cutoff percent_clip (int|tuple): The percent to clip off the data array. Input values are taken as percentages. Can either be integer value (halved for lo,hi) or a tuple specifying lo,hi values. Default is 1%. Returns: A list of axis range values to use ''' if lim is not None: assert len(lim) == 2, 'range must be a list or tuple of 2' else: # get percent clips if isinstance(percent_clip, (list, tuple)): lo, hi = percent_clip else: lo = percent_clip / 2. hi = 100 - lo zscore = stats.zscore(column) # use percentile limits if the max zscore is > 50 sigma away from mean/stdev if np.max(zscore) > sigma_cutoff: lim = [np.percentile(column, lo), np.percentile(column, hi)] else: pass return lim def _check_input_data(coldim, col, data=None): ''' Check the input data Parameters: coldim (str): Name of the dimension col (str|array): The list or array of values. If data keyword is specified, col is a string name data (Pandas.DataFrame) A Pandas dataframe Returns: The column of data ''' # check data assert col is not None, 'Must provide an {0} column'.format(coldim) if data is not None: assert isinstance(col, str), '{0} must be a string name if Dataframe provided'.format(coldim) assert isinstance(data, pd.core.frame.DataFrame), 'data must be Pandas dataframe' assert col in data.columns, '{0} must be a specified column name in Pandas dataframe'.format(coldim) col = data[col] else: assert isinstance(col, (list, np.ndarray, pd.core.series.Series)), '{0} data must be a list, Pandas Series, or Numpy array'.format(coldim) return col def _format_hist_kwargs(axis, **kwargs): ''' Format the histogram kwargs from plot ''' kwargs['color'] = kwargs.get('hist_color', 'lightblue') if axis == 'x': kwargs['ylabel'] = kwargs.get('xhist_label', 'Counts') kwargs['title'] = kwargs.get('xhist_title', None) elif axis == 'y': kwargs['ylabel'] = kwargs.get('yhist_label', 'Counts') kwargs['title'] = kwargs.get('yhist_title', None) kwargs['color'] = kwargs.get('hist_color', 'lightblue') kwargs['edgecolor'] = kwargs.get('edgecolors', None) return kwargs def _prep_func_kwargs(func, kwargs): ''' Prepare the keyword arguments for the proper function input Checks an input dictionary against allowed keyword arguments for a given function. Returns only those usable in that function. Parameters: func: The name of the function to check keywords against kwargs (dict): A dictionary of keyword arguments to test Returns: A new dictionary of usable keyword arguments ''' invalid = invalidArgs(func, kwargs) new_kwargs = kwargs.copy() for key in invalid: __ = new_kwargs.pop(key) if isCallableWithArgs(func, new_kwargs): return new_kwargs else: raise MarvinUserWarning('Cannot call func {0} with current kwargs {1}. Check your inputs'.format(func, new_kwargs))
[docs]def plot(x, y, **kwargs): ''' Create a scatter plot given two columns of data Creates a Matplotlib plot using two input arrays of data. Creates either a Matplotlib scatter plot, hexbin plot, or scatter density plot depending on the size of the input data. For data with < 1000 values, creates a scatter plot. For data with values between 1000 and 500,000, creates a hexbin plot. For data with > 500,000 values, creates a scatter density plot. By default, will also create and display histograms for the x and y data. This can be disabled setting the "with_hist" keyword to False, or "x", or "y" for displaying only that column. Accepts all the same keyword arguments as matplotlib scatter, hexbin, and hist methods. See `scatter-density <https://github.com/astrofrog/mpl-scatter-density>`_ See `matplotlib.pyplot.scatter <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.scatter>`_ See `matplotlib.pyplot.hexbin <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.hexbin>`_ Parameters: x (str|list|ndarray): The x array of data y (str|list|ndarray): The y array of data data (Pandas dataframe): Optional Pandas Dataframe. x, y specify string column names in the dataframe xmask (ndarray): A mask to apply to the x-array of data ymask (ndarray): A mask to apply to the y-array of data with_hist (bool|str): If True, creates the plot with both x,y histograms. False, disables it. If 'x' or 'y', only creates that histogram. Default is True. hist_axes_visible (bool): If True, disables the x-axis ticks for each histogram. Default is True. xlim (tuple): A tuple limited the range of the x-axis ylim (tuple): A tuple limited the range of the y-axis xlabel (str|Marvin column): The x axis label or a Marvin DataModel Property or QueryParameter to use for display ylabel (str|Marvin column): The y axis label or a Marvin DataModel Property or QueryParameter to use for display bins (int|tuple): A number or tuple specifying the number of bins to use in the histogram. Default is 50. An integer number is adopted for both x and y bins. A tuple is used to customize per axis. return_figure (bool): If True, return the figure and axis object. Default is True. kwargs (dict): Any other keyword arguments to be passed to `matplotlib.pyplot.scatter <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.scatter>`_ or `matplotlib.pyplot.hist <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.hist>`_ or `matplotlib.pyplot.hexbin <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.hexbin>`_. Returns: A tuple of the matplotlib figure, axes, and histogram data (if returned) Example: >>> # create a scatter plot >>> import numpy as np >>> from marvin.utils.scatter import plot >>> x = np.random.random(100) >>> y = np.random.random(100) >>> plot(x, y) ''' # check the input data data = kwargs.pop('data', None) x = _check_input_data('x', x, data=data) y = _check_input_data('y', y, data=data) # general keyword arguments use_datamodel = kwargs.pop('usemodel', None) xmask = kwargs.pop('xmask', None) ymask = kwargs.pop('ymask', None) return_figure = kwargs.pop('return_figure', True) # scatterplot keyword arguments xlim = kwargs.pop('xlim', None) ylim = kwargs.pop('ylim', None) xlabel = kwargs.pop('xlabel', None) ylabel = kwargs.pop('ylabel', None) color = kwargs.pop('color', None) size = kwargs.pop('size', 20) marker = kwargs.pop('marker', 'o') edgecolors = kwargs.pop('edgecolors', 'black') # hexbin keywords gridsize = kwargs.pop('gridsize', 50) # histogram keywords with_hist = kwargs.pop('with_hist', True) bins = kwargs.pop('bins', ['scott', 'scott']) hist_axes_visible = kwargs.pop('hist_axes_visible', False) # convert to numpy masked arrays x = _make_masked(x, mask=xmask) y = _make_masked(y, mask=ymask) count = len(x) use_density = True if count > 500000 else False # create figure and axes objects with plt.style.context('seaborn-darkgrid'): fig, ax_scat, ax_hist_x, ax_hist_y = _create_figure(hist=with_hist, use_density=use_density, hist_axes_visible=hist_axes_visible) # create the hexbin or scatter plot kind = kwargs.get('kind', None) assert kind in ['hex', 'scatter', 'density', 'joint', None], 'plot kind must be either scatter, hex, density, or joint' if count > 1000 and count <= 500000: scat_kwargs = _prep_func_kwargs(plt.hexbin, kwargs) main = ax_scat.hexbin(x, y, gridsize=gridsize, mincnt=1, cmap='inferno', **scat_kwargs) cb = fig.colorbar(main, ax=ax_scat, label='Counts') #ax_scat.grid(color='gray', linestyle='dashed', alpha=0.8) elif count > 500000: # abort if mpl-scatter-density is not installed if not msd: raise ImportError(msderr) scat_kwargs = _prep_func_kwargs(plt.imshow, kwargs) main = ax_scat.scatter_density(x, y, cmap='inferno', **scat_kwargs) cb = fig.colorbar(main, ax=ax_scat, label='Number of points per pixel') ax_scat.grid(color='gray', linestyle='dashed', alpha=0.8) else: # create the scatter plot scat_kwargs = _prep_func_kwargs(plt.scatter, kwargs) main = ax_scat.scatter(x, y, c=color, s=size, marker=marker, edgecolors=edgecolors, **scat_kwargs) cb = None #ax_scat.grid(color='gray', linestyle='dashed', alpha=0.8) # set limits xlim = _set_limits(x, lim=xlim) ylim = _set_limits(y, lim=ylim) if xlim: ax_scat.set_xlim(xlim) if ylim: ax_scat.set_ylim(ylim) # set display names xlabel = _get_axis_label(xlabel, axis='x') ylabel = _get_axis_label(ylabel, axis='y') ax_scat.set_xlabel(xlabel) ax_scat.set_ylabel(ylabel) # set axes object axes = [ax_scat] # create histogram dictionary if with_hist: hist_data = {} xbin, ybin = bins if isinstance(bins, list) else (bins, bins) # set x-histogram if ax_hist_x: xhist_kwargs = _format_hist_kwargs('x', **kwargs) #xrange = ax_scat.get_xlim() xhist, fig, ax_hist_x = hist(x, bins=xbin, fig=fig, ax=ax_hist_x, **xhist_kwargs) axes.append(ax_hist_x) hist_data['xhist'] = xhist if cb is not None: ocb = fig.colorbar(main, ax=ax_hist_x) ocb.remove() # set y-histogram if ax_hist_y: yhist_kwargs = _format_hist_kwargs('y', **kwargs) yhist, fig, ax_hist_y = hist(y, bins=ybin, fig=fig, ax=ax_hist_y, orientation='horizontal', rotate_title=True, **yhist_kwargs) axes.append(ax_hist_y) hist_data['yhist'] = yhist if return_figure: output = (fig, axes, hist_data) if with_hist else (fig, axes) else: output = hist_data if with_hist else None return output
[docs]def hist(arr, mask=None, fig=None, ax=None, bins=None, **kwargs): ''' Create a histogram of an array Plots a histogram of an input column of data. Input can be a list or a Numpy array. Converts the input into a Numpy MaskedArray, applying the optional mask. If no mask is supplied, it masks any NaN values. This uses `Astropy's enhanced hist <http://docs.astropy.org/en/stable/api/astropy.visualization.hist.html#astropy.visualization.hist>`_ function under the hood. Accepts all the same keyword arguments as matplotlib hist method. Parameters: arr (list|ndarray): An array of data to plot with. Required. mask (ndarray): A mask to use on the data, applied to the data in a Numpy Masked Array. fig (plt.fig): An optional matplotlib figure object ax (plt.ax): An optional matplotlib axis object bins (int): The number of bins to use. Default is a `scott <http://docs.astropy.org/en/stable/visualization/histogram.html>`_ binning scheme. xlabel (str|Marvin Column): The x axis label or a Marvin DataModel Property or QueryParameter to use for display ylabel (str): The y axis label title (str): The plot title rotate_title (bool): If True, moves the title text to the right y-axis during a horizontal histogram. Default is False. return_figure (bool): If True, return the figure and axis object. Default is True. kwargs (dict): Any other keyword arguments to be passed to `matplotlib.pyplot.hist <http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.hist>`_. Returns: tuple: histogram data, matplotlib figure, and axis objects. The histogram data returned is a dictionary containing:: { 'bins': The number of bins used, 'counts': A list of the count of objects within each bin, 'binedges': A list of the left binedge used in defining each bin, 'binids': An array of the same shape as input data, containing the binid of each element, 'indices': A dictionary of a list of array indices within each bin } Example: >>> # histogram some random data >>> from marvin.utils.plot.scatter import hist >>> import numpy as np >>> x = np.random.random(100) >>> hist_data, fig, ax = hist(x) ''' # check the input data data = kwargs.pop('data', None) arr = _check_input_data('column', arr, data=data) arr = _make_masked(arr, mask=mask) # general keywords xlabel = kwargs.pop('xlabel', None) ylabel = kwargs.pop('ylabel', 'Counts') title = kwargs.pop('title', None) rotate_title = kwargs.pop('rotate_title', False) return_figure = kwargs.pop('return_figure', True) # histogram keywords bins = bins if bins else 'scott' color = kwargs.pop('color', None) edgecolor = kwargs.pop('edgecolor', None) hrange = kwargs.pop('range', None) orientation = kwargs.pop('orientation', 'vertical') # create a figure and axis if they don't exist with plt.style.context('seaborn-darkgrid'): if fig is None and ax is None: fig, ax = plt.subplots() elif fig is None: fig = plt.figure() # set labels xlabel = _get_axis_label(xlabel, axis='x') ax.set_ylabel(ylabel) if orientation == 'vertical' else ax.set_ylabel(xlabel) ax.set_xlabel(xlabel) if orientation == 'vertical' else ax.set_xlabel(ylabel) # reset the label positions ax.yaxis.set_label_position('left') ax.xaxis.set_label_position('bottom') # set limits hrange = _set_limits(arr, lim=hrange) # set title title = title if title else _create_hist_title(arr) ax.set_title(title) if rotate_title: ax.set_title('') ax.yaxis.set_label_position('right') ax.yaxis.label.set_fontsize(12.0) ax.set_ylabel(title, rotation=270, verticalalignment='bottom') # create histogram hist_kwargs = _prep_func_kwargs(ahist, kwargs) counts, binedges, patches = ahist(arr[~arr.mask], bins=bins, color=color, orientation=orientation, edgecolor=edgecolor, range=hrange, ax=ax, **hist_kwargs) # compute a dictionary of the binids containing a list of the array indices in each bin binids = np.digitize(arr, binedges) inds = np.where(binids)[0] indices = defaultdict(list) tmp = list(map(lambda i, x: indices[x].append(i), inds, binids)) hist_data = {'counts': counts, 'binedges': binedges, 'bins': bins, 'binids': binids, 'indices': indices} output = (hist_data, fig, ax) if return_figure else hist_data return output