Source code for prism.modellink._modellink

# -*- coding: utf-8 -*-

"""
ModelLink
=========
Provides the definition of the :class:`~ModelLink` abstract base class.

"""


# %% IMPORTS
# Future imports
from __future__ import absolute_import, division, print_function

# Built-in imports
import abc
from inspect import isclass
from os import path
import warnings

# Package imports
from e13tools import InputError, ShapeError
import numpy as np
from numpy.random import rand
from six import string_types, with_metaclass
from sortedcontainers import SortedDict as sdict, SortedSet as sset

# PRISM imports
from prism._docstrings import std_emul_i_doc
from prism._internal import (PRISM_Comm, RequestWarning, check_instance,
                             check_vals, convert_str_seq, docstring_substitute,
                             getCLogger, np_array, raise_error)

# All declaration
__all__ = ['ModelLink', 'test_subclass']


# %% MODELLINK CLASS DEFINITION
# TODO: Allow for inter-process methods?
# Like, having a method that is called before/after construction.



# %% UTILITY FUNCTIONS
# This function tests a given ModelLink subclass
# TODO: Are there any more tests that can be done here?
[docs]def test_subclass(subclass, *args, **kwargs): """ Tests a provided :class:`~ModelLink` `subclass` by initializing it with the given `args` and `kwargs` and checking if all required methods can be properly called. This function needs to be called by all MPI ranks. Parameters ---------- subclass : :class:`~ModelLink` subclass The :class:`~ModelLink` subclass that requires testing. args : tuple Positional arguments that need to be provided to the constructor of the `subclass`. kwargs : dict Keyword arguments that need to be provided to the constructor of the `subclass`. Returns ------- modellink_obj : :obj:`~ModelLink` object Instance of the provided `subclass` if all tests pass successfully. Specific exceptions are raised if a test fails. Note ---- Depending on the complexity of the model wrapped in the given `subclass`, this function may take a while to execute. """ # Check if provided subclass is a class if not isclass(subclass): raise InputError("Input argument 'subclass' must be a class!") # Check if provided subclass is a subclass of ModelLink if not issubclass(subclass, ModelLink): raise TypeError("Input argument 'subclass' must be a subclass of the " "ModelLink class!") # Try to initialize provided subclass try: modellink_obj = subclass(*args, **kwargs) except Exception as error: raise InputError("Input argument 'subclass' cannot be initialized! " "(%s)" % (error)) # Check if modellink_obj was initialized properly if not check_instance(modellink_obj, ModelLink): obj_name = modellink_obj.__class__.__name__ raise InputError("Provided ModelLink subclass %r was not " "initialized properly! Make sure that %r calls " "the super constructor during initialization!" % (obj_name, obj_name)) # Set MPI intra-communicator comm = PRISM_Comm() # Obtain random sam_set on controller if not comm._rank: sam_set = modellink_obj._to_par_space(rand(1, modellink_obj._n_par)) # Workers get dummy sam_set else: sam_set = [] # Broadcast random sam_set to workers sam_set = comm.bcast(sam_set, 0) # Try to evaluate sam_set in the model try: # Check who needs to call the model if not comm._rank or modellink_obj._MPI_call: # Do multi-call if modellink_obj._multi_call: mod_set = modellink_obj.call_model( emul_i=0, par_set=sdict(zip(modellink_obj._par_name, sam_set.T)), data_idx=modellink_obj._data_idx) # Single-call else: # Initialize mod_set mod_set = np.zeros([sam_set.shape[0], modellink_obj._n_data]) # Loop over all samples in sam_set for i, par_set in enumerate(sam_set): mod_set[i] = modellink_obj.call_model( emul_i=0, par_set=sdict(zip(modellink_obj._par_name, par_set)), data_idx=modellink_obj._data_idx) # If call_model was not overridden, catch NotImplementedError except NotImplementedError: raise NotImplementedError("Provided ModelLink subclass %r has no " "user-written 'call_model()'-method!" % (modellink_obj._name)) # If successful, check if obtained mod_set has correct shape if not comm._rank: mod_set = modellink_obj._check_mod_set(mod_set, 'mod_set') # Check if the model discrepancy variance can be obtained try: md_var = modellink_obj.get_md_var( emul_i=0, par_set=sdict(zip(modellink_obj._par_name, sam_set[0])), data_idx=modellink_obj._data_idx) # If get_md_var was not overridden, catch NotImplementedError except NotImplementedError: warn_msg = ("Provided ModelLink subclass %r has no user-written " "get_md_var()-method! Default model discrepancy variance " "description would be used instead!" % (modellink_obj._name)) warnings.warn(warn_msg, RequestWarning, stacklevel=2) # If successful, check if obtained md_var has correct shape else: md_var = modellink_obj._check_mod_set(md_var, 'md_var') # Return modellink_obj return(modellink_obj)