Source code for prism.utils.mcmc

# -*- coding: utf-8 -*-

"""
MCMC
====
Provides several functions that allow for *PRISM* to be connected more easily
to MCMC routines.

"""


# %% IMPORTS
# Future imports
from __future__ import absolute_import, division, print_function

# Built-in imports
from inspect import isfunction
import warnings

# Package imports
from e13tools import InputError
from e13tools.sampling import lhd
import numpy as np
from six import integer_types

# PRISM imports
from prism._docstrings import user_emul_i_doc
from prism._internal import (RequestError, check_vals, docstring_substitute,
                             np_array)
from prism._pipeline import Pipeline

# All declaration
__all__ = ['get_lnpost_fn', 'get_walkers']


# %% FUNCTION DEFINITIONS
# This function returns a specialized version of the lnpost function
[docs]@docstring_substitute(emul_i=user_emul_i_doc) def get_lnpost_fn(ext_lnpost, pipeline_obj, emul_i=None, unit_space=True, hybrid=True): """ Returns a function definition ``get_lnpost(par_set, *args, **kwargs)``. This `get_lnpost` function can be used to calculate the natural logarithm of the posterior probability, which analyzes a given `par_set` first in the provided `pipeline_obj` at iteration `emul_i` and passes it to the `ext_lnpost` function if it is plausible. This function needs to be called by all MPI ranks. Parameters ---------- ext_lnpost : function Function definition that needs to be called if the provided `par_set` is plausible in iteration `emul_i` of `pipeline_obj`. The used call signature is ``ext_lnpost(par_set, *args, **kwargs)``. All MPI ranks will call this function unless called within the :attr:`~prism.Pipeline.worker_mode` context manager. pipeline_obj : :obj:`~prism.Pipeline` object The instance of the :class:`~prism.Pipeline` class that needs to be used for determining the validity of the proposed sampling step. Optional -------- %(emul_i)s unit_space : bool. Default: True Bool determining whether or not the provided sample will be given in unit space. hybrid : bool. Default: True Bool determining whether or not the `get_lnpost` function should use the implausibility values of a given `par_set` as an additional prior. Returns ------- get_lnpost : function Definition of the function ``get_lnpost(par_set, *args, **kwargs)``. See also -------- :func:`~get_walkers`: Analyzes proposed `init_walkers` and returns valid \ `p0_walkers`. :attr:`~prism.Pipeline.worker_mode`: Special context manager within which \ all code is executed in worker mode. Warning ------- Calling this function factory will disable all regular logging in `pipeline_obj` (:attr:`~prism.Pipeline.do_logging` set to *False*), in order to avoid having the same message being logged every time `get_lnpost` is called. """ # Check if ext_lnpost is a function if not isfunction(ext_lnpost): raise InputError("Input argument 'ext_lnpost' is not a callable " "function definition!") # Make abbreviation for pipeline_obj pipe = pipeline_obj # Check if provided pipeline_obj is an instance of the Pipeline class if not isinstance(pipe, Pipeline): raise TypeError("Input argument 'pipeline_obj' must be an instance of " "the Pipeline class!") # Get emulator iteration emul_i = pipe._emulator._get_emul_i(emul_i, True) # Check if unit_space is a bool unit_space = check_vals(unit_space, 'unit_space', 'bool') # Check if hybrid is a bool hybrid = check_vals(hybrid, 'hybrid', 'bool') # Disable PRISM logging pipe.do_logging = False # Define get_lnpost function def get_lnpost(par_set, *args, **kwargs): """ Calculates the natural logarithm of the posterior probability of `par_set` using the provided function `ext_lnpost`, in addition to constraining it first with the emulator defined in the `pipeline_obj`. This function needs to be called by all MPI ranks unless called within the :attr:`~prism.Pipeline.worker_mode` context manager. Parameters ---------- par_set : 1D array_like Sample to calculate the posterior probability for. This sample is first analyzed in `pipeline_obj` and only given to `ext_lnpost` if it is plausible. args : tuple Positional arguments that need to be passed to the `ext_lnpost` function. kwargs : dict Keyword arguments that need to be passed to the `ext_lnpost` function. Returns ------- lnp : float The natural logarithm of the posterior probability of `par_set`, as determined by the `ext_lnpost` function if `par_set` is plausible. If `hybrid` is *True*, `lnp` is calculated as `lnprior` + `ext_lnpost()`, with `lnprior` the natural logarithm of the first implausibility cut-off value of `par_set` scaled with its maximum. """ # If unit_space is True, convert par_set to par_space if unit_space: sam = pipe._modellink._to_par_space(par_set) else: sam = par_set # Check if par_set is within parameter space and return -inf if not par_rng = pipe._modellink._par_rng if not ((par_rng[:, 0] <= sam)*(sam <= par_rng[:, 1])).all(): return(-np.infty) # Check what sampling is requested and analyze par_set if hybrid: impl_sam, lnprior = pipe._make_call('_evaluate_sam_set', emul_i, np_array(sam, ndmin=2), 'hybrid') else: impl_sam = pipe._make_call('_evaluate_sam_set', emul_i, np_array(sam, ndmin=2), 'analyze') lnprior = 0 # If par_set is plausible, call ext_lnpost if len(impl_sam): return(lnprior+ext_lnpost(par_set, *args, **kwargs)) # If par_set is not plausible, return -inf else: return(-np.infty) # Check if model in ModelLink can be single-called, raise warning if not if pipe._is_controller and not pipe._modellink._single_call: warn_msg = ("ModelLink bound to provided Pipeline object solely " "requests multi-calls. Using MCMC may not be possible.") warnings.warn(warn_msg, UserWarning, stacklevel=2) # Return get_lnpost function definition return(get_lnpost)
# This function returns a set of valid MCMC walkers
[docs]@docstring_substitute(emul_i=user_emul_i_doc) def get_walkers(pipeline_obj, emul_i=None, init_walkers=None, unit_space=True, ext_lnpost=None, **kwargs): """ Analyzes proposed `init_walkers` and returns valid `p0_walkers`. Analyzes sample set `init_walkers` in the provided `pipeline_obj` at iteration `emul_i` and returns all samples that are plausible to be used as MCMC walkers. The provided samples and returned walkers should be/are given in unit space if `unit_space` is *True*. If `init_walkers` is *None*, returns :attr:`~prism.Pipeline.impl_sam` instead if it is available. This function needs to be called by all MPI ranks. Parameters ---------- pipeline_obj : :obj:`~prism.Pipeline` object The instance of the :class:`~prism.Pipeline` class that needs to be used for determining the validity of the proposed walkers. Optional -------- %(emul_i)s init_walkers : 2D array_like, int or None. Default: None Sample set of proposed initial MCMC walker positions. All plausible samples in `init_walkers` will be returned. If int, generate an LHD of provided size and return all plausible samples. If *None*, return :attr:`~prism.Pipeline.impl_sam` corresponding to iteration `emul_i` instead. unit_space : bool. Default: True Bool determining whether or not the provided samples and returned walkers are given in unit space. ext_lnpost : function or None. Default: None If function, call :func:`~get_lnpost_fn` function factory using `ext_lnpost` and the same values for `pipeline_obj`, `emul_i` and `unit_space`, and return the resulting function definition `get_lnpost`. Any additionally provided `kwargs` are also passed to it. Returns ------- n_walkers : int Number of returned MCMC walkers. p0_walkers : 2D :obj:`~numpy.ndarray` object Array containing starting positions of valid MCMC walkers. get_lnpost : function (if `ext_lnpost` is a function) The function returned by :func:`~get_lnpost_fn` function factory using `ext_lnpost`, `pipeline_obj`, `emul_i`, `unit_space` and `kwargs` as the input values. See also -------- :func:`~get_lnpost_fn`: Returns a function definition \ ``get_lnpost(par_set, *args, **kwargs)``. :attr:`~prism.Pipeline.worker_mode`: Special context manager within which \ all code is executed in worker mode. Notes ----- If `init_walkers` is *None* and emulator iteration `emul_i` has not been analyzed yet, a :class:`~prism._internal.RequestError` will be raised. """ # Make abbreviation for pipeline_obj pipe = pipeline_obj # Check if provided pipeline_obj is an instance of the Pipeline class if not isinstance(pipe, Pipeline): raise TypeError("Input argument 'pipeline_obj' must be an instance of " "the Pipeline class!") # Get emulator iteration emul_i = pipe._emulator._get_emul_i(emul_i, True) # Check if unit_space is a bool unit_space = check_vals(unit_space, 'unit_space', 'bool') # Check if ext_lnpost is None and try to obtain lnpost function if not if ext_lnpost is not None: try: lnpost_fn = get_lnpost_fn(ext_lnpost, pipe, emul_i, unit_space, **kwargs) except InputError: raise InputError("Input argument 'ext_lnpost' is invalid!") # If init_walkers is None, use impl_sam of emul_i if init_walkers is None: # Controller checking if emul_i has already been analyzed if pipe._is_controller: # If iteration has not been analyzed, raise error if not pipe._n_eval_sam[emul_i]: raise RequestError("Emulator iteration %i has not been " "analyzed yet!" % (emul_i)) # If iteration is last iteration, init_walkers is current impl_sam elif(emul_i == pipe._emulator._emul_i): init_walkers = pipe._impl_sam # If iteration is not last, init_walkers is previous impl_sam else: init_walkers = pipe._emulator._sam_set[emul_i+1] # Broadcast init_walkers to workers as p0_walkers p0_walkers = pipe._comm.bcast(init_walkers, 0) # If init_walkers is not None, use provided samples or LHD size else: # Controller checking if init_walkers is valid if pipe._is_controller: # If init_walkers is an int, create LHD of provided size if isinstance(init_walkers, integer_types): # Check if provided integer is positive n_sam = check_vals(init_walkers, 'init_walkers', 'pos') # Create LHD of provided size init_walkers = lhd(n_sam, pipe._modellink._n_par, pipe._modellink._par_rng, 'center', pipe._criterion, 100) # If init_walkers is not an int, it must be array_like else: # Make sure that init_walkers is a NumPy array init_walkers = np_array(init_walkers, ndmin=2) # If unit_space is True, convert init_walkers to par_space if unit_space: init_walkers = pipe._modellink._to_par_space(init_walkers) # Check if init_walkers is valid init_walkers = pipe._modellink._check_sam_set(init_walkers, 'init_walkers') # Broadcast init_walkers to workers init_walkers = pipe._comm.bcast(init_walkers, 0) # Analyze init_walkers and save them as p0_walkers p0_walkers = pipe._evaluate_sam_set(emul_i, init_walkers, 'analyze') # Calculate n_walkers n_walkers = len(p0_walkers) # Check if p0_walkers is not empty if not n_walkers: raise InputError("Input argument 'init_walkers' contains no plausible " "samples!") # Check if p0_walkers needs to be converted if unit_space: p0_walkers = pipe._modellink._to_unit_space(p0_walkers) # Check if lnpost_fn was requested and return it as well if so if ext_lnpost is not None: return(n_walkers, p0_walkers, lnpost_fn) else: return(n_walkers, p0_walkers)