Source code for ACID_code_v2.result

from time import time
import numpy as np
import matplotlib.pyplot as plt
import corner, sys, os, pickle, warnings, contextlib, functools, inspect
from emcee import EnsembleSampler
import emcee.backends.backend as emceebackend
from beartype import beartype
from scipy.interpolate import interp1d
from numpy import integer as npint
from numpy.polynomial import polynomial as P
from .lsd import LSD
from . import mcmc
from . import utils
from .data import Data
from .data import Config

warnings.filterwarnings("ignore")

def _require_all_frames(method):
    # Make sure all results are processed before calling method
    @functools.wraps(method)
    def wrapper(self, *args, **kwargs):
        if self.all_frames is None:
            name = method.__qualname__
            if self.sampler is not None and self.data is not None:
                if self.config.verbose>0:
                    print(f"Note: The Result object was created without all_frames processed. " \
                        f"Running {name} requires all results to be processed, " \
                        "so process_results() will be called automatically...")
                self.process_results()
            else:
                error = f"Cannot call {name}. The all_frames attribute is not available, and no " \
                "sampler and data objects are available to process results. Please pass an Acid " \
                "object after running ACID to the results init."
                raise ValueError(error)
        return method(self, *args, **kwargs)
    return wrapper

def _require_data(method):
    # Make sure Data object is available before calling method
    @functools.wraps(method)
    def wrapper(self, *args, **kwargs):
        if self.data is None:
            name = method.__qualname__
            error = f"Cannot call {name}. The Data object is not available in this " \
            "Result instance. This can occur if Data was set to None to allow for pickling in the " \
            "case that multiple orders or frames were used."
            raise ValueError(error)
        return method(self, *args, **kwargs)
    return wrapper

def _require_sampler(method):
    # Make sure sampler object is available before calling method
    sig = inspect.signature(method)
    @functools.wraps(method)
    def wrapper(self, *args, **kwargs):
        bound = sig.bind_partial(self, *args, **kwargs)

        # A specific carvout for the save function
        store_sampler_in_args = "store_sampler" in sig.parameters
        if store_sampler_in_args is True:
            if bound.arguments.get("store_sampler", True) is False:
                return method(self, *args, **kwargs)

        sampler_in_args = "sampler" in sig.parameters
        inputted_sampler = bound.arguments.get("sampler", None)

        self.initiate_sampler(inputted_sampler if sampler_in_args else None)

        return method(self, *args, **kwargs)
    return wrapper

