#!/usr/bin/env python
"""This module contains utility functions for plotting:

        Helper function to check arguments of plotting functions. Retrieve figure
        and axes from `axes`. If `axes` is None, create figure and axes, and
        return those.

        Split a :class:`matplotlib.axes.Axes` into one or more panels, with tied
        x and y axes. Also hides overlapping tick labels

        Remove pairs of values from two arrays of data if either  element in the
        pair is `nan` or `inf`. Used to prepare data for histogram or violin
        plots; as even masked `nan` and `inf` values are not handled by these

        Evaluate a kernel density estimate over observed data in linear or
        log-transformed space (e.g. for making violin plots in log space,
        but having kernels appropriately scaled).
import numpy
import scipy.stats
import matplotlib
import matplotlib.pyplot as plt

[docs]def get_fig_axes(axes=None): """Retrieve figure and axes from `axes`. If `axes` is None, both. Used as a helper function for replotting atop existing axes, by functions defined in :mod:`plastid.plotting.plots`. Parameters ---------- axes : :class:`matplotlib.axes.Axes` or `None` Axes in which to place plot. If `None`, a new figure is generated. Returns ------- :class:`matplotlib.figure.Figure` Parent figure of axes :class:`matplotlib.axes.Axes` Axes containing plot """ if axes is None: fig = plt.figure() ax = plt.gca() else: ax = axes fig = ax.figure return fig, ax
[docs]def split_axes( ax, top_height = 0, left_width = 0, right_width = 0, bottom_height = 0, main_ax_kwargs = {}, other_ax_kwargs = {} ): # yapf: disable """Split the spaces taken by one axes into one or more panes, setting the original axes invisible. Parameters ---------- ax : :class:`matplotlib.axes.Axes` Axes to split top_height, left_width, right_width, bottom_height : float, optional If not `None`, a panel on the corresponding side of the `ax` will be created, using whatever fraction is specified (e.g. 0.1 to use 10% of total height). main_ax_kwargs : dict Dictionary of keyword arguments for central panes, passed to :meth:`matplotlib.figure.Figure.add_axes` other_ax_kwargs : dict Dictionary of keyword arguments for peripheral panes, passed to :meth:`matplotlib.figure.Figure.add_axes` Returns ------- dict Dictionary of axes. `'orig'` refers to `ax`. The central panel is `'main'`. Other panels will be mapped to `'top'`, `'left`' et c, if they are created. """ fig = ax.figure ax.set_visible(False) axes = {"orig": ax} mplrc = matplotlib.rcParams # yapf: disable buf_left = mplrc["figure.subplot.left"] buf_bot = mplrc["figure.subplot.bottom"] buf_right = 1.0 - mplrc["figure.subplot.right"] buf_top = 1.0 - mplrc[""] hscale = 1.0 - buf_top - buf_bot wscale = 1.0 - buf_left - buf_right main_height = (1.0 - bottom_height - top_height) * hscale main_width = (1.0 - left_width - right_width) * wscale bottom_height *= hscale top_height *= hscale left_width *= wscale right_width *= wscale main_left = buf_left + left_width main_right = main_left + main_width main_bot = buf_bot + bottom_height main_top = main_bot + main_height # yapf: enable # each rect is (left, bottom, width, height) rects = {} rects["main"] = [main_left, main_bot, main_width, main_height] if left_width > 0: rects["left"] = [buf_left, main_bot, left_width, main_height] if right_width > 0: rects["right"] = [main_right, main_bot, right_width, main_height] if bottom_height > 0: rects["bottom"] = [main_left, main_bot, main_width, bottom_height] if top_height > 0: rects["top"] = [main_left, main_top, main_width, top_height] axes["main"] = fig.add_axes(rects["main"], zorder=100, **main_ax_kwargs) for axes_name, rect in rects.items(): if axes_name == "main": pass else: ax_kwargs = other_ax_kwargs if axes_name in ("right", "left"): ax_kwargs["sharey"] = axes["main"] if "sharex" in ax_kwargs: ax_kwargs.pop("sharex") if axes_name in ("top", "bottom"): ax_kwargs["sharex"] = axes["main"] if "sharey" in ax_kwargs: ax_kwargs.pop("sharey") axes[axes_name] = fig.add_axes(rect, zorder=50, **ax_kwargs) if "top" in axes: axes["top"].xaxis.tick_top() axes["main"].xaxis.tick_bottom() # prevent tick collisions axes["top"].yaxis.get_ticklabels()[0].set_visible(False) if "bottom" in axes: axes["bottom"].xaxis.tick_bottom() for t in axes["main"].xaxis.get_ticklabels(): t.set_visible(False) if "left" in axes: axes["left"].yaxis.tick_left() axes["left"].xaxis.get_ticklabels()[-1].set_visible(False) for t in axes["main"].yaxis.get_ticklabels(): t.set_visible(False) if "right" in axes: axes["right"].yaxis.tick_right() axes["main"].yaxis.tick_left() axes["right"].xaxis.get_ticklabels()[0].set_visible(False) return axes
[docs]def clean_invalid(x, y, min_x=-numpy.inf, min_y=-numpy.inf, max_x=numpy.inf, max_y=numpy.inf): """Remove corresponding values from x and y when one or both of those is `nan` or `inf`, and optionally truncate values to minima and maxima Parameters ---------- x, y : :class:`numpy.ndarray` or list Pair arrays or lists of corresponding numbers min_x, min_y, max_x, max_y : number, optional If supplied, set values below `min_x` to `min_x`, values larger than `max_x` to `max_x` and so for `min_y` and `max_y` Returns ------- :class:`numpy.ndarray` A shortened version of `x`, excluding invalid values :class:`numpy.ndarray` A shortened version of `y`, excluding invalid values """ x = numpy.array(x).astype(float) y = numpy.array(y).astype(float) x[x < min_x] = min_x x[x > max_x] = max_x y[y < min_y] = min_y y[y > max_y] = max_y newmask = numpy.isinf(x) | numpy.isnan(x) | numpy.isinf(y) | numpy.isnan(y) x = x[~newmask] y = y[~newmask] return x, y
[docs]def get_kde(data, log=False, base=2, points=100, bw_method="scott"): """Estimate a kernel density (kde) over `data` Parameters ---------- data : :class:`numpy.ndarray` Data to build kde over log : bool, optional If `True`, `data` is log-transformed before the kde is estimated. Data are converted back to non-log space afterwards. base : 2, 10, or :obj:`numpy.e`, optional If `log` is `True`, this serves as the base of the log space. If `log` is `False`, this is ignored. (Default: 2) points : int Number of points over which to evaluate kde. (Default: 100) bw_method : str Bandwith estimation method. See documentation for :obj:`scipy.stats.gaussian_kde`. (Default: "scott") Returns ------- :class:`numpy.ndarray` Points over which kde is evaluated (x-values), in non-log space :class:`numpy.ndarray` Value of kde (y-values), in non-log space """ if log == True: if base == 2: func = numpy.log2 elif base == 10: func = numpy.log10 elif base == numpy.e: func = numpy.log else: raise ValueError("kde: Base must be 2, 10, or numpy.e") data = func(data) domain = func(numpy.logspace(data.min(), data.max(), base=base, num=points)) else: domain = numpy.linspace(data.min(), data.max(), points) kde = scipy.stats.gaussian_kde(data, bw_method=bw_method) curve = kde.evaluate(domain) if log == True: domain = base**domain return domain, curve