3D toy model¶

  • generate a segmented 3d image of cells
  • modify image to create explicit membrane compartments
  • combine cells, combine membranes
  • create a simple sme model using this geometry
  • do an example simulation

Navigation¶

  • Press Space to show the next page
  • Press Shift+Space to show the previous page
  • Press Escape to zoom out

Utility functions¶

In [1]:
import numpy as np
from itertools import cycle

import matplotlib.pyplot as plt
from matplotlib import animation
import matplotlib.colors as mcolors
from mpl_toolkits.mplot3d import Axes3D

from IPython.display import Image, display, HTML, Video

import imageio.v3 as iio
import tifffile
import skimage
from scipy import ndimage as ndi

import pyvista as pv

pv.global_theme.axes.show = True
pv.global_theme.interactive = True
plt.rcParams["figure.figsize"] = (8, 8)

Select the 'trame' jupyter backend below to have run the notebook locally and be able to interact with the plots. See documentation of pyvista for other backends

In [2]:
pv.set_jupyter_backend("static")

# pv.set_jupyter_backend("trame") # for interactive plots
In [3]:
def sphere_mask(grid_shape, center, radius, deformation):
    # generate a boolean mask for a sphere with given center, radius and deformation
    Z, Y, X = grid_shape
    z0, y0, x0 = center
    dz, dy, dx = deformation
    z, y, x = np.ogrid[:Z, :Y, :X]
    return dx * (x - x0) ** 2 + dy * (y - y0) ** 2 + dz * (z - z0) ** 2 <= radius**2
In [4]:
def spheres(n_pixels, n_spheres, max_radius, max_deform):
    # generate a segmented image containing n_spheres randomly distributed, sized and deformed
    voxels = np.zeros((n_pixels, n_pixels, n_pixels), dtype=np.uint16)
    for n_sphere in range(1, n_spheres + 1):
        center = np.random.randint(2, n_pixels - 2, 3)
        nuclear_radius = np.random.randint(1, max_radius / 2)
        cell_radius = np.random.randint(1.5 * nuclear_radius, max_radius)
        deformation = np.random.uniform(1 / max_deform, max_deform, 3)
        voxels[sphere_mask(voxels.shape, center, nuclear_radius, deformation)] = (
            n_sphere
        )
    return voxels

3D plotting with pyvista(vtk)¶

In [5]:
def make_discrete_colormap(cmap: str = "tab10", values: np.array = []) -> list[int]:
    """Create a discrete colormap of potentially repeating colors of the same size as the `values` array.

    Args:
        cmap (str, optional): matplotlib colormap name. Defaults to "tab10".
        values (np.array, optional): values to be mapped to colors. Defaults to [].

    Returns:
        list[int]: list of color in rgba format.
    """
    cm = [(0, 0, 0, 1)]
    i = 0
    for c in cycle(plt.get_cmap(cmap).colors):
        cm.append(mcolors.to_rgba(c))
        if len(cm) >= len(values):
            break
        i += 1
    return cm
In [6]:
def rgb_to_scalar(img: np.ndarray) -> np.ndarray:
    """Convert an array of RGB values to scalar values.
        This function is necessary because pyvista does not support RGB values directly as mesh data

    Args:
        img (np.ndarray): data to be converted, of shape (n, m, 3)

    Returns:
        np.ndarray: data converted to scalar values, of shape (n, m)
    """
    reshaped = img.reshape(-1, 3, copy=True)
    unique_rgb, ridx = np.unique(reshaped, axis=0, return_inverse=True)

    values = np.arange(len(unique_rgb))
    return values[ridx].reshape(img.shape[:-1])
