Program Listing for File rdata.h

Return to documentation for file (include/amici/rdata.h)

#ifndef AMICI_RDATA_H
#define AMICI_RDATA_H

#include "amici/defines.h"
#include "amici/logging.h"
#include "amici/misc.h"
#include "amici/model.h"
#include "amici/sundials_matrix_wrapper.h"
#include "amici/vector.h"

#include <vector>

namespace amici {
class ReturnData;
class Solver;
class ExpData;
class ForwardProblem;
class BackwardProblem;
class SteadyStateProblem;
class SteadyStateBackwardProblem;
} // namespace amici

namespace boost::serialization {
template <class Archive>
void serialize(Archive& ar, amici::ReturnData& r, unsigned int version);
} // namespace boost::serialization

namespace amici {

class ReturnData : public ModelDimensions {
  public:
    ReturnData() = default;

    ReturnData(
        std::vector<realtype> ts_, ModelDimensions const& model_dimensions_,
        int nmaxevent_, int newton_maxsteps_, std::vector<int> plist_,
        std::vector<ParameterScaling> pscale_, SecondOrderMode o2mode_,
        SensitivityOrder sensi_, SensitivityMethod sensi_meth_,
        RDataReporting rdrm_, bool quadratic_llh_, bool sigma_res_,
        realtype sigma_offset_
    );

    ReturnData(Solver const& solver, Model const& model);

    ~ReturnData() = default;

    void process_simulation_objects(
        ForwardProblem const* fwd, BackwardProblem const* bwd, Model& model,
        Solver const& solver, ExpData const* edata
    );
    std::string id;

    std::vector<realtype> ts;

    std::vector<realtype> xdot;

    std::vector<realtype> J;

    std::vector<realtype> w;

    std::vector<realtype> z;

    std::vector<realtype> sigmaz;

    std::vector<realtype> sz;

    std::vector<realtype> ssigmaz;

    std::vector<realtype> rz;

    std::vector<realtype> srz;

    std::vector<realtype> s2rz;

    std::vector<realtype> x;

    std::vector<realtype> sx;

    std::vector<realtype> y;

    std::vector<realtype> sigmay;

    std::vector<realtype> sy;

    std::vector<realtype> ssigmay;

    std::vector<realtype> res;

    std::vector<realtype> sres;

    std::vector<realtype> FIM;

    std::vector<int> numsteps;

    std::vector<int> numsteps_b;

    std::vector<int> num_rhs_evals;

    std::vector<int> num_rhs_evals_b;

    std::vector<int> num_err_test_fails;

    std::vector<int> num_err_test_fails_b;

    std::vector<int> num_non_lin_solv_conv_fails;

    std::vector<int> num_non_lin_solv_conv_fails_b;

    std::vector<int> order;

    double cpu_time = 0.0;

    double cpu_time_b = 0.0;

    double cpu_time_total = 0.0;

    std::vector<SteadyStateStatus> preeq_status;

    double preeq_cpu_time = 0.0;

    double preeq_cpu_time_b = 0.0;

    std::vector<SteadyStateStatus> posteq_status;

    double posteq_cpu_time = 0.0;

    double posteq_cpu_time_b = 0.0;

    std::vector<int> preeq_numsteps;

    int preeq_numsteps_b = 0;

    std::vector<int> posteq_numsteps;

    int posteq_numsteps_b = 0;

    realtype preeq_t = NAN;

    realtype preeq_wrms = NAN;

    realtype posteq_t = NAN;

    realtype posteq_wrms = NAN;

    std::vector<realtype> x0;

    std::vector<realtype> x_ss;

    std::vector<realtype> sx0;

    std::vector<realtype> sx_ss;

    realtype llh = 0.0;

    realtype chi2 = 0.0;

    std::vector<realtype> sllh;

    std::vector<realtype> s2llh;

    int status = AMICI_NOT_RUN;

    int nplist{0};

    int nmaxevent{0};

    int nt{0};

    int newton_maxsteps{0};

    std::vector<ParameterScaling> pscale;

    SecondOrderMode o2mode{SecondOrderMode::none};

    SensitivityOrder sensi{SensitivityOrder::none};

    SensitivityMethod sensi_meth{SensitivityMethod::none};

    RDataReporting rdata_reporting{RDataReporting::full};

