Source code for galaxy_dive.analyze_data.simulation_data

#!/usr/bin/env python
'''Class for analyzing simulation data.

@author: Zach Hafen
@contact: zachary.h.hafen@gmail.com
@status: Development
'''

# Base python imports
import copy
import numpy as np
import numpy.testing as npt
import pandas as pd
import scipy
import scipy.signal as signal

# Imports from my own stuff
import galaxy_dive.analyze_data.ahf as analyze_ahf
import galaxy_dive.read_data.ahf as read_ahf
import galaxy_dive.utils.astro as astro
import galaxy_dive.utils.constants as constants
import galaxy_dive.utils.data_operations as data_operations
import galaxy_dive.analyze_data.generic_data as generic_data
import galaxy_dive.utils.utilities as utilities

########################################################################
########################################################################


class SimulationData( generic_data.GenericData ):
    '''Class for handling simulation data.

    Args:
        data_dir (str) :
             Directory the simulation is contained in.

        halo_data_dir (str) :
             Directory simulation halo catalogs are contained in.
            Defaults to data_dir

        ahf_index (str) :
             What to index the snapshots by.
             Should be the last snapshot in the simulation *if*
             AHF was run backwards from the last snapshot.
             Required to put in manually to avoid easy mistakes.

        averaging_frac (dict) :
            There are some averaged quantities (e.g. COM velocity, total
            angular momentum) that we consider. This is the fraction of
            length_scale_used within which we calculate these.

        length_scale_used (str) :
             What length scale to use for the simulation.
             Will be used to put lengths in fractions.

        z_sun (float) :
            Used mass fraction for solar metallicity.

        halo_data_retrieved (bool) :
            Whether or not we retrieved relevant values from the AHF halo data.

        centered (bool):
            Whether or not the coordinates are centered on the galaxy of
            choice at the start.

        vel_centered (bool) :
            Whether or not the velocities are relative to the galaxy of
            choice at the start.

        hubble_corrected (bool) :
            Whether or not the velocities have had the Hubble flow added
            (velocities must be centered).

        ahf_tag (str) :
            Identifying tag for the ahf merger tree halo files,
            looks for ahf files of type 'halo_00000_{}.dat'.format( tag ).

        main_halo_id (int) :
             What is the halo ID of the main galaxy in the simulation?

        center_method (str or np.array) :
            How to center the coordinates. Options are...

            'halo' (default) : Centers the dataset on the main halo
                (main_halo_id) using AHF halo data.

            np.array : Array of coordinates on which to center the data

        vel_center_method (str or np.array of size 3) :
            How to center the velocity coordinates, i.e. what the
            velocity is relative to. Options are...

            'halo' (default) : Sets velocity relative to the main halo
                (main_halo_id) using AHF halo data.

            np.array of size 3 :
                Centers the dataset on this coordinate.

    Keyword Args:
        function_args (dict):
            Dictionary of args used to specify an arbitrary function with
            which to generate data.
    '''

    @utilities.store_parameters
    def __init__(
        self,
        data_dir = None,
        halo_data_dir = None,
        ahf_index = None,
        averaging_frac = 4.,
        length_scale_used = 'Rstar0.5',
        z_sun = constants.Z_MASSFRAC_SUN,
        halo_data_retrieved = False,
        centered = False,
        vel_centered = False,
        hubble_corrected = False,
        ahf_tag = 'smooth',
        main_halo_id = 0,
        center_method = 'halo',
        vel_center_method = 'halo',
        store_ahf_reader = False,
        **kwargs
    ):

        # Make sure that all the arguments have been specified.
        for attr in vars( self ).keys():
            if attr == 'kwargs':
                continue
            if getattr( self, attr ) is None:

                # Set the analysis dir to data_dir if not given
                if attr == 'halo_data_dir':
                    self.halo_data_dir = self.data_dir

                elif attr == 'ahf_index':
                    continue

                else:
                    raise Exception( '{} not specified'.format( attr ) )

        # By definition, the halo data should not be retrieved when the class is first initiated.
        self.halo_data_retrieved = False

        super( SimulationData, self ).__init__( z_sun=z_sun, **kwargs )

    ########################################################################
    # Properties
    ########################################################################

    @property
    def length_scale( self ):
        '''Property for fiducial simulation length scale.'''

        if self.length_scale_used == 'R_vir':
            return self.r_vir
        elif self.length_scale_used == 'r_scale':
            return self.r_scale
        else:
            return self.halo_data.get_mt_data(
                self.length_scale_used,
                snums = self.snum,
                return_values_only = False, # This is because we're only getting one value out
                a_power = 1.,
            ) / self.data_attrs['hubble']

    ########################################################################

    @property
    def velocity_scale( self ):
        '''Property for fiducial simulation velocity scale.'''

        return self.v_c

    ########################################################################

    @property
    def base_data_shape( self ):
        '''Property for simulation redshift.'''

        if not hasattr( self, '_base_data_shape' ):

            # Use Density as the default data we assume will usually be there.
            if 'Den' in self.data:
                self._base_data_shape = self.data['Den'].shape
            # If it doesn't have density, it might have mass
            elif 'M' in self.data:
                self._base_data_shape = self.data['M'].shape
            else:
                raise Exception( "No data key to base shape off of." )

        return self._base_data_shape

    @base_data_shape.setter
    def base_data_shape( self, value ):
        '''Setting function for simulation redshift property.'''

        # If we try to set it, make sure that if it already exists we don't change it.
        if hasattr( self, '_base_data_shape' ):
            assert self._base_data_shape == value

        else:
            self._base_data_shape = value

    ########################################################################

    @property
    def halo_data( self ):
        '''Halo data used.
        '''

        if not hasattr( self, '_halo_data' ):

            self._halo_data = analyze_ahf.HaloData(
                self.halo_data_dir,
                tag = self.ahf_tag,
                index = self.ahf_index
            )

        return self._halo_data

    ########################################################################

    @property
    def redshift( self ):
        '''Property for simulation redshift.'''

        if not hasattr( self, '_redshift' ):

            # Try to get it from the attributes.
            if 'redshift' in self.data_attrs:
                self._redshift = self.data_attrs['redshift']
            elif hasattr( self, 'data' ):
                if 'redshift' in self.data:
                    self._redshift = self.data['redshift']
            # If not, retrieve halo data, which should set it.
            # In fact, if we call self.retrieve_halo_data() somewhere else and we already set redshift by getting it from
            # the attributes, it will check that it matches.
            else:
                self.retrieve_halo_data()

        return self._redshift

    @redshift.setter
    def redshift( self, value ):
        '''Setting function for simulation redshift property.'''

        # If we try to set it, make sure that if it already exists we don't change it.
        if hasattr( self, '_redshift' ):

            if isinstance( value, np.ndarray ) or isinstance( self._redshift, np.ndarray ):

                is_nan = np.any( [ np.isnan( value ), np.isnan( self._redshift ) ], axis=1 )
                not_nan_inds = np.where( np.invert( is_nan ) )[0]

                test_value = np.array(value)[not_nan_inds]  # Cast as np.ndarray because Pandas arrays can cause trouble.
                test_existing_value = np.array(self._redshift)[not_nan_inds]
                npt.assert_allclose( test_value, test_existing_value, atol=1e-5 )

                self._redshift = value

            else:
                npt.assert_allclose( value, self._redshift, atol=1e-5 )

        else:
            self._redshift = value

    ########################################################################

    @property
    def r_vir( self ):
        '''Property for virial radius.'''

        if not hasattr( self, '_r_vir' ):
            self.retrieve_halo_data()

        return self._r_vir

    @r_vir.setter
    def r_vir( self, value ):

        # If we try to set it, make sure that if it already exists we don't change it.
        if hasattr( self, '_r_vir' ):
            npt.assert_allclose( value, self._r_vir )

        else:
            self._r_vir = value

    ########################################################################

    @property
    def r_scale( self ):
        '''Property for scale radius.'''

        if not hasattr( self, '_r_scale' ):
            self.retrieve_halo_data()

        return self._r_scale

    @r_scale.setter
    def r_scale( self, value ):

        # If we try to set it, make sure that if it already exists we don't change it.
        if hasattr( self, '_r_scale' ):
            npt.assert_allclose( value, self._r_scale )

        else:
            self._r_scale = value

    ########################################################################

    @property
    def v_c( self ):
        '''Property for circular velocity.'''

        if not hasattr( self, '_v_c' ):
            self.retrieve_halo_data()

        return self._v_c

    @v_c.setter
    def v_c( self, value ):

        # If we try to set it, make sure that if it already exists we don't change it.
        if hasattr( self, '_v_c' ):
            npt.assert_allclose( value, self._v_c )

        else:
            self._v_c = value

    ########################################################################

    @property
    def hubble_z( self ):
        '''Property for the hubble function at specified redshift.'''

        if not hasattr( self, '_hubble_z' ):
            self._hubble_z = astro.hubble_parameter(
                self.redshift,
                h=self.data_attrs['hubble'],
                omega_matter=self.data_attrs['omega_matter'],
                omega_lambda=self.data_attrs['omega_lambda'],
                units='km/s/kpc'
            )

        return self._hubble_z

    ########################################################################
    # Overall changes to the data
    ########################################################################

    def center_coords( self ):
        '''Change the location of the origin, if the data isn't already centered.

        Modifies:
            self.data['P'] : Shifts the coordinates to the center.
        '''

        if self.centered:
            return

        if isinstance( self.center_method, np.ndarray ):
            self.origin = copy.copy( self.center_method )

        elif self.center_method == 'halo':
            self.retrieve_halo_data()
            self.origin = copy.copy( self.halo_coords )

        else:
            raise KeyError( "Unrecognized center_method, {}".format( self.center_method ) )

        # Do it like this because we don't know the shape of self.data['P'][0]
        for i in range( 3 ):
            self.data['P'][i] -= self.origin[i]

        # Note that we're now centered
        self.centered = True

    ########################################################################

    def center_vel_coords( self ):
        '''Get velocity coordinates to center on the main halo.

        Modifies:
            self.data['V'] : Makes all velocities relative to self.vel_origin
        '''

        if self.vel_centered:
            return

        if isinstance( self.vel_center_method, np.ndarray ):
            self.vel_origin = copy.copy( self.vel_center_method )

        elif self.vel_center_method == 'halo':
            self.retrieve_halo_data()
            self.vel_origin = copy.copy( self.halo_velocity )

        else:
            raise KeyError( "Unrecognized vel_center_method, {}".format( self.vel_center_method ) )

        # Do it like this because we don't know the shape of self.data['V'][0]
        for i in range( 3 ):
            self.data['V'][i] -= self.vel_origin[i]

        self.vel_centered = True

    ########################################################################

    def add_hubble_flow( self ):
        '''Correct for hubble flow movement.

        Modifies:
            self.data['V'] : Accounts for hubble flow, relative to origin
        '''

        if self.hubble_corrected:
            return

        self.center_vel_coords()

        # Handle weird formatting that happens when using data frames.
        if isinstance( self.hubble_z, pd.Series ):
            self.data['V'] += self.get_data( 'P' ) * self.hubble_z.values
        else:
            self.data['V'] += self.get_data( 'P' ) * self.hubble_z

        self.hubble_corrected = True

    ########################################################################
    # Get Data
    ########################################################################

    def get_data( self, data_key, sl=None ):
        '''Get the data from within the class. Only for getting the data. No post-processing or changing the data
        (putting it in particular units, etc.) The idea is to calculate necessary quantities as the need arises,
        hence a whole function for getting the data.

        Args:
            data_key (str) : Key in the data dictionary for the key we want to get
            sl (slice) : Slice of the data, if requested.

        Returns:
            data (np.ndarray) : Requested data.
        '''

        # Loop through, handling issues
        n_tries = 10
        for i in range( n_tries ):
            try:

                # Positions
                if self.key_parser.is_position_key( data_key ):
                    data = self.get_position_data( data_key )

                # Velocities
                elif self.key_parser.is_velocity_key( data_key ):

                    data = self.get_velocity_data( data_key )

                # Arbitrary functions of the data
                elif data_key == 'Function':

                    raise Exception( "TODO: Test this" )

                    # Use the keys to get the data
                    function_data_keys = self.kwargs['function_args']['function_data_keys']
                    input_data = [ self.get_data(function_data_key) for function_data_key in function_data_keys ]

                    # Apply the function
                    data = self.kwargs['function_args']['function']( input_data )

                # Other
                else:
                    data = self.data[data_key]

            # Calculate missing data
            except KeyError as e:
                self.handle_data_key_error( data_key )
                continue

            break

        if 'data' not in locals().keys():
            raise KeyError( "After {} tries, unable to find or create data_key, {}".format( i + 1, data_key ) )

        if sl is not None:
            return data[sl]

        return data

    ########################################################################

    def get_position_data( self, data_key ):
        '''Get position data (assuming the data starts with an 'R')

        Args:
            data_key (str) : Key in the data dictionary for the key we want to get

        Returns:
            data (np.ndarray) : Requested data.
        '''

        self.center_coords()

        # Transpose in order to account for when the data isn't regularly shaped
        if data_key == 'Rx':
            data = self.data['P'][0, :]
        elif data_key == 'Ry':
            data = self.data['P'][1, :]
        elif data_key == 'Rz':
            data = self.data['P'][2, :]
        else:
            data = self.data[data_key]

        return data

    ########################################################################

    def get_velocity_data( self, data_key ):
        '''Get position data (assuming the data starts with a 'V')

        Args:
            data_key (str) : Key in the data dictionary for the key we want to get

        Returns:
            data (np.ndarray) : Requested data.
        '''

        self.center_vel_coords()
        self.add_hubble_flow()

        # Get data
        if data_key == 'Vx':
            data = self.data['V'][0, :]
        elif data_key == 'Vy':
            data = self.data['V'][1, :]
        elif data_key == 'Vz':
            data = self.data['V'][2, :]
        else:
            data = self.data[data_key]

        return data

    ########################################################################

    def handle_data_key_error( self, data_key ):
        '''When get_data() fails to data_key in self.data, it passes the data_key to try and generate that data.

        Args:
            data_key (str) : Key to try and generate data for

        Modifies:
            self.data[data_key] (np.array) : If it finds a function to generate the data, it will do so
        '''

        if self.verbose:
            print( 'Data key {} not found in data. Attempting to calculate.'.format( data_key ) )

        method_str = 'calc_{}'.format( data_key )
        if hasattr( self, method_str ):
            getattr( self, method_str )()

        # SimulationData methods
        elif data_key == 'R':
            self.calc_radial_distance()
        elif data_key == 'Vmag':
            self.calc_velocity_magnitude()
        elif data_key == 'Vr':
            self.calc_radial_velocity()
        elif data_key == 'Vtan':
            self.calc_tangential_velocity()
        elif data_key == 'ind':
            self.calc_inds()
        elif data_key == 'L':
            self.calc_ang_momentum()
        elif data_key == 'Phi':
            self.calc_phi()
        elif data_key == 'AbsPhi':
            self.calc_abs_phi()
        elif data_key == 'NumDen':
            self.calc_num_den()
        elif data_key == 'HDen':
            self.calc_H_den()
        elif data_key == 'HIDen':
            self.calc_HI_den()

        # TODO: Move these to the subclasses somehow.
        # Subclass methods
        elif data_key == 'T':
            # TODO: This is a hacky fix, that should be changed...
            if hasattr( self, 'calc_temp' ):
                self.calc_temp()
        elif data_key == 'Pressure':
            self.calc_pressure()

        else:
            raise KeyError( 'NULL data_key, data_key = {}'.format( data_key ) )

    ########################################################################

    def get_distance_to_point( self, point ):
        '''This is *not* unit tested.

        Args:
            point (array-like of shape (3,)) : The point you want to find the distance to for each particle.
        '''

        d_to_point = scipy.spatial.distance.cdist( self.get_data( 'P' ).transpose(), point[np.newaxis, :] )

        return d_to_point.flatten()

    ########################################################################

    def get_potential( self, point ):
        '''This is *not* unit tested.

        Args:
            point (array-like of shape (3,)) : The point you want to find the potential at
        '''

        d_to_point = self.get_distance_to_point( point )

        potential_per_particle = -1. * constants.UNITG_UNIV * self.get_data( 'M' ) / d_to_point

        total_potential = potential_per_particle.sum()

        return total_potential

    ########################################################################
    # Full calculations based on the data
    ########################################################################

    def calc_radial_distance( self ):
        '''Calculate the distance from the origin for a given particle.'''

        self.data['R'] = np.sqrt(
            self.get_data( 'Rx' )**2. + \
            self.get_data( 'Ry' )**2. + \
            self.get_data( 'Rz' )**2.
        )

        return self.data['R']

    ########################################################################

    def calc_velocity_magnitude( self ):
        '''Calculate the velocity relative to an origin velocity
        for a given particle.'''

        self.data['Vmag'] = np.sqrt(
            self.get_data( 'Vx' )**2. + \
            self.get_data( 'Vy' )**2. + \
            self.get_data( 'Vz' )**2.
        )

        return self.data['Vmag']

    ########################################################################

    def calc_rho_xy( self ):
        '''Calculate impact parameter in the xy-plane.
        '''

        self.data['rho_xy'] = np.sqrt(
            self.get_data( 'Rx' )**2. + \
            self.get_data( 'Ry' )**2.
        )

        return self.data['rho_xy']

    ########################################################################

    def calc_rho_xz( self ):
        '''Calculate impact parameter in the xz-plane.
        '''

        self.data['rho_xz'] = np.sqrt(
            self.get_data( 'Rx' )**2. + \
            self.get_data( 'Rz' )**2.
        )

        return self.data['rho_xz']

    ########################################################################

    def calc_rho_yz( self ):
        '''Calculate impact parameter in the yz-plane.
        '''

        self.data['rho_yz'] = np.sqrt(
            self.get_data( 'Ry' )**2. + \
            self.get_data( 'Rz' )**2.
        )

        return self.data['rho_yz']

