Multiple model parameter fitting¶
Starting point¶
- $N$ segmented geometry images $G_i$, each with a target concentration image $T_i$ at some time $t_i$
Navigation¶
- Press
Space
to show the next page - Press
Shift+Space
to show the previous page - Press
Escape
to zoom out- a model using one of these geometry images with reaction parameters $k^{(j)}$ to be fitted
Strategy¶
- make N identical copies $M_i$ of the initial sme model and set the parameters to the same input values for all models
- update the geometry image for model $M_i$ to $G_i$
- simulate each model $M_i$ for time $t_i$
- calculate the rms difference between the simulated concentrations of model $M_i$ and the target concentration $T_i$
- sum this difference over all models $M_i$ to return a total cost function value for this set of parameters
Then feed this cost function into an optimisation algorithm to fit the parameters $k^{(j)}$
Utility functions¶
These are just used for making the geometry images and plotting.
In [1]:
import numpy as np
from itertools import cycle
import logging
import matplotlib.pyplot as plt
from matplotlib import animation
import matplotlib.colors as mcolors
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import Image, display, HTML, Video
import imageio.v3 as iio
import tifffile
import skimage
from scipy import ndimage as ndi
import pyvista as pv
pv.global_theme.axes.show = True
pv.global_theme.interactive = True
plt.rcParams["figure.figsize"] = (4, 4)
Select the 'trame' jupyter backend below to have run the notebook locally and be able to interact with the plots. See documentation of pyvista for other backends
In [2]:
pv.set_jupyter_backend("static")
# pv.set_jupyter_backend("trame") # for interactive plots
In [3]:
def sphere_mask(grid_shape, center, radius, deformation):
# generate a boolean mask for a sphere with given center, radius and deformation
Z, Y, X = grid_shape
z0, y0, x0 = center
dz, dy, dx = deformation
z, y, x = np.ogrid[:Z, :Y, :X]
return dx * (x - x0) ** 2 + dy * (y - y0) ** 2 + dz * (z - z0) ** 2 <= radius**2
In [4]:
def generate_geometry_image(n_pixels):
# generate a segmented image containing one randomly distributed, sized and deformed sphere
max_radius = n_pixels / 3
max_deform = 1.2
voxels = np.zeros((n_pixels, n_pixels, n_pixels), dtype=np.uint16)
center = np.random.randint(2, n_pixels - 2, 3)
nuclear_radius = np.random.randint(1, max_radius / 2)
cell_radius = np.random.randint(1.5 * nuclear_radius, max_radius)
deformation = np.random.uniform(1 / max_deform, max_deform, 3)
voxels[sphere_mask(voxels.shape, center, cell_radius, deformation)] = 2
voxels[sphere_mask(voxels.shape, center, nuclear_radius, deformation)] = 1
return voxels
In [5]:
def make_discrete_colormap(cmap: str = "tab10", values: np.array = []) -> list[int]:
"""Create a discrete colormap of potentially repeating colors of the same size as the `values` array.
Args:
cmap (str, optional): matplotlib colormap name. Defaults to "tab10".
values (np.array, optional): values to be mapped to colors. Defaults to [].
Returns:
list[int]: list of color in rgba format.
"""
cm = [(0, 0, 0, 1)]
i = 0
for c in cycle(plt.get_cmap(cmap).colors):
cm.append(mcolors.to_rgba(c))
if len(cm) >= len(values):
break
i += 1
return cm
In [6]:
def rgb_to_scalar(img: np.ndarray) -> np.ndarray:
"""Convert an array of RGB values to scalar values.
This function is necessary because pyvista does not support RGB values directly as mesh data
Args:
img (np.ndarray): data to be converted, of shape (n, m, 3)
Returns:
np.ndarray: data converted to scalar values, of shape (n, m)
"""
reshaped = img.reshape(-1, 3)
unique_rgb, ridx = np.unique(reshaped, axis=0, return_inverse=True)
values = np.arange(len(unique_rgb))
return values[ridx].reshape(img.shape[:-1])
In [7]:
def plot3D(
data: np.ndarray,
title: str | list[str],
threshold_value: int | list[int] = [1, 0],
cmap: str | list[str] = "tab10",
with_swap: bool = True,
with_cbar: bool = False,
mesh_kwargs: dict = {},
) -> pv.Plotter:
"""Plot a 3D image with optional auxilary image that can show a differently thresholded version of the same mesh.
Args:
data (np.ndarray): Data to plot
title (str | list[str]): Title for each plot
threshold_value (int | list[int], optional): Treshold values of reach plot. Values below the threshold will not be shown Defaults to [1, 0].
cmap (str | list[str], optional): Name of a matplotlib colormap or a list of colors in RGBA or hex format. Defaults to "tab10".
with_swap (bool, optional): Whether axes 0 and 2 should be swapped. Defaults to True.
with_cbar (bool, optional): Show colorbar. Defaults to False.
mesh_kwargs (dict, optional): Other keywor arguments for the pyvista plotter.add_mesh function. Defaults to {}.
Raises:
ValueError: When the input data is not 3D or 4D (for RGB values)
ValueError: When the threshold_value is not a list of two integers when with_aux is True
Returns:
pv.Plotter: pyvista plotter object. Call plotter.show() to display the plot
"""
if data.ndim not in [3, 4]:
raise ValueError("Image must be 3D or 4D (for rgb values)")
_data = data
plotter = pv.Plotter(border=False, notebook=True)
if with_swap:
_data = np.swapaxes(data, 0, 2).copy()
if len(_data.shape) == 4:
_data = rgb_to_scalar(_data)
if isinstance(threshold_value, int):
threshold_value = [threshold_value, threshold_value]
img_data = pv.ImageData(dimensions=_data.shape, **mesh_kwargs)
img_data.point_data["Data"] = _data.flatten()
img_data = img_data.points_to_cells(scalars="Data")
plotter.subplot(0, 0)
plotter.add_text(title)
plotter.add_mesh(
img_data.threshold(threshold_value[0]),
show_edges=True,
show_scalar_bar=with_cbar,
cmap=cmap,
)
return plotter
In [8]:
def plot_geometry(img_indexed: np.ndarray, title: str):
values = np.unique(img_indexed)
lt = pv.LookupTable(
values=np.array(make_discrete_colormap(cmap="tab10", values=values)) * 255,
scalar_range=(0, len(values)),
n_values=len(values),
)
plotter = plot3D(
img_indexed,
title,
threshold_value=[1, 0],
cmap=lt,
)
plotter.show()
Model¶
- the model has 3 compartments, with a species in each compartment
- initially the species in the outside compartment has concentration 1, all others are zero
- the model has two reaction rate parameters:
k1
andk2
k1
controls the rate at which stuff flows from the outside to the membrane
- each species has a diffusion constant
- e.g. the species
A_membrane
in the membrane has a diffusion constant of 1.0
- e.g. the species
In [9]:
import sme
model = sme.Model("3d-model-parameter-fitting.xml")
print(model.parameters["k1"])
<sme.Parameter> - name: 'k1' - expression: '1'
In [10]:
print(model.compartments["membrane"].species["A_membrane"])
<sme.Species> - name: 'A_membrane' - diffusion_constant: 1
Generate segmented input data¶
- construct N 40x40x40 3d images, each with a single randomly distributed, sized and deformed cell / nucleus
In [11]:
N = 3
geometry_images = [generate_geometry_image(n_pixels=40) for _ in range(N)]
Apply geometry images to models¶
- Reactions:
outside <-> membrane
andmembrane <-> cell
- Here we open this model, then import the new geometry image generated above
In [12]:
from sme_contrib.optimize import minimize
In [13]:
models = []
for img in geometry_images:
m = sme.Model("3d-model-parameter-fitting.xml")
tifffile.imwrite("geom3d.tiff", img)
m.import_geometry_from_image("geom3d.tiff")
models.append(m)
In [14]:
for model in models:
plot_geometry(model.compartment_image, f"Geometry Model")
Cost function¶
- Takes a list of parameter values, in this case
[k1, diffusion-constant]
- Sets the parameters to the same values in all the models
- Simulates each model for time t
- Calculates the sum of squares of differences between the target concentration and the simulationed concentration over all voxels and models
In [15]:
def diff(conc: np.ndarray, target: np.ndarray, mask: np.ndarray):
return np.sum(np.power((target - conc)[mask], 2))
In [16]:
# here the target image is just 1 (high concentration for this model) for every voxel:
# (in the cost function we only apply this to voxels inside the cell compartment, so here we just set 1 everywhere).
target = np.ones_like(geometry_images[0])
def cost_function(params: list[float], verbose=False):
t = 10
cost = 0.0
# apply parameters
for model in models:
model.parameters["k1"].value = f"{params[0]}"
model.compartments["membrane"].species["A_membrane"].diffusion_constant = (
params[1]
)
# do simulation and calculate cost for each model: which is difference to target for A_cell species concentration in the cell compartment
for i, model in enumerate(models):
results = model.simulate(t, t)
result_cost = diff(
results[-1].species_concentration["A_cell"],
target,
model.compartments["cell"].geometry_mask,
)
if verbose:
logging.info(
f"model {i}, k1={params[0]}, diffusion-constant={params[1]} -> cost {result_cost}"
)
cost += result_cost
if verbose:
logging.info(
f"All models, k1={params[0]}, diffusion-constant={params[1]} -> cost {cost}"
)
return cost
In [17]:
# this simulates all models with `k1=1.2`, `diffusion-constant=0.7` and returns the total cost function:
cost_function([1.2, 0.7], verbose=True)
2025-04-15 13:42:12,383 - root - INFO - model 0, k1=1.2, diffusion-constant=0.7 -> cost 530.1865510237294
2025-04-15 13:42:13,252 - root - INFO - model 1, k1=1.2, diffusion-constant=0.7 -> cost 538.7947379220191
2025-04-15 13:42:14,157 - root - INFO - model 2, k1=1.2, diffusion-constant=0.7 -> cost 4.98128496152307
2025-04-15 13:42:14,158 - root - INFO - All models, k1=1.2, diffusion-constant=0.7 -> cost 1073.9625739072715
Out[17]:
np.float64(1073.9625739072715)
Parameter optimization¶
- Sets lower and upper bounds for each parameter
- Use particle swarm to try to minimise the cost function
Since the target image has a high concentration in the cell, the optimal parameters will just be k1
and k2
as large as possible (within the bounds that we set)
In [18]:
lower_bounds = [0.1, 0.1]
upper_bounds = [2.0, 2.0]
In [19]:
# particle swarm parameter optimization with 4 particles, 5 iterations, using 4 cpu cores:
# (more particles e.g. 20 & iterations e.g. 100, would be needed to get decent results)
best_cost, best_params, opt = minimize(
cost_function, lower_bounds, upper_bounds, particles=4, iterations=10, processes=4
)
2025-04-15 13:42:14,172 - pyswarms.single.global_best - INFO - Optimize for 10 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 0%| |0/10
pyswarms.single.global_best: 0%| |0/10, best_cost=1.07e+3
pyswarms.single.global_best: 10%|█ |1/10, best_cost=1.07e+3
pyswarms.single.global_best: 10%|█ |1/10, best_cost=1.07e+3
pyswarms.single.global_best: 20%|██ |2/10, best_cost=1.07e+3
pyswarms.single.global_best: 20%|██ |2/10, best_cost=1.07e+3
pyswarms.single.global_best: 30%|███ |3/10, best_cost=1.07e+3
pyswarms.single.global_best: 30%|███ |3/10, best_cost=1.07e+3
pyswarms.single.global_best: 40%|████ |4/10, best_cost=1.07e+3
pyswarms.single.global_best: 40%|████ |4/10, best_cost=1.07e+3
pyswarms.single.global_best: 50%|█████ |5/10, best_cost=1.07e+3
pyswarms.single.global_best: 50%|█████ |5/10, best_cost=1.07e+3
pyswarms.single.global_best: 60%|██████ |6/10, best_cost=1.07e+3
pyswarms.single.global_best: 60%|██████ |6/10, best_cost=1.07e+3
pyswarms.single.global_best: 70%|███████ |7/10, best_cost=1.07e+3
pyswarms.single.global_best: 70%|███████ |7/10, best_cost=1.07e+3
pyswarms.single.global_best: 80%|████████ |8/10, best_cost=1.07e+3
pyswarms.single.global_best: 80%|████████ |8/10, best_cost=1.07e+3
pyswarms.single.global_best: 90%|█████████ |9/10, best_cost=1.07e+3
pyswarms.single.global_best: 90%|█████████ |9/10, best_cost=1.07e+3
pyswarms.single.global_best: 100%|██████████|10/10, best_cost=1.07e+3
pyswarms.single.global_best: 100%|██████████|10/10, best_cost=1.07e+3
2025-04-15 13:42:59,123 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 1070.634623077984, best pos: [1.78428871 1.80187756]
In [20]:
from pyswarms.utils.plotters import plot_cost_history
plot_cost_history(cost_history=opt.cost_history)
plt.show()
In [21]:
best_params
Out[21]:
array([1.78428871, 1.80187756])
In [ ]: