{ "cells": [ { "cell_type": "markdown", "id": "d4d2bc5c", "metadata": {}, "source": [ "# AMICI & JAX\n", "\n", "## Overview\n", "\n", "The purpose of this guide is to showcase how AMICI can be combined with differentiable programming in [JAX](https://jax.readthedocs.io/en/latest/index.html). We will do so by reimplementing the parameter transformations available in AMICI in JAX and comparing it to the native implementation." ] }, { "cell_type": "code", "execution_count": 1, "id": "b0a66e18", "metadata": { "ExecuteTime": { "end_time": "2024-10-18T20:35:50.848939Z", "start_time": "2024-10-18T20:35:43.411198Z" } }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp" ] }, { "cell_type": "markdown", "id": "fb2fe897", "metadata": {}, "source": [ "## Preparation\n", "\n", "To get started, we will import a model using the [petab](https://petab.readthedocs.io). To this end, we will use the [benchmark collection](https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab), which features a variety of different models. For more details about petab import, see the respective notebook petab [notebook](https://amici.readthedocs.io/en/latest/petab.html)." ] }, { "cell_type": "markdown", "id": "8552f123", "metadata": {}, "source": [ "From the benchmark collection, we now import the Böhm model." ] }, { "cell_type": "code", "execution_count": 2, "id": "9166e3bf", "metadata": { "ExecuteTime": { "end_time": "2024-10-18T20:36:51.490345Z", "start_time": "2024-10-18T20:35:50.853467Z" } }, "outputs": [], "source": [ "import petab.v1 as petab\n", "\n", "model_name = \"Boehm_JProteomeRes2014\"\n", "yaml_file = f\"https://raw.githubusercontent.com/Benchmarking-Initiative/Benchmark-Models-PEtab/master/Benchmark-Models/{model_name}/{model_name}.yaml\"\n", "petab_problem = petab.Problem.from_yaml(yaml_file)" ] }, { "cell_type": "markdown", "id": "d4038fc4", "metadata": {}, "source": [ "The petab problem includes information about parameter scaling in it's the parameter table. For the boehm model, all estimated parameters (`petab.ESTIMATE` column equal to `1`) have a `petab.LOG10` as parameter scaling." ] }, { "cell_type": "code", "execution_count": 3, "id": "b04ca561", "metadata": { "ExecuteTime": { "end_time": "2024-10-18T20:36:51.740603Z", "start_time": "2024-10-18T20:36:51.725877Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
parameterNameparameterScalelowerBoundupperBoundnominalValueestimate
parameterId
Epo_degradation_BaF3EPO_{degradation,BaF3}log100.000011000000.0269831
k_exp_heterok_{exp,hetero}log100.000011000000.0000101
k_exp_homok_{exp,homo}log100.000011000000.0061701
k_imp_heterok_{imp,hetero}log100.000011000000.0163681
k_imp_homok_{imp,homo}log100.0000110000097749.3794021
k_phosk_{phos}log100.0000110000015766.5070201
ratioratiolin0.0000050.6930000
sd_pSTAT5A_rel\\sigma_{pSTAT5A,rel}log100.000011000003.8526121
sd_pSTAT5B_rel\\sigma_{pSTAT5B,rel}log100.000011000006.5914781
sd_rSTAT5A_rel\\sigma_{rSTAT5A,rel}log100.000011000003.1527131
specC17specC17lin-5.0000050.1070000
\n", "
" ], "text/plain": [ " parameterName parameterScale lowerBound \\\n", "parameterId \n", "Epo_degradation_BaF3 EPO_{degradation,BaF3} log10 0.00001 \n", "k_exp_hetero k_{exp,hetero} log10 0.00001 \n", "k_exp_homo k_{exp,homo} log10 0.00001 \n", "k_imp_hetero k_{imp,hetero} log10 0.00001 \n", "k_imp_homo k_{imp,homo} log10 0.00001 \n", "k_phos k_{phos} log10 0.00001 \n", "ratio ratio lin 0.00000 \n", "sd_pSTAT5A_rel \\sigma_{pSTAT5A,rel} log10 0.00001 \n", "sd_pSTAT5B_rel \\sigma_{pSTAT5B,rel} log10 0.00001 \n", "sd_rSTAT5A_rel \\sigma_{rSTAT5A,rel} log10 0.00001 \n", "specC17 specC17 lin -5.00000 \n", "\n", " upperBound nominalValue estimate \n", "parameterId \n", "Epo_degradation_BaF3 100000 0.026983 1 \n", "k_exp_hetero 100000 0.000010 1 \n", "k_exp_homo 100000 0.006170 1 \n", "k_imp_hetero 100000 0.016368 1 \n", "k_imp_homo 100000 97749.379402 1 \n", "k_phos 100000 15766.507020 1 \n", "ratio 5 0.693000 0 \n", "sd_pSTAT5A_rel 100000 3.852612 1 \n", "sd_pSTAT5B_rel 100000 6.591478 1 \n", "sd_rSTAT5A_rel 100000 3.152713 1 \n", "specC17 5 0.107000 0 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "petab_problem.parameter_df" ] }, { "cell_type": "markdown", "id": "8914a18d", "metadata": {}, "source": [ "We now import the petab problem using [amici.petab_import](https://amici.readthedocs.io/en/latest/generated/amici.petab_import.html#amici.petab_import.import_petab_problem)." ] }, { "cell_type": "code", "execution_count": 4, "id": "6ada3fb8", "metadata": { "ExecuteTime": { "start_time": "2024-10-18T20:36:51.765461Z" }, "jupyter": { "is_executing": true } }, "outputs": [], "source": [ "from amici.petab.petab_import import import_petab_problem\n", "\n", "amici_model = import_petab_problem(petab_problem, compile_=True, verbose=False)" ] }, { "cell_type": "markdown", "id": "e2ef051a", "metadata": {}, "source": [ "## JAX implementation\n", "\n", "For full jax support, we would have to implement a new [primitive](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html), which would require quite a bit of engineering, and in the end wouldn't add much benefit since AMICI can't run on GPUs. Instead, we will interface AMICI using the jax method [pure_callback](https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html)." ] }, { "cell_type": "markdown", "id": "6bbf2f06", "metadata": {}, "source": [ "To do so, we define a base function that only takes a single argument (the parameters) and runs simulation using petab via [simulate_petab](https://amici.readthedocs.io/en/latest/generated/amici.petab_objective.html#amici.petab_objective.simulate_petab). To enable gradient computation later on, we create a solver object and set the sensitivity order to first order and pass it to `simulate_petab`. Moreover, `simulate_petab` expects a dictionary of parameters, so we create a dictionary using the free parameter ids from the petab problem." ] }, { "cell_type": "code", "execution_count": 5, "id": "72053647", "metadata": {}, "outputs": [], "source": [ "import amici\n", "from amici.petab.simulations import simulate_petab\n", "\n", "amici_solver = amici_model.create_solver()\n", "amici_solver.set_sensitivity_order(amici.SensitivityOrder.first)\n", "\n", "\n", "def amici_callback_base(parameters: jnp.array):\n", " ret = simulate_petab(\n", " petab_problem,\n", " amici_model,\n", " problem_parameters=dict(zip(petab_problem.x_free_ids, parameters)),\n", " solver=amici_solver,\n", " )\n", " llh = np.array(ret[\"llh\"])\n", " sllh = np.array(\n", " tuple(ret[\"sllh\"][par_id] for par_id in petab_problem.x_free_ids)\n", " )\n", " return llh, sllh" ] }, { "cell_type": "markdown", "id": "6f6201e8", "metadata": {}, "source": [ "Now we can use this base function to create two separate functions that return the log-likelihood (`llh`) and a tuple with log-likelihood and its gradient (`sllh`). Both functions use [pure_callback](https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html) such that they can be called by other jax functions. Note that, as we are using the same base function here, the log-likelihood computation will also run with sensitivities which is not necessary and will add some overhead. This is only out of convenience and should be fixed in an application where efficiency is important." ] }, { "cell_type": "code", "execution_count": 6, "id": "2dd50b53", "metadata": {}, "outputs": [], "source": [ "def device_fun_llh(x: jnp.array):\n", " return jax.pure_callback(\n", " lambda x: amici_callback_base(x)[0],\n", " jax.ShapeDtypeStruct((), x.dtype),\n", " x,\n", " )\n", "\n", "\n", "def device_fun_llh_sllh(x: jnp.array):\n", " return jax.pure_callback(\n", " amici_callback_base,\n", " (\n", " jax.ShapeDtypeStruct((), x.dtype),\n", " jax.ShapeDtypeStruct(\n", " x.shape,\n", " x.dtype,\n", " ),\n", " ),\n", " x,\n", " )" ] }, { "cell_type": "markdown", "id": "98e819bd", "metadata": {}, "source": [ "Even though the two functions that we just defined are valid jax functions, they can't compute derivatives yet. To support derivative computation, we have to define a new function with a `jax.custom_jvp` decorator, which specifies that we will define a custom jacobian vector product (jvp) function, as well as the corresponding jvp using the `@jax_objective.defjvp` decorator. More details about custom jacobian vector product functions can be found in the [JAX documentation](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)" ] }, { "cell_type": "code", "execution_count": 7, "id": "6e1f4f43", "metadata": {}, "outputs": [], "source": [ "@jax.custom_jvp\n", "def jax_objective(parameters: jnp.array):\n", " return device_fun_llh(parameters)\n", "\n", "\n", "@jax_objective.defjvp\n", "def jax_objective_jvp(primals: jnp.array, tangents: jnp.array):\n", " (parameters,) = primals\n", " (x_dot,) = tangents\n", " llh, sllh = device_fun_llh_sllh(parameters)\n", " return llh, sllh @ x_dot" ] }, { "cell_type": "markdown", "id": "379485ca", "metadata": {}, "source": [ "As last step, we implement the parameter transformation in jax. This effectively just extracts parameter scales from the petab problem, implements rescaling in jax and then passes the scaled parameters to the objective function we previously defined. We add the `jax.value_and_grad` decorator such that the generated jax function returns both function value and function gradient in a tuple. Moreover, we add the `jax.jit` decorator such that the function is [just-in-time compiled](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) upon the first function call." ] }, { "cell_type": "code", "execution_count": 8, "id": "3ab8fde9", "metadata": {}, "outputs": [], "source": [ "parameter_scales = petab_problem.parameter_df.loc[\n", " petab_problem.x_free_ids, petab.PARAMETER_SCALE\n", "].values\n", "\n", "\n", "@jax.jit\n", "@jax.value_and_grad\n", "def jax_objective_with_parameter_transform(parameters: jnp.array):\n", " par_scaled = jnp.asarray(\n", " tuple(\n", " value\n", " if scale == petab.LIN\n", " else jnp.exp(value)\n", " if scale == petab.LOG\n", " else jnp.power(10, value)\n", " for value, scale in zip(parameters, parameter_scales)\n", " )\n", " )\n", " return jax_objective(par_scaled)" ] }, { "cell_type": "markdown", "id": "293e29fb", "metadata": {}, "source": [ "## Testing\n", "\n", "We can now run the function to compute the log-likelihood and the gradient." ] }, { "cell_type": "code", "execution_count": 9, "id": "b7e9ff3b", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "parameters = dict(zip(petab_problem.x_free_ids, petab_problem.x_nominal_free))\n", "scaled_parameters = petab_problem.scale_parameters(parameters)\n", "scaled_parameters_np = np.asarray(list(scaled_parameters.values()))" ] }, { "cell_type": "code", "execution_count": 10, "id": "fb3085a8", "metadata": {}, "outputs": [], "source": [ "llh_jax, sllh_jax = jax_objective_with_parameter_transform(\n", " jnp.array(scaled_parameters_np)\n", ")" ] }, { "cell_type": "markdown", "id": "6aa4a5f7", "metadata": {}, "source": [ "As a sanity check, we compare the computed value to native parameter transformation in amici." ] }, { "cell_type": "code", "execution_count": 11, "id": "48451b0e", "metadata": {}, "outputs": [], "source": [ "r = simulate_petab(\n", " petab_problem,\n", " amici_model,\n", " solver=amici_solver,\n", " scaled_parameters=True,\n", " scaled_gradients=True,\n", " problem_parameters=scaled_parameters,\n", ")" ] }, { "cell_type": "code", "execution_count": 12, "id": "2628db12", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
amicijaxrel_diff
llh-138.221997-138.222-2.135248e-08
\n", "
" ], "text/plain": [ " amici jax rel_diff\n", "llh -138.221997 -138.222 -2.135248e-08" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "pd.DataFrame(\n", " dict(\n", " amici=r[\"llh\"],\n", " jax=float(llh_jax),\n", " rel_diff=(r[\"llh\"] - float(llh_jax)) / r[\"llh\"],\n", " ),\n", " index=(\"llh\",),\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "0846523f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
amicijaxrel_diff
Epo_degradation_BaF3-0.022045-0.0220344.645833e-04
k_exp_hetero-0.055323-0.0553238.646725e-08
k_exp_homo-0.005789-0.005801-2.013520e-03
k_imp_hetero-0.005414-0.0054031.973604e-03
k_imp_homo0.0000450.0000451.119566e-06
k_phos-0.007907-0.0077941.447768e-02
sd_pSTAT5A_rel-0.010784-0.010800-1.469518e-03
sd_pSTAT5B_rel-0.024037-0.024037-8.729860e-06
sd_rSTAT5A_rel-0.019191-0.0191862.829431e-04
\n", "
" ], "text/plain": [ " amici jax rel_diff\n", "Epo_degradation_BaF3 -0.022045 -0.022034 4.645833e-04\n", "k_exp_hetero -0.055323 -0.055323 8.646725e-08\n", "k_exp_homo -0.005789 -0.005801 -2.013520e-03\n", "k_imp_hetero -0.005414 -0.005403 1.973604e-03\n", "k_imp_homo 0.000045 0.000045 1.119566e-06\n", "k_phos -0.007907 -0.007794 1.447768e-02\n", "sd_pSTAT5A_rel -0.010784 -0.010800 -1.469518e-03\n", "sd_pSTAT5B_rel -0.024037 -0.024037 -8.729860e-06\n", "sd_rSTAT5A_rel -0.019191 -0.019186 2.829431e-04" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grad_amici = np.asarray(list(r[\"sllh\"].values()))\n", "grad_jax = np.asarray(sllh_jax)\n", "rel_diff = (grad_amici - grad_jax) / grad_jax\n", "pd.DataFrame(\n", " index=r[\"sllh\"].keys(),\n", " data=dict(amici=grad_amici, jax=grad_jax, rel_diff=rel_diff),\n", ")" ] }, { "cell_type": "markdown", "id": "4b00dcb2", "metadata": {}, "source": [ "We see quite some differences in the gradient calculation, with over to 1% error for `k_phos`. The primary reason is that running JAX in default configuration will use float32 precision for the parameters that are passed to AMICI, which uses float64, and the derivative of the parameter transformation.\n", "As AMICI simulations that run on the CPU are the most expensive operation, there is barely any tradeoff for using float32 vs. float64 in JAX. Therefore, we configure JAX to use float64 instead and rerun simulations." ] }, { "cell_type": "code", "execution_count": 14, "id": "5f81c693", "metadata": {}, "outputs": [], "source": [ "jax.config.update(\"jax_enable_x64\", True)\n", "llh_jax, sllh_jax = jax_objective_with_parameter_transform(\n", " scaled_parameters_np\n", ")" ] }, { "cell_type": "markdown", "id": "ab39311d", "metadata": {}, "source": [ "We can now evaluate the results again and see that differences between pure AMICI and AMICI/JAX implementations have now disappeared." ] }, { "cell_type": "code", "execution_count": 15, "id": "25e8b301", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
amicijaxrel_diff
llh-138.221997-138.221997-0.0
\n", "
" ], "text/plain": [ " amici jax rel_diff\n", "llh -138.221997 -138.221997 -0.0" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(\n", " dict(\n", " amici=r[\"llh\"],\n", " jax=float(llh_jax),\n", " rel_diff=(r[\"llh\"] - float(llh_jax)) / r[\"llh\"],\n", " ),\n", " index=(\"llh\",),\n", ")" ] }, { "cell_type": "code", "execution_count": 16, "id": "f31a3927", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
amicijaxrel_diff
Epo_degradation_BaF3-0.022045-0.022045-0.0
k_exp_hetero-0.055323-0.055323-0.0
k_exp_homo-0.005789-0.005789-0.0
k_imp_hetero-0.005414-0.005414-0.0
k_imp_homo0.0000450.0000450.0
k_phos-0.007907-0.007907-0.0
sd_pSTAT5A_rel-0.010784-0.010784-0.0
sd_pSTAT5B_rel-0.024037-0.024037-0.0
sd_rSTAT5A_rel-0.019191-0.019191-0.0
\n", "
" ], "text/plain": [ " amici jax rel_diff\n", "Epo_degradation_BaF3 -0.022045 -0.022045 -0.0\n", "k_exp_hetero -0.055323 -0.055323 -0.0\n", "k_exp_homo -0.005789 -0.005789 -0.0\n", "k_imp_hetero -0.005414 -0.005414 -0.0\n", "k_imp_homo 0.000045 0.000045 0.0\n", "k_phos -0.007907 -0.007907 -0.0\n", "sd_pSTAT5A_rel -0.010784 -0.010784 -0.0\n", "sd_pSTAT5B_rel -0.024037 -0.024037 -0.0\n", "sd_rSTAT5A_rel -0.019191 -0.019191 -0.0" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grad_amici = np.asarray(list(r[\"sllh\"].values()))\n", "grad_jax = np.asarray(sllh_jax)\n", "rel_diff = (grad_amici - grad_jax) / grad_jax\n", "pd.DataFrame(\n", " index=r[\"sllh\"].keys(),\n", " data=dict(amici=grad_amici, jax=grad_jax, rel_diff=rel_diff),\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.0" } }, "nbformat": 4, "nbformat_minor": 5 }