Program Listing for File solver.h

Return to documentation for file (include/amici/solver.h)

#ifndef AMICI_SOLVER_H
#define AMICI_SOLVER_H

#include "amici/defines.h"
#include "amici/logging.h"
#include "amici/misc.h"
#include "amici/model_state.h"
#include "amici/sundials_linsol_wrapper.h"
#include "amici/vector.h"

#include <chrono>
#include <cmath>
#include <functional>
#include <memory>

namespace amici {

class ReturnData;
class ForwardProblem;
class BackwardProblem;
class Model;
class Solver;

} // namespace amici

// for serialization friend in Solver
namespace boost::serialization {
template <class Archive>
void serialize(Archive& ar, amici::Solver& s, unsigned int version);
} // namespace boost::serialization

namespace amici {

/*
 * NOTE: Any changes in data members here must be propagated to copy ctor,
 * equality operator, serialization functions in serialization.h, and
 * amici::hdf5::(read/write)SolverSettings(From/To)HDF5 in hdf5.cpp.
 */
class Solver {
  public:
    using user_data_type = std::pair<Model*, Solver const*>;
    using free_solver_ptr = std::function<void(void*)>;
    Solver() = default;

    Solver(Solver const& other);

    virtual ~Solver() = default;

    virtual Solver* clone() const = 0;

    SUNContext get_sun_context() const;

    virtual std::string get_class_name() const = 0;

    int run(realtype tout) const;

    int step(realtype tout) const;

    void run_b(realtype tout) const;


    void setup(
        realtype t0, Model* model, AmiVector const& x0, AmiVector const& dx0,
        AmiVectorArray const& sx0, AmiVectorArray const& sdx0
    ) const;


    void setup_b(
        int* which, realtype tf, Model* model, AmiVector const& xB0,
        AmiVector const& dxB0, AmiVector const& xQB0
    ) const;


    void setup_steady_state(
        realtype t0, Model* model, AmiVector const& x0, AmiVector const& dx0,
        AmiVector const& xB0, AmiVector const& dxB0, AmiVector const& xQ0
    ) const;

    void update_and_reinit_states_and_sensitivities(Model* model) const;

    virtual void get_root_info(int* rootsfound) const = 0;

    virtual void calc_ic(realtype tout1) const = 0;

    virtual void calc_ic_b(int which, realtype tout1) const = 0;

    virtual void solve_b(realtype tBout, int itaskB) const = 0;

    virtual void turn_off_root_finding() const = 0;

    SensitivityMethod get_sensitivity_method() const;

    void set_sensitivity_method(SensitivityMethod sensi_meth);

    SensitivityMethod get_sensitivity_method_pre_equilibration() const;

    void set_sensitivity_method_pre_equilibration(
        SensitivityMethod sensi_meth_preeq
    );

    void switch_forward_sensis_off() const;

    int get_newton_max_steps() const;

    void set_newton_max_steps(int newton_maxsteps);

    NewtonDampingFactorMode get_newton_damping_factor_mode() const;

    void
    set_newton_damping_factor_mode(NewtonDampingFactorMode dampingFactorMode);

    double get_newton_damping_factor_lower_bound() const;

    void set_newton_damping_factor_lower_bound(double dampingFactorLowerBound);

    SensitivityOrder get_sensitivity_order() const;

    void set_sensitivity_order(SensitivityOrder sensi);

    double get_relative_tolerance() const;

    void set_relative_tolerance(double rtol);

    double get_absolute_tolerance() const;

    void set_absolute_tolerance(double atol);

    double get_relative_tolerance_fsa() const;

    void set_relative_tolerance_fsa(double rtol);

    double get_absolute_tolerance_fsa() const;

    void set_absolute_tolerance_fsa(double atol);

    double get_relative_tolerance_b() const;

    void set_relative_tolerance_b(double rtol);

    double get_absolute_tolerance_b() const;

    void set_absolute_tolerance_b(double atol);

    double get_relative_tolerance_quadratures() const;

    void set_relative_tolerance_quadratures(double rtol);

    double get_absolute_tolerance_quadratures() const;

    void set_absolute_tolerance_quadratures(double atol);

    double get_steady_state_tolerance_factor() const;

    void set_steady_state_tolerance_factor(double factor);

    double get_relative_tolerance_steady_state() const;

    void set_relative_tolerance_steady_state(double rtol);

    double get_absolute_tolerance_steady_state() const;

    void set_absolute_tolerance_steady_state(double atol);

    double get_steady_state_sensi_tolerance_factor() const;

    void set_steady_state_sensi_tolerance_factor(double factor);

    double get_relative_tolerance_steady_state_sensi() const;

    void set_relative_tolerance_steady_state_sensi(double rtol);

    double get_absolute_tolerance_steady_state_sensi() const;

    void set_absolute_tolerance_steady_state_sensi(double atol);

    long int get_max_steps() const;

    void set_max_steps(long int maxsteps);

    double get_max_time() const;

    void set_max_time(double maxtime);

    void start_timer() const;

    bool time_exceeded(int interval = 1) const;

    long int get_max_steps_backward_problem() const;

    void set_max_steps_backward_problem(long int maxsteps);

    LinearMultistepMethod get_linear_multistep_method() const;

    void set_linear_multistep_method(LinearMultistepMethod lmm);

    NonlinearSolverIteration get_non_linear_solver_iteration() const;

    void set_non_linear_solver_iteration(NonlinearSolverIteration iter);

    InterpolationType get_interpolation_type() const;

    void set_interpolation_type(InterpolationType interpType);

    int get_state_ordering() const;

    void set_state_ordering(int ordering);

    bool get_stability_limit_flag() const;

    void set_stability_limit_flag(bool stldet);

    LinearSolver get_linear_solver() const;

    void set_linear_solver(LinearSolver linsol);

    InternalSensitivityMethod get_internal_sensitivity_method() const;

    void set_internal_sensitivity_method(InternalSensitivityMethod ism);

    RDataReporting get_return_data_reporting_mode() const;

    void set_return_data_reporting_mode(RDataReporting rdrm);

    void write_solution(
        realtype& t, AmiVector& x, AmiVector& dx, AmiVectorArray& sx,
        AmiVector& xQ
    ) const;

    void write_solution(
        realtype& t, AmiVector& x, AmiVector& dx, AmiVectorArray& sx
    ) const;

    void write_solution(SolutionState& sol) const;

    void write_solution(realtype t, SolutionState& sol) const;

    void write_solution_b(
        realtype& t, AmiVector& xB, AmiVector& dxB, AmiVector& xQB, int which
    ) const;

    AmiVector const& get_state(realtype t) const;

    AmiVector const& get_derivative_state(realtype t) const;

    AmiVectorArray const& get_state_sensitivity(realtype t) const;

    AmiVector const& get_adjoint_state(int which, realtype t) const;

    AmiVector const& get_adjoint_derivative_state(int which, realtype t) const;

    AmiVector const& get_adjoint_quadrature(int which, realtype t) const;

    AmiVector const& get_quadrature(realtype t) const;

    virtual void
    reinit(realtype t0, AmiVector const& yy0, AmiVector const& yp0) const
        = 0;

    virtual void
    sens_reinit(AmiVectorArray const& yyS0, AmiVectorArray const& ypS0) const
        = 0;

    virtual void sens_toggle_off() const = 0;

    virtual void reinit_b(
        int which, realtype tB0, AmiVector const& yyB0, AmiVector const& ypB0
    ) const
        = 0;

    virtual void reinit_quad_b(int which, AmiVector const& yQB0) const = 0;

    realtype get_t() const;

    realtype get_cpu_time() const;

    realtype get_cpu_time_b() const;

    int nx() const;

    int nplist() const;

    int nquad() const;

    bool computing_fsa() const {
        return get_sensitivity_order() >= SensitivityOrder::first
               && get_sensitivity_method() == SensitivityMethod::forward
               && nplist() > 0 && nx() > 0;
    }

    bool computing_asa() const {
        return get_sensitivity_order() >= SensitivityOrder::first
               && get_sensitivity_method() == SensitivityMethod::adjoint
               && nplist() > 0;
    }

    void reset_diagnosis() const;

    void store_diagnosis() const;

    void store_diagnosis_b(int which) const;

    std::vector<int> const& get_num_steps() const { return ns_; }

    std::vector<int> const& get_num_steps_b() const { return nsB_; }