    template <class Archive>
    friend void boost::serialization::serialize(
        Archive& ar, ReturnData& r, unsigned int version
    );

    bool sigma_res{false};

    std::vector<LogItem> messages;

    realtype t_last{std::numeric_limits<realtype>::quiet_NaN()};

    std::vector<int> plist;

  protected:
    realtype sigma_offset{0.0};

    std::vector<int> nroots_;

    void initialize_likelihood_reporting(bool quadratic_llh);

    void initializeObservablesLikelihoodReporting(bool quadratic_llh);

    void initialize_residual_reporting(bool enable_res);

    void initialize_full_reporting(bool enable_fim);

    void initialize_objective_function(bool enable_chi2);

    void process_pre_equilibration(
        SteadyStateProblem const& preeq,
        SteadyStateBackwardProblem const* preeq_bwd, Model& model
    );

    void process_post_equilibration(
        SteadyStateProblem const& posteq,
        SteadyStateBackwardProblem const* posteq_bwd, Model& model,
        ExpData const* edata
    );

    void process_forward_problem(
        ForwardProblem const& fwd, Model& model, ExpData const* edata
    );

    void process_backward_problem(
        ForwardProblem const& fwd, BackwardProblem const& bwd,
        SteadyStateProblem const* preeq,
        SteadyStateBackwardProblem const* preeq_bwd, Model& model
    );

    void process_solver(Solver const& solver);

    template <class T>
    void store_jacobian_and_derivative(T const& problem, Model& model) {
        auto const& simulation_state = problem.get_final_simulation_state();
        model.set_model_state(simulation_state.mod);
        auto const& sol = simulation_state.sol;
        sundials::Context sunctx;
        AmiVector xdot(nx_solver, sunctx);
        if (!this->xdot.empty() || !this->J.empty())
            model.fxdot(sol.t, sol.x, sol.dx, xdot);

        if (!this->xdot.empty())
            write_slice(xdot, this->xdot);

        if (!this->J.empty()) {
            SUNMatrixWrapper J(nx_solver, nx_solver, sunctx);
            model.fJ(sol.t, 0.0, sol.x, sol.dx, xdot, J);
            // CVODES uses colmajor, so we need to transform to rowmajor
            for (int ix = 0; ix < model.nx_solver; ix++)
                for (int jx = 0; jx < model.nx_solver; jx++)
                    this->J.at(ix * model.nx_solver + jx)
                        = J.data()[ix + model.nx_solver * jx];
        }
    }

    void
    fres(int it, Model& model, SolutionState const& sol, ExpData const& edata);

    void fchi2(int it, ExpData const& edata);

    void
    fsres(int it, Model& model, SolutionState const& sol, ExpData const& edata);

    void
    fFIM(int it, Model& model, SolutionState const& sol, ExpData const& edata);

    void invalidate(int it_start);

    void invalidate_llh();

    void invalidate_sllh();

    void apply_chain_rule_factor_to_simulation_results(Model const& model);

    [[nodiscard]] bool computing_fsa() const {
        return (
            sensi_meth == SensitivityMethod::forward
            && sensi >= SensitivityOrder::first
        );
    }

    void get_data_output(
        int it, Model& model, SolutionState const& sol, ExpData const* edata
    );

    void get_data_sensis_fsa(
        int it, Model& model, SolutionState const& sol, ExpData const* edata
    );

    void get_event_output(
        std::vector<int> const& rootidx, Model& model, SolutionState const& sol,
        ExpData const* edata
    );

    void get_event_sensis_fsa(
        int ie, Model& model, SolutionState const& sol, ExpData const* edata
    );

    void handle_sx0_backward(
        Model const& model, AmiVectorArray const& sx0, AmiVector const& xB,
        std::vector<realtype>& llhS0
    ) const;

    void handle_sx0_forward(
        Model const& model, SolutionState const& sol,
        std::vector<realtype>& llhS0, AmiVector const& xB
    ) const;
};

class ModelContext : public ContextManager {
  public:
    explicit ModelContext(Model* model);

    ModelContext& operator=(ModelContext const& other) = delete;

    ~ModelContext() noexcept(false);

    void restore();

  private:
    Model* model_{nullptr};
    ModelState original_state_;
};

} // namespace amici

#endif // AMICI_RDATA_H