########################################################################
########################################################################


class SnapshotData( SimulationData ):
    '''Class for analysis of a single snapshot of data.'''

    def __init__( self, snum, *args, **kwargs ):
        '''
        Args:
            snum (int or array of ints) : Snapshot or snapshots to inspect.
        '''

        self.snum = snum

        super( SnapshotData, self ).__init__( *args, **kwargs )

    ########################################################################
    # Get Additional Data
    ########################################################################

    def retrieve_halo_data( self ):

        if self.halo_data_retrieved:
            return

        # Load the AHF data
        ahf_reader = read_ahf.AHFReader( self.halo_data_dir )
        ahf_reader.get_mtree_halos( index=self.ahf_index, tag=self.ahf_tag )

        # Select the main halo at the right redshift
        mtree_halo = ahf_reader.mtree_halos[self.main_halo_id].loc[self.snum]

        # Add the halo data to the class.
        self.redshift = mtree_halo['redshift']
        halo_coords_comoving = np.array( [ mtree_halo['Xc'], mtree_halo['Yc'], mtree_halo['Zc'] ] )
        self.halo_coords = halo_coords_comoving / (1. + self.redshift) / self.data_attrs['hubble']
        self.halo_velocity = np.array( [ mtree_halo['VXc'], mtree_halo['VYc'], mtree_halo['VZc'] ] )
        self.r_vir = mtree_halo['Rvir'] / (1. + self.redshift) / self.data_attrs['hubble']
        self.r_scale = self.r_vir / mtree_halo['cAnalytic']
        self.m_vir = mtree_halo['Mvir'] / self.data_attrs['hubble']
        self.m_gas = mtree_halo['M_gas'] / self.data_attrs['hubble']
        self.m_star = mtree_halo['M_star'] / self.data_attrs['hubble']

        # Calculate the circular velocity
        self.v_c = astro.circular_velocity( self.r_vir, self.m_vir )

        self.halo_data_retrieved = True

        if self.store_ahf_reader:
            self.ahf_reader = ahf_reader

    ########################################################################
    # Properties
    ########################################################################

    @property
    def central_mask( self ):
        '''This mask is used when, for example, finding the velocity of the
        center of mass.
        '''

        if not hasattr( self, '_central_mask' ):

            self._central_mask  = self.data_masker.mask_data(
                'Rf',
                0.,
                self.averaging_frac,
                return_or_store = 'return',
            )

        return self._central_mask

    ########################################################################

    @property
    def v_com( self ):
        '''Property for the velocity of the center of mass.'''

        if not hasattr( self, '_v_com' ):

            m_ma = self.get_selected_data( 'M', self.central_mask )
            v_ma = self.get_selected_data( 'V', self.central_mask )

            self._v_com = ( v_ma * m_ma ).sum( 1 ) / m_ma.sum()

        return self._v_com

    ########################################################################

    @property
    def total_ang_momentum(self):
        '''Calculate the total angular momentum vector.'''

        # Exit early if already calculated.
        if not hasattr( self, '_total_ang_momentum' ):

            # Adapt for application to 'l', which is a multidimensional array
            inner_mask = np.array([ self.central_mask ] * 3)

            # Apply masks
            ang_momentum = self.get_data('L')
            l_ma = np.ma.masked_array(ang_momentum, mask=inner_mask)

            # Get the total angular momentum
            self._total_ang_momentum = np.zeros(3)
            for i in range(3):
                self._total_ang_momentum[i] = l_ma[i].sum()

        return self._total_ang_momentum

    ########################################################################

    @property
    def dN_halo(self, R_vir=None, time_units='abs_length'):
        '''Calculate dN_halo/dX/dlog10Mh or dN_halo/dz/dlog10Mh. X is absorption path length (see for example Ribaudo+11)

        time_units: 'abs_length'- dN_halo/dX/dlog10Mh
                                'redshift' - dN_halo/dz/dlog10Mh
        R_vir:      None - Default, assumes given.
                                'BN' Calculates R_vir from the mass and redshift
        '''

        raise Exception( "TODO: Test this" )

        # Previous code that needs to be cleaned up below.

        # # Choose cosmology
        # cosmo = Cosmology.setCosmology('WMAP9')

        # # Calculate the virial radius, if necessary
        # if R_vir is not None:
        #     if R_vir == 'BN':
        #         # Calculate R_vir off of the cosmocode def, and convert to proper kpc.
        #         self.R_vir = cosmo.virialRadius(M, z) * 10.**3.

        # # Make h easier to use (don't have to write the whole thing out...)
        # h = self.data_attrs['hubble']

        # Mh = h * self.M_vir  # Convert the halo mass to Msun/h, so as to feed it into the HMF.
        # dldz = np.abs(cosmo.line_elt(self.redshift))  # Cosmological line element in Mpc.
        # dldz_kpc = dldz * 10.**3.
        # dndlog10M = cosmo.HMF(Mh, self.redshift) * self.M_vir * np.log(10) * h**4.  # HMF in 1/Mpc^3
        # dndlog10M_kpc = dndlog10M * 10.**-9.
        # dN_halo = dldz_kpc * dndlog10M_kpc * np.pi * self.R_vir**2.

        # # Convert from per redshift to per absorption path length.
        # if time_units == 'abs_length':
        #     dN_halo /= cosmo.dXdz(z)
        # elif time_units == 'redshift':
        #     pass

        # return dN_halo

    ########################################################################
    # Full calculations of the data
    ########################################################################

    def calc_velocity_magnitude(self):
        '''Calculate the radial velocity.'''

        # Center velocity and radius
        self.center_coords()
        self.center_vel_coords()

        # Calculate the radial velocity
        self.data['Vmag'] = np.sqrt(
            self.get_data( 'Vx' )**2. + \
            self.get_data( 'Vy' )**2. + \
            self.get_data( 'Vz' )**2.
        )

    ########################################################################

    def calc_radial_velocity(self):
        '''Calculate the radial velocity.'''

        # Center velocity and radius
        self.center_coords()
        self.center_vel_coords()

        # Calculate the radial velocity
        self.data['Vr'] = ( self.get_data( 'V' ) * self.get_data('P')).sum(0) / self.get_data('R')

    ########################################################################

    def calc_tangential_velocity(self):
        '''Calculate the radial velocity.'''

        # Center velocity and radius
        self.center_coords()
        self.center_vel_coords()

        # Calculate the radial velocity
        self.data['Vtan'] = np.sqrt( self.get_data( 'Vmag' )**2. - self.get_data( 'Vr' )**2. )

    ########################################################################

    def calc_inds(self):
        '''Calculate the indices the data are located at, prior to any masks.'''

        raise Exception( "TODO: Test this" )

        # Flattened index array
        flat_inds = np.arange(self.get_data('Den').size)

        # Put into a multidimensional array
        self.data['ind'] = flat_inds.reshape(self.get_data('Den').shape)

    ########################################################################

    def calc_ang_momentum( self ):
        '''Calculate the angular momentum.'''

        m_mult = np.array( [ self.get_data('M'), ] * 3 )

        p = self.get_data('P')
        v = self.get_data('V')

        l = np.cross( p, v, 0, 0).transpose()
        l *= m_mult

        self.data['L'] = l

    ########################################################################

    def calc_phi( self, normal_vector='total ang momentum' ):
        '''Calculate the angle (in degrees) from some vector.
        By default the vector is the total angular momentum.
        '''

        raise Exception( "TODO: Test this" )

        if vector == 'total ang momentum':
            # Calculate the total angular momentum vector, if it's not calculated yet
            self.normal_vector = self.calc_total_ang_momentum()
        else:
            self.normal_vector = normal_vector

        # Get the dot product
        P = self.get_data('P')
        dot_product = np.zeros(P[0, :].shape)
        for i in range(3):
            dot_product += self.v[i] * P[i, :]

        # Isolate for the cosine
        cos_phi = dot_product / self.get_data('R') / np.linalg.norm(self.v)

        # Get the angle (in degrees)
        self.data['Phi'] = np.arccos(cos_phi) * 180. / np.pi

    ########################################################################

    def calc_abs_phi(self, vector='total gas ang momentum'):
        '''The angle (in degrees) from some vector, but don't mirror
        it around 90 degrees (e.g. 135 -> 45 degrees, 180 -> 0 degrees).'''

        raise Exception( "TODO: Test this" )

        # Get the original Phi
        self.calc_phi(vector)

        self.data['AbsPhi'] = np.where(self.data['Phi'] < 90., self.data['Phi'], np.absolute(self.data['Phi'] - 180.))

    ########################################################################

    def calc_num_den(self):
        '''Calculate the number density (it's just a simple conversion...).'''

        self.data['NumDen'] = self.data['Den'] * constants.UNITDENSITY_IN_NUMDEN

    ########################################################################

    def calc_H_den(self):
        '''Calculate the H density in cgs (cm^-3). Assume the H fraction is ~0.75'''

        # Assume Hydrogen makes up 75% of the gas
        X_H = 0.75

        self.data['HDen'] = X_H * self.data['Den'] * constants.UNITDENSITY_IN_NUMDEN

    ########################################################################

    def calc_HI_den(self):
        '''Calculate the HI density in cgs (cm^-3).'''

        raise Exception( "TODO: Test this" )

        # Assume Hydrogen makes up 75% of the gas
        X_H = 0.75

        # Calculate the hydrogen density
        HDen = X_H * self.data['Den'] * constants.gas_den_to_nb

        self.data['HIDen'] = self.data['nHI'] * HDen

    ########################################################################
    # Non-Altering Calculations
    ########################################################################

    def dist_to_point(self, point, units='default'):
        '''Calculate the distance to a point for all particles.

        point : np array that gives the point
        '''

        raise Exception( "TODO: Test this/update to use scipy cdist" )

        # Calculate the distance to the point
        relative_positions = self.get_data('P').transpose() - np.array(point)
        d_mag = np.linalg.norm(relative_positions, axis=1)

        # Put in different units if necessary
        if units == 'default':
            pass
        elif units == 'h':
            d_mag /= self.get_data('h')
        else:
            raise Exception('Null units, units = {}'.format(units))

        return d_mag

    ########################################################################

    def calc_mu(self):
        '''Calculate the mean molecular weight. '''

        y_helium = self.data['Z_Species'][:, 0]  # Get the mass fraction of helium
        mu = 1. / (1. - 0.75 * y_helium + self.data['ne'])

        return mu