    std::vector<int> const& get_num_rhs_evals() const { return nrhs_; }

    std::vector<int> const& get_num_rhs_evals_b() const { return nrhsB_; }

    std::vector<int> const& get_num_err_test_fails() const { return netf_; }

    std::vector<int> const& get_num_err_test_fails_b() const { return netfB_; }

    std::vector<int> const& get_num_non_lin_solv_conv_fails() const {
        return nnlscf_;
    }

    std::vector<int> const& get_num_non_lin_solv_conv_fails_b() const {
        return nnlscfB_;
    }

    std::vector<int> const& get_last_order() const { return order_; }

    bool get_newton_step_steady_state_check() const {
        return newton_step_steadystate_conv_;
    }

    bool get_sensi_steady_state_check() const {
        return check_sensi_steadystate_conv_;
    }

    void set_newton_step_steady_state_check(bool const flag) {
        newton_step_steadystate_conv_ = flag;
    }

    void set_sensi_steady_state_check(bool const flag) {
        check_sensi_steadystate_conv_ = flag;
    }

    void set_max_nonlin_iters(int max_nonlin_iters);

    int get_max_nonlin_iters() const;

    void set_max_conv_fails(int max_conv_fails);

    int get_max_conv_fails() const;

    void set_constraints(std::vector<realtype> const& constraints);

    std::vector<realtype> get_constraints() const {
        return constraints_.get_vector();
    }

    void set_max_step_size(realtype max_step_size);

    realtype get_max_step_size() const;

    template <class Archive>
    friend void boost::serialization::serialize(
        Archive& ar, Solver& s, unsigned int version
    );

    friend bool operator==(Solver const& a, Solver const& b);

    Logger* get_logger() const { return logger_; }

    void set_logger(Logger* logger) { logger_ = logger; }

  protected:
    virtual void set_stop_time(realtype tstop) const = 0;

    virtual int solve(realtype tout, int itask) const = 0;

    virtual int solve_f(realtype tout, int itask, int* ncheckPtr) const = 0;

    virtual void reinit_post_process_f(realtype tnext) const = 0;

    virtual void reinit_post_process_b(realtype tnext) const = 0;

    virtual void get_sens() const = 0;

    virtual void get_b(int which) const = 0;

    virtual void get_quad_b(int which) const = 0;

    virtual void get_quad(realtype& t) const = 0;

    virtual void
    init(realtype t0, AmiVector const& x0, AmiVector const& dx0) const
        = 0;

    virtual void init_steady_state(
        realtype t0, AmiVector const& x0, AmiVector const& dx0
    ) const
        = 0;

    virtual void
    sens_init_1(AmiVectorArray const& sx0, AmiVectorArray const& sdx0) const
        = 0;

    virtual void b_init(
        int which, realtype tf, AmiVector const& xB0, AmiVector const& dxB0
    ) const
        = 0;

    virtual void qb_init(int which, AmiVector const& xQB0) const = 0;

    virtual void root_init(int ne) const = 0;

    void initialize_non_linear_solver_sens(Model const* model) const;

    virtual void set_dense_jac_fn() const = 0;

    virtual void set_sparse_jac_fn() const = 0;

    virtual void set_band_jac_fn() const = 0;

    virtual void set_jac_times_vec_fn() const = 0;

    virtual void set_dense_jac_fn_b(int which) const = 0;

    virtual void set_sparse_jac_fn_b(int which) const = 0;

    virtual void set_band_jac_fn_b(int which) const = 0;

    virtual void set_jac_times_vec_fn_b(int which) const = 0;

    virtual void set_sparse_jac_fn_ss() const = 0;

    virtual void allocate_solver() const = 0;

    virtual void set_ss_tolerances(double rtol, double atol) const = 0;

    virtual void set_sens_ss_tolerances(double rtol, double const* atol) const
        = 0;

    virtual void set_sens_err_con(bool error_corr) const = 0;

    virtual void set_quad_err_con_b(int which, bool flag) const = 0;

    virtual void set_quad_err_con(bool flag) const = 0;

    virtual void set_err_handler_fn() const;

    virtual void set_user_data() const = 0;

    virtual void set_user_data_b(int which) const = 0;

    virtual void set_max_num_steps(long int mxsteps) const = 0;

    virtual void set_max_num_steps_b(int which, long int mxstepsB) const = 0;

    virtual void set_stab_lim_det(int stldet) const = 0;

    virtual void set_stab_lim_det_b(int which, int stldet) const = 0;

    virtual void set_id(Model const* model) const = 0;

    virtual void set_suppress_alg(bool flag) const = 0;

    virtual void set_sens_params(
        realtype const* p, realtype const* pbar, int const* plist
    ) const
        = 0;

    virtual void get_dky(realtype t, int k) const = 0;

    virtual void get_dky_b(realtype t, int k, int which) const = 0;

    virtual void get_sens_dky(realtype t, int k) const = 0;

    virtual void get_quad_dky_b(realtype t, int k, int which) const = 0;

    virtual void get_quad_dky(realtype t, int k) const = 0;

    virtual void adj_init() const = 0;

    virtual void quad_init(AmiVector const& xQ0) const = 0;

    virtual void allocate_solver_b(int* which) const = 0;

    virtual void
    set_ss_tolerances_b(int which, realtype relTolB, realtype absTolB) const
        = 0;

    virtual void
    quad_ss_tolerances_b(int which, realtype reltolQB, realtype abstolQB) const
        = 0;

    virtual void quad_ss_tolerances(realtype reltolQB, realtype abstolQB) const
        = 0;

    virtual void get_num_steps(void const* ami_mem, long int* numsteps) const
        = 0;

    virtual void
    get_num_rhs_evals(void const* ami_mem, long int* numrhsevals) const
        = 0;

    virtual void
    get_num_err_test_fails(void const* ami_mem, long int* numerrtestfails) const
        = 0;

    virtual void get_num_non_lin_solv_conv_fails(
        void const* ami_mem, long int* numnonlinsolvconvfails
    ) const
        = 0;

    virtual void get_last_order(void const* ami_mem, int* order) const = 0;

    void initialize_linear_solver(Model const* model) const;

    void initialize_non_linear_solver() const;

    virtual void set_linear_solver() const = 0;

    virtual void set_linear_solver_b(int which) const = 0;

    virtual void set_non_linear_solver() const = 0;

    virtual void set_non_linear_solver_b(int which) const = 0;

    virtual void set_non_linear_solver_sens() const = 0;


    void initialize_linear_solver_b(Model const* model, int which) const;

    void initialize_non_linear_solver_b(int which) const;

    virtual Model const* get_model() const = 0;

    bool get_init_done() const;

    bool get_sens_init_done() const;

    bool get_adj_init_done() const;

    bool get_init_done_b(int which) const;

    bool get_quad_init_done_b(int which) const;

    bool get_quad_init_done() const;

    virtual void diag() const = 0;

    virtual void diag_b(int which) const = 0;

    void reset_mutable_memory(int nx, int nplist, int nquad) const;

    virtual void* get_adj_b_mem(void* ami_mem, int which) const = 0;

    void apply_tolerances() const;

    void apply_tolerances_fsa() const;

    void apply_tolerances_asa(int which) const;

    void apply_quad_tolerances_asa(int which) const;

    void apply_quad_tolerances() const;

    void apply_sensitivity_tolerances() const;

    virtual void apply_constraints() const;

    sundials::Context sunctx_;

    mutable std::unique_ptr<void, free_solver_ptr> solver_memory_;

    mutable std::vector<std::unique_ptr<void, free_solver_ptr>>
        solver_memory_B_;

    mutable user_data_type user_data_;

    InternalSensitivityMethod ism_{InternalSensitivityMethod::simultaneous};

    LinearMultistepMethod lmm_{LinearMultistepMethod::BDF};

    NonlinearSolverIteration iter_{NonlinearSolverIteration::newton};

    InterpolationType interp_type_{InterpolationType::polynomial};

    long int maxsteps_{10000};

    std::chrono::duration<double, std::ratio<1>> maxtime_{0};

    mutable CpuTimer simulation_timer_;

    mutable std::unique_ptr<SUNLinSolWrapper> linear_solver_;

    mutable std::unique_ptr<SUNLinSolWrapper> linear_solver_B_;

    mutable std::unique_ptr<SUNNonLinSolWrapper> non_linear_solver_;

    mutable std::unique_ptr<SUNNonLinSolWrapper> non_linear_solver_B_;

    mutable std::unique_ptr<SUNNonLinSolWrapper> non_linear_solver_sens_;

    mutable bool solver_was_called_F_{false};

    mutable bool solver_was_called_B_{false};

    void set_init_done() const;

    void set_sens_init_done() const;

    void set_adj_init_done() const;

    void set_init_done_b(int which) const;

    void set_quad_init_done_b(int which) const;

    void set_quad_init_done() const;

    void check_sensitivity_method(
        SensitivityMethod sensi_meth, bool preequilibration
    ) const;

    virtual void apply_max_nonlin_iters() const = 0;

    virtual void apply_max_conv_fails() const = 0;

    virtual void apply_max_step_size() const = 0;

    mutable AmiVector x_{0, sunctx_};

    mutable AmiVector dky_{0, sunctx_};

    mutable AmiVector dx_{0, sunctx_};

    mutable AmiVectorArray sx_{0, 0, sunctx_};
    mutable AmiVectorArray sdx_{0, 0, sunctx_};

    mutable AmiVector xB_{0, sunctx_};

    mutable AmiVector dxB_{0, sunctx_};

    mutable AmiVector xQB_{0, sunctx_};

    mutable AmiVector xQ_{0, sunctx_};

    mutable realtype t_{std::nan("")};

    mutable bool force_reinit_postprocess_F_{false};

    mutable bool force_reinit_postprocess_B_{false};

    mutable bool sens_initialized_{false};

    mutable AmiVector constraints_;

  private:
    void apply_max_num_steps() const;

    void apply_max_num_steps_B() const;

    SensitivityMethod sensi_meth_{SensitivityMethod::forward};

    SensitivityMethod sensi_meth_preeq_{SensitivityMethod::forward};

    sunbooleantype stldet_{SUNTRUE};

    int ordering_{static_cast<int>(SUNLinSolKLU::StateOrdering::AMD)};

    long int newton_maxsteps_{0L};

    NewtonDampingFactorMode newton_damping_factor_mode_{
        NewtonDampingFactorMode::on
    };

    realtype newton_damping_factor_lower_bound_{1e-8};

    LinearSolver linsol_{LinearSolver::KLU};

    realtype atol_{1e-16};

    realtype rtol_{1e-8};

    realtype atol_fsa_{NAN};

    realtype rtol_fsa_{NAN};

    realtype atolB_{NAN};

    realtype rtolB_{NAN};

    realtype quad_atol_{1e-12};

    realtype quad_rtol_{1e-8};

    realtype ss_tol_factor_{1e2};

    realtype ss_atol_{NAN};

    realtype ss_rtol_{NAN};

    realtype ss_tol_sensi_factor_{1e2};

    realtype ss_atol_sensi_{NAN};

    realtype ss_rtol_sensi_{NAN};

    RDataReporting rdata_mode_{RDataReporting::full};

    bool newton_step_steadystate_conv_{false};

    bool check_sensi_steadystate_conv_{true};

    int max_nonlin_iters_{3};

    int max_conv_fails_{10};

    realtype max_step_size_{0.0};

    mutable realtype cpu_time_{0.0};

    mutable realtype cpu_time_b_{0.0};

    long int maxstepsB_{0L};

    SensitivityOrder sensi_{SensitivityOrder::none};

    mutable bool initialized_{false};

    mutable bool adj_initialized_{false};

    mutable bool quad_initialized_{false};

    mutable std::vector<bool> initializedB_{false};

    mutable std::vector<bool> initializedQB_{false};

    mutable int ncheckPtr_{0};

    mutable std::vector<int> ns_;

    mutable std::vector<int> nsB_;

    mutable std::vector<int> nrhs_;

    mutable std::vector<int> nrhsB_;

    mutable std::vector<int> netf_;

    mutable std::vector<int> netfB_;

    mutable std::vector<int> nnlscf_;

    mutable std::vector<int> nnlscfB_;

    mutable std::vector<int> order_;

    Logger* logger_ = nullptr;
};

bool operator==(Solver const& a, Solver const& b);

} // namespace amici

#endif // AMICISOLVER_H