Program Listing for File backwardproblem.h

Return to documentation for file (include/amici/backwardproblem.h)

#ifndef AMICI_BACKWARDPROBLEM_H
#define AMICI_BACKWARDPROBLEM_H

#include "amici/defines.h"
#include "amici/forwardproblem.h"
#include "amici/vector.h"

#include <optional>
#include <vector>

namespace amici {
class ExpData;
class Solver;
class Model;
class ForwardProblem;
class SteadyStateProblem;

struct BwdSimWorkspace {
    BwdSimWorkspace(
        gsl::not_null<Model*> model, gsl::not_null<Solver const*> solver
    );

    Model* model_;

    AmiVector xB_;
    AmiVector dxB_;
    AmiVector xQB_;

    std::vector<int> nroots_;
    std::vector<Discontinuity> discs_;
    int which = 0;
};

class EventHandlingBwdSimulator {
  public:
    EventHandlingBwdSimulator(
        gsl::not_null<Model*> model, gsl::not_null<Solver const*> solver,
        gsl::not_null<BwdSimWorkspace*> ws
    )
        : model_(model)
        , solver_(solver)
        , ws_(ws) {}

    void
    run(realtype t_start, realtype t_end, realtype it,
        std::vector<realtype> const& timepoints,
        std::vector<realtype> const* dJydx, std::vector<realtype> const* dJzdx);

  private:
    void handle_event_b(
        Discontinuity const& disc, std::vector<realtype> const* dJzdx
    );

    void handle_datapoint_b(int it, std::vector<realtype> const* dJydx);

    realtype get_next_t(int it);

    Model* model_;

    Solver const* solver_;

    gsl::not_null<BwdSimWorkspace*> ws_;

    realtype t_{0};
};

class SteadyStateBackwardProblem {
  public:
    SteadyStateBackwardProblem(
        Solver const& solver, Model& model, SolutionState& final_state,
        gsl::not_null<BwdSimWorkspace*> ws
    );

    void run(realtype t0);

    [[nodiscard]] double get_cpu_time_b() const { return cpu_time_b_; }

    [[nodiscard]] int get_num_steps_b() const { return num_steps_b_; }

    [[nodiscard]] AmiVector const& get_adjoint_state() const;

    [[nodiscard]] AmiVector const& get_adjoint_quadrature() const;

    [[nodiscard]] bool has_quadrature() const { return has_quadrature_; }

  private:
    void run_simulation(Solver const& solver);

    void compute_steady_state_quadrature(realtype t0);

    void compute_quadrature_by_lin_solve();

    void compute_quadrature_by_simulation(realtype t0);

    double cpu_time_b_{0.0};

    bool has_quadrature_{false};

    int num_steps_b_{0};

    AmiVector xQ_;

    SolutionState& final_state_;

    NewtonSolver newton_solver_;

    bool newton_step_conv_{false};

    Model* model_{nullptr};
    Solver const* solver_{nullptr};
    BwdSimWorkspace* ws_{nullptr};
};



class BackwardProblem {
  public:
    explicit BackwardProblem(ForwardProblem& fwd);

    void workBackwardProblem();

    [[nodiscard]] AmiVector const& get_adjoint_state_pre_preeq() const {
        return xB_pre_preeq_;
    }

    [[nodiscard]] AmiVector const& get_adjoint_quadrature_pre_preeq() const {
        return xQB_pre_preeq_;
    }

    [[nodiscard]] AmiVector const& get_adjoint_state() const { return ws_.xB_; }

    [[nodiscard]] AmiVector const& get_adjoint_quadrature() const {
        return ws_.xQB_;
    }

    [[nodiscard]] SteadyStateBackwardProblem const*
    get_posteq_bwd_problem() const {
        if (posteq_problem_bwd_.has_value())
            return &*posteq_problem_bwd_;
        return nullptr;
    }

    [[nodiscard]] SteadyStateBackwardProblem const*
    get_preeq_bwd_problem() const {
        if (preeq_problem_bwd_.has_value())
            return &*preeq_problem_bwd_;
        return nullptr;
    }

  private:
    void handle_postequilibration();

    Model* model_;
    Solver* solver_;
    ExpData const* edata_;

    realtype t_;

    std::vector<Discontinuity> discs_main_;

    std::vector<realtype> dJydx_;
    std::vector<realtype> const dJzdx_;

    SteadyStateProblem* preeq_problem_;

    SteadyStateProblem* posteq_problem_;

    PeriodResult presim_result;

    BwdSimWorkspace ws_;

    EventHandlingBwdSimulator simulator_;

    std::optional<SteadyStateBackwardProblem> preeq_problem_bwd_;

    std::optional<SteadyStateBackwardProblem> posteq_problem_bwd_;

    AmiVector xB_pre_preeq_;

    AmiVector xQB_pre_preeq_;
};

} // namespace amici

#endif // AMICI_BACKWARDPROBLEM_H