Source code for paprica.stitcher

"""
Submodule containing classes and functions relative to **stitching**.

With this submodule the user can stitch a previously parsed dataset, typically the autofluorescence channel:

>>> import paprica
>>> tiles_autofluo = paprica.parser.tileParser(path_to_autofluo, frame_size=1024, overlap=25)
>>> stitcher = paprica.stitcher.tileStitcher(tiles_autofluo)
>>> stitcher.compute_registration_fast()

Others channel can then easily stitched using the previous one as reference:

>>> tiles_signal = paprica.parser.tileParser(path_to_data, frame_size=1024, overlap=25)
>>> stitcher_channel = paprica.stitcher.channelStitcher(stitcher, tiles_autofluo, tiles_signal)
>>> stitcher_channel.compute_rigid_registration()

Doing that each tile in the second data set will be registered to the corresponding autofluorescence tile and
then their spatial position will be adjusted.

WARNING: when stitching, the expected overlap must be HIGHER than the real one. To enforce this, a margin of 20% is
automatically taken (this margin can be set lower by the user for speed improvement). In order to get the best stitching
quality it requires to have a good estimate of the overlap, hence why the full volume is not considered.

This submodule also contains a class for merging and reconstructing the data. It was intended to be used at lower
resolution for atlasing. The generated data can quickly become out of hands, use with caution!

By using this code you agree to the terms of the software license agreement.

© Copyright 2020 Wyss Center for Bio and Neuro Engineering – All rights reserved
"""

import os
import warnings
from pathlib import Path

import cv2 as cv
import dill
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyapr
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
# from skimage.registration import phase_cross_correlation
from scipy.signal import correlate
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree, depth_first_order
from skimage.color import label2rgb, hsv2rgb
from skimage.exposure import equalize_adapthist, rescale_intensity
from skimage.filters import gaussian
from skimage.metrics import normalized_root_mse
from skimage.transform import warp, AffineTransform, downscale_local_mean
from tqdm import tqdm
import napari


[docs]def max_sum_over_single_max(reference_image, moving_image, d): """ This function is a reliability metric which works well for sparse data. It computes the 99 percentile of the sum of reference and shifted image divided by twice the 99 percentile of the reference image. Parameters ---------- reference_image: ndarray 2D array of the reference image moving_image: ndarray 2D array of the moving image image d: array_like registration parameters Returns ------- e: float error estimation of the registration. The lower e, the more reliable the registration. """ shifted_image = warp(moving_image, AffineTransform(translation=[d[1], d[0]]), mode='wrap', preserve_range=True) e = (2*np.percentile(reference_image, 99))/np.percentile(reference_image+shifted_image, 99) return e
[docs]def mse(reference_image, moving_image, d): """ Normalized root mean square error. Parameters ---------- reference_image: ndarray 2D array of the reference image moving_image: ndarray 2D array of the moving image image d: array_like registration parameters Returns ------- _: float error estimation of the registration. The lower e, the more reliable the registration. """ shifted_image = warp(moving_image, AffineTransform(translation=[d[1], d[0]]), mode='wrap', preserve_range=True) return normalized_root_mse(reference_image, shifted_image, normalization='mean')
[docs]def phase_cross_correlation(reference_image, moving_image, upsample_factor=1, return_error=True): """ Phase cross correlation. Because skimage function compute the NORMAL cross correlation to estimate the shift I modified it to compute the TRUE phase cross correlation, as per the standard definition. Parameters ---------- reference_image : array Reference image. moving_image : array Image to register. Must be same dimensionality as ``reference_image``. upsample_factor : int, optional Upsampling factor. Images will be registered to within ``1 / upsample_factor`` of a pixel. For example ``upsample_factor == 20`` means the images will be registered within 1/20th of a pixel. Default is 1 (no upsampling). Not used if any of ``reference_mask`` or ``moving_mask`` is not None. return_error : bool, optional Returns error and phase difference if on, otherwise only shifts are returned. Has noeffect if any of ``reference_mask`` or ``moving_mask`` is not None. In this case only shifts is returned. Returns ------- shifts : ndarray Shift vector (in pixels) required to register ``moving_image`` with ``reference_image``. Axis ordering is consistent with numpy (e.g. Z, Y, X) error : float Translation invariant normalized RMS error between ``reference_image`` and ``moving_image``. phasediff : float Global phase difference between the two images (should be zero if images are non-negative). """ # images must be the same shape if reference_image.shape != moving_image.shape: raise ValueError("images must be same shape") src_freq = np.fft.fftn(reference_image) target_freq = np.fft.fftn(moving_image) # Whole-pixel shift - Compute cross-correlation by an IFFT shape = src_freq.shape image_product = src_freq * target_freq.conj() eps = np.finfo(image_product.real.dtype).eps image_product /= (np.abs(image_product) + eps) cross_correlation = np.fft.ifftn(image_product) # Locate maximum maxima = np.unravel_index(np.argmax(np.abs(cross_correlation)), cross_correlation.shape) midpoints = np.array([np.fix(axis_size / 2) for axis_size in shape]) shifts = np.stack(maxima).astype(np.float64) shifts[shifts > midpoints] -= np.array(shape)[shifts > midpoints] if upsample_factor == 1: if return_error: src_amp = np.sum(np.real(src_freq * src_freq.conj())) src_amp /= src_freq.size target_amp = np.sum(np.real(target_freq * target_freq.conj())) target_amp /= target_freq.size CCmax = cross_correlation[maxima] # If upsampling > 1, then refine estimate with matrix multiply DFT else: raise ValueError('Error: upsampled phase cross corrrelation not implemented here, use skimage.') # If its only one row or column the shift along that dimension has no # effect. We set to zero. for dim in range(src_freq.ndim): if shape[dim] == 1: shifts[dim] = 0 if return_error: error = np.real(1.0 - CCmax * CCmax.conj()) phase_diff = np.arctan2(CCmax.imag, CCmax.real) return shifts, np.sqrt(np.abs(error)), phase_diff else: return shifts
[docs]def phase_cross_correlation_cv(reference_image, moving_image): """ Compute openCV to compute the phase cross correlation. It is around 16 times faster than the implementation using numpy FFT (same as skimage). Parameters ---------- reference_image : array Reference image. moving_image : array Image to register. Must be same dimensionality as ``reference_image``. Returns ------- shifts : ndarray Shift vector (in pixels) required to register ``moving_image`` with ``reference_image``. Axis ordering is consistent with numpy (e.g. Z, Y, X) error : float Peak response (see opencv description here: https://docs.opencv.org/4.5.3/d7/df3/group__imgproc__motion.html#ga552420a2ace9ef3fb053cd630fdb4952) """ d, e = cv.phaseCorrelate(reference_image.astype(np.float32), moving_image.astype(np.float32)) d_correct = [-np.round(d[1]).astype(int), -np.round(d[0]).astype(int)] return d_correct
[docs]def _compute_shift(reference_image, moving_image): """ Backbone function to compute the registration and the registration error used for the global optimisation. This function can be replaced by experienced user to use their own registration and error estimation functions. Parameters ---------- reference_image : array Reference image. moving_image : array Image to register. Must be same dimensionality as ``reference_image``. Returns ------- d: array_like registration parameters found e: float error estimation for the registration (the higher the error the higher the registration uncertainty) """ d = phase_cross_correlation_cv(reference_image, moving_image) e = max_sum_over_single_max(reference_image, moving_image, d) e = e/np.sqrt(np.mean(reference_image))*10 return d, e
[docs]def _get_max_proj_apr(apr, parts, patch, patch_yx=None, plot=False): """ Compute maximum projection on 3D APR data. Parameters ---------- apr: pyapr.APR apr tree parts: pyapr.ParticlData apr particle patch: pyapr.ReconPatch patch for computing the projection only on the overlapping area. plot: bool control data plotting Returns ------- _: list[ndarray] maximum intensity projection in each 3 dimension. """ proj = [] if patch_yx is None: for d in range(3): # dim=0: project along Y to produce a ZY plane # dim=1: project along X to produce a ZX plane # dim=2: project along Z to produce an YX plane proj.append(pyapr.transform.maximum_projection(apr, parts, dim=d, patch=patch, method='auto')) else: proj.append(pyapr.transform.projection.maximum_projection(apr, parts, dim=0, patch=patch, method='auto')) proj.append(pyapr.transform.projection.maximum_projection(apr, parts, dim=1, patch=patch, method='auto')) proj.append(pyapr.transform.projection.maximum_projection(apr, parts, dim=2, patch=patch_yx, method='auto')) if plot: fig, ax = plt.subplots(1, 3) for i, title in enumerate(['ZY', 'ZX', 'YX']): ax[i].imshow(proj[i], cmap='gray') ax[i].set_title(title) return proj[0], proj[1], proj[2]
[docs]def _get_proj_shifts(proj1, proj2): """ This function computes shifts from max-projections on overlapping areas. It uses the phase cross-correlation to compute the shifts. Parameters ---------- proj1: list[ndarray] max-projections for tile 1 proj2: list[ndarray] max-projections for tile 2 upsample_factor: float upsampling_factor for estimating the maximum phase cross-correlation position Returns ------- _: array_like shifts in (x, y, z) and error measure (0=reliable, 1=not reliable) """ # Compute phase cross-correlation to extract shifts dzy, error_zy = _compute_shift(proj1[0], proj2[0]) dzx, error_zx = _compute_shift(proj1[1], proj2[1]) dyx, error_yx = _compute_shift(proj1[2], proj2[2]) # Replace error == 0 with 1 otherwise the minimum spanning tree considers that vertex are not connected if error_zy == 0: error_zy = 1e-6 if error_zx == 0: error_zx = 1e-6 if error_yx == 0: error_yx = 1e-6 # Keep only the most reliable registration # D/z if error_zx < error_zy: dz = dzx[0] rz = error_zx else: dz = dzy[0] rz = error_zy # H/x if error_zx < error_yx: dx = dzx[1] rx = error_zx else: dx = dyx[1] rx = error_yx # V/y if error_yx < error_zy: dy = dyx[0] ry = error_yx else: dy = dzy[1] ry = error_zy # for i, title, vector, err in zip(range(3), ['ZY', 'ZX', 'YX'], [dzy, dzx, dyx], [error_zy, error_zx, error_yx]): # fig, ax = plt.subplots(1, 3, sharex=True, sharey=True) # ax[0].imshow(np.log(proj1[i]+1), cmap='gray') # ax[0].set_title('d={}, e={:0.3f}'.format(vector, err)) # ax[1].imshow(np.log(proj2[i]+1), cmap='gray') # ax[1].set_title(title) # # shifted = warp(proj1[i], AffineTransform(translation=[vector[1], vector[0]]), mode='wrap', preserve_range=True) # rgb = np.dstack((np.log(proj2[i]+1), np.log(shifted+1), np.zeros_like(proj1[i]))) # ax[2].imshow((rescale_intensity(rgb, out_range='uint8')).astype('uint8')) # # print('ok') return np.array([dz, dy, dx]), np.array([rz, ry, rx])
[docs]def _get_masked_proj_shifts(proj1, proj2, threshold, upsample_factor=1): """ This function computes shifts from max-projections on overlapping areas with mask on brightest area. It uses the phase cross-correlation to compute the shifts. Parameters ---------- proj1: list[ndarray] max-projections for tile 1 proj2: list[ndarray] max-projections for tile 2 upsample_factor: float upsampling_factor for estimating the maximum phase cross-correlation position Returns ------- _: array_like shifts in (x, y, z) and error measure (0=reliable, 1=not reliable) """ # Compute mask to discard very bright area that are likely bubbles or artefacts mask_ref = [] mask_move = [] for i in range(3): vmax = np.percentile(proj1[i], threshold) mask_ref.append(proj1[i] < vmax) vmax = np.percentile(proj2[i], threshold) mask_move.append(proj2[i] < vmax) # Compute phase cross-correlation to extract shifts dzy = phase_cross_correlation(proj1[0], proj2[0], return_error=True, upsample_factor=upsample_factor, reference_mask=mask_ref[0], moving_mask=mask_move[0]) error_zy = np.sqrt(1 - correlate(proj1[0], proj2[0]).max() ** 2 / (np.sum(proj1 ** 2) * np.sum(proj2 ** 2))) dzx = phase_cross_correlation(proj1[1], proj2[1], return_error=True, upsample_factor=upsample_factor, reference_mask=mask_ref[1], moving_mask=mask_move[1]) error_zx = np.sqrt(1 - correlate(proj1[1], proj2[1]).max() ** 2 / (np.sum(proj1 ** 2) * np.sum(proj2 ** 2))) dyx = phase_cross_correlation(proj1[2], proj2[2], return_error=True, upsample_factor=upsample_factor, reference_mask=mask_ref[2], moving_mask=mask_move[2]) error_yx = np.sqrt(1 - correlate(proj1[2], proj2[2]).max() ** 2 / (np.sum(proj1 ** 2) * np.sum(proj2 ** 2))) # Replace error == 0 with 1e-6 otherwise the minimum spanning tree considers that vertex are not connected if error_zy == 0: error_zy = 1e-6 if error_zx == 0: error_zx = 1e-6 if error_yx == 0: error_yx = 1e-6 # Keep only the most reliable registration # D/z if error_zx < error_zy: dz = dzx[0] rz = error_zx else: dz = dzy[0] rz = error_zy # H/x if error_zx < error_yx: dx = dzx[1] rx = error_zx else: dx = dyx[1] rx = error_yx # V/y if error_yx < error_zy: dy = dyx[0] ry = error_yx else: dy = dzy[1] ry = error_zy # for i, title, vector in zip(range(3), ['ZY', 'ZX', 'YX'], [[dy, dz], [dx, dz], [dx, dy]]): # fig, ax = plt.subplots(1, 3, sharex=True, sharey=True) # ax[0].imshow(proj1[i], cmap='gray') # ax[0].set_title('dx={}, dy={}, dz={}'.format(dx, dy, dz)) # ax[1].imshow(proj2[i], cmap='gray') # ax[1].set_title(title) # from skimage.transform import warp, AffineTransform # from skimage.exposure import rescale_intensity # shifted = warp(proj1[i], AffineTransform(translation=vector), mode='wrap', preserve_range=True) # rgb = np.dstack([proj2[i], shifted, np.zeros_like(proj1[i])]) # ax[2].imshow((rescale_intensity(rgb, out_range='uint8')).astype('uint8')) # print('ok') return np.array([dz, dy, dx]), np.array([rz, ry, rx])
[docs]class baseStitcher(): """ Base class for stitching multi-tile data. """
[docs] def __init__(self, tiles, overlap_h: (int, float), overlap_v: (int, float)): """ Constructor for the baseStitcher class. Parameters ---------- tiles: tileParser tileParser object containing the dataset to stitch. overlap_h: float expected horizontal overlap in % overlap_v: float expected vertical overlap in % Returns ------- None """ self.tiles = tiles self.ncol = tiles.ncol self.nrow = tiles.nrow self.n_vertex = tiles.n_tiles self.n_edges = tiles.n_edges self.frame_size = tiles.frame_size self.expected_overlap_h = int(overlap_h/100*self.frame_size) self.expected_overlap_v = int(overlap_v/100*self.frame_size) self.overlap_h = int(self.expected_overlap_h*1.2) if self.expected_overlap_h > self.frame_size: self.expected_overlap_h = self.frame_size self.overlap_v = int(self.expected_overlap_v*1.2) if self.expected_overlap_v > self.frame_size: self.expected_overlap_v = self.frame_size self.mask = False self.threshold = None self.segment = False self.segmenter = None self.reg_x = int(self.frame_size*0.05) self.reg_y = int(self.frame_size*0.05) self.reg_z = 20 self.z_begin = None self.z_end = None
[docs] def activate_mask(self, threshold): """ Activate the masked cross-correlation for the displacement estimation. Pixels above threshold are not taken into account. Parameters ---------- threshold: int threshold for the cross-correlation mask as a percentage of pixel to keep (e.g. 95 will create a mask removing the 5% brightest pixels). """ self.mask = True self.threshold = threshold
[docs] def deactivate_mask(self): """ Deactivate the masked cross-correlation and uses a classical cross correlation. """ self.mask = False self.threshold = None
[docs] def save_database(self, path=None): """ Save database at the given path. The database must be built before calling this method. Parameters ---------- path: string path to save the database. """ if self.database is None: raise TypeError('Error: database can''t be saved because it was not created. ' 'Please call build_database() first.') if path is None: path = os.path.join(self.tiles.path, 'registration_results.csv') self.database.to_csv(path)
[docs] def load_database(self, path=None, force=False): """ Save database at the given path. The database must be built before calling this method. Parameters ---------- path: string path to save the database. """ if self.database is not None and force is False: raise TypeError('Error: database can''t be read because a database is already exist and force is set' 'to False.') if path is None: path = os.path.join(self.tiles.path, 'registration_results.csv') self.database = pd.read_csv(path)
[docs] def activate_segmentation(self, segmenter): """ Activate the segmentation. When a tile is loaded it is segmented before the stitching is done. Parameters ---------- segmenter: tileSegmenter segmenter object for segmenting each tile. """ self.segment = True self.segmenter = segmenter
[docs] def deactivate_segmentation(self): """ Deactivate tile segmentation. """ self.segment = False
[docs] def reconstruct_slice(self, loc=None, n_proj=0, dim=0, downsample=1, color=False, debug=False, plot=True, seg=False, progress_bar=True): """ Reconstruct whole sample 2D section at the given location and in a given dimension. This function can also reconstruct a maximum intensity projection if `n_proj>0`. Parameters ---------- loc: int (default: middle of the sample) Position of the plane where the reconstruction should be done. The location varies depending on the downsample parameter and should be adapted. n_proj: int (default: 0) Number of planes to perform the maximum intensity projection. dim: int (default: 0) Dimension of the reconstruction, e.g. 0 will be [y, x] plane (orthogonal to z). downsample: int (default: 1) Downsample factor for the reconstruction. Must be in [1, 2, 4, 8, 16, 32]. color: bool (default: False) Option to reconstruct with checkerboard color pattern. Useful to identify doubling artifacts. debug: bool (default: False) Option to add a white square for each tile, making it easy to see overlapping areas. plot: bool (default: True) Define if the function plots the results with Matplotlib or just returns an array. seg: bool (default: False) Option to also reconstruct the segmentation. Only works with `dim=0` Returns ------- _: ndarray Array containing the reconstructed data. """ if dim == 0: return self._reconstruct_z_slice(z=loc, n_proj=n_proj, downsample=downsample, color=color, debug=debug, plot=plot, seg=seg, progress_bar=progress_bar) elif dim == 1: return self._reconstruct_y_slice(y=loc, n_proj=n_proj, downsample=downsample, color=color, debug=debug, plot=plot, progress_bar=progress_bar) elif dim == 2: return self._reconstruct_x_slice(x=loc, n_proj=n_proj, downsample=downsample, color=color, debug=debug, plot=plot, progress_bar=progress_bar) else: raise ValueError('dim should be in [1, 2, 3], got dim = {}'.format(dim))
[docs] def set_regularization(self, reg_x, reg_y, reg_z): """ Set the regularization for the stitching to prevent aberrant displacements. Parameters ---------- reg_x: int if the horizontal displacement computed in the pairwise registration for any tile is greater than reg_x (in pixel unit) then the expected displacement (from motor position) is taken. reg_y: int if the horizontal displacement computed in the pairwise registration for any tile is greater than reg_z (in pixel unit) then the expected displacement (from motor position) is taken. reg_z: int if the horizontal displacement computed in the pairwise registration for any tile is greater than reg_z (in pixel unit) then the expected displacement (from motor position) is taken. Returns ------- None """ self.reg_x = reg_x self.reg_y = reg_y self.reg_z = reg_z
[docs] def reconstruct_z_color(self, z=None, n_proj=10, downsample=1, debug=False, plot=True, progress_bar=True): """ Reconstruct and merge the sample at a given depth z. Parameters ---------- z: int reconstruction depth downsample: int downsample for reconstruction (must be a power of 2) debug: bool if true the border of each tile will be highlighted Returns ------- merged_data: ndarray Merged frame at depth z. """ level_delta = int(-np.sign(downsample) * np.log2(np.abs(downsample))) tile = self.tiles[0] tile.lazy_load_tile(level_delta=level_delta) if z is None: z = int(tile.lazy_data.shape[0] / 2) if z > tile.lazy_data.shape[0]: raise ValueError('Error: z is too large ({}), maximum depth at this downsample is {}.'.format(z, tile.lazy_data.shape[0])) frame_size = tile.lazy_data.shape[1:] x_pos = self.database['ABS_H'].to_numpy() nx = int(np.ceil((x_pos.max() - x_pos.min()) / downsample + frame_size[1])) y_pos = self.database['ABS_V'].to_numpy() ny = int(np.ceil((y_pos.max() - y_pos.min()) / downsample + frame_size[0])) H = np.zeros((ny, nx), dtype='uint16') V = np.zeros((ny, nx), dtype='uint16') H_pos = (x_pos - x_pos.min()) / downsample V_pos = (y_pos - y_pos.min()) / downsample for i, tile in enumerate(tqdm(self.tiles, desc='Merging', disable=not progress_bar)): tile.lazy_load_tile(level_delta=level_delta) zf = min(z+n_proj, tile.lazy_data.shape[0]) data = tile.lazy_data[z:zf] v = data.max(axis=0) h = np.argmax(data, axis=0) # In debug mode we highlight each tile edge to see where it was if debug: v[self.overlap_v, :] = 2**16-1 v[-self.overlap_v, :] = 2**16-1 v[:, self.overlap_h] = 2**16-1 v[:, -self.overlap_h] = 2**16-1 x1 = int(H_pos[i]) x2 = int(H_pos[i] + v.shape[1]) y1 = int(V_pos[i]) y2 = int(V_pos[i] + v.shape[0]) V[y1:y2, x1:x2] = np.maximum(V[y1:y2, x1:x2], v) H[y1:y2, x1:x2] = np.maximum(H[y1:y2, x1:x2], h) H = rescale_intensity(gaussian(H, sigma=2/downsample), out_range=np.float64)*0.66 V = np.log(V + 200) vmin, vmax = np.percentile(V[V > np.log(100)], (1, 99.9)) V = rescale_intensity(V, in_range=(vmin, vmax), out_range=np.float64) S = rescale_intensity(V**1.5, out_range=np.float64)*0.66 rgb = hsv2rgb(np.dstack((H, S, V))) rgb = rescale_intensity(rgb, out_range='uint8') if plot: fig, ax = plt.subplots(1, 1) h = ax.imshow(rgb, cmap='turbo', vmin=0, vmax=n_proj*downsample) divider = make_axes_locatable(ax) cax = divider.append_axes('right', size='5%', pad=0.05) fig.colorbar(h, cax=cax, orientation='vertical', label='Depth [pixel]') return rgb
[docs] def _reconstruct_z_slice(self, z=None, n_proj=0, downsample=1, color=False, debug=False, plot=True, seg=False, progress_bar=True): """ Reconstruct and merge the sample at a given depth z. Parameters ---------- z: int reconstruction depth (vary with downsample) n_proj: int (default: 0) Number of planes to perform the maximum intensity projection. dim: int (default: 0) Dimension of the reconstruction, e.g. 0 will be [y, x] plane (orthogonal to z). downsample: int (default: 1) Downsample factor for the reconstruction. Must be in [1, 2, 4, 8, 16, 32]. color: bool (default: False) Option to reconstruct with checkerboard color pattern. Useful to identify doubling artifacts. debug: bool (default: False) Option to add a white square for each tile, making it easy to see overlapping areas. plot: bool (default: True) Define if the function plots the results with Matplotlib or just returns an array. seg: bool (default: False) Option to also reconstruct the segmentation. Only works with `dim=0` Returns ------- merged_data: ndarray Merged frame at depth z. """ level_delta = int(-np.sign(downsample) * np.log2(np.abs(downsample))) tile = self.tiles[0] tile.lazy_load_tile(level_delta=level_delta) if seg: tile.lazy_load_segmentation(level_delta=level_delta) if z is None: z = int(tile.lazy_data.shape[0] / 2) if z > tile.lazy_data.shape[0]: raise ValueError('Error: z is too large ({}), maximum depth at this downsample is {}.'.format(z, tile.lazy_data.shape[0])) frame_size = tile.lazy_data.shape[1:] x_pos = self.database['ABS_H'].to_numpy() nx = int(np.ceil((x_pos.max() - x_pos.min()) / downsample + frame_size[1])) y_pos = self.database['ABS_V'].to_numpy() ny = int(np.ceil((y_pos.max() - y_pos.min()) / downsample + frame_size[0])) if color: merged_data = np.ones((ny, nx, 3), dtype='uint16') merged_data[:, :, 2] = 0 else: merged_data = np.zeros((ny, nx), dtype='uint16') if seg: merged_seg = np.zeros((ny, nx), dtype='float32') H_pos = (x_pos - x_pos.min()) / downsample V_pos = (y_pos - y_pos.min()) / downsample for i, tile in enumerate(tqdm(self.tiles, desc='Merging', disable=not progress_bar)): tile.lazy_load_tile(level_delta=level_delta) if seg: tile.lazy_load_segmentation(level_delta=level_delta) zf = min(z+n_proj, tile.lazy_data.shape[0]) if zf > z: data = tile.lazy_data[z:zf].max(axis=0) if seg: cc = tile.lazy_segmentation[z:zf].max(axis=0) else: data = tile.lazy_data[z] if seg: cc = tile.lazy_segmentation[z] # In debug mode we highlight each tile edge to see where it was if debug: xv = int(self.expected_overlap_v/downsample) xh = int(self.expected_overlap_h/downsample) data[xv, xh:-xh] = 2**16-1 data[-xv, xh:-xh] = 2**16-1 data[xv:-xv, xh] = 2**16-1 data[xv:-xv, -xh] = 2**16-1 x1 = int(H_pos[i]) x2 = int(H_pos[i] + data.shape[1]) y1 = int(V_pos[i]) y2 = int(V_pos[i] + data.shape[0]) if color: if tile.col % 2: if tile.row % 2: merged_data[y1:y2, x1:x2, 0] = np.maximum(merged_data[y1:y2, x1:x2, 1], data) else: merged_data[y1:y2, x1:x2, 1] = np.maximum(merged_data[y1:y2, x1:x2, 0], data) else: if tile.row % 2: merged_data[y1:y2, x1:x2, 1] = np.maximum(merged_data[y1:y2, x1:x2, 1], data) else: merged_data[y1:y2, x1:x2, 0] = np.maximum(merged_data[y1:y2, x1:x2, 0], data) else: merged_data[y1:y2, x1:x2] = np.maximum(merged_data[y1:y2, x1:x2], data) if seg: merged_seg[y1:y2, x1:x2] = np.maximum(merged_seg[y1:y2, x1:x2], cc) if plot: viewer = napari.Viewer() if color: viewer.add_image(self._process_RGB_for_display(merged_data), name='Z plane') else: viewer.add_image(self._process_GRAY_for_display(merged_data), name='Z plane', contrast_limits=[0, 2**16-1]) if seg: viewer.add_labels(merged_seg, name='Segmentation labels.') if not seg: return merged_data else: return merged_data, merged_seg
[docs] def _reconstruct_y_slice(self, y=None, n_proj=0, downsample=1, color=False, debug=False, plot=True, progress_bar=True): """ Reconstruct and merge the sample at a given position y. Parameters ---------- y: int reconstruction location in y n_proj: int (default: 0) Number of planes to perform the maximum intensity projection. dim: int (default: 0) Dimension of the reconstruction, e.g. 0 will be [y, x] plane (orthogonal to z). downsample: int (default: 1) Downsample factor for the reconstruction. Must be in [1, 2, 4, 8, 16, 32]. color: bool (default: False) Option to reconstruct with checkerboard color pattern. Useful to identify doubling artifacts. debug: bool (default: False) Option to add a white square for each tile, making it easy to see overlapping areas. plot: bool (default: True) Define if the function plots the results with Matplotlib or just returns an array. Returns ------- merged_data: ndarray Merged frame at position y. """ level_delta = int(-np.sign(downsample) * np.log2(np.abs(downsample))) tile = self.tiles[0] tile.lazy_load_tile(level_delta=level_delta) tile_shape = tile.lazy_data.shape if y is None: y = int(tile_shape[1]*self.tiles.nrow/2) if y > tile.lazy_data.shape[1]*self.tiles.nrow: raise ValueError('Error: y is too large ({}), maximum depth at this downsample is {}.' .format(y, tile.lazy_data.shape[1]*self.tiles.nrow)) x_pos = self.database['ABS_H'].to_numpy() nx = int(np.ceil((x_pos.max() - x_pos.min()) / downsample + tile_shape[2])) y_pos = self.database['ABS_V'].to_numpy() ny = int(np.ceil((y_pos.max() - y_pos.min()) / downsample + tile_shape[1])) z_pos = self.database['ABS_D'].to_numpy() nz = int(np.ceil((z_pos.max() - z_pos.min()) / downsample + tile_shape[0])) # Determine tiles to load tiles_to_load = [] tiles_pos = [] for x_loc, y_loc, z_loc, tile in zip(x_pos/downsample, y_pos/downsample, z_pos/downsample, self.tiles): if (y > y_loc) and (y < y_loc+tile_shape[1]): tiles_to_load.append(tile) tiles_pos.append([z_loc, y_loc, x_loc]) tiles_pos = np.array(tiles_pos).astype('uint64') if color: merged_data = np.ones((nz, nx, 3), dtype='uint16') merged_data[:, :, 2] = 0 else: merged_data = np.zeros((nz, nx), dtype='uint16') for i, tile in enumerate(tqdm(tiles_to_load, desc='Merging', disable=not progress_bar)): tile.lazy_load_tile(level_delta=level_delta) y_tile = int(y - tiles_pos[i, 1]) yf = min(y_tile+n_proj, tiles_pos[i, 1]+tile.lazy_data.shape[1]) if yf > y: data = tile.lazy_data[:, y_tile:yf, :].max(axis=1) else: data = tile.lazy_data[:, y_tile, :] # In debug mode we highlight each tile edge to see where it was if debug: xv = int(self.expected_overlap_v/downsample) xh = int(self.expected_overlap_h/downsample) data[xv, xh:-xh] = 2**16-1 data[-xv, xh:-xh] = 2**16-1 data[xv:-xv, xh] = 2**16-1 data[xv:-xv, -xh] = 2**16-1 x1 = int(tiles_pos[i, 2]) x2 = int(tiles_pos[i, 2] + data.shape[1]) z1 = int(tiles_pos[i, 0]) z2 = int(tiles_pos[i, 0] + data.shape[0]) if color: if tile.col % 2: if tile.row % 2: merged_data[z1:z2, x1:x2, 0] = np.maximum(merged_data[z1:z2, x1:x2, 1], data) else: merged_data[z1:z2, x1:x2, 1] = np.maximum(merged_data[z1:z2, x1:x2, 0], data) else: if tile.row % 2: merged_data[z1:z2, x1:x2, 1] = np.maximum(merged_data[z1:z2, x1:x2, 1], data) else: merged_data[z1:z2, x1:x2, 0] = np.maximum(merged_data[z1:z2, x1:x2, 0], data) else: merged_data[z1:z2, x1:x2] = np.maximum(merged_data[z1:z2, x1:x2], data) if plot: viewer = napari.Viewer() if color: viewer.add_image(self._process_RGB_for_display(merged_data), name='Y plane') else: viewer.add_image(self._process_GRAY_for_display(merged_data), name='Y plane') return merged_data
[docs] def _reconstruct_x_slice(self, x=None, n_proj=0, downsample=1, color=False, debug=False, plot=True, progress_bar=True): """ Reconstruct and merge the sample at a given position x. Parameters ---------- x: int reconstruction location in x n_proj: int (default: 0) Number of planes to perform the maximum intensity projection. dim: int (default: 0) Dimension of the reconstruction, e.g. 0 will be [y, x] plane (orthogonal to z). downsample: int (default: 1) Downsample factor for the reconstruction. Must be in [1, 2, 4, 8, 16, 32]. color: bool (default: False) Option to reconstruct with checkerboard color pattern. Useful to identify doubling artifacts. debug: bool (default: False) Option to add a white square for each tile, making it easy to see overlapping areas. plot: bool (default: True) Define if the function plots the results with Matplotlib or just returns an array. Returns ------- merged_data: ndarray Merged frame at position x. """ level_delta = int(-np.sign(downsample) * np.log2(np.abs(downsample))) tile = self.tiles[0] tile.lazy_load_tile(level_delta=level_delta) tile_shape = tile.lazy_data.shape if x is None: x = int(tile_shape[2]*self.tiles.ncol/2) if x > tile.lazy_data.shape[2]*self.tiles.ncol: raise ValueError('Error: y is too large ({}), maximum depth at this downsample is {}.' .format(x, tile.lazy_data.shape[2]*self.tiles.ncol)) x_pos = self.database['ABS_H'].to_numpy() nx = int(np.ceil((x_pos.max() - x_pos.min()) / downsample + tile_shape[2])) y_pos = self.database['ABS_V'].to_numpy() ny = int(np.ceil((y_pos.max() - y_pos.min()) / downsample + tile_shape[1])) z_pos = self.database['ABS_D'].to_numpy() nz = int(np.ceil((z_pos.max() - z_pos.min()) / downsample + tile_shape[0])) # Determine tiles to load tiles_to_load = [] tiles_pos = [] for x_loc, y_loc, z_loc, tile in zip(x_pos/downsample, y_pos/downsample, z_pos/downsample, self.tiles): if (x > x_loc) and (x < x_loc+tile_shape[2]): tiles_to_load.append(tile) tiles_pos.append([z_loc, y_loc, x_loc]) tiles_pos = np.array(tiles_pos).astype('uint64') if color: merged_data = np.ones((nz, ny, 3), dtype='uint16') merged_data[:, :, 2] = 0 else: merged_data = np.zeros((nz, ny), dtype='uint16') for i, tile in enumerate(tqdm(tiles_to_load, desc='Merging', disable=not progress_bar)): tile.lazy_load_tile(level_delta=level_delta) x_tile = int(x - tiles_pos[i, 2]) xf = min(x_tile+n_proj, tiles_pos[i, 2]+tile.lazy_data.shape[2]) if xf > x: data = tile.lazy_data[:, :, x_tile:xf].max(axis=2) else: data = tile.lazy_data[:, :, x_tile] # In debug mode we highlight each tile edge to see where it was if debug: xv = int(self.expected_overlap_v/downsample) xh = int(self.expected_overlap_h/downsample) data[xv, xh:-xh] = 2**16-1 data[-xv, xh:-xh] = 2**16-1 data[xv:-xv, xh] = 2**16-1 data[xv:-xv, -xh] = 2**16-1 y1 = int(tiles_pos[i, 1]) y2 = int(tiles_pos[i, 1] + data.shape[1]) z1 = int(tiles_pos[i, 0]) z2 = int(tiles_pos[i, 0] + data.shape[0]) if color: if tile.col % 2: if tile.row % 2: merged_data[z1:z2, y1:y2, 0] = np.maximum(merged_data[z1:z2, y1:y2, 1], data) else: merged_data[z1:z2, y1:y2, 1] = np.maximum(merged_data[z1:z2, y1:y2, 0], data) else: if tile.row % 2: merged_data[z1:z2, y1:y2, 1] = np.maximum(merged_data[z1:z2, y1:y2, 1], data) else: merged_data[z1:z2, y1:y2, 0] = np.maximum(merged_data[z1:z2, y1:y2, 0], data) else: merged_data[z1:z2, y1:y2] = np.maximum(merged_data[z1:z2, y1:y2], data) if plot: viewer = napari.Viewer() if color: viewer.add_image(self._process_RGB_for_display(merged_data), name='X plane') else: viewer.add_image(self._process_GRAY_for_display(merged_data), name='X plane') return merged_data
[docs] def _process_RGB_for_display(self, u): """ Process RGB data for correctly displaying it. Parameters ---------- u: ndarray RGB data Returns ------- data_to_display: ndarray RGB data displayable with correct contrast and colors. """ data_to_display = np.zeros_like(u, dtype='uint8') for i in range(2): tmp = np.log(u[:, :, i] + 200) vmin, vmax = np.percentile(tmp[tmp > np.log(1 + 200)], (1, 99.9)) data_to_display[:, :, i] = rescale_intensity(tmp, in_range=(vmin, vmax), out_range='uint8') return data_to_display
[docs] def _process_GRAY_for_display(self, u): """ Process RGB data for correctly displaying it. Parameters ---------- u: ndarray RGB data Returns ------- data_to_display: ndarray RGB data displayable with correct contrast and colors. """ u = np.log(u+200) vmin, vmax = np.percentile(u[u > np.log(1 + 200)], (1, 99.9)) data_to_display = rescale_intensity(u, in_range=(vmin, vmax), out_range='uint16') return data_to_display
[docs] def _regularize(self, reg, rel): """ Remove too large displacement and replace them with expected one with a large uncertainty. """ if np.abs(reg[2] - (self.overlap_h - self.expected_overlap_h)) > self.reg_x: reg[2] = (self.overlap_h - self.expected_overlap_h) rel[2] = 2 if np.abs(reg[1] - (self.overlap_v - self.expected_overlap_v)) > self.reg_y: reg[1] = (self.overlap_v - self.expected_overlap_v) rel[1] = 2 if np.abs(reg[0]) > self.reg_z: reg[0] = 0 rel[0] = 2 return reg, rel
[docs] def _save_max_projs(self): """ Save the computed maximum intensity projection on persistent memory. This is useful to recompute the registration directly from the max. proj. but only works if the overlaps are kept the same. Returns ------- None """ # Safely create folder to save max-projs Path(self.tiles.folder_max_projs).mkdir(parents=True, exist_ok=True) for row in range(self.nrow): for col in range(self.ncol): proj = self.projs[row, col] if proj is not None: for loc in proj.keys(): for i, d in enumerate(['zy', 'zx', 'yx']): data = proj[loc] np.save(os.path.join(self.tiles.folder_max_projs, '{}_{}_{}_{}.npy'.format(row, col, loc, d)), data[i])
[docs] def _load_max_projs(self, path): """ Load the maximum intensity projection previously stored. Parameters ---------- path: str path to load the maximum intensity projection from. If None then default to `max_projs` folder in the acquisition folder. Returns ------- None """ projs = np.empty((self.nrow, self.ncol), dtype=object) if path is None: folder_max_projs = self.tiles.folder_max_projs else: folder_max_projs = path for tile in self.tiles: proj = {} if tile.col + 1 < self.tiles.ncol: if self.tiles.tiles_pattern[tile.row, tile.col + 1] == 1: # EAST 1 tmp = [] for i, d in enumerate(['zy', 'zx', 'yx']): tmp.append(np.load(os.path.join(folder_max_projs, '{}_{}_east_{}.npy'.format(tile.row, tile.col, d)))) proj['east'] = tmp if tile.col - 1 >= 0: if self.tiles.tiles_pattern[tile.row, tile.col - 1] == 1: # EAST 2 tmp = [] for i, d in enumerate(['zy', 'zx', 'yx']): tmp.append(np.load(os.path.join(folder_max_projs, '{}_{}_west_{}.npy'.format(tile.row, tile.col, d)))) proj['west'] = tmp if tile.row + 1 < self.tiles.nrow: if self.tiles.tiles_pattern[tile.row + 1, tile.col] == 1: # SOUTH 1 tmp = [] for i, d in enumerate(['zy', 'zx', 'yx']): tmp.append(np.load(os.path.join(folder_max_projs, '{}_{}_south_{}.npy'.format(tile.row, tile.col, d)))) proj['south'] = tmp if tile.row - 1 >= 0: if self.tiles.tiles_pattern[tile.row - 1, tile.col] == 1: # SOUTH 2 tmp = [] for i, d in enumerate(['zy', 'zx', 'yx']): tmp.append(np.load(os.path.join(folder_max_projs, '{}_{}_north_{}.npy'.format(tile.row, tile.col, d)))) proj['north'] = tmp projs[tile.row, tile.col] = proj return projs
[docs] def _precompute_max_projs(self, progress_bar=True): """ Precompute max-projections for loading the data only once during the stitching. Returns ------- None """ projs = np.empty((self.nrow, self.ncol), dtype=object) for tile in tqdm(self.tiles, desc='Computing max. proj.', disable=not progress_bar): tile.load_tile() proj = {} if tile.col + 1 < self.tiles.ncol: if self.tiles.tiles_pattern[tile.row, tile.col + 1] == 1: # EAST 1 patch = pyapr.ReconPatch() patch.y_begin = self.frame_size - self.overlap_h if self.z_begin is None: proj['east'] = _get_max_proj_apr(tile.apr, tile.parts, patch, plot=False) else: patch_yx = pyapr.ReconPatch() patch_yx.y_begin = self.frame_size - self.overlap_h patch_yx.z_begin = self.z_begin patch_yx.z_end = self.z_end proj['east'] = _get_max_proj_apr(tile.apr, tile.parts, patch=patch, patch_yx=patch_yx, plot=False) if tile.col - 1 >= 0: if self.tiles.tiles_pattern[tile.row, tile.col - 1] == 1: # EAST 2 patch = pyapr.ReconPatch() patch.y_end = self.overlap_h if self.z_begin is None: proj['west'] = _get_max_proj_apr(tile.apr, tile.parts, patch, plot=False) else: patch_yx = pyapr.ReconPatch() patch_yx.y_end = self.overlap_h patch_yx.z_begin = self.z_begin patch_yx.z_end = self.z_end proj['west'] = _get_max_proj_apr(tile.apr, tile.parts, patch=patch, patch_yx=patch_yx, plot=False) if tile.row + 1 < self.tiles.nrow: if self.tiles.tiles_pattern[tile.row + 1, tile.col] == 1: # SOUTH 1 patch = pyapr.ReconPatch() patch.x_begin = self.frame_size - self.overlap_v if self.z_begin is None: proj['south'] = _get_max_proj_apr(tile.apr, tile.parts, patch, plot=False) else: patch_yx = pyapr.ReconPatch() patch_yx.x_begin = self.frame_size - self.overlap_v patch_yx.z_begin = self.z_begin patch_yx.z_end = self.z_end proj['south'] = _get_max_proj_apr(tile.apr, tile.parts, patch=patch, patch_yx=patch_yx, plot=False) if tile.row - 1 >= 0: if self.tiles.tiles_pattern[tile.row - 1, tile.col] == 1: # SOUTH 2 patch = pyapr.ReconPatch() patch.x_end = self.overlap_v if self.z_begin is None: proj['north'] = _get_max_proj_apr(tile.apr, tile.parts, patch, plot=False) else: patch_yx = pyapr.ReconPatch() patch_yx.x_end = self.overlap_v patch_yx.z_begin = self.z_begin patch_yx.z_end = self.z_end proj['north'] = _get_max_proj_apr(tile.apr, tile.parts, patch=patch, patch_yx=patch_yx, plot=False) projs[tile.row, tile.col] = proj if self.segment: self.segmenter.compute_segmentation(tile) self.projs = projs
[docs]class tileStitcher(baseStitcher): """ Class used to perform the stitching. The stitching is performed in 4 steps: 1. The pairwise registration parameters of each neighboring tile is computed on the max-projection 2. A sparse graph (edges = tiles and vertex = registration between neighboring tiles) is constructed to store the registration parameters (displacements and reliability) 3. The sparse graph is optimized to satisfy the constraints (every loop in the graph should sum to 0) using the maximum spanning tree on the reliability estimation. 4. The maximum spanning tree is parsed to extract optimal tile positions solution. The beauty of this method is that it scales well with increasing dataset sizes and because the final optimization is very fast and does not require to reload the data. """
[docs] def __init__(self, tiles, overlap_h: (int, float), overlap_v: (int, float)): """ Constructor for the tileStitcher class. Parameters ---------- tiles: tileParser tileParser object containing the dataset to stitch. overlap_h: float expected horizontal overlap in % overlap_v: float expected vertical overlap in % """ super().__init__(tiles, overlap_h, overlap_v) self.cgraph_from = [] self.cgraph_to = [] self.relia_H = [] self.relia_V = [] self.relia_D = [] self.dH = [] self.dV = [] self.dD = [] # Attributes below are set when the corresponding method are called. self.registration_map_rel = None self.registration_map_abs = None self.ctree_from_H = None self.ctree_from_V = None self.ctree_from_D = None self.ctree_to_H = None self.ctree_to_V = None self.ctree_to_D = None self.min_tree_H = None self.min_tree_V = None self.min_tree_D = None self.graph_relia_H = None self.graph_relia_V = None self.graph_relia_D = None self.database = None self.projs = None
[docs] def _compute_registration_old(self): """ Compute the pair-wise registration for all tiles. This implementation loads the data twice and is therefore not efficient. """ for tile in tqdm(self.tiles, desc='Computing stitching'): tile.load_tile() tile.load_neighbors() if self.segment: self.segmenter.compute_segmentation(tile) for apr, parts, coords in zip(tile.apr_neighbors, tile.parts_neighbors, tile.neighbors): if tile.row == coords[0] and tile.col < coords[1]: # EAST reg, rel = self._compute_east_registration(tile.apr, tile.parts, apr, parts) elif tile.col == coords[1] and tile.row < coords[0]: # SOUTH reg, rel = self._compute_south_registration(tile.apr, tile.parts, apr, parts) else: raise TypeError('Error: couldn''t determine registration to perform.') # Regularize in cas of aberrant displacements reg, rel = self._regularize(reg, rel) self.cgraph_from.append(np.ravel_multi_index([tile.row, tile.col], dims=(self.nrow, self.ncol))) self.cgraph_to.append(np.ravel_multi_index([coords[0], coords[1]], dims=(self.nrow, self.ncol))) # H=x, V=y, D=z self.dH.append(reg[2]) self.dV.append(reg[1]) self.dD.append(reg[0]) self.relia_H.append(rel[2]) self.relia_V.append(rel[1]) self.relia_D.append(rel[0]) self._build_sparse_graphs() self._optimize_sparse_graphs() _, _ = self._produce_registration_map() self._build_database() self._print_info()
[docs] def compute_registration(self, on_disk=False, progress_bar=True): """ Compute the pair-wise registration for all tiles. This implementation loads the data once by precomputing the max-proj and is therefore efficient. """ # First we pre-compute the max-projections and keep them in memory or save them on disk and load them up. if on_disk: # It makes more sens and it avoids loading the max-proj when it is computed. self._precompute_max_projs(progress_bar=progress_bar) self._save_max_projs() else: self._precompute_max_projs(progress_bar=progress_bar) # Then we loop again through the tiles but now we have access to the max-proj for tile in tqdm(self.tiles, desc='Computing cross-correlations', disable=not progress_bar): proj1 = self.projs[tile.row, tile.col] for coords in tile.neighbors: proj2 = self.projs[coords[0], coords[1]] if tile.row == coords[0] and tile.col < coords[1]: # EAST if self.mask: reg, rel = _get_masked_proj_shifts(proj1['east'], proj2['west'], threshold=self.threshold) else: reg, rel = _get_proj_shifts(proj1['east'], proj2['west']) elif tile.col == coords[1] and tile.row < coords[0]: # SOUTH if self.mask: reg, rel = _get_masked_proj_shifts(proj1['south'], proj2['north'], threshold=self.threshold) else: reg, rel = _get_proj_shifts(proj1['south'], proj2['north']) else: raise TypeError('Error: couldn''t determine registration to perform.') self.cgraph_from.append(np.ravel_multi_index([tile.row, tile.col], dims=(self.nrow, self.ncol))) self.cgraph_to.append(np.ravel_multi_index([coords[0], coords[1]], dims=(self.nrow, self.ncol))) # Regularize in case of aberrant displacements reg, rel = self._regularize(reg, rel) # H=x, V=y, D=z self.dH.append(reg[2]) self.dV.append(reg[1]) self.dD.append(reg[0]) self.relia_H.append(rel[2]) self.relia_V.append(rel[1]) self.relia_D.append(rel[0]) self._build_sparse_graphs() self._optimize_sparse_graphs() _, _ = self._produce_registration_map() self._build_database() self._print_info()
[docs] def compute_registration_from_max_projs(self, path=None): """ Compute the registration directly from the max-projections. Max-projections must have been computed before. """ # First we pre-compute the max-projections and keep them in memory or save them on disk and load them up. projs = self._load_max_projs(path=path) # Then we loop again through the tiles but now we have access to the max-proj for tile in tqdm(self.tiles, desc='Compute cross-correlation'): proj1 = projs[tile.row, tile.col] for coords in tile.neighbors: proj2 = projs[coords[0], coords[1]] if tile.row == coords[0] and tile.col < coords[1]: # EAST if self.mask: reg, rel = _get_masked_proj_shifts(proj1['east'], proj2['west'], threshold=self.threshold) else: reg, rel = _get_proj_shifts(proj1['east'], proj2['west']) elif tile.col == coords[1] and tile.row < coords[0]: # SOUTH if self.mask: reg, rel = _get_masked_proj_shifts(proj1['south'], proj2['north'], threshold=self.threshold) else: reg, rel = _get_proj_shifts(proj1['south'], proj2['north']) else: raise TypeError('Error: couldn''t determine registration to perform.') self.cgraph_from.append(np.ravel_multi_index([tile.row, tile.col], dims=(self.nrow, self.ncol))) self.cgraph_to.append(np.ravel_multi_index([coords[0], coords[1]], dims=(self.nrow, self.ncol))) # Regularize in cas of aberrant displacements reg, rel = self._regularize(reg, rel) # H=x, V=y, D=z self.dH.append(reg[2]) self.dV.append(reg[1]) self.dD.append(reg[0]) self.relia_H.append(rel[2]) self.relia_V.append(rel[1]) self.relia_D.append(rel[0]) self._build_sparse_graphs() self._optimize_sparse_graphs() _, _ = self._produce_registration_map() self._build_database() self._print_info()
[docs] def compute_expected_registration(self): """ Compute the expected registration if the expected overlap are correct. """ reg_rel_map = np.zeros((3, self.nrow, self.ncol)) self.registration_map_rel = reg_rel_map reg_abs_map = np.zeros_like(reg_rel_map) # H for x in range(reg_abs_map.shape[2]): reg_abs_map[0, :, x] = reg_rel_map[0, :, x] + x * (self.frame_size - self.expected_overlap_h) # V for x in range(reg_abs_map.shape[1]): reg_abs_map[1, x, :] = reg_rel_map[1, x, :] + x * (self.frame_size - self.expected_overlap_v) # D reg_abs_map[2] = reg_rel_map[2] self.registration_map_abs = reg_abs_map self._build_database()
[docs] def plot_graph(self, annotate=False): """ Plot the graph for each direction (H, D, V). This method needs to be called after the graph optimization. Parameters ---------- annotate: bool control if annotation are drawn on the graph """ if self.graph_relia_H is None: raise TypeError('Error: graph not build yet, please use build_sparse_graph()' 'before trying to plot the graph.') fig, ax = plt.subplots(1, 3) for i, d in enumerate(['H', 'V', 'D']): ind_from = getattr(self, 'cgraph_from') row, col = np.unravel_index(ind_from, shape=(self.nrow, self.ncol)) V1 = np.vstack((row, col)).T ind_to = getattr(self, 'cgraph_to') row, col = np.unravel_index(ind_to, shape=(self.nrow, self.ncol)) V2 = np.vstack((row, col)).T rel = getattr(self, 'relia_' + d) dX = getattr(self, 'd' + d) for ii in range(V1.shape[0]): ax[i].plot([V1[ii, 1], V2[ii, 1]], [V1[ii, 0], V2[ii, 0]], 'ko', markerfacecolor='r') if annotate: p1 = ax[i].transData.transform_point([V1[ii, 1], V1[ii, 0]]) p2 = ax[i].transData.transform_point([V2[ii, 1], V2[ii, 0]]) dy = p2[1]-p1[1] dx = p2[0]-p1[0] rot = np.degrees(np.arctan2(dy, dx)) if rel[ii] < 0.15: color = 'g' elif rel[ii] < 0.30: color = 'orange' else: color = 'r' ax[i].annotate(text='err={:.2f} d{}={:.2f}'.format(rel[ii], d, dX[ii]), xy=((V1[ii, 1]+V2[ii, 1])/2, (V1[ii, 0]+V2[ii, 0])/2), ha='center', va='center', fontsize=8, rotation=rot, backgroundcolor='w', color=color) ax[i].set_title(d + ' tree') ax[i].invert_yaxis() return fig, ax
[docs] def plot_min_trees(self, annotate=False): """ Plot the minimum spanning tree for each direction (H, D, V). This method needs to be called after the graph optimization. Parameters ---------- annotate: bool control if annotation are drawn on the graph """ if self.min_tree_H is None: raise TypeError('Error: minimum spanning tree not computed yet, please use optimize_sparse_graph()' 'before trying to plot the trees.') fig, ax = self.plot_graph(annotate=False) for i, d in enumerate(['H', 'V', 'D']): ind_from = getattr(self, 'ctree_from_' + d) row, col = np.unravel_index(ind_from, shape=(self.nrow, self.ncol)) V1 = np.vstack((row, col)).T ind_to = getattr(self, 'ctree_to_' + d) row, col = np.unravel_index(ind_to, shape=(self.nrow, self.ncol)) V2 = np.vstack((row, col)).T dX = getattr(self, 'd' + d) for ii in range(V1.shape[0]): ax[i].plot([V1[ii, 1], V2[ii, 1]], [V1[ii, 0], V2[ii, 0]], 'ko-', markerfacecolor='r', linewidth=2) if annotate: p1 = ax[i].transData.transform_point([V1[ii, 1], V1[ii, 0]]) p2 = ax[i].transData.transform_point([V2[ii, 1], V2[ii, 0]]) dy = p2[1]-p1[1] dx = p2[0]-p1[0] rot = np.degrees(np.arctan2(dy, dx)) ax[i].annotate(text='{:.2f}'.format(dX[self._get_ind(ind_from[ii], ind_to[ii])]), xy=((V1[ii, 1]+V2[ii, 1])/2, (V1[ii, 0]+V2[ii, 0])/2), ha='center', va='center', fontsize=8, rotation=rot, backgroundcolor='w', color='r') ax[i].set_title(d + ' tree')
[docs] def plot_stitching_info(self): """ Plot pair-wise registration error for each axis [H, V, D]. Returns ------- rel_map: array error matrix d_map: array shift matrix """ if self.min_tree_H is None: raise TypeError('Error: minimum spanning tree not computed yet, please use optimize_sparse_graph()' 'before trying to plot stitching info.') rel_map = np.zeros((3, self.nrow, self.ncol)) for i, d in enumerate(['H', 'V', 'D']): ind_from = getattr(self, 'ctree_from_' + d) ind_to = getattr(self, 'ctree_to_' + d) graph = getattr(self, 'graph_relia_' + d) rows, cols = np.unravel_index(ind_to, shape=(self.nrow, self.ncol)) for row, col, i1, i2 in zip(rows, cols, ind_from, ind_to): rel = graph[i1, i2] rel_map[i, row, col] = np.max((rel_map[i, row, col], rel)) rows, cols = np.unravel_index(ind_from, shape=(self.nrow, self.ncol)) for row, col, i1, i2 in zip(rows, cols, ind_from, ind_to): rel = graph[i1, i2] rel_map[i, row, col] = np.max((rel_map[i, row, col], rel)) fig, ax = plt.subplots(1, 3, sharex=True, sharey=True) for i, d in enumerate(['H', 'V', 'D']): h = ax[i].imshow(rel_map[i], cmap='turbo', vmin=0, vmax=2) ax[i].set_title('Registration {}'.format(d)) divider = make_axes_locatable(ax[i]) cax = divider.append_axes('right', size='5%', pad=0.05) fig.colorbar(h, cax=cax, label='Estimated error [a.u.]') plt.figure() plt.imshow(np.mean(rel_map, axis=0), cmap='turbo') plt.colorbar(label='Total stimated error [a.u.]') if self.graph_relia_H is None: raise TypeError('Error: graph not build yet, please use build_sparse_graph()' 'before trying to plot the graph.') fig, ax = plt.subplots(1, 3) for i, d in enumerate(['H', 'V', 'D']): ind_from = getattr(self, 'cgraph_from') row, col = np.unravel_index(ind_from, shape=(self.nrow, self.ncol)) V1 = np.vstack((row, col)).T+0.25 ind_to = getattr(self, 'cgraph_to') row, col = np.unravel_index(ind_to, shape=(self.nrow, self.ncol)) V2 = np.vstack((row, col)).T+0.25 for ii in range(V1.shape[0]): ax[i].plot([V1[ii, 1], V2[ii, 1]], [V1[ii, 0], V2[ii, 0]], 'ko', markerfacecolor='r') ax[i].set_title(d + ' tree') ax[i].invert_yaxis() ind_from = getattr(self, 'ctree_from_' + d) row, col = np.unravel_index(ind_from, shape=(self.nrow, self.ncol)) V1 = np.vstack((row, col)).T+0.25 ind_to = getattr(self, 'ctree_to_' + d) row, col = np.unravel_index(ind_to, shape=(self.nrow, self.ncol)) V2 = np.vstack((row, col)).T+0.25 dX = getattr(self, 'd' + d) for ii in range(V1.shape[0]): ax[i].plot([V1[ii, 1], V2[ii, 1]], [V1[ii, 0], V2[ii, 0]], 'ko-', markerfacecolor='r', linewidth=2) p1 = ax[i].transData.transform_point([V1[ii, 1], V1[ii, 0]]) p2 = ax[i].transData.transform_point([V2[ii, 1], V2[ii, 0]]) dy = p2[1] - p1[1] dx = p2[0] - p1[0] rot = np.degrees(np.arctan2(dy, dx)) ax[i].annotate(text='{:.2f}'.format(dX[self._get_ind(ind_from[ii], ind_to[ii])]), xy=((V1[ii, 1] + V2[ii, 1]) / 2, (V1[ii, 0] + V2[ii, 0]) / 2), ha='center', va='center', fontsize=8, rotation=rot, backgroundcolor='w', color='r') sns.heatmap(self.registration_map_rel[i], annot=True, fmt='4.0f', ax=ax[i], cbar=False)
[docs] def plot_registration_map(self): """ Display the registration map using matplotlib. """ if self.registration_map_abs is None: raise TypeError('Error: registration map not computed yet, please use produce_registration_map()' 'before trying to display the registration map.') fig, ax = plt.subplots(2, 3) for i, d in enumerate(['H', 'V', 'D']): ax[0, i].imshow(self.registration_map_rel[i], cmap='gray') ax[0, i].set_title('Rel reg. map ' + d) ax[1, i].imshow(self.registration_map_abs[i], cmap='gray') ax[1, i].set_title('Abs reg. map ' + d)
[docs] def dump_stitcher(self, path): """ Use dill to store a tgraph object. Parameters ---------- path: string path to save the database. """ if path[-4:] != '.pkl': path = path + '.pkl' with open(path, 'wb') as f: dill.dump(self, f)
[docs] def set_overlap_margin(self, margin): """ Modify the overlaping area size. If the overlaping area is smaller than the true one, the stitching can't be performed properly. If the overlaping area area is more than twice the size of the true one it will also fail (due to the circular FFT in the phase cross correlation). Parameters ---------- margin: float safety margin in % to take the overlaping area. Returns ------- None """ if margin > 45: raise ValueError('Error: overlap margin is too big and will make the stitching fail.') if margin < 1: raise ValueError('Error: overlap margin is too small and may make the stitching fail.') self.overlap_h = int(self.expected_overlap_h*(1+margin/100)) if self.expected_overlap_h > self.frame_size: self.expected_overlap_h = self.frame_size self.overlap_v = int(self.expected_overlap_v*(1+margin/100)) if self.expected_overlap_v > self.frame_size: self.expected_overlap_v = self.frame_size
[docs] def set_z_range(self, z_begin, z_end): """ Set a range of depth fo computing the stitching. Parameters ---------- z_begin: int first depth to be included in the max-proj z_end: int last depth to be included in the max-proj Returns ------- None """ self.z_begin = z_begin self.z_end = z_end
[docs] def _print_info(self): """ Display stitching result information. """ overlap = np.median(np.diff(np.median(self.registration_map_abs[0], axis=0))) self.effective_overlap_h = (self.frame_size-overlap)/self.frame_size*100 print('Effective horizontal overlap: {:0.2f}%'.format(self.effective_overlap_h)) overlap = np.median(np.diff(np.median(self.registration_map_abs[1], axis=1))) self.effective_overlap_v = (self.frame_size-overlap)/self.frame_size*100 print('Effective vertical overlap: {:0.2f}%'.format(self.effective_overlap_v)) if np.abs(self.effective_overlap_v*self.frame_size/100-self.expected_overlap_v)>0.2*self.expected_overlap_v: warnings.warn('Expected vertical overlap is very different from the computed one, the registration ' 'might be wrong.') if np.abs(self.effective_overlap_h*self.frame_size/100-self.expected_overlap_h)>0.2*self.expected_overlap_h: warnings.warn('Expected horizontal overlap is very different from the computed one, the registration ' 'might be wrong.')
[docs] def _build_sparse_graphs(self): """ Build the sparse graph from the reliability and (row, col). This method needs to be called after the pair-wise registration has been performed for all neighbors pair. Returns ------- None """ csr_matrix_size = self.ncol*self.nrow self.graph_relia_H = csr_matrix((self.relia_H, (self.cgraph_from, self.cgraph_to)), shape=(csr_matrix_size, csr_matrix_size)) self.graph_relia_V = csr_matrix((self.relia_V, (self.cgraph_from, self.cgraph_to)), shape=(csr_matrix_size, csr_matrix_size)) self.graph_relia_D = csr_matrix((self.relia_D, (self.cgraph_from, self.cgraph_to)), shape=(csr_matrix_size, csr_matrix_size))
[docs] def _optimize_sparse_graphs(self): """ Optimize the sparse graph by computing the minimum spanning tree for each direction (H, D, V). This method needs to be called after the sparse graphs have been built. Returns ------- None """ if self.graph_relia_H is None: raise TypeError('Error: sparse graph not build yet, please use build_sparse_graph() before trying to' 'perform the optimization.') for g in ['graph_relia_H', 'graph_relia_V', 'graph_relia_D']: graph = getattr(self, g) # Minimum spanning tree min_tree = minimum_spanning_tree(graph) # Get the "true" neighbors min_tree = min_tree.tocoo() setattr(self, 'min_tree_' + g[-1], min_tree) ctree_from = min_tree.row setattr(self, 'ctree_from_' + g[-1], ctree_from) ctree_to = min_tree.col setattr(self, 'ctree_to_' + g[-1], ctree_to)
[docs] def _produce_registration_map(self): """ Produce the registration map where reg_rel_map[d, row, col] (d = H,V,D) is the relative tile position in pixel from the expected one. This method needs to be called after the optimization has been done. Returns ------- None """ if self.min_tree_H is None: raise TypeError('Error: minimum spanning tree not computed yet, please use optimize_sparse_graph()' 'before trying to compute the registration map.') # Relative registration # Initialize relative registration map reg_rel_map = np.zeros((3, self.nrow, self.ncol)) # H, V, D for i, min_tree in enumerate(['min_tree_H', 'min_tree_V', 'min_tree_D']): # Fill it by following the tree and getting the corresponding registration parameters node_array = depth_first_order(getattr(self, min_tree), i_start=self.cgraph_from[0], directed=False, return_predecessors=False) node_visited = [node_array[0]] tree = getattr(self, min_tree) row = tree.row col = tree.col for node_to in zip(node_array[1:]): # The previous node in the MST is a visited node with an edge to the current node neighbors = [] for r, c in zip(row, col): if r == node_to: neighbors.append(c) if c == node_to: neighbors.append(r) node_from = [x for x in neighbors if x in node_visited] node_visited.append(node_to) # Get the previous neighbor local reg parameter ind1, ind2 = np.unravel_index(node_from, shape=(self.nrow, self.ncol)) d_neighbor = reg_rel_map[i, ind1, ind2] # Get the current 2D tile position ind1, ind2 = np.unravel_index(node_to, shape=(self.nrow, self.ncol)) # Get the associated ind position in the registration graph (as opposed to the reliability min_tree) ind_graph = self._get_ind(node_from, node_to) # Get the corresponding reg parameter d = getattr(self, 'd' + min_tree[-1])[ind_graph] # Get the corresponding relia and print a warning if it was regularized: relia = getattr(self, 'relia_' + min_tree[-1])[ind_graph] if relia == 2: print('Aberrant pair-wise registration remaining after global optimization between tile ({},{}) ' 'and tile ({},{})'.format(*np.unravel_index(node_from, shape=(self.nrow, self.ncol)), *np.unravel_index(node_to, shape=(self.nrow, self.ncol)))) # Update the local reg parameter in the 2D matrix if node_to > node_from[0]: reg_rel_map[i, ind1, ind2] = d_neighbor + d else: reg_rel_map[i, ind1, ind2] = d_neighbor - d self.registration_map_rel = reg_rel_map reg_abs_map = np.zeros_like(reg_rel_map) # H for x in range(reg_abs_map.shape[2]): reg_abs_map[0, :, x] = reg_rel_map[0, :, x] + x * (self.frame_size-self.overlap_h) # V for x in range(reg_abs_map.shape[1]): reg_abs_map[1, x, :] = reg_rel_map[1, x, :] + x * (self.frame_size-self.overlap_v) # D reg_abs_map[2] = reg_rel_map[2] self.registration_map_abs = reg_abs_map return reg_rel_map, reg_abs_map
[docs] def _build_database(self): """ Build the database for storing the registration parameters. This method needs to be called after the registration map has been produced. Returns ------- None """ if self.registration_map_rel is None: raise TypeError('Error: database can''t be build if the registration map has not been computed.' ' Please use produce_registration_map() method first.') database_dict = {} for i in range(self.n_vertex): row = self.tiles[i].row col = self.tiles[i].col database_dict[i] = {'path': self.tiles[i].path, 'row': row, 'col': col, 'dH': self.registration_map_rel[0, row, col], 'dV': self.registration_map_rel[1, row, col], 'dD': self.registration_map_rel[2, row, col], 'ABS_H': self.registration_map_abs[0, row, col], 'ABS_V': self.registration_map_abs[1, row, col], 'ABS_D': self.registration_map_abs[2, row, col]} self.database = pd.DataFrame.from_dict(database_dict, orient='index') # Finally set the origin so that tile on the edge have coordinate 0 (rather than negative): for i, d in enumerate(['ABS_D', 'ABS_V', 'ABS_H']): self.database[d] = self.database[d] - self.database[d].min()
[docs] def _get_ind(self, ind_from, ind_to): """ Returns the ind in the original graph which corresponds to (ind_from, ind_to) in the minimum spanning tree. Parameters ---------- ind_from: int starting node in the directed graph ind_to: int ending node in the directed graph Returns ---------- ind: int corresponding ind in the original graph """ ind = None for i, f in enumerate(self.cgraph_from): if f == ind_from: if self.cgraph_to[i] == ind_to: ind = i if ind is None: for i, f in enumerate(self.cgraph_to): if f == ind_from: if self.cgraph_from[i] == ind_to: ind = i if ind is None: raise ValueError('Error: can''t find matching vertex pair.') return ind
[docs] def _compute_east_registration(self, apr_1, parts_1, apr_2, parts_2): """ Compute the registration between the current tile and its eastern neighbor. Parameters ---------- u: list current tile v: list neighboring tile Returns ------- None """ patch = pyapr.ReconPatch() patch.y_begin = self.frame_size - self.overlap_h proj_zy1, proj_zx1, proj_yx1 = _get_max_proj_apr(apr_1, parts_1, patch, plot=False) patch = pyapr.ReconPatch() patch.y_end = self.overlap_h proj_zy2, proj_zx2, proj_yx2 = _get_max_proj_apr(apr_2, parts_2, patch, plot=False) # proj1, proj2 = [proj_zy1, proj_zx1, proj_yx1], [proj_zy2, proj_zx2, proj_yx2] # for i, title in enumerate(['X', 'Y', 'Z']): # fig, ax = plt.subplots(1, 2, sharex=True, sharey=True) # ax[0].imshow(proj1[i], cmap='gray') # ax[0].set_title('EAST') # ax[1].imshow(proj2[i], cmap='gray') # ax[1].set_title(title) if self.mask: return _get_masked_proj_shifts([proj_zy1, proj_zx1, proj_yx1], [proj_zy2, proj_zx2, proj_yx2], threshold=self.threshold) else: return _get_proj_shifts([proj_zy1, proj_zx1, proj_yx1], [proj_zy2, proj_zx2, proj_yx2])
[docs] def _compute_south_registration(self, apr_1, parts_1, apr_2, parts_2): """ Compute the registration between the current tile and its southern neighbor. Parameters ---------- u: list current tile v: list neighboring tile Returns ------- None """ patch = pyapr.ReconPatch() patch.x_begin = self.frame_size - self.overlap_v proj_zy1, proj_zx1, proj_yx1 = _get_max_proj_apr(apr_1, parts_1, patch, plot=False) patch = pyapr.ReconPatch() patch.x_end = self.overlap_v proj_zy2, proj_zx2, proj_yx2 = _get_max_proj_apr(apr_2, parts_2, patch, plot=False) # proj1, proj2 = [proj_zy1, proj_zx1, proj_yx1], [proj_zy2, proj_zx2, proj_yx2] # for i, title in enumerate(['X', 'Y', 'Z']): # fig, ax = plt.subplots(1, 2, sharex=True, sharey=True) # ax[0].imshow(proj1[i], cmap='gray') # ax[0].set_title('SOUTH') # ax[1].imshow(proj2[i], cmap='gray') # ax[1].set_title(title) if self.mask: return _get_masked_proj_shifts([proj_zy1, proj_zx1, proj_yx1], [proj_zy2, proj_zx2, proj_yx2], threshold=self.threshold) else: return _get_proj_shifts([proj_zy1, proj_zx1, proj_yx1], [proj_zy2, proj_zx2, proj_yx2])
[docs]class channelStitcher(baseStitcher): """ Class used to perform the stitching between different channels. The registration must be performed first a single channel (typically auto-fluorescence) The stitching is performed between each corresponding tile and the relative displacement is added to the previously determined stitching parameters. The number and position of tile should matched for bot dataset. """
[docs] def __init__(self, stitcher, ref, moving): """ Constructor for the channelStitcher class. Parameters ---------- stitcher: tileStitcher tileStitcher object with the multitile registration parameters tiles_stitched: tileParser tiles corresponding to the stitcher tiles_channel: tileParser tiles to be registered to tiles_stitched """ super().__init__(moving, stitcher.overlap_h, stitcher.overlap_v) self.stitcher = stitcher self.tiles_ref = ref self.database = stitcher.database.copy() # Change tiles path for the channel tiles self.database.path = self.tiles.path_list self.segment = False self.segmentation_verbose = None self.patch = pyapr.ReconPatch()
[docs] def compute_rigid_registration(self, progress_bar=True): """ Compute the rigid registration between each pair of tiles across different channels. Returns ------- None """ for tile1, tile2 in zip(tqdm(self.tiles_ref, desc='Computing rigid registration', disable=not progress_bar), self.tiles): tile1.load_tile() tile2.load_tile() if self.segment: self.segmenter.compute_segmentation(tile2) proj1 = _get_max_proj_apr(tile1.apr, tile1.parts, self.patch) proj2 = _get_max_proj_apr(tile2.apr, tile2.parts, self.patch) if self.mask: reg, rel = _get_masked_proj_shifts(proj1, proj2, self.threshold) else: reg, rel = _get_proj_shifts(proj1, proj2) # TODO: add regularization to avoid aberrant shifts. self._update_database(tile2.row, tile2.col, reg)
[docs] def set_lim(self, x_begin=None, x_end=None, y_begin=None, y_end=None, z_begin=None, z_end=None): """ Define spatial limits to compute the maximum intensity projection. Parameters ---------- x_begin: int x_end: int y_begin: int y_end: int z_begin: int z_end: int Returns ------- None """ if x_begin is not None: self.patch.x_begin = x_begin if x_end is not None: self.patch.x_end = x_end if y_begin is not None: self.patch.y_begin = y_begin if y_end is not None: self.patch.y_end = y_end if z_begin is not None: self.patch.z_begin = z_begin if z_end is not None: self.patch.z_end = z_end
[docs] def _update_database(self, row, col, d): """ Update database after the registration. Parameters ---------- row: int row number col: int col number d: int computed displacement Returns ------- """ d = np.concatenate([d, d]) df = self.database for loc, value in zip(['dD', 'dV', 'dH', 'ABS_D', 'ABS_V', 'ABS_H'], d): df.loc[(df['row'] == row) & (df['col'] == col), loc] += value
[docs]class tileMerger(): """ Class to merge tiles and create a stitched volume. Typically used at a lower resolution for registering the sample to an Atlas. """
[docs] def __init__(self, tiles, database): """ Constructor for the tileMerger class. Parameters ---------- tiles: tileParser tileParser object containing the dataset to merge. database: (pd.DataFrame, string) database or path to the database containing the registered tile position n_planes: int number of planes per files. Returns ------- None """ if isinstance(database, str): self.database = pd.read_csv(database) else: self.database = database self.tiles = tiles self.lazy = self._find_if_lazy() self.type = tiles.type self.frame_size = tiles.frame_size self.n_planes = self._get_n_planes() self.n_tiles = tiles.n_tiles self.n_row = tiles.nrow self.n_col = tiles.ncol # Size of the merged array (to be defined when the merged array is initialized). self.nx = None self.ny = None self.nz = None self.downsample = 1 self.level_delta = 0 self.merged_data = None self.merged_segmentation = None
[docs] def merge_additive(self, reconstruction_mode='constant', tree_mode='mean', progress_bar=True): """ Perform merging with a mean algorithm for overlapping areas. Maximum merging should be preferred to avoid integer overflowing and higher signals on the overlapping areas. Parameters ---------- mode: string APR reconstruction type among ('constant', 'smooth', 'level') Returns ------- None """ if self.merged_data is None: self._initialize_merged_array() H_pos = self.database['ABS_H'].to_numpy() H_pos = (H_pos - H_pos.min())/self.downsample V_pos = self.database['ABS_V'].to_numpy() V_pos = (V_pos - V_pos.min())/self.downsample D_pos = self.database['ABS_D'].to_numpy() D_pos = (D_pos - D_pos.min())/self.downsample for i, tile in enumerate(tqdm(self.tiles, desc='Merging', disable=not progress_bar)): if self.type == 'apr': if self.lazy: tile.lazy_load_tile(level_delta=self.level_delta) data = tile.lazy_data[:, :, :] else: tile.load_tile() u = pyapr.reconstruction.APRSlicer(tile.apr, tile.parts, level_delta=self.level_delta, mode=reconstruction_mode, tree_mode=tree_mode) data = u[:, :, :] else: tile.load_tile() data = tile.data x1 = int(H_pos[i]) x2 = int(H_pos[i] + data.shape[2]) y1 = int(V_pos[i]) y2 = int(V_pos[i] + data.shape[1]) z1 = int(D_pos[i]) z2 = int(D_pos[i] + data.shape[0]) self.merged_data[z1:z2, y1:y2, x1:x2] = self.merged_data[z1:z2, y1:y2, x1:x2] + data self.merged_data = self.merged_data.astype('uint16')
[docs] def merge_max(self, reconstruction_mode='constant', tree_mode='mean', debug=False, progress_bar=True): """ Perform merging with a maximum algorithm for overlapping areas. Parameters ---------- mode: string APR reconstruction type among ('constant', 'smooth', 'level') debug: bool add white border on the edge of each tile to see where it was overlapping. Returns ------- None """ if self.merged_data is None: self._initialize_merged_array() H_pos = self.database['ABS_H'].to_numpy() H_pos = (H_pos - H_pos.min())/self.downsample V_pos = self.database['ABS_V'].to_numpy() V_pos = (V_pos - V_pos.min())/self.downsample D_pos = self.database['ABS_D'].to_numpy() D_pos = (D_pos - D_pos.min())/self.downsample for i, tile in enumerate(tqdm(self.tiles, desc='Merging', disable=not progress_bar)): if self.type == 'apr': if self.lazy: tile.lazy_load_tile(level_delta=self.level_delta) data = tile.lazy_data[:, :, :] else: tile.load_tile() u = pyapr.reconstruction.APRSlicer(tile.apr, tile.parts, level_delta=self.level_delta, mode=reconstruction_mode, tree_mode=tree_mode) data = u[:, :, :] else: tile.load_tile() data = downscale_local_mean(tile.data, factors=(self.downsample, self.downsample, self.downsample)) # In debug mode we highlight each tile edge to see where it was if debug: data[0, :, :] = 2**16-1 data[-1, :, :] = 2 ** 16 - 1 data[:, 0, :] = 2 ** 16 - 1 data[:, -1, :] = 2 ** 16 - 1 data[:, :, 0] = 2 ** 16 - 1 data[:, :, -1] = 2 ** 16 - 1 x1 = int(H_pos[i]) x2 = int(H_pos[i] + data.shape[2]) y1 = int(V_pos[i]) y2 = int(V_pos[i] + data.shape[1]) z1 = int(D_pos[i]) z2 = int(D_pos[i] + data.shape[0]) self.merged_data[z1:z2, y1:y2, x1:x2] = np.maximum(self.merged_data[z1:z2, y1:y2, x1:x2], data)
[docs] def merge_segmentation(self, reconstruction_mode='constant', tree_mode='max', debug=False, progress_bar=True): """ Perform merging with a maximum algorithm for overlapping areas. Parameters ---------- mode: string APR reconstruction type among ('constant', 'smooth', 'level') debug: bool add white border on the edge of each tile to see where it was overlapping. Returns ------- None """ if self.merged_segmentation is None: self._initialize_merged_segmentation() H_pos = self.database['ABS_H'].to_numpy() H_pos = (H_pos - H_pos.min())/self.downsample V_pos = self.database['ABS_V'].to_numpy() V_pos = (V_pos - V_pos.min())/self.downsample D_pos = self.database['ABS_D'].to_numpy() D_pos = (D_pos - D_pos.min())/self.downsample for i, tile in enumerate(tqdm(self.tiles, desc='Merging', disable=not progress_bar)): if self.type == 'apr': if self.lazy: tile.lazy_load_segmentation(level_delta=self.level_delta) data = tile.lazy_segmentation[:, :, :] else: tile.load_segmentation() u = pyapr.reconstruction.APRSlicer(tile.apr, tile.parts_cc, level_delta=self.level_delta, mode=reconstruction_mode, tree_mode=tree_mode) data = u[:, :, :] # In debug mode we highlight each tile edge to see where it was if debug: data[0, :, :] = 2**16-1 data[-1, :, :] = 2 ** 16 - 1 data[:, 0, :] = 2 ** 16 - 1 data[:, -1, :] = 2 ** 16 - 1 data[:, :, 0] = 2 ** 16 - 1 data[:, :, -1] = 2 ** 16 - 1 x1 = int(H_pos[i]) x2 = int(H_pos[i] + data.shape[2]) y1 = int(V_pos[i]) y2 = int(V_pos[i] + data.shape[1]) z1 = int(D_pos[i]) z2 = int(D_pos[i] + data.shape[0]) self.merged_segmentation[z1:z2, y1:y2, x1:x2] = np.maximum(self.merged_segmentation[z1:z2, y1:y2, x1:x2], data)
[docs] def crop(self, background=0, xlim=None, ylim=None, zlim=None): """ Add a black mask around the brain (rather than really cropping which makes the overlays complicated in a later stage). Parameters ---------- background: int constant value to replace the cropped area with. xlim: array_like x limits for cropping ylim: array_like y limits for cropping zlim: array_like z limits for cropping Returns ------- None """ if self.merged_data is None: raise TypeError('Error: please merge data before cropping.') if xlim is not None: if xlim[0] != 0: self.merged_data[:, :, :xlim[0]] = background if xlim[1] != self.merged_data.shape[2]: self.merged_data[:, :, xlim[1]:] = background if ylim is not None: if ylim[0] != 0: self.merged_data[:, :ylim[0], :] = background if ylim[1] != self.merged_data.shape[1]: self.merged_data[:, ylim[1]:, :] = background if zlim is not None: if zlim[0] != 0: self.merged_data[:zlim[0], :, :] = background if zlim[1] != self.merged_data.shape[0]: self.merged_data[zlim[1]:, :, :] = background
[docs] def equalize_hist(self, method='opencv'): """ Perform histogram equalization to improve the contrast on merged data. Both OpenCV (only 2D) and Skimage (3D but 10 times slower) are available. Parameters ---------- method: string method for performing histogram equalization among 'skimage' and 'opencv'. Returns ------- None """ if self.merged_data is None: raise TypeError('Error: please merge data before equalizing histogram.') if method == 'opencv': clahe = cv.createCLAHE(tileGridSize=(8, 8)) for i in range(self.merged_data.shape[0]): self.merged_data[i] = clahe.apply(self.merged_data[i]) elif method == 'skimage': self.merged_data = equalize_adapthist(self.merged_data) else: raise ValueError('Error: unknown method for adaptive histogram normalization.')
[docs] def set_downsample(self, downsample): """ Set the downsampling value for the merging reconstruction. Parameters ---------- downsample: int downsample factor Returns ------- None """ # TODO: find a more rigorous way of enforcing this. (Probably requires that the APR is loaded). if downsample not in [1, 2, 4, 8, 16, 32]: raise ValueError('Error: downsample value should be compatible with APR levels.') self.downsample = downsample self.level_delta = int(-np.log2(self.downsample))
[docs] def _initialize_merged_array(self): """ Initialize the merged array in accordance with the asked downsampling. Returns ------- None """ self.nx = int(np.ceil(self._get_nx() / self.downsample)) self.ny = int(np.ceil(self._get_ny() / self.downsample)) self.nz = int(np.ceil(self._get_nz() / self.downsample)) self.merged_data = np.zeros((self.nz, self.ny, self.nx), dtype='uint16')
[docs] def _initialize_merged_segmentation(self): """ Initialize the merged array in accordance with the asked downsampling. Returns ------- None """ self.nx = int(np.ceil(self._get_nx() / self.downsample)) self.ny = int(np.ceil(self._get_ny() / self.downsample)) self.nz = int(np.ceil(self._get_nz() / self.downsample)) self.merged_segmentation = np.zeros((self.nz, self.ny, self.nx), dtype='uint16')
[docs] def _get_nx(self): """ Compute the merged array size for x dimension. Returns ------- _: int x size for merged array """ x_pos = self.database['ABS_H'].to_numpy() return x_pos.max() - x_pos.min() + self.frame_size
[docs] def _get_ny(self): """ Compute the merged array size for y dimension. Returns ------- _: int y size for merged array """ y_pos = self.database['ABS_V'].to_numpy() return y_pos.max() - y_pos.min() + self.frame_size
[docs] def _get_nz(self): """ Compute the merged array size for y dimension. Returns ------- _: int y size for merged array """ z_pos = self.database['ABS_D'].to_numpy() return z_pos.max() - z_pos.min() + self.n_planes
[docs] def _get_n_planes(self): """ Load a tile and check the number of planes per tile. Returns ------- _: int Number of planes per tile; """ tile = self.tiles[0] if self.type == 'apr': if self.lazy: tile.lazy_load_tile() return tile.lazy_data.shape[0] else: tile.load_tile() return tile.apr.shape()[0] else: tile.load_tile() return tile.data.shape[0]
[docs] def _find_if_lazy(self): """ Function to test if all tile can be lazy loaded. Returns ------- _: bool True if the tile can be lazy loaded, false if not. """ try: for tile in self.tiles: tile.lazy_load_tile() except: return False return True