Source code for prism.utils.mcmc

# -*- coding: utf-8 -*-

"""
MCMC
====
Provides several functions that allow for *PRISM* to be connected more easily
to MCMC routines.

"""


# %% IMPORTS
# Built-in imports
from inspect import isfunction
import warnings

# Package imports
from e13tools import InputError
from e13tools.sampling import lhd
from e13tools.utils import docstring_substitute
import numpy as np
from sortedcontainers import SortedDict as sdict

# PRISM imports
from prism._docstrings import user_emul_i_doc
from prism._internal import RequestError, check_vals, np_array
from prism._pipeline import Pipeline

# All declaration
__all__ = ['get_hybrid_lnpost_fn', 'get_walkers', 'get_lnpost_fn']


# %% FUNCTION DEFINITIONS
# This function returns a hybrid version of the lnpost function
[docs]@docstring_substitute(emul_i=user_emul_i_doc) def get_hybrid_lnpost_fn(lnpost_fn, pipeline_obj, *, emul_i=None, unit_space=False, impl_prior=True, par_dict=False): """ Returns a function definition ``hybrid_lnpost(par_set, *args, **kwargs)``. This `hybrid_lnpost()` function can be used to calculate the natural logarithm of the posterior probability, which analyzes a given `par_set` first in the provided `pipeline_obj` at iteration `emul_i` and passes it to `lnpost_fn` if it is plausible. This function needs to be called by all MPI ranks. Parameters ---------- lnpost_fn : function Function definition that needs to be called if the provided `par_set` is plausible in iteration `emul_i` of `pipeline_obj`. The used call signature is ``lnpost_fn(par_set, *args, **kwargs)``. All MPI ranks will call this function unless called within the :attr:`~prism.Pipeline.worker_mode` context manager. pipeline_obj : :obj:`~prism.Pipeline` object The instance of the :class:`~prism.Pipeline` class that needs to be used for determining the validity of the proposed sampling step. Optional -------- %(emul_i)s unit_space : bool. Default: False Bool determining whether or not `par_set` will be given in unit space. impl_prior : bool. Default: True Bool determining whether or not the `hybrid_lnpost()` function should use the implausibility values of a given `par_set` as an additional prior. par_dict : bool. Default: False Bool determining whether or not `par_set` will be an array_like (*False*) or a dict (*True*). Returns ------- hybrid_lnpost : function Definition of the function ``hybrid_lnpost(par_set, *args, **kwargs)``. See also -------- :func:`~get_walkers` Analyzes proposed `init_walkers` and returns valid `p0_walkers`. :attr:`~prism.Pipeline.worker_mode` Special context manager within which all code is executed in worker mode. Note ---- The input arguments `unit_space` and `par_dict` state in what form `par_set` will be provided to the `hybrid_lnpost()` function, such that it can be properly converted to the format used in :class:`~prism.Pipeline`. The `par_set` that is passed to `lnpost_fn` is unchanged. Warning ------- Calling this function factory will disable all regular logging in `pipeline_obj` (:attr:`~prism.Pipeline.do_logging` set to *False*), in order to avoid having the same message being logged every time `hybrid_lnpost()` is called. """ # Check if lnpost_fn is a function if not isfunction(lnpost_fn): raise InputError("Input argument 'lnpost_fn' is not a callable " "function definition!") # Make abbreviation for pipeline_obj pipe = pipeline_obj # Check if provided pipeline_obj is an instance of the Pipeline class if not isinstance(pipe, Pipeline): raise TypeError("Input argument 'pipeline_obj' must be an instance of " "the Pipeline class!") # Check if the provided pipeline_obj uses a default emulator if(pipe._emulator._emul_type != 'default'): raise InputError("Input argument 'pipeline_obj' does not use a default" " emulator!") # Get emulator iteration emul_i = pipe._emulator._get_emul_i(emul_i, True) # Check if unit_space is a bool unit_space = check_vals(unit_space, 'unit_space', 'bool') # Check if impl_prior is a bool impl_prior = check_vals(impl_prior, 'impl_prior', 'bool') # Check if par_dict is a bool par_dict = check_vals(par_dict, 'par_dict', 'bool') # Disable PRISM logging pipe.do_logging = False # Define hybrid_lnpost function def hybrid_lnpost(par_set, *args, **kwargs): """ Calculates the natural logarithm of the posterior probability of `par_set` using the provided function `lnpost_fn`, in addition to constraining it first with the emulator defined in the `pipeline_obj`. This function needs to be called by all MPI ranks unless called within the :attr:`~prism.Pipeline.worker_mode` context manager. Parameters ---------- par_set : 1D array_like or dict Sample to calculate the posterior probability for. This sample is first analyzed in `pipeline_obj` and only given to `lnpost_fn` if it is plausible. If `par_dict` is *True*, this is a dict. args : tuple Positional arguments that need to be passed to `lnpost_fn`. kwargs : dict Keyword arguments that need to be passed to `lnpost_fn`. Returns ------- lnp : float The natural logarithm of the posterior probability of `par_set`, as determined by `lnpost_fn` if `par_set` is plausible. If `impl_prior` is *True*, `lnp` is calculated as `lnprior` + `lnpost_fn()`, with `lnprior` the natural logarithm of the first implausibility cut-off value of `par_set` scaled with its maximum. """ # If par_dict is True, convert par_set to a NumPy array if par_dict: sam = np_array(sdict(par_set).values(), ndmin=2) else: sam = np_array(par_set, ndmin=2) # If unit_space is True, convert par_set to par_space if unit_space: sam = pipe._modellink._to_par_space(sam) # Check if par_set is within parameter space and return -inf if not par_rng = pipe._modellink._par_rng if not ((par_rng[:, 0] <= sam[0])*(sam[0] <= par_rng[:, 1])).all(): return(-np.infty) # Check what sampling is requested and analyze par_set if impl_prior: impl_sam, lnprior = pipe._make_call('_evaluate_sam_set', emul_i, sam, 'hybrid') else: impl_sam = pipe._make_call('_evaluate_sam_set', emul_i, sam, 'analyze') lnprior = 0 # If par_set is plausible, call lnpost_fn if len(impl_sam): return(lnprior+lnpost_fn(par_set, *args, **kwargs)) # If par_set is not plausible, return -inf else: return(-np.infty) # Check if model in ModelLink can be single-called, raise warning if not if pipe._is_controller and not pipe._modellink._single_call: warn_msg = ("ModelLink bound to provided Pipeline object solely " "requests multi-calls. Using MCMC may not be possible.") warnings.warn(warn_msg, UserWarning, stacklevel=2) # Return hybrid_lnpost function definition return(hybrid_lnpost)
# This function returns a set of valid MCMC walkers
[docs]@docstring_substitute(emul_i=user_emul_i_doc) def get_walkers(pipeline_obj, *, emul_i=None, init_walkers=None, unit_space=False, lnpost_fn=None, **kwargs): """ Analyzes proposed `init_walkers` and returns plausible `p0_walkers`. Analyzes sample set `init_walkers` in the provided `pipeline_obj` at iteration `emul_i` and returns all samples that are plausible to be used as MCMC walkers. The provided samples and returned walkers should be/are given in unit space if `unit_space` is *True*. If `init_walkers` is *None*, returns :attr:`~prism.Pipeline.impl_sam` instead if it is available. This function needs to be called by all MPI ranks. Parameters ---------- pipeline_obj : :obj:`~prism.Pipeline` object The instance of the :class:`~prism.Pipeline` class that needs to be used for determining the plausibility of the proposed walkers. Optional -------- %(emul_i)s init_walkers : 2D array_like, dict, int or None. Default: None Sample set of proposed initial MCMC walker positions. All plausible samples in `init_walkers` will be returned. If int, generate an LHD of provided size and return all plausible samples. If *None*, return :attr:`~prism.Pipeline.impl_sam` corresponding to iteration `emul_i` instead. unit_space : bool. Default: False Bool determining whether or not the provided samples and returned walkers are given in unit space. lnpost_fn : function or None. Default: None If function, call :func:`~get_hybrid_lnpost_fn` using `lnpost_fn` and the same values for `pipeline_obj`, `emul_i` and `unit_space`, and return the resulting function definition `hybrid_lnpost()`. Any additionally provided `kwargs` are also passed to it. Returns ------- n_walkers : int Number of returned MCMC walkers. p0_walkers : 2D :obj:`~numpy.ndarray` object or dict Array containing plausible starting positions of valid MCMC walkers. If `init_walkers` was provided as a dict, `p0_walkers` will be a dict. hybrid_lnpost : function (if `lnpost_fn` is a function) The function returned by :func:`~get_hybrid_lnpost_fn` using `lnpost_fn`, `pipeline_obj`, `emul_i`, `unit_space` and `kwargs` as the input values. See also -------- :func:`~get_hybrid_lnpost_fn` Returns a function definition ``hybrid_lnpost(par_set, *args, **kwargs)``. :attr:`~prism.Pipeline.worker_mode` Special context manager within which all code is executed in worker mode. Notes ----- If `init_walkers` is *None* and emulator iteration `emul_i` has not been analyzed yet, a :class:`~prism._internal.RequestError` will be raised. """ # Make abbreviation for pipeline_obj pipe = pipeline_obj # Check if provided pipeline_obj is an instance of the Pipeline class if not isinstance(pipe, Pipeline): raise TypeError("Input argument 'pipeline_obj' must be an instance of " "the Pipeline class!") # Check if the provided pipeline_obj uses a default emulator if(pipe._emulator._emul_type != 'default'): raise InputError("Input argument 'pipeline_obj' does not use a default" " emulator!") # Get emulator iteration emul_i = pipe._emulator._get_emul_i(emul_i, True) # Check if unit_space is a bool unit_space = check_vals(unit_space, 'unit_space', 'bool') # Assume that walkers are not to be returned as a dict walker_dict = False # Check if lnpost_fn is None and try to get hybrid_lnpost function if not if lnpost_fn is not None: try: hybrid_lnpost =\ get_hybrid_lnpost_fn(lnpost_fn, pipe, emul_i=emul_i, unit_space=unit_space, **kwargs) except InputError: raise InputError("Input argument 'lnpost_fn' is invalid!") # If init_walkers is None, use impl_sam of emul_i if init_walkers is None: # Controller checking if emul_i has already been analyzed if pipe._is_controller: # If iteration has not been analyzed, raise error if not pipe._n_eval_sam[emul_i]: raise RequestError("Emulator iteration %i has not been " "analyzed yet!" % (emul_i)) # If iteration is last iteration, init_walkers is current impl_sam elif(emul_i == pipe._emulator._emul_i): init_walkers = pipe._impl_sam # If iteration is not last, init_walkers is previous impl_sam else: init_walkers = pipe._emulator._sam_set[emul_i+1] # Make sure to make a copy of init_walkers to avoid modifications init_walkers = init_walkers.copy() # Broadcast init_walkers to workers as p0_walkers p0_walkers = pipe._comm.bcast(init_walkers, 0) # If init_walkers is not None, use provided samples or LHD size else: # Controller checking if init_walkers is valid if pipe._is_controller: # If init_walkers is an int, create LHD of provided size if isinstance(init_walkers, int): # Check if provided integer is positive n_sam = check_vals(init_walkers, 'init_walkers', 'pos') # Create LHD of provided size init_walkers = lhd(n_sam, pipe._modellink._n_par, pipe._modellink._par_rng, 'center', pipe._criterion, 100) # If init_walkers is not an int, it must be array_like or dict else: # If init_walkers is provided as a dict, convert it if isinstance(init_walkers, dict): # Make sure that init_walkers is a SortedDict init_walkers = sdict(init_walkers) # Convert it to normal init_walkers = np_array(init_walkers.values()).T # Return p0_walkers as a dict walker_dict = True # Make sure that init_walkers is a NumPy array init_walkers = np_array(init_walkers, ndmin=2) # If unit_space is True, convert init_walkers to par_space if unit_space: init_walkers = pipe._modellink._to_par_space(init_walkers) # Check if init_walkers is valid init_walkers = pipe._modellink._check_sam_set(init_walkers, 'init_walkers') # Broadcast init_walkers to workers init_walkers = pipe._comm.bcast(init_walkers, 0) # Analyze init_walkers and save them as p0_walkers p0_walkers = pipe._evaluate_sam_set(emul_i, init_walkers, 'analyze') # Calculate n_walkers n_walkers = len(p0_walkers) # Check if p0_walkers is not empty if not n_walkers: raise InputError("Input argument 'init_walkers' contains no plausible " "samples!") # Check if p0_walkers needs to be converted if unit_space: p0_walkers = pipe._modellink._to_unit_space(p0_walkers) # Check if p0_walkers needs to be returned as a dict if walker_dict: p0_walkers = sdict(zip(pipe._modellink._par_name, p0_walkers.T)) # Check if hybrid_lnpost was requested and return it as well if so if lnpost_fn is not None: return(n_walkers, p0_walkers, hybrid_lnpost) else: return(n_walkers, p0_walkers)
# %% DEPRECATED FUNCTION DEFINITIONS # This function returns a specialized version of the lnpost function
[docs]@docstring_substitute(emul_i=user_emul_i_doc) def get_lnpost_fn(ext_lnpost, pipeline_obj, *, emul_i=None, unit_space=True, hybrid=True, par_dict=False): # pragma: no cover """ .. deprecated:: 1.1.3 Returns a function definition ``get_lnpost(par_set, *args, **kwargs)``. This `get_lnpost` function can be used to calculate the natural logarithm of the posterior probability, which analyzes a given `par_set` first in the provided `pipeline_obj` at iteration `emul_i` and passes it to the `ext_lnpost` function if it is plausible. This function needs to be called by all MPI ranks. Parameters ---------- ext_lnpost : function Function definition that needs to be called if the provided `par_set` is plausible in iteration `emul_i` of `pipeline_obj`. The used call signature is ``ext_lnpost(par_set, *args, **kwargs)``. All MPI ranks will call this function unless called within the :attr:`~prism.Pipeline.worker_mode` context manager. pipeline_obj : :obj:`~prism.Pipeline` object The instance of the :class:`~prism.Pipeline` class that needs to be used for determining the validity of the proposed sampling step. Optional -------- %(emul_i)s unit_space : bool. Default: True Bool determining whether or not `par_set` will be given in unit space. hybrid : bool. Default: True Bool determining whether or not the `get_lnpost` function should use the implausibility values of a given `par_set` as an additional prior. par_dict : bool. Default: False Bool determining whether or not `par_set` will be an array_like (*False*) or a dict (*True*). Returns ------- get_lnpost : function Definition of the function ``get_lnpost(par_set, *args, **kwargs)``. See also -------- :func:`~get_walkers` Analyzes proposed `init_walkers` and returns valid `p0_walkers`. :attr:`~prism.Pipeline.worker_mode` Special context manager within which all code is executed in worker mode. Note ---- The input arguments `unit_space` and `par_dict` state in what form `par_set` will be provided to the ``get_lnpost()`` function, such that it can be properly converted to the format used in :class:`~prism.Pipeline`. The `par_set` that is passed to the ``ext_lnpost()`` function is unchanged. Warning ------- Calling this function factory will disable all regular logging in `pipeline_obj` (:attr:`~prism.Pipeline.do_logging` set to *False*), in order to avoid having the same message being logged every time `get_lnpost` is called. """ # Raise a FutureWarning warn_msg = ("This function factory was remade into 'get_hybrid_lnpost_fn' " "in v1.1.3. This compatibility definition will be removed " "entirely in v1.2.0.") warnings.warn(warn_msg, FutureWarning, stacklevel=2) # Call new get_hybrid_lnpost_fn() function factory return(get_hybrid_lnpost_fn(ext_lnpost, pipeline_obj, emul_i=emul_i, unit_space=unit_space, impl_prior=hybrid, par_dict=par_dict))