Source code for

# -*- coding: utf-8 -*-
# ELEKTRONN2 Toolkit
# Copyright (c) 2015 Marius Killinger
# All rights reserved
from __future__ import absolute_import, division, print_function
from builtins import filter, hex, input, int, map, next, oct, pow, range, super, zip
# TODO: Python 3 compatibility

__all__ = ['trace_zyx2xyz', 'trace_to_kzip', 'SkeletonMFK',

import os
import sys
from subprocess import check_call
import logging
import getpass
from collections import OrderedDict

import numba
from scipy import interpolate
from scipy import sparse
from scipy.sparse import csgraph
import numpy as np

from .. import utils

from ..config import config
from . import transformations

logger = logging.getLogger('elektronn2log')
inspection_logger = logging.getLogger('elektronn2log-inspection')

if sys.version_info[:2] != (2, 7):
    raise ImportError(
        '\nSorry, this module only supports Python 2.7.'
        '\nYour current Python version is {}\n'.format(sys.version)

    from knossos_utils import skeleton as knossos_skeleton
except ImportError as e:
    logger.error('\nFor using the tracing_utils module, you will need to'
                 ' install the knossos_utils module'
                 ' (\n')
    raise e

user_name = getpass.getuser()

with open(os.devnull, 'w') as devnull:
    # mayavi is to dumb to raise an exception and instead crashes whole script....
        # "xset q" will always succeed to run if an X server is currently running
        check_call(['xset', 'q'], stdout=devnull, stderr=devnull)
        import mayavi.mlab as mlab
        # Don't set backend explicitly, use system default...
    # if "xset q" fails, conclude that X is not running
    except: # (OSError, ImportError, CalledProcessError, ValueError)
        logger.warning("No mayavi imported, cannot plot skeletons")
        mlab = None


# Constants for scaling of radius
BASE = 1.3
BASE_I = BASE ** -1
HYST = 0.75  # 0.5: no memory 1.0 complete non-overlap
assert 0.5 <= HYST <= 1.0

def insert(cube, coords, i, off):
    for k in np.arange(coords.shape[0]):
        cube[coords[k,0]-off[0], coords[k,1]-off[1], coords[k,2]-off[2]] = i

def insert_vec(cube, coords, vec, off):
    n = len(coords)
    m = len(vec[0])
    double_inserts = 0
    for i in np.arange(n):
        for j in np.arange(m):
            if abs(cube[coords[i,0]-off[0],
                double_inserts += 1
                cube[coords[i,0]-off[0], coords[i,1]-off[1], coords[i,2]-off[2],j] = np.nan # in case of doubt, don't train here...
                cube[coords[i,0]-off[0], coords[i,1]-off[1], coords[i,2]-off[2],j] = vec[i, j]

    return double_inserts

def ray_cast(max_dists, hull_points, hull_dist, ray_steps, hull_cube, off):
    s = np.float32(0.9) # step length
    sh = hull_cube.shape
    for i in np.arange(len(hull_points)): # take hull point
        # initialise dist and position
        dist = hull_dist[i] + 1e-5
        x = hull_points[i, 0] - off[0]
        y = hull_points[i, 1] - off[1]
        z = hull_points[i, 2] - off[2]
        found = False
        count = 0
        while True:
            count += 1
            x = x + s * ray_steps[i,0]
            y = y + s * ray_steps[i,1]
            z = z + s * ray_steps[i,2]
            if<0.0 or<0.0 or<0.0:
            if>=sh[0] or>=sh[1] or>=sh[2]:
            # search if hull is True in neighbourhood of x,y,z
            found = hull_cube[,,] or \
                    hull_cube[,,] or \
                    hull_cube[,,] or \
                    hull_cube[,,] or \
                    hull_cube[,,] or \
                    hull_cube[,,] or \
                    hull_cube[,,] or \
            if not found:
            if count>200:

            dist = dist + s

        max_dists[i] = dist

@numba.jit(nopython=True, cache=True)
def find_peaks_helper(padded_cube, peak_cube):
    sh = padded_cube.shape
    for z in np.arange(1,sh[0]-1):
        for x in np.arange(1,sh[1]-1):
            for y in np.arange(1,sh[2]-1):
                center = padded_cube[z,x,y]
                is_peak = center >= padded_cube[z-1, x-1, y-1] and \
                          center >= padded_cube[z+1, x+1, y+1] and \
                          center >= padded_cube[z+1, x+1, y-1] and \
                          center >= padded_cube[z+1, x-1, y+1] and \
                          center >= padded_cube[z-1, x+1, y+1] and \
                          center >= padded_cube[z+1, x-1, y-1] and \
                          center >= padded_cube[z-1, x-1, y+1] and \
                          center >= padded_cube[z-1, x+1, y-1] and \
                          center >= padded_cube[z, x, y+1] and \
                          center >= padded_cube[z, x, y-1] and \
                          center >= padded_cube[z, x+1, y] and \
                          center >= padded_cube[z, x-1, y] and \
                          center >= padded_cube[z+1, x, y] and \
                          center >= padded_cube[z-1, x, y] and \
                          center >= padded_cube[z, x+1, y+1] and \
                          center >= padded_cube[z, x-1, y-1] and \
                          center >= padded_cube[z+1, x+1, y] and \
                          center >= padded_cube[z-1, x-1, y] and \
                          center >= padded_cube[z+1, x, y+1] and \
                          center >= padded_cube[z-1, x, y-1] and \
                          center >= padded_cube[z, x-1, y+1] and \
                          center >= padded_cube[z, x+1, y-1] and \
                          center >= padded_cube[z+1, x-1, y] and \
                          center >= padded_cube[z+1, x-1, y] and \
                          center >= padded_cube[z-1, x, y+1] and \
                          center >= padded_cube[z-1, x, y+1]
                if is_peak:
                    peak_cube[z-1,x-1,y-1] = center

def find_peaks(cube):
    padded_cube = np.pad(cube, 1, mode='constant')
    peaks = np.zeros_like(cube)
    find_peaks_helper(padded_cube, peaks)
    #peak_label, n = ndimage.label(peaks)
    #coordinates = ndimage.center_of_mass(peaks, peak_label, index=np.arange(1,n+1))
    indices = np.flatnonzero(peaks)
    maxima = cube.ravel()[indices]
    sort_ix = np.argsort(maxima)
    return indices[sort_ix], maxima[sort_ix]

# WARNING / NOTE: skeleton objects are in xyz-order
[docs]class SkeletonMFK(object): """ Joints: all branches and end points / node terminations (nodes not of deg 2) Branches: Joints of degree >= 3 """
[docs] @staticmethod def find_joints(node_list): joints = {} branches = {} for node in node_list: if > 2: # branching point joints[node.ID] = node branches[node.ID] = node if joints[node.ID] = node # end point return joints, branches
def __init__(self, aniso_scale=2, name=None, skel_num=None): self.aniso_scale = np.array([[1,1,aniso_scale]], dtype=np.float32) self.bones = dict() self.edges = list() self.branches = dict() self.joints = dict() self.all_nodes = None self.hull_points = None self.hull_skel = dict() self.hull_branch = dict() = name self.skel_num = skel_num self.radii = dict() self.all_radii = None self.joint_radii = None self.props = dict() self.all_props = None self.joint_props = None self.joint_id2joint_index = dict() # For training self.kdt_hull = None self.linked_data = None self.lost_track = False self.position_s = None self.position_l = None self.direction_il = None self.start_new_training = True self.prev_batch = None self.trafo = None self.prev_scale = 1.0 self.prev_gamma = 0.0 self.training_traces = [] self.background_processes = False self._hull_point_bg = dict() self.cnn_grid = None # Old for training self.debug_traces = [] self.debug_traces_current = [] self.debug_grads = [] self.debug_grads_current = []
[docs] def init_from_annotation(self, skeleton_annotatation, min_radius=None, interpolation_resolution=0.5, interpolation_order=1): # Read annotation data structures and convert to dicts and np.ndarrays #print(len(skeleton_annotatation.getNodes())) self.joints, self.branches = self.find_joints(skeleton_annotatation.getNodes()) #print(len(self.joints), len(self.branches)) visited = {n: False for n in skeleton_annotatation.getNodes()} for joint_id, joint in self.joints.items(): directions = joint.getNeighbors() for d in directions: if visited[d]: # we have visited this bone already continue visited[d] = True # mark as visited bone = OrderedDict() # create new bone bone[joint] = True # start the bone at the joint current_node = d # next go to the node in the selected direction while True: bone[current_node] = True if > 2 or # At new branch or end point, the bone ends here # add edge between starting joint and this branch if joint_id < current_node.ID: edge = (joint_id, current_node.ID) else: edge = (current_node.ID, joint_id) self.edges.append(edge) break else: # The node has 2 neighbours, one from which we come and # another one to which we go nb = list(current_node.getNeighbors()) assert len(nb) == 2 # Test which node we visit next if nb[0] in bone: assert nb[1] not in bone current_node = nb[1] if nb[1] in bone: assert nb[0] not in bone current_node = nb[0] self.bones[edge] = list(bone.keys()) # Convert bones to arrays for edge, bone in self.bones.items(): self.bones[edge] = np.array([x.getCoordinate() for x in bone], dtype=np.float32) self.radii[edge] = np.array([x.getDataElem('radius') for x in bone], dtype=np.float32) try: axoness_pred = np.array([x.getDataElem('axoness_pred') for x in bone], dtype=np.int16) spiness_pred = np.array([x.getDataElem('spiness_pred') for x in bone], dtype=np.int16) props = np.concatenate([axoness_pred[:,None], spiness_pred[:,None]], axis=1) self.props[edge] = props except KeyError: pass # convert joints to arrays self.joint_radii = np.array([x.getDataElem('radius') for x in self.joints.values()], dtype=np.float32) try: axoness_pred = np.array([x.getDataElem('axoness_pred') for x in self.joints.values()], dtype=np.int16) spiness_pred = np.array([x.getDataElem('spiness_pred') for x in self.joints.values()], dtype=np.int16) self.joint_props = np.concatenate([axoness_pred[:,None], spiness_pred[:,None]], axis=1) except KeyError: pass self.joint_id2joint_index = dict(zip(self.joints.keys(), range(len(self.joints)))) self.joints = np.array([x.getCoordinate() for x in self.joints.values()], dtype=np.float32) # convert branches to arrays self.branches = np.array([x.getCoordinate() for x in self.branches.values()], dtype=np.float32) if interpolation_resolution is not None: for edge, bone in self.bones.items(): if len(bone)<=1: continue try: new_bone = self.interpolate_bone(bone,max_k=interpolation_order, resolution=interpolation_resolution) self.radii[edge] = self.interpolate_prop(bone, self.radii[edge], new_bone) except: bone, keep_index = utils.unique_rows(bone) new_bone = self.interpolate_bone(bone,max_k=interpolation_order, resolution=interpolation_resolution) self.radii[edge] = self.interpolate_prop(bone, self.radii[edge][keep_index], new_bone) try: self.props[edge] = self.interpolate_prop(bone, self.props[edge],new_bone, discrete=True) except: pass self.bones[edge] = new_bone self.all_nodes = np.vstack([self.joints,] + list(self.bones.values())) self.all_radii = np.hstack([self.joint_radii,] + list(self.radii.values())) try: self.all_props = np.vstack([self.joint_props, ] + list(self.props.values())) except: pass if min_radius: self.all_radii = np.maximum(self.all_radii, min_radius)
[docs] def save(self, fname): utils.picklesave(self, fname)
[docs] def interpolate_bone(self, bone, max_k=1, resolution=0.5): bone_iso = bone * self.aniso_scale linear_distances = np.linalg.norm(np.diff(bone_iso, axis=0), axis=1) total_dist = linear_distances.sum() k = min(max_k, bone_iso.shape[0]-1) tck, u = interpolate.splprep(bone_iso.T, k=k) n = max(2, int(float(total_dist) / resolution)) new = interpolate.splev(np.linspace(0,1,n), tck) new = np.array(new).T / self.aniso_scale return new
[docs] def interpolate_prop(self, old_bone, old_prop, new_bone, discrete=False): dtype = np.int16 if discrete else np.float32 new_prop = np.zeros((len(new_bone),)+old_prop.shape[1:], dtype=dtype) old_bone_iso = old_bone * self.aniso_scale new_bone_iso = new_bone * self.aniso_scale start_i = 0 stop_i = 1 min_dist = np.linalg.norm(new_bone_iso[0] - old_bone_iso[stop_i]) for i in range(len(new_bone)): dist_start = np.linalg.norm(new_bone_iso[i] - old_bone_iso[start_i]) dist_stop = np.linalg.norm(new_bone_iso[i] - old_bone_iso[stop_i]) min_dist = min(min_dist, dist_stop) if (min_dist < dist_stop) and stop_i+1<len(old_bone): stop_i += 1 start_i += 1 dist_start = dist_stop dist_stop = np.linalg.norm(new_bone_iso[i] - old_bone_iso[stop_i]) min_dist = dist_stop if discrete: if dist_stop > dist_start: new_prop[i] = old_prop[start_i] else: new_prop[i] = old_prop[stop_i] else: d = dist_start + dist_stop new_prop[i] = dist_stop/d * old_prop[start_i] + dist_start/d * old_prop[stop_i] return new_prop
[docs] @utils.cache() def get_kdtree(self, static_points, k=1, jobs=-1): kdt = utils.KDT(n_neighbors=k, n_jobs=jobs, algorithm='kd_tree', leaf_size=20) * self.aniso_scale) # change metric) #assert np.all(kdt._fit_X / self.aniso_scale == static_points) return kdt
[docs] @utils.cache() def get_knn(self, kdt, query_points, k=None): if k is not None: pass #assert k==kdt.n_neighbors else: k = kdt.n_neighbors distances, indices = kdt.kneighbors(query_points * self.aniso_scale, n_neighbors=k) # change metric) static_points = kdt._fit_X.astype(np.float32) # Attention those still have the aniso scale in [:,2] if k==1: indices = indices[:,0] distances = distances[:,0].astype(np.float32) coordinates = static_points[indices] / self.aniso_scale # change to pixel coordinates else: distances = distances.astype(np.float32) coordinates = static_points[indices] / self.aniso_scale # change to pixel coordinates assert coordinates.shape[1] == k return distances, indices, coordinates
[docs] def get_closest_node(self, position_s): kdt = self.get_kdtree(self.all_nodes, k=1, jobs=1) dist, ind, nearest_s = self.get_knn(kdt, position_s) if position_s.ndim==1: dist = dist[0] ind = ind[0] nearest_s = nearest_s[0] return dist.astype(np.float32), ind, nearest_s
### Sampling routines for getting training data ###
[docs] def sample_skel_point(self, rng, joint_ratio=None): n = len(self.all_nodes) if joint_ratio: if rng.rand() < joint_ratio: n = len(self.joints) i = rng.randint(n) node = self.all_nodes[i] return node, i
[docs] def sample_tube_point(self, rng, r_max_scale=0.9, joint_ratio=None): """ This is skeleton node based sampling: Go to a random node, sample a random orthogonal direction go a random distance into direction (uniform over the [0, r_max_scale * local maximal radius]) """ # tt = utils.Timer() if self.hull_points is None: kdt = None else: if self.kdt_hull is None: raise RuntimeError("Hull kdts must be pre initialised") kdt = self.kdt_hull node, node_i = self.sample_skel_point(rng, joint_ratio) direc_iso = self.sample_local_direction_iso(node) local_r = self.all_radii[node_i] * r_max_scale count = 0 max_count = 30 proposal = node clipped = False while True: r = rng.rand() * local_r phi = rng.rand() * 2 * np.pi cos_theta = rng.rand() * 2 - 1 sin_theta = np.sqrt(1 - cos_theta ** 2) x = np.cos(phi) * sin_theta y = np.sin(phi) * sin_theta z = cos_theta rand_vec = np.array([x, y, z]) orthogonal_vec_iso = np.cross(direc_iso, rand_vec) orthogonal_vec_iso /= np.linalg.norm(orthogonal_vec_iso) orthogonal_vec = orthogonal_vec_iso / self.aniso_scale[0] proposal = node + orthogonal_vec * r if kdt is None: return proposal dist, ind, coord = self.get_knn(kdt, proposal) dist = dist[0] if dist < 1.5: # we are within hull: break if count >= max_count / 2 and not clipped: local_r *= 0.5 clipped = True logger.debug("Sample hull point: clipped r") if count >= max_count: logger.debug( "Sample hull point: max count %i reached" % max_count) proposal = node break # tt.check("\tdouble_check") count += 1 return proposal
[docs] def sample_local_direction_iso(self, point, n_neighbors=6): """ For a point gives the local skeleton direction/orientation by fitting a line through the nearest neighbours, sign is randomly assigned """ kdt = self.get_kdtree(self.all_nodes, k=n_neighbors, jobs=1) dist, ind, coord = self.get_knn(kdt, point) dist = dist[0] ind = ind[0] coord = coord[0] # maybe use dist as weights for svd? neibs_iso = coord * self.aniso_scale # transform to iso space uu, dd, vv = np.linalg.svd(neibs_iso - neibs_iso.mean(axis=0)) direc_iso = vv[0] # take largest eigenvector direc_iso /= np.linalg.norm(direc_iso, axis=0) # normalise return direc_iso
[docs] def sample_tracing_direction_iso(self, rng, local_direction_iso, c=0.5): """ Sample a direction close to the local direction there is a prior so that the normalised (0,1) angle of deviation a has this distribution: p(a) = 1/N * (1-c*a), where N= 1 - c/2, tmp is the inverse cdf of this. """ if rng.rand() > 0.5: # the sign is undefined, choose randomly local_direction_iso *= -1 u = rng.rand() tmp = (1 - np.sqrt(1-(2*c - c**2)*u)) / c # theta scaled between 0 and 1 # theta scaled between 0 and 90 deg in rad i.e. 0 and pi/2 theta = tmp * 0.5 * np.pi max_count = 1000 count = 0 proposal = local_direction_iso while True: proposal = rng.rand(3) * 2 - 1 proposal /= np.linalg.norm(proposal, axis=0) # normalise cos_alpha =, local_direction_iso) if cos_alpha < 0: # flip to next best within +/- 90 deg cos_alpha *= -1 proposal *= -1 alpha = np.arccos(cos_alpha) if alpha < theta + 0.01: break count += 1 if count>max_count: logger.debug("Sample tracing directions: max count reached") break return proposal
### Loss and loss gradient for Theano Graph ###
[docs] def get_loss_and_gradient(self, new_position_s, cutoff_inner=1.0/3, rise_factor=0.1): """ prediction_c (zxy) Zoned error surface: flat in inner hull (selected at cutoff_inner) constant gradient in "outer" hull towards nearest inner hull voxel gradient increasing with distance (scaled by rise_factor) for predictions outside hull """ inner_hull, indices = self.get_hull_points_inner(cutoff_inner, return_indices=True) kdt = self.get_kdtree(inner_hull, k=1, jobs=1) dist, ind, nearest_s = self.get_knn(kdt, new_position_s) dist = dist[0] ind = ind[0] nearest_s = nearest_s[0] if config.inspection:"nearest_s: %s"% (nearest_s.tolist())) if dist<1.5: # we are within inner hull. The maximal distance if # within hull is exactly: np.linalg.norm(np.multiply( # [0.5, 0.5, 0.6], [1, 1, 2])) = 1.22... --> add some margin loss = 0.0 grad_s = np.zeros((3,), dtype=np.float32) self.lost_track = False else: loss = dist # max dist of closest node max_dist = self.hull_skel['max_dist'][indices[ind]] # pointing from nearest to new position unit_grad = (new_position_s - nearest_s) unit_grad /= np.linalg.norm(unit_grad * self.aniso_scale[0], axis=0) if max_dist > dist: # we are in hull but not in inner tube grad_s = unit_grad * 1.0 self.lost_track = False else: # we are outside hull self.lost_track = True factor = rise_factor * (dist - max_dist) grad_s = unit_grad * (1 + factor) self.debug_traces_current.append(new_position_s) self.debug_grads_current.append(grad_s) loss = np.array([loss,], dtype=np.float32) return loss, grad_s
def _new_training_trace(self, **get_batch_kwargs): """ Prepare skeleton for a new training (sample location/direction, reset stuff) Parameters ---------- get_batch_kwargs """ #tt = utils.Timer() if self.current_trace: if len(self.training_traces)>20: self.training_traces = self.training_traces[-2:] self.training_traces.append(self.current_trace) self.current_trace = Trace(linked_skel=self) r_max_scale = get_batch_kwargs['r_max_scale'] tracing_dir_prior_c = get_batch_kwargs['tracing_dir_prior_c'] joint_ratio = get_batch_kwargs.get('joint_ratio', None) position_s = self.sample_tube_point(self.linked_data.rng, r_max_scale=r_max_scale, joint_ratio=joint_ratio) if config.inspection:"Start new training") local_direc_is = self.sample_local_direction_iso(position_s, n_neighbors=6) tracing_direc_is = self.sample_tracing_direction_iso(self.linked_data.rng, local_direc_is, c=tracing_dir_prior_c) self.position_s = position_s self.position_l = position_s[::-1] # from lab2data (xyz)->(zxy) self.direction_il = tracing_direc_is[::-1] # from lab2data (xyz)->(zxy) self.current_trace.append(position_s, coord_cnn=[0,]*3, grad=[0,]*3, features=[0,]*7) self.lost_track = False self.trafo = None #tt.check("final")
[docs] @staticmethod def get_scale_factor(radius, old_factor, scale_strenght): """ Parameters ---------- radius: predicted radius (not the true radius) old_factor: factor by which the radius prediction and the image was scaled scale_strenght: limits the maximal scale factor Returns ------- new_factor """ # if old was large (zoom in), radius is smaller hi = BASE ** (scale_strenght * 2) + 1e-3 # e.g 1.69 for 1.3**2 lo = BASE_I ** (scale_strenght * 4) - 1e-3 # e.g. 0.35 for 1/1.3 ** 4 radius_true = radius / old_factor new_factor = REF_RADIUS / radius_true new_factor = np.clip(new_factor, lo, hi) change = new_factor / old_factor if new_factor > 1.0: # left side if change >= BASE_H: # growing new_factor = old_factor * BASE elif change < BASE_IH: new_factor = old_factor * BASE_I else: new_factor = old_factor elif new_factor < 1.0: # right side if change <= BASE_IH: # zoom out new_factor = old_factor * BASE_I elif change > BASE_H: # zoom in new_factor = old_factor * BASE else: new_factor = old_factor else: new_factor = old_factor if config.inspection:"SCALE: %.2f -> %.2f, factor0: %.2f, factor: %.2f" % (radius, radius_true, 20.0 / radius_true, new_factor)) return new_factor
@staticmethod @utils.cache def make_grid(t_grid_sh, z_shift): """ Parameters ---------- t_grid_sh: tagged shape (pixel shape + strides) z_shift: shift of center (positive means more look ahead) Returns ------- points: coordinate list zyx order zz,yy,xx: coordinate meshgrid """ sh = np.array(t_grid_sh.spatial_shape) st = np.array(t_grid_sh.strides) lim = (sh-1) * st + 1 lim //= 2 zz,yy,xx = np.mgrid[-lim[0]:lim[0]:1j * sh[0], -lim[1]:lim[1]:1j * sh[1], -lim[2]:lim[2]:1j * sh[2]] zz += z_shift points = np.hstack([zz.ravel()[:,None], yy.ravel()[:,None], xx.ravel()[:,None]]).astype(np.float32) return points, zz,yy,xx
[docs] @staticmethod def point_potential(r, margin_scale, size, repulsion=None): if repulsion is None: repulsion = 1.0 left = margin_scale * size x = (r - left)/(size - left) v = 1.0 - (x**3*(x*(x*6 - 15) + 10)) # soft step function v = np.minimum(np.maximum(v, 0.0), 1.0) return v * repulsion
[docs] def getbatch(self, prediction, scale_strenght, **get_batch_kwargs): """ Parameters ---------- prediction: [[new_position_c, radius, ]] scale_strenght: limits the maximal scale factor for zoom get_batch_kwargs Returns ------- batch: img, target_img, target_grid, target_node """ get_batch_kwargs = dict(get_batch_kwargs) # copy because we destroy it if self.start_new_training: self._new_training_trace(**get_batch_kwargs) self.start_new_training = False scale = 1.0 self.prev_scale = 1.0 self.prev_gamma = np.random.rand() * 2 * np.pi elif np.allclose(prediction, 0): scale = self.prev_scale if config.inspection: inspection_logger.warning("getbatch with no feedback: either " "training on same skel or error") else: prediction = prediction[0] new_position_c = prediction[:3] radius = prediction[3] # this is just the predicted val, not the true new_position_l, tracing_direc_il = self.trafo.cnn_pred2lab_position(new_position_c) new_position_s = new_position_l[::-1] self.position_s = new_position_s self.position_l = new_position_s[::-1] # from lab2data (xyz)->(zxy) self.direction_il = tracing_direc_il scale = self.get_scale_factor(radius, self.prev_scale, scale_strenght) self.prev_scale = scale grid = get_batch_kwargs.pop('grid', False) t_grid_sh = get_batch_kwargs.pop('t_grid_sh', None) z_shift = get_batch_kwargs['z_shift'] get_batch_kwargs.pop('joint_ratio', None) try: if config.inspection:"Getslice from position_l %s in " "direction_il %s, SCALE %.2f"%(np.array_str( self.position_l, precision=1, suppress_small=True), self.direction_il, scale)) get_batch_kwargs['gamma'] = self.prev_gamma data_batch = self.linked_data.get_newslice(self.position_l, self.direction_il, scale=scale, **get_batch_kwargs) img, target_img, trafo = data_batch[:3] if grid: raise RuntimeError("The creation of the grid target must" "be testet for spatial coherence again") if not self.cnn_grid: self.cnn_grid = self.make_grid(t_grid_sh, z_shift) grid_coords_c, zz, yy, xx = self.cnn_grid #dir_point_s = self.position_s + self.direction_il[::-1]/self.aniso_scale[0] #dir_momentum_s = dir_point_s - self.current_trace.coords[-4:].mean(0) #dir_momentum_ci = trafo.lab_coord2cnn_coord(dir_momentum_s[::-1])*[2,1,1] #directions_ci = grid_coords_c*[2,1,1] #direction_difference = cdist(directions_ci, dir_momentum_ci[None], 'cosine') # 0..2, 45deg thresh: > 1.7 #direction_difference = (direction_difference[:,0] - 1.7).astype(np.float32) #direction_difference[direction_difference<0.0] = 0.0 #direction_difference[np.isnan(direction_difference)] = 0.0 # center if even is NULL #repulsion = 1.0 - direction_difference * 2 # * strength, without factor ~ -25% ### TODO might also make repulsion depending on skel_node instead of grid_position. No WHY? ### TODO repulsion is not smooth enough repulsion = 1.0 grid_coords = trafo.cnn_coord2lab_coord(grid_coords_c,add_offset_l=True) dist, ind, nearest_s = self.get_closest_node(grid_coords[:,::-1]) radii = self.all_radii[ind] target_grid = self.point_potential(dist, 0.1, radii, repulsion) target_grid = target_grid.reshape(zz.shape)[None] # add channel if np.allclose(target_grid, 0.0): logger.warning("WTF") self.debug_store = [img, target_grid] self.debug_store2 = [nearest_s, radii] else: target_grid = np.ones((1,1,1,1), dtype=np.float32) # Get bio labels/classes dist, ind, nearest_s = self.get_closest_node(self.position_s) classes = self.all_props[ind] target_node = np.zeros(7, dtype=np.float32) target_node[classes[0]+1] = 1 target_node[classes[1]+4] = 1 target_node[0] = self.all_radii[ind] * scale if config.inspection:"target_node %s, (true r: %.1f)" %(target_node, self.all_radii[ind])) batch = (img, target_img, target_grid, target_node) self.trafo = trafo return batch except transformations.WarpingOOBError: if config.inspection:"OOB in getbatch") raise transformations.WarpingOOBError("Batch OOB")
[docs] def step_feedback(self, new_position_s, new_direction_is, pred_c, pred_features, cutoff_inner=1.0/3, rise_factor=0.1): inner_hull, indices = self.get_hull_points_inner(cutoff_inner, return_indices=True) kdt = self.get_kdtree(inner_hull, k=1, jobs=1) dist, ind, nearest_s = self.get_knn(kdt, new_position_s) dist = dist[0] ind = ind[0] nearest_s = nearest_s[0] # we are within inner hull. The maximal distance if within hull is 1.2... if dist < 1.5: loss = 0.0 grad_s = np.array([0, 0, 0], dtype=np.float32) else: loss = dist max_dist = self.hull_skel['max_dist'][indices[ind]] # max dist of closest node unit_grad = (new_position_s - nearest_s) # pointing from nearest to new position unit_grad /= np.linalg.norm(unit_grad * self.aniso_scale[0], axis=0) # normalise grad if max_dist > dist: # we are in hull but not in inner tube grad_s = unit_grad * 1.0 else: # we are outside hull self.lost_track = True if config.inspection:"Lost track") factor = rise_factor * (dist - max_dist) grad_s = unit_grad * (1 + factor) self.current_trace.append(new_position_s, coord_cnn=pred_c, grad=grad_s, features=pred_features) # Actually the new positions should be set in getbach, but we need to # set them here to because sometimes getbatch might be called without # "start_new_training" and with only zeros as prediction self.position_s = new_position_s self.position_l = new_position_s[::-1] # from lab to data frame (xyz) -> (zxy) self.direction_il = new_direction_is[::-1] # from lab to data frame (xyz) -> (zxy) loss = np.array([loss,], dtype=np.float32) return loss, grad_s, nearest_s
[docs] def step_grid_update(self, grid, radius, bio): pred_features = np.hstack([radius, bio]) flat_indices, scores = find_peaks(grid[0,0]) grid_coords_c, zz, yy, xx = self.cnn_grid preds_c = grid_coords_c[flat_indices] #preds_l = self.trafo.cnn_coord2lab_coord(preds_c,add_offset_l=True) if len(scores): new_position_c = grid_coords_c[flat_indices[-1]] preds_c = new_position_c[None] else: new_position_c = np.array([2,0,0], dtype=np.float32) new_position_l, tracing_direc_il = self.trafo.cnn_pred2lab_position(new_position_c) new_position_s = new_position_l[::-1] new_direction_is = tracing_direc_il[::-1] if config.inspection:"GridUpdate, node pred %s" % ( np.array_str(pred_features, precision=2, suppress_small=True),)) "GridUpdate, new_position_c: %s, new_position_l: %s" % ( new_position_c, np.array_str(new_position_l, precision=1, suppress_small=True))) if config.inspection>1: img, grid_t = self.debug_store utils.picklesave( [img[0,0], grid_t[0], grid[0,0]], '/tmp/{}_debug_skel_{}'.format(user_name, self.skel_num)) self.current_trace.append(new_position_s, coord_cnn=new_position_c, features=pred_features) # Actually the new positions should be set in getbach, but we need to # set them here to because sometimes getbatch might be called without # "start_new_training" and with only zeros as prediction self.position_s = new_position_s self.position_l = new_position_s[::-1] # from lab to data frame (xyz) -> (zxy) self.direction_il = new_direction_is[::-1] # from lab to data frame (xyz) -> (zxy) return new_position_c[None], preds_c, scores
### Plotting ###
[docs] def plot_skel(self, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) x = self.all_nodes[:,0] y = self.all_nodes[:,1] z = self.all_nodes[:,2]*self.aniso_scale[0,2] mlab.points3d(x,y,z, scale_factor=0.8, color=(1,0,0), figure=fig) for bone in self.bones.values(): x = bone[:,0] y = bone[:,1] z = bone[:,2]*self.aniso_scale[0,2] mlab.plot3d(x,y,z,tube_radius=0.4, color=(0.3,0.3,0.3), figure=fig) self._plot_joints(fig=fig) return fig
[docs] def plot_debug_traces(self, grads=True, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) traces = np.array(self.debug_traces) for trace in traces: x = trace[:, 0] y = trace[:, 1] z = trace[:, 2] * self.aniso_scale[0, 2] mlab.plot3d(x, y, z, tube_radius=0.2, color=(0.3, 0.3, 0.3), figure=fig) if grads: grads = np.array(self.debug_grads) for grad, trace in zip(grads, traces): x = trace[:, 0] y = trace[:, 1] z = trace[:, 2] * self.aniso_scale[0, 2] gx = -grad[:, 0] gy = -grad[:, 1] gz = -grad[:, 2] * self.aniso_scale[0, 2] mlab.quiver3d(x,y,z, gx, gy, gz, figure=fig, color=(0,0.6,0.2), scale_factor=3) return fig
[docs] def plot_radii(self, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) x = self.all_nodes[:,0] y = self.all_nodes[:,1] z = self.all_nodes[:,2]*self.aniso_scale[0,2] r = self.all_radii mlab.points3d(x,y,z,r, scale_mode='scalar', scale_factor=1, color=(0,0.5,0.5), mode='sphere', opacity=0.1, figure=fig) return fig
def _plot_joints(self, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) x = self.joints[:,0] y = self.joints[:,1] z = self.joints[:,2]*self.aniso_scale[0,2] mlab.points3d(x,y,z, scale_factor=3, color=(1,1,0), figure=fig) return fig ### Hull methods ###
[docs] def calc_max_dist_to_skels(self): hull = self.hull_points # (n, 3) direc = self.hull_skel['direc'] #(n, 3) dist = self.hull_skel['dist'] #(n) # true distances max_dist = np.zeros(len(hull), dtype=np.float32) # This ray has unit magnitude in the true metric ray_steps = -direc/(np.linalg.norm(direc * self.aniso_scale, axis=1)[:,None]+1e-5) # create dense cube and insert hull sh = np.max(hull,0) + 1 off = np.min(hull,0) sh -= off hull_cube = np.zeros(sh, dtype=np.bool) insert(hull_cube, hull, True, off) # cast rays through dense cube ray_cast(max_dist, hull, dist, ray_steps, hull_cube, off) # in this case the magnitude of the direc vector is 0 anyway max_dist[np.any(~np.isfinite(ray_steps), axis=1)] = 1.0 rel_dist = dist / max_dist return max_dist, rel_dist
[docs] def map_hull(self, hull_points): """ Distances take already into account the anisotropy in z (i.e. they are true distances) But all coordinates for hulls and vectors are still pixel coordinates """ self.hull_points = hull_points.astype(np.int16) hull_points = hull_points.astype(np.float32) kdt_skel = self.get_kdtree(self.all_nodes) dist_skel, ind_skel, coord_skel = self.get_knn(kdt_skel, hull_points) self.hull_skel['dist'] = dist_skel self.hull_skel['ind'] = ind_skel self.hull_skel['direc'] = coord_skel - hull_points ## NNs - Queries max_dist, rel_dist = self.calc_max_dist_to_skels() self.hull_skel['max_dist'] = max_dist self.hull_skel['rel_dist'] = rel_dist if len(self.branches): kdt_branch = self.get_kdtree(self.branches) dist_branch, ind_branch, coord_branch = self.get_knn(kdt_branch, hull_points) self.hull_branch['dist'] = dist_branch self.hull_branch['ind'] = ind_branch self.hull_branch['direc'] = coord_branch - hull_points else: self.hull_branch['dist'] = np.zeros(len(hull_points), dtype=np.float32) self.hull_branch['ind'] = None self.hull_branch['direc'] = np.zeros((len(hull_points),3), dtype=np.float32) if not np.all(np.isfinite(dist_skel)) or \ not np.all(np.isfinite(self.hull_branch['dist'])): raise ValueError("InfiniteValue") self.kdt_hull = self.get_kdtree(self.hull_points, k=1, jobs=1) # store for later use
[docs] @utils.cache() def get_hull_points_inner(self, cutoff=1.0/3, return_indices=False): mask = self.hull_skel['rel_dist'] < cutoff if return_indices: return self.hull_points[mask], mask.nonzero()[0] else: return self.hull_points[mask]
[docs] @utils.cache() def get_hull_branch_direc_cutoff(self, cutoff=25, normalise=False): mask = self.hull_branch['dist'] < cutoff ret = self.hull_branch['direc'] * mask #[mask] if normalise: ret /= (self.hull_branch['dist'][:,None]+1e-5) return ret
[docs] @utils.cache() def get_hull_branch_dist_cutoff(self, cutoff=25, normalise=True): mask = self.hull_branch['dist'] < cutoff ret = self.hull_branch['dist'] * mask #[mask] if normalise: ret = (ret > 0) return ret
[docs] @utils.cache() def get_hull_skel_direc_rel(self): return self.hull_skel['direc'] / self.hull_skel['max_dist'][:,None]
[docs] def plot_hull(self, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) x = self.hull_points[:,0] y = self.hull_points[:,1] z = self.hull_points[:,2]*self.aniso_scale[0,2] mlab.points3d(x,y,z, scale_factor=1, color=(1,1,1), mode='cube', opacity=0.1, figure=fig) return fig
[docs] def plot_hull_inner(self, cutoff, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) inner_hull = self.get_hull_points_inner(cutoff) x = inner_hull[:,0] y = inner_hull[:,1] z = inner_hull[:,2]*self.aniso_scale[0,2] mlab.points3d(x,y,z, scale_factor=1, color=(0.8,0.8,1), mode='cube', opacity=0.1, figure=fig) return fig
[docs] def plot_vec(self, substep=15, dict_name='skel', key='direc', vec=None, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) x = self.hull_points[:,0] y = self.hull_points[:,1] z = self.hull_points[:,2]*self.aniso_scale[0,2] x, y, z = x[::substep], y[::substep], z[::substep] if vec is None: dict_ = self.hull_skel if dict_name=='skel' else self.hull_branch u = dict_[key][:,0] v = dict_[key][:,1] w = dict_[key][:,2]*self.aniso_scale[0,2] else: u = vec[:,0] v = vec[:,1] w = vec[:,2]*self.aniso_scale[0,2] u,v,w = u[::substep], v[::substep], w[::substep] mlab.quiver3d(x,y,z, u,v,w, figure=fig) return fig
[docs]class Trace(object): """ Unless otherwise state all coordinates are in skeleton system (xyz) with z-axis anisotrope and all distances are in pixels (conversion to mu: 1/100) """ def __init__(self, linked_skel=None, aniso_scale=2,max_cutoff=200, uturn_detection_k=40, uturn_detection_thresh=0.45, uturn_detection_hold=10, feature_count=7): self.aniso_scale = np.array([[1, 1, aniso_scale]], dtype=np.float32) self.skel = linked_skel self.lost_track = False self.uturn_occurred = False self.coords = utils.AccumulationArray(right_shape=3, n_init=500) self.seg_length = utils.AccumulationArray(n_init=500) self.runlengths = utils.AccumulationArray(n_init=500) self.dist_self = utils.AccumulationArray(right_shape=2, n_init=500) self.dist_skel = utils.AccumulationArray(n_init=500) self.uturn_mask = utils.AccumulationArray(n_init=500, dtype=np.bool) self.coords_cnn = utils.AccumulationArray(right_shape=3, n_init=500) self.grads = utils.AccumulationArray(right_shape=3, n_init=500) self.features = utils.AccumulationArray(right_shape=feature_count, n_init=500) self.max_cutoff = max_cutoff self.uturn_detection_k = uturn_detection_k self.uturn_detection_thresh = uturn_detection_thresh self.uturn_detection_hold = uturn_detection_hold self.kdt = utils.DynamicKDT(k=uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale) self.root = 0 self.comment = ""
[docs] def new_reverted_trace(self): new_trace = Trace(self.skel, self.aniso_scale[0,2], self.max_cutoff, self.uturn_detection_k, self.uturn_detection_thresh, self.uturn_detection_hold,[1:]) new_trace.coords = utils.AccumulationArray(data=self.coords[::-1]) new_trace.seg_length = utils.AccumulationArray(data=self.seg_length[::-1]) new_trace.runlengths = utils.AccumulationArray(data=self.runlengths[-1]-self.runlengths[::-1]) new_trace.dist_self = utils.AccumulationArray(data=self.dist_self[::-1]) new_trace.dist_skel = utils.AccumulationArray(data=self.dist_skel[::-1]) new_trace.uturn_mask = utils.AccumulationArray(data=self.uturn_mask[::-1]) new_trace.coords_cnn = utils.AccumulationArray(data=self.coords_cnn[::-1]) new_trace.grads = utils.AccumulationArray(data=self.grads[::-1]) new_trace.features = utils.AccumulationArray(data=self.features[::-1]) if len(new_trace)<=self.uturn_detection_k: kdt = utils.DynamicKDT(k=self.uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale) for c in kdt.append(c) new_trace.kdt = kdt else: new_trace.kdt = utils.DynamicKDT(, k=self.uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale) new_trace.root = len(self)-1 try: self.comment except AttributeError: self.comment = "" new_trace.comment = self.comment+ " R" return new_trace
[docs] def new_cut_trace(self, start, stop): new_trace = Trace(self.skel, self.aniso_scale[0,2], self.max_cutoff, self.uturn_detection_k, self.uturn_detection_thresh, self.uturn_detection_hold,[1:]) new_trace.coords = utils.AccumulationArray(data=self.coords[start:stop]) new_trace.seg_length = utils.AccumulationArray(data=self.seg_length[start:stop]) new_trace.runlengths = utils.AccumulationArray(data=self.runlengths[start:stop]-self.runlengths[start]) new_trace.dist_self = utils.AccumulationArray(data=self.dist_self[start:stop]) new_trace.dist_skel = utils.AccumulationArray(data=self.dist_skel[start:stop]) new_trace.uturn_mask = utils.AccumulationArray(data=self.uturn_mask[start:stop]) new_trace.coords_cnn = utils.AccumulationArray(data=self.coords_cnn[start:stop]) new_trace.grads = utils.AccumulationArray(data=self.grads[start:stop]) new_trace.features = utils.AccumulationArray(data=self.features[start:stop]) if len(new_trace)<=self.uturn_detection_k: kdt = utils.DynamicKDT(k=self.uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale) for c in kdt.append(c) new_trace.kdt = kdt else: new_trace.kdt = utils.DynamicKDT(, k=self.uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale) if (self.root - start) >= 0 and (self.root - start) < len(new_trace): new_trace.root = self.root - start else: new_trace.root = None #np.minimum(len(new_trace)-1, self.root - start) try: self.comment except AttributeError: self.comment = "" new_trace.comment = self.comment + "C%i-%i"%(start, stop) return new_trace
def __len__(self): return len(self.coords)
[docs] def save(self, fname): utils.picklesave(self, fname)
[docs] def save_to_kzip(self, fname): trace_to_kzip(self, fname)
[docs] def add_offset(self, off): off = np.atleast_2d(off) self.coords.add_offset(off) if len(self)<=self.uturn_detection_k: kdt = utils.DynamicKDT(k=self.uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale) for c in kdt.append(c) self.kdt = kdt else: self.kdt = utils.DynamicKDT(, k=self.uturn_detection_k, n_jobs=1, aniso_scale=self.aniso_scale)
[docs] def append(self, coord, coord_cnn=None, grad=None, features=None): self.coords.append(coord) if len(self)>1: diff = np.linalg.norm((coord - self.coords[-2]) * self.aniso_scale[0]) else: diff = 5 # just guess self.seg_length.append(diff) self.runlengths.append(self.runlength) if len(self) > self.uturn_detection_k+1: distances, indices, coordinates = self.kdt.get_knn(coord, k=self.uturn_detection_k) dist = distances.mean() else: dist = self.seg_length.ema * float(self.uturn_detection_k + 1) / 2 normalisation = self.seg_length.ema * float(self.uturn_detection_k + 1) / 2 self.dist_self.append([dist, dist/normalisation]) self.kdt.append(coord) if self.skel: dist, index, node = self.skel.get_closest_node(coord) self.dist_skel.append(dist) if grad is not None: self.grads.append(grad) if features is not None: self.features.append(features) if coord_cnn is not None: self.coords_cnn.append(coord_cnn) # Check for criteria last_dist = self.dist_self[-self.uturn_detection_hold:, 1] uturn = np.all(last_dist < self.uturn_detection_thresh) self.uturn_mask.append(uturn) if not self.uturn_occurred and uturn: # register the first u-turn self.uturn_occurred = (len(self), self.runlength) if not self.lost_track: lost = self.dist_skel.max() > self.max_cutoff if lost: self.lost_track = (len(self), self.runlength)
[docs] def append_serial(self, *args): for arg in zip(*args): self.append(*arg)
@property def avg_seg_length(self): return self.seg_length.mean() @property def runlength(self): return self.seg_length.sum() @property def avg_dist_skel(self): return self.dist_skel.mean() @property def max_dist_skel(self): return self.dist_skel.max() @property def avg_dist_self(self): return self.dist_self.mean() @property def min_dist_self(self): return self.dist_self.min()[0] @property def min_normed_dist_self(self): return self.dist_self.min()[1]
[docs] def tortuosity(self, start=None, end=None): if start is None: start = 0 if end is None: end = len(self) arc = self.runlengths[end-1] - self.runlengths[start] chord = np.linalg.norm((self.coords[end-1] - self.coords[start]) * self.aniso_scale[0]) t = arc / chord return t
[docs] def plot(self, grads=True, skel=True, rand_color=False, fig=None): if fig is None: fig = mlab.figure(bgcolor=(1.0, 0.8, 0.4), size=(600,400)) if skel and self.skel: fig = self.skel.plot_skel(fig=fig) x = self.coords[:, 0] y = self.coords[:, 1] z = self.coords[:, 2] * self.aniso_scale[0, 2] line_c = tuple(np.random.rand(3)) if rand_color else (0, 0, 0.7) point_c = line_c if rand_color else (0.6, 0.7, 0.9) mlab.plot3d(x, y, z, tube_radius=0.2, color=line_c, figure=fig) mlab.points3d(x, y, z, scale_factor=0.8, color=point_c, figure=fig) if grads and self.grads.length: x = self.coords[:, 0] y = self.coords[:, 1] z = self.coords[:, 2] * self.aniso_scale[0, 2] gx = -self.grads[:, 0] gy = -self.grads[:, 1] gz = -self.grads[:, 2] * self.aniso_scale[0, 2] mlab.quiver3d(x, y, z, gx, gy, gz, figure=fig, color=(0, 0.6, 0.2), scale_factor=3) return fig
[docs] def split_uturns(self, return_accum_pathlength=False, print_stat=False): transitions = np.diff(self.uturn_mask, axis=0, n=1) transitions = np.nonzero(transitions)[0] transitions[0::2] -= self.uturn_detection_hold # if add: end segment closer to uturn transitions[1::2] -= self.uturn_detection_hold # if subtract: start new segment closer to uturn transitions = np.minimum(np.maximum(0, transitions), len(self)) transitions = np.hstack((0, transitions, len(self))) new_traces = [] accum_pathlenghts = [] accum_dist_skel = [] accum_runlength = 0 for i in range(0, len(transitions)-1, 2): new = self.__class__(self.skel, self.aniso_scale[0, 2], self.max_cutoff, self.uturn_detection_k, self.uturn_detection_thresh, self.uturn_detection_hold) start, stop = transitions[i], transitions[i+1] if print_stat: print("cutting between %i and %i " % (start, stop)) if start<stop: # some transitions are too short coords = self.coords[start:stop] new.append_serial(coords) new_traces.append(new) # accumulate pathlenghts and dist to skel over splits runlengths = self.runlengths[start:stop] runlengths = runlengths + accum_runlength - runlengths[0] # shift accum_pathlenghts.append(runlengths) accum_runlength = runlengths[-1] if self.skel: dist_skel = self.dist_skel[start:stop].copy() if i > 0: # this makes the eval stop if the trace deviated from the # skeleton too much during the uturn max_dist_in_uturn = np.max(self.dist_skel[transitions[i-1]:start]) dist_skel[0] = np.maximum(dist_skel[0], max_dist_in_uturn) accum_dist_skel.append(dist_skel) else: accum_dist_skel.append([]) if return_accum_pathlength: return new_traces, np.hstack(accum_pathlenghts), np.hstack(accum_dist_skel) else: return new_traces
def normalised_min_dist(tr, point): dist, ind, coord = tr.kdt.get_knn(point, k=1) radius = tr.features[ind, 0] return dist, dist / radius def simple_stats(a): m = np.mean(a, axis=0) s = np.std(a, axis=0) minv = np.min(a, axis=0) maxv = np.max(a, axis=0) return np.array([m, s, minv, maxv]) def radius_hist(r): bins = np.array([0,8,14,23,35,50,80,200]) counts, bins = np.histogram(r, bins=bins, density=True) return counts def get_merge_features(main_tr, main_node, sub_tr, sub_node, end_match): m_slice_small = slice(max(0,main_node-5), main_node+5) m_slice_large = slice(max(0,main_node-25), main_node+25) m_points = main_tr.coords[m_slice_small] * main_tr.aniso_scale uu, pc_m, pc_dir_m = np.linalg.svd(m_points-m_points.mean(0)) m_feat_small = simple_stats(main_tr.features[m_slice_small]) m_feat_large = simple_stats(main_tr.features[m_slice_large]) m_radius_hist= radius_hist(main_tr.features[m_slice_large, 0]) m_tortuosity = main_tr.tortuosity() main_features= np.hstack([pc_m, pc_dir_m.ravel(), m_feat_small.ravel(), m_feat_large.ravel(), m_radius_hist, m_tortuosity]) if end_match: s_slice_small = slice(max(0,sub_node-10), sub_node) s_slice_large = slice(max(0,sub_node-50), sub_node) else: s_slice_small = slice(sub_node, sub_node+10) s_slice_large = slice(sub_node, sub_node+50) s_points = sub_tr.coords[s_slice_small] * sub_tr.aniso_scale if len(s_points)==0: pass uu, pc_s, pc_dir_s = np.linalg.svd(s_points-s_points.mean(0)) s_feat_small = simple_stats(sub_tr.features[s_slice_small]) s_feat_large = simple_stats(sub_tr.features[s_slice_large]) s_radius_hist= radius_hist(sub_tr.features[s_slice_large, 0]) s_tortuosity = sub_tr.tortuosity() sub_features = np.hstack([pc_s, pc_dir_s.ravel(), s_feat_small.ravel(), s_feat_large.ravel(), s_radius_hist, s_tortuosity]) dist = np.linalg.norm((main_tr.coords[main_node]-sub_tr.coords[sub_node])*sub_tr.aniso_scale) r_m = main_tr.features[main_node, 0] r_s = sub_tr.features[sub_node, 0] pc_dir_similarity = np.abs(, pc_dir_s.T)) joint_features = np.hstack([dist, dist/r_m, dist/r_s, 2*dist/(r_m+r_s), pc_dir_similarity.ravel()]) return main_features, sub_features, joint_features def split_tree_components(tracetree, cut=False): if tracetree.num_components==1: return [tracetree,] new_trees = [list() for i in range(tracetree.num_components)] for tr_i in tracetree.traces: c = tracetree.tr_i2comp_i[tr_i] tr = tracetree.traces[tr_i] cuts = tracetree.trace_cuts.get(tr_i, None) if cuts and cut: tr = tr.new_cut_trace(*cuts) new_trees[c].append(tr) for i in range(tracetree.num_components): new_tree = TraceTree(new_trees[i], tracetree.spine_thresh, tracetree.endpoint_thresh) new_trees[i] = new_tree return new_trees class TraceTree(object): def __init__(self, traces, spine_thresh=1.5, endpoint_thresh=0.8): """ :param traces: :param spine_thresh: """""":param spine_thresh: float How large the maximal relative distance needs to be for a loop to be retained as a spine branch :param endpoint_thresh: float Threshold of relative distance between endpoint and other trace tp count as a connection """ # Rename trace keys to smaller contiguous numbers if not isinstance(traces, dict): traces = dict(zip(range(len(traces)), traces)) self.traces = traces self.trace_cuts = dict() self.pruned_traces = [] self.edge_candidates = dict() self.edges = [] self.tr_i2comp_i = None self.num_components = 1 self.aniso = np.array([[1,1,2]]) self.spine_thresh = spine_thresh self.endpoint_thresh = endpoint_thresh self.joined_kdt = None self.joined_coords = None self.joined_radii = None def build_joined_features(self): self.joined_coords = np.vstack([tr.coords for tr in self.traces.values()]) self.joined_radii = np.hstack([tr.features[:,0] for tr in self.traces.values()]) kdt = utils.DynamicKDT(self.joined_coords, n_jobs=-1, aniso_scale=[1, 1, 2], k=1) self.joined_kdt = kdt def cut_traces_inplace(self): for tr_i in self.traces: tr = self.traces[tr_i] cuts = self.trace_cuts.get(tr_i, None) if cuts: new_tr = tr.new_cut_trace(*cuts) self.traces[tr_i] = new_tr def to_kzip(self, fname, save_loops=False, save_edge_candiates=False, add_edges=False, save_edges=False): fname = os.path.expanduser(fname) fpath, comment_name = os.path.split(fname) skel_objs = [] component_annos = [] for c in range(self.num_components): skel_obj = knossos_skeleton.Skeleton() skel_objs.append(skel_obj) anno_ = knossos_skeleton.SkeletonAnnotation() anno_.scaling = (9.0, 9.0, 20.0) anno_.setComment(comment_name+"-c%i"%c) skel_obj.add_annotation(anno_) component_annos.append(anno_) # Save all cut traces to own anno-obj of their component node_mappings = dict() for tr_i in self.traces: if self.tr_i2comp_i is not None: c = self.tr_i2comp_i[tr_i] else: c = 0 anno = component_annos[c] tr = self.traces[tr_i] cuts = self.trace_cuts.get(tr_i, None) if cuts: tr = tr.new_cut_trace(*cuts) _, node_mapping = trace_to_anno(tr, fname, anno) node_mappings[tr_i] = node_mapping # Save all edges (between cut points) to a edge annotation if len(self.edges) and save_edges: edge_anno = knossos_skeleton.SkeletonAnnotation() edge_anno.scaling = (9.0, 9.0, 20.0) edge_anno.setComment("Edges-"+comment_name) skel_obj_edges = knossos_skeleton.Skeleton() skel_obj_edges.add_annotation(edge_anno) for e in self.edges: try: main, sub = e main_node_i, sub_node_i = self.edge_candidates[tuple(e)][1] except KeyError: # MST might turn around edge order main_node_i, sub_node_i = self.edge_candidates[tuple(e)[::-1]][1] main, sub = e[::-1] main_node = knossos_skeleton.SkeletonNode() x,y,z = np.round(self.traces[main].coords[main_node_i]).astype(np.int16) main_node.from_scratch(edge_anno, x,y,z) edge_anno.addNode(main_node) sub_node = knossos_skeleton.SkeletonNode() x,y,z = np.round(self.traces[sub].coords[sub_node_i]).astype(np.int16) sub_node.from_scratch(edge_anno, x,y,z) edge_anno.addNode(sub_node) edge_anno.addEdge(main_node, sub_node) if add_edges: main_cut = self.trace_cuts.get(main, [0, None])[0] sub_cut = self.trace_cuts.get(sub, [0, None])[0] main_i = main_node_i - main_cut sub_i = sub_node_i - sub_cut try: n_main = node_mappings[main][main_i] n_sub = node_mappings[sub][sub_i] n_main.annotation.addEdge(n_main, n_sub) except: pass outfile = fpath + "/edges-" + comment_name + '' skel_obj_edges.to_kzip(outfile) # As Node==Edge==Node in one skeleton Tree (for making GT in knossos) if save_edge_candiates and self.edge_candidates: edge_candiate_anno = knossos_skeleton.SkeletonAnnotation() edge_candiate_anno.scaling = (9.0, 9.0, 20.0) edge_candiate_anno.setComment(comment_name+"-Edge-Candiates") skel_obj_candidates = knossos_skeleton.Skeleton() skel_obj_candidates.add_annotation(edge_candiate_anno) for e in self.edge_candidates: main, sub = e main_node_i, sub_node_i = self.edge_candidates[e][1] main_node = knossos_skeleton.SkeletonNode() x,y,z = np.round(self.traces[main].coords[main_node_i]).astype(np.int16) main_node.from_scratch(edge_candiate_anno, x,y,z) main_node.setComment(comment_name+"-M%i_%i-S%i_%i-main" %(main, main_node_i, sub, sub_node_i)) edge_candiate_anno.addNode(main_node) sub_node = knossos_skeleton.SkeletonNode() x,y,z = np.round(self.traces[sub].coords[sub_node_i]).astype(np.int16) sub_node.from_scratch(edge_candiate_anno, x,y,z) sub_node.setComment(comment_name+"-M%i_%i-S%i_%i-sub" %(main, main_node_i, sub, sub_node_i)) edge_candiate_anno.addNode(sub_node) edge_candiate_anno.addEdge(main_node, sub_node) outfile = fpath + "/candidates-" + comment_name + '' skel_obj_candidates.to_kzip(outfile) if save_loops: skel_obj_loops = knossos_skeleton.Skeleton() for i,t in enumerate(self.pruned_traces): anno_loop, _ = trace_to_anno(t, comment_name+'-loop%i'%i) skel_obj_loops.add_annotation(anno_loop) outfile = fname + '' skel_obj_loops.to_kzip(outfile) for i,skel_obj in enumerate(skel_objs): outfile = fname + '' %i skel_obj.to_kzip(outfile) def is_loop(self, trace, traces): """ :param trace: Trace test candidate :param traces: list of Traces :return: bool """ if not len(traces): return False # Determine the average distance of the two end points to all other # traces. Use the distances normalised by the radius of the other trace end_points = trace.coords[[0,-1]] relative_distances = np.ones(len(traces)) * np.inf for i,tr in enumerate(traces.values()): if tr == trace: # don't compare trace to itself, would be loop always continue dist, ind, coord = tr.kdt.get_knn(end_points, k=1) radii = tr.features[ind, 0] relative_distances[i] = (dist/radii).mean() k = relative_distances.argmin() tr_i = traces.keys()[k] if relative_distances[k] < self.endpoint_thresh: #now check if there exists point that is farther away from main dist, rel_dist = normalised_min_dist(traces[tr_i], max_point = rel_dist.argmax() if rel_dist[max_point] >= self.spine_thresh : is_loop = False cut_a = 0 if rel_dist[0] < rel_dist[-1] else len(rel_dist) cut_b = max_point cut_0 = min(cut_a, cut_b) cut_1 = max(cut_a, cut_b) assert cut_0 < cut_1 cuts = (cut_0, cut_1) else: is_loop = True cuts = None else: is_loop = False cuts = None return is_loop, cuts def closest_approach(self, tr_a, tr_b): """ :param tr_a: Trace :param tr_b: Trace :return: """ b0toa = normalised_min_dist(tr_a, tr_b.coords[0])[1] b1toa = normalised_min_dist(tr_a, tr_b.coords[-1])[1] a0tob = normalised_min_dist(tr_b, tr_a.coords[0])[1] a1tob = normalised_min_dist(tr_b, tr_a.coords[-1])[1] case = np.argmin([b0toa, b1toa, a0tob, a1tob]) geometrict_dist = np.min([b0toa, b1toa, a0tob, a1tob]) if geometrict_dist > self.spine_thresh : return None end_match = case in [1,3] a_is_main = case in [0, 1] main_tr = tr_a if case in [0,1] else tr_b sub_tr = tr_b if case in [0,1] else tr_a slice_20 = slice(-20, None) if end_match else slice(None, 20) sub_coords = sub_tr.coords[slice_20] distances, indices, coordinates = main_tr.kdt.get_knn(sub_coords, k=1) max_seg_length = sub_tr.seg_length[slice_20].max() sub_merge_candidates = (distances < max_seg_length).nonzero()[0] if len(sub_merge_candidates): sub_node = sub_merge_candidates[0] if end_match else sub_merge_candidates[-1] else: sub_node = distances.argmin() main_node = indices[sub_node] # Take the main node which was found in knn if end_match: # For end_match the index needs to be shifted by the trace length to comply with the indices of sub_coords sub_node += len(sub_tr) - np.minimum(20, len(sub_tr)) cut_start = 0 if end_match else sub_node cut_end = sub_node+1 if end_match else len(sub_tr) assert cut_end-cut_start>0 # Check for cases where there is a spine loop if (cut_start==sub_node and end_match): assert cut_start==0 cut_end = len(sub_tr) end_match = not end_match elif (cut_end==sub_node and not end_match): assert cut_end==len(sub_tr) cut_start = 0 end_match = not end_match cuts = (cut_start, cut_end) nodes = (main_node, sub_node) try: feat = get_merge_features(main_tr,main_node,sub_tr,sub_node, end_match) except: pass # merge_feat_main, merge_feat_sub, merge_feat_joint = feat return geometrict_dist, a_is_main, nodes, cuts, feat def make_merge_graph(self): n = len(self.traces) msd_list= [] keys = self.traces.keys() # For all pairwise traces find closest approach, node/cuts indices and # features for the edge classifier, collect positive edges in "msd_list" for s in range(n): for t in range(s+1, n): tr_a = keys[s] tr_b = keys[t] tmp = self.closest_approach(self.traces[tr_a], self.traces[tr_b]) if tmp is None: # If Components are disconnected still add them msd_list.append([tr_a, tr_a, 0]) msd_list.append([tr_b, tr_b, 0]) continue geometrict_dist, a_is_main, nodes, cuts, feat = tmp if np.isclose(geometrict_dist, 0.0): geometrict_dist = 0.1 # otherwise connected components will consider this as split if a_is_main: main = tr_a sub = tr_b else: main = tr_b sub = tr_a main_coord = self.traces[main].coords[nodes[0]] sub_coord = self.traces[sub].coords[nodes[1]] coords = (main_coord, sub_coord) # if edge_classifier(feat) > thresh: # don't add edges which are not classified self.edge_candidates[(main, sub)] = [geometrict_dist, nodes, cuts, feat, coords] msd_list.append([main, sub, geometrict_dist]) if len(msd_list)==0: return # Create MST from edge graph (actually MST-Forest) main, sub, dist = np.array(msd_list).T a = np.hstack([main, sub]) b = np.hstack([sub, main]) values = np.hstack([dist, dist]) adj_mat = sparse.csr_matrix(sparse.coo_matrix( (values, (a,b)) )) mst = csgraph.minimum_spanning_tree(adj_mat) edges_mst = np.array(mst.nonzero()).T self.edges = edges_mst # For all MST-edges update the cuts (cut as much as possible to cover merge positions) for edge_mst in edges_mst: try: sub = edge_mst[1] tmp = self.edge_candidates[tuple(edge_mst)] except KeyError: # MST might turn around edge order sub = edge_mst[0] tmp = self.edge_candidates[tuple(edge_mst)[::-1]] old_cuts = self.trace_cuts.get(sub, None) if old_cuts is None: new_cuts = tmp[2] else: new_cuts = (np.minimum(old_cuts[0], tmp[2][0]), np.maximum(old_cuts[1], tmp[2][1])) self.trace_cuts[sub] = new_cuts # If edges candidates were classified negative, the components oft the # MST-forest must be split # Returns unconnected nodes (empty slots) as component too! num, labels = csgraph.connected_components(mst, directed=False) # Therefore select only the stuff which is in keys comp_names, components = np.unique(labels[keys], return_inverse=True) self.tr_i2comp_i = dict(zip(keys, components)) self.num_components = comp_names.size def simplify(self, profile=False): if profile: tt = utils.Timer() keep_traces = {} pruned_traces = {} traces = dict(self.traces) keys = np.array(traces.keys()) trace_lengths = np.array([traces[tr_i].runlength for tr_i in keys]) keys_sorted = keys[np.argsort(trace_lengths)] for tr_i in keys_sorted: tr = traces[tr_i] is_loop, cuts = self.is_loop(tr, traces) if is_loop: pruned_traces[tr_i] = traces.pop(tr_i) else: # Don't cut traces here, it might mess up connection parts #if cuts: # cut traces must be put to stack again because they might be a loop now # cut_tr = tr.new_cut_trace(*cuts) # traces[tr_i] = cut_tr #else: keep_traces[tr_i] = tr self.pruned_traces = pruned_traces self.traces = keep_traces if profile: tt.check(name='prune loops') # If all traces are loops with another trace take the largest single trace if len(self.traces)==0: tr_lengths = np.array([(tr_i, tr.runlength) for tr_i,tr in pruned_traces.items()]) i = np.argmax(tr_lengths[:,1]) i = int(tr_lengths[i,0]) tr0 = pruned_traces.pop(i) self.traces[i] = tr0 self.make_merge_graph() if profile: tt.check("merge graph") def make_segment_lenghts(bone): segment_lengths = np.linalg.norm(np.diff(bone, n=1, axis=0) * np.array([[1, 1, 2]]), axis=1) segment_lengths = np.hstack(([0, ], segment_lengths)) segment_lengths[0] = segment_lengths[1] * 0.5 segment_lengths[1:-2] = (segment_lengths[1:-2] + segment_lengths[ 2:-1]) * 0.5 segment_lengths[-1] = segment_lengths[-1] * 0.5 return segment_lengths def runlength_metric(path_lengths, distances, cut_start=10, cut_max=200, num=50): cutoffs = np.linspace(cut_start, cut_max, num=num) runlengths = utils.AccumulationArray() correct_lenghts = utils.AccumulationArray() for cutoff in cutoffs: for i in range(len(distances)): larger = np.nonzero(distances[i] >= cutoff)[0] if len(larger): correct_lenghts.append(path_lengths[i][larger[0]-1]) else: correct_lenghts.append(path_lengths[i][-1]) mean_correct_lenght = correct_lenghts.mean() correct_lenghts.clear() runlengths.append(mean_correct_lenght) return, cutoffs def runlength_metric_GT(trace, skel=None, cut_start=10, cut_max=200, num=20): if skel is None: skel = trace.skel cutoffs = np.linspace(cut_start, cut_max, num=num) runlengths = np.zeros(num) trace_points = trace_kdt = utils.KDT(radius=cut_max, n_jobs=-1) * np.array([[1,1,2]])) for edge, bone in skel.bones.items(): segment_lengths = make_segment_lenghts(bone) dist, ind = trace_kdt.radius_neighbors(bone * np.array([[1,1,2]])) for i, cutoff in enumerate(cutoffs): was_traced = np.zeros(len(bone), dtype=np.bool) for k in range(len(dist)): was_traced[k] = np.any(dist[k]<=cutoff) runlengths[i] += segment_lengths[was_traced].sum() return cutoffs, runlengths def trace_to_anno(trace_xyz, name, anno=None, root=None): if isinstance(trace_xyz, Trace): feature_avail = len(trace_xyz.features)==len(trace_xyz) else: feature_avail = True radius = 1.0 if anno is None: anno = knossos_skeleton.SkeletonAnnotation() anno.scaling = (9.0, 9.0 ,20.0) anno.setComment(os.path.split(name)[1]) node_mapping = dict() last_node = knossos_skeleton.SkeletonNode() trace_coords = np.round(trace_xyz.coords).astype(np.int16) # if feature_avail: radius = trace_xyz.features[0, 0] last_node.from_scratch(anno, trace_coords[0,0], trace_coords[0,1], trace_coords[0,2], radius=radius) if feature_avail: last_node.setDataElem("axoness_proba0", trace_xyz.features[0, 1]) last_node.setDataElem("axoness_proba1", trace_xyz.features[0, 2]) last_node.setDataElem("axoness_proba2", trace_xyz.features[0, 3]) last_node.setDataElem("spiness_proba0", trace_xyz.features[0, 4]) last_node.setDataElem("spiness_proba1", trace_xyz.features[0, 5]) last_node.setDataElem("spiness_proba2", trace_xyz.features[0, 6]) anno.addNode(last_node) node_mapping[0] = last_node for k in range(1, len(trace_coords)): coord = trace_coords[k] if feature_avail: radius = trace_xyz.features[k,0] new_node = knossos_skeleton.SkeletonNode() new_node.from_scratch(anno, coord[0], coord[1], coord[2], radius=radius) if feature_avail: last_node.setDataElem("axoness_proba0", trace_xyz.features[k, 1]) last_node.setDataElem("axoness_proba1", trace_xyz.features[k, 2]) last_node.setDataElem("axoness_proba2", trace_xyz.features[k, 3]) last_node.setDataElem("spiness_proba0", trace_xyz.features[k, 4]) last_node.setDataElem("spiness_proba1", trace_xyz.features[k, 5]) last_node.setDataElem("spiness_proba2", trace_xyz.features[k, 6]) node_mapping[k] = new_node anno.addNode(new_node) last_node.addChild(new_node) last_node = new_node if root is None: if isinstance(trace_xyz, Trace) and trace_xyz.root is not None: root = trace_xyz.root if root is not None: node_mapping[root].setRoot() return anno, node_mapping
[docs]def trace_to_kzip(trace_xyz, fname): skel_obj = knossos_skeleton.Skeleton() anno, node_mapping = trace_to_anno(trace_xyz, fname) skel_obj.add_annotation(anno) outfile = fname + '' skel_obj.to_kzip(outfile)
def trace_to_kzip_multi(traces, fname): if isinstance(traces, dict): traces = traces.values() skel_obj = knossos_skeleton.Skeleton() for i, trace_xyz in enumerate(traces): if not isinstance(trace_xyz, np.ndarray): anno, node_mapping = trace_to_anno(trace_xyz, fname+"_%i"%i) else: trace_xyz = np.round(trace_xyz).astype(np.int16) anno = knossos_skeleton.SkeletonAnnotation() anno.scaling = (9.0, 9.0 ,20.0) anno.setComment(os.path.split(fname)[1]+"_%i"%i) last_node = knossos_skeleton.SkeletonNode() last_node.from_scratch(anno, trace_xyz[0,0], trace_xyz[0,1], trace_xyz[0,2]) last_node.setRoot() anno.addNode(last_node) for coord in trace_xyz[1:]: new_node = knossos_skeleton.SkeletonNode() new_node.from_scratch(anno, coord[0], coord[1], coord[2]) anno.addNode(new_node) last_node.addChild(new_node) last_node = new_node skel_obj.add_annotation(anno) outfile = fname + '' skel_obj.to_kzip(os.path.expanduser(outfile)) def bbox_cube_anno(off_xyz, sz_xyz, comment="?", cross_edges=False): off_xyz = np.array(off_xyz) sz_xyz = np.array(sz_xyz) cords = [off_xyz+sz_xyz*[0,0,0], off_xyz+sz_xyz*[1,0,0], off_xyz+sz_xyz*[1,1,0],#2 off_xyz+sz_xyz*[0,1,0], off_xyz+sz_xyz*[0,0,1],#4 off_xyz+sz_xyz*[1,0,1], off_xyz+sz_xyz*[1,1,1],#6 off_xyz+sz_xyz*[0,1,1], ] cords = np.array(cords) anno = knossos_skeleton.SkeletonAnnotation() anno.scaling = (9.0, 9.0 ,20.0) anno.setComment("%s: %s - %s"%(comment, off_xyz,sz_xyz)) nodes = [] for x,y,z in cords: new_node = knossos_skeleton.SkeletonNode() new_node.from_scratch(anno,x,y,z) anno.addNode(new_node) nodes.append(new_node) edges = [(0,1),(0,3),(0,4), (1,2),(1,5), (3,2),(3,7), (4,5),(4,7), (6,7),(6,5),(6,2)] for n1, n2 in edges: nodes[n1].addChild(nodes[n2]) if cross_edges: for n1 in anno.nodes: for n2 in anno.nodes: n1.addChild(n2) return anno