Program Listing for File model.h

Return to documentation for file (include/amici/model.h)

#ifndef AMICI_MODEL_H
#define AMICI_MODEL_H

#include "amici/abstract_model.h"
#include "amici/defines.h"
#include "amici/event.h"
#include "amici/model_dimensions.h"
#include "amici/model_state.h"
#include "amici/simulation_parameters.h"
#include "amici/splinefunctions.h"
#include "amici/sundials_matrix_wrapper.h"

#include <map>
#include <vector>

namespace amici {

class ExpData;
class Model;
class Solver;
class Logger;
class AmiVector;
class AmiVectorArray;

} // namespace amici

// for serialization friend in amici::Model
namespace boost::serialization {
template <class Archive>
void serialize(Archive& ar, amici::Model& m, unsigned int version);
} // namespace boost::serialization

namespace amici {

enum class ModelQuantity {
    J,
    JB,
    Jv,
    JvB,
    JDiag,
    sx,
    sy,
    sz,
    srz,
    ssigmay,
    ssigmaz,
    xdot,
    sxdot,
    xBdot,
    x0_rdata,
    x0,
    x_rdata,
    x,
    dwdw,
    dwdx,
    dwdp,
    y,
    dydp,
    dydx,
    w,
    root,
    qBdot,
    qBdot_ss,
    xBdot_ss,
    JSparseB_ss,
    deltax,
    deltasx,
    deltaxB,
    k,
    p,
    ts,
    dJydy,
    deltaqB,
    dsigmaydp,
    dsigmaydy,
    dsigmazdp,
    dJydsigma,
    dJydx,
    dzdx,
    dzdp,
    dJrzdsigma,
    dJrzdz,
    dJrzdx,
    dJzdsigma,
    dJzdz,
    dJzdx,
    drzdp,
    drzdx,
};

extern std::map<ModelQuantity, std::string> const model_quantity_to_str;

class Model : public AbstractModel, public ModelDimensions {
  public:
    Model() = default;

    Model(
        ModelDimensions const& model_dimensions,
        SimulationParameters simulation_parameters, SecondOrderMode o2mode,
        std::vector<realtype> idlist, std::vector<int> z2event,
        std::vector<Event> events = {}
    );

    ~Model() override = default;

    Model& operator=(Model const& other) = delete;

    [[nodiscard]] virtual Model* clone() const = 0;

    template <class Archive>
    friend void boost::serialization::serialize(
        Archive& ar, Model& m, unsigned int version
    );

    friend bool operator==(Model const& a, Model const& b);

    // Overloaded base class methods
    using AbstractModel::fdeltaqB;
    using AbstractModel::fdeltasx;
    using AbstractModel::fdeltax;
    using AbstractModel::fdeltaxB;
    using AbstractModel::fdJrzdsigma;
    using AbstractModel::fdJrzdz;
    using AbstractModel::fdJydsigma;
    using AbstractModel::fdJydy;
    using AbstractModel::fdJydy_colptrs;
    using AbstractModel::fdJydy_rowvals;
    using AbstractModel::fdJzdsigma;
    using AbstractModel::fdJzdz;
    using AbstractModel::fdrzdp;
    using AbstractModel::fdrzdx;
    using AbstractModel::fdsigmaydp;
    using AbstractModel::fdsigmaydy;
    using AbstractModel::fdsigmazdp;
    using AbstractModel::fdtotal_cldp;
    using AbstractModel::fdtotal_cldx_rdata;
    using AbstractModel::fdtotal_cldx_rdata_colptrs;
    using AbstractModel::fdtotal_cldx_rdata_rowvals;
    using AbstractModel::fdwdp;
    using AbstractModel::fdwdp_colptrs;
    using AbstractModel::fdwdp_rowvals;
    using AbstractModel::fdwdw;
    using AbstractModel::fdwdw_colptrs;
    using AbstractModel::fdwdw_rowvals;
    using AbstractModel::fdwdx;
    using AbstractModel::fdwdx_colptrs;
    using AbstractModel::fdwdx_rowvals;
    using AbstractModel::fdx_rdatadp;
    using AbstractModel::fdx_rdatadtcl;
    using AbstractModel::fdx_rdatadtcl_colptrs;
    using AbstractModel::fdx_rdatadtcl_rowvals;
    using AbstractModel::fdx_rdatadx_solver;
    using AbstractModel::fdx_rdatadx_solver_colptrs;
    using AbstractModel::fdx_rdatadx_solver_rowvals;
    using AbstractModel::fdydp;
    using AbstractModel::fdydx;
    using AbstractModel::fdzdp;
    using AbstractModel::fdzdx;
    using AbstractModel::fJrz;
    using AbstractModel::fJy;
    using AbstractModel::fJz;
    using AbstractModel::frz;
    using AbstractModel::fsigmay;
    using AbstractModel::fsigmaz;
    using AbstractModel::fsrz;
    using AbstractModel::fstau;
    using AbstractModel::fsx0;
    using AbstractModel::fsx0_fixedParameters;
    using AbstractModel::fsz;
    using AbstractModel::fw;
    using AbstractModel::fx0;
    using AbstractModel::fx0_fixedParameters;
    using AbstractModel::fy;
    using AbstractModel::fz;