[docs] @beartype class Result: """Class to handle the results from the Acid MCMC sampling, and results processing. Fundamentally, this class requires two objects to run, the Sampler object and the Data object, both of which can be obtained from the Acid object. If one or the other is not provided, some methods will not work.""" def __init__( self, Acid_or_Data_or_Sampler, sampler = None, process_results: bool = True, ACID_HARPS : bool = False, verbose : int|bool|None = None, ): """Initiate Result class Parameters ---------- Acid_or_Data_or_Sampler : Acid | Data | emcee.EnsembleSampler An Acid object, Data object (contained in Acid class), or sampler object. If an Acid object is provided, all other arguments are taken from there. If a Data object is provided, a sampler can be provided in the second argument. If a sampler object is provided, it will be used as the sampler, but all other attributes will need to be set manually for the Result object to be fully functional. sampler : emcee.EnsembleSampler | None, optional A sampler object to use if the Data object was provided. If an Acid object was provided, the sampler will be taken from there. If a sampler object was provided in the first argument, this will be ignored (with a warning), by default None process_results : bool, optional Whether to process the results from the Acid object upon initialisation, by default True. If False, the all_frames attribute will not be available until Result.process_results() is called. The process_results functions does a LSD call, which can be skipped to save time and use the Result object for methods that do not require the all_frames attribute, such as continue_sampling() or plot_walkers(). This requires a Data object with the necessary attributes, and a sampler object in the initialisation, or an Acid object with the necessary attributes already set. By default, None. ACID_HARPS : bool, optional Whether the ACID_HARPS function was used, by default False verbose : int|bool|None, optional Verbosity level, works exactly the same as Acid verbosity, if not provided defaults to provided Acid class verbosity otherwise defaults to 2. # production_run : bool, optional # Whether Acid was run in production mode, by default False """ obj = Acid_or_Data_or_Sampler self.sampler = None self.data = None self.all_frames = None self.config = Config() # default config, will be updated if Acid or Data object is provided self.config.verbose = verbose if hasattr(obj, "data") and hasattr(obj, "config") and hasattr(obj, "sampler"): # The above line is only all true if an Acid object (as they are set in initialisation), # the sampler and data classes do not store all 3 acid = obj self.initiate_data(acid.data) self.config = acid.config self.initiate_sampler(acid.sampler) elif isinstance(obj, Data): data = obj self.initiate_data(data) self.config = data.config if sampler is not None: self.initiate_sampler(sampler) elif isinstance(obj, EnsembleSampler): self.initiate_sampler(obj) if self.config.verbose>0: print("Warning: Data object not provided. Result object will not be fully functional.") return else: raise ValueError("First argument must be an Acid object, Data object, or emcee.EnsembleSampler object. " f"Got {type(obj)} instead.") # From this point, a Data instance is provided and can be drawn from, but sampler may or may not be provided. # All frames must be available as a Result class variable due to legacy behaviour. Once created, we can point # Data.all_frames to Result.all_frames to keep them in sync. if process_results: self.process_results() # sets self.data.all_frames, and points self.all_frames to self.data.all_frames else: if self.config.verbose>0: print("Warning: Results not processed. all_frames attribute will not be available until " \ "Result.process_results() is called.") # Store internal variables self.ACID_HARPS = ACID_HARPS # Only takes if ACID_HARPS was run, otherwise all None self.BJDs = getattr(obj, 'BJDs', None) self.profiles = getattr(obj, 'profiles', None) self.errors = getattr(obj, 'errors', None)
[docs] @_require_data @_require_sampler def process_results(self): t0 = time() # Obtain flattened samples flat_samples = self.sampler.get_chain(discard=self.burnin, thin=self.thin, flat=True) # Getting the final profile and continuum values - median of last 1000 steps nvel = len(self.data.velocities) if self.config.deterministic_profile is False else 0 quartiles = np.percentile(flat_samples, [16, 50, 84], axis=0) errors = np.diff(quartiles, axis=0) errors = np.max(errors, axis=0) # why? self.profile = quartiles[1, :nvel] self.profile_err = errors[:nvel] self.poly_cos = quartiles[1, nvel:] self.poly_cos_err = errors[nvel:] if self.config.verbose > 1: print('Getting the final profiles...') # Finding error for the continuum fit norm_wl = self.data.wavelengths["combined_normalized"] coeffs = flat_samples[:, nvel:] ncoeffs = coeffs.shape[1] powers = np.vander(norm_wl, N=ncoeffs, increasing=True) conts = (coeffs @ powers.T) continuum_error = np.std(np.array(conts), axis=0) for counter in range(len(self.data.flux["input"])): flux = np.copy(self.data.flux["input"][counter]) error = np.copy(self.data.errors["input"][counter]) wavelengths = np.copy(self.data.wavelengths["input"][counter]) sn = np.copy(self.data.sn["input"][counter]) flux = flux[self.data.nanmask] error = error[self.data.nanmask] wavelengths = wavelengths[self.data.nanmask] # Build continuum model a, b = utils.get_normalisation_coeffs(wavelengths) norm_wavelengths = (a*wavelengths)+b mdl1 = P.polyval(norm_wavelengths, self.poly_cos) # Masking based off residuals interpolated onto new wavelength grid reference_wave = self.data.wavelengths["input"][np.nanargmax(self.data.sn["input"])] reference_wave = reference_wave[self.data.nanmask] mask_pos = np.ones(reference_wave.shape) mask_pos[self.data.residual_masks]=1e12 f2 = interp1d(reference_wave, mask_pos, bounds_error = False, fill_value = np.nan) interp_mask_pos = f2(wavelengths) interp_mask_idx = tuple([interp_mask_pos>=1e12]) error[interp_mask_idx]=1e12 # correcting continuum error = np.sqrt((error/mdl1)**2 + (continuum_error/mdl1)**2) flux /= mdl1 remove = tuple([flux<0]) flux[remove] = 1. error[remove] = 1e12 LSD_profiles = LSD(self.data) LSD_profiles.run_LSD(wavelengths, flux, error, sn=sn) profile_f = LSD_profiles.profile_F profile_errors_f = LSD_profiles.profile_errors_F # profile_f = profile_f-1 self.all_frames[counter, self.config.order]=[profile_f, profile_errors_f] self.data.all_frames = self.all_frames # point Data.all_frames to Result.all_frames to keep them in sync self.data.get_profiles_time = time() - t0 self.data.full_run_time = self.data.initialisation_time + self.data.mcmc_time + self.data.get_profiles_time return
@_require_all_frames def __getitem__(self, item): """Allows indexing into the all_frames array directly from the Result object. """ if self.ACID_HARPS: return self.BJDs[item], self.profiles[item], self.errors[item] else: return self.all_frames[item] @_require_all_frames def __iter__(self): """Allows iteration over the BJDs, profiles, and errors if ACID_HARPS was used. """ if self.ACID_HARPS: return iter((self.BJDs, self.profiles, self.errors)) # Acid is not subscriptable normally, only when ACID_HARPS was called raise TypeError("Result is not iterable unless ACID_HARPS=True") def __repr__(self): # Only print out the sampler and data attributes, and whether all_frames is available, to avoid printing large arrays return f"Result object with sampler={self.sampler}, data={self.data}, all_frames={'available' if self.all_frames is not None else 'not available'}"
[docs] @_require_data @_require_sampler def continue_sampling(self, process_results:bool=True, sampler=None, **kwargs): """Continue MCMC sampling for additional steps. Passes the stored sampler into a Acid instance with the saved data. See Acid.continue_sampling() for more details on the parameters that can be passed. Parameters ---------- nsteps : int | None, optional Number of additional MCMC steps to run. Passed to Acid.continue_sampling with the stored sampler. max_steps : int | None, optional Maximum number of MCMC steps to run, by default None. Passed to Acid.continue_sampling with the stored sampler. max_steps_kwards : dict, optional Additional keyword arguments to be passed to the run_mcmc_until_converged function if max_steps is specified, by default None. The kwargs description can be found in Acid.ACID(), they are the 4 kwargs appearing after max_steps. Typos for kwargs are silently ignored. Passed to Acid.continue_sampling with the stored sampler. process_results : bool, optional Whether to process the results after continuing sampling, by default True. If False, the all_frames attribute will not be updated until Result.process_results() is called. sampler : emcee.EnsembleSampler | None, optional Optionally provide a different sampler to continue sampling from, otherwise, takes the sampler from the Result object, by default None """ if type(process_results) is int: raise ValueError("The process_results attribute must be a boolean, not an integer. Did you mean to set nsteps? If so, specificy nsteps=nsteps.") from .acid import Acid acid = Acid(data=self.data) # includes config data self.sampler = acid.continue_sampling(self.sampler, **kwargs) self.initiate_sampler(self.sampler) # update internal variables to match new sampler if process_results: self.process_results() # update all_frames else: if self.config.verbose>0: print("Warning: Results not processed. all_frames attribute will not be available until " \ "Result.process_results() is called.")
[docs] @_require_sampler def plot_walkers(self, sampler=None, burnin:int|npint|None=None, thin:int|npint|None=None, return_fig:bool=False): """Plots, at maximum, the last 10 MCMC walkers for the LSD profile and continuum polynomial coefficients. Parameters ---------- sampler : emcee.EnsembleSampler | None, optional Optionally provide a different sampler to plot from, otherwise, takes the sampler from the Result object, by default None burnin : int | None, optional Optionally define the number of burnin steps, by default None thin : int | None, optional Optionally define the number of thinning steps, by default None return_fig : bool, optional Whether to return the figure and axis objects instead of showing the plot, by default False """ burnin = burnin if burnin is not None else self.burnin thin = thin if thin is not None else self.thin samples = self.sampler.get_chain(thin=int(thin)) steps = np.arange(samples.shape[0]) * thin naxes = len(self.default_params) fig, ax = plt.subplots(naxes, 1, figsize=(10, 20), sharex=True) for i in range(naxes): ax[i].plot(steps, samples[:, :, self.default_params[i]], "k", alpha=0.3) ax[i].axvspan(0, burnin, color="red", alpha=0.1, label="burn-in") ax[i].set_ylabel(self.default_param_labels[i]) ax[-1].legend() ax[-1].set_xlabel("Step number") ax[-1].set_xlim(0, self.nsteps) ax[0].set_title('MCMC Walkers') plt.subplots_adjust(hspace=0.05) if return_fig: return fig, ax plt.show()
[docs] @_require_sampler def plot_corner(self, sampler=None, return_fig:bool=False, **kwargs): """Creates a corner plot for at maximum the last 8 LSD profile and continuum polynomial coefficients. Parameters ---------- sampler : emcee.EnsembleSampler | None, optional Optionally provide a different sampler to plot from, otherwise, takes the sampler from the Result object, by default None return_fig : bool, optional Whether to return the figure object instead of showing the plot, by default False **kwargs: Additional keyword arguments to pass to corner.corner(). """ samples = self.sampler.get_chain() samples = self.sampler.get_chain(discard=self.burnin, flat=True, thin=self.thin)[:, self.default_params] fig = corner.corner(samples, labels=self.default_param_labels, show_title=True, title_fmt=".3f", title_kwargs={"fontsize": 16}, **kwargs) plt.suptitle('MCMC Corner Plot') if return_fig: return fig plt.show()
[docs] @_require_all_frames def plot_profiles( self, grid :bool = True, labels :dict|None = None, return_fig :bool = False, subplot_kwargs :dict|None = None, errorbar_kwargs :dict|None = None, fig_ax = None, **kwargs, ): """Plots the LSD profile result from Acid. Parameters ---------- grid : bool, optional Show or hide grid, by default True labels : dict | None, optional Keys: 'xlabel', 'ylabel', and 'title'. Allows label overrides., by default None return_fig : bool, optional Whether to return the figure and axis objects instead of showing the plot, by default False subplot_kwargs : dict | None, optional Keyword arguments to be passed to plt.subplots(), by default None errorbar_kwargs : dict | None, optional Keyword arguments to be passed to ax.errorbar(), by default None fig_ax : tuple[matplotlib.figure.Figure, matplotlib.axes.Axes] | None, optional Optionally provide an existing fig/axis tuple to plot on, by default None """ # Set default errorbar kwargs errorbar_defaults = { "fmt" : ".-", "ecolor" : "red", "linewidth": 1, } errorbar_kwargs = utils.set_dict_defaults(errorbar_kwargs, errorbar_defaults) # Set default subplot kwargs subplot_kwargs = utils.set_dict_defaults(subplot_kwargs, {"figsize": (10, 6)}) # Set default labels default_labels = { "title" : "Acid Profile", "xlabel": "Velocity (km/s)", "ylabel": "Normalised Flux" } labels = utils.set_dict_defaults(labels, default_labels) # Set useful variables nframes = len(self.all_frames) norders = len(self.all_frames[0]) frames = np.copy(self.all_frames) if fig_ax is None: fig, ax = plt.subplots(**subplot_kwargs) else: fig, ax = fig_ax if nframes > 1: if norders > 1: if self.verbose > 0: print("Warning: Multiple frames and orders detected. Only plotting the first frame for each order") frames = frames[:1, :, :, :] # Take first frame only for f, frame in enumerate(frames): for o, order in enumerate(frame): x, y, yerr = self.data.velocities, order[0], order[1] # TODO: Make Order a function of self.order_range, which needs to be configured in Acid # so that order_range is done automatically if multiple orders are manually put (and not # just using ACID_HARPS) label_default = f"Frame {f+1}, Order {o+1}" if nframes > 1 and norders > 1 else None errorbar_kwargs = utils.set_dict_defaults(errorbar_kwargs, {"label": label_default}) ax.errorbar(x, y-1, yerr=yerr, **errorbar_kwargs) ax.set_title(labels["title"]) ax.set_xlabel(labels["xlabel"]) ax.set_ylabel(labels["ylabel"]) ax.axhline(0, color='black', linestyle='--', linewidth=1) ax.legend() ax.grid(grid) if return_fig: return fig, ax else: plt.show()
[docs] @_require_all_frames def plot_forward_model( self, input_version :str = "masked", grid :bool = True, labels :dict|None = None, return_fig :bool = False, subplot_kwargs :dict|None = None, **kwargs # for testing ): """Plots the forward model fit to the observed spectrum. Parameters ---------- input_version : str, optional Which input spectrum to use: 'combined', 'input', 'masked', by default 'masked' grid : bool, optional Show or hide grid, by default True labels : dict | None, optional Keys: 'xlabel', 'ylabel', 'title', and 'residuals_ylabel'. Allows label overrides, by default None return_fig : bool, optional Whether to return the figure and axis objects instead of showing the plot, by default False subplot_kwargs : dict | None, optional Keyword arguments to be passed to plt.subplots(). Allows label overrides, by default None """ # Validate all inputs and set defaults input_version = input_version.lower() if input_version not in self.data.wavelengths.keys(): raise ValueError(f"input_version must be one of {list(self.data.wavelengths.keys())}") # Set default labels default_labels = { "title" : "Forward Model Fit to Observed Spectrum", "xlabel" : "Wavelength (Angstroms)", "ylabel" : "Normalised Flux", "residuals_ylabel": "Residuals", } labels = utils.set_dict_defaults(labels, default_labels) # Set default subplot kwargs subplot_kwargs = { "figsize": (10, 8), "sharex": True, "gridspec_kw": {'height_ratios': [3, 1]} } subplot_kwargs = utils.set_dict_defaults(subplot_kwargs, {"figsize": (10, 8)}) # Get input data input_wavelengths = self.data.wavelengths[input_version] input_flux = self.data.flux[input_version] input_errors = self.data.errors[input_version] # Get model flux samples = self.sampler.get_chain(discard=self.burnin, thin=self.thin, flat=True) theta_median = np.median(samples, axis=0) model_flux, _ = self.model(theta_median) # Plotting fig, ax = plt.subplots(2, 1, **subplot_kwargs) ax[0].plot(input_wavelengths, input_flux, color='black', linewidth=1, label='Observed Spectrum') ax[0].plot(input_wavelengths, model_flux, color='C0', linewidth=1, label='Forward Model Fit') ax[1].plot(input_wavelengths, input_flux - model_flux, color='C0', linewidth=1, label='Residuals') ax[0].set_title(labels["title"]) ax[1].set_xlabel(labels["xlabel"]) ax[0].set_ylabel(labels["ylabel"]) ax[1].set_ylabel(labels["residuals_ylabel"]) ax[1].axhline(0, color='black', linestyle='--', linewidth=1) ax[0].legend() ax[1].legend() ax[0].grid(grid) ax[1].grid(grid) plt.subplots_adjust(hspace=0.05) if return_fig: return fig, ax else: plt.show()
[docs] @_require_sampler def plot_autocorrelation( self, sampler=None, burnin: int | None = None, thin: int | None = None, n_grid: int = 12, c: float = 5.0, return_fig: bool = False, subplot_kwargs: dict | None = None, min_steps: int = 100 ): """ Plot estimated integrated autocorrelation time as a function of chain length. From the emcee docs: - For several prefixes of the chain, estimate tau with Sokal windowing. - Plot tau(N) and the reference line tau = N/50. Parameters ---------- sampler : emcee.EnsembleSampler | None, optional Optionally provide a different sampler to plot from, otherwise, takes the sampler from the Result object, by default None burnin, thin : int | None, optional Optional overrides. Defaults to self.burnin/self.thin from the sampler. n_grid : int, optional Number of N values (prefix lengths) to evaluate, by default 12. c : float, optional Sokal window constant (usually 5), by default 5.0. return_fig : bool, optional Whether to return the figure and axes objects, by default False subplot_kwargs : dict | None, optional Keyword arguments to be passed to plt.subplots(). Allows label overrides, by default None min_steps : int, optional Minimum number of post-burnin samples required to attempt autocorrelation estimation, by default 100 If you decrease this, you may get unreliable estimates or errors from the autocorrelation time estimation. Returns ---------- If return_fig is True, returns a tuple (fig, ax) of the figure and axes objects containing the plot. Otherwise, displays the plot and returns None. """ chain = self.sampler.get_chain() # (nsteps, nwalkers, ndim) nsteps, nwalkers, ndim = chain.shape if nsteps < min_steps: raise ValueError("Not enough post-burnin samples to estimate autocorrelation reliably.") Ns = np.unique(np.exp(np.linspace(np.log(min_steps), np.log(nsteps), n_grid)).astype(int)) Ns = Ns[Ns >= min_steps] # Ensure we only consider N >= min_steps tau_estimates = {p: np.full(len(Ns), np.nan, dtype=float) for p in self.default_params} # Estimate taus for i, n in enumerate(Ns): for p in self.default_params: y = chain[:n, :, p].T tau_estimates[p][i] = utils.autocorr_new(y, c=c) subplot_kwargs = {} if subplot_kwargs is None else dict(subplot_kwargs) subplot_kwargs = utils.set_dict_defaults(subplot_kwargs, {"figsize": (10, 6)}) fig, ax = plt.subplots(**subplot_kwargs) for label, p in zip(self.default_param_labels, self.default_params): ax.loglog(Ns, tau_estimates[p], "o-", label=f"{label}") # Reference line tau = N/50 ax.loglog(Ns, Ns / 50.0, "--", label=r"$\tau = N/50$") ax.set_xlabel("number of post-burnin samples per walker (N)") ax.set_ylabel(r"estimated integrated autocorrelation time $\tau$") ax.set_title("Autocorrelation time estimates vs chain length") ax.legend() ax.grid(True, which="both") if return_fig: return fig, ax plt.show() return
[docs] @_require_sampler def plot_acf( self, sampler=None, max_lag: int | None = None, return_fig: bool = False, subplot_kwargs: dict | None = None, ): """ Plot the autocorrelation function (ACF) for each parameter, averaged across walkers. Parameters ---------- sampler : emcee.EnsembleSampler | None, optional Optionally provide a different sampler to plot from, otherwise, takes the sampler from the Result object, by default None max_lag : int | None, optional Maximum lag to plot, by default None (plots up to min(5000, nsteps-1)) return_fig : bool, optional Whether to return the figure and axes objects, by default False subplot_kwargs : dict | None, optional Keyword arguments to be passed to plt.subplots(). Allows label overrides, by default None Returns ------- If return_fig is True, returns a tuple (fig, ax) of the figure and axes objects containing the plot. Otherwise, displays the plot and returns None. """ chain = self.sampler.get_chain() nsteps, nwalkers, ndim = chain.shape subplot_kwargs = {} if subplot_kwargs is None else dict(subplot_kwargs) subplot_kwargs = utils.set_dict_defaults(subplot_kwargs, {"figsize": (10, 5)}) fig, ax = plt.subplots(**subplot_kwargs) for param, label in zip(self.default_params, self.default_param_labels): y = chain[:, :, param].T # (nwalkers, nsteps) # Mean ACF across walkers f = np.zeros(nsteps) for w in range(nwalkers): f += utils.autocorr_func_1d(y[w], norm=True) f /= nwalkers if max_lag is None: max_lag = min(5_000, nsteps - 1) max_lag = int(max_lag) ax.plot(np.arange(max_lag + 1), f[: max_lag + 1], label=f"{label}") ax.set_xlabel("Lag (steps)") ax.set_ylabel("Autocorrelation") ax.set_title(f"Mean ACF across walkers") ax.set_xscale("log") ax.grid(True) ax.axhline(0, color="black", linestyle="--", linewidth=1) ax.legend() if return_fig: return fig, ax plt.show()
[docs] def initiate_sampler(self, sampler): """Initiates the sampler attribute from an external sampler. Parameters ---------- sampler : emcee.EnsembleSampler An emcee EnsembleSampler object to set as the sampler attribute. """ self.sampler = sampler if sampler is not None else self.sampler if self.sampler is None: raise ValueError("A sampler must be provided in initialisation or in method call") if sampler is None: return # sampler already initiated from initialisation, so skip the rest of the method self.ndim = self.sampler.ndim self.nwalkers = self.sampler.nwalkers self.nsteps = self.sampler.get_chain().shape[0] # Calculate autocorr time, burnin, thin # Suppress output from get_autocorr_time call with open(os.devnull, "w") as devnull, \ contextlib.redirect_stdout(devnull), \ contextlib.redirect_stderr(devnull): self.tau = self.sampler.get_autocorr_time(quiet=True) self.converged = True if self.nsteps < 50 * np.max(self.tau): self.converged = False if self.config.verbose>1: print("The number of MCMC steps is less than 50 times the maximum autocorrelation " \ "time.\n The sampler may not have converged. Consider running more steps or checking " \ f"the walker plots.\n The max autocorrelation time is {np.max(self.tau):.2f}, therefore " \ f"the minimum number of steps should be roughly {int(50 * np.max(self.tau))}.\n Disabling burnin " \ f"from autocorrelation time, instead using burnin=steps-1000") try: self.thin = int(np.min(self.tau)/5) if self.converged: self.burnin = int(2 * np.max(self.tau)) else: self.burnin = self.nsteps - 1000 # just the last 1000 steps except: if self.config.verbose>0: print(f"Warning: Could not compute autocorrelation time for burnin and thinning.\n This is likely" \ f" due to all posterior samples being rejected by prior constraints.\n The resulting profile is likely" \ f" wrong. Setting burnin=nsteps-1000, and thin=1.") self.burnin = self.nsteps - 1000 # just the last 1000 steps self.thin = 1 if self.config is not None: deterministic = self.config.deterministic_profile n_poly_params = self.data.config.poly_ord + 1 else: # Make a best guess if self.ndim > 6: # ie we assume a poly order of 5 is the highest anyone would ever want to go deterministic = False n_poly_params = 4 else: deterministic = True n_poly_params = self.ndim poly_params = np.arange(-1, -n_poly_params-1, -1).tolist() a=ord('a') alph=[chr(i) for i in range(a,a+26)] poly_labels = [alph[i] for i in range(n_poly_params)] samples = self.sampler.get_chain(thin=int(self.thin), discard=int(self.burnin)) if not deterministic: max_profile_idx = np.argmax(samples[:,:,:-n_poly_params].mean(axis=(0,1))) poly_params.extend([-5, max_profile_idx, 1]) poly_labels.extend(["$Z_{-1}$", "$Z_{max}$", "$Z_0$"]) self.default_params = poly_params self.default_param_labels = poly_labels
[docs] def initiate_data(self, data): """Initiates the data attribute from an external Data object. Parameters ---------- data : Data A Data object to set as the data attribute. """ self.data = data if data is not None else getattr(self, "data", None) if self.data is None: raise ValueError("A Data object must be provided in initialisation or in method call") if data is None: return # data already initiated from initialisation, so skip the rest of the method self.all_frames = self.data.all_frames self.nsteps = self.data.nsteps # For convenience, let the user call the model without needing to input all required args MCMC_class = mcmc.MCMC(self.data) self.model = MCMC_class.run_model_function
[docs] @_require_data @_require_sampler def save_result(self, filename:str="result.pkl", store_sampler:bool=True): """Saves the Result object to a pickle file. Parameters ---------- filename : str, optional Name of the file to save the Result object to, by default "result.pkl" store_sampler : bool, optional Whether to store the sampler backend in the pickle file. If False, the sampler will not be stored, and the Result object will not be able to continue sampling or plot walkers/corner plots """ state = dict(self.__dict__) state["data"] = self.data.to_dict() state["backend"] = dict(self.sampler.backend.__dict__) if store_sampler else None state["model"] = None state["sampler"] = None with open(filename, "wb") as f: pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL) if getattr(self, "config", None) is not None and self.config.verbose > 1: print(f"Result object saved to {filename}")
[docs] @classmethod def load_result(cls, result_object: str | object = "result.pkl"): """Loads a Result object from a pickle file or from an object with the same attributes as a saved Result object. Parameters ---------- result_object : str | object, optional A pickle file name or an object with the same attributes as a saved Result object, by default "result.pkl" Returns ---------- Result A Result object loaded from the pickle file or from the provided object. """ if isinstance(result_object, str): with open(result_object, "rb") as f: obj = pickle.load(f) else: obj = dict(result_object.__dict__) res = cls.__new__(cls) res.__dict__.update(obj) # reconstruct data data = Data() data.from_dict(res.data) res.data = data # reconstruct backend if res.backend is not None: backend = emceebackend.Backend(dtype=np.float64) backend.__dict__.update(res.backend) # reconstruct sampler from backend shape = backend.shape log_prob = mcmc.MCMC(res.data) res.sampler = EnsembleSampler(*shape, log_prob, backend=backend) # dummy sampler to hold the backend res.backend = None # backend is now stored in the sampler, so remove it from the Result object to avoid confusion # rebuild convenience things that shouldn’t be pickled cls.initiate_data(res, res.data) # sets all_frames and nsteps if getattr(res, "sampler", None) is not None: cls.initiate_sampler(res, res.sampler) # sets burnin, thin, and default params/labels if getattr(res, "config", None) is None: res.config = Config() if res.config.verbose > 1: print("Result object loaded") return res