Source code for amici.jax.nn

from pathlib import Path

import equinox as eqx
import jax.numpy as jnp

from amici import amiciModulePath
from amici._codegen.template import apply_template


[docs] class Flatten(eqx.Module): """Custom implementation of a `torch.flatten` layer for Equinox.""" start_dim: int end_dim: int
[docs] def __init__(self, start_dim: int, end_dim: int): super().__init__() self.start_dim = start_dim self.end_dim = end_dim
def __call__(self, x): if self.end_dim == -1: return jnp.reshape(x, x.shape[: self.start_dim] + (-1,)) else: return jnp.reshape( x, x.shape[: self.start_dim] + (-1,) + x.shape[self.end_dim :] )
[docs] def tanhshrink(x: jnp.ndarray) -> jnp.ndarray: """Custom implementation of the torch.nn.Tanhshrink activation function for JAX.""" return x - jnp.tanh(x)
[docs] def cat(tensors, axis: int = 0): """Alias for torch.cat using JAX's concatenate/stack function. Handles both regular arrays and zero-dimensional (scalar) arrays by using stack instead of concatenate for 0D arrays. :param tensors: List of arrays to concatenate :param axis: Dimension along which to concatenate (default: 0) :return: Concatenated array """ # Check if all tensors are 0-dimensional (scalars) if all(jnp.ndim(t) == 0 for t in tensors): # For 0D arrays, use stack instead of concatenate return jnp.stack(tensors, axis=axis) return jnp.concatenate(tensors, axis=axis)
[docs] def generate_equinox( nn_model: "NNModel", # noqa: F821 filename: Path | str, frozen_layers: dict[str, bool] | None = None, ) -> None: """ Generate Equinox model file from petab_sciml neural network object. :param nn_model: Neural network model in petab_sciml format :param filename: output filename for generated Equinox model :param frozen_layers: list of layer names to freeze during training """ # TODO: move to top level import and replace forward type definitions from petab_sciml import Layer if frozen_layers is None: frozen_layers = {} filename = Path(filename) layer_indent = 12 node_indent = 8 layers = {layer.layer_id: layer for layer in nn_model.layers} # Collect placeholder nodes to determine input handling placeholder_nodes = [ node for node in nn_model.forward if node.op == "placeholder" ] input_names = [node.name for node in placeholder_nodes] # Generate input unpacking line if len(input_names) == 1: input_unpack = f"{input_names[0]} = input" else: input_unpack = f"{', '.join(input_names)} = input" # Generate forward pass lines (excluding placeholder nodes) forward_lines = [ _generate_forward( node, node_indent, frozen_layers, layers.get( node.target, Layer(layer_id="dummy", layer_type="Linear"), ).layer_type, ) for node in nn_model.forward ] # Filter out empty lines from placeholder processing forward_lines = [line for line in forward_lines if line] # Prepend input unpacking forward_code = f"{' ' * node_indent}{input_unpack}\n" + "\n".join( forward_lines ) tpl_data = { "MODEL_ID": nn_model.nn_model_id, "LAYERS": ",\n".join( [ _generate_layer(layer, layer_indent, ilayer) for ilayer, layer in enumerate(nn_model.layers) ] )[layer_indent:], "FORWARD": forward_code[node_indent:], "INPUT": ", ".join([f"'{inp.input_id}'" for inp in nn_model.inputs]), "OUTPUT": ", ".join( [ f"'{arg}'" for arg in next( node for node in nn_model.forward if node.op == "output" ).args ] ), "N_LAYERS": len(nn_model.layers), } filename.parent.mkdir(parents=True, exist_ok=True) apply_template( Path(amiciModulePath) / "jax" / "nn.template.py", filename, tpl_data, )
def _process_argval(v): """ Process argument value for layer instantiation string """ if isinstance(v, str): return f"'{v}'" if isinstance(v, bool): return str(v) return str(v) def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F821 """ Generate layer definition string for a given layer :param layer: petab_sciml Layer object :param indent: indentation level for generated string :param ilayer: layer index for key generation :return: string defining the layer in equinox syntax """ if layer.layer_type.startswith( ("BatchNorm", "AlphaDropout", "InstanceNorm") ): raise NotImplementedError( f"{layer.layer_type} layers currently not supported" ) if layer.layer_type.startswith("MaxPool") and "dilation" in layer.args: raise NotImplementedError("MaxPool layers with dilation not supported") if layer.layer_type.startswith("Dropout") and "inplace" in layer.args: raise NotImplementedError("Dropout layers with inplace not supported") if layer.layer_type == "Bilinear": raise NotImplementedError("Bilinear layers not supported") # mapping of layer names in sciml yaml format to equinox/custom amici implementations layer_map = { "Dropout1d": "eqx.nn.Dropout", "Dropout2d": "eqx.nn.Dropout", "Flatten": "amici.jax.Flatten", } # mapping of keyword argument names in sciml yaml format to equinox/custom amici implementations kwarg_map = { "Linear": { "bias": "use_bias", }, "Conv1d": { "bias": "use_bias", }, "Conv2d": { "bias": "use_bias", }, "LayerNorm": { "elementwise_affine": "use_bias", # Deprecation warning - replace LayerNorm(elementwise_affine) with LayerNorm(use_bias) "normalized_shape": "shape", }, } # list of keyword arguments to ignore when generating layer, as they are not supported in equinox (see above) kwarg_ignore = { "Dropout1d": ("inplace",), "Dropout2d": ("inplace",), } # construct argument string for layer instantiation kwargs = [ f"{kwarg_map.get(layer.layer_type, {}).get(k, k)}={_process_argval(v)}" for k, v in layer.args.items() if k not in kwarg_ignore.get(layer.layer_type, ()) ] # add key for initialization if layer.layer_type in ( "Linear", "Conv1d", "Conv2d", "Conv3d", "ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d", ): kwargs += [f"key=keys[{ilayer}]"] type_str = layer_map.get(layer.layer_type, f"eqx.nn.{layer.layer_type}") layer_str = f"{type_str}({', '.join(kwargs)})" return f"{' ' * indent}'{layer.layer_id}': {layer_str}" def _format_function_call( var_name: str, fun_str: str, args: list, kwargs: list[str], indent: int ) -> str: """ Utility function to format a function call assignment string. :param var_name: name of the variable to assign the result to :param fun_str: string representation of the function to call :param args: list of positional arguments :param kwargs: list of keyword arguments as strings :param indent: indentation level for generated string :return: formatted string representing the function call assignment """ args_str = ", ".join([f"{arg}" for arg in args]) kwargs_str = ", ".join(kwargs) all_args = ", ".join(filter(None, [args_str, kwargs_str])) return f"{' ' * indent}{var_name} = {fun_str}({all_args})" def _process_layer_call( node: "Node", # noqa: F821 layer_type: str, frozen_layers: dict[str, bool], ) -> tuple[str, str]: """ Process a layer (call_module) node and return function string and optional tree string. :param node: petab sciml Node object representing a layer call :param layer_type: petab sciml layer type of the node :param frozen_layers: dict of layer names to boolean indicating whether layer is frozen :return: tuple of (function_string, tree_string) where tree_string is empty if no tree is needed """ fun_str = f"self.layers['{node.target}']" tree_string = "" # Handle frozen layers if node.name in frozen_layers: if frozen_layers[node.name]: arr_attr = frozen_layers[node.name] get_lambda = f"lambda layer: getattr(layer, '{arr_attr}')" replacer = "replace_fn = lambda arr: jax.lax.stop_gradient(arr)" tree_string = f"tree_{node.name} = eqx.tree_at({get_lambda}, {fun_str}, {replacer})" fun_str = f"tree_{node.name}" else: fun_str = f"jax.lax.stop_gradient({fun_str})" # Handle vmap for certain layer types if layer_type.startswith(("Conv", "Linear", "LayerNorm")): if layer_type in ("LayerNorm",): dims = f"len({fun_str}.shape)+1" elif layer_type == "Linear": dims = 2 elif layer_type.endswith("1d"): dims = 3 elif layer_type.endswith("2d"): dims = 4 elif layer_type.endswith("3d"): dims = 5 fun_str = f"(jax.vmap({fun_str}) if len({node.args[0]}.shape) == {dims} else {fun_str})" return fun_str, tree_string def _process_activation_call(node: "Node") -> str: # noqa: F821 """ Process an activation function (call_function/call_method) node and return function string. :param node: petab sciml Node object representing an activation function call :return: string representation of the activation function """ # Mapping of function names in sciml yaml format to equinox/custom amici implementations activation_map = { "hardtanh": "jax.nn.hard_tanh", "hardsigmoid": "jax.nn.hard_sigmoid", "hardswish": "jax.nn.hard_swish", "tanhshrink": "amici.jax.tanhshrink", "softsign": "jax.nn.soft_sign", "cat": "amici.jax.cat", } # Validate hardtanh parameters if node.target == "hardtanh": if node.kwargs.pop("min_val", -1.0) != -1.0: raise NotImplementedError( "min_val != -1.0 not supported for hardtanh" ) if node.kwargs.pop("max_val", 1.0) != 1.0: raise NotImplementedError( "max_val != 1.0 not supported for hardtanh" ) # Handle kwarg aliasing for cat (dim -> axis) if node.target == "cat": if "dim" in node.kwargs: node.kwargs["axis"] = node.kwargs.pop("dim") # Convert list of variable names to proper bracket-enclosed list if isinstance(node.args[0], list): # node.args[0] is a list like ['net_input1', 'net_input2'] # We need to convert it to a single string representing the list: [net_input1, net_input2] node.args = tuple( ["[" + ", ".join(node.args[0]) + "]"] + list(node.args[1:]) ) return activation_map.get(node.target, f"jax.nn.{node.target}") def _generate_forward( node: "Node", # noqa: F821 indent, frozen_layers: dict[str, bool] | None = None, layer_type: str = "", ) -> str: """ Generate forward pass line for a given node :param node: petab sciml Node object representing a step in the forward pass :param indent: indentation level for generated string :param frozen_layers: dict of layer names to boolean indicating whether layer is frozen :param layer_type: petab sciml layer type of the node (only relevant for call_module nodes) :return: string defining the forward pass implementation for the given node in equinox syntax """ if frozen_layers is None: frozen_layers = {} # Handle placeholder nodes - skip individual processing, handled collectively in generate_equinox if node.op == "placeholder": return "" # Handle output nodes if node.op == "output": args_str = ", ".join([f"{arg}" for arg in node.args]) return f"{' ' * indent}{node.target} = {args_str}" # Process layer calls tree_string = "" if node.op == "call_module": fun_str, tree_string = _process_layer_call( node, layer_type, frozen_layers ) # Process activation function calls if node.op in ("call_function", "call_method"): fun_str = _process_activation_call(node) # Build kwargs list, filtering out unsupported arguments kwargs = [ f"{k}={item}" for k, item in node.kwargs.items() if k not in ("inplace",) ] # Add key parameter for Dropout layers if layer_type.startswith("Dropout"): kwargs += ["key=key"] # Format the function call if node.op in ("call_module", "call_function", "call_method"): result = _format_function_call( node.name, fun_str, node.args, kwargs, indent ) # Prepend tree_string if needed for frozen layers if tree_string: return f"{' ' * indent}{tree_string}\n{result}" return result raise NotImplementedError(f"Operation {node.op} not supported")