    void initialize(
        realtype t, AmiVector& x, AmiVector& dx, AmiVectorArray& sx,
        AmiVectorArray& sdx, bool computeSensitivities,
        std::vector<int>& roots_found
    );

    void reinitialize(
        realtype t, AmiVector& x, AmiVectorArray& sx, bool computeSensitivities
    );

    void initialize_b(
        AmiVector& xB, AmiVector& dxB, AmiVector& xQB, bool posteq
    ) const;

    void initialize_state(realtype t, AmiVector& x);

    void initialize_state_sensitivities(
        realtype t, AmiVectorArray& sx, AmiVector const& x
    );

    void initialize_splines();

    void initialize_spline_sensitivities();

    void initialize_events(
        realtype t, AmiVector const& x, AmiVector const& dx,
        std::vector<int>& roots_found
    );

    void reinit_events(
        realtype t, AmiVector const& x, AmiVector const& dx,
        std::vector<realtype> const& h_old, std::vector<int>& roots_found
    );

    void reinit_explicit_roots();

    int nplist() const;

    int np() const;

    int nk() const;

    int ncl() const;

    int nx_reinit() const;

    double const* k() const;

    int n_max_event() const;

    void set_n_max_event(int nmaxevent);

    int nt() const;

    std::vector<ParameterScaling> const& get_parameter_scale() const;

    void set_parameter_scale(ParameterScaling pscale);

    void set_parameter_scale(std::vector<ParameterScaling> const& pscaleVec);

    std::vector<realtype> const& get_unscaled_parameters() const;

    std::vector<realtype> const& get_parameters() const;

    realtype get_parameter_by_id(std::string const& par_id) const;

    realtype get_parameter_by_name(std::string const& par_name) const;

    void set_parameters(std::vector<realtype> const& p);

    void set_parameter_by_id(
        std::map<std::string, realtype> const& p, bool ignoreErrors = false
    );

    void set_parameter_by_id(std::string const& par_id, realtype value);

    int
    set_parameters_by_id_regex(std::string const& par_id_regex, realtype value);

    void set_parameter_by_name(std::string const& par_name, realtype value);

    void set_parameter_by_name(
        std::map<std::string, realtype> const& p, bool ignoreErrors = false
    );

    int set_parameters_by_name_regex(
        std::string const& par_name_regex, realtype value
    );

    std::vector<realtype> const& get_fixed_parameters() const;

    realtype get_fixed_parameter_by_id(std::string const& par_id) const;

    realtype get_fixed_parameter_by_name(std::string const& par_name) const;

    void set_fixed_parameters(std::vector<realtype> const& k);

    void set_fixed_parameter_by_id(std::string const& par_id, realtype value);

    int set_fixed_parameters_by_id_regex(
        std::string const& par_id_regex, realtype value
    );

    void
    set_fixed_parameter_by_name(std::string const& par_name, realtype value);

    int set_fixed_parameters_by_name_regex(
        std::string const& par_name_regex, realtype value
    );

    virtual std::string get_name() const;

    virtual bool has_parameter_names() const;

    virtual std::vector<std::string> get_parameter_names() const;

    virtual bool has_state_names() const;

    virtual std::vector<std::string> get_state_names() const;

    virtual std::vector<std::string> get_state_names_solver() const;

    virtual bool has_fixed_parameter_names() const;

    virtual std::vector<std::string> get_fixed_parameter_names() const;

    virtual bool has_observable_names() const;

    virtual std::vector<std::string> get_observable_names() const;

    virtual bool has_expression_names() const;

    virtual std::vector<std::string> get_expression_names() const;

    virtual bool has_parameter_ids() const;

    virtual std::vector<std::string> get_parameter_ids() const;

    virtual bool has_state_ids() const;

    virtual std::vector<std::string> get_state_ids() const;

    virtual std::vector<std::string> get_state_ids_solver() const;

    virtual bool has_fixed_parameter_ids() const;

    virtual std::vector<std::string> get_fixed_parameter_ids() const;

    virtual bool has_observable_ids() const;

    virtual std::vector<std::string> get_observable_ids() const;

    virtual bool has_expression_ids() const;

    virtual std::vector<std::string> get_expression_ids() const;

    virtual bool has_quadratic_llh() const;

    std::vector<realtype> const& get_timepoints() const;

    realtype get_timepoint(int it) const;

    void set_timepoints(std::vector<realtype> const& ts);

    double t0() const;

    void set_t0(double t0);

    double t0_preeq() const;

    void set_t0_preeq(double t0_preeq);

    std::vector<bool> const& get_state_is_non_negative() const;

    void set_state_is_non_negative(std::vector<bool> const& stateIsNonNegative);

    void set_all_states_non_negative();

    ModelState const& get_model_state() const { return state_; }

    void set_model_state(ModelState const& state) {
        if (gsl::narrow<int>(state.unscaled_parameters.size()) != np())
            throw AmiException("Mismatch in parameter size");
        if (gsl::narrow<int>(state.fixed_parameters.size()) != nk())
            throw AmiException("Mismatch in fixed parameter size");
        if (gsl::narrow<int>(state.h.size()) != ne)
            throw AmiException("Mismatch in Heaviside size");
        if (gsl::narrow<int>(state.total_cl.size()) != ncl())
            throw AmiException("Mismatch in conservation law size");
        if (gsl::narrow<int>(state.stotal_cl.size()) != ncl() * np())
            throw AmiException("Mismatch in conservation law sensitivity size");
        state_ = state;
    }

    void set_minimum_sigma_residuals(double const min_sigma) {
        min_sigma_ = min_sigma;
    }

    [[nodiscard]] realtype get_minimum_sigma_residuals() const {
        return min_sigma_;
    }

    void set_add_sigma_residuals(bool const sigma_res) {
        sigma_res_ = sigma_res;
    }

    bool get_add_sigma_residuals() const { return sigma_res_; }

    std::vector<int> const& get_parameter_list() const;

    int plist(int pos) const;

    void set_parameter_list(std::vector<int> const& plist);

    std::vector<realtype> get_initial_state(realtype t0);

    std::vector<realtype> get_initial_state() {
        return get_initial_state(t0());
    }

    void set_initial_state(std::vector<realtype> const& x0);

    [[nodiscard]] bool has_custom_initial_state() const;

    std::vector<realtype> get_initial_state_sensitivities() {
        return get_initial_state_sensitivities(t0());
    }

    std::vector<realtype> get_initial_state_sensitivities(realtype t0);

    void set_initial_state_sensitivities(std::vector<realtype> const& sx0);

    bool has_custom_initial_state_sensitivities() const;

    void
    set_unscaled_initial_state_sensitivities(std::vector<realtype> const& sx0);

    void set_steady_state_computation_mode(SteadyStateComputationMode mode);

    SteadyStateComputationMode get_steady_state_computation_mode() const;

    void set_steady_state_sensitivity_mode(SteadyStateSensitivityMode mode);

    [[nodiscard]] SteadyStateSensitivityMode
    get_steady_state_sensitivity_mode() const;

    void set_reinitialize_fixed_parameter_initial_states(bool flag);

    [[nodiscard]] bool get_reinitialize_fixed_parameter_initial_states() const;

    void require_sensitivities_for_all_parameters();

    void get_expression(gsl::span<realtype> w, realtype t, AmiVector const& x);

    void get_observable(gsl::span<realtype> y, realtype t, AmiVector const& x);

    [[nodiscard]] virtual ObservableScaling
    get_observable_scaling(int iy) const;

    void get_observable_sensitivity(
        gsl::span<realtype> sy, realtype t, AmiVector const& x,
        AmiVectorArray const& sx
    );

    void get_observable_sigma(
        gsl::span<realtype> sigmay, int it, ExpData const* edata
    );

    void get_observable_sigma_sensitivity(
        gsl::span<realtype> ssigmay, gsl::span<realtype const> sy, int it,
        ExpData const* edata
    );

    void add_observable_objective(
        realtype& Jy, int it, AmiVector const& x, ExpData const& edata
    );

    void add_observable_objective_sensitivity(
        std::vector<realtype>& sllh, std::vector<realtype>& s2llh, int it,
        AmiVector const& x, AmiVectorArray const& sx, ExpData const& edata
    );

    void add_partial_observable_objective_sensitivity(
        std::vector<realtype>& sllh, std::vector<realtype>& s2llh, int it,
        AmiVector const& x, ExpData const& edata
    );

    void get_adjoint_state_observable_update(
        gsl::span<realtype> dJydx, int it, AmiVector const& x,
        ExpData const& edata
    );

    void
    get_event(gsl::span<realtype> z, int ie, realtype t, AmiVector const& x);
    void get_event_sensitivity(
        gsl::span<realtype> sz, int ie, realtype t, AmiVector const& x,
        AmiVectorArray const& sx
    );

    void get_unobserved_event_sensitivity(gsl::span<realtype> sz, int ie);

    void get_event_regularization(
        gsl::span<realtype> rz, int ie, realtype t, AmiVector const& x
    );

    void get_event_regularization_sensitivity(
        gsl::span<realtype> srz, int ie, realtype t, AmiVector const& x,
        AmiVectorArray const& sx
    );
    void get_event_sigma(
        gsl::span<realtype> sigmaz, int ie, int nroots, realtype t,
        ExpData const* edata
    );

    void get_event_sigma_sensitivity(
        gsl::span<realtype> ssigmaz, int ie, int nroots, realtype t,
        ExpData const* edata
    );

    void add_event_objective(
        realtype& Jz, int ie, int nroots, realtype t, AmiVector const& x,
        ExpData const& edata
    );

    void add_event_objective_regularization(
        realtype& Jrz, int ie, int nroots, realtype t, AmiVector const& x,
        ExpData const& edata
    );

    void add_event_objective_sensitivity(
        std::vector<realtype>& sllh, std::vector<realtype>& s2llh, int ie,
        int nroots, realtype t, AmiVector const& x, AmiVectorArray const& sx,
        ExpData const& edata
    );

    void add_partial_event_objective_sensitivity(
        std::vector<realtype>& sllh, std::vector<realtype>& s2llh, int ie,
        int nroots, realtype t, AmiVector const& x, ExpData const& edata
    );

    void get_adjoint_state_event_update(
        gsl::span<realtype> dJzdx, int ie, int nroots, realtype t,
        AmiVector const& x, ExpData const& edata
    );

    void get_event_time_sensitivity(
        std::vector<realtype>& stau, realtype t, int ie, AmiVector const& x,
        AmiVectorArray const& sx, AmiVector const& dx
    );

    void add_state_event_update(
        AmiVector& x, int ie, realtype t, AmiVector const& xdot,
        AmiVector const& xdot_old, AmiVector const& x_old,
        ModelState const& state
    );

    void add_state_sensitivity_event_update(
        AmiVectorArray& sx, int ie, realtype t, AmiVector const& x,
        AmiVector const& x_old, AmiVector const& xdot,
        AmiVector const& xdot_old, AmiVectorArray const& sx_old,
        std::vector<realtype> const& stau
    );

    void add_adjoint_state_event_update(
        AmiVector& xB, int ie, realtype t, AmiVector const& x,
        AmiVector const& xdot, AmiVector const& xdot_old,
        AmiVector const& x_old, AmiVector const& dx
    );

    void add_adjoint_quadrature_event_update(
        AmiVector& xQB, int ie, realtype t, AmiVector const& x,
        AmiVector const& xB, AmiVector const& xdot, AmiVector const& xdot_old,
        AmiVector const& x_old, AmiVector const& dx
    );

    void update_heaviside(std::vector<int> const& rootsfound);

    int check_finite(
        gsl::span<realtype const> array, ModelQuantity model_quantity,
        realtype t
    ) const;
    int check_finite(
        gsl::span<realtype const> array, ModelQuantity model_quantity,
        size_t num_cols, realtype t
    ) const;

    int
    check_finite(SUNMatrix m, ModelQuantity model_quantity, realtype t) const;

    void set_always_check_finite(bool alwaysCheck);

    [[nodiscard]] bool get_always_check_finite() const;

    void fx0(realtype t, AmiVector& x);

    void fx0_fixedParameters(realtype t, AmiVector& x);

    void fsx0(realtype t, AmiVectorArray& sx, AmiVector const& x);

    void
    fsx0_fixedParameters(realtype t, AmiVectorArray& sx, AmiVector const& x);

    virtual void fsdx0();

    void fx_rdata(gsl::span<realtype> x_rdata, AmiVector const& x_solver);

    void fsx_rdata(
        gsl::span<realtype> sx_rdata, AmiVectorArray const& sx_solver,
        AmiVector const& x_solver
    );

    void set_reinitialization_state_idxs(std::vector<int> const& idxs);

    [[nodiscard]] std::vector<int> const&
    get_reinitialization_state_idxs() const;

    [[nodiscard]] SUNMatrixWrapper const& get_dxdotdp_full() const;

    [[nodiscard]] virtual std::vector<double> get_trigger_timepoints() const;

    [[nodiscard]] std::vector<realtype> get_steadystate_mask() const {
        return steadystate_mask_;
    }

    void set_steadystate_mask(std::vector<realtype> const& mask);

    [[nodiscard]] Event const& get_event(int const ie) const {
        return events_.at(ie);
    }

    [[nodiscard]] bool get_any_state_nonnegative() const {
        return any_state_non_negative_;
    }

    [[nodiscard]] std::vector<std::vector<realtype>> fexplicit_roots(
        [[maybe_unused]] realtype const* p, [[maybe_unused]] realtype const* k
    ) override {
        if (ne != ne_solver) {
            throw AmiException(
                "ne!=ne_solver, but 'fexplicit_roots' is not implemented for "
                "this model."
            );
        }
        return {};
    }

    std::vector<realtype> const& get_id_list() const { return id_list_; }

    SecondOrderMode get_second_order_mode() const { return o2_mode_; }

    Logger* get_logger() const { return logger_; }

    void set_logger(Logger* logger) { logger_ = logger; }

    std::map<realtype, std::vector<int>> const& get_explicit_roots() const {
        return explicit_roots_;
    }

  protected:
    void write_slice_event(
        gsl::span<realtype const> slice, gsl::span<realtype> buffer, int ie
    );

    void write_sensitivity_slice_event(
        gsl::span<realtype const> slice, gsl::span<realtype> buffer, int ie
    );

    void write_llh_sensitivity_slice(
        std::vector<realtype> const& dLLhdp, std::vector<realtype>& sllh,
        std::vector<realtype>& s2llh
    );

    void check_llh_buffer_size(
        std::vector<realtype> const& sllh, std::vector<realtype> const& s2llh
    ) const;

    void initialize_vectors();

    void fy(realtype t, AmiVector const& x);

    void fdydp(realtype t, AmiVector const& x);

    void fdydx(realtype t, AmiVector const& x);

    void fsigmay(int it, ExpData const* edata);

    void fdsigmaydp(int it, ExpData const* edata);

    void fdsigmaydy(int it, ExpData const* edata);

    void fdJydy(int it, AmiVector const& x, ExpData const& edata);

    void fdJydsigma(int it, AmiVector const& x, ExpData const& edata);

    void fdJydp(int it, AmiVector const& x, ExpData const& edata);

    void fdJydx(int it, AmiVector const& x, ExpData const& edata);

    void fz(int ie, realtype t, AmiVector const& x);

    void fdzdp(int ie, realtype t, AmiVector const& x);

    void fdzdx(int ie, realtype t, AmiVector const& x);

    void frz(int ie, realtype t, AmiVector const& x);

    void fdrzdp(int ie, realtype t, AmiVector const& x);

    void fdrzdx(int ie, realtype t, AmiVector const& x);

    void fsigmaz(int ie, int nroots, realtype t, ExpData const* edata);

    void fdsigmazdp(int ie, int nroots, realtype t, ExpData const* edata);

    void fdJzdz(
        int ie, int nroots, realtype t, AmiVector const& x, ExpData const& edata
    );

    void fdJzdsigma(
        int ie, int nroots, realtype t, AmiVector const& x, ExpData const& edata
    );

    void fdJzdp(
        int ie, int nroots, realtype t, AmiVector const& x, ExpData const& edata
    );

    void fdJzdx(
        int ie, int nroots, realtype t, AmiVector const& x, ExpData const& edata
    );

    void fdJrzdz(
        int ie, int nroots, realtype t, AmiVector const& x, ExpData const& edata
    );

    void fdJrzdsigma(
        int ie, int nroots, realtype t, AmiVector const& x, ExpData const& edata
    );

    void fspl(realtype t);

    void fsspl(realtype t);

    void fw(realtype t, realtype const* x, bool include_static = true);

    void fdwdp(realtype t, realtype const* x, bool include_static = true);

    void fdwdx(realtype t, realtype const* x, bool include_static = true);

    void fdwdw(realtype t, realtype const* x, bool include_static = true);

    virtual void fx_rdata(
        realtype* x_rdata, realtype const* x_solver, realtype const* tcl,
        realtype const* p, realtype const* k
    );

    virtual void fsx_rdata(
        realtype* sx_rdata, realtype const* sx_solver, realtype const* stcl,
        realtype const* p, realtype const* k, realtype const* x_solver,
        realtype const* tcl, int ip
    );

    virtual void fx_solver(realtype* x_solver, realtype const* x_rdata);

    virtual void fsx_solver(realtype* sx_solver, realtype const* sx_rdata);

    virtual void ftotal_cl(
        realtype* total_cl, realtype const* x_rdata, realtype const* p,
        realtype const* k
    );

    virtual void fstotal_cl(
        realtype* stotal_cl, realtype const* sx_rdata, int ip,
        realtype const* x_rdata, realtype const* p, realtype const* k,
        realtype const* tcl
    );

    const_N_Vector compute_x_pos(const_N_Vector x);

    realtype const* compute_x_pos(AmiVector const& x);

    ModelState state_;

    ModelStateDerived derived_state_;

    std::vector<HermiteSpline> splines_;

    std::vector<int> z2event_;

    std::vector<realtype> x0data_;

    std::vector<realtype> sx0data_;

    std::vector<bool> state_is_non_negative_;

    bool any_state_non_negative_{false};

    int nmaxevent_{10};

    SteadyStateComputationMode steadystate_computation_mode_{
        SteadyStateComputationMode::integrationOnly
    };

    SteadyStateSensitivityMode steadystate_sensitivity_mode_{
        SteadyStateSensitivityMode::integrationOnly
    };

#ifdef NDEBUG
    bool always_check_finite_{false};
#else
    bool always_check_finite_{true};
#endif

    bool sigma_res_{false};

    realtype min_sigma_{50.0};

  private:
    SecondOrderMode o2_mode_{SecondOrderMode::none};

    std::vector<realtype> id_list_;

    SimulationParameters simulation_parameters_;

    std::vector<realtype> steadystate_mask_;

    std::vector<Event> events_;

    Logger* logger_ = nullptr;

    std::map<realtype, std::vector<int>> explicit_roots_ = {};
};

bool operator==(Model const& a, Model const& b);
bool operator==(ModelDimensions const& a, ModelDimensions const& b);

} // namespace amici

#endif // AMICI_MODEL_H