Source code for amici.petab.petab_import

"""
PEtab Import
------------
Import a model in the :mod:`petab` (https://github.com/PEtab-dev/PEtab) format
into AMICI.
"""

import logging
import os
import re
import shutil
from pathlib import Path

import pandas as pd
import petab.v1 as petab
from petab.v1.models import MODEL_TYPE_PYSB, MODEL_TYPE_SBML

import amici

from ..logging import get_logger
from .import_helpers import (
    _can_import_model,
    _create_model_name,
    _get_package_name_and_path,
    check_model,
)
from .sbml_import import import_model_sbml

try:
    from .pysb_import import import_model_pysb
except ModuleNotFoundError:
    # pysb not available
    import_model_pysb = None


__all__ = ["import_petab_problem"]

logger = get_logger(__name__, logging.WARNING)


[docs] def import_petab_problem( petab_problem: petab.Problem, model_output_dir: str | Path | None = None, model_name: str = None, compile_: bool = None, non_estimated_parameters_as_constants=True, jax=False, **kwargs, ) -> "amici.Model | amici.jax.JAXProblem": """ Create an AMICI model for a PEtab problem. :param petab_problem: A petab problem containing all relevant information on the model. :param model_output_dir: Directory to write the model code to. It will be created if it doesn't exist. Defaults to :func:`amici.get_model_dir`. :param model_name: Name of the generated model module. Defaults to the ID of the model or the model file name without the extension. :param compile_: If ``True``, the model will be compiled. If ``False``, the model will not be compiled. If ``None``, the model will be compiled if it cannot be imported. :param non_estimated_parameters_as_constants: Whether parameters marked as non-estimated in PEtab should be considered constant in AMICI. Setting this to ``True`` will reduce model size and simulation times. If sensitivities with respect to those parameters are required, this should be set to ``False``. :param jax: Whether to create a JAX-based problem. If ``True``, returns a :class:`amici.jax.JAXProblem` instance. If ``False``, returns a standard AMICI model. :param kwargs: Additional keyword arguments to be passed to :meth:`amici.sbml_import.SbmlImporter.sbml2amici` or :func:`amici.pysb_import.pysb2amici`, depending on the model type. :return: The imported model (if ``jax=False``) or JAX problem (if ``jax=True``). """ if petab_problem.model.type_id not in (MODEL_TYPE_SBML, MODEL_TYPE_PYSB): raise NotImplementedError( "Unsupported model type " + petab_problem.model.type_id ) model_name = model_name or petab_problem.model.model_id if petab_problem.model.type_id == MODEL_TYPE_PYSB and model_name is None: model_name = petab_problem.pysb_model.name elif model_name is None and model_output_dir: model_name = _create_model_name(model_output_dir) # generate folder and model name if necessary if model_output_dir is None: model_output_dir = amici.get_model_dir(model_name, jax=jax).absolute() else: model_output_dir = Path(model_output_dir).absolute() model_output_dir.mkdir(parents=True, exist_ok=True) # check if compilation necessary if compile_ or ( compile_ is None and not _can_import_model(model_name, model_output_dir, jax) ): # check if folder exists if os.listdir(model_output_dir) and not compile_: raise ValueError( f"Cannot compile to {model_output_dir}: not empty. " "Please assign a different target or set `compile_` to `True`." ) # remove folder if exists if not jax and os.path.exists(model_output_dir): shutil.rmtree(model_output_dir) logger.info(f"Compiling model {model_name} to {model_output_dir}.") if "sciml" in petab_problem.extensions_config: from petab_sciml.standard import NNModelStandard config = petab_problem.extensions_config["sciml"] # TODO: only accept YAML format for now hybridizations = [ pd.read_csv(hf, sep="\t") for hf in config["hybridization_files"] ] hybridization_table = pd.concat(hybridizations) input_mapping = dict( zip( hybridization_table["targetId"], hybridization_table["targetValue"], ) ) output_mapping = dict( zip( hybridization_table["targetValue"], hybridization_table["targetId"], ) ) observable_mapping = dict( zip( petab_problem.observable_df["observableFormula"], petab_problem.observable_df.index, ) ) hybridization = { net_id: { "model": NNModelStandard.load_data( Path(net_config["location"]) ), "input_vars": [ input_mapping[petab_id] for petab_id, model_id in petab_problem.mapping_df.loc[ petab_problem.mapping_df[petab.MODEL_ENTITY_ID] .str.split(".") .str[0] == net_id, petab.MODEL_ENTITY_ID, ] .to_dict() .items() if model_id.split(".")[1].startswith("input") and petab_id in input_mapping.keys() ], "output_vars": { output_mapping[petab_id]: _get_net_index(model_id) for petab_id, model_id in petab_problem.mapping_df.loc[ petab_problem.mapping_df[petab.MODEL_ENTITY_ID] .str.split(".") .str[0] == net_id, petab.MODEL_ENTITY_ID, ] .to_dict() .items() if model_id.split(".")[1].startswith("output") and petab_id in output_mapping.keys() }, "observable_vars": { observable_mapping[petab_id]: _get_net_index(model_id) for petab_id, model_id in petab_problem.mapping_df.loc[ petab_problem.mapping_df[petab.MODEL_ENTITY_ID] .str.split(".") .str[0] == net_id, petab.MODEL_ENTITY_ID, ] .to_dict() .items() if model_id.split(".")[1].startswith("output") and petab_id in observable_mapping.keys() }, "frozen_layers": dict( [ _get_frozen_layers(model_id) for petab_id, model_id in petab_problem.mapping_df.loc[ petab_problem.mapping_df[petab.MODEL_ENTITY_ID] .str.split(".") .str[0] == net_id, petab.MODEL_ENTITY_ID, ] .to_dict() .items() if petab_id in petab_problem.parameter_df.index and petab_problem.parameter_df.loc[ petab_id, petab.ESTIMATE ] == 0 ] ), **net_config, } for net_id, net_config in config["neural_nets"].items() } if not jax or petab_problem.model.type_id != MODEL_TYPE_SBML: raise NotImplementedError( "petab_sciml extension is currently only supported for sbml models" ) else: hybridization = None # compile the model if petab_problem.model.type_id == MODEL_TYPE_PYSB: import_model_pysb( petab_problem, model_name=model_name, model_output_dir=model_output_dir, jax=jax, **kwargs, ) else: import_model_sbml( petab_problem=petab_problem, model_name=model_name, model_output_dir=model_output_dir, non_estimated_parameters_as_constants=non_estimated_parameters_as_constants, hybridization=hybridization, jax=jax, **kwargs, ) # import model model_module = amici.import_model_module( *_get_package_name_and_path(model_name, model_output_dir, jax=jax) ) if jax: from amici.jax import JAXProblem model = model_module.Model() logger.info( f"Successfully loaded jax model {model_name} " f"from {model_output_dir}." ) # Create and return JAXProblem logger.info(f"Successfully created JAXProblem for {model_name}.") return JAXProblem(model, petab_problem) model = model_module.get_model() check_model(amici_model=model, petab_problem=petab_problem) logger.info( f"Successfully loaded model {model_name} from {model_output_dir}." ) return model
def _get_net_index(model_id: str): matches = re.findall(r"\[(\d+)\]", model_id) if matches: return int(matches[-1]) def _get_frozen_layers(model_id): layers = re.findall(r"\[(.*?)\]", model_id) array_attr = model_id.split(".")[-1] layer_id = layers[0] if len(layers) else None array_attr = array_attr if array_attr in ("weight", "bias") else None return layer_id, array_attr