#!/usr/bin/env python
'''Class for plotting simulation data.
@author: Zach Hafen
@contact: zachary.h.hafen@gmail.com
@status: Development
'''
# Base python imports
import numpy as np
import os
import scipy.stats
import scipy.signal as signal
import warnings
import verdict
import matplotlib
matplotlib.use('PDF')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as plt_colors
import matplotlib.gridspec as gridspec
import matplotlib.patheffects as path_effects
import matplotlib.transforms as transforms
import galaxy_dive.utils.mp_utils as mp_utils
import galaxy_dive.utils.utilities as utilities
import galaxy_dive.plot_data.plotting as gen_plot
import galaxy_dive.plot_data.pu_colormaps as pu_cm
########################################################################
########################################################################
class GenericPlotter( object ):
@utilities.store_parameters
def __init__( self, data_object, label=None, color='black', ):
'''
Args:
data_object ( generic_data.GenericData object or subclass of such ) : The data container to use.
'''
pass
########################################################################
# Alternate inherent methods
########################################################################
def __getattr__( self, attr):
'''By replacing getattr with the following code, we allow automatically searching the data_object
for the appropriate attribute as well, while losing none of the original functionality.
'''
print( "Attribute {} not found in plotting object. Checking data object.".format( attr ) )
return getattr( self.data_object, attr )
########################################################################
# Specific Generic Plots
########################################################################
def histogram( self,
data_key,
provided_data = None,
provided_hist = None,
weight_key = None,
slices = None,
ax = None,
fix_invalid = False,
mask_zeros = False,
invalid_fix_method = None,
bins = 32,
normed = True,
norm_type = 'probability',
scaling = None,
smooth = False,
smoothing_window_length = 9,
smoothing_polyorder = 3,
histogram_style = 'step',
color = 'black',
linestyle = '-',
linewidth = 3.5,
alpha = 1.,
x_range = None, y_range = None,
x_label = None, y_label = None,
add_x_label = True, add_y_label = True,
add_plot_label = True,
plot_label = None,
line_label = None,
label_fontsize = 24,
x_scale = 'linear', y_scale = 'linear',
cdf = False,
vertical_line = None,
vertical_line_kwargs = { 'linestyle': '--', 'linewidth': 3, 'color': 'k', },
return_dist = False,
assert_contains_all_data = True,
data_kwargs = {},
*args, **kwargs ):
'''Make a histogram of the data. Extra arguments are passed to
self.data_object.get_selected_data.
Args:
data_key (str) :
Data key to plot.
weight_key (str) :
Data key for data to use as a weight. By None, no weight.
slices (int or tuple of slices) :
How to slices the data.
ax (axis) :
What axis to use. By default creates a figure and places the axis on it.
fix_invalid (bool) :
Throw away invalid values?
invalid_fix_method (float or int) :
How to handle invalid values. By None throw them away. Providing a value to this argument instead replaces
them with that value.
bins (int or array-like) :
bins argument to be passed to np.histogram
normed (bool) :
Normalize the histogram?
color (str) :
Color of histogram.
linestyle (str) :
Linestyle of histogram.
linewidth (float) :
Linewidth for histogram.
alpha (float) :
Alpha value for histogram
x_range, y_range ( [min,max] ) :
What are the minimum and maximum x- and y- values to include?
Defaults to matplotlib's automatic choices
x_label, ylabel (str) :
Axes labels. Defaults to the data_key for the x-axis and "Normalized Histogram" for the y-axis.
add_x_label, add_y_label (bool) :
Include axes labels?
plot_label (str or dict) :
What to label the plot with. By None, uses self.label.
line_label (str) :
What label to give the line.
label_fontsize (int) :
Fontsize for the labels.
x_scale, y_scale (str) :
What scales to use for the x and y axes.
cdf (bool) :
Plot a CDF instead.
vertical_line (float) :
Plot a vertical line at this value on the x-axis, if true.
vertical_line_kwargs (dict) :
Arguments to pass to ax.plot( for the vertical line.
return_dist (bool) :
If True, return the data values and the edges for the histogram.
assert_contains_all_data (bool) :
If True, make sure that the histogram plots all selected data.
*args, **kwargs :
Extra arguments to pass to self.data_object.get_selected_data()
'''
print( "Plotting histogram for {}".format( data_key ) )
if provided_hist is None:
if isinstance( slices, int ):
sl = ( slice(None), slices )
else:
sl = slices
data_kwargs = utilities.merge_two_dicts( data_kwargs, kwargs )
if provided_data is None:
data = self.data_object.get_selected_data(
data_key,
sl=sl,
*args, **data_kwargs
).copy()
else:
data = provided_data.copy()
if weight_key is None:
weights = None
else:
if 'scale_key' in kwargs:
warnings.warn(
"Scaling weights by {}. Is this correct?".format(
kwargs['scale_key']
)
)
weights = self.data_object.get_selected_data( weight_key, sl=sl, *args, **kwargs )
if fix_invalid:
if invalid_fix_method is None:
data = np.ma.fix_invalid( data ).compressed()
else:
data = np.ma.fix_invalid( data )
data.fill_value = invalid_fix_method
data = data.filled()
# Make the histogram itself
hist, edges = np.histogram( data, bins=bins, weights=weights )
# Make sure we have all the data in the histogram
if assert_contains_all_data:
assert data.size == hist.sum()
if normed:
if norm_type == 'probability':
hist = hist.astype( float ) / ( hist.sum()*(edges[1] - edges[0]) )
elif norm_type == 'bin_width':
hist = hist.astype( float ) / (edges[1] - edges[0])
elif norm_type == 'outer_edge':
hist = hist.astype( float ) / hist[-1]
elif norm_type == 'max_value':
hist = hist.astype( float ) / hist.max()
else:
raise Exception(
"Unrecognized norm_type, {}".format( norm_type )
)
if scaling is not None:
hist *= scaling
if cdf:
hist = np.cumsum( hist )*(edges[1] - edges[0])
else:
hist = provided_hist
edges = bins
if mask_zeros:
hist = np.ma.masked_where(
hist < 1e-14,
hist,
)
if smooth:
hist = signal.savgol_filter(
hist,
window_length = smoothing_window_length,
polyorder = smoothing_polyorder,
)
if ax is None:
fig = plt.figure( figsize=(11,5), facecolor='white', )
ax = plt.gca()
if line_label is None:
line_label = self.label
if color is None:
color = self.color
# Inserting a 0 at the beginning allows plotting a numpy histogram with a step plot
if histogram_style == 'step':
ax.step(
edges,
np.insert(hist, 0, 0.),
color=color,
linestyle=linestyle,
linewidth=linewidth,
label=line_label,
alpha=alpha,
)
elif histogram_style == 'line':
x_values = 0.5 * ( edges[:-1] + edges[1:] )
ax.plot(
x_values,
hist,
color=color,
linestyle=linestyle,
linewidth=linewidth,
label=line_label,
alpha=alpha,
)
else:
raise KeyError(
"Unrecognized histogram_style, {}".format( histogram_style )
)
# Plot a vertical line?
if vertical_line is not None:
trans = transforms.blended_transform_factory( ax.transData, ax.transAxes )
ax.plot(
[ vertical_line, ]*2,
[ 0., 1., ],
transform = trans,
**vertical_line_kwargs
)
# Plot label
if add_plot_label:
if plot_label is None:
plt_label = ax.annotate(
s = self.label,
xy = (0.,1.0),
va = 'bottom',
xycoords = 'axes fraction',
fontsize = label_fontsize,
)
elif isinstance( plot_label, str ):
plt_label = ax.annotate(
s = plot_label,
xy = (0.,1.0),
va = 'bottom',
xycoords = 'axes fraction',
fontsize = label_fontsize,
)
elif isinstance( plot_label, dict ):
plt_label = ax.annotate( **plot_label )
elif plot_label is None:
pass
else:
raise Exception( 'Unrecognized plot_label arguments, {}'.format( plot_label ) )
# Add axis labels
if add_x_label:
if x_label is None:
x_label = data_key
ax.set_xlabel( x_label, fontsize=label_fontsize )
if add_y_label:
if y_label is None:
if not cdf:
y_label = r'Normalized Histogram'
else:
y_label = r'CDF'
ax.set_ylabel( y_label, fontsize=label_fontsize )
if x_range is not None:
ax.set_xlim( x_range )
if y_range is not None:
ax.set_ylim( y_range )
ax.set_xscale( x_scale )
ax.set_yscale( y_scale )
if return_dist:
return hist, edges
########################################################################
def histogram2d( self,
x_key, y_key,
x_data = None, y_data = None,
weight_key = None,
x_data_args = {}, y_data_args = {},
weight_data_args = {},
slices = None,
ax = None,
x_range = None, y_range = None,
x_scale = 'linear', y_scale = 'linear', z_scale = 'log',
n_bins = 128,
average = False,
normed = False,
hist_div_arr = None,
conditional_y = False,
y_div_function = None,
vmin = None, vmax = None,
min_bin_value_displayed = None,
zorder = 0,
add_colorbar = True,
cmap = pu_cm.magma,
colorbar_args = None,
x_label = None, y_label = None,
add_x_label = True, add_y_label = True,
plot_label = None,
outline_plot_label = False,
label_galaxy_cut = False,
label_redshift = False,
label_fontsize = 24,
tick_param_args = None,
out_dir = None,
save_file = None,
close_plot_after_saving = True,
fix_invalid = True,
line_slope = None,
cdf = False,
horizontal_line = None, vertical_line = None,
horizontal_line_kwargs = { 'linestyle': '--', 'linewidth': 5, 'color': '#337DB8', },
vertical_line_kwargs = { 'linestyle': '--', 'linewidth': 5, 'color': '#337DB8', },
return_dist = False,
*args, **kwargs ):
'''Make a 2D histogram of the data. Extra arguments are passed to get_selected_data.
Args:
x_key, y_key (str) : Data keys to plot.
weight_key (str) : Data key for data to use as a weight. By None, no weight.
x_data_args, y_data_args (dicts) : Keyword arguments to be passed only to x or y.
slices (int or tuple of slices) : How to slices the data.
ax (axis) : What axis to use. By None creates a figure and places the axis on it.
x_range, y_range ( (float, float) ) : Histogram edges. If None, all data is enclosed. If list, set manually.
If float, is +- x_range*length scale at that snapshot.
n_bins (int) : Number of bins in the histogram.
vmin, vmax (float) : Limits for the colorbar.
aspect (str) : What should the aspect ratio of the plot be?
plot_halos (bool) : Whether or not to plot merger tree halos on top of the histogram.
Only makes sense for when dealing with positions.
add_colorbar (bool) : If True, add a colorbar to colorbar_args
colorbar_args (axis) : What axis to add the colorbar to. By None, is ax.
x_label, ylabel (str) : Axes labels.
add_x_label, add_y_label (bool) : Include axes labels?
plot_label (str or dict) : What to label the plot with. By None, uses self.label.
Can also pass a dict of full args.
outline_plot_label (bool) : If True, add an outline around the plot label.
label_galaxy_cut (bool) : If true, add a label that indicates how the galaxy was defined.
label_redshift (bool) : If True, add a label indicating the redshift.
label_fontsize (int) : Fontsize for the labels.
tick_param_args (args) : Arguments to pass to ax.tick_params. By None, don't change inherent defaults.
out_dir (str) : If given, where to save the file.
fix_invalid (bool) : Fix invalid values.
line_slope (float) : If given, draw a line with the given slope.
'''
if isinstance( slices, int ):
sl = ( slice(None), slices )
else:
sl = slices
varying_kwargs = {
'x': x_data_args,
'y': y_data_args,
'weight': weight_data_args
}
data_kwargs = utilities.dict_from_defaults_and_variations( kwargs, varying_kwargs )
# Get data
if x_data is None:
x_data = self.data_object.get_selected_data( x_key, sl=sl, *args, **data_kwargs['x'] ).copy()
if y_data is None:
y_data = self.data_object.get_selected_data( y_key, sl=sl, *args, **data_kwargs['y'] ).copy()
if y_div_function is not None:
y_div_values = y_div_function( x_data )
y_data /= y_div_values
# Fix NaNs
if fix_invalid:
x_mask = np.ma.fix_invalid( x_data ).mask
y_mask = np.ma.fix_invalid( y_data ).mask
mask = np.ma.mask_or( x_mask, y_mask )
x_data = np.ma.masked_array( x_data, mask=mask ).compressed()
y_data = np.ma.masked_array( y_data, mask=mask ).compressed()
if weight_key is None:
weights = None
else:
weights = self.data_object.get_selected_data(
weight_key,
sl=sl,
*args,
**data_kwargs['weight']
).flatten()
if fix_invalid:
weights = np.ma.masked_array( weights, mask=mask ).compressed()
if x_range is None:
x_range = [ x_data.min(), x_data.max() ]
elif isinstance( x_range, float ):
x_range = np.array( [ -x_range, x_range ])*self.data_object.length_scale[slices]
if y_range is None:
y_range = [ y_data.min(), y_data.max() ]
elif isinstance( y_range, float ):
y_range = np.array( [ -y_range, y_range ])*self.data_object.length_scale[slices]
if x_scale == 'log':
x_edges = np.logspace( np.log10( x_range[0] ), np.log10( x_range[1] ), n_bins )
else:
x_edges = np.linspace( x_range[0], x_range[1], n_bins )
if y_scale == 'log':
y_edges = np.logspace( np.log10( y_range[0] ), np.log10( y_range[1] ), n_bins )
else:
y_edges = np.linspace( y_range[0], y_range[1], n_bins )
# Make the histogram
hist2d, x_edges, y_edges = np.histogram2d( x_data, y_data, [x_edges, y_edges], weights=weights, normed=normed )
# If doing an average, divide by the number in each bin
if average:
average_hist2d, x_edges, y_edges = np.histogram2d( x_data, y_data, [x_edges, y_edges], normed=normed )
hist2d /= average_hist2d
# If making the y-axis conditional, divide by the distribution of data for the x-axis.
if conditional_y:
hist_x, x_edges = np.histogram( x_data, x_edges, normed=normed )
hist2d /= hist_x[:,np.newaxis]
# Divide the histogram bins by this array
if hist_div_arr is not None:
hist2d /= hist_div_arr
# Mask bins below a specified value
if min_bin_value_displayed is not None:
hist2d = np.ma.masked_where(
hist2d < min_bin_value_displayed,
hist2d,
)
# Plot
if ax is None:
fig = plt.figure( figsize=(10,9), facecolor='white' )
ax = plt.gca()
if z_scale == 'linear':
norm = plt_colors.Normalize()
elif z_scale == 'log':
norm = plt_colors.LogNorm()
if cdf:
raise Exception(
"Not implemented yet. When implementing, use utilities.cumsum2d"
)
im = ax.pcolormesh(
x_edges,
y_edges,
hist2d.transpose(),
cmap = cmap,
norm = norm,
vmin = vmin,
vmax = vmax,
zorder = zorder,
)
# Add a colorbar
if add_colorbar:
if colorbar_args is None:
colorbar_args = ax
cbar = gen_plot.add_colorbar( colorbar_args, im, method='ax' )
else:
colorbar_args['color_object'] = im
cbar = gen_plot.add_colorbar( **colorbar_args )
cbar.ax.tick_params( labelsize=20 )
# Plot Line for easier visual interpretation
if line_slope is not None:
line_x = np.array( [ x_data.min(), x_data.max() ] )
line_y = line_slope*line_x
ax.plot( line_x, line_y, linewidth=3, linestyle='dashed', )
if horizontal_line is not None:
trans = transforms.blended_transform_factory( ax.transAxes, ax.transData )
ax.plot( [ 0., 1. ], [ horizontal_line, ]*2, transform=trans, **horizontal_line_kwargs )
if vertical_line is not None:
trans = transforms.blended_transform_factory( ax.transData, ax.transAxes )
ax.plot( [ vertical_line, ]*2, [ 0., 1. ], transform=trans, **vertical_line_kwargs )
# Plot label
if plot_label is not None:
if plot_label is None:
plt_label = ax.annotate(
s = self.label,
xy = (0.,1.0),
va = 'bottom',
xycoords = 'axes fraction',
fontsize = label_fontsize,
)
elif isinstance( plot_label, str ):
plt_label = ax.annotate(
s = plot_label,
xy = (0.,1.0),
va = 'bottom',
xycoords = 'axes fraction',
fontsize = label_fontsize,
)
elif isinstance( plot_label, dict ):
plt_label = ax.annotate( **plot_label )
else:
raise Exception( 'Unrecognized plot_label arguments, {}'.format( plot_label ) )
if outline_plot_label:
plt_label.set_path_effects([ path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal() ])
# Upper right label (info label)
info_label = ''
if label_galaxy_cut:
info_label = r'$r_{ \rm cut } = ' + '{:.3g}'.format( self.data_object.galids.parameters['galaxy_cut'] ) + 'r_{ s}$'
if label_redshift:
try:
info_label = r'$z=' + '{:.3f}'.format( self.data_object.redshift ) + '$'+ info_label
except ValueError:
info_label = r'$z=' + '{:.3f}'.format( self.data_object.redshift.values[sl[1]] ) + '$'+ info_label
if label_galaxy_cut or label_redshift:
ax.annotate( s=info_label, xy=(1.,1.0225), xycoords='axes fraction', fontsize=label_fontsize,
ha='right' )
# Add axis labels
if add_x_label:
if x_label is None:
x_label = x_key
ax.set_xlabel( x_label, fontsize=label_fontsize )
if add_y_label:
if y_label is None:
y_label = y_key
ax.set_ylabel( y_label, fontsize=label_fontsize )
# Limits
ax.set_xlim( x_range )
ax.set_ylim( y_range )
# Scale
ax.set_xscale( x_scale )
ax.set_yscale( y_scale )
# Set tick parameters
if tick_param_args is not None:
ax.tick_params( **tick_param_args )
# Save the file
if out_dir is not None:
if save_file is None:
save_file = '{}_{:03d}.png'.format( self.label, self.data_object.ptracks.snum[slices] )
gen_plot.save_fig( out_dir, save_file, fig=fig, dpi=75 )
if close_plot_after_saving:
plt.close()
# Return?
if return_dist:
return hist2d, x_edges, y_edges
########################################################################
def statistic_and_interval(
self,
x_key, y_key,
x_data = None, y_data = None,
weights = None,
statistic = 'median',
lower_percentile = 16,
upper_percentile = 84,
plot_interval = True,
x_data_args = {}, y_data_args = {},
ax = None,
slices = None,
fix_invalid = False,
bins = 64,
linewidth = 3,
linestyle = '-',
color = 'k',
label = None,
zorder = 100,
alpha = 0.5,
plot_label = None,
add_plot_label = True,
plot_label_kwargs = {
'xy': (0.,1.0),
'va': 'bottom',
'xycoords': 'axes fraction',
'fontsize': 22,
},
return_values = False,
*args, **kwargs
):
if isinstance( slices, int ):
sl = ( slice(None), slices )
else:
sl = slices
varying_kwargs = {
'x': x_data_args,
'y': y_data_args,
}
data_kwargs = utilities.dict_from_defaults_and_variations( kwargs, varying_kwargs )
# Get data
if x_data is None:
x_data = self.data_object.get_selected_data( x_key, sl=sl, *args, **data_kwargs['x'] ).copy()
if y_data is None:
y_data = self.data_object.get_selected_data( y_key, sl=sl, *args, **data_kwargs['y'] ).copy()
# Fix NaNs
if fix_invalid:
x_mask = np.ma.fix_invalid( x_data ).mask
y_mask = np.ma.fix_invalid( y_data ).mask
mask = np.ma.mask_or( x_mask, y_mask )
x_data = np.ma.masked_array( x_data, mask=mask ).compressed()
y_data = np.ma.masked_array( y_data, mask=mask ).compressed()
# Calculate the statistic
if statistic == 'weighted_mean':
assert weights is not None, "Need to provide weights."
weighted_sum, bin_edges, binnumber = scipy.stats.binned_statistic(
x = x_data,
values = y_data * weights,
statistic = 'sum',
bins = bins,
)
weights_sum, bin_edges, binnumber = scipy.stats.binned_statistic(
x = x_data,
values = weights,
statistic = 'sum',
bins = bins,
)
stat = weighted_sum / weights_sum
else:
assert weights is None, "weights only works with weighted_mean"
# Usual statistic
stat, bin_edges, binnumber = scipy.stats.binned_statistic(
x = x_data,
values = y_data,
statistic = statistic,
bins = bins,
)
# Calculate the percentiles
def get_lower_percentile( data ):
return np.percentile( data, lower_percentile )
def get_upper_percentile( data ):
return np.percentile( data, upper_percentile )
low_p, bin_edges, binnumber = scipy.stats.binned_statistic(
x = x_data,
values = y_data,
statistic = get_lower_percentile,
bins = bins,
)
high_p, bin_edges, binnumber = scipy.stats.binned_statistic(
x = x_data,
values = y_data,
statistic = get_upper_percentile,
bins = bins,
)
# Get plotting axis
if ax is None:
fig = plt.figure( figsize=(10,9), facecolor='white' )
ax = plt.gca()
# X Values fo rplot
x_values = bin_edges[:-1] + 0.5 * ( bin_edges[1] - bin_edges[0] )
# Plot statistic
ax.plot(
x_values,
stat,
linewidth = linewidth,
linestyle = linestyle,
color = color,
zorder = zorder,
label = label,
)
# Plot interval
if plot_interval:
ax.fill_between(
x_values,
low_p,
high_p,
color = color,
alpha = alpha,
)
# Add plot label
if add_plot_label:
if plot_label is None:
plot_label = self.label
if plot_label is not None:
plt_label = ax.annotate(
s = plot_label,
**plot_label_kwargs
)
if return_values:
return stat, low_p, high_p, bin_edges
########################################################################
def scatter(
self,
x_key, y_key,
slices = None,
n_subsample = None,
ax = None,
marker_size = 100,
color = 'k',
marker = '.',
zorder = -100,
x_range = None, y_range = None,
x_label = None, y_label = None,
add_x_label = True, add_y_label = True,
plot_label = None,
outline_plot_label = False,
label_galaxy_cut = False,
label_redshift = False,
label_fontsize = 24,
tick_param_args = None,
out_dir = None,
fix_invalid = True,
line_slope = None,
*args, **kwargs ):
'''Make a 2D scatter plot of the data. Extra arguments are passed to get_selected_data.
Args:
x_key, y_key (str) : Data keys to plot.
weight_key (str) : Data key for data to use as a weight. By None, no weight.
slices (int or tuple of slices) : How to slices the data.
ax (axis) : What axis to use. By None creates a figure and places the axis on it.
x_range, y_range ( (float, float) ) : Histogram edges. If None, all data is enclosed. If list, set manually.
If float, is +- x_range*length scale at that snapshot.
n_bins (int) : Number of bins in the histogram.
vmin, vmax (float) : Limits for the colorbar.
plot_halos (bool) : Whether or not to plot merger tree halos on top of the histogram.
Only makes sense for when dealing with positions.
add_colorbar (bool) : If True, add a colorbar to colorbar_args
colorbar_args (axis) : What axis to add the colorbar to. By None, is ax.
x_label, ylabel (str) : Axes labels.
add_x_label, add_y_label (bool) : Include axes labels?
plot_label (str or dict) : What to label the plot with. By None, uses self.label.
Can also pass a dict of full args.
outline_plot_label (bool) : If True, add an outline around the plot label.
label_galaxy_cut (bool) : If true, add a label that indicates how the galaxy was defined.
label_redshift (bool) : If True, add a label indicating the redshift.
label_fontsize (int) : Fontsize for the labels.
tick_param_args (args) : Arguments to pass to ax.tick_params. By None, don't change inherent defaults.
out_dir (str) : If given, where to save the file.
fix_invalid (bool) : Fix invalid values.
line_slope (float) : If given, draw a line with the given slope.
'''
if isinstance( slices, int ):
sl = ( slice(None), slices )
else:
sl = slices
# Get data
x_data = self.data_object.get_selected_data( x_key, sl=sl, *args, **kwargs )
y_data = self.data_object.get_selected_data( y_key, sl=sl, *args, **kwargs )
# Fix NaNs
if fix_invalid:
x_mask = np.ma.fix_invalid( x_data ).mask
y_mask = np.ma.fix_invalid( y_data ).mask
mask = np.ma.mask_or( x_mask, y_mask )
x_data = np.ma.masked_array( x_data, mask=mask ).compressed()
y_data = np.ma.masked_array( y_data, mask=mask ).compressed()
# Subsample
if n_subsample is not None:
sampled_inds = np.random.randint( 0, x_data.size, n_subsample )
x_data = x_data[sampled_inds]
y_data = y_data[sampled_inds]
if x_range is None:
x_range = [ x_data.min(), x_data.max() ]
elif isinstance( x_range, float ):
x_range = np.array( [ -x_range, x_range ])*self.data_object.ptracks.length_scale.iloc[slices]
if y_range is None:
y_range = [ y_data.min(), y_data.max() ]
elif isinstance( y_range, float ):
y_range = np.array( [ -y_range, y_range ])*self.data_object.ptracks.length_scale.iloc[slices]
# Plot
if ax is None:
fig = plt.figure( figsize=(10,9), facecolor='white' )
ax = plt.gca()
s = ax.scatter( x_data, y_data, s=marker_size, color=color, marker=marker )
# Change the z order
s.set_zorder( zorder )
# Halo Plot
if line_slope is not None:
line_x = np.array( [ x_data.min(), x_data.max() ] )
line_y = line_slope*line_x
ax.plot( line_x, line_y, linewidth=3, linestyle='dashed', )
# Plot label
if plot_label is None:
plt_label = ax.annotate( s=self.label, xy=(0.,1.0225), xycoords='axes fraction', fontsize=label_fontsize, )
elif isinstance( plot_label, str ):
plt_label = ax.annotate( s=plot_label, xy=(0.,1.0225), xycoords='axes fraction', fontsize=label_fontsize, )
elif isinstance( plot_label, dict ):
plt_label = ax.annotate( **plot_label )
elif plot_label is None:
pass
else:
raise Exception( 'Unrecognized plot_label arguments, {}'.format( plot_label ) )
if outline_plot_label:
plt_label.set_path_effects([ path_effects.Stroke(linewidth=3, foreground='black'), path_effects.Normal() ])
# Upper right label (info label)
info_label = ''
if label_galaxy_cut:
info_label = r'$r_{ \rm cut } = ' + '{:.3g}'.format( self.data_object.galids.parameters['galaxy_cut'] ) + 'r_{ s}$'
if label_redshift:
info_label = r'$z=' + '{:.3f}'.format( self.data_object.ptracks.redshift.iloc[slices] ) + '$, '+ info_label
if label_galaxy_cut or label_redshift:
ax.annotate( s=info_label, xy=(1.,1.0225), xycoords='axes fraction', fontsize=label_fontsize,
ha='right' )
# Add axis labels
if add_x_label:
if x_label is None:
x_label = x_key
ax.set_xlabel( x_label, fontsize=label_fontsize )
if add_y_label:
if y_label is None:
y_label = y_key
ax.set_ylabel( y_label, fontsize=label_fontsize )
# Limits
ax.set_xlim( x_range )
ax.set_ylim( y_range )
# Set tick parameters
if tick_param_args is not None:
ax.tick_params( **tick_param_args )
# Save the file
if out_dir is not None:
save_file = '{}_{:03d}.png'.format( self.label, self.data_object.ptracks.snum[slices] )
gen_plot.save_fig( out_dir, save_file, fig=fig, dpi=75 )
plt.close()
########################################################################
def plot_stacked_data(
self,
x_key,
y_keys,
colors,
ax = None,
*args, **kwargs
):
if ax is None:
plt.figure( figsize=(11, 5), facecolor='white' )
ax = plt.gca()
y_prev = np.zeros( shape=y_datas.values()[0].shape )
y_datas = []
for y_key in y_keys:
y_data = self.data_object.get_selected_data(
y_key,
*args, **kwargs
).copy()
y_datas.append( y_data )
for i, y_key in y_keys:
y_next = y_prev + y_datas[i]
ax.fill_between(
x_data,
y_prev,
y_next,
color = classification_colors[key],
alpha = p_constants.CLASSIFICATION_ALPHA,
)
# Make virtual artists to allow a legend to appear
color_object = matplotlib.patches.Rectangle(
(0, 0),
1,
1,
fc = classification_colors[key],
ec = classification_colors[key],
alpha = p_constants.CLASSIFICATION_ALPHA,
)
color_objects.append( color_object )
labels.append( p_constants.CLASSIFICATION_LABELS[key] )
ax.annotate(
s=self.label,
xy=(0., 1.0225),
xycoords='axes fraction',
fontsize=22,
)
ax.legend(
color_objects,
labels,
prop={'size': 14.5},
ncol=5,
loc=(0., -0.28),
fontsize=20
)
########################################################################
def plot_time_dependent_data( self,
ax = None,
x_range = [ 0., np.log10(8.) ], y_range = None,
y_scale = 'log',
x_label = None, y_label = None,
):
'''Make a plot like the top panel of Fig. 3 in Angles-Alcazar+17
Args:
ax (axis object) :
What axis to put the plot on. By None, create a new one on a separate figure.
x_range, y_range (list-like) :
[ x_min, x_max ] or [ y_min, y_max ] for the displayed range.
x_label, y_label (str) :
Labels for axis. By None, redshift and f(M_star), respectively.
plot_dividing_line (bool) :
Whether or not to plot a line at the edge between stacked regions.
'''
if ax is None:
fig = plt.figure( figsize=(11,5), facecolor='white' )
ax = plt.gca()
x_data = np.log10( 1. + self.data_object.get_data( 'redshift' ) )
y_data = self.data_object.get_categories_galaxy_mass()
for key in p_constants.CLASSIFICATION_LIST_A[::-1]:
y_data = y_datas[key]
ax.plot(
x_data,
y_data,
linewidth = 3,
color = p_constants.CLASSIFICATION_COLORS[key],
label = p_constants.CLASSIFICATION_LABELS[key],
)
if x_range is not None:
ax.set_xlim( x_range )
if y_range is not None:
ax.set_ylim( y_range )
ax.set_yscale( y_scale )
tick_redshifts = np.array( [ 0.25, 0.5, 1, 2, 3, 4, 5, 6, 7, ] )
x_tick_values = np.log10( 1. + tick_redshifts )
plt.xticks( x_tick_values, tick_redshifts )
ax.set_xlabel( r'z', fontsize=22, )
ax.set_ylabel( r'$M_{\star} (M_{\odot})$', fontsize=22, )
ax.annotate( s=self.label, xy=(0.,1.0225), xycoords='axes fraction', fontsize=22, )
ax.legend( prop={'size':14.5}, ncol=5, loc=(0.,-0.28), fontsize=20 )
########################################################################
# Generic Plotting Methods
########################################################################
def same_axis_plot(
self,
axis_plotting_method_str,
variations,
ax = None,
figsize = (11, 5),
out_dir = None,
add_line_label = False,
legend_args = { 'prop': {'size': 16.5}, 'loc': 'upper right', 'fontsize': 20 },
*args, **kwargs
):
if ax is None:
fig = plt.figure( figsize=figsize, facecolor='white', )
ax = plt.gca()
all_plotting_kwargs = utilities.dict_from_defaults_and_variations( kwargs, variations )
axis_plotting_method = getattr( self, axis_plotting_method_str )
for key, plotting_kwargs in all_plotting_kwargs.items():
plotting_kwargs['ax'] = ax
if add_line_label:
plotting_kwargs['line_label'] = key
axis_plotting_method( *args, **plotting_kwargs )
ax.legend( **legend_args )
# Save the file
if out_dir is not None:
save_file = '{}_{:03d}.png'.format( self.label, self.data_object.ptracks.snum[kwargs['slices']] )
gen_plot.save_fig( out_dir, save_file, fig=fig, dpi=75 )
plt.close()
########################################################################
def panel_plot( self,
panel_plotting_method_str,
defaults,
variations,
slices = None,
n_rows = 2,
n_columns = 2,
plot_locations = [ (0,0), (0,1), (1,0), (1,1) ],
figsize = (10,9),
plot_label = None,
outline_plot_label = False,
label_galaxy_cut = False,
label_redshift = True,
label_fontsize = 24,
subplot_label_args = { 'xy': (0.075, 0.88), 'xycoords': 'axes fraction', 'fontsize': 20, 'color': 'k', },
subplot_spacing_args = { 'hspace': 0.0001, 'wspace': 0.0001, },
out_dir = None,
):
'''
Make a multi panel plot of the type of your choosing.
Args:
panel_plotting_method_str (str) : What type of plot to make.
defaults (dict) : Default arguments to pass to panel_plotting_method.
variations (dict of dicts) : Differences in plotting arguments per subplot.
slices (slice) : What slices to select. By None, this doesn't pass any slices argument to panel_plotting_method
plot_label (str or dict) : What to label the plot with. By None, uses self.label.
Can also pass a dict of full args.
outline_plot_label (bool) : If True, add an outline around the plot label.
label_galaxy_cut (bool) : If true, add a label that indicates how the galaxy was defined.
label_redshift (bool) : If True, add a label indicating the redshift.
label_fontsize (int) : Fontsize for the labels.
subplot_label_args (dict) : Label arguments to pass to each subplot for the label for the subplot.
The actual label string itself corresponds to the keys in variations.
subplot_spacing_args (dict) : How to space the subplots.
out_dir (str) : If given, where to save the file.
'''
fig = plt.figure( figsize=figsize, facecolor='white', )
ax = plt.gca()
fig.subplots_adjust( **subplot_spacing_args )
if slices is not None:
defaults['slices'] = slices
plotting_kwargs = utilities.dict_from_defaults_and_variations( defaults, variations )
# Setup axes
gs = gridspec.GridSpec(n_rows, n_columns)
axs = []
for plot_location in plot_locations:
axs.append( plt.subplot( gs[plot_location] ) )
# Setup arguments further
for i, key in enumerate( plotting_kwargs.keys() ):
ax_kwargs = plotting_kwargs[key]
ax_kwargs['ax'] = axs[i]
# Subplot label args
this_subplot_label_args = subplot_label_args.copy()
this_subplot_label_args['s'] = key
ax_kwargs['plot_label'] = this_subplot_label_args
if ax_kwargs['add_colorbar']:
ax_kwargs['colorbar_args'] = { 'fig_or_ax': fig, 'ax_location': [0.9, 0.125, 0.03, 0.775 ], }
# Clean up interior axes
ax_tick_parm_args = ax_kwargs['tick_param_args'].copy()
plot_location = plot_locations[i]
# Hide repetitive x labels
if plot_location[0] != n_rows -1 :
ax_kwargs['add_x_label'] = False
ax_tick_parm_args['labelbottom'] = False
# Hide repetitive y labels
if plot_location[1] != 0:
ax_kwargs['add_y_label'] = False
ax_tick_parm_args['labelleft'] = False
ax_kwargs['tick_param_args'] = ax_tick_parm_args
# Actual panel plots
panel_plotting_method = getattr( self, panel_plotting_method_str )
for key in plotting_kwargs.keys():
panel_plotting_method( **plotting_kwargs[key] )
# Main axes labels
# Plot label
if plot_label is None:
plt_label = axs[0].annotate( s=self.label, xy=(0.,1.0225), xycoords='axes fraction', fontsize=label_fontsize, )
elif isinstance( plot_label, str ):
plt_label = axs[0].annotate( s=plot_label, xy=(0.,1.0225), xycoords='axes fraction', fontsize=label_fontsize, )
elif isinstance( plot_label, dict ):
plt_label = axs[0].annotate( **plot_label )
else:
raise Exception( 'Unrecognized plot_label arguments, {}'.format( plot_label ) )
if outline_plot_label:
plt_label.set_path_effects([
path_effects.Stroke(linewidth=3, foreground='white', background='white'),
path_effects.Normal()
])
# Upper right label (info label)
info_label = ''
if label_galaxy_cut:
info_label = r'$r_{ \rm cut } = ' + '{:.3g}'.format( self.data_object.galids.parameters['galaxy_cut'] ) + 'r_{ s}$'
if label_redshift:
ind = defaults['slices']
info_label = r'$z=' + '{:.3f}'.format( self.data_object.ptracks.redshift.iloc[ind] ) + '$'+ info_label
if label_galaxy_cut or label_redshift:
label_ax = plt.subplot( gs[0,n_columns-1,] )
label_ax.annotate(
s=info_label,
xy=(1.,1.0225),
xycoords='axes fraction',
fontsize=label_fontsize,
ha='right'
)
# Save the file
if out_dir is not None:
save_file = '{}_{:03d}.png'.format( self.label, self.data_object.ptracks.snum[slices] )
gen_plot.save_fig( out_dir, save_file, fig=fig )
plt.close()
########################################################################
def make_multiple_plots( self,
plotting_method_str,
iter_args_key,
iter_args,
n_processors = 1,
out_dir = None,
make_out_dir_subdir = True,
make_movie = False,
clear_data = False,
*args, **kwargs ):
'''Make multiple plots of a selected type. *args and **kwargs are passed to plotting_method_str.
Args:
plotting_method_str (str) : What plotting method to use.
iter_args_key (str) : The name of the argument to iterate over.
iter_args (list) : List of argument values to change.
n_processors (int) : Number of processors to use. Should only be used when saving the data.
out_dir (str) : Where to save the data.
make_movie (bool) : Make a movie out of the plots, if True.
clear_data (bool) : If True, clear memory of the data after making the plots.
'''
plotting_method = getattr( self, plotting_method_str )
if ( out_dir is not None ) and make_out_dir_subdir:
out_dir = os.path.join( out_dir, self.label )
def plotting_method_wrapper( process_args ):
used_out_dir, used_args, used_kwargs = process_args
plotting_method( out_dir=used_out_dir, *used_args, **used_kwargs )
del used_out_dir, used_args, used_kwargs
return
all_process_args = []
for iter_arg in iter_args:
process_kwargs = dict( kwargs )
process_kwargs[iter_args_key] = iter_arg
all_process_args.append( ( out_dir, args, process_kwargs ) )
if n_processors > 1:
# For safety, make sure we've loaded the data already
self.data_object.ptracks, self.data_object.galids, self.data_object.classifications
mp_utils.parmap( plotting_method_wrapper, all_process_args, n_processors=n_processors, return_values=False )
else:
for i, iter_arg in enumerate( iter_args ):
plotting_method_wrapper( all_process_args[i] )
if make_movie:
gen_plot.make_movie( out_dir, '{}_*.png'.format( self.label ), '{}.mp4'.format( self.label ), )
if clear_data:
del self.data_object.ptracks
del self.data_object.galids
del self.data_object.classifications
########################################################################
########################################################################
class PlotterSet( verdict.Dict ):
'''Container for multiple plotters that is an enhanced dictionary.
'''
def __init__( self, data_object_cls, plotter_object_cls, defaults, variations ):
'''
Args:
data_object_cls (object) : Class for the data object.
plotter_object_cls (object) : Class for the plotter object.
defaults (dict) : Set of None arguments for loading worldline data.
variations (dict of dicts) : Labels and differences in arguments to be passed to Worldlines
'''
# Load the worldline sets
storage = {}
for key in variations.keys():
kwargs = dict( defaults )
for var_key in variations[key].keys():
kwargs[var_key] = variations[key][var_key]
storage[key] = { 'data_object': data_object_cls( **kwargs ), 'label': key }
plotters_storage = utilities.SmartDict.from_class_and_args( plotter_object_cls, storage )
super( PlotterSet, self ).__init__( plotters_storage )