In [7]:
def plot3D(
    data: np.ndarray,
    title: str | list[str],
    threshold_value: int | list[int] = [1, 0],
    cmap: str | list[str] = "tab10",
    with_swap: bool = True,
    with_aux: bool = True,
    with_cbar: bool = False,
    mesh_kwargs: dict = {},
) -> pv.Plotter:
    """Plot a 3D image with optional auxilary image that can show a differently thresholded version of the same mesh.

    Args:
        data (np.ndarray): Data to plot
        title (str | list[str]): Title for each plot
        threshold_value (int | list[int], optional): Treshold values of reach plot. Values below the threshold will not be shown Defaults to [1, 0].
        cmap (str | list[str], optional): Name of a matplotlib colormap or a list of colors in RGBA or hex format. Defaults to "tab10".
        with_swap (bool, optional): Whether axes 0 and 2 should be swapped. Defaults to True.
        with_aux (bool, optional): Enable second plot. Defaults to True.
        with_cbar (bool, optional): Show colorbar. Defaults to False.
        mesh_kwargs (dict, optional): Other keywor arguments for the pyvista plotter.add_mesh function. Defaults to {}.

    Raises:
        ValueError: When the input data is not 3D or 4D (for RGB values)
        ValueError: When the title is not a list of two strings when with_aux is True
        ValueError: When the threshold_value is not a list of two integers when with_aux is True

    Returns:
        pv.Plotter: pyvista plotter object. Call plotter.show() to display the plot
    """
    if data.ndim not in [3, 4]:
        raise ValueError("Image must be 3D or 4D (for rgb values)")

    _data = data

    plotter = pv.Plotter(shape=(1, 2 if with_aux else 1), border=False, notebook=True)

    if with_aux and (len(title) != 2 or isinstance(title, str)):
        raise ValueError("Two title must be provided for the two subplots")

    if with_aux and (len(threshold_value) != 2 or isinstance(threshold_value, int)):
        raise ValueError("Two threshold values must be provided for the two subplots")

    if with_swap:
        _data = np.swapaxes(data, 0, 2).copy()

    if len(_data.shape) == 4:
        _data = rgb_to_scalar(_data)

    if isinstance(threshold_value, int):
        threshold_value = [threshold_value, threshold_value]

    img_data = pv.ImageData(dimensions=_data.shape, **mesh_kwargs)
    img_data.point_data["Data"] = _data.flatten()
    img_data = img_data.points_to_cells(scalars="Data")
    plotter.subplot(0, 0)
    plotter.add_text(title[0] if isinstance(title, list) else title)
    plotter.add_mesh(
        img_data.threshold(threshold_value[0]),
        show_edges=True,
        show_scalar_bar=with_cbar,
        cmap=cmap,
    )

    if with_aux:
        plotter.subplot(0, 1)
        plotter.add_text(title[1] if isinstance(title, list) else title)
        plotter.add_mesh(
            img_data.threshold(threshold_value[1]),
            show_edges=True,
            show_scalar_bar=with_cbar,
            cmap=cmap,
        )
        plotter.link_views()
    return plotter
In [8]:
def animate(
    data: list[np.ndarray] | np.ndarray,
    title: list[str] | list[str],
    with_swap: bool = True,
    cmap: str | list[str] = "viridis",
    with_cbar: bool = False,
    threshold_value: int = 0,
    filename="tmp.mp4",
):
    """Animate a list of 3D images. This is done by creating a mp4 video file that is then displayed in the notebook.

    Args:
        data (list[np.ndarray] | np.ndarray): list of 3D or 4D data that is to be animated
        title (list[str] | list[str]): Title of each frame, e.g., a timestep
        with_swap (bool, optional): Whether axes 0 and 2 should be swapped. Defaults to True.
        cmap (str | list[str], optional): Colormap. Name of a matplotlib colormap or a list of colors in RGBA or hex format.. Defaults to "viridis".
        with_cbar (bool, optional): Show colorbar in the movie. Defaults to False.
        threshold_value (int, optional): Threshold below which data will not be shown. Defaults to 0.
        filename (str, optional): path to where the movie should be stored. Defaults to "tmp.mp4".

    Returns:
        _type_: _description_
    """
    plotter = pv.Plotter()

    plotter.open_movie(filename)

    def process_data(d):
        if with_swap:
            d = np.swapaxes(d, 0, 2).copy()
        if len(d.shape) == 4:
            d = rgb_to_scalar(d)
        return d

    _data = process_data(data[0])
    img_data = pv.ImageData(dimensions=_data.shape)
    img_data.point_data["Data"] = _data.flatten()
    img_data = img_data.points_to_cells(scalars="Data")
    plotter.add_text(title[0], name="time-label")
    actor = plotter.add_mesh(
        img_data.threshold(threshold_value),
        show_edges=True,
        show_scalar_bar=with_cbar,
        cmap=cmap,
    )

    plotter.write_frame()
    for i in range(len(data)):
        _data = process_data(data[i])
        img_data = pv.ImageData(dimensions=_data.shape)
        img_data.point_data["Data"] = _data.flatten()
        img_data = img_data.points_to_cells(scalars="Data")

        plotter.add_text(title[i], name="time-label")
        actor.mapper.dataset = img_data.threshold(threshold_value)
        plotter.mapper.scalar_range = (np.min(_data), np.max(_data))
        plotter.write_frame()

    plotter.close()

    return Video(filename, embed=True)

Generate segmented input data¶

  • construct a 40x40x40 3d image with 300 randomly distributed, sized and deformed spheres
  • each voxel has an index which identifies which sphere (if any) it belongs to
In [9]:
img_indexed = spheres(n_pixels=40, n_spheres=300, max_radius=14, max_deform=1.5)
  • make a colormap that stays the same over all plots
In [10]:
values = np.unique(img_indexed)
lt = pv.LookupTable(
    values=np.array(make_discrete_colormap(cmap="tab10", values=values)) * 255,
    scalar_range=(0, len(values)),
    n_values=len(values),
)
In [11]:
plotter = plot3D(
    img_indexed,
    ["Segmented image (voxel)", "Segmented image (surface)"],
    threshold_value=[1, 0],
    cmap=lt,
)
plotter.show()
No description has been provided for this image

