# -*- coding: utf-8 -*-
"""
MCMC
====
Provides several functions that allow for *PRISM* to be connected more easily
to MCMC routines.
"""
# %% IMPORTS
# Built-in imports
from inspect import isfunction
import warnings
# Package imports
import e13tools as e13
import numpy as np
from numpy.random import multivariate_normal
from sortedcontainers import SortedDict as sdict
from tqdm import tqdm
# PRISM imports
from prism._docstrings import user_emul_i_doc
from prism._internal import RequestError, check_vals, np_array
from prism._pipeline import Pipeline
# All declaration
__all__ = ['get_hybrid_lnpost_fn', 'get_walkers']
# %% FUNCTION DEFINITIONS
# This function returns a hybrid version of the lnpost function
[docs]@e13.docstring_substitute(emul_i=user_emul_i_doc)
def get_hybrid_lnpost_fn(lnpost_fn, pipeline_obj, *, emul_i=None,
unit_space=False, impl_prior=True, par_dict=False):
"""
Returns a function definition ``hybrid_lnpost(par_set, *args, **kwargs)``.
This `hybrid_lnpost()` function can be used to calculate the natural
logarithm of the posterior probability, which analyzes a given `par_set`
first in the provided `pipeline_obj` at iteration `emul_i` and passes it to
`lnpost_fn` if it is plausible.
This function needs to be called by all MPI ranks.
Parameters
----------
lnpost_fn : function
Function definition that needs to be called if the provided `par_set`
is plausible in iteration `emul_i` of `pipeline_obj`. The used call
signature is ``lnpost_fn(par_set, *args, **kwargs)``. All MPI ranks
will call this function unless called within the
:attr:`~prism.Pipeline.worker_mode` context manager.
pipeline_obj : :obj:`~prism.Pipeline` object
The instance of the :class:`~prism.Pipeline` class that needs
to be used for determining the validity of the proposed sampling step.
Optional
--------
%(emul_i)s
unit_space : bool. Default: False
Bool determining whether or not `par_set` will be given in unit space.
impl_prior : bool. Default: True
Bool determining whether or not the `hybrid_lnpost()` function should
use the implausibility values of a given `par_set` as an additional
prior.
par_dict : bool. Default: False
Bool determining whether or not `par_set` will be an array_like
(*False*) or a dict (*True*).
Returns
-------
hybrid_lnpost : function
Definition of the function ``hybrid_lnpost(par_set, *args, **kwargs)``.
See also
--------
:func:`~get_walkers`
Analyzes proposed `init_walkers` and returns valid `p0_walkers`.
:attr:`~prism.Pipeline.worker_mode`
Special context manager within which all code is executed in worker
mode.
Note
----
The input arguments `unit_space` and `par_dict` state in what form
`par_set` will be provided to the `hybrid_lnpost()` function, such that it
can be properly converted to the format used in :class:`~prism.Pipeline`.
The `par_set` that is passed to `lnpost_fn` is unchanged.
Warning
-------
Calling this function factory will disable all regular logging in
`pipeline_obj` (:attr:`~prism.Pipeline.do_logging` set to *False*), in
order to avoid having the same message being logged every time
`hybrid_lnpost()` is called.
"""
# Check if lnpost_fn is a function
if not isfunction(lnpost_fn):
raise e13.InputError("Input argument 'lnpost_fn' is not a callable "
"function definition!")
# Make abbreviation for pipeline_obj
pipe = pipeline_obj
# Check if provided pipeline_obj is an instance of the Pipeline class
if not isinstance(pipe, Pipeline):
raise TypeError("Input argument 'pipeline_obj' must be an instance of "
"the Pipeline class!")
# Check if the provided pipeline_obj uses a default emulator
if(pipe._emulator._emul_type != 'default'):
raise e13.InputError("Input argument 'pipeline_obj' does not use a "
"default emulator!")
# Get emulator iteration
emul_i = pipe._emulator._get_emul_i(emul_i)
# Check if unit_space is a bool
unit_space = check_vals(unit_space, 'unit_space', 'bool')
# Check if impl_prior is a bool
impl_prior = check_vals(impl_prior, 'impl_prior', 'bool')
# Check if par_dict is a bool
par_dict = check_vals(par_dict, 'par_dict', 'bool')
# Disable PRISM logging
pipe.do_logging = False
# Define hybrid_lnpost function
def hybrid_lnpost(par_set, *args, **kwargs):
"""
Calculates the natural logarithm of the posterior probability of
`par_set` using the provided function `lnpost_fn`, in addition to
constraining it first with the emulator defined in the `pipeline_obj`.
This function needs to be called by all MPI ranks unless called within
the :attr:`~prism.Pipeline.worker_mode` context manager.
Parameters
----------
par_set : 1D array_like or dict
Sample to calculate the posterior probability for. This sample is
first analyzed in `pipeline_obj` and only given to `lnpost_fn` if
it is plausible. If `par_dict` is *True*, this is a dict.
args : positional arguments
Positional arguments that need to be passed to `lnpost_fn`.
kwargs : keyword arguments
Keyword arguments that need to be passed to `lnpost_fn`.
Returns
-------
lnp : float
The natural logarithm of the posterior probability of `par_set`, as
determined by `lnpost_fn` if `par_set` is plausible. If
`impl_prior` is *True*, `lnp` is calculated as `lnprior` +
`lnpost_fn()`, with `lnprior` the natural logarithm of the first
implausibility cut-off value of `par_set` scaled with its maximum.
"""
# If par_dict is True, convert par_set to a NumPy array
if par_dict:
sam = np_array(sdict(par_set).values(), ndmin=2)
else:
sam = np_array(par_set, ndmin=2)
# If unit_space is True, convert par_set to par_space
if unit_space:
sam = pipe._modellink._to_par_space(sam)
# Check if par_set is within parameter space and return -inf if not
par_rng = pipe._modellink._par_rng
if not ((par_rng[:, 0] <= sam[0])*(sam[0] <= par_rng[:, 1])).all():
return(-np.infty)
# Check what sampling is requested and analyze par_set
if impl_prior:
impl_sam, lnprior = pipe._make_call('_evaluate_sam_set', emul_i,
sam, 'hybrid')
else:
impl_sam = pipe._make_call('_evaluate_sam_set', emul_i, sam,
'analyze')
lnprior = 0
# If par_set is plausible, call lnpost_fn
if len(impl_sam):
return(lnprior+lnpost_fn(par_set, *args, **kwargs))
# If par_set is not plausible, return -inf
else:
return(-np.infty)
# Check if model in ModelLink can be single-called, raise warning if not
if pipe._is_controller and not pipe._modellink._single_call:
warn_msg = ("ModelLink bound to provided Pipeline object solely "
"requests multi-calls. Using MCMC may not be possible.")
warnings.warn(warn_msg, UserWarning, stacklevel=2)
# Return hybrid_lnpost function definition
return(hybrid_lnpost)
# This function returns a set of valid MCMC walkers
[docs]@e13.docstring_substitute(emul_i=user_emul_i_doc)
def get_walkers(pipeline_obj, *, emul_i=None, init_walkers=None,
req_n_walkers=None, unit_space=False, lnpost_fn=None,
**kwargs):
"""
Analyzes proposed `init_walkers` and returns plausible `p0_walkers`.
Analyzes sample set `init_walkers` in the provided `pipeline_obj` at
iteration `emul_i` and returns all samples that are plausible to be used as
starting positions for MCMC walkers. The provided samples and returned
walkers should be/are given in unit space if `unit_space` is *True*.
If `init_walkers` is *None*, returns :attr:`~prism.Pipeline.impl_sam`
instead if it is available.
This function needs to be called by all MPI ranks.
Parameters
----------
pipeline_obj : :obj:`~prism.Pipeline` object
The instance of the :class:`~prism.Pipeline` class that needs to be
used for determining the plausibility of the proposed starting
positions.
Optional
--------
%(emul_i)s
init_walkers : 2D array_like, dict, int or None. Default: None
Sample set of proposed initial MCMC walker positions. All plausible
samples in `init_walkers` will be returned.
If int, generate an LHD of provided size and return all plausible
samples.
If *None*, return :attr:`~prism.Pipeline.impl_sam` corresponding to
iteration `emul_i` instead.
req_n_walkers : int or None. Default: None
The minimum required number of plausible starting positions that should
be returned. If *None*, all plausible starting positions in
`init_walkers` are returned instead.
.. versionadded:: 1.2.0
unit_space : bool. Default: False
Bool determining whether or not the provided samples and returned
walkers are given in unit space.
lnpost_fn : function or None. Default: None
If function, call :func:`~get_hybrid_lnpost_fn` using `lnpost_fn` and
the same values for `pipeline_obj`, `emul_i` and `unit_space`, and
return the resulting function definition `hybrid_lnpost()`. Any
additionally provided `kwargs` are also passed to it.
Returns
-------
n_walkers : int
Number of returned MCMC walkers. Note that this number can be higher
than `req_n_walkers` if not *None*.
p0_walkers : 2D :obj:`~numpy.ndarray` object or dict
Array containing plausible starting positions of valid MCMC walkers.
If `init_walkers` was provided as a dict, `p0_walkers` will be a dict.
hybrid_lnpost : function (if `lnpost_fn` is a function)
The function returned by :func:`~get_hybrid_lnpost_fn` using
`lnpost_fn`, `pipeline_obj`, `emul_i`, `unit_space` and `kwargs` as the
input values.
See also
--------
:func:`~get_hybrid_lnpost_fn`
Returns a function definition ``hybrid_lnpost(par_set, *args,
**kwargs)``.
:attr:`~prism.Pipeline.worker_mode`
Special context manager within which all code is executed in worker
mode.
Notes
-----
If `init_walkers` is *None* and emulator iteration `emul_i` has not been
analyzed yet, a :class:`~prism._internal.RequestError` will be raised.
If `req_n_walkers` is not *None*, a custom Metropolis-Hastings sampling
algorithm is used to generate the required number of starting positions.
All plausible samples in `init_walkers` are used as the start of every MCMC
chain. Note that if the number of plausible samples in `init_walkers` is
small, it is possible that the returned `p0_walkers` are not spread out
properly over parameter space.
"""
# Make abbreviation for pipeline_obj
pipe = pipeline_obj
# Check if provided pipeline_obj is an instance of the Pipeline class
if not isinstance(pipe, Pipeline):
raise TypeError("Input argument 'pipeline_obj' must be an instance of "
"the Pipeline class!")
# Check if the provided pipeline_obj uses a default emulator
if(pipe._emulator._emul_type != 'default'):
raise e13.InputError("Input argument 'pipeline_obj' does not use a "
"default emulator!")
# Get emulator iteration
emul_i = pipe._emulator._get_emul_i(emul_i)
# If req_n_walkers is not None, check if it is an integer
if req_n_walkers is not None:
req_n_walkers = check_vals(req_n_walkers, 'req_n_walkers', 'int',
'pos')
# Check if unit_space is a bool
unit_space = check_vals(unit_space, 'unit_space', 'bool')
# Assume that walkers are not to be returned as a dict
walker_dict = False
# Check if lnpost_fn is None and try to get hybrid_lnpost function if not
if lnpost_fn is not None:
try:
hybrid_lnpost =\
get_hybrid_lnpost_fn(lnpost_fn, pipe, emul_i=emul_i,
unit_space=unit_space, **kwargs)
except e13.InputError:
raise e13.InputError("Input argument 'lnpost_fn' is invalid!")
# If init_walkers is None, use impl_sam of emul_i
if init_walkers is None:
# Controller checking if emul_i has already been analyzed
if pipe._is_controller:
# If iteration has not been analyzed, raise error
if not pipe._n_eval_sam[emul_i]:
raise RequestError("Emulator iteration %i has not been "
"analyzed yet!" % (emul_i))
# If iteration is last iteration, init_walkers is current impl_sam
elif(emul_i == pipe._emulator._emul_i):
init_walkers = pipe._impl_sam
# If iteration is not last, init_walkers is previous impl_sam
else:
init_walkers = pipe._emulator._sam_set[emul_i+1]
# Make sure to make a copy of init_walkers to avoid modifications
init_walkers = init_walkers.copy()
# Broadcast init_walkers to workers as p0_walkers
p0_walkers = pipe._comm.bcast(init_walkers, 0)
# If init_walkers is not None, use provided samples or LHD size
else:
# Controller checking if init_walkers is valid
if pipe._is_controller:
# If init_walkers is an int, create LHD of provided size
if isinstance(init_walkers, int):
# Check if provided integer is positive
n_sam = check_vals(init_walkers, 'init_walkers', 'pos')
# Obtain the par_space to sample in
par_space = pipe._get_impl_space(emul_i)
# If par_space is None, use the corresponding emul_space
if par_space is None:
par_space = pipe._emulator._emul_space[emul_i]
# Create LHD of provided size
init_walkers = e13.lhd(n_sam, pipe._modellink._n_par,
par_space, 'center', pipe._criterion,
100)
# If init_walkers is not an int, it must be array_like or dict
else:
# If init_walkers is provided as a dict, convert it
if isinstance(init_walkers, dict):
# Make sure that init_walkers is a SortedDict
init_walkers = sdict(init_walkers)
# Convert it to normal
init_walkers = np_array(init_walkers.values()).T
# Return p0_walkers as a dict
walker_dict = True
# Make sure that init_walkers is a NumPy array
init_walkers = np_array(init_walkers, ndmin=2)
# If unit_space is True, convert init_walkers to par_space
if unit_space:
init_walkers = pipe._modellink._to_par_space(init_walkers)
# Check if init_walkers is valid
init_walkers = pipe._modellink._check_sam_set(init_walkers,
'init_walkers')
# Broadcast init_walkers to workers
init_walkers = pipe._comm.bcast(init_walkers, 0)
# Analyze init_walkers and save them as p0_walkers
p0_walkers = pipe._evaluate_sam_set(emul_i, init_walkers, 'analyze')
# Check if init_walkers is not empty and raise error if it is
if not p0_walkers.shape[0]:
raise e13.InputError("Input argument 'init_walkers' contains no "
"plausible samples!")
# If req_n_walkers is not None, use MH MCMC to find all required walkers
if req_n_walkers is not None:
n_walkers, p0_walkers = _do_mh_walkers(pipe, p0_walkers, req_n_walkers)
else:
p0_walkers = np.unique(p0_walkers, axis=0)
n_walkers = p0_walkers.shape[0]
# Check if p0_walkers needs to be converted
if unit_space:
p0_walkers = pipe._modellink._to_unit_space(p0_walkers)
# Check if p0_walkers needs to be returned as a dict
if walker_dict:
p0_walkers = pipe._modellink._get_sam_dict(p0_walkers)
# Check if hybrid_lnpost was requested and return it as well if so
if lnpost_fn is not None:
return(n_walkers, p0_walkers, hybrid_lnpost)
else:
return(n_walkers, p0_walkers)
# %% HIDDEN FUNCTION DEFINITIONS
# This function uses MH sampling to find req_n_walkers initial positions
def _do_mh_walkers(pipeline_obj, p0_walkers, req_n_walkers):
"""
Generates `req_n_walkers` unique starting positions for MCMC walkers using
Metropolis-Hastings sampling and the provided `pipeline_obj` and
`p0_walkers`.
This function needs to be called by all MPI ranks.
Parameters
----------
pipeline_obj : :obj:`~prism.Pipeline` object
The instance of the :class:`~prism.Pipeline` class that needs to be
used for determining the validity of a proposed sampling step.
p0_walkers : 2D :obj:`~numpy.ndarray` object
Sample set of starting positions for the MCMC chains.
req_n_walkers : int
The minimum required number of unique MCMC walker positions that must
be returned.
Returns
-------
n_walkers : int
Number of unique MCMC walker positions that are returned.
walkers : 2D :obj:`~numpy.ndarray` object
Array containing all unique MCMC walker positions.
Note
----
Executing this function will temporarily disable all regular logging in the
provided :obj:`~prism.Pipeline` object. If logging was enabled before this
function was executed, it will be enabled again afterward.
"""
# Make abbreviation for pipeline_obj
pipe = pipeline_obj
# Define function to check if proposed sam_set should be accepted
def advance_chain(sam_set):
# Make sure that sam_set is 2D
sam_set = np_array(sam_set, ndmin=2)
# Check if sam_set is within parameter space and reject if not
par_rng = pipe._modellink._par_rng
accept = ((par_rng[:, 0] <= sam_set)*(sam_set <= par_rng[:, 1])).all(1)
# Evaluate all non-rejected samples and accept if plausible
emul_i = pipe._emulator._emul_i
accept[accept] = pipe._make_call('_evaluate_sam_set', emul_i,
sam_set[accept], 'project')[0]
# Return which samples should be accepted or rejected
return(accept)
# Initialize array of final walkers
n_walkers = p0_walkers.shape[0]
walkers = np.empty([req_n_walkers+n_walkers-1, pipe._modellink._n_par])
walkers[:n_walkers] = p0_walkers
# Check if logging is currently turned on
was_logging = bool(pipe.do_logging)
# Make sure that logging is turned off
pipe.do_logging = False
# Use worker mode
with pipe.worker_mode:
if pipe._is_controller:
# Initialize progress bar
pbar = tqdm(desc="Finding walkers", total=req_n_walkers,
initial=n_walkers, disable=not was_logging,
bar_format=("{l_bar}{bar}| {n_fmt}/{total_fmt} "
"[Time: {elapsed}]"))
# Keep searching parameter space until req_n_walkers are found
while(n_walkers < req_n_walkers):
# Calculate the covariance matrix of all walkers
cov = np.cov(walkers[:n_walkers].T)
# Create set of proposed walkers
new_walkers = np.apply_along_axis(multivariate_normal, 1,
p0_walkers, cov)
# Check which proposed walkers should be accepted
accept = advance_chain(new_walkers)
acc_walkers = new_walkers[accept]
n_accepted = sum(accept)
# Replace current walkers with accepted walkers
p0_walkers[accept] = acc_walkers
# Update final walkers array
walkers[n_walkers:n_walkers+n_accepted] = acc_walkers
n_walkers += n_accepted
# Update progress bar
pbar.update(min(req_n_walkers, n_walkers)-pbar.n)
# Close progress bar
pbar.close()
# Turn logging back on if it used to be on
pipe.do_logging = was_logging
# Broadcast walkers to all workers
walkers = pipe._comm.bcast(np.unique(walkers[:req_n_walkers], axis=0), 0)
n_walkers = walkers.shape[0]
# Return n_walkers, walkers
return(n_walkers, walkers)