########################################################################
########################################################################

class TimeData( SimulationData ):
    '''Class for analysis of a time series data, e.g. the worldlines of a number of particles.'''

    def __init__(
        self,
        data_masker = None,
        *args, **kwargs
    ):

        if data_masker is None:
            data_masker = TimeDataMasker( self )

        super( TimeData, self ).__init__(
            data_masker = data_masker,
            *args,
            **kwargs
        )

    ########################################################################

    @property
    def n_particles( self ):
        '''Number of snapshots, i.e. data points on the time axis.'''

        if not hasattr( self, '_n_particles' ):
            self._n_particles = self.base_data_shape[0]

        return self._n_particles

    ########################################################################

    @property
    def n_snaps( self ):
        '''Number of snapshots, i.e. data points on the time axis.'''

        if not hasattr( self, '_n_snaps' ):
            self._n_snaps = self.base_data_shape[1]

        return self._n_snaps

    ########################################################################

    def retrieve_halo_data( self ):

        if self.halo_data_retrieved:
            return

        # Load the AHF data
        ahf_reader = read_ahf.AHFReader( self.halo_data_dir )
        ahf_reader.get_mtree_halos( index=self.ahf_index, tag=self.ahf_tag )

        # Select the main halo at the right redshift
        mtree_halo = ahf_reader.mtree_halos[self.main_halo_id].loc[self.snums]

        # Add the halo data to the class.
        self.redshift = mtree_halo['redshift']
        scale_factor_and_hinv = 1. / (1. + self.redshift) / self.hubble_param

        halo_coords_comoving = np.array( [ mtree_halo['Xc'], mtree_halo['Yc'], mtree_halo['Zc'] ] )
        self.halo_coords = halo_coords_comoving * scale_factor_and_hinv[np.newaxis, :]
        self.halo_velocity = np.array( [ mtree_halo['VXc'], mtree_halo['VYc'], mtree_halo['VZc'] ] )
        self.r_vir = mtree_halo['Rvir'] * scale_factor_and_hinv
        self.r_scale = self.r_vir / mtree_halo['cAnalytic']
        self.m_vir = mtree_halo['Mvir'] / self.hubble_param
        self.m_gas = mtree_halo['M_gas'] / self.hubble_param
        self.m_star = mtree_halo['M_star'] / self.hubble_param

        # Calculate the circular velocity
        self.v_c = astro.circular_velocity( self.r_vir, self.m_vir )

        self.halo_data_retrieved = True

        if self.store_ahf_reader:
            self.ahf_reader = ahf_reader

    ########################################################################

    def calc_radial_velocity( self ):
        '''Calculate the distance from the origin for a given particle.'''

        v_all = self.get_data( 'V' )
        p_all = self.get_data( 'P' )
        r_all = self.get_data( 'R' )

        # Note that we don't need to add the hubble flow, because we've
        # already done so in self.get_velocity_data()
        assert self.hubble_corrected

        v_r = []
        for i in range( self.n_snaps ):

            v = v_all[:,:,i]
            p = p_all[:,:,i]
            r = r_all[:,i]

            v_r_i = ( v * p ).sum( axis=0 )/r

            v_r.append( v_r_i )

        # Format the output
        self.data['Vr'] = np.array( v_r ).transpose()

    ########################################################################

    def handle_data_key_error( self, data_key ):

        method_str = 'calc_{}'.format( data_key )
        if hasattr( self, method_str ):
            getattr( self, method_str )()
        elif self.calc_time_as_classification( data_key ):
            return
        elif self.calc_time_until_not_classification( data_key ):
            return

        else:
            super( TimeData, self ).handle_data_key_error( data_key )
            # raise KeyError( 'NULL data_key, data_key = {}'.format( data_key ) )

    ########################################################################

    def get_processed_data(
        self,
        data_key,
        sl = None,
        smooth_data = False,
        smoothing_window_length = 9,
        smoothing_polyorder = 3,
        a_power = None,
        scale_key = None,
        scale_a_power = None,
        scale_h_power = None,
        tile_data = False,
        tile_dim = 'auto',
        *args, **kwargs
    ):
        '''Modified method for getting processed method. For the most part is
        equivalent to calling the method of the parent class, but is also
        capable of scaling the retrieved data by a column from the halo data.

        Args:
            data_key (str) : What to get out?

            sl (slice) : Slice of the data, if requested.

            smooth_data (bool) : If True, smooth the data.

            smoothing_window_length, smoothing_polyorder (int) :
                Arguments for how to smooth the data.

            scale_key (str) :
                Halo data entry by which to divide the data by.

            scale_a_power (float) :
                The halo data that we are scaling processed_data by will be
                multiplied by a to this power.
                Useful for data in cosmological units (as is often normal).

            scale_h_power (float) :
                The halo data that we are scaling processed_data by will be
                multiplied by the hubble parameter to this power.
                Useful for data in cosmological units (as is often normal).

            tile_data (bool) :
                If True, tile data along a given direction. This is usually for
                data formatting purposes.

            tile_dim (str) :
                If the data is tiled, what dimension of the data should match?
                Options:
                    'auto' :
                        Tiles according to data size.
                    'match_snaps' :
                        The data is tiled such that the new shape is
                        (self.n_snaps, data_size).
                    'match_particles' :
                        The data is tiled such that the new shape is
                        (data_size, self.n_particles).

            *args, **kwargs :
                Passed to SimulationData.get_processed_data.

        Returns:
            processed_data (array-like) : Requested data array.
        '''

        if ( sl is not None ) and tile_data:
            used_sl = None
        else:
            used_sl = sl

        processed_data = super( TimeData, self ).get_processed_data(
            data_key,
            sl = used_sl,
            *args, **kwargs
        )

        if smooth_data:
            processed_data = signal.savgol_filter(
                processed_data,
                window_length = smoothing_window_length,
                polyorder = smoothing_polyorder,
            )

        if scale_key is not None:

            # Get the data
            data_to_div_by = self.halo_data.get_mt_data(
                scale_key,
                mt_halo_id = self.main_halo_id,
                a_power = scale_a_power,
                snums = self.snums
            )

            if scale_h_power is not None:
                processed_data /= self.hubble_param**scale_h_power

            if sl is not None:
                try:
                    data_to_div_by = data_to_div_by[sl]
                # For when we're getting a slice of a single snapshot
                except IndexError:
                    assert sl[0] == slice(None)
                    data_to_div_by = data_to_div_by[sl[1]]

            processed_data /= data_to_div_by

        if tile_data:

            if tile_dim == 'auto':
                if processed_data.shape == ( self.n_particles, ):
                    tile_dim = 'match_snaps'
                elif processed_data.shape == ( self.n_snaps, ):
                    tile_dim = 'match_particles'
                elif processed_data.shape == self.base_data_shape:
                    tile_dim = None
                else:
                    raise Exception(
                        "Unrecognized data shape, {}".format(
                            processed_data.shape
                        )
                    )

            if tile_dim == 'match_snaps':
                processed_data = np.tile(
                    processed_data,
                    ( self.n_snaps, 1),
                ).transpose()

            elif tile_dim == 'match_particles':
                processed_data = np.tile(
                    processed_data,
                    ( self.n_particles, 1),
                )

            if sl is not None:
                processed_data = processed_data[sl]

        return processed_data

    ########################################################################

    def get_selected_data_over_time( self, *args, **kwargs ):
        '''Wrapper for geting masked data as a function of time.'''

        return self.data_masker.get_selected_data_over_time( *args, **kwargs )

    ########################################################################

    @property
    def snums( self ):

        # TODO: This is a workable structure for now, but it's not ideal. This may not always be how we get the snum
        return self.data['snum']

    ########################################################################

    @property
    def hubble_param( self ):

        # TODO: This is a workable structure for now, but it's not ideal. The hubble parameter may not always be here.
        return self.data_attrs['hubble']

    ########################################################################

    def calc_ang_momentum( self ):
        '''The angular momentum (in the standard coordinates).

        Modifies:
            self.data['L'] (array-like):
                Angular momentum of each resolution element.
        '''

        m_mult = np.array( [ self.get_data('M'), ] * 3 )

        p_all = self.get_data('P')
        v_all = self.get_data('V')

        l_all = np.zeros( ( 3, self.n_particles, self.n_snaps ) )
        # Calculate the angular momentum at each redshift
        for i in range( self.n_snaps ):                                         
                                                                                
            v = v_all[:,:,i]                                                    
            p = p_all[:,:,i]                                                    
            m = m_mult[:,:,i] 

            l = np.cross( p, v, 0, 0).transpose()
            l *= m

            l_all[:,:,i] = l

        self.data['L'] = l_all

    ########################################################################

    def calc_phi( self, normal_vector='total ang momentum' ):
        '''Calculate the angle (in degrees) from some vector.
        By default the vector is the total angular momentum.

        Args:
            normal_vector (str or array-like):
                Vector that represents the vertical.
        '''

        # Exit early if already calculated with the same normal vector
        if 'Phi' in self.data:
            same = np.allclose(
                self.normal_vector,
                normal_vector
            )
            if same:
                return

        # Set up the normal vector
        if normal_vector == 'total ang momentum':
            self.normal_vector = self.total_ang_momentum
        else:
            self.normal_vector = normal_vector

        # Get all the data we need
        p_all = self.get_data('P')
        r_all = self.get_data('R')

        # Loop over each snapshot and calculate
        phi_all = np.zeros( self.base_data_shape )
        for i in range( self.n_snaps ):

            p = p_all[:,:,i]                                                    
            r = r_all[:,i]

            # Get the dot product
            dot_product = np.zeros( p[0,:].shape )
            for j in range(3):
                dot_product += self.normal_vector[j] * p[j,:]

            # Isolate for the cosine
            cos_phi = dot_product / r / np.linalg.norm( self.normal_vector )

            # Get the angle (in degrees)
            phi_all[:,i] = np.arccos( cos_phi ) * 180. / np.pi

        self.data['Phi'] = phi_all

    ########################################################################

    def calc_abs_phi(self, normal_vector='total ang momentum' ):
        '''Calculate the angle (in degrees) from some vector, but mirror
        values past 90 degrees (e.g. 135 -> 45 degrees, 180 -> 0 degrees).
        This is useful when there's symmetry above and below 90.

        Args:
            normal_vector (str or array-like):
                Vector that represents the vertical. Defaults to using the
                total stellar angular momentum of the main galaxy.

        Modifies:
            self.data['AbsPhi'] (array-like):
                Angle from the vector, in many cases acting as the angle
                from the disk axis.
        '''

        # Get the original Phi
        self.calc_phi( normal_vector )

        phi = self.data['Phi']

        # Actual calculation
        self.data['AbsPhi'] = np.where(
            phi < 90.,
            phi,
            np.absolute( phi - 180.)
        )

    ########################################################################

    def calc_time_as_classification( self, data_key ):

        # Check if we should be running this function (does the provided
        # data_key even match the format we want to parse?)
        if data_key[:7] != 'time_as':
            return False

        # Get the data key for the classification
        classification_data_key = 'is_{}'.format( data_key[8:] )

        # Get out the classification data itself
        classification = self.get_data( classification_data_key )

        # Get out the time intervals, and tile them for formatting
        dt = self.get_data( 'dt' )

        # Fill in the array row by row
        time_as_classification = np.zeros( classification.shape )
        for i, row in enumerate( classification ):

            # Identify regions of contiguous classification
            contiguous_regions = data_operations.contiguous_regions( row )

            # Find the cumulative time in specified regions
            for start, end in contiguous_regions:

                dt_region = dt[start:end]

                # We need to flip when we sum because we want a reverse sum
                # (due to data formatting, t=0 is at j=-1)
                cumtime_region = np.cumsum( dt_region[::-1] )[::-1]

                # Store that time
                time_as_classification[i, start:end] = cumtime_region

        self.data[data_key] = time_as_classification

        return True

    ########################################################################

    def calc_time_until_not_classification( self, data_key ):

        # Check if we should be running this function (does the provided
        # data_key even match the format we want to parse?)
        if data_key[:14] != 'time_until_not':
            return False

        # Get the data key for the classification
        classification_data_key = 'is_{}'.format( data_key[15:] )

        # Get out the classification data itself
        classification = self.get_data( classification_data_key )

        # Get out the time intervals, and tile them for formatting
        dt = self.get_data( 'dt' )

        # dt[i] = time difference between i and i+1, so time until is 0
        # if the the particle changes the next snapshot
        dt =  np.insert( dt, 0, 0. )
        
        # Fill in the array row by row
        time_un_classification = np.zeros( classification.shape )
        for i, row in enumerate( classification ):

            # Identify regions of contiguous classification
            contiguous_regions = data_operations.contiguous_regions( row )

            # Find the cumulative time in specified regions
            for start, end in contiguous_regions:

                dt_region = dt[start:end]

                # We need to flip when we sum because we want a reverse sum
                # (due to data formatting, t=0 is at j=-1)
                cumtime_region = np.cumsum( dt_region )

                # Store that time
                try:
                    time_un_classification[i, start:end] = cumtime_region
                except:
                    #DEBUG
                    import pdb; pdb.set_trace()

        self.data[data_key] = time_un_classification

        return True

    ########################################################################

    def calc_metal_mass( self ):
        '''Calculate the metal mass held by each resolution element.

        Modifies:
            self.data['enriched_metal_mass'] (array-like):
                Metal mass from enrichment.
        '''

        metal_mass = self.get_data( 'M' ) * self.get_data( 'Z' ) * self.z_sun

        self.data['metal_mass'] = metal_mass

    ########################################################################

    def calc_enriched_metal_mass( self ):
        '''Calculate the metal mass that comes from enrichment for
        each resolution element, not counting mass that's at the metallicity
        floor. Assumes that the there will always be at least one resolution
        element in the simulation that's at the metallicity floor.

        Modifies:
            self.data['enriched_metal_mass'] (array-like):
                Metal mass from enrichment.
        '''

        enrichment_fraction = (
            self.get_data( 'Z' ) - np.nanmin( self.get_data( 'Z' ) )
        )

        enrichment_fraction *= self.z_sun

        enriched_metal_mass = self.get_data( 'M' ) * enrichment_fraction

        self.data['enriched_metal_mass'] = enriched_metal_mass

########################################################################
########################################################################

class TimeDataMasker( generic_data.DataMasker ):
    '''Data masker for worldline data.'''

    def __init__( self, time_data ):

        super( TimeDataMasker, self ).__init__( time_data )

    ########################################################################

    def get_selected_data_over_time(
        self,
        data_key,
        snum,
        mask = 'total',
        optional_masks = None,
        n_samples = None,
        seed = None,
        *args, **kwargs
    ):
        '''Get data over the full time history, based on its mask at
        one time.
    
        Args:
            data_key (str): Data to get.
            
            snum (int): Snapshot to get the data corresponding to.
        '''

        # Make sure we don't try to pass a slice to any keyword arguments
        assert 'sl' not in kwargs, 'Taking slices of the original data' + \
                ' should not be done when using get_selected_data_over_time'

        # Get the appropriate mask
        if isinstance( mask, np.ndarray ):
            used_mask = mask
        elif isinstance( mask, bool ) or isinstance( mask, np.bool_ ):
            if not mask:
                used_mask = np.zeros(
                    self.data_object.base_data_shape
                ).astype( bool )
            else:
                raise Exception( "All data is masked." )
        elif mask == 'total':
            used_mask = self.get_total_mask( optional_masks=optional_masks )
        else:
            raise KeyError( "Unrecognized type of mask, {}".format( mask ) )

        # Get the relevant ind
        correct_snum = self.data_object.snums == snum
        ind = np.arange( self.data_object.n_snaps )[correct_snum][0]

        # Get the boolean for the selected data
        sl = np.invert( used_mask[:,ind] )

        # If we want to get the particle ind, we do something different
        if data_key == 'particle_ind':

            n_particles_mask = sl.sum()

            if n_samples is None:
                n_particles_selected = n_particles_mask
            else:
                n_particles_selected = n_samples

            if n_samples >= n_particles_mask:
                print( "n_samples > n_particles_selected, not sampling." )
                n_particles_selected = n_particles_mask

            return np.tile(
                np.arange( n_particles_selected ),
                ( self.data_object.n_snaps, 1 ),
            ).transpose()
        
        # Get the masked data (masked via the slice)
        masked_data = self.data_object.get_processed_data(
            data_key,
            sl = sl,
            *args, **kwargs
        )

        # Sample a subset of the data
        if n_samples is not None:

            # If given a seed for sampling, use it
            if seed is not None:
                np.random.seed( seed )

            n_particles_selected = masked_data.shape[0]
            if n_samples >= n_particles_selected:
                print( "n_samples > n_particles_selected, not sampling." )
                return masked_data

            sampled_inds = np.random.choice(
                np.arange( n_particles_selected ),
                n_samples,
                replace = False,
            )
            masked_data = masked_data[sampled_inds,:]

        return masked_data