Generate explicit membranes by dilating each cell¶

  • We want to add explicit membrane compartments around each cell.
  • To do this we take a mask of each cell individually, dilate it, and select the pixels that differ from the original mask
  • Repeating this over all cells and combining the results gives us a mask of membrane compartment pixels
In [12]:
img_membrane_mask = np.zeros(img_indexed.shape).astype(bool)
kernel = ndi.generate_binary_structure(rank=3, connectivity=1)
kernel_size = (3, 3, 3)
kernel = np.ones(kernel_size, dtype=np.uint8)
for index in range(img_indexed.max()):
    img = (img_indexed == index).astype(np.uint8)
    img_membrane_mask |= ndi.binary_dilation(img) != img
In [13]:
plot3D(
    img_membrane_mask,
    ["Membrane voxels (voxel)", "Membrane voxels (surface)"],
    threshold_value=[1, 0],
    cmap=lt,
).show()
No description has been provided for this image

Define cells as any segmented pixel excluding membrane pixels¶

  • Now we select all pixels that were identified as cells
  • Then we exclude pixels that are part of the membrane mask to leave a cell mask
In [14]:
img_cell_mask = img_indexed != 0
img_cell_mask = img_cell_mask & (img_cell_mask != img_membrane_mask)
In [15]:
plot3D(
    img_cell_mask,
    ["Cell voxels (voxel)", "Cell voxels (surface)"],
    threshold_value=[1, 0],
    cmap=lt,
).show()
No description has been provided for this image

Construct segmented geometry image for sme¶

  • From these masks we can construct a segmented geometry image for sme
  • Each colour in this image can then be assigned to a compartment in the model
  • We export the result as a 3d tiff to be imported into sme
In [16]:
img = np.zeros(img_cell_mask.shape, dtype=np.uint8)
img[img_cell_mask] = 1
img[img_membrane_mask] = 2
tifffile.imwrite("geom3d.tiff", img)
In [17]:
plot3D(
    img,
    ["Geometry image (voxel)", "Geometry image (surface)"],
    threshold_value=[1, 0],
    cmap=lt,
).show()
No description has been provided for this image
In [18]:
def geometry_image(n_pixels):
    # generate a segmented image containing one randomly distributed, sized and deformed sphere
    max_radius = n_pixels / 3
    max_deform = 1.2
    voxels = np.zeros((n_pixels, n_pixels, n_pixels), dtype=np.uint16)
    center = np.random.randint(2, n_pixels - 2, 3)
    nuclear_radius = np.random.randint(1, max_radius / 2)
    cell_radius = np.random.randint(1.5 * nuclear_radius, max_radius)
    deformation = np.random.uniform(1 / max_deform, max_deform, 3)
    voxels[sphere_mask(voxels.shape, center, cell_radius, deformation)] = 2
    voxels[sphere_mask(voxels.shape, center, nuclear_radius, deformation)] = 1
    return voxels
In [19]:
img = geometry_image(30)
plot3D(
    img, ["Geometry image (voxel)", "Geometry image (surface)"], threshold_value=[1, 0]
).show()
No description has been provided for this image

Create sme model¶

  • This was done using the GUI, starting from the 2d toy model & importing a 3d geometry image
  • As in the 2d case: one species in each compartment, intially only non-zero in outside
  • Reactions: outside <-> membrane and membrane <-> cell
  • Here we open this model, then import the new geometry image generated above
In [20]:
import sme
In [21]:
model = sme.Model("3d-toy-model.xml")
In [22]:
model.import_geometry_from_image("geom3d.tiff")
In [23]:
plotter = plot3D(
    model.compartment_image,
    ["Model geometry (voxel)", "Model geometry (surface)"],
    threshold_value=[0, 0],
    cmap=lt,
    with_aux=False,
)

plotter.show()
No description has been provided for this image

Simulate model¶

  • simulate for 60s, storing the results every 30s
  • this might take a few minutes
In [24]:
simulation_results = model.simulate(60, 2)

Simulation results¶

In [25]:
plot3D(
    data=simulation_results[0].concentration_image,
    title=f"Concentrations at t={simulation_results[0].time_point}",
    cmap="viridis",
    with_aux=False,
    threshold_value=0,
    with_cbar=True,
).show()
No description has been provided for this image
In [26]:
plot3D(
    simulation_results[15].concentration_image / 255.0,
    f"Concentrations at t={simulation_results[15].time_point}",
    cmap="viridis",
    with_aux=False,
    threshold_value=0,
).show()
No description has been provided for this image
In [27]:
plot3D(
    simulation_results[30].concentration_image / 255.0,
    f"Concentrations at t={simulation_results[30].time_point}",
    cmap="viridis",
    with_aux=False,
    threshold_value=0,
).show()
No description has been provided for this image
In [28]:
animate(
    [r.concentration_image for r in simulation_results],
    [f"Concentrations at t={r.time_point}" for r in simulation_results],
    with_swap=False,
    cmap="viridis",
)
Out[28]:
Your browser does not support the video tag.