from __future__ import annotations
from dataclasses import dataclass, field, fields
from typing import Any, Dict, Optional
import pickle
import numpy as np
from . import utils
[docs]
class Config:
"""A simple class to store the configuration of the ACID run."""
defaults = {
"verbose" : 2,
"order_range" : [1],
"telluric_lines" : [
3820.33, 3933.66, 3968.47, 4327.74, 4307.90, 4383.55, 4861.34,
5183.62, 5270.39, 5889.95, 5895.92, 6562.81, 7593.70, 8226.96
]
}
def __init__(self, **kwargs) -> None:
# Initialize all properties to None, so that we can check if they
# have been set or not in the update methods
self.property_names = self.get_property_names()
for k in self.property_names:
setattr(self, f"_{k}", None)
self.update_hipri(**kwargs) # Set initial values, allowing overwriting and validation of properties
# for k, v in self.defaults.items():
# if getattr(self, k, None) is None:
# setattr(self, k, v)
self.order_range = self.defaults["order_range"]
# self.update_lowpri(**self.defaults) # Could do later if moving all defaults to this class
# --- Update methods ---
[docs]
def update_hipri(self, **kwargs: Any) -> None:
# Update and overwrite existing keys
for k, v in kwargs.items():
if v is None:
continue
if self.is_property(k):
old = getattr(self, f"_{k}", None)
try:
setattr(self, f"_{k}", None)
setattr(self, k, v)
except Exception:
setattr(self, f"_{k}", old)
raise
else:
setattr(self, k, v)
[docs]
def update_lowpri(self, **kwargs: Any) -> None:
# Update but do not overwrite existing keys
for k, v in kwargs.items():
# Property setters automatically only set if previous value was None
if self.is_property(k):
setattr(self, k, v) # setter already implements "only if None"
else:
if getattr(self, k, None) is None:
setattr(self, k, v)
[docs]
def to_dict(self) -> dict[str, Any]:
d = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
for k in self.property_names:
d[k] = getattr(self, k)
return d
[docs]
@classmethod
def get_property_names(cls) -> set[str]:
# Collect @property names from the class and its bases
names: set[str] = set()
for c in cls.mro():
for name, attr in c.__dict__.items():
if isinstance(attr, property):
names.add(name)
return names
[docs]
def is_property(self, name: str) -> bool:
return name in self.property_names
# --- Properties ---
@property
def verbose(self) -> int:
if self._verbose is None:
return self.defaults["verbose"]
return self._verbose
@verbose.setter
def verbose(self, value) -> None:
# Make verbosity always an int regardless of input type, and check correct range
if self._verbose is not None:
return
if value is True:
value = self.defaults["verbose"]
elif value is False:
value = 0
elif isinstance(value, int):
if value < 0 or value > 3:
raise ValueError("verbose must be an integer between 0 and 3")
elif isinstance(value, str):
value = value.lower()
if value in ["none", "no", "false"]:
value = 0
elif value in ["low", "1"]:
value = 1
elif value in ["medium", "med", "2"]:
value = 2
elif value in ["high", "3"]:
value = 3
else:
raise ValueError("verbose string not recognised, must be one of 'none', 'low', 'medium', 'high' or their common variants")
elif value is None:
value = self.defaults["verbose"]
else:
raise ValueError("verbose must be an integer between 0 and 3, a boolean, or a string indicating the verbosity level")
self._verbose = value # Only updates if it was previously None
@property
def telluric_lines(self) -> int:
if self._telluric_lines is None:
return self.defaults["telluric_lines"]
return self._telluric_lines
@telluric_lines.setter
def telluric_lines(self, lines) -> None:
telluric_lines = lines
# Define telluric_lines with defaults if not input, check type if it is
if getattr(self, "telluric_lines", None) is not None:
return
if telluric_lines is None:
telluric_lines = self.defaults["telluric_lines"]
if not isinstance(telluric_lines, (list, np.ndarray)):
raise TypeError("telluric_lines must be a list or numpy array of telluric lines to" \
"mask in angstroms (could be empty or single-valued)")
telluric_lines = np.array(telluric_lines)
if telluric_lines.ndim != 1 or telluric_lines.size == 0:
raise ValueError("telluric_lines must be a one-dimensional array or list")
self._telluric_lines = telluric_lines
[docs]
@dataclass(slots=True)
class Data:
"""Stores necessary data for the Acid class which can be conveniently updated and saved.
Allows ACID to handle data that has already been computed to avoid recalculation. This class
is designed to be lightweight in memory and hence does not store the sampler as an object."""
# Standard necessary inputs, stored in dictionaries so we can store their state at multiple different
# states of the calculations in Acid
wavelengths : Dict[str, np.ndarray] = field(default_factory=dict)
flux : Dict[str, np.ndarray] = field(default_factory=dict)
errors : Dict[str, np.ndarray] = field(default_factory=dict)
sn : Dict[str, np.ndarray] = field(default_factory=dict)
# Cached products that are expensive or useful for resuming
alpha : Optional[np.ndarray] = None # the alpha vector used in the linear model, used for solving the linear system in MCMC
c_factor : Optional[tuple] = None # tuple generated by np.cho_factor, used for solving the linear system in MCMC
residual_masks : Optional[np.ndarray] = None # boolean 1D mask on "combined" grid, used in final process_results step
nanmask : Optional[np.ndarray] = None # boolean 1D mask on "combined" grid, used to mask out NaN values in combined spectra
velocities : Optional[np.ndarray] = None # velocities array, used throughout Acid and Results
initial_profile : Optional[np.ndarray] = None # initial profile generated in residual masking
initial_profile_errors : Optional[np.ndarray] = None # corresponding errors
poly_inputs : Optional[np.ndarray] = None # polynomial inputs for just the continuum model
model_inputs : Optional[np.ndarray] = None # the concatenated array of initial profile and poly coefficents, used as input to emcee
initial_state : Optional[np.ndarray] = None # the initial state of the MCMC walkers, used for resuming and debugging
# Small cached products needed for MCMC if doing reruns
nwalkers : Optional[int] = None
ndim : Optional[int] = None
# Data required/calculated in results/after MCMC sampling
all_frames : Optional[np.ndarray] = None # the array to store all frames of the MCMC sampling
nsteps : Optional[int] = 0
max_steps : Optional[int] = None
# Other useful data:
initialisation_time : Optional[float] = None # time taken for initialization
mcmc_time : Optional[float] = None # time taken for MCMC sampling
get_profiles_time : Optional[float] = None # time taken to get profiles
full_run_time : Optional[float] = None # total time for the full run
# Initialise the properties
# Config data for convenience, it is very memory light so not an issue to also store in here
_config : Config = field(default_factory=Config) # config stored as class, but converted to dict on save
_linelist : Optional[Dict[str, np.ndarray]] = None
[docs]
def initiate_all_frames(self, all_frames: np.ndarray) -> None:
"""Initiates the all_frames variable, used in the ACID method, to eventually store the results of the MCMC sampling.
This is used to update the all_frames variable after each sampling step, allowing for resuming and avoiding
recalculation of profiles if the user wishes to continue sampling.
Parameters
----------
all_frames : np.ndarray
The array of all frames to be stored in the data class. This should be of shape (n_steps, n_profiles, 2)
where the last dimension contains the profile and its error.
"""
if isinstance(all_frames, str):
if all_frames == "default":
all_frames = None # legacy behaviour
if all_frames is None:
if self.all_frames is None:
# By default order_range is [1], so len(self.order_range) = 1, which is same as original
# code behaviour. This change allows self.order_range to be used in ACID_HARPS.
self.all_frames = np.zeros((len(self.flux["input"]), len(self.config.order_range), 2, len(self.velocities)))
else:
self.all_frames = all_frames
if isinstance(self.all_frames, object):
from .result import Result
if isinstance(self.all_frames, Result):
self.all_frames = self.all_frames.all_frames
if not isinstance(self.all_frames, np.ndarray):
raise TypeError("'all_frames' must be a numpy array")
if not self.all_frames.ndim == 4:
raise ValueError("'all_frames' must be a 4-dimensional numpy array, see docstring for details")
[docs]
def save(self, filename: str = "data.pkl") -> None:
"""Saves the data object to a file using pickling. This will store just the dictionary of the class,
not the actual class itself. The load function then will initialise a new Data class using the dictionary.
Parameters
----------
filename : str
The name of the file to save the data object to. This should be a .pkl file.
"""
payload = self.to_dict() # generates a dictionary of the data object for easy pickling
with open(filename, "wb") as f:
pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL)
[docs]
@classmethod
def load(cls, filename: str) -> Data:
with open(filename, "rb") as f:
payload = pickle.load(f)
obj = cls()
obj.from_dict(payload)
return obj
[docs]
def to_dict(self) -> dict[str, Any]:
"""Converts the data object to a dictionary payload for saving. This is used internally in the save method, but can also be used for debugging or other purposes."""
payload: dict[str, Any] = {}
for f in fields(self):
name = f.name
val = getattr(self, name)
if name == "_config":
payload["config"] = val.to_dict() # store as dict in payload, but store as class in Data
else:
payload[name] = val
return payload
[docs]
def from_dict(self, payload: dict[str, Any]) -> None:
"""Updates the data object from a dictionary payload. This is used internally in the load method, but can also be used for debugging or other purposes."""
for f in fields(self):
name = f.name
if name == "_config": # config stored as a dict in payload, but stored here as class
cfg_dict = payload.get("config", {})
setattr(self, "_config", Config(**cfg_dict))
else:
if name in payload:
setattr(self, name, payload[name])
@property
def config(self) -> Config:
"""Returns the internally stored config object, which contains the configuration of the ACID run."""
return self._config
@config.setter
def config(self, value: Config) -> None:
self._config = value
@property
def linelist(self) -> Dict[str, np.ndarray]:
"""Returns the internally stored linelist. It has keys "wavelengths" and "depths" or index 0 and 1."""
return Linelist(self._linelist)if self._linelist is not None else None
[docs]
def set_linelist(self, linelist_path=None, linelist_wl=None, linelist_depths=None) -> None:
if self._linelist is not None: # linelist already set, do not overwrite
return
if (linelist_wl is None and linelist_depths is None) and linelist_path is None:
raise ValueError("One of ('linelist_wl' and 'linelist_depths') or 'linelist_path' must be provided.")
elif linelist_path is None and (linelist_wl is None or linelist_depths is None):
raise ValueError("If 'linelist_path' is not provided, both 'linelist_wl' and 'linelist_depths' must be provided.")
elif isinstance(linelist_path, str):
# VALD linelist code, will add more linelist formats in the future or if requested
full_linelist = np.genfromtxt('%s'%linelist_path, skip_header=4, delimiter=',', usecols=(1,9), invalid_raise=False)
linelist_wl = full_linelist[:,0]
linelist_depths = full_linelist[:,1]
elif isinstance(linelist_path, Linelist):
linelist_wl = linelist_path[0]
linelist_depths = linelist_path[1]
elif isinstance(linelist_path, dict):
if "wavelengths" not in linelist_path or "depths" not in linelist_path:
raise ValueError("If 'linelist_path' is a dict, it must contain keys 'wavelengths' and 'depths'")
linelist_wl = linelist_path["wavelengths"]
linelist_depths = linelist_path["depths"]
elif isinstance(linelist_path, (list, np.ndarray)):
if len(linelist_path) != 2:
raise ValueError("If 'linelist_path' is a list or array, it must have length 2, with index 0 being wavelengths and index 1 being depths")
linelist_wl = linelist_path[0]
linelist_depths = linelist_path[1]
else:
raise ValueError("'linelist_path' must be a string path to a VALD linelist, a dictionary with keys 'wavelengths' and 'depths', " \
"a Linelist object, or a list/array indexed such that 0 is wavelengths and 1 is depths.")
linelist_wl = np.array(linelist_wl)
linelist_depths = np.array(linelist_depths)
linelist_wl, linelist_depths = Linelist.drop_NaNs(linelist_wl, linelist_depths)
Linelist.validate_dimensions(linelist_wl, linelist_depths)
self._linelist = {"wavelengths": linelist_wl, "depths": linelist_depths}
[docs]
class Linelist:
"""A simple class to expose the linelist when called in Data"""
__slots__ = ("ll",) # the only thing stored in this class is the linelist
def __init__(self, ll: dict):
self.ll = ll
def __getitem__(self, k):
if k == 0:
return self.ll["wavelengths"]
if k == 1:
return self.ll["depths"]
if isinstance(k, int):
raise IndexError("Linelist only has keys 0 and 1, or 'wavelengths' and 'depths'")
return self.ll[k] # allow "wavelengths"/"depths"
def __iter__(self):
yield self.ll["wavelengths"]
yield self.ll["depths"]
[docs]
@staticmethod
def validate_dimensions(wavelengths, depths):
if wavelengths.ndim != 1 or depths.ndim != 1:
raise ValueError("'wavelengths' and 'depths' must be a one-dimensional array or list")
if wavelengths.shape != depths.shape:
raise ValueError("'wavelengths' and 'depths' must have the same length and shape")
[docs]
@staticmethod
def drop_NaNs(wavelengths, depths, return_mask=False, verbose=0):
mask = np.isfinite(wavelengths) & np.isfinite(depths)
count_dropped = np.count_nonzero(~mask)
mask &= (wavelengths > 0) & (depths > 0)
if verbose > 0 and count_dropped > 0:
print(f"Your linelist includes {count_dropped} non-finite/nan values, these will be removed, but it is still recommended to check your linelist.")
if return_mask:
return wavelengths[mask], depths[mask], mask
return wavelengths[mask], depths[mask]