Source code for amici.de_model_components

"""Objects for AMICI's internal differential equation model representation"""

import abc
import numbers
from typing import SupportsFloat

import sympy as sp

from .constants import SymbolId
from .import_utils import (
    RESERVED_SYMBOLS,
    ObservableTransformation,
    amici_time_symbol,
    cast_to_sym,
    contains_periodic_subexpression,
    generate_measurement_symbol,
    generate_regularization_symbol,
)

__all__ = [
    "ConservationLaw",
    "Constant",
    "Event",
    "Expression",
    "LogLikelihoodY",
    "LogLikelihoodZ",
    "LogLikelihoodRZ",
    "ModelQuantity",
    "NoiseParameter",
    "Observable",
    "ObservableParameter",
    "Parameter",
    "SigmaY",
    "SigmaZ",
    "DifferentialState",
    "EventObservable",
    "AlgebraicState",
    "AlgebraicEquation",
    "State",
]


[docs] class ModelQuantity: """ Base class for model components """
[docs] def __init__( self, symbol: sp.Symbol, name: str, value: SupportsFloat | numbers.Number | sp.Expr, ): """ Create a new ModelQuantity instance. :param symbol: Symbol of the quantity with unique identifier. :param name: individual name of the quantity (does not need to be unique) :param value: either formula, numeric value or initial value """ if not isinstance(symbol, sp.Symbol): raise TypeError(f"symbol must be sympy.Symbol, was {type(symbol)}") if str(symbol) in RESERVED_SYMBOLS or ( hasattr(symbol, "name") and symbol.name in RESERVED_SYMBOLS ): raise ValueError( f'Cannot add model quantity with reserved name "{name}", ' "please rename." ) self._symbol: sp.Symbol = symbol if not isinstance(name, str): raise TypeError(f"name must be str, was {type(name)}") self._name: str = name self._value: sp.Expr = cast_to_sym(value, "value")
def __repr__(self) -> str: """ Representation of the ModelQuantity object :return: string representation of the ModelQuantity """ return str(self._symbol)
[docs] def get_sym(self) -> sp.Symbol: """ ModelQuantity symbol :return: Symbol of the ModelQuantity """ return self._symbol
[docs] def get_id(self) -> str: """ ModelQuantity identifier :return: identifier of the ModelQuantity """ return ( self._symbol.name if hasattr(self._symbol, "name") else str(self._symbol) )
[docs] def get_name(self) -> str: """ ModelQuantity name :return: name of the ModelQuantity """ return self._name
[docs] def get_val(self) -> sp.Expr: """ ModelQuantity value :return: value of the ModelQuantity """ return self._value
[docs] def set_val(self, val: sp.Expr): """ Set ModelQuantity value :return: value of the ModelQuantity """ self._value = cast_to_sym(val, "value")
[docs] class ConservationLaw(ModelQuantity): """ A conservation law defines the absolute the total amount of a (weighted) sum of states """
[docs] def __init__( self, symbol: sp.Symbol, name: str, value: sp.Expr, coefficients: dict[sp.Symbol, sp.Expr], state_id: sp.Symbol, ): """ Create a new ConservationLaw instance. :param symbol: unique symbol of the ConservationLaw :param name: individual name of the ConservationLaw (does not need to be unique) :param value: formula (sum of states) :param coefficients: coefficients of the states in the sum :param state_id: Symbol of the state that this conservation law replaces """ self._state_expr: sp.Expr = symbol - (value - state_id) self._coefficients: dict[sp.Symbol, sp.Expr] = coefficients self._ncoeff: sp.Expr = coefficients[state_id] super().__init__(symbol, name, value)
[docs] def get_ncoeff(self, state_sym: sp.Symbol) -> sp.Expr | int | float: """ Computes the normalized coefficient a_i/a_j where i is the index of the provided state_id and j is the index of the state that is replaced by this conservation law. This can be used to compute both dtotal_cl/dx_rdata (=ncoeff) and dx_rdata/dx_solver (=-ncoeff). :param state_sym: Symbol of the state :return: normalized coefficient of the state """ return self._coefficients.get(state_sym, 0.0) / self._ncoeff
[docs] def get_x_rdata(self): """ Returns the expression that allows computation of x_rdata for the state that this conservation law replaces. :return: x_rdata expression """ return self._state_expr
[docs] class AlgebraicEquation(ModelQuantity): """ An AlgebraicEquation defines an algebraic equation. """
[docs] def __init__(self, symbol: sp.Symbol, value: sp.Expr): """ Create a new AlgebraicEquation instance. :param value: Formula of the algebraic equation, the solution is given by ``formula == 0`` """ super().__init__(symbol, symbol.name, value)
[docs] def get_free_symbols(self): return self._value.free_symbols
def __repr__(self): return str(self._value)
[docs] class State(ModelQuantity): """ Base class for differential and algebraic model states """ _conservation_law: ConservationLaw | None = None
[docs] def get_x_rdata(self): """ Returns the expression that allows computation of x_rdata for this state, accounting for conservation laws. :return: x_rdata expression """ if self._conservation_law is None: return self.get_sym() else: return self._conservation_law.get_x_rdata()
[docs] def get_dx_rdata_dx_solver(self, state_id): """ Returns the expression that allows computation of ``dx_rdata_dx_solver`` for this state, accounting for conservation laws. :return: dx_rdata_dx_solver expression """ if self._conservation_law is None: return sp.Integer(self._symbol == state_id) else: return -self._conservation_law.get_ncoeff(state_id)
[docs] @abc.abstractmethod def has_conservation_law(self): """ Checks whether this state has a conservation law assigned. :return: True if assigned, False otherwise """ ...
[docs] class AlgebraicState(State): """ An AlgebraicState defines an entity that is algebraically determined """
[docs] def __init__(self, symbol: sp.Symbol, name: str, init: sp.Expr): """ Create a new AlgebraicState instance. :param symbol: unique symbol of the AlgebraicState :param name: individual name of the AlgebraicState (does not need to be unique) :param init: initial value of the AlgebraicState """ super().__init__(symbol, name, init)
[docs] def has_conservation_law(self) -> bool: """ Checks whether this state has a conservation law assigned. :return: True if assigned, False otherwise """ return False
[docs] def get_free_symbols(self): return self._value.free_symbols
[docs] def get_x_rdata(self): return self._symbol
[docs] class DifferentialState(State): """ A State variable defines an entity that evolves with time according to the provided time derivative, abbreviated by ``x``. :ivar _conservation_law: algebraic formula that allows computation of this state according to a conservation law :ivar _dt: algebraic formula that defines the temporal derivative of this state """
[docs] def __init__( self, symbol: sp.Symbol, name: str, init: sp.Expr, dt: sp.Expr ): """ Create a new State instance. Extends :meth:`ModelQuantity.__init__` by ``dt`` :param symbol: unique symbol of the state :param name: individual name of the state (does not need to be unique) :param init: initial value :param dt: time derivative """ super().__init__(symbol, name, init) self._dt = cast_to_sym(dt, "dt") self._conservation_law: ConservationLaw | None = None
[docs] def set_conservation_law(self, law: ConservationLaw) -> None: """ Sets the conservation law of a state. If a conservation law is set, the respective state will be replaced by an algebraic formula according to the respective conservation law. :param law: linear sum of states that if added to this state remain constant over time """ if not isinstance(law, ConservationLaw): raise TypeError( f"conservation law must have type ConservationLaw" f", was {type(law)}" ) self._conservation_law = law
[docs] def set_dt(self, dt: sp.Expr) -> None: """ Sets the time derivative :param dt: time derivative """ self._dt = cast_to_sym(dt, "dt")
[docs] def get_dt(self) -> sp.Expr: """ Gets the time derivative :return: time derivative """ return self._dt
[docs] def get_free_symbols(self) -> set[sp.Basic]: """ Gets the set of free symbols in time derivative and initial conditions :return: free symbols """ return self._dt.free_symbols.union(self._value.free_symbols)
[docs] def has_conservation_law(self): """ Checks whether this state has a conservation law assigned. :return: True if assigned, False otherwise """ return self._conservation_law is not None
[docs] class Observable(ModelQuantity): """ An Observable links model simulations to experimental measurements, abbreviated by ``y``. :ivar _measurement_symbol: sympy symbol used in the objective function to represent measurements to this observable :ivar trafo: observable transformation, only applies when evaluating objective function or residuals """ _measurement_symbol: sp.Symbol | None = None
[docs] def __init__( self, symbol: sp.Symbol, name: str, value: sp.Expr, measurement_symbol: sp.Symbol | None = None, transformation: None | ObservableTransformation = ObservableTransformation.LIN, ): """ Create a new Observable instance. :param symbol: unique symbol of the Observable :param name: individual name of the Observable (does not need to be unique) :param value: formula :param transformation: observable transformation, only applies when evaluating objective function or residuals """ super().__init__(symbol, name, value) self._measurement_symbol = measurement_symbol self._regularization_symbol = None self.trafo = transformation
[docs] def get_measurement_symbol(self) -> sp.Symbol: if self._measurement_symbol is None: self._measurement_symbol = generate_measurement_symbol( self.get_sym() ) return self._measurement_symbol
[docs] def get_regularization_symbol(self) -> sp.Symbol: if self._regularization_symbol is None: self._regularization_symbol = generate_regularization_symbol( self.get_sym() ) return self._regularization_symbol
[docs] class EventObservable(Observable): """ An Event Observable links model simulations to event related experimental measurements, abbreviated by ``z``. :ivar _event: symbolic event identifier """
[docs] def __init__( self, symbol: sp.Symbol, name: str, value: sp.Expr, event: sp.Symbol, measurement_symbol: sp.Symbol | None = None, transformation: ObservableTransformation | None = "lin", ): """ Create a new EventObservable instance. :param symbol: See :py:meth:`Observable.__init__`. :param name: See :py:meth:`Observable.__init__`. :param value: See :py:meth:`Observable.__init__`. :param transformation: See :py:meth:`Observable.__init__`. :param event: Symbolic identifier of the corresponding event. """ super().__init__( symbol, name, value, measurement_symbol, transformation ) self._event: sp.Symbol = event
[docs] def get_event(self) -> sp.Symbol: """ Get the symbolic identifier of the corresponding event. :return: symbolic identifier """ return self._event
class Sigma(ModelQuantity): """ A Standard Deviation Sigma rescales the distance between simulations and measurements when computing residuals or objective functions, abbreviated by ``sigma{y,z}``. """ def __init__(self, symbol: sp.Symbol, name: str, value: sp.Expr): """ Create a new Standard Deviation instance. :param symbol: unique symbol of the Standard Deviation :param name: individual name of the Standard Deviation (does not need to be unique) :param value: formula """ if self.__class__.__name__ == "Sigma": raise RuntimeError( "This class is meant to be sub-classed, not used directly." ) super().__init__(symbol, name, value)
[docs] class SigmaY(Sigma): """ Standard deviation for observables """
[docs] class SigmaZ(Sigma): """ Standard deviation for event observables """
[docs] class Expression(ModelQuantity): """ An Expression is a recurring elements in symbolic formulas. Specifying this may yield more compact expression which may lead to substantially shorter model compilation times, but may also reduce model simulation time. Abbreviated by ``w``. """
[docs] def __init__(self, symbol: sp.Symbol, name: str, value: sp.Expr): """ Create a new Expression instance. :param symbol: unique symbol of the Expression :param name: individual name of the Expression (does not need to be unique) :param value: formula """ super().__init__(symbol, name, value)
[docs] class Parameter(ModelQuantity): """ A Parameter is a free variable in the model with respect to which sensitivities may be computed, abbreviated by ``p``. """
[docs] def __init__(self, symbol: sp.Symbol, name: str, value: numbers.Number): """ Create a new Expression instance. :param symbol: unique symbol of the Parameter :param name: individual name of the Parameter (does not need to be unique) :param value: numeric value """ super().__init__(symbol, name, value)
[docs] class Constant(ModelQuantity): """ A Constant is a fixed variable in the model with respect to which sensitivities cannot be computed, abbreviated by ``k``. """
[docs] def __init__(self, symbol: sp.Symbol, name: str, value: numbers.Number): """ Create a new Expression instance. :param symbol: unique symbol of the Constant :param name: individual name of the Constant (does not need to be unique) :param value: numeric value """ super().__init__(symbol, name, value)
[docs] class NoiseParameter(ModelQuantity): """ A NoiseParameter is an input variable for the computation of ``sigma`` that can be specified in a data-point specific manner, abbreviated by ``np``. Only used for jax models. """
[docs] def __init__(self, symbol: sp.Symbol, name: str): """ Create a new Expression instance. :param symbol: unique symbol of the NoiseParameter :param name: individual name of the NoiseParameter (does not need to be unique) """ super().__init__(symbol, name, 0.0)
[docs] class ObservableParameter(ModelQuantity): """ A NoiseParameter is an input variable for the computation of ``y`` that can be specified in a data-point specific manner, abbreviated by ``op``. Only used for jax models. """
[docs] def __init__(self, symbol: sp.Symbol, name: str): """ Create a new Expression instance. :param symbol: unique symbol of the ObservableParameter :param name: individual name of the ObservableParameter (does not need to be unique) """ super().__init__(symbol, name, 0.0)
class LogLikelihood(ModelQuantity): """ A LogLikelihood defines the distance between measurements and experiments for a particular observable. The final LogLikelihood value in the simulation will be the sum of all specified LogLikelihood instances evaluated at all timepoints, abbreviated by ``Jy``. """ def __init__(self, symbol: sp.Symbol, name: str, value: sp.Expr): """ Create a new Expression instance. :param symbol: unique symbol of the LogLikelihood :param name: individual name of the LogLikelihood (does not need to be unique) :param value: formula """ if self.__class__.__name__ == "LogLikelihood": raise RuntimeError( "This class is meant to be sub-classed, not used directly." ) super().__init__(symbol, name, value)
[docs] class LogLikelihoodY(LogLikelihood): """ Loglikelihood for observables """
[docs] class LogLikelihoodZ(LogLikelihood): """ Loglikelihood for event observables """
[docs] class LogLikelihoodRZ(LogLikelihood): """ Loglikelihood for event observables regularization """
[docs] class Event(ModelQuantity): """ An Event defines either a SBML event or a root of the argument of a Heaviside function. The Heaviside functions will be tracked via the vector ``h`` during simulation and are needed to inform the solver about a discontinuity in either the right-hand side or the states themselves, causing a reinitialization of the solver. """
[docs] def __init__( self, symbol: sp.Symbol, name: str, value: sp.Expr, use_values_from_trigger_time: bool, assignments: dict[sp.Symbol, sp.Expr] | None = None, initial_value: bool | None = True, priority: sp.Basic | None = None, ): """ Create a new Event instance. :param symbol: unique symbol of the Event :param name: individual name of the Event (does not need to be unique) :param value: formula for the root / trigger function :param assignments: Dictionary of event assignments: state symbol -> new value. :param initial_value: initial boolean value of the trigger function at t0. If set to `False`, events may trigger at ``t==t0``, otherwise not. :param priority: The priority of the event assignment. :param use_values_from_trigger_time: Whether the event assignment is evaluated using the state from the time point at which the event triggered (True), or at the time point at which the event assignment is evaluated (False). """ super().__init__(symbol, name, value) # add the Event specific components self._assignments = assignments if assignments is not None else {} self._initial_value = initial_value if priority is not None and not priority.is_Number: raise NotImplementedError( "Currently, only numeric values are supported as event priority." ) self._priority = priority self._use_values_from_trigger_time = use_values_from_trigger_time # expression(s) for the timepoint(s) at which the event triggers self._t_root = [] if not contains_periodic_subexpression( self.get_val(), amici_time_symbol ): # `solve` will solve, e.g., sin(t), but will only return [0, pi], # so we better skip any periodic expressions here try: self._t_root = sp.solve(self.get_val(), amici_time_symbol) except NotImplementedError: # the trigger can't be solved for `t` pass
[docs] def get_state_update( self, x: sp.Matrix, x_old: sp.Matrix ) -> sp.Matrix | None: """ Get the state update (bolus) expression for the event assignment. :param x: The current state vector. :param x_old: The previous state vector. If ``use_values_from_trigger_time=True``, this is equal to `x`. :return: State-update matrix or ``None`` if no state update is defined. """ if len(self._assignments) == 0: return None x_to_x_old = dict(zip(x, x_old, strict=True)) def get_bolus(x_i: sp.Symbol) -> sp.Expr: """ Get the bolus expression for a state variable. :param x_i: state variable symbol :return: bolus expression """ if (assignment := self._assignments.get(x_i)) is not None: return assignment.subs(x_to_x_old) - x_i else: return sp.Float(0.0) return sp.Matrix([get_bolus(x_i) for x_i in x])
[docs] def get_initial_value(self) -> bool: """ Return the initial value for the root function. :return: initial value formula """ return self._initial_value
[docs] def get_priority(self) -> sp.Basic | None: """Return the priority of the event assignment.""" return self._priority
def __eq__(self, other): """ Check equality of events at the level of trigger/root functions, as we need to collect unique root functions for ``roots.cpp`` """ return self.get_val() == other.get_val() and ( self.get_initial_value() == other.get_initial_value() )
[docs] def triggers_at_fixed_timepoint(self) -> bool: """Check whether the event triggers at a (single) fixed time-point.""" if len(self._t_root) != 1: return False return self._t_root[0].is_Number
[docs] def get_trigger_time(self) -> sp.Float: """Get the time at which the event triggers. Only for events that trigger at a single fixed time-point. """ if not self.triggers_at_fixed_timepoint(): raise NotImplementedError( "This event does not trigger at a fixed timepoint." ) return self._t_root[0]
[docs] def has_explicit_trigger_times( self, allowed_symbols: set[sp.Symbol] | None = None ) -> bool: """Check whether the event has explicit trigger times. Explicit trigger times do not require root finding to determine the time points at which the event triggers. :param allowed_symbols: The set of symbols that are allowed in the trigger time expressions. If `None`, any symbols are allowed. If empty, only numeric values are allowed. """ if allowed_symbols is None: return len(self._t_root) > 0 return len(self._t_root) > 0 and all( t.is_Number or t.free_symbols.issubset(allowed_symbols) for t in self._t_root )
[docs] def get_trigger_times(self) -> set[sp.Expr]: """Get the time points at which the event triggers. Returns a set of expressions, which may contain multiple time points for events that trigger at multiple time points. If the return value is empty, the trigger function cannot be solved for `t`. I.e., the event does not explicitly depend on time, or sympy is unable to solve the trigger function for `t`. If the return value is non-empty, it contains expressions for *all* time points at which the event triggers. """ return set(self._t_root)
@property def uses_values_from_trigger_time(self) -> bool: """Whether the event assignment is evaluated using the state from the time point at which the event triggered (True), or at the time point at which the event assignment is evaluated (False). """ return self._use_values_from_trigger_time @property def updates_state(self) -> bool: """Whether the event assignment updates the model state.""" return bool(self._assignments)
# defines the type of some attributes in DEModel symbol_to_type = { SymbolId.SPECIES: DifferentialState, SymbolId.ALGEBRAIC_STATE: AlgebraicState, SymbolId.ALGEBRAIC_EQUATION: AlgebraicEquation, SymbolId.PARAMETER: Parameter, SymbolId.FIXED_PARAMETER: Constant, SymbolId.OBSERVABLE: Observable, SymbolId.EVENT_OBSERVABLE: EventObservable, SymbolId.SIGMAY: SigmaY, SymbolId.SIGMAZ: SigmaZ, SymbolId.LLHY: LogLikelihoodY, SymbolId.LLHZ: LogLikelihoodZ, SymbolId.LLHRZ: LogLikelihoodRZ, SymbolId.EXPRESSION: Expression, SymbolId.EVENT: Event, SymbolId.NOISE_PARAMETER: NoiseParameter, SymbolId.OBSERVABLE_PARAMETER: ObservableParameter, }