Multiple model parameter fitting¶

Starting point¶

  • $N$ segmented geometry images $G_i$, each with a target concentration image $T_i$ at some time $t_i$

Navigation¶

  • Press Space to show the next page
  • Press Shift+Space to show the previous page
  • Press Escape to zoom out- a model using one of these geometry images with reaction parameters $k^{(j)}$ to be fitted

Strategy¶

  • make N identical copies $M_i$ of the initial sme model and set the parameters to the same input values for all models
  • update the geometry image for model $M_i$ to $G_i$
  • simulate each model $M_i$ for time $t_i$
  • calculate the rms difference between the simulated concentrations of model $M_i$ and the target concentration $T_i$
  • sum this difference over all models $M_i$ to return a total cost function value for this set of parameters

Then feed this cost function into an optimisation algorithm to fit the parameters $k^{(j)}$

Utility functions¶

These are just used for making the geometry images and plotting.

In [1]:
import numpy as np
from itertools import cycle
import logging

import matplotlib.pyplot as plt
from matplotlib import animation
import matplotlib.colors as mcolors
from mpl_toolkits.mplot3d import Axes3D

from IPython.display import Image, display, HTML, Video

import imageio.v3 as iio
import tifffile
import skimage
from scipy import ndimage as ndi

import pyvista as pv

pv.global_theme.axes.show = True
pv.global_theme.interactive = True
plt.rcParams["figure.figsize"] = (4, 4)

Select the 'trame' jupyter backend below to have run the notebook locally and be able to interact with the plots. See documentation of pyvista for other backends

In [2]:
pv.set_jupyter_backend("static")

# pv.set_jupyter_backend("trame") # for interactive plots
In [3]:
def sphere_mask(grid_shape, center, radius, deformation):
    # generate a boolean mask for a sphere with given center, radius and deformation
    Z, Y, X = grid_shape
    z0, y0, x0 = center
    dz, dy, dx = deformation
    z, y, x = np.ogrid[:Z, :Y, :X]
    return dx * (x - x0) ** 2 + dy * (y - y0) ** 2 + dz * (z - z0) ** 2 <= radius**2
In [4]:
def generate_geometry_image(n_pixels):
    # generate a segmented image containing one randomly distributed, sized and deformed sphere
    max_radius = n_pixels / 3
    max_deform = 1.2
    voxels = np.zeros((n_pixels, n_pixels, n_pixels), dtype=np.uint16)
    center = np.random.randint(2, n_pixels - 2, 3)
    nuclear_radius = np.random.randint(1, max_radius / 2)
    cell_radius = np.random.randint(1.5 * nuclear_radius, max_radius)
    deformation = np.random.uniform(1 / max_deform, max_deform, 3)
    voxels[sphere_mask(voxels.shape, center, cell_radius, deformation)] = 2
    voxels[sphere_mask(voxels.shape, center, nuclear_radius, deformation)] = 1
    return voxels
In [5]:
def make_discrete_colormap(cmap: str = "tab10", values: np.array = []) -> list[int]:
    """Create a discrete colormap of potentially repeating colors of the same size as the `values` array.

    Args:
        cmap (str, optional): matplotlib colormap name. Defaults to "tab10".
        values (np.array, optional): values to be mapped to colors. Defaults to [].

    Returns:
        list[int]: list of color in rgba format.
    """
    cm = [(0, 0, 0, 1)]
    i = 0
    for c in cycle(plt.get_cmap(cmap).colors):
        cm.append(mcolors.to_rgba(c))
        if len(cm) >= len(values):
            break
        i += 1
    return cm
In [6]:
def rgb_to_scalar(img: np.ndarray) -> np.ndarray:
    """Convert an array of RGB values to scalar values.
        This function is necessary because pyvista does not support RGB values directly as mesh data

    Args:
        img (np.ndarray): data to be converted, of shape (n, m, 3)

    Returns:
        np.ndarray: data converted to scalar values, of shape (n, m)
    """
    reshaped = img.reshape(-1, 3)
    unique_rgb, ridx = np.unique(reshaped, axis=0, return_inverse=True)

    values = np.arange(len(unique_rgb))
    return values[ridx].reshape(img.shape[:-1])
