3D toy model¶
- generate a segmented 3d image of cells
- modify image to create explicit membrane compartments
- combine cells, combine membranes
- create a simple sme model using this geometry
- do an example simulation
Navigation¶
- Press
Space
to show the next page - Press
Shift+Space
to show the previous page - Press
Escape
to zoom out
Utility functions¶
In [1]:
import numpy as np
from itertools import cycle
import matplotlib.pyplot as plt
from matplotlib import animation
import matplotlib.colors as mcolors
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import Image, display, HTML, Video
import imageio.v3 as iio
import tifffile
import skimage
from scipy import ndimage as ndi
import pyvista as pv
pv.global_theme.axes.show = True
pv.global_theme.interactive = True
plt.rcParams["figure.figsize"] = (8, 8)
Select the 'trame' jupyter backend below to have run the notebook locally and be able to interact with the plots. See documentation of pyvista for other backends
In [2]:
pv.set_jupyter_backend("static")
# pv.set_jupyter_backend("trame") # for interactive plots
In [3]:
def sphere_mask(grid_shape, center, radius, deformation):
# generate a boolean mask for a sphere with given center, radius and deformation
Z, Y, X = grid_shape
z0, y0, x0 = center
dz, dy, dx = deformation
z, y, x = np.ogrid[:Z, :Y, :X]
return dx * (x - x0) ** 2 + dy * (y - y0) ** 2 + dz * (z - z0) ** 2 <= radius**2
In [4]:
def spheres(n_pixels, n_spheres, max_radius, max_deform):
# generate a segmented image containing n_spheres randomly distributed, sized and deformed
voxels = np.zeros((n_pixels, n_pixels, n_pixels), dtype=np.uint16)
for n_sphere in range(1, n_spheres + 1):
center = np.random.randint(2, n_pixels - 2, 3)
nuclear_radius = np.random.randint(1, max_radius / 2)
cell_radius = np.random.randint(1.5 * nuclear_radius, max_radius)
deformation = np.random.uniform(1 / max_deform, max_deform, 3)
voxels[sphere_mask(voxels.shape, center, nuclear_radius, deformation)] = (
n_sphere
)
return voxels
3D plotting with pyvista(vtk)¶
In [5]:
def make_discrete_colormap(cmap: str = "tab10", values: np.array = []) -> list[int]:
"""Create a discrete colormap of potentially repeating colors of the same size as the `values` array.
Args:
cmap (str, optional): matplotlib colormap name. Defaults to "tab10".
values (np.array, optional): values to be mapped to colors. Defaults to [].
Returns:
list[int]: list of color in rgba format.
"""
cm = [(0, 0, 0, 1)]
i = 0
for c in cycle(plt.get_cmap(cmap).colors):
cm.append(mcolors.to_rgba(c))
if len(cm) >= len(values):
break
i += 1
return cm
In [6]:
def rgb_to_scalar(img: np.ndarray) -> np.ndarray:
"""Convert an array of RGB values to scalar values.
This function is necessary because pyvista does not support RGB values directly as mesh data
Args:
img (np.ndarray): data to be converted, of shape (n, m, 3)
Returns:
np.ndarray: data converted to scalar values, of shape (n, m)
"""
reshaped = img.reshape(-1, 3, copy=True)
unique_rgb, ridx = np.unique(reshaped, axis=0, return_inverse=True)
values = np.arange(len(unique_rgb))
return values[ridx].reshape(img.shape[:-1])
In [7]:
def plot3D(
data: np.ndarray,
title: str | list[str],
threshold_value: int | list[int] = [1, 0],
cmap: str | list[str] = "tab10",
with_swap: bool = True,
with_aux: bool = True,
with_cbar: bool = False,
mesh_kwargs: dict = {},
) -> pv.Plotter:
"""Plot a 3D image with optional auxilary image that can show a differently thresholded version of the same mesh.
Args:
data (np.ndarray): Data to plot
title (str | list[str]): Title for each plot
threshold_value (int | list[int], optional): Treshold values of reach plot. Values below the threshold will not be shown Defaults to [1, 0].
cmap (str | list[str], optional): Name of a matplotlib colormap or a list of colors in RGBA or hex format. Defaults to "tab10".
with_swap (bool, optional): Whether axes 0 and 2 should be swapped. Defaults to True.
with_aux (bool, optional): Enable second plot. Defaults to True.
with_cbar (bool, optional): Show colorbar. Defaults to False.
mesh_kwargs (dict, optional): Other keywor arguments for the pyvista plotter.add_mesh function. Defaults to {}.
Raises:
ValueError: When the input data is not 3D or 4D (for RGB values)
ValueError: When the title is not a list of two strings when with_aux is True
ValueError: When the threshold_value is not a list of two integers when with_aux is True
Returns:
pv.Plotter: pyvista plotter object. Call plotter.show() to display the plot
"""
if data.ndim not in [3, 4]:
raise ValueError("Image must be 3D or 4D (for rgb values)")
_data = data
plotter = pv.Plotter(shape=(1, 2 if with_aux else 1), border=False, notebook=True)
if with_aux and (len(title) != 2 or isinstance(title, str)):
raise ValueError("Two title must be provided for the two subplots")
if with_aux and (len(threshold_value) != 2 or isinstance(threshold_value, int)):
raise ValueError("Two threshold values must be provided for the two subplots")
if with_swap:
_data = np.swapaxes(data, 0, 2).copy()
if len(_data.shape) == 4:
_data = rgb_to_scalar(_data)
if isinstance(threshold_value, int):
threshold_value = [threshold_value, threshold_value]
img_data = pv.ImageData(dimensions=_data.shape, **mesh_kwargs)
img_data.point_data["Data"] = _data.flatten()
img_data = img_data.points_to_cells(scalars="Data")
plotter.subplot(0, 0)
plotter.add_text(title[0] if isinstance(title, list) else title)
plotter.add_mesh(
img_data.threshold(threshold_value[0]),
show_edges=True,
show_scalar_bar=with_cbar,
cmap=cmap,
)
if with_aux:
plotter.subplot(0, 1)
plotter.add_text(title[1] if isinstance(title, list) else title)
plotter.add_mesh(
img_data.threshold(threshold_value[1]),
show_edges=True,
show_scalar_bar=with_cbar,
cmap=cmap,
)
plotter.link_views()
return plotter
In [8]:
def animate(
data: list[np.ndarray] | np.ndarray,
title: list[str] | list[str],
with_swap: bool = True,
cmap: str | list[str] = "viridis",
with_cbar: bool = False,
threshold_value: int = 0,
filename="tmp.mp4",
):
"""Animate a list of 3D images. This is done by creating a mp4 video file that is then displayed in the notebook.
Args:
data (list[np.ndarray] | np.ndarray): list of 3D or 4D data that is to be animated
title (list[str] | list[str]): Title of each frame, e.g., a timestep
with_swap (bool, optional): Whether axes 0 and 2 should be swapped. Defaults to True.
cmap (str | list[str], optional): Colormap. Name of a matplotlib colormap or a list of colors in RGBA or hex format.. Defaults to "viridis".
with_cbar (bool, optional): Show colorbar in the movie. Defaults to False.
threshold_value (int, optional): Threshold below which data will not be shown. Defaults to 0.
filename (str, optional): path to where the movie should be stored. Defaults to "tmp.mp4".
Returns:
_type_: _description_
"""
plotter = pv.Plotter()
plotter.open_movie(filename)
def process_data(d):
if with_swap:
d = np.swapaxes(d, 0, 2).copy()
if len(d.shape) == 4:
d = rgb_to_scalar(d)
return d
_data = process_data(data[0])
img_data = pv.ImageData(dimensions=_data.shape)
img_data.point_data["Data"] = _data.flatten()
img_data = img_data.points_to_cells(scalars="Data")
plotter.add_text(title[0], name="time-label")
actor = plotter.add_mesh(
img_data.threshold(threshold_value),
show_edges=True,
show_scalar_bar=with_cbar,
cmap=cmap,
)
plotter.write_frame()
for i in range(len(data)):
_data = process_data(data[i])
img_data = pv.ImageData(dimensions=_data.shape)
img_data.point_data["Data"] = _data.flatten()
img_data = img_data.points_to_cells(scalars="Data")
plotter.add_text(title[i], name="time-label")
actor.mapper.dataset = img_data.threshold(threshold_value)
plotter.mapper.scalar_range = (np.min(_data), np.max(_data))
plotter.write_frame()
plotter.close()
return Video(filename, embed=True)
Generate segmented input data¶
- construct a 40x40x40 3d image with 300 randomly distributed, sized and deformed spheres
- each voxel has an index which identifies which sphere (if any) it belongs to
In [9]:
img_indexed = spheres(n_pixels=40, n_spheres=300, max_radius=14, max_deform=1.5)
- make a colormap that stays the same over all plots
In [10]:
values = np.unique(img_indexed)
lt = pv.LookupTable(
values=np.array(make_discrete_colormap(cmap="tab10", values=values)) * 255,
scalar_range=(0, len(values)),
n_values=len(values),
)
In [11]:
plotter = plot3D(
img_indexed,
["Segmented image (voxel)", "Segmented image (surface)"],
threshold_value=[1, 0],
cmap=lt,
)
plotter.show()
Generate explicit membranes by dilating each cell¶
- We want to add explicit membrane compartments around each cell.
- To do this we take a mask of each cell individually, dilate it, and select the pixels that differ from the original mask
- Repeating this over all cells and combining the results gives us a mask of membrane compartment pixels
In [12]:
img_membrane_mask = np.zeros(img_indexed.shape).astype(bool)
kernel = ndi.generate_binary_structure(rank=3, connectivity=1)
kernel_size = (3, 3, 3)
kernel = np.ones(kernel_size, dtype=np.uint8)
for index in range(img_indexed.max()):
img = (img_indexed == index).astype(np.uint8)
img_membrane_mask |= ndi.binary_dilation(img) != img
In [13]:
plot3D(
img_membrane_mask,
["Membrane voxels (voxel)", "Membrane voxels (surface)"],
threshold_value=[1, 0],
cmap=lt,
).show()
Define cells as any segmented pixel excluding membrane pixels¶
- Now we select all pixels that were identified as cells
- Then we exclude pixels that are part of the membrane mask to leave a cell mask
In [14]:
img_cell_mask = img_indexed != 0
img_cell_mask = img_cell_mask & (img_cell_mask != img_membrane_mask)
In [15]:
plot3D(
img_cell_mask,
["Cell voxels (voxel)", "Cell voxels (surface)"],
threshold_value=[1, 0],
cmap=lt,
).show()
Construct segmented geometry image for sme¶
- From these masks we can construct a segmented geometry image for sme
- Each colour in this image can then be assigned to a compartment in the model
- We export the result as a 3d tiff to be imported into sme
In [16]:
img = np.zeros(img_cell_mask.shape, dtype=np.uint8)
img[img_cell_mask] = 1
img[img_membrane_mask] = 2
tifffile.imwrite("geom3d.tiff", img)
In [17]:
plot3D(
img,
["Geometry image (voxel)", "Geometry image (surface)"],
threshold_value=[1, 0],
cmap=lt,
).show()
Create sme model¶
- This was done using the GUI, starting from the 2d toy model & importing a 3d geometry image
- As in the 2d case: one species in each compartment, intially only non-zero in outside
- Reactions:
outside <-> membrane
andmembrane <-> cell
- Here we open this model, then import the new geometry image generated above
In [18]:
import sme
In [19]:
model = sme.Model("3d-toy-model.xml")
In [20]:
model.import_geometry_from_image("geom3d.tiff")
In [21]:
plotter = plot3D(
model.compartment_image,
["Model geometry (voxel)", "Model geometry (surface)"],
threshold_value=[0, 0],
cmap=lt,
with_aux=False,
)
plotter.show()
Simulate model¶
- simulate for 60s, storing the results every 30s
- this might take a few minutes
In [22]:
simulation_results = model.simulate(60, 2)
Simulation results¶
In [23]:
plot3D(
data=simulation_results[0].concentration_image,
title=f"Concentrations at t={simulation_results[0].time_point}",
cmap="viridis",
with_aux=False,
threshold_value=0,
with_cbar=True,
).show()
In [24]:
plot3D(
simulation_results[15].concentration_image / 255.0,
f"Concentrations at t={simulation_results[15].time_point}",
cmap="viridis",
with_aux=False,
threshold_value=0,
).show()
In [25]:
plot3D(
simulation_results[30].concentration_image / 255.0,
f"Concentrations at t={simulation_results[30].time_point}",
cmap="viridis",
with_aux=False,
threshold_value=0,
).show()
In [26]:
animate(
[r.concentration_image for r in simulation_results],
[f"Concentrations at t={r.time_point}" for r in simulation_results],
with_swap=False,
cmap="viridis",
)
Out[26]: