Source code for prism.modellink.utils

# -*- coding: utf-8 -*-

"""
Utilities
=========
Provides several utility functions concerning the
:class:`~prism.modellink.ModelLink` class.

"""


# %% IMPORTS
# Built-in imports
from inspect import _VAR_KEYWORD, _VAR_POSITIONAL, _empty, isclass, signature
from os import path
import warnings

# Package imports
import e13tools as e13
from mpi4pyd.MPI import get_HybridComm_obj
import numpy as np
from numpy.random import rand
from sortedcontainers import SortedDict as sdict

# PRISM imports
from prism._internal import RequestWarning, check_vals, np_array

# All declaration
__all__ = ['convert_data', 'convert_parameters', 'test_subclass']


# %% UTILITY FUNCTIONS
# This function converts provided model data into format used by PRISM
[docs]def convert_data(model_data): """ Converts the provided `model_data` into a full data dict, taking into account all formatting options, and returns it. This function can be used externally to check how the provided `model_data` would be interpreted when provided to the :class:`~prism.modellink.ModelLink` subclass. Its output can be used for the 'model_data' input argument. Parameters ---------- model_data : array_like, dict or str Anything that can be converted to a dict that provides model data information. Returns ------- data_dict : dict Dict with the provided `model_data` converted to its full format. """ # If a data file is given if isinstance(model_data, str): # Obtain absolute path to given file data_file = path.abspath(model_data) # Read the data file in as a string data_points = np.genfromtxt(data_file, dtype=(str), delimiter=':', autostrip=True) # Make sure that data_points is 2D data_points = np_array(data_points, ndmin=2) # Convert read-in data to dict model_data = dict(data_points) # If a data dict is given elif isinstance(model_data, dict): model_data = dict(model_data) # If anything else is given else: # Check if it can be converted to a dict try: model_data = dict(model_data) except Exception: raise TypeError("Input model data cannot be converted to type " "'dict'!") # Make empty data_dict data_dict = dict() # Loop over all items in model_data for key, value in model_data.items(): # Convert key to an actual data_idx idx = e13.split_seq(key) # Check if tmp_idx is not empty if not idx: raise e13.InputError("Model data contains a data point with no " "identifier!") # Convert value to an actual data point data = e13.split_seq(value) # Check if provided data value is valid val = check_vals(data[0], 'data_val%s' % (idx), 'float') # Extract data error and space # If length is two, centered error and no data space were given if(len(data) == 2): err = [check_vals(data[1], 'data_err%s' % (idx), 'float', 'pos')]*2 spc = 'lin' # If length is three, there are two possibilities elif(len(data) == 3): # If the third column contains a string, it is the data space if isinstance(data[2], str): err = [check_vals(data[1], 'data_err%s' % (idx), 'float', 'pos')]*2 spc = data[2] # If the third column contains no string, it is error interval else: err = check_vals(data[1:3], 'data_err%s' % (idx), 'float', 'pos') spc = 'lin' # If length is four+, error interval and data space were given else: err = check_vals(data[1:3], 'data_err%s' % (idx), 'float', 'pos') spc = data[3] # Check if valid data space has been provided spc = str(spc).replace("'", '').replace('"', '') if spc.lower() in ('lin', 'linear'): spc = 'lin' elif spc.lower() in ('log', 'log10', 'log_10'): spc = 'log10' elif spc.lower() in ('ln', 'loge', 'log_e'): spc = 'ln' else: raise ValueError("Input argument 'data_spc%s' is invalid (%r)!" % (idx, spc)) # Save data identifier as tuple or single element if(len(idx) == 1): idx = idx[0] else: idx = tuple(idx) # Add entire data point to data_dict data_dict[idx] = [val, *err, spc] # Return data_dict return(data_dict)
# This function converts provided model parameters into format used by PRISM
[docs]def convert_parameters(model_parameters): """ Converts the provided `model_parameters` into a full parameters dict, taking into account all formatting options, and returns it. This function can be used externally to check how the provided `model_parameters` would be interpreted when provided to the :class:`~prism.modellink.ModelLink` subclass. Its output can be used for the 'model_parameters' input argument. Parameters ---------- model_parameters : array_like, dict or str Anything that can be converted to a dict that provides model parameters information. Returns ------- par_dict : dict Dict with the provided `model_parameters` converted to its full format. """ # If a parameter file is given if isinstance(model_parameters, str): # Obtain absolute path to given file par_file = path.abspath(model_parameters) # Read the parameter file in as a string pars = np.genfromtxt(par_file, dtype=(str), delimiter=':', autostrip=True) # Make sure that pars is 2D pars = np_array(pars, ndmin=2) # Convert read-in parameters to dict model_parameters = sdict(pars) # If a parameter dict is given elif isinstance(model_parameters, dict): model_parameters = sdict(model_parameters) # If anything else is given else: # Check if it can be converted to a dict try: model_parameters = sdict(model_parameters) except Exception: raise TypeError("Input model parameters cannot be converted to" " type 'dict'!") # Initialize empty par_dict par_dict = sdict() # Loop over all items in model_parameters for name, values_str in model_parameters.items(): # Convert values_str to values values = e13.split_seq(values_str) # Check if provided name is a string name = check_vals(name, 'par_name[%s]' % (name), 'str') # Check if provided range consists of two floats par_rng = check_vals(values[:2], 'par_rng[%s]' % (name), 'float') # Check if provided lower bound is lower than the upper bound if(par_rng[0] >= par_rng[1]): raise ValueError("Input argument 'par_rng[%s]' does not define a " "valid parameter range (%f !< %f)!" % (name, par_rng[0], par_rng[1])) # Check if a float parameter estimate was provided try: est = check_vals(values[2], 'par_est[%s]' % (name), 'float') # If no estimate was provided, save it as None except IndexError: est = None # If no float was provided, check if it was None except TypeError as error: # If it is None, save it as such if(str(values[2]).lower() == 'none'): est = None # If it is not None, reraise the previous error else: raise error # If a float was provided, check if it is within parameter range else: if not(values[0] <= est <= values[1]): raise ValueError("Input argument 'par_est[%s]' is outside " "of defined parameter range!" % (name)) # Add parameter to par_dict par_dict[name] = [*par_rng, est] # Return par_dict return(par_dict)
# This function tests a given ModelLink subclass # TODO: Are there any more tests that can be done here?
[docs]def test_subclass(subclass, *args, **kwargs): """ Tests a provided :class:`~prism.modellink.ModelLink` `subclass` by initializing it with the given `args` and `kwargs` and checking if all required methods can be properly called. This function needs to be called by all MPI ranks. Parameters ---------- subclass : :class:`~prism.modellink.ModelLink` subclass The :class:`~prism.modellink.ModelLink` subclass that requires testing. args : positional arguments Positional arguments that need to be provided to the constructor of the `subclass`. kwargs : keyword arguments Keyword arguments that need to be provided to the constructor of the `subclass`. Returns ------- modellink_obj : :obj:`~prism.modellink.ModelLink` object Instance of the provided `subclass` if all tests pass successfully. Specific exceptions are raised if a test fails. Note ---- Depending on the complexity of the model wrapped in the given `subclass`, this function may take a while to execute. """ # Import ModelLink class from prism.modellink import ModelLink # Check if provided subclass is a class if not isclass(subclass): raise e13.InputError("Input argument 'subclass' must be a class!") # Check if provided subclass is a subclass of ModelLink if not issubclass(subclass, ModelLink): raise TypeError("Input argument 'subclass' must be a subclass of the " "ModelLink class!") # Try to initialize provided subclass try: modellink_obj = subclass(*args, **kwargs) except Exception as error: raise e13.InputError("Input argument 'subclass' cannot be initialized!" " (%s)" % (error)) # Check if modellink_obj was initialized properly if not e13.check_instance(modellink_obj, ModelLink): obj_name = modellink_obj.__class__.__name__ raise e13.InputError("Provided ModelLink subclass %r was not " "initialized properly! Make sure that %r calls " "the super constructor during initialization!" % (obj_name, obj_name)) # Obtain list of arguments call_model should take call_model_args = list(signature(ModelLink.call_model).parameters) call_model_args.remove('self') # Check if call_model takes the correct arguments obj_call_model_args = dict(signature(modellink_obj.call_model).parameters) for arg in call_model_args: if arg not in obj_call_model_args.keys(): raise e13.InputError("The 'call_model()'-method in provided " "ModelLink subclass %r does not take required" " input argument %r!" % (modellink_obj._name, arg)) else: obj_call_model_args.pop(arg) # Check if call_model takes any other arguments for arg, par in obj_call_model_args.items(): # If this parameter has no default value and is not *args or **kwargs if(par.default == _empty and par.kind != _VAR_POSITIONAL and par.kind != _VAR_KEYWORD): # Raise error raise e13.InputError("The 'call_model()'-method in provided " "ModelLink subclass %r takes an unknown " "non-optional input argument %r!" % (modellink_obj._name, arg)) # Obtain list of arguments get_md_var should take get_md_var_args = list(signature(ModelLink.get_md_var).parameters) get_md_var_args.remove('self') # Check if get_md_var takes the correct arguments obj_get_md_var_args = dict(signature(modellink_obj.get_md_var).parameters) for arg in get_md_var_args: if arg not in obj_get_md_var_args.keys(): raise e13.InputError("The 'get_md_var()'-method in provided " "ModelLink subclass %r does not take required" " input argument %r!" % (modellink_obj._name, arg)) else: obj_get_md_var_args.pop(arg) # Check if get_md_var takes any other arguments for arg, par in obj_get_md_var_args.items(): # If this parameter has no default value and is not *args or **kwargs if(par.default == _empty and par.kind != _VAR_POSITIONAL and par.kind != _VAR_KEYWORD): # Raise an error raise e13.InputError("The 'get_md_var()'-method in provided " "ModelLink subclass %r takes an unknown " "non-optional input argument %r!" % (modellink_obj._name, arg)) # Set MPI intra-communicator comm = get_HybridComm_obj() # Obtain random sam_set on controller if not comm._rank: sam_set = modellink_obj._to_par_space(rand(1, modellink_obj._n_par)) # Workers get dummy sam_set else: sam_set = [] # Broadcast random sam_set to workers sam_set = comm.bcast(sam_set, 0) # Try to evaluate sam_set in the model try: # Check who needs to call the model if not comm._rank or modellink_obj._MPI_call: # Do multi-call if modellink_obj._multi_call: mod_set = modellink_obj.call_model( emul_i=0, par_set=modellink_obj._get_sam_dict(sam_set), data_idx=modellink_obj._data_idx) # Single-call else: # Initialize mod_set mod_set = np.zeros([sam_set.shape[0], modellink_obj._n_data]) # Loop over all samples in sam_set for i, par_set in enumerate(sam_set): mod_set[i] = modellink_obj.call_model( emul_i=0, par_set=modellink_obj._get_sam_dict(par_set), data_idx=modellink_obj._data_idx) # If call_model was not overridden, catch NotImplementedError except NotImplementedError: raise NotImplementedError("Provided ModelLink subclass %r has no " "user-written 'call_model()'-method!" % (modellink_obj._name)) # If successful, check if obtained mod_set has correct shape if not comm._rank: mod_set = modellink_obj._check_mod_set(mod_set, 'mod_set') # Check if the model discrepancy variance can be obtained try: md_var = modellink_obj.get_md_var( emul_i=0, par_set=modellink_obj._get_sam_dict(sam_set[0]), data_idx=modellink_obj._data_idx) # If get_md_var was not overridden, catch NotImplementedError except NotImplementedError: warn_msg = ("Provided ModelLink subclass %r has no user-written " "'get_md_var()'-method! Default model discrepancy variance" " description would be used instead!" % (modellink_obj._name)) warnings.warn(warn_msg, RequestWarning, stacklevel=2) # If successful, check if obtained md_var has correct shape else: md_var = modellink_obj._check_md_var(md_var, 'md_var') # Return modellink_obj return(modellink_obj)