In [7]:
def plot3D(
    data: np.ndarray,
    title: str | list[str],
    threshold_value: int | list[int] = [1, 0],
    cmap: str | list[str] = "tab10",
    with_swap: bool = True,
    with_cbar: bool = False,
    mesh_kwargs: dict = {},
) -> pv.Plotter:
    """Plot a 3D image with optional auxilary image that can show a differently thresholded version of the same mesh.

    Args:
        data (np.ndarray): Data to plot
        title (str | list[str]): Title for each plot
        threshold_value (int | list[int], optional): Treshold values of reach plot. Values below the threshold will not be shown Defaults to [1, 0].
        cmap (str | list[str], optional): Name of a matplotlib colormap or a list of colors in RGBA or hex format. Defaults to "tab10".
        with_swap (bool, optional): Whether axes 0 and 2 should be swapped. Defaults to True.
        with_cbar (bool, optional): Show colorbar. Defaults to False.
        mesh_kwargs (dict, optional): Other keywor arguments for the pyvista plotter.add_mesh function. Defaults to {}.

    Raises:
        ValueError: When the input data is not 3D or 4D (for RGB values)
        ValueError: When the threshold_value is not a list of two integers when with_aux is True

    Returns:
        pv.Plotter: pyvista plotter object. Call plotter.show() to display the plot
    """
    if data.ndim not in [3, 4]:
        raise ValueError("Image must be 3D or 4D (for rgb values)")

    _data = data

    plotter = pv.Plotter(border=False, notebook=True)

    if with_swap:
        _data = np.swapaxes(data, 0, 2).copy()

    if len(_data.shape) == 4:
        _data = rgb_to_scalar(_data)

    if isinstance(threshold_value, int):
        threshold_value = [threshold_value, threshold_value]

    img_data = pv.ImageData(dimensions=_data.shape, **mesh_kwargs)
    img_data.point_data["Data"] = _data.flatten()
    img_data = img_data.points_to_cells(scalars="Data")
    plotter.subplot(0, 0)
    plotter.add_text(title)
    plotter.add_mesh(
        img_data.threshold(threshold_value[0]),
        show_edges=True,
        show_scalar_bar=with_cbar,
        cmap=cmap,
    )
    return plotter
In [8]:
def plot_geometry(img_indexed: np.ndarray, title: str):
    values = np.unique(img_indexed)
    lt = pv.LookupTable(
        values=np.array(make_discrete_colormap(cmap="tab10", values=values)) * 255,
        scalar_range=(0, len(values)),
        n_values=len(values),
    )
    plotter = plot3D(
        img_indexed,
        title,
        threshold_value=[1, 0],
        cmap=lt,
    )
    plotter.show()

Model¶

  • the model has 3 compartments, with a species in each compartment
    • initially the species in the outside compartment has concentration 1, all others are zero
  • the model has two reaction rate parameters: k1 and k2
    • k1 controls the rate at which stuff flows from the outside to the membrane
  • each species has a diffusion constant
    • e.g. the species A_membrane in the membrane has a diffusion constant of 1.0
In [9]:
import sme

model = sme.Model("3d-model-parameter-fitting.xml")
print(model.parameters["k1"])
<sme.Parameter>
  - name: 'k1'
  - expression: '1'

In [10]:
print(model.compartments["membrane"].species["A_membrane"])
<sme.Species>
  - name: 'A_membrane'
  - diffusion_constant: 1

Generate segmented input data¶

  • construct N 40x40x40 3d images, each with a single randomly distributed, sized and deformed cell / nucleus
In [11]:
N = 3
geometry_images = [generate_geometry_image(n_pixels=40) for _ in range(N)]

Apply geometry images to models¶

  • Reactions: outside <-> membrane and membrane <-> cell
  • Here we open this model, then import the new geometry image generated above
In [12]:
from sme_contrib.optimize import minimize
In [13]:
models = []
for img in geometry_images:
    m = sme.Model("3d-model-parameter-fitting.xml")
    tifffile.imwrite("geom3d.tiff", img)
    m.import_geometry_from_image("geom3d.tiff")
    models.append(m)
In [14]:
for model in models:
    plot_geometry(model.compartment_image, f"Geometry Model")
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Cost function¶

  • Takes a list of parameter values, in this case [k1, diffusion-constant]
  • Sets the parameters to the same values in all the models
  • Simulates each model for time t
  • Calculates the sum of squares of differences between the target concentration and the simulationed concentration over all voxels and models
In [15]:
def diff(conc: np.ndarray, target: np.ndarray, mask: np.ndarray):
    return np.sum(np.power((target - conc)[mask], 2))
In [16]:
# here the target image is just 1 (high concentration for this model) for every voxel:
# (in the cost function we only apply this to voxels inside the cell compartment, so here we just set 1 everywhere).
target = np.ones_like(geometry_images[0])


def cost_function(params: list[float], verbose=False):
    t = 10
    cost = 0.0

    # apply parameters
    for model in models:
        model.parameters["k1"].value = f"{params[0]}"
        model.compartments["membrane"].species["A_membrane"].diffusion_constant = (
            params[1]
        )

    # do simulation and calculate cost for each model: which is difference to target for A_cell species concentration in the cell compartment
    for i, model in enumerate(models):
        results = model.simulate(t, t)
        result_cost = diff(
            results[-1].species_concentration["A_cell"],
            target,
            model.compartments["cell"].geometry_mask,
        )
        if verbose:
            logging.info(
                f"model {i}, k1={params[0]}, diffusion-constant={params[1]} -> cost {result_cost}"
            )
        cost += result_cost

    if verbose:
        logging.info(
            f"All models, k1={params[0]}, diffusion-constant={params[1]} -> cost {cost}"
        )
    return cost
In [17]:
# this simulates all models with `k1=1.2`, `diffusion-constant=0.7` and returns the total cost function:
cost_function([1.2, 0.7], verbose=True)
2025-04-15 13:42:12,383 - root - INFO - model 0, k1=1.2, diffusion-constant=0.7 -> cost 530.1865510237294
2025-04-15 13:42:13,252 - root - INFO - model 1, k1=1.2, diffusion-constant=0.7 -> cost 538.7947379220191
2025-04-15 13:42:14,157 - root - INFO - model 2, k1=1.2, diffusion-constant=0.7 -> cost 4.98128496152307
2025-04-15 13:42:14,158 - root - INFO - All models, k1=1.2, diffusion-constant=0.7 -> cost 1073.9625739072715
Out[17]:
np.float64(1073.9625739072715)

Parameter optimization¶

  • Sets lower and upper bounds for each parameter
  • Use particle swarm to try to minimise the cost function

Since the target image has a high concentration in the cell, the optimal parameters will just be k1 and k2 as large as possible (within the bounds that we set)

In [18]:
lower_bounds = [0.1, 0.1]
upper_bounds = [2.0, 2.0]
In [19]:
# particle swarm parameter optimization with 4 particles, 5 iterations, using 4 cpu cores:
# (more particles e.g. 20 & iterations e.g. 100, would be needed to get decent results)
best_cost, best_params, opt = minimize(
    cost_function, lower_bounds, upper_bounds, particles=4, iterations=10, processes=4
)
2025-04-15 13:42:14,172 - pyswarms.single.global_best - INFO - Optimize for 10 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best:   0%|          |0/10
pyswarms.single.global_best:   0%|          |0/10, best_cost=1.07e+3
pyswarms.single.global_best:  10%|█         |1/10, best_cost=1.07e+3
pyswarms.single.global_best:  10%|█         |1/10, best_cost=1.07e+3
pyswarms.single.global_best:  20%|██        |2/10, best_cost=1.07e+3
pyswarms.single.global_best:  20%|██        |2/10, best_cost=1.07e+3
pyswarms.single.global_best:  30%|███       |3/10, best_cost=1.07e+3
pyswarms.single.global_best:  30%|███       |3/10, best_cost=1.07e+3
pyswarms.single.global_best:  40%|████      |4/10, best_cost=1.07e+3
pyswarms.single.global_best:  40%|████      |4/10, best_cost=1.07e+3
pyswarms.single.global_best:  50%|█████     |5/10, best_cost=1.07e+3
pyswarms.single.global_best:  50%|█████     |5/10, best_cost=1.07e+3
pyswarms.single.global_best:  60%|██████    |6/10, best_cost=1.07e+3
pyswarms.single.global_best:  60%|██████    |6/10, best_cost=1.07e+3
pyswarms.single.global_best:  70%|███████   |7/10, best_cost=1.07e+3
pyswarms.single.global_best:  70%|███████   |7/10, best_cost=1.07e+3
pyswarms.single.global_best:  80%|████████  |8/10, best_cost=1.07e+3
pyswarms.single.global_best:  80%|████████  |8/10, best_cost=1.07e+3
pyswarms.single.global_best:  90%|█████████ |9/10, best_cost=1.07e+3
pyswarms.single.global_best:  90%|█████████ |9/10, best_cost=1.07e+3
pyswarms.single.global_best: 100%|██████████|10/10, best_cost=1.07e+3
pyswarms.single.global_best: 100%|██████████|10/10, best_cost=1.07e+3
2025-04-15 13:42:59,123 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 1070.634623077984, best pos: [1.78428871 1.80187756]
In [20]:
from pyswarms.utils.plotters import plot_cost_history

plot_cost_history(cost_history=opt.cost_history)
plt.show()
No description has been provided for this image
In [21]:
best_params
Out[21]:
array([1.78428871, 1.80187756])
In [ ]: