# -*- coding: utf-8 -*-
# ELEKTRONN2 - Neural Network Toolkit
#
# Copyright (c) 2014 - now
# Max-Planck-Institute for Medical Research, Heidelberg, Germany
# Authors: Marius Killinger
from __future__ import absolute_import, division, print_function
from builtins import filter, hex, input, int, map, next, oct, pow, range, \
super, zip
import os
from collections import OrderedDict
from matplotlib import pyplot as plt
from scipy import stats
import numpy as np
import seaborn as sns
from .locking import FileLock
from ..config import config
import logging
logger = logging.getLogger('elektronn2log')
def _scroll_plot1(image, name, init_z):
"""
Creates a plot of 3d volume images
Scrolling changes the displayed slices
Parameters
----------
images: array of shape (z,x,y) or (z,x,y,RGB)
Usage
-----
For the scroll interaction to work, the "scroller" object
must be returned to the calling scope
>>> fig, scroller = scroll_plot(image, name)
>>> fig.show()
"""
fig = plt.figure(figsize=(12, 12))
ax1 = fig.add_subplot(111)
scroller = Scroller([ax1], [image, ], [name, ], init_z)
fig.canvas.mpl_connect('scroll_event', scroller.onscroll)
fig.tight_layout()
return scroller
def _scroll_plot2(images, names, init_z):
"""
Creates a plot 1x2 image plot of 3d volume images
Scrolling changes the displayed slices
Parameters
----------
images: list of 2 arrays
Each array of shape (z,y,x) or (z,y,x,RGB)
names: list of 2 strings
Names for each image
Usage
-----
For the scroll interaction to work, the "scroller" object
must be returned to the calling scope
>>> fig, scroller = _scroll_plot4(images, names)
>>> fig.show()
"""
fig = plt.figure(figsize=(12, 6))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122, sharex=ax1, sharey=ax1)
scroller = Scroller([ax1, ax2], images, names, init_z)
fig.canvas.mpl_connect('scroll_event', scroller.onscroll)
fig.tight_layout()
return scroller
def _scroll_plot4(images, names, init_z):
"""
Creates a plot 2x2 image plot of 3d volume images
Scrolling changes the displayed slices
Parameters
----------
images: list of 4 arrays
Each array of shape (z,y,x) or (z,y,x,RGB)
names: list of 4 strings
Names for each image
Usage
-----
For the scroll interaction to work, the "scroller" object
must be returned to the calling scope
>>> fig, scroller = _scroll_plot4(images, names)
>>> fig.show()
"""
fig = plt.figure(figsize=(12, 12))
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222, sharex=ax1, sharey=ax1)
ax3 = fig.add_subplot(223, sharex=ax1, sharey=ax1)
ax4 = fig.add_subplot(224, sharex=ax1, sharey=ax1)
scroller = Scroller([ax1, ax2, ax3, ax4], images, names, init_z)
fig.canvas.mpl_connect('scroll_event', scroller.onscroll)
fig.tight_layout()
return scroller
def _embed3d2d(a, border_width=1, normalize=False, output_ratio=1.5, ):
"""
Embed an 3d array into an 2d matrix by tiling.
The last two dimensions of ``a`` are assumed to be spatial, the first is tiled.
"""
sh = a.shape
assert len(sh)==3
n = sh[0]
nhor = int(np.ceil(np.sqrt(n * output_ratio))) # aim: ratio 16:9
nvert = int(np.ceil(float(n) / nhor)) # warning: too big: nvert*nhor >= n
if normalize:
maxs = [np.max(a[i, :, :]) + 1e-8 for i in range(n)]
mins = [np.min(a[i, :, :]) for i in range(n)]
else:
maxs = [1] * n
mins = [0] * n
ret = np.zeros(
(nvert * (border_width + sh[1]), nhor * (border_width + sh[2])),
dtype=np.float32)
for j in range(nvert):
for i in range(nhor):
if i + j * nhor >= n:
return ret
ret[j*(border_width+sh[1]):j*(border_width+sh[1])+sh[1],
i*(border_width+sh[2]):i*(border_width+sh[2])+sh[2]] = \
(a[i+j*nhor,:,:]-mins[i+j*nhor])/(maxs[i+j*nhor]-mins[i+j*nhor])
return ret
[docs]def embedfilters(filters, border_width=1, normalize=False, output_ratio=1.0,
rgb_axis=None):
"""
Embed an nd array into an 2d matrix by tiling.
The last two dimensions of ``a`` are assumed to be spatial,
the others are tiled recursively.
"""
if rgb_axis is not None:
assert filters[rgb_axis]==3
channels = []
for i in range(3):
slice = [slice(None), ] * filters.ndim
slice[rgb_axis] = i
f = filters[slice]
channels.append(
embedfilters(f, border_width, normalize, output_ratio))
return np.dstack(channels)
if filters.ndim==3:
return _embed3d2d(filters, border_width, normalize, output_ratio)
elif filters.ndim > 3:
parts = []
for f in filters:
parts.append(
embedfilters(f, border_width, normalize, output_ratio))
parts = np.concatenate([x[None, ...] for x in parts])
return embedfilters(parts, border_width, normalize, output_ratio)
[docs]def sma(c, n):
"""
Returns box-SMA of c with box length n, the returned array has the same
length as c and is const-padded at the beginning
"""
if n==0:
return c
ret = np.cumsum(c, dtype=float)
ret[n:] = (ret[n:] - ret[:-n]) / n
m = min(n, len(c))
ret[:n] = ret[:n] / np.arange(1, m + 1) # unsmoothed
return ret
[docs]def add_timeticks(ax, times, steps, time_str='mins', num=5):
N = int(times[-1])
k = max(N / num, 1)
k = int(np.log10(k)) # 10-base of locators
m = int(np.round(float(N) / (num * 10 ** k))) # multiple of base
s = max(m * 10 ** k, 1)
x_labs = np.arange(0, N, s, dtype=np.int)
x_ticks = np.interp(x_labs, times, steps)
ax.set_xticks(x_ticks)
ax.set_xticklabels(x_labs)
ax.set_xlim(0, steps[-1])
ax.set_xlabel('Runtime [%s]' % time_str) # (%s)'%("{0:,d}".format(N)))
[docs]def plot_hist(timeline, history, save_name, loss_smoothing_length=200,
autoscale=True):
"""Plot graphical info during Training"""
plt.ioff()
try:
# Subsample points for plotting
N = len(timeline)
x_timeline = np.arange(N)
s = max((len(timeline) // 2000), 1)
x_timeline = x_timeline[::s]
timeline = timeline[::s]
s = max((len(history) // 2000), 1)
history = history[::s]
if timeline['time'][-1] < 120 * 60:
runtime = str(int(timeline['time'][-1] / 60)) + ' mins'
else:
runtime = "%.1f hours" % (timeline['time'][-1] / 3600)
# check if valid data is available
if not np.any(np.isnan(history['valid_loss'])):
l = history['valid_loss'][-10:]
else:
l = history['train_loss'][-10:]
loss_cap = l.mean() + 2 * l.std()
lt = timeline['loss'][-200:]
lt_m = lt.mean()
lt_s = lt.std()
loss_cap_t = lt_m + 2 * lt_s
loss_cap = np.maximum(loss_cap, loss_cap_t)
if np.all(timeline['loss'] >= 0):
loss_floor = 0.0
else:
loss_floor = lt_m - 2 * lt_s
### Timeline, Loss ###
plt.figure(figsize=(16, 12))
plt.subplot(211)
plt.plot(x_timeline, timeline['loss'], 'b-', alpha=0.5,
label='Update Loss')
loss_smooth = sma(timeline['loss'], loss_smoothing_length)
plt.plot(x_timeline, loss_smooth, 'k-', label='Smooth update Loss',
linewidth=3)
if autoscale:
plt.ylim(loss_floor, loss_cap)
plt.xlim(0, N)
plt.legend(loc=0)
plt.hlines(lt_m, 0, N, linestyle='dashed', colors='r', linewidth=2)
plt.hlines(lt_m + lt_s, 0, N, linestyle='dotted', colors='r',
linewidth=1)
plt.hlines(lt_m - lt_s, 0, N, linestyle='dotted', colors='r',
linewidth=1)
plt.xlabel('Update steps %s, total runtime %s' % (N - 1, runtime))
ax = plt.twiny()
if timeline['time'][-1] > 120 * 60:
add_timeticks(ax, timeline['time'] / 3600, x_timeline,
time_str='hours')
else:
add_timeticks(ax, timeline['time'] / 60, x_timeline,
time_str='mins')
### Loss vs Prevalence ###
plt.subplot(212)
c = 1.0 - (timeline['time'] / timeline['time'].max())
plt.scatter(timeline['batch_char'], timeline['loss'], c=c, marker='.',
s=80, cmap='gray', edgecolors='face')
if autoscale:
bc = timeline['batch_char'][-200:]
bc_m = bc.mean()
bc_s = bc.std()
bc_cap = bc_m + 2 * bc_s
if np.all(bc >= 0):
bc_floor = -0.01
else:
bc_floor = bc_m - 2 * bc_s
plt.ylim(loss_floor, loss_cap)
plt.xlim(bc_floor, bc_cap)
plt.xlabel('Mean target of batch')
plt.ylabel('Loss')
plt.tight_layout()
with FileLock('plotting'):
plt.savefig(save_name + ".timeline.png", bbox_inches='tight')
###################################################################
### History Loss ###
plt.figure(figsize=(16, 12))
plt.subplot(311)
plt.plot(history['steps'], history['train_loss'], 'g-',
label='Train Loss', linewidth=3)
plt.plot(history['steps'], history['valid_loss'], 'r-',
label='Valid Loss', linewidth=3)
if autoscale:
plt.ylim(loss_floor, loss_cap)
plt.xlim(0, history['steps'][-1])
plt.legend(loc=0)
# plt.xlabel('Update steps %s, total runtime %s'%(N-1, runtime))
ax = plt.twiny()
if timeline['time'][-1] > 120 * 60:
add_timeticks(ax, timeline['time'] / 3600, x_timeline,
time_str='hours')
else:
add_timeticks(ax, timeline['time'] / 60, x_timeline,
time_str='mins')
### History Loss gains ###
plt.subplot(312)
plt.plot(history['steps'], history['loss_gain'], 'b-',
label='Loss Gain at update', linewidth=3)
plt.hlines(0, 0, history['steps'][-1], linestyles='dotted')
plt.plot(history['steps'], history['lr'], 'r-', label='LR',
linewidth=3)
# plt.xlabel('Update steps %s, total runtime %s'%(N-1, runtime))
plt.legend(loc=3)
std = history['loss_gain'][:5].std() * 2 if len(history) > 6 else 1.0
if autoscale:
# add epsilon to suppress matplotlib warning in case of CG
plt.ylim(-std, std + 1e-10)
plt.xlim(0, history['steps'][-1])
ax2 = plt.twinx()
ax2.plot(history['steps'], history['mom'], 'r-', label='MOM')
ax2.plot(history['steps'], history['gradnetrate'], 'r-',
label='GradNetRate')
ax2.set_ylim(-1, 1)
if autoscale:
ax2.set_xlim(0, history['steps'][-1])
ax2.legend(loc=4)
### Errors ###
plt.subplot(313)
cutoff = 2
if len(history) > (cutoff + 1):
history = history[cutoff:]
# check if valid data is available
if not np.any(np.isnan(history['valid_err'])):
e = history['valid_err'][-200:]
else:
e = history['train_err'][-200:]
e_m = e.mean()
e_s = e.std()
err_cap = e_m + 2 * e_s
if np.all(e > 0):
err_floor = 0.0
else:
err_floor = e_m - 2 * e_s
plt.plot(history['steps'], history['train_err'], 'g--',
label='Train error', linewidth=1)
plt.plot(history['steps'], history['valid_err'], 'r--',
label='Valid Error', linewidth=1)
plt.plot(history['steps'], sma(history['train_err'], 8), 'g-',
label='Smooth train error', linewidth=3)
if not np.any(np.isnan(sma(history['valid_err'], 8))):
plt.plot(history['steps'], sma(history['valid_err'], 8), 'r-',
label='Smooth valid Error', linewidth=3)
if autoscale:
plt.ylim(err_floor, err_cap)
plt.xlim(0, history['steps'][-1])
plt.grid()
plt.legend(loc=0)
plt.xlabel('Update steps %s, total runtime %s' % (N - 1, runtime))
plt.tight_layout()
with FileLock('plotting'):
plt.savefig(save_name + ".history.png", bbox_inches='tight')
except ValueError:
# When arrays are empty
logger.warning("An error occurred during plotting.")
[docs]def plot_var(var, save_name):
# [i, nll, nll.std, conc.mean, conc.std,]
plt.figure(figsize=(16, 12))
plt.subplot(211)
plt.plot(var[:, 0], var[:, 1], 'b-', alpha=0.6)
plt.plot(var[:, 0], sma(var[:, 1], 100), 'g-', linewidth=3)
plt.plot(var[:, 0], sma(var[:, 1] + var[:, 2], 100), 'r:', linewidth=2)
plt.plot(var[:, 0], sma(var[:, 1] - var[:, 2], 100), 'r:', linewidth=2)
plt.title("NLL")
plt.subplot(212)
plt.plot(var[:, 0], var[:, 3], 'b-', alpha=0.6)
plt.plot(var[:, 0], sma(var[:, 3], 100), 'g-', linewidth=3)
plt.plot(var[:, 0], sma(var[:, 3] + var[:, 4], 100), 'r:', linewidth=2)
plt.plot(var[:, 0], sma(var[:, 3] - var[:, 4], 100), 'r:', linewidth=2)
plt.title("Concentration")
with FileLock('plotting'):
plt.savefig(save_name + ".Beta1.png", bbox_inches='tight')
plt.figure(figsize=(12, 12))
c = 1.0 - ((var[:, 0]).astype(np.float32) / var[-1, 0])
plt.subplot(221)
plt.scatter(var[:, 1], var[:, 3], c=c, marker='.', s=80, cmap='gray',
edgecolors='face')
plt.title("Concentration vs. NLL")
plt.subplot(222)
plt.scatter(var[:, 2], var[:, 3], c=c, marker='.', s=80, cmap='gray',
edgecolors='face')
plt.title("Concentration vs. NLL.std")
plt.subplot(223)
plt.scatter(var[:, 3], var[:, 4], c=c, marker='.', s=80, cmap='gray',
edgecolors='face')
plt.title("Concentration vs. Concentration.std")
plt.subplot(224)
plt.scatter(var[:, 1], var[:, 2], c=c, marker='.', s=80, cmap='gray',
edgecolors='face')
plt.title("NLL vs. NLL.std")
with FileLock('plotting'):
plt.savefig(save_name + ".Beta2.png", bbox_inches='tight')
[docs]def plot_debug(var, debug_output_names, save_name):
# [i, nll, other....]
s = max((len(var) // 2000), 1)
var = var[::s]
plt.figure(figsize=(16, 12))
colors = ['gold', 'b', 'darkblue', 'crimson', 'navajowhite', 'deepskyblue',
'darkgray', 'maroon', 'palevioletred', 'forestgreen', ] * 2
n = len(colors) // 2
marker = ['-', ] * n + [':'] * n
lw_s = [2, ] * n + [3, ] * n
maxima = []
minima = []
total = sma(var[:, 1], 70)
maxima.append(total[-100:].max())
minima.append(total[-100:].min())
plt.plot(var[:, 0], total, 'k-', linewidth=4, label='total loss')
for i in range(len(debug_output_names)): ###TODO automatic std intervals
name = debug_output_names[i]
smooth = sma(var[:, i + 2], 70)
plt.plot(var[:, 0], smooth, color=colors[i], linestyle=marker[i],
linewidth=lw_s[i], label=name)
maxima.append(smooth[-100:].max())
minima.append(smooth[-100:].min())
plt.title("Debug Outputs")
cap_hi = np.max([x for x in maxima if np.isfinite(x)]) * 1.5
cap_lo = np.min([x for x in minima if np.isfinite(x)])
plt.ylim(cap_lo, cap_hi)
plt.legend(loc=0)
plt.hlines(0, var[0, 0], var[-1, 0], linewidth=1)
plt.grid()
with FileLock('plotting'):
plt.savefig(save_name + ".Debug.png", bbox_inches='tight')
[docs]def plot_regression(pred, target, save_name, loss_smoothing_length=200,
autoscale=True):
"""Plot graphical info during Training"""
try:
# Subsample points for plotting
N = len(pred)
s = max((len(pred) // 2000), 1)
pred = pred[::s].ravel()
target = target[::s].ravel()
N = len(pred)
x_timeline = np.arange(N)
c = N - x_timeline
plt.figure(figsize=(8, 8))
### Loss vs Prevalence ###
plt.scatter(pred, target, c=c, marker='.', s=80, cmap='gray',
edgecolors='face')
m = np.minimum(pred.min(), target.min())
M = np.maximum(pred.max(), target.max())
plt.plot([m, M], [m, M], 'r:')
plt.ylim(m, M)
plt.xlim(m, M)
plt.xlabel('Prediction')
plt.ylabel('Target')
plt.tight_layout()
with FileLock('plotting'):
plt.savefig(save_name + ".regression.png", bbox_inches='tight')
except ValueError:
# When arrays are empty
logger.warning("An error occurred during regression plotting.")
[docs]def plot_kde(pred, target, save_name, limit=90, scale='same', grid=50,
take_last=4000):
try:
if take_last:
pred = pred[-take_last:].ravel()
target = target[-take_last:].ravel()
if limit=='max':
mp, mt = pred.min(), target.min()
Mp, Mt = pred.max(), target.max()
else:
lo = 100 - limit
mp, mt = np.percentile(pred, lo), np.percentile(target, lo)
Mp, Mt = np.percentile(pred, limit), np.percentile(target, limit)
if scale=='same':
mp = np.minimum(mp, mt)
Mp = np.maximum(Mp, Mt)
mt = mp
Mt = Mp
if isinstance(grid, int):
grid = [grid, grid]
pg, tg = np.mgrid[mp:Mp:grid[0] * 1j, mt:Mp:grid[1] * 1j]
positions = np.vstack([pg.ravel(), tg.ravel()])
values = np.vstack([pred, target])
kernel = stats.gaussian_kde(values)
f = np.reshape(kernel(positions).T, pg.shape)
plt.figure()
plt.xlim(mp, Mp)
plt.ylim(mt, Mt)
plt.xlabel("Prediction")
plt.ylabel("Target")
plt.imshow(np.rot90(f), cmap=plt.cm.gist_earth_r,
extent=[mp, Mp, mt, Mt])
plt.contour(pg, tg, f)
plt.plot([mt, Mt], [mt, Mt], 'r:')
plt.tight_layout()
with FileLock('plotting'):
plt.savefig(save_name + ".regression_kde.png", bbox_inches='tight')
except ValueError:
# When arrays are empty
logger.warning("An error occurred during regression kde plotting.")
[docs]def my_quiver(x, y, img=None, c=None):
"""
first dim of x,y changes along vertical axis
second dim changes along horizontal axis
x: vertical vector component
y: horizontal vector component
"""
figure = plt.figure(figsize=(7, 7))
if img is not None:
plt.imshow(img, interpolation='none', alpha=0.22, cmap='gray')
plt.quiver(x, y, c, angles='xy', units='xy', cmap='spring', pivot='middle',
scale=0.5)
return figure
[docs]def plot_trainingtarget(img, lab, stride=1):
"""
Plots raw image vs target to check if valid batches are produced.
Raw data is also shown overlaid with targets
Parameters
----------
img: 2d array
raw image from batch
lab: 2d array
targets
stride: int
strides of targets
"""
if len(lab) * stride!=len(img):
off = (len(img) - stride * len(lab)) // 2 // stride
if lab.ndim==3:
assert lab.shape[2]==3
new_t = np.zeros(
(lab.shape[0] + 2 * off, lab.shape[1] + 2 * off, 3))
new_t[off:-off, off:-off, :] = lab
else:
new_t = np.zeros((lab.shape[0] + 2 * off, lab.shape[1] + 2 * off))
new_t[off:-off, off:-off] = lab
lab = new_t
if lab.ndim==3:
assert lab.shape[2]==3
img = img[:, :, None]
img = np.repeat(img, 3, axis=2)
plt.figure(figsize=(18, 6))
plt.subplot(131)
plt.imshow(img, interpolation='none', cmap=plt.get_cmap('gray'))
plt.title('data')
plt.subplot(132)
plt.imshow(lab, interpolation='none', cmap=plt.get_cmap('gray'))
plt.title('target')
if img.shape==lab.shape:
overlay = 0.75 * img + 0.25 * (1 - lab)
plt.subplot(133)
plt.imshow(overlay, interpolation='none', cmap=plt.get_cmap('gray'))
plt.title('overlay')
if config.gui_plot:
plt.show()
return img - lab
[docs]def plot_exectimes(exectimes, save_path='~/exectimes.png', max_items=32):
"""
Plot model execution time dict obtained from
elektronn2.neuromancer.model.Model.measure_exectimes()
:param exectimes: OrderedDict of execution times
(output of Model.measure_exectimes())
:param save_path: Where to save the plot
:param max_items: Only the max_items largest execution times are given
names and are plotted independently.
Everything else is grouped under '(other nodes)'.
"""
thresh_val = 0
if len(exectimes) > max_items:
thresh_val = sorted(list(exectimes.values()))[-max_items]
filt_rtimes = OrderedDict()
for key, val in exectimes.items():
if val >= thresh_val:
filt_rtimes[key] = val
other = sum(exectimes.values()) - sum(filt_rtimes.values())
node_names = list(filt_rtimes.keys())
node_exectimes = list(filt_rtimes.values())
if len(exectimes) > max_items:
node_names += ['(other nodes)']
node_exectimes += [other]
cs = plt.cm.Set1(np.arange(len(node_exectimes)) / (len(node_exectimes)))
sns.set_style("whitegrid")
plt.figure(figsize=(13, 12))
plt.title('Node execution times')
plt.ylabel('Node')
plt.xlabel('Time (in ms)')
ax = sns.barplot(y=node_names, x=node_exectimes)
with FileLock('plotting'):
ax.get_figure().savefig(os.path.expanduser(save_path), bbox_inches='tight')