# -*- coding: utf-8 -*-
# ELEKTRONN2 Toolkit
# Copyright (c) 2015 Martin Drawitsch, Marius Killinger
# All rights reserved
from __future__ import absolute_import, division, print_function
from builtins import filter, hex, input, int, map, next, oct, pow, range, super, zip
__all__ = ['warp_slice', 'get_tracing_slice', 'WarpingOOBError',
'Transform', 'trafo_from_array', 'get_warped_slice']
import itertools
import logging
from functools import reduce
import numpy as np
import numba
from .. import utils
logger = logging.getLogger('elektronn2log')
inspection_logger = logging.getLogger('elektronn2log-inspection')
# @numba.jit(nopython=True, cache=True)
# def map_coordinates_nearest(src, coords, lo, dest):
# sh = coords.shape
# for z in np.arange(sh[0]):
# for x in np.arange(sh[1]):
# for y in np.arange(sh[2]):
# u = np.int32(np.round(coords[z,x,y,0] - lo[0]))
# v = np.int32(np.round(coords[z,x,y,1] - lo[1]))
# w = np.int32(np.round(coords[z,x,y,2] - lo[2]))
# or
# u = coords[z, x, y, 0] - lo[0]
# v = coords[z, x, y, 1] - lo[1]
# w = coords[z, x, y, 2] - lo[2]
# dest[z,x,y] = src[u,v,w]
@numba.guvectorize(['void(float32[:,:,:], float32[:], float32[:], float32[:,],)'],
'(x,y,z),(i),(i)->()', target='parallel', nopython=True)
def map_coordinates_nearest(src, coords, lo, dest):
u = np.int32(np.round(coords[0] - lo[0]))
v = np.int32(np.round(coords[1] - lo[1]))
w = np.int32(np.round(coords[2] - lo[2]))
dest[0] = src[u,v,w]
@numba.guvectorize(['void(float32[:,:,:], float32[:], float32[:], float32[:,],)'],
'(x,y,z),(i),(i)->()', target='parallel', nopython=True)
def map_coordinates_linear(src, coords, lo, dest):
u = coords[0] - lo[0]
v = coords[1] - lo[1]
w = coords[2] - lo[2]
u0 = np.int32(u)
u1 = u0 + 1
du = u - u0
v0 = np.int32(v)
v1 = v0 + 1
dv = v - v0
w0 = np.int32(w)
w1 = w0 + 1
dw = w - w0
val = src[u0, v0, w0] * (1-du) * (1-dv) * (1-dw) +\
src[u1, v0, w0] * du * (1-dv) * (1-dw) +\
src[u0, v1, w0] * (1-du) * dv * (1-dw) +\
src[u0, v0, w1] * (1-du) * (1-dv) * dw +\
src[u1, v0, w1] * du * (1-dv) * dw +\
src[u0, v1, w1] * (1-du) * dv * dw +\
src[u1, v1, w0] * du * dv * (1-dw) +\
src[u1, v1, w1] * du * dv * dw
dest[0] = val
@numba.jit(nopython=True, cache=True)
def map_coordinates_max_kernel(src, coords, lo, k, dest):
k = np.float32(k)
kz = min(0.5, k/2)
sh = coords.shape
sh_src = src.shape
for z in np.arange(sh[0]):
for x in np.arange(sh[1]):
for y in np.arange(sh[2]):
u = coords[z,x,y,0] - lo[0]
v = coords[z,x,y,1] - lo[1]
w = coords[z,x,y,2] - lo[2]
u0 = np.int32(np.round(max(0, min(sh_src[0], u - kz))))
u1 = np.int32(np.round(max(0, min(sh_src[0], u + kz))))
v0 = np.int32(np.round(max(0, min(sh_src[1], v - k))))
v1 = np.int32(np.round(max(0, min(sh_src[1], v + k))))
w0 = np.int32(np.round(max(0, min(sh_src[2], w - k))))
w1 = np.int32(np.round(max(0, min(sh_src[2], w + k))))
val = src[u0:u1, v0:v1, w0:w1].max()
dest[z, x, y] = val
def identity():
return np.eye(4, dtype=np.float32)
def translate(dz, dy, dx):
return np.array([
[1.0, 0.0, 0.0, dz],
[0.0, 1.0, 0.0, dy],
[0.0, 0.0, 1.0, dx],
[0.0, 0.0, 0.0, 1.0]
], dtype=np.float32)
def rotate_z(a):
return np.array([
[1.0, 0.0, 0.0, 0.0],
[0.0, np.cos(a), -np.sin(a), 0.0],
[0.0, np.sin(a), np.cos(a), 0.0],
[0.0, 0.0, 0.0, 1.0]
], dtype=np.float32)
def rotate_y(a):
return np.array([
[np.cos(a), -np.sin(a), 0.0, 0.0],
[np.sin(a), np.cos(a), 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]
], dtype=np.float32)
def rotate_x(a):
return np.array([
[np.cos(a), 0.0, np.sin(a), 0.0],
[0.0, 1.0, 0.0, 0.0],
[-np.sin(a), 0.0, np.cos(a), 0.0],
[0.0, 0.0, 0.0, 1.0]
], dtype=np.float32)
def scale_inv(mz, my, mx):
return np.array([
[1/mz, 0.0, 0.0, 0.0],
[0.0, 1/my, 0.0, 0.0],
[0.0, 0.0, 1/mx, 0.0],
[0.0, 0.0, 0.0, 1.0]
], dtype=np.float32)
def scale(mz, my, mx):
return np.array([
[mz, 0.0, 0.0, 0.0],
[0.0, my, 0.0, 0.0],
[0.0, 0.0, mx, 0.0],
[0.0, 0.0, 0.0, 1.0]
], dtype=np.float32)
def chain_matrices(mat_list):
return reduce(np.dot, mat_list, identity())
def get_euler_angles(direc, gamma):
"""
tracing_dir (z, x, y) normalised
angle3 rotation around z in dest frame
phi is the rotation about the 1-2 axis
theta is the rotation about the 0'-1' axis
"""
assert abs(np.linalg.norm(direc) - 1) < 1e-3
phi = np.arctan2(direc[2], direc[1])
theta = np.arccos(direc[0])
return phi, theta, gamma
def get_rotmat_from_direc(direc, gamma=None, rng=None):
if gamma is None:
gamma=0.0
elif gamma=='rand':
if rng is None:
gamma = np.random.rand() * 2 * np.pi
else:
gamma = rng.rand() * 2 * np.pi
phi, theta, gamma = get_euler_angles(direc, gamma)
R1 = rotate_z(-phi)
R2 = rotate_y(-theta)
R3 = rotate_z(gamma)
R = chain_matrices([R3, R2, R1])
return R
def get_random_rotmat(lock_z=False, amount=1.0, rng=None):
rng = np.random.RandomState() if rng is None else rng
gamma = rng.rand() * 2 * np.pi * amount
if lock_z:
return rotate_z(gamma)
phi = rng.rand() * 2 * np.pi * amount
theta = np.arcsin(rng.rand()) * amount
R1 = rotate_z(-phi)
R2 = rotate_y(-theta)
R3 = rotate_z(gamma)
R = chain_matrices([R3, R2, R1])
return R
def get_random_flipmat(no_x_flip=False, rng=None):
rng = np.random.RandomState() if rng is None else rng
F = np.eye(4, dtype=np.float32)
flips = rng.binomial(1, 0.5, 4) * 2 - 1
flips[3] = 1 # don't flip homogeneous dimension
if no_x_flip:
flips[2] = 1
np.fill_diagonal(F, flips)
return F
def get_random_swapmat(lock_z=False, rng=None):
rng = np.random.RandomState() if rng is None else rng
S = np.eye(4, dtype=np.float32)
if lock_z:
swaps = [[0, 1, 2, 3],
[0, 2, 1, 3]]
else:
swaps = [[0, 1, 2, 3],
[0, 2, 1, 3],
[1, 0, 2, 3],
[1, 2, 0, 3],
[2, 0, 1, 3],
[2, 1, 0, 3]]
i = rng.randint(0, len(swaps))
S = S[swaps[i]]
return S
def get_random_warpmat(lock_z=False, perspective=False, amount=1.0, rng=None):
W = np.eye(4, dtype=np.float32)
amount *= 0.1
perturb = np.random.uniform(-amount, amount, (4, 4))
perturb[3,3] = 0
if lock_z:
perturb[0] = 0
perturb[:,0] = 0
if not perspective:
perturb[3] = 0
perturb[3,:3] *= 0.05 # perspective parameters need to be very small
np.clip(perturb[3,:3], -3e-3, 3e-3, out=perturb[3,:3])
return W + perturb
@utils.cache()
def make_dest_coords(sh):
"""
Make coordinate list for destination array of shape sh
"""
zz,xx,yy = np.mgrid[0:sh[0], 0:sh[1], 0:sh[2]]
hh = np.ones(sh, dtype=np.int)
coords = np.concatenate([zz[...,None], xx[...,None],
yy[...,None], hh[...,None]], axis=-1)
return coords.astype(np.float32)
def make_dest_corners(sh):
"""
Make coordinate list of the corners of destination array of shape sh
"""
corners = list(itertools.product(*([0,1],)*3))
sh = np.subtract(sh, 1) # 0-based indices
corners = np.multiply(sh, corners)
corners = np.hstack((corners, np.ones((8,1)))) # homogeneous coords
return corners
[docs]class WarpingOOBError(ValueError):
def __init__(self, *args, **kwargs):
super(WarpingOOBError, self).__init__( *args, **kwargs)
[docs]def trafo_from_array(a):
M = a[:16].reshape((4,4))
offset_l = a[16:19]
aniso_factor = a[19]
return Transform(M, offset_l, aniso_factor)
[docs]def warp_slice(img, ps, M, target=None, target_ps=None,
target_vec_ix=None, target_discrete_ix=None,
last_ch_max_interp=False, ksize=0.5):
"""
Cuts a warped slice out of the input image and out of the target image.
Warping is applied by multiplying the original source coordinates with
the inverse of the homogeneous (forward) transformation matrix ``M``.
"Source coordinates" (``src_coords``) signify the coordinates of voxels in
``img`` and ``target`` that are used to compose their respective warped
versions. The idea here is that not the images themselves, but the
coordinates from where they are read are warped. his allows for much higher
efficiency for large image volumes because we don't have to calculate the
expensive warping transform for the whole image, but only for the voxels
that we eventually want to use for the new warped image.
The transformed coordinates usually don't align to the discrete
voxel grids of the original images (meaning they are not integers), so the
new voxel values are obtained by linear interpolation.
Parameters
----------
img: np.ndarray
Image array in shape ``(f, z, x, y)``
ps: tuple
(spatial only) Patch size ``(z, x, y)``
(spatial shape of the neural network's input node)
M: np.ndarray
Forward warping tansformation matrix (4x4).
Must contain translations in source and target array.
target: np.ndarray or None
Optional target array to be extracted in the same way.
target_ps: tuple
Patch size for the ``target`` array.
target_vec_ix: list
List of triples that denote vector value parts in the target array.
E.g. [(0,1,2), (4,5,6)] denotes two vector fields, separated by a
scalar field in channel 3.
last_ch_max_interp: bool
ksize: float
Returns
-------
img_new: np.ndarray
Warped input image slice
target_new: np.ndarray or None
Warped target image slice
or ``None``, if ``target is None``.
"""
#T = utils.Timer(silent_all=True)
ps = tuple(ps)
if len(img.shape)==3: # For single knossos_array
n_f = 1
sh = img.shape
elif len(img.shape)==4:
n_f = img.shape[0]
sh = img.shape[1:]
else:
raise ValueError('img wrong dim/shape')
M_inv = np.linalg.inv(M.astype(np.float64)).astype(np.float32) # stability...
dest_corners = make_dest_corners(ps)
src_corners = np.dot(M_inv, dest_corners.T).T
if np.any(M[3,:3] != 0): # homogeneous divide
src_corners /= src_corners[:,3][:,None]
# check corners
src_corners = src_corners[:,:3]
lo = np.min(np.floor(src_corners), 0).astype(np.int)
hi = np.max(np.ceil(src_corners + 1), 0).astype(np.int) # add 1 because linear interp
if np.any(lo < 0) or np.any(hi >= sh):
raise WarpingOOBError("Out of bounds")
#T.check("corners")
# compute/transform dense coords
dest_coords = make_dest_coords(ps)
src_coords = np.tensordot(dest_coords, M_inv, axes=[[-1],[1]])
if np.any(M[3,:3] != 0): # homogeneous divide
src_coords /= src_coords[...,3][...,None]
#T.check("tensordot")
# cut patch
src_coords = src_coords[...,:3]
img_cut = img[:, lo[0]:hi[0]+1, #add 1 to include this coordinate!
lo[1]:hi[1]+1,
lo[2]:hi[2]+1,]
img_cut = np.ascontiguousarray(img_cut, dtype=np.float32)
#T.check("cut")
img_new = np.zeros((n_f,)+ps, dtype=np.float32)
#T.check("alloc")
lo = lo.astype(np.float32)
for k in range(n_f):
if (ksize>0.5) and last_ch_max_interp and (k==n_f-1):
map_coordinates_max_kernel(img_cut[k], src_coords, lo, ksize, img_new[k])
else:
map_coordinates_linear(img_cut[k], src_coords, lo, img_new[k])
#T.check("map img")
if np.isnan(img_new).sum():
logger.warning("img_new has {} NaN values".format(np.isnan(img_new).sum()))
if target is not None:
target_ps = tuple(target_ps)
n_f_t = target.shape[0]
off = np.subtract(sh, target.shape[1:])
if np.any(np.mod(off, 2)):
raise ValueError("targets must be centered w.r.t. images")
off //= 2
off_ps = np.subtract(ps, target_ps)
if np.any(np.mod(off_ps, 2)):
raise ValueError("targets must be centered w.r.t. images")
off_ps //= 2
src_coords_target = src_coords[off_ps[0]:off_ps[0]+target_ps[0],
off_ps[1]:off_ps[1]+target_ps[1],
off_ps[2]:off_ps[2]+target_ps[2]]
# shift coords to be w.r.t. to origin of target array
lo_targ = np.floor(src_coords_target.min(2).min(1).min(0) - off).astype(np.int)
# add 1 because linear interp
hi_targ = np.ceil(src_coords_target.max(2).max(1).max(0) - off + 1).astype(np.int)
if np.any(lo_targ < 0) or np.any(hi_targ >= target.shape[1:]):
raise WarpingOOBError("Out of bounds for target")
#T.check("target_check")
target_cut = target[:, lo_targ[0]:hi_targ[0]+1, #add 1 to include this coordinate!
lo_targ[1]:hi_targ[1]+1,
lo_targ[2]:hi_targ[2]+1]
target_cut = np.ascontiguousarray(target_cut, dtype=np.float32)
src_coords_target = np.ascontiguousarray(src_coords_target, dtype=np.float32)
#T.check("target_cut")
target_new = np.zeros((n_f_t,)+target_ps, dtype=np.float32)
#T.check("target_alloc")
lo_targ = (lo_targ + off).astype(np.float32)
if target_discrete_ix is None:
target_discrete_ix = [True for i in range(n_f_t)]
else:
target_discrete_ix = [i in target_discrete_ix for i in range(n_f_t)]
for k, discr in enumerate(target_discrete_ix):
if discr:
map_coordinates_nearest(target_cut[k], src_coords_target, lo_targ, target_new[k])
else:
map_coordinates_linear(target_cut[k], src_coords_target, lo_targ, target_new[k])
#T.check("target_map")
if target_vec_ix is not None: # Vectors must be transformed again
assert np.allclose(M[3,:3], 0.0) # no projective transform
M_lin = M[:3,:3]
for ix in target_vec_ix:
assert len(ix)==3
target_new[ix] = np.tensordot(M_lin, target_new[ix], axes=[[1],[0]])
#T.check("target_vec")
else:
target_new = None
#T.summary()
return img_new, target_new
[docs]def get_tracing_slice(img, ps, pos, z_shift=0, aniso_factor=2,
sample_aniso=True, gamma=0, scale_factor=1.0, direction_iso=None,
target=None, target_ps=None, target_vec_ix=None,
target_discrete_ix=None, rng=None, last_ch_max_interp=False):
# positive z_shift --> see more slices in positive z-direction w.r.t. pos
# scale_factor > 1 zooms into image / magnifies
rng = np.random.RandomState() if rng is None else rng
dest_center = np.array(ps, dtype=np.float)/2
dest_center[0] -= z_shift
R = get_rotmat_from_direc(direction_iso, gamma, rng)
T_src = translate(-pos[0], -pos[1], -pos[2])
S_src = scale(aniso_factor, 1, 1)
S_zoom = scale(scale_factor, scale_factor, scale_factor)
if sample_aniso:
S_dest = scale_inv(aniso_factor, 1, 1)
else:
S_dest = identity()
T_dest = translate(dest_center[0], dest_center[1], dest_center[2])
M = chain_matrices([T_dest, S_zoom, S_dest, R, S_src, T_src])
ksize = min(0.5, 0.5/scale_factor)
img_new, target_new = warp_slice(img, ps, M,
target=target,
target_ps=target_ps,
target_vec_ix=target_vec_ix,
target_discrete_ix=target_discrete_ix,
last_ch_max_interp=last_ch_max_interp,
ksize=ksize)
return img_new, target_new, M
[docs]def get_warped_slice(img, ps, aniso_factor=2, sample_aniso=True,
warp_amount=1.0, lock_z=True, no_x_flip=False, perspective=False,
target=None, target_ps=None, target_vec_ix=None,
target_discrete_ix=None, rng=None):
"""
(Wraps :py:meth:`elektronn2.data.transformations.warp_slice()`)
Generates the warping transformation parameters and composes them into a
single 4D homogeneous transformation matrix ``M``.
Then this transformation is applied to ``img`` and ``target`` in the
``warp_slice()`` function and the transformed input and target image are
returned.
Parameters
----------
img: np.array
Input image
ps: np.array
Patch size (spatial shape of the neural network's input node)
aniso_factor: float
Anisotropy factor that determines an additional scaling in ``z``
direction.
sample_aniso: bool
Scale coordinates by ``1 / aniso_factor`` while warping.
warp_amount: float
Strength of the random warping transformation. A lower ``warp_amount``
will lead to less distorted images.
lock_z: bool
Exclude ``z`` coordinates from the random warping transformations.
no_x_flip: bool
Don't flip ``x`` axis during random warping.
perspective: bool
Apply perspective transformations (in addition to affine ones).
target: np.array
Target image
target_ps: np.array
Target patch size
target_vec_ix
target_discrete_ix
rng: np.random.mtrand.RandomState
Random number generator state (obtainable by
``np.random.RandomState()``). Passing a known state makes the random
transformations reproducible.
Returns
-------
img_new: np.ndarray
(Warped) input image slice
target_new: np.ndarray
(Warped) target slice
"""
rng = np.random.RandomState() if rng is None else rng
strip_2d = False
if len(ps)==2:
strip_2d = True
ps = np.array([1]+list(ps))
if target is not None:
target_ps = np.array([1]+list(target_ps))
dest_center = np.array(ps, dtype=np.float) / 2
src_remainder = np.array(np.mod(ps, 2), dtype=np.float) / 2
if target_ps is not None:
t_center = np.array(target_ps, dtype=np.float) / 2
off = np.subtract(img.shape[1:], target.shape[1:])
off //= 2
lo_pos = np.maximum(dest_center, t_center+off)
hi_pos = np.minimum(img.shape[1:] - dest_center, target.shape[1:] - t_center + off)
else:
lo_pos = dest_center
hi_pos = img.shape[1:] - dest_center
z = rng.randint(lo_pos[0], hi_pos[0]) + src_remainder[0]
y = rng.randint(lo_pos[1], hi_pos[1]) + src_remainder[1]
x = rng.randint(lo_pos[2], hi_pos[2]) + src_remainder[2]
F = get_random_flipmat(no_x_flip, rng)
if no_x_flip:
S = np.eye(4, dtype=np.float32)
else:
S = get_random_swapmat(lock_z, rng)
if np.isclose(warp_amount, 0):
R = np.eye(4, dtype=np.float32)
W = np.eye(4, dtype=np.float32)
else:
R = get_random_rotmat(lock_z, warp_amount, rng)
W = get_random_warpmat(lock_z, perspective, warp_amount, rng)
T_src = translate(-z, -y, -x)
S_src = scale(aniso_factor, 1, 1)
if sample_aniso:
S_dest = scale(1.0 / aniso_factor, 1, 1)
else:
S_dest = identity()
T_dest = translate(dest_center[0], dest_center[1], dest_center[2])
M = chain_matrices([T_dest, S_dest, R, W, F, S, S_src, T_src])
img_new, target_new = warp_slice(img, ps, M,
target=target,
target_ps=target_ps,
target_vec_ix=target_vec_ix,
target_discrete_ix=target_discrete_ix)
if strip_2d:
img_new = img_new[:,0]
if target is not None:
target_new = target_new[:,0]
return img_new, target_new
if __name__ =='__main__':
# pos = [120,175,175]
# import time
# from elektronn2 import utils
# from elektronn2.utils.plotting import scroll_plot
#
# raw = np.random.rand(250,500,500).astype(np.float32)
# if True: # insert stripes
# raw[::10,:,:] = 0
# raw[:,::20,:] = 0
# raw[:,:,::20] = 0
#
# img = raw[None]
# R = chain_matrices([rotate_y(0.2), rotate_z(0.1)])
# ps = [46,171,171]
# z_shift = 8
# aniso_factor = 2
# sample_aniso=True
#
# dest_center = np.array(ps, dtype=np.float32)/2
# dest_center[0] -= z_shift # sign?
# T_src = translate(-pos[0], -pos[1], -pos[2])
# S_src = scale(aniso_factor, 1, 1)
# if sample_aniso:
# S_dest = scale(1.0/aniso_factor, 1, 1)
# else:
# S_dest = identity()
# T_dest = translate(dest_center[0], dest_center[1], dest_center[2])
#
# M = chain_matrices([T_dest, S_dest, R, S_src, T_src])
#
# warp_time = utils.timeit(warp_slice)
# for i in range(10):
# img_new = warp_time(img, ps, M)[0]
# f = scroll_plot(img_new[0])
pos = [150,300,300]
from elektronn2.utils.plotting import scroll_plot
from elektronn2.data.image import center_cubes
warp_time = utils.timeit(warp_slice)
raw = utils.h5load('~/lustre/mkilling/BirdGT/j0126_old/v2_old_0-raw-zyx.h5')
if True: # insert stripes
raw[:,::10,:,:] = 0
raw[:,:,::20,:] = 0
raw[:,:,:,::20] = 0
img = raw
targ = utils.h5load('~/lustre/mkilling/BirdGT/j0126_old/v2_old_0-combo-sparse-zyx.h5')
targ[:,::10,:,:] = 0
targ[:,:,::20,:] = 0
targ[:,:,:,::20] = 0
targ = targ
# --------------------------------
R = chain_matrices([rotate_y(0.05), rotate_z(0.2)])
#direc = np.array([1,0.15, 0.05])
#direc /= np.linalg.norm(direc)
#direc = [1,0,0]
#R = get_rotmat_from_direc(direc, 0.2)
#R[3,1] = 0.002
ps = [31,131,131]
z_shift = 0
aniso_factor = 2
sample_aniso=True
dest_center = np.array(ps, dtype=np.float32)/2
dest_center[0] -= z_shift # sign?
T_src = translate(-pos[0], -pos[1], -pos[2])
S_src = scale(aniso_factor, 1, 1)
if sample_aniso:
S_dest = scale(1.0/aniso_factor, 1, 1)
else:
S_dest = identity()
T_dest = translate(dest_center[0], dest_center[1], dest_center[2])
M = chain_matrices([T_dest, S_dest, R, S_src, T_src])
for i in range(5):
img_new, targ_new = warp_time(img, ps, M, target=targ, target_ps=[11,121,121], target_discrete_ix=[0,1,2])
#f0 = scroll_plot(img_new[0])
i, t = center_cubes(img_new, targ_new)
f1 = scroll_plot([i[0], t[0]])
def xyz2zxy(vol):
"""
Swaps axes to ELEKTRONN convention ([X, Y, Z] -> [Z, X, Y]). If additional
channel axis is provided: [X, Y, Z, CH] -> [Z, X, Y, CH]
Parameters
----------
vol : np.array [X, Y, Z]
Returns
-------
np.array [Z, X, Y]
"""
assert vol.ndim >= 3
vol = vol.swapaxes(1, 0) # y x z
vol = vol.swapaxes(0, 2) # z x y
return vol