#!/usr/bin/env python
'''Class for analyzing simulation data.
@author: Zach Hafen
@contact: zachary.h.hafen@gmail.com
@status: Development
'''
# Base python imports
import copy
from functools import wraps
import h5py
import numpy as np
import numpy.testing as npt
import warnings
import galaxy_dive.utils.constants as constants
import galaxy_dive.utils.utilities as utilities
########################################################################
# For default values
default = object()
########################################################################
########################################################################
class GenericData( object ):
'''Very generic data class, with getting and masking functionality.
Args:
key_parser (object) : KeyParser instance to use to interpret data keys.
data_masker (object) : DataMasker instance to use to filter and mask data.
verbose (bool) : Print out additional information.
z_sun (float) : Used mass fraction for solar metallicity.
'''
@utilities.store_parameters
def __init__( self,
key_parser = None,
data_masker = None,
verbose = False,
z_sun = constants.Z_MASSFRAC_SUN,
**kwargs ):
# For storing and creating masks to pass the data
if data_masker is None:
self.data_masker = DataMasker( self )
# Setup a data key parser
if key_parser is None:
self.key_parser = DataKeyParser()
########################################################################
# Properties
########################################################################
@property
def length_scale( self ):
'''Property for fiducial length scale. By default, is 1.
However, a more advanced subclass might set this differently, or this might
change in the future.
'''
# TODO: Address this.
raise Exception( "Current thinking: this should not be called." )
return 1.
########################################################################
@property
def velocity_scale( self ):
'''Property for fiducial velocity scale. By default is 1.
However, a more advanced subclass might set this differently, or this might
change in the future.
'''
# TODO: Address this.
raise Exception( "Current thinking: this should not be called." )
return 1.
########################################################################
@property
def metallicity_scale( self ):
'''Property for fiducial metallicity scale. By default is z_sun
However, a more advanced subclass might set this differently, or this might
change in the future.
'''
return self.z_sun
########################################################################
@property
def base_data_shape( self ):
'''Property for simulation redshift.'''
if not hasattr( self, '_base_data_shape' ):
self._base_data_shape = self.data.values()[0].shape
return self._base_data_shape
@base_data_shape.setter
def base_data_shape( self, value ):
'''Setting function for simulation redshift property.'''
# If we try to set it, make sure that if it already exists we don't change it.
if hasattr( self, '_base_data_shape' ):
assert self._base_data_shape == value
else:
self._base_data_shape = value
########################################################################
# Data Retrieval
########################################################################
def get_data( self, data_key, sl=None, data_storage=default ):
'''Get the data from within the class. Only for getting the data. No post-processing or changing the data
(putting it in particular units, etc.) The idea is to calculate necessary quantities as the need arises,
hence a whole function for getting the data.
Args:
data_key (str) : Key in the data dictionary for the key we want to get
sl (slice) : Slice of the data, if requested.
Returns:
data (np.ndarray) : Requested data.
'''
if data_storage is default:
data_storage = self.data
# Loop through, handling issues
n_tries = 10
for i in range( n_tries ):
try:
# Arbitrary functions of the data
if data_key == 'Function':
raise Exception( "TODO: Test this" )
# Use the keys to get the data
function_data_keys = self.kwargs['function_args']['function_data_keys']
input_data = [ self.get_data(function_data_key) for function_data_key in function_data_keys ]
# Apply the function
data = self.kwargs['function_args']['function']( input_data )
# Other
else:
data = self.data[data_key]
# Calculate missing data
except KeyError as e:
self.handle_data_key_error( data_key )
continue
break
if 'data' not in locals().keys():
raise KeyError( "After {} tries, unable to find or create data_key, {}".format( i+1, data_key ) )
if sl is not None:
return data[sl]
return data
########################################################################
def get_processed_data( self, data_key, data_method=default, *args, **kwargs ):
'''Get post-processed data. (Accounting for fractions, log-space, etc.).
Args:
data_key (str) : What data to get.
data_method (str) : What method to use for getting the data itself. Defaults to using self.get_data
*args, **kwargs : Passed to get_data()
Returns:
processed_data (np.ndarray) : Requested data, including formatting.
'''
# Account for fractional data keys
data_key, fraction_flag = self.key_parser.is_fraction_key( data_key )
# Account for logarithmic data
data_key, log_flag = self.key_parser.is_log_key( data_key )
# Choose what method we're using for getting data.
if data_method is default:
get_data_method = self.get_data
else:
get_data_method = getattr( self, data_method )
# Get the data and make a copy to avoid altering
data = copy.deepcopy( get_data_method( data_key, *args, **kwargs ) )
# Actually calculate the fractional data
if fraction_flag:
# Put distances in units of the virial radius
if self.key_parser.is_position_key( data_key ):
data /= self.length_scale
# Put velocities in units of the circular velocity
elif self.key_parser.is_velocity_key( data_key ):
data /= self.velocity_scale
# Put the metallicity in solar units
elif data_key == 'Z':
data /= self.metallicity_scale
else:
raise Exception('Fraction type not recognized')
# Make appropriate units into log
if log_flag:
data = np.log10( data )
return data
########################################################################
def get_selected_data( self, *args, **kwargs ):
'''Wrapper for getting masked data.'''
return self.data_masker.get_selected_data( *args, **kwargs )
def mask_data( self, *args, **kwargs ):
'''Wrapper for masking data.'''
return self.data_masker.mask_data( *args, **kwargs )
########################################################################
def shift( self, data, data_key ):
'''Shift or multiply the data by some amount. Note that this is applied after logarithms are applied.
data : data to be shifted
data_key : Data key for the parameters to be shifted
Parameters are a subdictionary of self.kwargs
'data_key' : What data key the data is shifted for
'''
raise Exception( "TODO: Test this" )
shift_p = self.kwargs['shift']
# Exit early if not the right data key
if shift_p['data_key'] != data_key:
return 0
# Shift by the mass metallicity relation.
if 'MZR' in shift_p:
# Gas-phase ISM MZR fit from Ma+2015.
# Set to be no shift for gas with the mean metallicity of LLS at z=0 (This has a value of -0.45, as I last calculated)
log_shift = 0.93*(np.exp(-0.43*self.redshift) - 1.) - 0.45
data -= log_shift
########################################################################
def handle_data_key_error( self, data_key ):
'''Method for attempting to generate data on the fly.
Args:
data_key (str) : Type of data to attempt to generate data for.
'''
raise Exception( "This method should be replaced in the subclass!" )
########################################################################
# Meta Methods (for doing generic things with class methods)
########################################################################
def iterate_over_method( self, method_str, iter_arg, iter_values, method_args ):
'''Iterate over a specified method, and get the results out.
Args:
method_str (str) :
Which method to use.
iter_arg (str) :
Which argument of the method to iterate over.
iter_values (list of values) :
Which values to change.
method_args (dict) :
Default args to pass to the method
Returns:
results (list) : [ method( **used_args ) for used_args in all_variations_of_used_args ]
'''
method = getattr( self, method_str )
results = []
for iter_value in iter_values:
method_args[iter_arg] = iter_value
result = method( **method_args )
results.append( result )
return results
########################################################################
########################################################################
class DataKeyParser( object ):
'''Class for parsing data_keys provided to SimulationData.'''
########################################################################
def is_position_key( self, data_key ):
'''Checks if the data key deals with position primarily.
Args:
data_key (str) : Data key to check.
Returns:
is_position_key (bool) : True if deals with position.
'''
return ( data_key[0] == 'R' ) | ( data_key == 'P' )
########################################################################
def is_velocity_key( self, data_key ):
'''Checks if the data key deals with velocity primarily.
Args:
data_key (str) : Data key to check.
Returns:
is_velocity_key (bool) : True if deals with velocity.
'''
return ( data_key[0] == 'V' )
########################################################################
def is_fraction_key( self, data_key ):
'''Check if the data should be some fraction of its relevant scale.
Args:
data_key (str) : Data key to check.
Returns:
is_fraction_key (bool) : True if the data should be scaled as a fraction of the relevant scale.
'''
fraction_flag = False
if data_key[-1] == 'f':
data_key = data_key[:-1]
fraction_flag = True
return data_key, fraction_flag
########################################################################
def is_log_key( self, data_key ):
'''Check if the data key indicates the data should be put in log scale.
Args:
data_key (str) : Data key to check.
Returns:
is_log_key (bool) : True if the data should be taken log10 of
'''
log_flag = False
if data_key[0:3] == 'log':
data_key = data_key[3:]
log_flag = True
return data_key, log_flag
########################################################################
########################################################################
class DataMasker( object ):
def __init__( self, data_object ):
'''Class for masking data.
Args:
data_object (GenericData object) : Used for getting data to find mask ranges.
'''
self.data_object = data_object
self.masks = []
self.optional_masks = {}
########################################################################
def mask_data( self,
data_key,
data_min = default,
data_max = default,
data_value = default,
custom_mask = default,
return_or_store = 'store',
optional_mask = False,
mask_name = default,
*args, **kwargs ):
'''Get only the particle data within a certain range. Note that it retrieves the processed data.
Args:
data_key (str) :
Data key to base the mask off of.
data_min (float) :
Everything below data_min will be masked.
data_max (float) :
Everything above data_max will be masked.
data_value (float) :
Everything except for data_value will be masked.
custom_mask (bool) :
If provided, take in a custom mask instead, using data_key as the label for the mask.
return_or_store (str) :
Whether to store the mask as part of the masks dictionary, or to return it.
optional_mask (bool) :
If True, store in the dictionary self.optional_masks instead.
mask_name (str) :
What name to associate with this mask? Currently only relevant if optional_mask is True.
By default uses the data_key as the name.
*args, **kwargs :
Passed to self.data_object.get_processed_data()
Returns:
data_mask (np.array of bools) :
If requested, the mask for data in that range.
Modifies:
self.masks (list of dicts) :
Appends a dictionary describing the mask.
'''
# Process what type of mask to get
mask_outside = ( data_min is not default ) and ( data_max is not default )
mask_discrete = data_value is not default
mask_custom = custom_mask is not default
assert ( mask_outside + mask_discrete + mask_custom )==1, "Bad combination of masks!"
if not mask_custom:
data = self.data_object.get_processed_data( data_key, *args, **kwargs )
# Get the mask
if mask_outside:
data_ma = np.ma.masked_outside( data, data_min, data_max )
mask = data_ma.mask
elif mask_discrete:
mask = np.invert( data == data_value )
elif mask_custom:
mask = custom_mask
else:
raise NameError( "Unspecified combination of data masking." )
# Handle the case where an entire array is masked or none of it is masked
# (Make into an array for easier combination with other masks)
if mask.size == 1:
mask = mask*np.ones( shape=data.shape, dtype=bool )
if return_or_store == 'store':
if mask_outside:
mask_dict = {'data_key': data_key, 'data_min': data_min, 'data_max': data_max, 'mask': mask}
elif mask_discrete:
mask_dict = {'data_key': data_key, 'data_value' : data_value, 'mask': mask}
elif mask_custom:
mask_dict = {'data_key': data_key, 'custom_mask' : True, 'mask': mask}
if optional_mask:
if mask_name is default:
mask_name = data_key
assert mask_name not in self.optional_masks.keys(), "A mask with that name already exists!"
self.optional_masks[mask_name] = mask_dict
else:
self.masks.append( mask_dict )
elif return_or_store == 'return':
return mask
else:
raise Exception('NULL return_or_store')
########################################################################
def get_total_mask( self, optional_masks=None ):
'''Get the result of combining all masks in the data.
Args:
optional_masks (list-like) : List of names of optional masks to use (must be found in self.optional_masks).
Returns:
total_mask (np.array of bools) : Result of all masks.
'''
# Compile masks
all_masks = []
for mask_dict in self.masks:
all_masks.append( mask_dict['mask'] )
# Get any requested optional masks
if optional_masks is not None:
for optional_mask in optional_masks:
all_masks.append( self.optional_masks[optional_mask]['mask'] )
# Combine masks
return np.any( all_masks, axis=0, keepdims=True )[0]
########################################################################
def get_selected_data( self,
data_key,
mask = 'total',
optional_masks = None,
sl = None,
apply_slice_to_mask = True,
fix_invalid = False,
compress = True,
mask_multidim_data = True,
*args, **kwargs ):
'''Get all the data that doesn't have some sort of mask applied to it. Use the processed data.
Args:
data_key (str) : Data key to get the data for.
mask (str or np.array of bools) : Mask to apply. If none, use the total mask.
optional_masks (list-like) : List of names of optional masks to use (must be found in self.optional_masks).
sl (slice) : Slice to apply to the data
apply_slice_to_mask (bool) : Whether or not to apply the same slice you applied to the data to the mask.
fix_invalid (bool) : Whether or not to also mask invalid data.
compress (bool) : Whether or not to return compressed data.
mask_multidim_data (bool) : Whether or not to change the mask to fit multidimensional data.
*args, **kwargs : Passed to get_proceesed_data.
Returns:
data_ma (np.array) : Compressed masked data. Because it's compressed it may not have the same shape as the
original data.
'''
data = self.data_object.get_processed_data( data_key, sl=sl, *args, **kwargs )
# Get the appropriate mask
if isinstance( mask, np.ndarray ):
used_mask = mask
elif isinstance( mask, bool ) or isinstance( mask, np.bool_ ):
if not mask:
if fix_invalid:
return np.ma.fix_invalid( data ).compressed()
else:
return data
raise Exception( "All data is masked." )
elif mask == 'total':
used_mask = self.get_total_mask( optional_masks=optional_masks )
else:
raise KeyError( "Unrecognized type of mask, {}".format( mask ) )
if ( sl is not None ) and apply_slice_to_mask:
used_mask = used_mask[sl]
if fix_invalid:
array_to_ma_array_fn = np.ma.fix_invalid
else:
array_to_ma_array_fn = np.ma.array
# Test for if the data fits the mask, or if it's multi-dimensional
if mask_multidim_data:
if len( data.shape ) > len( self.data_object.base_data_shape ):
data_ma = [ array_to_ma_array_fn( data_part, mask=used_mask ) for data_part in data ]
data_ma = [ data_ma_part.compressed() for data_ma_part in data_ma ]
data_ma = np.array( data_ma )
else:
data_ma = array_to_ma_array_fn( data, mask=used_mask )
if compress:
data_ma = data_ma.compressed()
else:
data_ma = array_to_ma_array_fn( data, mask=used_mask )
if compress:
data_ma = data_ma.compressed()
return data_ma
########################################################################
def clear_masks( self, clear_optional_masks=False ):
'''Reset the masks in total to nothing.
Modifies:
self.masks (lists) : Sets to empty
'''
self.masks = []
if clear_optional_masks:
self.optional_masks = {}