"""Objects for AMICI's internal differential equation model representation"""
import abc
import numbers
from typing import SupportsFloat
import sympy as sp
from .constants import SymbolId
from .import_utils import (
RESERVED_SYMBOLS,
ObservableTransformation,
amici_time_symbol,
cast_to_sym,
contains_periodic_subexpression,
generate_measurement_symbol,
generate_regularization_symbol,
)
__all__ = [
"ConservationLaw",
"Constant",
"Event",
"Expression",
"LogLikelihoodY",
"LogLikelihoodZ",
"LogLikelihoodRZ",
"ModelQuantity",
"NoiseParameter",
"Observable",
"ObservableParameter",
"Parameter",
"SigmaY",
"SigmaZ",
"DifferentialState",
"EventObservable",
"AlgebraicState",
"AlgebraicEquation",
"State",
]
[docs]
class ModelQuantity:
"""
Base class for model components
"""
[docs]
def __init__(
self,
symbol: sp.Symbol,
name: str,
value: SupportsFloat | numbers.Number | sp.Expr,
):
"""
Create a new ModelQuantity instance.
:param symbol:
Symbol of the quantity with unique identifier.
:param name:
individual name of the quantity (does not need to be unique)
:param value:
either formula, numeric value or initial value
"""
if not isinstance(symbol, sp.Symbol):
raise TypeError(f"symbol must be sympy.Symbol, was {type(symbol)}")
if str(symbol) in RESERVED_SYMBOLS or (
hasattr(symbol, "name") and symbol.name in RESERVED_SYMBOLS
):
raise ValueError(
f'Cannot add model quantity with reserved name "{name}", '
"please rename."
)
self._symbol: sp.Symbol = symbol
if not isinstance(name, str):
raise TypeError(f"name must be str, was {type(name)}")
self._name: str = name
self._value: sp.Expr = cast_to_sym(value, "value")
def __repr__(self) -> str:
"""
Representation of the ModelQuantity object
:return:
string representation of the ModelQuantity
"""
return str(self._symbol)
[docs]
def get_sym(self) -> sp.Symbol:
"""
ModelQuantity symbol
:return:
Symbol of the ModelQuantity
"""
return self._symbol
[docs]
def get_id(self) -> str:
"""
ModelQuantity identifier
:return:
identifier of the ModelQuantity
"""
return (
self._symbol.name
if hasattr(self._symbol, "name")
else str(self._symbol)
)
[docs]
def get_name(self) -> str:
"""
ModelQuantity name
:return:
name of the ModelQuantity
"""
return self._name
[docs]
def get_val(self) -> sp.Expr:
"""
ModelQuantity value
:return:
value of the ModelQuantity
"""
return self._value
[docs]
def set_val(self, val: sp.Expr):
"""
Set ModelQuantity value
:return:
value of the ModelQuantity
"""
self._value = cast_to_sym(val, "value")
[docs]
class ConservationLaw(ModelQuantity):
"""
A conservation law defines the absolute the total amount of a
(weighted) sum of states
"""
[docs]
def __init__(
self,
symbol: sp.Symbol,
name: str,
value: sp.Expr,
coefficients: dict[sp.Symbol, sp.Expr],
state_id: sp.Symbol,
):
"""
Create a new ConservationLaw instance.
:param symbol:
unique symbol of the ConservationLaw
:param name:
individual name of the ConservationLaw (does not need to be
unique)
:param value: formula (sum of states)
:param coefficients:
coefficients of the states in the sum
:param state_id:
Symbol of the state that this conservation law replaces
"""
self._state_expr: sp.Expr = symbol - (value - state_id)
self._coefficients: dict[sp.Symbol, sp.Expr] = coefficients
self._ncoeff: sp.Expr = coefficients[state_id]
super().__init__(symbol, name, value)
[docs]
def get_ncoeff(self, state_sym: sp.Symbol) -> sp.Expr | int | float:
"""
Computes the normalized coefficient a_i/a_j where i is the index of
the provided state_id and j is the index of the state that is
replaced by this conservation law. This can be used to compute both
dtotal_cl/dx_rdata (=ncoeff) and dx_rdata/dx_solver (=-ncoeff).
:param state_sym:
Symbol of the state
:return: normalized coefficient of the state
"""
return self._coefficients.get(state_sym, 0.0) / self._ncoeff
[docs]
def get_x_rdata(self):
"""
Returns the expression that allows computation of x_rdata for the state
that this conservation law replaces.
:return: x_rdata expression
"""
return self._state_expr
[docs]
class AlgebraicEquation(ModelQuantity):
"""
An AlgebraicEquation defines an algebraic equation.
"""
[docs]
def __init__(self, symbol: sp.Symbol, value: sp.Expr):
"""
Create a new AlgebraicEquation instance.
:param value:
Formula of the algebraic equation, the solution is given by
``formula == 0``
"""
super().__init__(symbol, symbol.name, value)
[docs]
def get_free_symbols(self):
return self._value.free_symbols
def __repr__(self):
return str(self._value)
[docs]
class State(ModelQuantity):
"""
Base class for differential and algebraic model states
"""
_conservation_law: ConservationLaw | None = None
[docs]
def get_x_rdata(self):
"""
Returns the expression that allows computation of x_rdata for this
state, accounting for conservation laws.
:return: x_rdata expression
"""
if self._conservation_law is None:
return self.get_sym()
else:
return self._conservation_law.get_x_rdata()
[docs]
def get_dx_rdata_dx_solver(self, state_id):
"""
Returns the expression that allows computation of
``dx_rdata_dx_solver`` for this state, accounting for conservation
laws.
:return: dx_rdata_dx_solver expression
"""
if self._conservation_law is None:
return sp.Integer(self._symbol == state_id)
else:
return -self._conservation_law.get_ncoeff(state_id)
[docs]
@abc.abstractmethod
def has_conservation_law(self):
"""
Checks whether this state has a conservation law assigned.
:return: True if assigned, False otherwise
"""
...
[docs]
class AlgebraicState(State):
"""
An AlgebraicState defines an entity that is algebraically determined
"""
[docs]
def __init__(self, symbol: sp.Symbol, name: str, init: sp.Expr):
"""
Create a new AlgebraicState instance.
:param symbol:
unique symbol of the AlgebraicState
:param name:
individual name of the AlgebraicState (does not need to be unique)
:param init:
initial value of the AlgebraicState
"""
super().__init__(symbol, name, init)
[docs]
def has_conservation_law(self) -> bool:
"""
Checks whether this state has a conservation law assigned.
:return: True if assigned, False otherwise
"""
return False
[docs]
def get_free_symbols(self):
return self._value.free_symbols
[docs]
def get_x_rdata(self):
return self._symbol
[docs]
class DifferentialState(State):
"""
A State variable defines an entity that evolves with time according to
the provided time derivative, abbreviated by ``x``.
:ivar _conservation_law:
algebraic formula that allows computation of this
state according to a conservation law
:ivar _dt:
algebraic formula that defines the temporal derivative of this state
"""
[docs]
def __init__(
self, symbol: sp.Symbol, name: str, init: sp.Expr, dt: sp.Expr
):
"""
Create a new State instance. Extends :meth:`ModelQuantity.__init__`
by ``dt``
:param symbol:
unique symbol of the state
:param name:
individual name of the state (does not need to be unique)
:param init:
initial value
:param dt:
time derivative
"""
super().__init__(symbol, name, init)
self._dt = cast_to_sym(dt, "dt")
self._conservation_law: ConservationLaw | None = None
[docs]
def set_conservation_law(self, law: ConservationLaw) -> None:
"""
Sets the conservation law of a state.
If a conservation law is set, the respective state will be replaced by
an algebraic formula according to the respective conservation law.
:param law:
linear sum of states that if added to this state remain
constant over time
"""
if not isinstance(law, ConservationLaw):
raise TypeError(
f"conservation law must have type ConservationLaw"
f", was {type(law)}"
)
self._conservation_law = law
[docs]
def set_dt(self, dt: sp.Expr) -> None:
"""
Sets the time derivative
:param dt:
time derivative
"""
self._dt = cast_to_sym(dt, "dt")
[docs]
def get_dt(self) -> sp.Expr:
"""
Gets the time derivative
:return:
time derivative
"""
return self._dt
[docs]
def get_free_symbols(self) -> set[sp.Basic]:
"""
Gets the set of free symbols in time derivative and initial conditions
:return:
free symbols
"""
return self._dt.free_symbols.union(self._value.free_symbols)
[docs]
def has_conservation_law(self):
"""
Checks whether this state has a conservation law assigned.
:return: True if assigned, False otherwise
"""
return self._conservation_law is not None
[docs]
class Observable(ModelQuantity):
"""
An Observable links model simulations to experimental measurements,
abbreviated by ``y``.
:ivar _measurement_symbol:
sympy symbol used in the objective function to represent
measurements to this observable
:ivar trafo:
observable transformation, only applies when evaluating objective
function or residuals
"""
_measurement_symbol: sp.Symbol | None = None
[docs]
def __init__(
self,
symbol: sp.Symbol,
name: str,
value: sp.Expr,
measurement_symbol: sp.Symbol | None = None,
transformation: None
| ObservableTransformation = ObservableTransformation.LIN,
):
"""
Create a new Observable instance.
:param symbol:
unique symbol of the Observable
:param name:
individual name of the Observable (does not need to be unique)
:param value:
formula
:param transformation:
observable transformation, only applies when evaluating objective
function or residuals
"""
super().__init__(symbol, name, value)
self._measurement_symbol = measurement_symbol
self._regularization_symbol = None
self.trafo = transformation
[docs]
def get_measurement_symbol(self) -> sp.Symbol:
if self._measurement_symbol is None:
self._measurement_symbol = generate_measurement_symbol(
self.get_sym()
)
return self._measurement_symbol
[docs]
def get_regularization_symbol(self) -> sp.Symbol:
if self._regularization_symbol is None:
self._regularization_symbol = generate_regularization_symbol(
self.get_sym()
)
return self._regularization_symbol
[docs]
class EventObservable(Observable):
"""
An Event Observable links model simulations to event related experimental
measurements, abbreviated by ``z``.
:ivar _event:
symbolic event identifier
"""
[docs]
def __init__(
self,
symbol: sp.Symbol,
name: str,
value: sp.Expr,
event: sp.Symbol,
measurement_symbol: sp.Symbol | None = None,
transformation: ObservableTransformation | None = "lin",
):
"""
Create a new EventObservable instance.
:param symbol:
See :py:meth:`Observable.__init__`.
:param name:
See :py:meth:`Observable.__init__`.
:param value:
See :py:meth:`Observable.__init__`.
:param transformation:
See :py:meth:`Observable.__init__`.
:param event:
Symbolic identifier of the corresponding event.
"""
super().__init__(
symbol, name, value, measurement_symbol, transformation
)
self._event: sp.Symbol = event
[docs]
def get_event(self) -> sp.Symbol:
"""
Get the symbolic identifier of the corresponding event.
:return: symbolic identifier
"""
return self._event
class Sigma(ModelQuantity):
"""
A Standard Deviation Sigma rescales the distance between simulations
and measurements when computing residuals or objective functions,
abbreviated by ``sigma{y,z}``.
"""
def __init__(self, symbol: sp.Symbol, name: str, value: sp.Expr):
"""
Create a new Standard Deviation instance.
:param symbol:
unique symbol of the Standard Deviation
:param name:
individual name of the Standard Deviation (does not need to
be unique)
:param value:
formula
"""
if self.__class__.__name__ == "Sigma":
raise RuntimeError(
"This class is meant to be sub-classed, not used directly."
)
super().__init__(symbol, name, value)
[docs]
class SigmaY(Sigma):
"""
Standard deviation for observables
"""
[docs]
class SigmaZ(Sigma):
"""
Standard deviation for event observables
"""
[docs]
class Expression(ModelQuantity):
"""
An Expression is a recurring elements in symbolic formulas. Specifying
this may yield more compact expression which may lead to substantially
shorter model compilation times, but may also reduce model simulation time.
Abbreviated by ``w``.
"""
[docs]
def __init__(self, symbol: sp.Symbol, name: str, value: sp.Expr):
"""
Create a new Expression instance.
:param symbol:
unique symbol of the Expression
:param name:
individual name of the Expression (does not need to be unique)
:param value:
formula
"""
super().__init__(symbol, name, value)
[docs]
class Parameter(ModelQuantity):
"""
A Parameter is a free variable in the model with respect to which
sensitivities may be computed, abbreviated by ``p``.
"""
[docs]
def __init__(self, symbol: sp.Symbol, name: str, value: numbers.Number):
"""
Create a new Expression instance.
:param symbol:
unique symbol of the Parameter
:param name:
individual name of the Parameter (does not need to be
unique)
:param value:
numeric value
"""
super().__init__(symbol, name, value)
[docs]
class Constant(ModelQuantity):
"""
A Constant is a fixed variable in the model with respect to which
sensitivities cannot be computed, abbreviated by ``k``.
"""
[docs]
def __init__(self, symbol: sp.Symbol, name: str, value: numbers.Number):
"""
Create a new Expression instance.
:param symbol:
unique symbol of the Constant
:param name:
individual name of the Constant (does not need to be unique)
:param value:
numeric value
"""
super().__init__(symbol, name, value)
[docs]
class NoiseParameter(ModelQuantity):
"""
A NoiseParameter is an input variable for the computation of ``sigma`` that can be specified in a data-point
specific manner, abbreviated by ``np``. Only used for jax models.
"""
[docs]
def __init__(self, symbol: sp.Symbol, name: str):
"""
Create a new Expression instance.
:param symbol:
unique symbol of the NoiseParameter
:param name:
individual name of the NoiseParameter (does not need to be
unique)
"""
super().__init__(symbol, name, 0.0)
[docs]
class ObservableParameter(ModelQuantity):
"""
A NoiseParameter is an input variable for the computation of ``y`` that can be specified in a data-point specific
manner, abbreviated by ``op``. Only used for jax models.
"""
[docs]
def __init__(self, symbol: sp.Symbol, name: str):
"""
Create a new Expression instance.
:param symbol:
unique symbol of the ObservableParameter
:param name:
individual name of the ObservableParameter (does not need to be
unique)
"""
super().__init__(symbol, name, 0.0)
class LogLikelihood(ModelQuantity):
"""
A LogLikelihood defines the distance between measurements and
experiments for a particular observable. The final LogLikelihood value
in the simulation will be the sum of all specified LogLikelihood
instances evaluated at all timepoints, abbreviated by ``Jy``.
"""
def __init__(self, symbol: sp.Symbol, name: str, value: sp.Expr):
"""
Create a new Expression instance.
:param symbol:
unique symbol of the LogLikelihood
:param name:
individual name of the LogLikelihood (does not need to be
unique)
:param value:
formula
"""
if self.__class__.__name__ == "LogLikelihood":
raise RuntimeError(
"This class is meant to be sub-classed, not used directly."
)
super().__init__(symbol, name, value)
[docs]
class LogLikelihoodY(LogLikelihood):
"""
Loglikelihood for observables
"""
[docs]
class LogLikelihoodZ(LogLikelihood):
"""
Loglikelihood for event observables
"""
[docs]
class LogLikelihoodRZ(LogLikelihood):
"""
Loglikelihood for event observables regularization
"""
[docs]
class Event(ModelQuantity):
"""
An Event defines either a SBML event or a root of the argument of a
Heaviside function. The Heaviside functions will be tracked via the
vector ``h`` during simulation and are needed to inform the solver
about a discontinuity in either the right-hand side or the states
themselves, causing a reinitialization of the solver.
"""
[docs]
def __init__(
self,
symbol: sp.Symbol,
name: str,
value: sp.Expr,
use_values_from_trigger_time: bool,
assignments: dict[sp.Symbol, sp.Expr] | None = None,
initial_value: bool | None = True,
priority: sp.Basic | None = None,
):
"""
Create a new Event instance.
:param symbol:
unique symbol of the Event
:param name:
individual name of the Event (does not need to be unique)
:param value:
formula for the root / trigger function
:param assignments:
Dictionary of event assignments: state symbol -> new value.
:param initial_value:
initial boolean value of the trigger function at t0. If set to
`False`, events may trigger at ``t==t0``, otherwise not.
:param priority: The priority of the event assignment.
:param use_values_from_trigger_time:
Whether the event assignment is evaluated using the state from
the time point at which the event triggered (True), or at the time
point at which the event assignment is evaluated (False).
"""
super().__init__(symbol, name, value)
# add the Event specific components
self._assignments = assignments if assignments is not None else {}
self._initial_value = initial_value
if priority is not None and not priority.is_Number:
raise NotImplementedError(
"Currently, only numeric values are supported as event priority."
)
self._priority = priority
self._use_values_from_trigger_time = use_values_from_trigger_time
# expression(s) for the timepoint(s) at which the event triggers
self._t_root = []
if not contains_periodic_subexpression(
self.get_val(), amici_time_symbol
):
# `solve` will solve, e.g., sin(t), but will only return [0, pi],
# so we better skip any periodic expressions here
try:
self._t_root = sp.solve(self.get_val(), amici_time_symbol)
except NotImplementedError:
# the trigger can't be solved for `t`
pass
[docs]
def get_state_update(
self, x: sp.Matrix, x_old: sp.Matrix
) -> sp.Matrix | None:
"""
Get the state update (bolus) expression for the event assignment.
:param x: The current state vector.
:param x_old: The previous state vector.
If ``use_values_from_trigger_time=True``, this is equal to `x`.
:return: State-update matrix or ``None`` if no state update is defined.
"""
if len(self._assignments) == 0:
return None
x_to_x_old = dict(zip(x, x_old, strict=True))
def get_bolus(x_i: sp.Symbol) -> sp.Expr:
"""
Get the bolus expression for a state variable.
:param x_i: state variable symbol
:return: bolus expression
"""
if (assignment := self._assignments.get(x_i)) is not None:
return assignment.subs(x_to_x_old) - x_i
else:
return sp.Float(0.0)
return sp.Matrix([get_bolus(x_i) for x_i in x])
[docs]
def get_initial_value(self) -> bool:
"""
Return the initial value for the root function.
:return:
initial value formula
"""
return self._initial_value
[docs]
def get_priority(self) -> sp.Basic | None:
"""Return the priority of the event assignment."""
return self._priority
def __eq__(self, other):
"""
Check equality of events at the level of trigger/root functions, as we
need to collect unique root functions for ``roots.cpp``
"""
return self.get_val() == other.get_val() and (
self.get_initial_value() == other.get_initial_value()
)
[docs]
def triggers_at_fixed_timepoint(self) -> bool:
"""Check whether the event triggers at a (single) fixed time-point."""
if len(self._t_root) != 1:
return False
return self._t_root[0].is_Number
[docs]
def get_trigger_time(self) -> sp.Float:
"""Get the time at which the event triggers.
Only for events that trigger at a single fixed time-point.
"""
if not self.triggers_at_fixed_timepoint():
raise NotImplementedError(
"This event does not trigger at a fixed timepoint."
)
return self._t_root[0]
[docs]
def has_explicit_trigger_times(
self, allowed_symbols: set[sp.Symbol] | None = None
) -> bool:
"""Check whether the event has explicit trigger times.
Explicit trigger times do not require root finding to determine
the time points at which the event triggers.
:param allowed_symbols:
The set of symbols that are allowed in the trigger time
expressions. If `None`, any symbols are allowed.
If empty, only numeric values are allowed.
"""
if allowed_symbols is None:
return len(self._t_root) > 0
return len(self._t_root) > 0 and all(
t.is_Number or t.free_symbols.issubset(allowed_symbols)
for t in self._t_root
)
[docs]
def get_trigger_times(self) -> set[sp.Expr]:
"""Get the time points at which the event triggers.
Returns a set of expressions, which may contain multiple time points
for events that trigger at multiple time points.
If the return value is empty, the trigger function cannot be solved
for `t`. I.e., the event does not explicitly depend on time,
or sympy is unable to solve the trigger function for `t`.
If the return value is non-empty, it contains expressions for *all*
time points at which the event triggers.
"""
return set(self._t_root)
@property
def uses_values_from_trigger_time(self) -> bool:
"""Whether the event assignment is evaluated using the state from
the time point at which the event triggered (True), or at the time
point at which the event assignment is evaluated (False).
"""
return self._use_values_from_trigger_time
@property
def updates_state(self) -> bool:
"""Whether the event assignment updates the model state."""
return bool(self._assignments)
# defines the type of some attributes in DEModel
symbol_to_type = {
SymbolId.SPECIES: DifferentialState,
SymbolId.ALGEBRAIC_STATE: AlgebraicState,
SymbolId.ALGEBRAIC_EQUATION: AlgebraicEquation,
SymbolId.PARAMETER: Parameter,
SymbolId.FIXED_PARAMETER: Constant,
SymbolId.OBSERVABLE: Observable,
SymbolId.EVENT_OBSERVABLE: EventObservable,
SymbolId.SIGMAY: SigmaY,
SymbolId.SIGMAZ: SigmaZ,
SymbolId.LLHY: LogLikelihoodY,
SymbolId.LLHZ: LogLikelihoodZ,
SymbolId.LLHRZ: LogLikelihoodRZ,
SymbolId.EXPRESSION: Expression,
SymbolId.EVENT: Event,
SymbolId.NOISE_PARAMETER: NoiseParameter,
SymbolId.OBSERVABLE_PARAMETER: ObservableParameter,
}