Source code for elektronn2.data.image

# -*- coding: utf-8 -*-
# ELEKTRONN2 Toolkit
# Copyright (c) 2015 Marius Killinger
# All rights reserved
from __future__ import absolute_import, division, print_function
from builtins import filter, hex, input, int, map, next, oct, pow, range, super, zip


__all__ = ['make_affinities', 'downsample_xy', 'ids2barriers', 'smearbarriers',
           'center_cubes', ]


import logging
import multiprocessing
from functools import reduce

import numba
from scipy import ndimage
import scipy.ndimage.filters as filters
from skimage.morphology import watershed
import numpy as np

from .. import malis
from .. import utils

logger = logging.getLogger('elektronn2log')

rand_index = malis.compute_V_rand_N2  # small is better

[docs]def make_affinities(labels, nhood=None, size_thresh=1): """ Construct an affinity graph from a segmentation (IDs) Segments with ID 0 are regarded as disconnected The spatial shape of the affinity graph is the same as of seg_gt. This means that some edges are are undefined and therefore treated as disconnected. If the offsets in nhood are positive, the edges with largest spatial index are undefined. Connected components is run on the affgraph to relabel the IDs locally. Parameters ---------- labels: 4d np.ndarray, int (any precision) Volumes of segmentation IDs (bs, z, y, x) nhood: 2d np.ndarray, int Neighbourhood pattern specifying the edges in the affinity graph Shape: (#edges, ndim) nhood[i] contains the displacement coordinates of edge i The number and order of edges is arbitrary size_thresh: int Size filters for connected components, smaller objects are mapped to BG Returns ------- aff: 5d np.ndarray int16 Affinity graph of shape (bs, #edges, x, y, z) 1: connected, 0: disconnected seg_gt: 4d np.ndarray int16 Affinity graph of shape (bs, x, y, z) Relabelling of components """ if nhood is None: nhood = np.eye(3, dtype=np.int32) aff_sh = [labels.shape[0], nhood.shape[0],] + list(labels.shape[1:]) out_aff = np.zeros(aff_sh, dtype=np.int16) out_seg = np.zeros(labels.shape, dtype=np.int16) for i,l in enumerate(labels): out_aff[i] = malis.seg_to_affgraph(l, nhood) # we throw away the seg sizes out_seg[i], _ = malis.affgraph_to_seg(out_aff[i], nhood, size_thresh) return out_aff, out_seg
def make_nhood_targets(target, nhood): # nhood (edges, displacements) e.g. (5,3) if target.ndim==4: raise NotImplementedError else: assert target.ndim==5 assert target.shape[1]==1 sh = target.shape new_sh = list(sh) new_sh[1] = len(nhood) new_target = -1 * np.ones(new_sh, dtype=target.dtype) sh = sh[2:] for i,off in enumerate(nhood): tmp = target[:,0, max(0,-off[0]):sh[0]-max(0,off[0]), max(0,-off[1]):sh[1]-max(0,off[1]), max(0,-off[2]):sh[2]-max(0,off[2])] if tmp.size: new_target[:, i, max(0,off[0]):sh[0]-max(0,-off[0]), max(0,off[1]):sh[1]-max(0,-off[1]), max(0,off[2]):sh[2]-max(0,-off[2])] = tmp return new_target
[docs]def downsample_xy(d, l, factor): """ Downsample by averaging :param d: data :param l: label :param factor: :return: """ f = int(factor) l_sh = l.shape cut = np.mod(l_sh, f) d = d[:, :, :l_sh[-2]-cut[-2], :l_sh[-1]-cut[-1]] sh = d[:, :, ::f, ::f].shape new_d = np.zeros(sh, dtype=np.float32) l = l[:, :, l_sh[-2]-cut[-2], :l_sh[-1]-cut[-1]] sh = l[:, :, :f, ::f].shape new_l = np.zeros(sh, dtype=l.dtype) for i in range(f): for j in range(f): new_d += d[:, :, i::f, j::f] new_l += l[:, i::f, j::f] d = new_d / f**2 l = new_l / f**2 return d, l
@utils.timeit @numba.jit(nopython=True) def _ids2barriers(ids, barriers, dilute, connectivity): """ Draw a 2 or 4 pix barrier where label IDs are different :param ids: (x,y,z) :param barriers: :param dilute: e.g. [False, True, True] :param connectivity: e.g. [True, True, True] :return: """ nx = ids.shape[0] ny = ids.shape[1] nz = ids.shape[2] for x in np.arange(nx-1): for y in np.arange(ny-1): for z in np.arange(nz-1): if connectivity[0]: if ids[x,y,z]!=ids[x+1,y,z]: barriers[x,y,z] = 1 barriers[x+1,y,z] = 1 if dilute[0]: if x>0: barriers[x-1,y,z] = 1 if x<nx-2: barriers[x+2,y,z] = 1 if connectivity[1]: if ids[x,y,z]!=ids[x,y+1,z]: barriers[x,y,z] = 1 barriers[x,y+1,z] = 1 if dilute[1]: if y>0: barriers[x,y-1,z] = 1 if y<ny-2: barriers[x,y+2,z] = 1 if connectivity[2]: if ids[x,y,z]!=ids[x,y,z+1]: barriers[x,y,z] = 1 barriers[x,y,z+1] = 1 if dilute[2]: if z>0: barriers[x,y,z-1] = 1 if z<nz-2: barriers[x,y,z+2] = 1
[docs]def ids2barriers(ids, dilute=[True,True, True], connectivity=[True, True, True], ecs_as_barr=True, smoothen=False): dilute = np.array(dilute) connectivity = np.array(connectivity) barriers = np.zeros_like(ids, dtype=np.int16) _ids2barriers(ids, barriers, dilute, connectivity) _ids2barriers(ids[::-1,::-1,::-1], barriers[::-1,::-1,::-1], dilute, connectivity) # apply backwards as lazy hack to fix boundary if smoothen: kernel = np.array([[[0.1, 0.2, 0.1], [0.2, 0.3, 0.2], [0.1, 0.2, 0.1]], [[0.3, 0.5, 0.3], [0.5, 1.0, 0.5], [0.3, 0.5, 0.3]], [[0.1, 0.2, 0.1], [0.2, 0.3, 0.2], [0.1, 0.2, 0.1]]]) barriers_s = filters.convolve(barriers.astype(np.float32), kernel.astype(np.float32)) barriers = (barriers_s>4).astype(np.int16) # (old - new).mean() ~ 0 if ecs_as_barr=='new_class': ecs = np.logical_and( (ids==0), (barriers!=1)) barriers[ecs] = 2 elif ecs_as_barr: ecs = (ids==0).astype(np.int16) barriers = np.maximum(ecs, barriers) return barriers
def blob(sizes): """ Return Gaussian blob filter """ grids = [np.linspace(-2.2,2.2,size) for size in sizes] grids = np.meshgrid(*grids, indexing='ij') ret = np.exp(-0.5*(reduce(np.add, list(map(np.square, grids))))) ret = ret / np.square(ret).sum() return ret def _smearbarriers(barriers, kernel): # Note: this is good but makes holes to small, # besides we must raise/lower all confidences in GT barriers = barriers.astype(np.float32) if kernel is None: kernel = np.array([ [[ 0., 0., 0., 0., 0.], [ 0., 0., 0.1, 0., 0.], [ 0., 0.1, 0.2, 0.1, 0.], [ 0., 0., 0.1, 0., 0.], [ 0., 0., 0., 0., 0.]], [[ 0., 0., 0.1, 0., 0.], [ 0., 0.2, 0.4, 0.2, 0.], [ 0.1, 0.4, 1., 0.4, 0.1], [ 0., 0.2, 0.4, 0.2, 0.], [ 0., 0., 0.1, 0., 0.]], [[ 0., 0., 0., 0., 0.], [ 0., 0., 0.1, 0., 0.], [ 0., 0.1, 0.2, 0.1, 0.], [ 0., 0., 0.1, 0., 0.], [ 0., 0., 0., 0., 0.]], ]).T else: sizes = kernel kernel = blob(sizes) index = np.subtract(sizes, 1) index = np.divide(index, 2) kernel[tuple(index)] = 1.0 # set center to 1 barriers = filters.convolve(barriers, kernel) barriers = np.minimum(barriers, 1.0) return barriers
[docs]def smearbarriers(barriers, kernel=None): """ barriers: 3d volume (z,x,y) """ pos = _smearbarriers(barriers, kernel) neg = 1.0 - _smearbarriers(1.0 - barriers, kernel) barriers = 0.5 * (pos + neg) #barriers = np.minimum(barriers, 1.0) return barriers
@numba.jit(nopython=True) def _grow_seg(seg, grow, mask): nx = seg.shape[0] ny = seg.shape[1] nz = seg.shape[2] for x in range(1,nx-1): for y in range(1,ny-1): for z in range(1,nz-1): if mask[0] and (seg[x,y,z]!=0) and (seg[x-1,y,z]==0): grow[x-1,y,z] = seg[x,y,z] if mask[0] and (seg[x,y,z]!=0) and (seg[x+1,y,z]==0): grow[x+1,y,z] = seg[x,y,z] if mask[1] and (seg[x,y,z]!=0) and (seg[x,y-1,z]==0): grow[x,y-1,z] = seg[x,y,z] if mask[1] and (seg[x,y,z]!=0) and (seg[x,y+1,z]==0): grow[x,y+1,z] = seg[x,y,z] if mask[2] and (seg[x,y,z]!=0) and (seg[x,y,z-1]==0): grow[x,y,z-1] = seg[x,y,z] if mask[2] and (seg[x,y,z]!=0) and (seg[x,y,z+1]==0): grow[x,y,z+1] = seg[x,y,z] def grow_seg(seg, pixel=[1,3,3]): """ Grow segmentation labels into ECS/background by n pixel """ if isinstance(pixel, (list, tuple, np.ndarray)): n = np.max(pixel) else: n = pixel pixel = [n,] * 3 if n==0: return seg grow = seg.copy() for i in range(n): mask = np.greater(pixel, 0) _grow_seg(seg, grow, mask) seg = grow.copy() pixel = np.subtract(pixel, 1) return seg
[docs]def center_cubes(cube1, cube2, crop=True): """ shapes (ch,x,y,z) or (x,y,z) """ is_3d = [False, False] if cube1.ndim==3: cube1 = cube1[None] is_3d[0] = True if cube2.ndim==3: cube2 = cube2[None] is_3d[1] = True diffs = np.subtract(cube1.shape, cube2.shape)[1:] assert np.all(diffs%2==0) diffs //= 2 slices1 = [slice(None)] pad1 = [(0,0)] slices2 = [slice(None)] pad2 = [(0,0)] for d in diffs: if d>0: # 1 is larger than 2 if crop: slices1.append(slice(d, -d)) pad1.append((0,0)) slices2.append(slice(None)) pad2.append((0,0)) else: slices1.append(slice(None)) pad1.append((0,0)) slices2.append(slice(None)) pad2.append((d,d)) elif d<0: if crop: slices2.append(slice(-d, d)) pad2.append((0,0)) slices1.append(slice(None)) pad1.append((0,0)) else: slices2.append(slice(None)) pad2.append((0,0)) slices1.append(slice(None)) pad1.append((-d, -d)) else: slices2.append(slice(None)) pad2.append((0,0)) slices1.append(slice(None)) pad1.append((0,0)) cube1 = cube1[slices1] cube2 = cube2[slices2] cube1 = np.pad(cube1, pad1, 'constant') cube2 = np.pad(cube2, pad2, 'constant') if is_3d[0]: cube1 = cube1[0] if is_3d[1]: cube2 = cube2[0] return cube1, cube2
### Segmentation ############################################################## def seg_old(pred=None, thresh=10.0, hi_thresh=140, lo_thresh=18, grow_it=4, slack_dt=True, scale = [20,9,9]): """ pred: int prob map (z,x,y) !!!!! thresh: threshold for """ # This with high threshold -> more holes mem_high = np.invert(((pred > hi_thresh) * 255).astype(np.uint8)) dt_objects_ws = -ndimage.distance_transform_edt(mem_high, sampling=scale) # This is with low threshold -> less holes, closes small stuff mem_low = np.invert(((pred > lo_thresh) * 255).astype(np.uint8)) dt_objects_labels = -ndimage.distance_transform_edt(mem_low, sampling=scale) # Both distance transforms have 0 on the "hard membrane" and large negative values inside segments if slack_dt: smem = ((pred > hi_thresh) * 255).astype(np.uint8) dt_slack = -ndimage.distance_transform_edt(smem, sampling=scale) dt_comb = dt_slack + dt_objects_ws del dt_slack else: dt_comb = dt_objects_ws # create a slack label, this is label 1 # regions where membrane/background/ecs is thick slack_label = ndimage.morphology.binary_erosion(pred > 150, iterations=3) * 1 seeds, num = ndimage.measurements.label(dt_objects_labels<-thresh) # CC seeds[seeds!=0] += 1 # "make space" for the slack label seed with ID 1 #print "SHIFT ON!" seeds[slack_label==1] = 1 ws = watershed(dt_comb, seeds) -1 ws = ws.astype(np.int16) ws = grow_seg(ws, pixel=grow_it) return ws def seg_proc(kwargs): gt = kwargs.pop('gt') seg = seg_old(**kwargs) kwargs.pop('pred') ri, ri_split, ri_merge = rand_index(gt, seg) print("RI=%.4f\tthresh=%i\thi_thresh=%i\tlo_thresh=%i\tgrow=%s"\ %(ri, kwargs['thresh'], kwargs['hi_thresh'], kwargs['lo_thresh'], kwargs['grow_it'])) return ri def optimise_segmentation(gt, pred, save_name, n_proc=2): threshs = [37,40,43,56] # 4 hi_threshs = [235,240,245] # 4 lo_threshs = [1,2,3] # 4 grow_its = [(0,0,0),(1,3,3),(2,6,6),(3,6,6)] # 3 args = [] for thresh in threshs: for grow_it in grow_its: for hi_thresh in hi_threshs: for lo_thresh in lo_threshs: args.append(dict(gt=gt, pred=pred, thresh=thresh, hi_thresh=hi_thresh, lo_thresh=lo_thresh, grow_it=grow_it)) print("Scanning for best SEGMENTATION parameters") mp = multiprocessing.Pool(n_proc) rand_indices = list(mp.map(seg_proc, args)) rand_indices = np.array(rand_indices) for i in range(len(args)): args[i].pop('gt') min_i = np.argmin(rand_indices) seg = seg_old(**args[min_i]) best_ri = rand_indices[min_i] for i in range(len(args)): args[i].pop('pred') report_str = "%s Evaluation\n"%(save_name) report_str += "BEST: RI=%.4f "%rand_indices[min_i] + str(args[min_i]) + '\n' for re, config in zip(rand_indices, args): report_str += "RI=%.4f "%re + str(config) + '\n' with open("%s-Seg_Params.txt" %(save_name), 'w') as f: f.write(report_str) utils.h5save(randomise_colours(seg), '%s_seg.h5' %(save_name,), 'seg') print("Best RI=%.4f" %(best_ri)) return rand_indices, best_ri, seg def randomise_colours(a, size_filter=1500): out = np.zeros_like(a) bc = np.bincount(a.ravel()) colors = (bc>size_filter).nonzero()[0] new_cols = np.random.permutation(len(colors)-1) new_cols = np.hstack([0, new_cols]) for c in colors: i = np.argmax(colors==c) out[a==c] = new_cols[i] return out def billig_seg(gt, pred, thresh, ecs_thresh): seeds, num = ndimage.measurements.label((pred<thresh)) ws = watershed(pred, seeds) ws[pred>ecs_thresh] = 0 seg = ws ri, ri_split, ri_merge = rand_index(gt, seg) return ri, seg def billig_seg_proc(kwargs): gt = kwargs['gt'] pred = kwargs['pred'] thresh = kwargs['thresh'] ecs_thresh = kwargs['ecs_thresh'] ri, seg = billig_seg(gt, pred, thresh, ecs_thresh) print("RI=%.4f\tthresh=%i\tecs_thresh=%i"%(ri, thresh, ecs_thresh)) return ri def optimise_billig_segmentation(gt, pred, save_name, n_proc=2): threshs = np.linspace(100, 240, 6) # 4 ecs_threshs = [253,] #np.linspace(230, 255,5) # 4 args = [] for thresh in threshs: for ecs_thresh in ecs_threshs: args.append(dict(gt=gt, pred=pred, thresh=thresh, ecs_thresh=ecs_thresh)) print("Scanning for best SEGMENTATION parameters") mp = multiprocessing.Pool(n_proc) rand_indices = list(mp.map(billig_seg_proc, args)) rand_indices = np.array(rand_indices) min_i = np.argmin(rand_indices) a = args[min_i] ri, seg = billig_seg(a['gt'], a['pred'], a['thresh'], a['ecs_thresh']) for i in range(len(args)): args[i].pop('pred', None) args[i].pop('gt', None) report_str = "%s Evaluation\n"%(save_name) report_str += "BEST: RI=%.4f "%rand_indices[min_i] + str(args[min_i]) + '\n' for re, config in zip(rand_indices, args): report_str += "RI=%.4f "%re + str(config) + '\n' with open("%s-Seg_Params.txt" %(save_name), 'w') as f: f.write(report_str) utils.h5save(randomise_colours(seg), '%s_seg.h5' %(save_name,), 'seg') print("Best RI=%.4f" %(rand_indices[min_i])) return rand_indices[min_i], seg