Source code for paprica.viewer

"""
Module containing classes and functions relative to Viewing.


By using this code you agree to the terms of the software license agreement.

© Copyright 2020 Wyss Center for Bio and Neuro Engineering – All rights reserved
"""

import os
from glob import glob

import matplotlib.pyplot as plt
import napari
import numpy as np
import pandas as pd
import pyapr
from matplotlib.colors import LogNorm
from napari.layers import Image, Labels, Points
from skimage.color import hsv2rgb
from skimage.exposure import rescale_intensity
from skimage.filters import gaussian
from skimage.io import imread
from skimage.transform import resize

import paprica


[docs]def display_apr_from_path(path, **kwargs): """ Display an APR using Napari from a filepath. Parameters ---------- path: string path to APR to be displayed kwargs: dict optional parameters for Napari Returns ------- None """ apr = pyapr.APR() parts = pyapr.ShortParticles() pyapr.io.read(path, apr, parts) layer = apr_to_napari_Image(apr, parts) display_layers_pyramidal([layer], level_delta=0, **kwargs)
[docs]def display_apr(apr, parts, **kwargs): """ Display an APR using Napari from previously loaded data. Parameters ---------- apr : pyapr.APR Input APR data structure parts : pyapr.FloatParticles, pyapr.ShortParticles Input particle intensities kwargs: dict optional parameters for Napari Returns ------- None """ l = apr_to_napari_Image(apr, parts, **kwargs) display_layers_pyramidal([l], level_delta=0)
[docs]def apr_to_napari_Image(apr: pyapr.APR, parts: (pyapr.ShortParticles, pyapr.FloatParticles), mode: str = 'constant', level_delta: int = 0, **kwargs): """ Construct a napari 'Image' layer from an APR. Pixel values are reconstructed on the fly via the APRSlicer class. Parameters ---------- apr : pyapr.APR Input APR data structure parts : pyapr.FloatParticles or pyapr.ShortParticles Input particle intensities mode: str (default: 'constant') Interpolation mode to reconstruct pixel values. Supported values are constant: piecewise constant interpolation smooth: smooth interpolation (via level-adaptive separable smoothing). Note: significantly slower than constant. level: interpolate the particle levels to the pixels level_delta: int Sets the resolution of the reconstruction. The size of the image domain is multiplied by a factor of 2**level_delta. Thus, a value of 0 corresponds to the original pixel image resolution, -1 halves the resolution and +1 doubles it. (default: 0) Returns ------- out : napari.layers.Image An Image layer of the APR that can be viewed in napari. """ if 'contrast_limits' in kwargs: contrast_limits = kwargs.get('contrast_limits') del kwargs['contrast_limits'] else: cmin = apr.level_min() if mode == 'level' else parts.min() cmax = apr.level_max() if mode == 'level' else parts.max() contrast_limits = [cmin, cmax] if 'tree_mode' in kwargs: tree_mode = kwargs.get('tree_mode') del kwargs['tree_mode'] else: tree_mode = 'mean' par = apr.get_parameters() return Image(data=pyapr.reconstruction.APRSlicer(apr, parts, mode=mode, level_delta=level_delta, tree_mode=tree_mode), rgb=False, multiscale=False, contrast_limits=contrast_limits, scale=[par.dz, par.dx, par.dy], **kwargs)
[docs]def apr_to_napari_Labels(apr: pyapr.APR, parts: pyapr.ShortParticles, mode: str = 'constant', level_delta: int = 0, **kwargs): """ Construct a napari 'Layers' layer from an APR. Pixel values are reconstructed on the fly via the APRSlicer class. Parameters ---------- apr : pyapr.APR Input APR data structure parts : pyapr.FloatParticles or pyapr.ShortParticles Input particle intensities mode: str (default: 'constant') Interpolation mode to reconstruct pixel values. Supported values are constant: piecewise constant interpolation smooth: smooth interpolation (via level-adaptive separable smoothing). Note: significantly slower than constant. level: interpolate the particle levels to the pixels level_delta: int Sets the resolution of the reconstruction. The size of the image domain is multiplied by a factor of 2**level_delta. Thus, a value of 0 corresponds to the original pixel image resolution, -1 halves the resolution and +1 doubles it. (default: 0) Returns ------- out : napari.layers.Image A Labels layer of the APR that can be viewed in napari. """ if 'contrast_limits' in kwargs: del kwargs['contrast_limits'] par = apr.get_parameters() return Labels(data=pyapr.reconstruction.APRSlicer(apr, parts, mode=mode, level_delta=level_delta, tree_mode='max'), multiscale=False, scale=[par.dz, par.dx, par.dy], **kwargs)
# Define a callback that will take the value of the slider and the viewer
[docs]def resolution_callback(viewer, value): for l in viewer.layers: if isinstance(l.data, pyapr.reconstruction.APRSlicer): old_value = -l.data.patch.level_delta l.data.set_level_delta(-value) l.translate = l.translate/2**(value-old_value) viewer.dims.set_point(axis=0, value=viewer.dims.point[0] / 2 ** (value-old_value)) viewer.status = str(value) viewer._update_layers() viewer.reset_view()
[docs]def display_layers(layers): """ Display a list of layers using Napari. Parameters ---------- layers: list[napari.Layer] list of layers to display Returns ------- viewer: napari.Viewer napari viewer. """ viewer = napari.Viewer() for layer in layers: viewer.add_layer(layer) napari.run() return viewer
[docs]def display_layers_pyramidal(layers, level_delta): """ Display a list of layers using Napari. Parameters ---------- layers: list[napari.Layer] list of layers to display Returns ------- viewer: napari.Viewer napari viewer. """ viewer = napari.Viewer() for layer in layers: viewer.add_layer(layer) from qtpy.QtCore import Qt from qtpy.QtWidgets import QSlider my_slider = QSlider(Qt.Horizontal) my_slider.setMinimum(0) layers_apr = [l for l in layers if isinstance(l.data, pyapr.reconstruction.APRSlicer)] l_max = np.min([l.data.apr.level_max() for l in layers_apr]) l_min = 5 if l_max > 5 else 1 my_slider.setMaximum(l_max-l_min) my_slider.setSingleStep(1) my_slider.setValue(-level_delta) # Connect your slider to your callback function my_slider.valueChanged[int].connect( lambda value=my_slider: resolution_callback(viewer, value) ) viewer.window.add_dock_widget(my_slider, name='Downsampling', area='left') napari.run() return viewer
[docs]def display_segmentation(apr, parts, mask, pyramidal=True, **kwargs): """ This function displays an image and its associated segmentation map. It uses napari to lazily generate the pixel data from APR on the fly. Parameters ---------- apr: pyapr.APR apr object parts: pyapr.ParticleData particle object representing the image mask: pyapr.ParticleData particle object representing the segmentation mask/connected component Returns ------- None """ layers = [] layers.append(apr_to_napari_Image(apr, parts, name='APR', **kwargs)) layers.append(apr_to_napari_Labels(apr, mask, name='Segmentation', opacity=0.3, **kwargs)) if pyramidal: display_layers_pyramidal(layers, level_delta=0) else: display_layers(layers)
[docs]def display_heatmap(heatmap, atlas=None, data=None, log=False): """ Display a heatmap (e.g. cell density) that can be overlaid on intensity data and atlas. Parameters ---------- heatmap: ndarray array containing the heatmap to be displayed atlas: ndarray array containing the atlas which will be automatically scaled to the heatmap data: ndarray array containing the data. log: bool plot in logscale (only used for 2D). Returns ------- None """ # If u is 2D then use matplotlib so we have a scale bar if heatmap.ndim == 2: fig, ax = plt.subplots() if log: h = ax.imshow(heatmap, norm=LogNorm(), cmap='jet') else: h = ax.imshow(heatmap, cmap='jet') cbar = fig.colorbar(h, ax=ax) cbar.set_label('Number of detected cells') ax.set_xticks([]) ax.set_yticks([]) # If u is 3D then use napari but no colorbar for now elif heatmap.ndim == 3: with napari.gui_qt(): viewer = napari.Viewer() viewer.add_image(heatmap, colormap='inferno', name='Heatmap', blending='additive', opacity=0.7) if atlas is not None: viewer.add_labels(atlas, name='Atlas regions', opacity=0.7) if data is not None: viewer.add_image(data, name='Intensity data', blending='additive', scale=np.array(heatmap.shape)/np.array(data.shape), opacity=0.7)
[docs]def compare_stitching(stitcher1, stitcher2, loc=None, n_proj=0, dim=0, downsample=2, color=False, rel_map=False): """ Compare two stitching at a given position `loc` for a given dimension `dim`. Parameters ---------- stitcher1: tileStitcher stitcher object 1 stitcher2: tileStitcher stitcher object 2 loc: int position in the given dimension dim: int dimension to use for comparison n_proj: int number of plane to perform the max-projection downsample: int downsampling factor for the reconstruction color: bool option to display in color rel_map: bool overlay reliability map on the reconstructed data Returns ------- None """ u1 = stitcher1.reconstruct_slice(loc=loc, n_proj=n_proj, dim=dim, downsample=downsample, color=color, plot=False) u2 = stitcher2.reconstruct_slice(loc=loc, n_proj=n_proj, dim=dim, downsample=downsample, color=color, plot=False) if color: fig, ax = plt.subplots(1, 2, sharex=True, sharey=True) data_to_display = np.ones_like(u1, dtype='uint8') for i in range(2): tmp = np.log(u1[:, :, i] + 200) vmin, vmax = np.percentile(tmp[tmp > np.log(1 + 200)], (1, 99.9)) data_to_display[:, :, i] = rescale_intensity(tmp, in_range=(vmin, vmax), out_range='uint8') ax[0].imshow(data_to_display) data_to_display = np.ones_like(u2, dtype='uint8') for i in range(2): tmp = np.log(u2[:, :, i] + 200) vmin, vmax = np.percentile(tmp[tmp > np.log(1 + 200)], (1, 99.9)) data_to_display[:, :, i] = rescale_intensity(tmp, in_range=(vmin, vmax), out_range='uint8') ax[1].imshow(data_to_display) else: fig, ax = plt.subplots(1, 2, sharex=True, sharey=True) ax[0].imshow(np.log(u1), cmap='gray') if rel_map: try: rel_map = resize(np.mean(stitcher1.plot_stitching_info(), axis=0), u1.shape, order=1) ax[0].imshow(rel_map, cmap='turbo', alpha=0.5) except: pass ax[1].imshow(np.log(u2), cmap='gray') if rel_map: try: rel_map = resize(np.mean(stitcher2.plot_stitching_info(), axis=0), u1.shape, order=1) ax[1].imshow(rel_map, cmap='turbo', alpha=0.5) except: pass
[docs]def reconstruct_colored_projection(apr, parts, loc=None, dim=0, n_proj=0, downsample=1, threshold=None, plot=True): """ Compare two stitching at a given position `loc` for a given dimension `dim`. Parameters ---------- apr: pyapr.APR apr tree object parts: pyapr.ParticleData apr particles loc: int position in the given dimension dim: int dimension to use for comparison n_proj: int number of plane to perform the max-projection downsample: int downsampling factor for the reconstruction color: bool option to display in color rel_map: bool overlay reliability map on the reconstructed data Returns ------- None """ level_delta = int(-np.sign(downsample) * np.log2(np.abs(downsample))) if loc is None: apr_shape = apr.shape() loc = int(apr_shape[dim] / 2) if loc > apr_shape[dim]: raise ValueError('Error: loc is too large ({}), maximum loc at this downsample is {}.'.format(loc, apr_shape[dim])) locf = min(loc+n_proj, apr_shape[dim]) patch = pyapr.ReconPatch() if dim == 0: patch.z_begin = loc patch.z_end = locf if dim == 1: patch.y_begin = loc patch.y_end = locf if dim == 2: patch.x_begin = loc patch.x_end = locf data = pyapr.reconstruction.reconstruct_constant(apr, parts, patch=patch) V = data.max(axis=dim) S = np.ones_like(V) * 0.7 if threshold is not None: S[V<threshold] = 0 H = np.argmax(data, axis=dim) H = rescale_intensity(gaussian(H, sigma=5), out_range=np.float64)*0.66 V = np.log(V + 200) vmin, vmax = np.percentile(V[V > np.log(100)], (1, 99.9)) V = rescale_intensity(V, in_range=(vmin, vmax), out_range=np.float64) S = S * V rgb = hsv2rgb(np.dstack((H,S,V))) rescale_intensity(rgb, out_range='uint8') if plot: plt.figure() plt.imshow(rgb) return rgb
[docs]class tileViewer(): """ Class to display the registration and segmentation using Napari. """
[docs] def __init__(self, tiles, database, segmentation: bool=False, cells=None, atlaser=None): """ Parameters ---------- tiles: tileParser tileParser object containing the dataset to be displayed. database: (pd.Dataframe, string, tileStitcher) database containing the tile positions. segmentation: bool option to also display the segmentation (connected component) data. cells: ndarray cells center to be displayed. atlaser: tileAtlaser tileAtlaser object containing the Atlas to be displayed. """ self.tiles = tiles if isinstance(database, paprica.stitcher.tileStitcher): self.database = database.database elif isinstance(database, pd.DataFrame): self.database = database elif isinstance(database, str): self.database = pd.read_csv(database) else: raise TypeError('Error: unknown type for database.') self.nrow = tiles.nrow self.ncol = tiles.ncol self.loaded_ind = [] self.loaded_tiles = {} self.segmentation = segmentation self.loaded_segmentation = {} self.cells = cells self.atlaser = atlaser
[docs] def get_layers_all_tiles(self, downsample=1, **kwargs): """ Display all parsed tiles. Parameters ---------- downsample: int downsampling parameter for APRSlicer (1: full resolution, 2: 2x downsampling, 4: 4x downsampling..etc) kwargs: dict dictionary passed to Napari for custom option Returns ------- layers: list[napari.Layer] list of layers to be displayed by Napari """ # Compute layers to be displayed by Napari layers = [] # Convert downsample to level delta level_delta = int(-np.sign(downsample)*np.log2(np.abs(downsample))) for tile in self.tiles: # Load tile if not loaded, else use cached tile ind = np.ravel_multi_index((tile.row, tile.col), dims=(self.nrow, self.ncol)) if self._is_tile_loaded(tile.row, tile.col): apr, parts = self.loaded_tiles[ind] if self.segmentation: cc = self.loaded_segmentation[ind] else: tile.load_tile() apr, parts = tile.apr, tile.parts self.loaded_ind.append(ind) self.loaded_tiles[ind] = apr, parts if self.segmentation: tile.load_segmentation() cc = tile.parts_cc self.loaded_segmentation[ind] = cc position = self._get_tile_position(tile.row, tile.col) if level_delta != 0: position = [x/downsample for x in position] layers.append(apr_to_napari_Image(apr, parts, mode='constant', name='Tile [{}, {}]'.format(tile.row, tile.col), translate=position, opacity=0.7, level_delta=level_delta, **kwargs)) if self.segmentation: layers.append(apr_to_napari_Labels(apr, cc, mode='constant', name='Segmentation [{}, {}]'.format(tile.row, tile.col), translate=position, level_delta=level_delta, opacity=0.7)) if self.cells is not None: par = apr.get_parameters() layers.append(Points(self.cells, opacity=0.7, name='Cells center', scale=[par.dz/downsample, par.dx/downsample, par.dy/downsample])) if self.atlaser is not None: layers.append(Labels(self.atlaser.atlas, opacity=0.7, name='Atlas', scale=[self.atlaser.z_downsample/downsample, self.atlaser.y_downsample/downsample, self.atlaser.x_downsample/downsample])) return layers
[docs] def display_all_tiles(self, pyramidal=True, downsample=1, color=False, **kwargs): """ Display all parsed tiles. Parameters ---------- pyramidal: bool option to have a slider that controls the displayed resolution downsample: int downsampling parameter for APRSlicer (1: full resolution, 2: 2x downsampling, 4: 4x downsampling..etc) kwargs: dict dictionary passed to Napari for custom option Returns ------- None """ # Compute layers to be displayed by Napari layers = [] # Convert downsample to level delta level_delta = int(-np.sign(downsample)*np.log2(np.abs(downsample))) for tile in self.tiles: # Load tile if not loaded, else use cached tile ind = np.ravel_multi_index((tile.row, tile.col), dims=(self.nrow, self.ncol)) if self._is_tile_loaded(tile.row, tile.col): apr, parts = self.loaded_tiles[ind] if self.segmentation: cc = self.loaded_segmentation[ind] else: tile.load_tile() apr, parts = tile.apr, tile.parts self.loaded_ind.append(ind) self.loaded_tiles[ind] = apr, parts if self.segmentation: tile.load_segmentation() cc = tile.parts_cc self.loaded_segmentation[ind] = cc position = self._get_tile_position(tile.row, tile.col) if color: blending = 'additive' if tile.col % 2: if tile.row % 2: cmap = 'red' else: cmap = 'green' else: if tile.row % 2: cmap = 'green' else: cmap = 'red' else: cmap = 'gray' blending = 'translucent' if level_delta != 0: position = [x/downsample for x in position] layers.append(apr_to_napari_Image(apr, parts, mode='constant', name='Tile [{}, {}]'.format(tile.row, tile.col), translate=position, opacity=0.7, level_delta=level_delta, **kwargs)) if self.segmentation: layers.append(apr_to_napari_Labels(apr, cc, mode='constant', name='Segmentation [{}, {}]'.format(tile.row, tile.col), translate=position, level_delta=level_delta, blending=blending, opacity=0.7)) if self.cells is not None: par = apr.get_parameters() layers.append(Points(self.cells, opacity=0.7, name='Cells center', scale=[par.dz/downsample, par.dx/downsample, par.dy/downsample])) if self.atlaser is not None: layers.append(Labels(self.atlaser.atlas, opacity=0.7, name='Atlas', scale=[self.atlaser.z_downsample/downsample, self.atlaser.y_downsample/downsample, self.atlaser.x_downsample/downsample])) # Display layers if pyramidal: display_layers_pyramidal(layers, level_delta) else: display_layers(layers)
[docs] def display_tiles(self, coords, pyramidal=True, downsample=1, color=False, **kwargs): """ Display tiles at position coords. Parameters ---------- coords: list list of tuples (row, col) containing the tile coordinate to display. downsample: int downsampling parameter for APRSlicer (1: full resolution, 2: 2x downsampling, 4: 4x downsampling..etc) kwargs: dict dictionary passed to Napari for custom option color: bool option to display in color Returns ------- None """ # Compute layers to be displayed by Napari layers = [] # Convert downsample to level delta level_delta = int(-np.sign(downsample) * np.log2(np.abs(downsample))) for tile in self.tiles: if (tile.row, tile.col) in coords: # Load tile if not loaded, else use cached tile ind = np.ravel_multi_index((tile.row, tile.col), dims=(self.nrow, self.ncol)) if self._is_tile_loaded(tile.row, tile.col): apr, parts = self.loaded_tiles[ind] if self.segmentation: cc = self.loaded_segmentation[ind] else: tile.load_tile() apr, parts = tile.apr, tile.parts self.loaded_ind.append(ind) self.loaded_tiles[ind] = apr, parts if self.segmentation: tile.load_segmentation() cc = tile.parts_cc self.loaded_segmentation[ind] = cc position = self._get_tile_position(tile.row, tile.col) if level_delta != 0: position = [x / downsample for x in position] if color: blending = 'additive' if tile.col % 2: if tile.row % 2: cmap = 'red' else: cmap = 'green' else: if tile.row % 2: cmap = 'green' else: cmap = 'red' else: cmap = 'gray' blending = 'translucent' layers.append(apr_to_napari_Image(apr, parts, mode='constant', name='Tile [{}, {}]'.format(tile.row, tile.col), translate=position, opacity=0.7, level_delta=level_delta, colormap=cmap, blending=blending, **kwargs)) if self.segmentation: layers.append(apr_to_napari_Labels(apr, cc, mode='constant', name='Segmentation [{}, {}]'.format(tile.row, tile.col), translate=position, level_delta=level_delta, opacity=0.7)) if self.cells is not None: par = apr.get_parameters() layers.append(Points(self.cells, opacity=0.7, name='Cells center', scale=[par.dz / downsample, par.dx / downsample, par.dy / downsample])) if self.atlaser is not None: layers.append(Labels(self.atlaser.atlas, opacity=0.7, name='Atlas', scale=[self.atlaser.z_downsample / downsample, self.atlaser.y_downsample / downsample, self.atlaser.x_downsample / downsample])) # Display layers if pyramidal: display_layers_pyramidal(layers, level_delta) else: display_layers(layers)
[docs] def check_stitching(self, downsample=8, color=False, **kwargs): """ Function to display the stitched dataset using napari. Parameters ---------- downsample: int downsampling parameter for APRSlicer (1: full resolution, 2: 2x downsampling, 4: 4x downsampling..etc) color: bool option to display in color kwargs: dict dictionary passed to Napari for custom option Returns ------- None """ # Compute layers to be displayed by Napari layers = [] # Convert downsample to level delta level_delta = int(-np.sign(downsample)*np.log2(np.abs(downsample))) for tile in self.tiles: tile.lazy_load_tile(level_delta=level_delta) position = self._get_tile_position(tile.row, tile.col) if level_delta != 0: position = [x/downsample for x in position] if color: blending = 'additive' if tile.col % 2: if tile.row % 2: cmap = 'red' else: cmap = 'green' else: if tile.row % 2: cmap = 'green' else: cmap = 'red' else: cmap = 'gray' blending = 'translucent' layers.append(Image(tile.lazy_data, name='Tile [{}, {}]'.format(tile.row, tile.col), translate=position, opacity=0.7, colormap=cmap, blending=blending, **kwargs)) display_layers(layers)
[docs] def _is_tile_loaded(self, row, col): """ Returns True is tile is loaded, False otherwise. """ ind = np.ravel_multi_index((row, col), dims=(self.nrow, self.ncol)) return ind in self.loaded_ind
[docs] def _load_tile(self, row, col): """ Load the tile at position [row, col]. """ df = self.database path = df[(df['row'] == row) & (df['col'] == col)]['path'].values[0] if self.tiles.type == 'tiff2D': files = glob(os.path.join(path, '*.tif')) im = imread(files[0]) u = np.zeros((len(files), *im.shape)) u[0] = im files.pop(0) for i, file in enumerate(files): u[i+1] = imread(file) return self._get_apr(u) elif self.tiles.type == 'tiff3D': u = imread(path) return self._get_apr(u) elif self.tiles.type == 'apr': apr = pyapr.APR() parts = pyapr.ShortParticles() pyapr.io.read(path, apr, parts) u = (apr, parts) return u else: raise TypeError('Error: image type {} not supported.'.format(self.type))
[docs] def _get_tile_position(self, row, col): """ Parse tile position in the database. """ df = self.database tile_df = df[(df['row'] == row) & (df['col'] == col)] px = tile_df['ABS_H'].values[0] py = tile_df['ABS_V'].values[0] pz = tile_df['ABS_D'].values[0] return [pz, py, px]