Source code for prism._internal

# -*- coding: utf-8 -*-

"""
Internal
========
Contains a collection of support classes/functions for the *PRISM* package.

"""


# %% IMPORTS
# Built-in imports
import logging
import logging.config
import os
from os import path
import platform
import shutil
from struct import calcsize
from tempfile import mkstemp
from textwrap import dedent

# Package imports
import e13tools as e13
import h5py
from mpi4pyd import MPI
import numpy as np
from pkg_resources import get_distribution

# PRISM imports
from prism.__version__ import __version__, compat_version

# All declaration
__all__ = ['CFilter', 'FeatureWarning', 'PRISM_Logger', 'RFilter',
           'RequestError', 'RequestWarning', 'check_compatibility',
           'check_vals', 'get_bibtex', 'get_PRISM_File', 'get_formatter',
           'get_handler', 'get_info', 'getCLogger', 'getLogger', 'getRLogger',
           'move_logger', 'np_array', 'set_base_logger']

# Determine MPI size and ranks
size = MPI.COMM_WORLD.Get_size()
rank = MPI.COMM_WORLD.Get_rank()


# %% CLASS DEFINITIONS
# Make a custom Filter class that only allows the controller to log messages
[docs]class CFilter(logging.Filter): """ Custom :class:`~logging.Filter` class that only allows the controller rank to log messages to the logfile. Calls from worker ranks are ignored. """ def __init__(self, MPI_rank): self.is_controller = 1 if not MPI_rank else 0 super().__init__('CFilter') def filter(self, record): return(self.is_controller)
# Define Warning class for when an experimental feature is being used
[docs]class FeatureWarning(FutureWarning): """ Generic warning raised for experimental features in *PRISM*. General purpose warning class, raised whenever a feature is used that should be considered experimental. Its behavior and API are subject to change, or the entire feature may be removed without a deprecation period. """ pass
# Define custom Logger class that allows for filters to be easily used
[docs]class PRISM_Logger(logging.Logger): """ Special :class:`~logging.Logger` class that allows for special filters to be set more easily. """ # Initialize Logger def __init__(self, *args, **kwargs): # Call super constructor super().__init__(*args, **kwargs) # Initialize different custom filters self.initialize_filters() # This function initializes custom filters def initialize_filters(self): self.PRISM_filters = { 'CFilter': CFilter(rank), 'RFilter': RFilter(rank)} # This function adds requested filters to Logger def set_filters(self, filters): # If filters is not None, add all filters to Logger if filters is not None: for filter in filters: self.addFilter(self.PRISM_filters[filter])
# Make a custom Filter class that logs the rank of the process that calls it
[docs]class RFilter(logging.Filter): """ Custom :class:`~logging.Filter` class that prepends the world rank of the MPI process that calls it to the logging message. If the size of :obj:`MPI.COMM_WORLD` is 1, this filter does nothing. """ def __init__(self, MPI_rank): if(size > 1): self.prefix = "Rank %i: " % (MPI_rank) else: self.prefix = "" super().__init__('RFilter') def filter(self, record): record.msg = "".join([self.prefix, record.msg]) return(1)
# Define Exception class for when a requested action is not possible
[docs]class RequestError(Exception): """ Generic exception raised for invalid action requests in the *PRISM* pipeline. General purpose exception class, raised whenever a requested action cannot be executed due to it not being allowed or possible in the current state of the :obj:`~prism.Pipeline` instance. """ pass
# Define Warning class for when a (future) requested action may not be useful
[docs]class RequestWarning(UserWarning): """ Generic warning raised for (future) action requests in the *PRISM* pipeline that may not be useful. General purpose warning class, raised whenever a requested action may not produce appropriate or expected results due to the current state of the :obj:`~prism.Pipeline` instance. It is also raised if an obtained result can lead to such an action in the future. """ pass
# %% FUNCTION DEFINITIONS # Function for checking if emulator system is compatible with PRISM version
[docs]def check_compatibility(emul_version): """ Checks if the provided `emul_version` is compatible with the current version of *PRISM*. Raises a :class:`~RequestError` if *False* and indicates which version of *PRISM* still supports the provided `emul_version`. """ # Do some logging logger = getCLogger('COMP_CHECK') logger.info("Performing version compatibility check.") # Check if emul_version is 1.0.x and raise warning if so if not e13.compare_versions(emul_version, '1.1.0'): warn_msg = ("The provided emulator was constructed with an " "unmaintained version of PRISM (v%s). Compatibility with " "the current version of PRISM cannot be guaranteed." % (emul_version)) e13.raise_warning(warn_msg, RequestWarning, logger, 2) return # Loop over all compatibility versions for version in compat_version: # If a compat_version is the same or newer than the emul_version # then it is incompatible if e13.compare_versions(version, emul_version): err_msg = ("The provided emulator is incompatible with the current" " version of PRISM (v%s). The last compatible version " "is v%s." % (__version__, version)) e13.raise_error(err_msg, RequestError, logger) # Check if emul_version is not newer than prism_version if not e13.compare_versions(__version__, emul_version): err_msg = ("The provided emulator was constructed with a version later" " than the current version of PRISM (v%s). Use v%s or later" " to use this emulator." % (__version__, emul_version)) e13.raise_error(err_msg, RequestError, logger) else: logger.info("Version compatibility check was successful.")
# This function checks if the input values meet all given criteria
[docs]def check_vals(values, name, *args): """ Checks if all values in provided input argument `values` with `name` meet all criteria given in `args`. If no criteria are given, it is checked if all values are finite. Returns `values` (0 or 1 in case of bool) if *True* and raises a :class:`~ValueError` or :class:`~TypeError` if *False*. Parameters ---------- values : array_like of {bool; complex; float; int; str} The values to be checked against all given criteria in `args`. It must be possible to convert `values` to a :obj:`~numpy.ndarray` object. name : str The name of the input argument, which is used in the error message if a criterion is not met. args : positional arguments in {'bool'; 'complex'; 'float'; 'int'; 'neg'; \ 'nneg'; 'normal'; 'npos'; 'nzero'; 'pos'; \ 'str'} Sequence of strings determining the criteria that `values` must meet. If `args` is empty, it is checked if `values` are finite. Returns ------- return_values : array_like of {complex; float; int; str} If `args` contained 'bool', returns 0s or 1s. Else, returns `values`. Notes ----- If `values` contains integers, but `args` contains 'float', `return_values` will be cast as float. """ # Define logger logger = getRLogger('CHECK') # Convert args to a list args = list(args) # Check type of values if isinstance(values, tuple): arr_type = 'tuple' elif isinstance(values, list): arr_type = 'list' elif isinstance(values, np.ndarray): arr_type = 'ndarray' elif np.isscalar(values) or values is None: arr_type = 'scalar' else: err_msg = "Input argument %r is not array_like!" % (name) e13.raise_error(err_msg, e13.InputError, logger) # Convert values to a NumPy array try: values = np.asanyarray(values) except Exception as error: # pragma: no cover err_msg = ("Input argument %r cannot be converted to a NumPy array! " "(%s)" % (name, error)) e13.raise_error(err_msg, e13.InputError, logger) else: # Since NumPy v1.16.0, anything can be converted to NumPy arrays # So, check if the dtype is not np.object_ if issubclass(values.dtype.type, np.object_): err_msg = ("Input argument %r cannot be converted to a NumPy " "dtype!" % (name)) e13.raise_error(err_msg, TypeError, logger) # Check if values is not empty and raise error if so if not values.size: err_msg = "Input argument %r is empty!" % (name) e13.raise_error(err_msg, e13.InputError, logger) # Loop over all criteria while args: # Check for bool if 'bool' in args: # Convert values to str values = np.char.lower(np.asanyarray(values, dtype=str)) # Check if values available are accepted as bools check_list = np.zeros_like(values, dtype=int, subok=False) check_list[values == '0'] = 1 check_list[values == 'false'] = 1 values[values == 'false'] = '0' check_list[values == '1'] = 1 check_list[values == 'true'] = 1 values[values == 'true'] = '1' # Check if check_list solely contains 1s if not check_list.all(): # If not, raise error index = np.unravel_index(np.argmin(check_list), values.shape) err_msg = ("Input argument '%s%s' is not of type 'bool'!" % (name, list(index) if values.ndim != 0 else '')) e13.raise_error(err_msg, TypeError, logger) else: # If so, convert values to integers and break the loop values = np.asanyarray(values, dtype=int) break # Check for string elif 'str' in args: # Check if str is provided and break if so if issubclass(values.dtype.type, str): break else: err_msg = "Input argument %r is not of type 'str'!" % (name) e13.raise_error(err_msg, TypeError, logger) # Check for complex elif 'complex' in args: # Check if complex is provided and continue if so if issubclass(values.dtype.type, (np.integer, np.floating, np.complexfloating)): # Remove 'complex' from args and check it again args.remove('complex') values = np.asanyarray(values, dtype=complex) continue else: err_msg = ("Input argument %r is not of type 'complex'!" % (name)) e13.raise_error(err_msg, TypeError, logger) # Check for float elif 'float' in args: # Check if float is provided and continue if so if issubclass(values.dtype.type, (np.integer, np.floating)): # Remove 'float' from args and check it again args.remove('float') values = np.asanyarray(values, dtype=float) continue else: err_msg = "Input argument %r is not of type 'float'!" % (name) e13.raise_error(err_msg, TypeError, logger) # Check for integer elif 'int' in args: # Check if int is provided and continue if so if issubclass(values.dtype.type, np.integer): # Remove 'int' from args and check it again args.remove('int') continue else: err_msg = "Input argument %r is not of type 'int'!" % (name) e13.raise_error(err_msg, TypeError, logger) # Check for negative value elif 'neg' in args: # Check if value is negative and continue if so try: index = list(np.argwhere(values >= 0)[0]) except IndexError: args.remove('neg') continue else: err_msg = ("Input argument '%s%s' is not negative!" % (name, index if values.ndim != 0 else '')) e13.raise_error(err_msg, ValueError, logger) # Check for non-negative value elif 'nneg' in args: # Check if value is non-negative and continue if so try: index = list(np.argwhere(values < 0)[0]) except IndexError: args.remove('nneg') continue else: err_msg = ("Input argument '%s%s' is not non-negative!" % (name, index if values.ndim != 0 else '')) e13.raise_error(err_msg, ValueError, logger) # Check for normalized value [-1, 1] elif 'normal' in args: # Check if value is normal and continue if so try: index = list(np.argwhere(abs(values) > 1)[0]) except IndexError: args.remove('normal') continue else: err_msg = ("Input argument '%s%s' is not normalized!" % (name, index if values.ndim != 0 else '')) e13.raise_error(err_msg, ValueError, logger) # Check for non-positive value elif 'npos' in args: # Check if value is non-positive and continue if so try: index = list(np.argwhere(values > 0)[0]) except IndexError: args.remove('npos') continue else: err_msg = ("Input argument '%s%s' is not non-positive!" % (name, index if values.ndim != 0 else '')) e13.raise_error(err_msg, ValueError, logger) # Check for non-zero value elif 'nzero' in args: # Check if value is non-zero and continue if so try: index = list(np.argwhere(values == 0)[0]) except IndexError: args.remove('nzero') continue else: err_msg = ("Input argument '%s%s' is not non-zero!" % (name, index if values.ndim != 0 else '')) e13.raise_error(err_msg, ValueError, logger) # Check for positive value elif 'pos' in args: # Check if value is positive and continue if so try: index = list(np.argwhere(values <= 0)[0]) except IndexError: args.remove('pos') continue else: err_msg = ("Input argument '%s%s' is not positive!" % (name, index if values.ndim != 0 else '')) e13.raise_error(err_msg, ValueError, logger) # If none of the criteria is found, the criteria are invalid else: err_msg = ("Input argument 'args' contains invalid elements (%s)!" % (args)) e13.raise_error(err_msg, ValueError, logger) # If no criteria are left, it must be a finite value else: # Check if value is finite and continue if so try: index = list(np.argwhere(~np.isfinite(values))[0]) except IndexError: pass except TypeError: err_msg = ("Input argument '%s%s' is not of type 'int' or 'float'!" % (name, index if values.ndim != 0 else '')) e13.raise_error(err_msg, TypeError, logger) else: err_msg = ("Input argument '%s%s' is not finite!" % (name, index if values.ndim != 0 else '')) e13.raise_error(err_msg, ValueError, logger) # Convert values back to its original type if(arr_type == 'tuple'): values = tuple(values.tolist()) elif(arr_type == 'list'): values = values.tolist() elif(arr_type == 'scalar'): values = values.item() # Return values return(values)
# Define class factory that returns a specialized h5py.File class
[docs]def get_PRISM_File(prism_hdf5_file): """ Returns a class definition ``PRISM_File(mode, emul_s=None, **kwargs)``. This class definition is a specialized version of the :class:`~h5py.File` class with the filename automatically set to `prism_hdf5_file` and added logging to the constructor and destructor methods. Parameters ---------- prism_hdf5_file : str Absolute path to the master HDF5-file that is used in a :obj:`~prism.Pipeline` instance. Returns ------- PRISM_File : class Definition of the class ``PRISM_File(mode, emul_s=None, **kwargs)``. """ # Split provided prism_hdf5_file up into parts parts = path.splitext(prism_hdf5_file) # Override h5py's File.__init__() and __exit__() methods class PRISM_File(h5py.File): """ Custom :class:`~h5py.File` class that automatically knows where all *PRISM* HDF5-files are located when created by the :func:`~get_PRISM_File` class factory. Additionally, certain keyword arguments have default values and the opening/closing of an HDF5-file is logged. """ # Override __init__() to include default settings and logging def __init__(self, mode, emul_s=None, **kwargs): """ Opens the master HDF5-file `prism_hdf5_file` in `mode` according to some set of default parameters. Parameters ---------- mode : {'r'; 'r+'; 'w'; 'w-'/'x'; 'a'} String indicating how the HDF5-file needs to be opened. Optional -------- emul_s : int or None. Default: None If int, number indicating the requested emulator system file to open. If *None*, the master HDF5-file itself is opened. kwargs : dict. Default: ``{'driver': None, 'libver': 'earliest'}`` Other keyword arguments that need to be given to the :func:`~h5py.File` function. """ # Save emul_s as a property self.emul_s = emul_s # Set default settings hdf5_kwargs = {'driver': None, 'libver': 'earliest'} # Check emul_s and obtain proper logger if self.emul_s is None: # Only controller opens master file for writing, so use CLogger sub_str = '' logger = getCLogger('M-HDF5') else: sub_str = '_%i' % (self.emul_s) logger = getRLogger('S-HDF5') # Add sub_str to filename filename = ''.join([parts[0], sub_str, parts[1]]) # Update hdf5_kwargs with provided ones hdf5_kwargs.update(kwargs) # Log that an HDF5-file is being opened if self.emul_s is None: logger.info("Opening master HDF5-file (mode: %r)." % (mode)) else: logger.info("Opening system HDF5-file %i (mode: %r)." % (self.emul_s, mode)) # Inheriting File __init__() super().__init__(filename, mode, **hdf5_kwargs) # Override __exit__() to include logging def __exit__(self, *args): # Log that an HDF5-file will be closed if self.emul_s is None: logger = getCLogger('M-HDF5') logger.info("Closing master HDF5-file.") else: logger = getRLogger('S-HDF5') logger.info("Closing system HDF5-file %i." % (self.emul_s)) # Inheriting File __exit__() super().__exit__(*args) # Return PRISM_File class definition return(PRISM_File)
# Define function that prints a string with the BibTeX entry to PRISM's paper
[docs]def get_bibtex(): """ Prints a string that gives the BibTeX entry for citing the *PRISM* paper (Van der Velden et al. 2019, ApJS, 242, 22). """ # Create string with BibTeX entry bibtex = dedent( r""" @ARTICLE{2019ApJS..242...22V, author = {{van der Velden}, E. and {Duffy}, A.~R. and {Croton}, D. and {Mutch}, S.~J. and {Sinha}, M.}, title = "{Model dispersion with PRISM; an alternative to MCMC for rapid analysis of models}", journal = {\apjs}, keywords = {Astrophysics - Instrumentation and Methods for Astrophysics, Physics - Computational Physics}, year = "2019", month = "Jun", volume = {242}, number = {2}, eid = {22}, pages = {22}, doi = {10.3847/1538-4365/ab1f7d}, archivePrefix = {arXiv}, eprint = {1901.08725}, primaryClass = {astro-ph.IM}, adsurl = {https://ui.adsabs.harvard.edu/abs/2019ApJS..242...22V}, adsnote = {Provided by the SAO/NASA Astrophysics Data System} } """) # Print the string print(bibtex.strip())
# This function returns a logging.Formatter used for PRISM logging
[docs]def get_formatter(): """ Returns a :obj:`~logging.Formatter` object containing the default logging formatting. """ # Set formatting strings fmt = "[%(asctime)s][%(levelname)-4s] %(name)-10s \t%(message)s" datefmt = "%Y-%m-%d %H:%M:%S" # Initialize Formatter class and return it return(logging.Formatter(fmt, datefmt))
# This function returns a logging.Handler used for PRISM logging
[docs]def get_handler(filename): """ Returns a :obj:`~logging.Handler` object containing the default logging handling settings. """ # Initialize Handler class handler = logging.FileHandler(filename, mode='a', encoding='utf-8') # Add name to handler handler.set_name('prism_base') # Set logLevel to DEBUG handler.setLevel('DEBUG') # Add formatter to handler handler.setFormatter(get_formatter()) # Return handler return(handler)
# Define function that prints a string with all PRISM package information
[docs]def get_info(): """ Prints a string that gives an overview of all information relevant to the *PRISM* package distribution. """ # Create info list info_list = [] # Add header to info_list info_list.append(dedent(""" Configuration -------------""")) # Add platform to info_list info_list.append("Platform: %s %i-bit" % (platform.system(), calcsize('P')*8)) # Add python version to info_list info_list.append("Python: %s" % (platform.python_version())) # Add PRISM version to info_list info_list.append("Version: %s" % (__version__)) # Access PRISM metadata prism_dist = get_distribution('prism') # Get list of all PRISM requirements req_list = [req.name for req in prism_dist.requires()] # Sort the requirements list req_list.sort() # Make requirements header info_list.append(dedent(""" Requirements ------------""")) # Get distribution version of every requirement of PRISM for req in req_list: dist = get_distribution(req) info_list.append("%s: %s" % (req, dist.version)) # Combine all strings in info_list to info_str info_str = '\n'.join(info_list) # Print info_str, stripping any additional whitespaces print(info_str.strip())
# Define custom getLogger function that adds the CFilter
[docs]def getCLogger(name=None): """ Creates a :obj:`~PRISM_Logger` instance with `name`, adds the :class:`~CFilter` to it and returns it. """ # Create PRISM_Logger with a CFilter return(getLogger(name, ['CFilter']))
# Define custom getLogger function that automatically names loggers correctly
[docs]def getLogger(name=None, filters=None): """ Creates a :obj:`~PRISM_Logger` instance with `name` and adds the provided `filters` to it. The returned :obj:`~PRISM_Logger` instance is a child of the base :class:`~PRISM_Logger` created with :func:`~set_base_logger`, but has its name changed (such that the parent name does not show up in the log-file). Optional -------- name : str or None. Default: None The name of the :obj:`~PRISM_Logger` instance to create. If *None*, initialize the base :class:`~PRISM_Logger` instead. filters : list of str or None. Default: None List of strings naming the filters that must be applied to the created :obj:`~PRISM_Logger` instance. If *None* or the :obj:`~PRISM_Logger` instance already existed, no filters will be applied. Returns ------- logger : :obj:`~PRISM_Logger` object The created :obj:`~PRISM_Logger` instance. """ # Set Logger name prefix prefix = 'prism' # Check what the provided name is if name is None: child_name = prefix name = 'PRISM_ROOT' else: child_name = ".".join([prefix, name]) # Temporarily set default Logger class to PRISM_Logger and initialize it logging.setLoggerClass(PRISM_Logger) logger = logging.getLogger(child_name) logging.setLoggerClass(logging.Logger) # Set name and filters if this logger did not already exist if(logger.name != name): # Remove prefix from the name of the PRISM_Logger instance logger.name = name # Set the requested filter(s) logger.set_filters(filters) # Return it return(logger)
# Define custom getLogger function that adds the RFilter
[docs]def getRLogger(name=None): """ Creates a :obj:`~PRISM_Logger` instance with `name`, adds the :class:`~RFilter` to it and returns it. """ # Create PRISM_Logger with an RFilter return(getLogger(name, ['RFilter']))
# Define function that can move the logging file of PRISM and restart logging
[docs]def move_logger(working_dir): """ Moves the base :class:`~PRISM_Logger` from the current working directory to the given `working_dir`, and then restarts it again. Parameters ---------- working_dir : str String containing the directory the log-file needs to be moved to. """ # Shut down logging process to allow the log-file to be moved logging.shutdown() # Get source and destination paths source = logging.getLogger('prism').handlers[0].baseFilename destination = path.join(working_dir, 'prism_log.log') # Check if file already exists and either combine files or move the file if path.isfile(destination): with open(destination, 'a') as dest, open(source, 'r') as src: for line in src: dest.write(line) os.remove(source) else: shutil.move(source, destination) # Restart the logger set_base_logger(filename=destination)
# This function automatically does not make a copy of a NumPy array
[docs]def np_array(obj, *args, **kwargs): """ Returns ``np.array(obj, *args, copy=False, **kwargs)``. """ # Return NumPy array with copy=False copy = kwargs.pop('copy', False) return(np.array(obj, *args, copy=copy, **kwargs))
# This function sets the base PRISM logger # TODO: Make base logger unique to Pipeline instance # This requires a lot of rewriting and many functions to be moved to Pipeline
[docs]def set_base_logger(filename=None): """ Initializes the base :class:`~PRISM_Logger`, from which all other :obj:`~PRISM_Logger` instances are derived. Optional -------- filename : str or None. Default: None String containing the name of the log-file that is opened. If *None*, a new log-file will be created. """ # If filename is not defined, make a new one if filename is None: fd, filename = mkstemp('.log', 'prism_', '.') os.close(fd) # Initialize base_logger base_logger = getLogger() # Make sure that base_logger has no handlers base_logger.handlers = [] # Initialize base handler and add it to base_logger base_logger.addHandler(get_handler(filename)) # Set logLevel to the same as the logLevel of the handler base_logger.setLevel(base_logger.handlers[0].level) # Make sure that the base_logger does not propagate logging messages base_logger.propagate = False