Source code for elektronn2.utils.utils_basic

# -*- coding: utf-8 -*-
# ELEKTRONN2 Toolkit
# Copyright (c) 2015 Marius Killinger
# All rights reserved

from __future__ import absolute_import, division, print_function
from builtins import filter, hex, input, int, map, next, oct, pow, range, \
    super, zip

__all__ = ['CircularBuffer', 'AccumulationArray', 'DynamicKDT', 'KDT',
           'pickleload', 'picklesave', 'h5load', 'h5save', 'pretty_string_ops',
           'import_variable_from_file', 'timeit', 'Timer',
           'cache', 'pretty_string_time', 'unique_rows',
           'get_free_cpu_count', 'parallel_accum', 'makeversiondir', 'as_list']

from builtins import filter, hex, input, int, map, next, oct, pow, range, super, zip

import os
import re
import time
import logging
from functools import reduce

try:
    import cPickle as pkl
except:
    import pickle as pkl


import numba
import psutil
from multiprocessing import Pool
import gzip
import h5py
import numpy as np
import sys
import importlib
import functools
from scipy.spatial.distance import cdist
import sklearn
from sklearn.neighbors import NearestNeighbors as NearestNeighbors_


# TODO: Why are there two __all__ assignments? Delete the first one?
__all__ = ['get_free_cpu_count', 'parallel_accum',
           'timeit', 'cache', 'CircularBuffer', 'AccumulationArray', 'KDT',
           'DynamicKDT', 'import_variable_from_file', 'pickleload', 'picklesave',
           'h5save', 'h5load', 'pretty_string_ops', 'pretty_string_time',
           'makeversiondir', 'Timer', 'unique_rows', 'as_list']

logger = logging.getLogger('elektronn2log')


[docs]def get_free_cpu_count(): m = psutil.cpu_count() if m<=2: return 1 else: load = float(psutil.cpu_percent(interval=0.4)) / 100 free = 1 - load n = min(max(3, m * free * 0.8), m - 1) return int(n)
[docs]def parallel_accum(func, n_ret, var_args, const_args, proc=-1, debug=False): if proc==-1: proc = get_free_cpu_count() args = [] for a in var_args: try: v = tuple(a) except TypeError: v = (a,) arg = v + tuple(const_args) args.append(arg) assert len(args)>0 if not debug: p = Pool(proc) try: ret = p.map(func, args) p.close() p.join() except KeyboardInterrupt: p.terminate() p.join() raise KeyboardInterrupt else: ret = map(func, args) accums = [list() for i in range(n_ret)] for tmp in ret: if n_ret==1: try: accums[0].extend(tmp) except TypeError: accums[0].append(tmp) else: for i in range(n_ret): x = tmp[i] try: accums[i].extend(x) except TypeError: accums[i].append(x) if n_ret==1: return accums[0] else: return tuple(accums)
### Decorator Collection ### class DecoratorBase(object): """ If used as ``@DecoratorBase`` this initialiser receives only the function to be wrapped (no wrapper args) Then ``__call__`` receives the arguments for the underlying function. Alternatively, if used as ``@DecoratorBase(wrapper_print=True, n_times=10)`` this initialiser receives wrapper args, the function is passed to ``__call__`` and ``__call__`` returns a wrapped function. This base class completely ignores all wrapper arguments. """ def __init__(self, *args, **kwargs): self.func = None self.dec_args = None self.dec_kwargs = None if len(args)==1 and not len(kwargs): assert hasattr(args[0], '__call__') func = args[0] self.func = func self.__call__.__func__.__doc__ = func.__doc__ self.__call__.__func__.__name__ = func.__name__ else: self.dec_args = args self.dec_kwargs = kwargs def __call__(self, *args, **kwargs): # The decorator was initialised with the func, it now has apply the decoration itself if not self.func is None: # do something with args ret = self.func(*args, **kwargs) # do something with kwargs return ret # The decorator was initialised with args, it now returns a wrapped function elif len(args)==1 and not len(kwargs): assert hasattr(args[0], '__call__') func = args[0] @functools.wraps(func) def decorated(*args0, **kwargs0): # do something with args0, read the decorator arguments # print(self.dec_args) # print(self.dec_kwargs) ret = func(*args0, **kwargs0) # do something with ret return ret return decorated else: raise ValueError()
[docs]class timeit(DecoratorBase): def __call__(self, *args, **kwargs): # The nor args for the decorator --> n=1 if not self.func is None: t0 = time.time() ret = self.func(*args, **kwargs) t = time.time() - t0 print("Function <%s> took %.5g s" % (self.func.__name__, t)) return ret # The decorator was initialised with args, it now returns a wrapped function elif len(args)==1 and not len(kwargs): assert hasattr(args[0], '__call__') func = args[0] n = self.dec_kwargs.get('n', 1) @functools.wraps(func) def decorated(*args0, **kwargs0): t0 = time.time() if n>1: for i in range(n - 1): func(*args0, **kwargs0) ret = func(*args0, **kwargs0) t = time.time() - t0 print("Function <%s> took %.5g s averaged over %i execs" % ( func.__name__, t / n, n)) return ret return decorated else: raise ValueError()
[docs]class cache(DecoratorBase): def __init__(self, *args, **kwargs): super(cache, self).__init__(*args, **kwargs) self.memo = {} self.default = None
[docs] @staticmethod def hash_args(args): tmp = [] for arg in args: if isinstance(arg, np.ndarray): tmp.append(hash(arg.tostring())) elif isinstance(arg, (list, tuple)): tmp.append(reduce(lambda x, y: x + hash(y), arg, 0)) else: tmp.append(hash(arg)) return reduce(lambda x, y: x + y, tmp, 0)
def __call__(self, *args, **kwargs): # The nor args for the decorator --> n=1 if not self.func is None: if len(args)==0 and len(kwargs)==0: if self.default is None: self.default = self() return self.default() else: key1 = self.hash_args(args) key2 = self.hash_args(kwargs.values()) key = key1 + key2 if not key in self.memo: self.memo[key] = self.func(*args, **kwargs) return self.memo[key] # The decorator was initialised with args, it now returns a wrapped function elif len(args)==1 and not len(kwargs): assert hasattr(args[0], '__call__') func = args[0] @functools.wraps(func) def decorated(*args0, **kwargs0): if len(args0)==0 and len(kwargs0)==0: if self.default is None: self.default = self() return self.default() else: key1 = self.hash_args(args0) key2 = self.hash_args(kwargs0.values()) key = key1 + key2 if not key in self.memo: self.memo[key] = func(*args0, **kwargs0) return self.memo[key] return decorated else: raise ValueError()
### Custom Data Structures ###
[docs]class CircularBuffer(object): def __init__(self, buffer_len): self.length = 0 self._buffer = np.zeros(buffer_len)
[docs] def append(self, data): self._buffer = np.roll(self._buffer, 1) self._buffer[0] = data self.length += 1
@property def data(self): return self._buffer[:self.length]
[docs] def mean(self): if self.length: return self.data.mean() else: return 0.0
[docs] def setvals(self, val): self._buffer[:] = val
def __len__(self): return self.length def __getitem__(self, slc): return self._buffer[self.length][slc] def __repr__(self): return repr(self.data)
[docs]class AccumulationArray(object): def __init__(self, right_shape=(), dtype=np.float32, n_init=100, data=None, ema_factor=0.95): if isinstance(dtype, dict) and right_shape!=(): raise ValueError("If dict is used as dtype, right shape must be" "unchanged (i.e it is 1d)") if data is not None and len(data): n_init += len(data) right_shape = data.shape[1:] dtype = data.dtype self._n_init = n_init if isinstance(right_shape, int): self._right_shape = (right_shape,) else: self._right_shape = tuple(right_shape) self.dtype = dtype self.length = 0 self._buffer = self._alloc(n_init) self._min = +np.inf self._max = -np.inf self._sum = 0 self._ema = None self._ema_factor = ema_factor if data is not None and len(data): self.length = len(data) self._buffer[:self.length] = data self._min = data.min(0) self._max = data.max(0) self._sum = data.sum(0) def __repr__(self): return repr(self.data) def _alloc(self, n): if isinstance(self._right_shape, (tuple, list, np.ndarray)): ret = np.zeros((n,) + tuple(self._right_shape), dtype=self.dtype) elif isinstance(self.dtype, dict): # rec array ret = np.zeros(n, dtype=self.dtype) else: raise ValueError("dtype not understood") return ret
[docs] def append(self, data): # data = self.normalise_data(data) if len(self._buffer)==self.length: tmp = self._alloc(len(self._buffer) * 2) tmp[:self.length] = self._buffer self._buffer = tmp if isinstance(self.dtype, dict): for k, val in enumerate(data): self._buffer[self.length][k] = data[k] else: self._buffer[self.length] = data if self._ema is None: self._ema = self._buffer[self.length] else: f = self._ema_factor fc = 1 - f self._ema = self._ema * f + self._buffer[self.length] * fc self.length += 1 self._min = np.minimum(data, self._min) self._max = np.maximum(data, self._max) self._sum = self._sum + np.asanyarray(data)
[docs] def add_offset(self, off): self.data[:] += off if off.ndim>np.ndim(self._sum): off = off[0] self._min += off self._max += off self._sum += off * self.length
[docs] def clear(self): self.length = 0 self._min = +np.inf self._max = -np.inf self._sum = 0
[docs] def mean(self): return np.asarray(self._sum, dtype=np.float32) / self.length
[docs] def sum(self): return self._sum
[docs] def max(self): return self._max
[docs] def min(self): return self._min
def __len__(self): return self.length @property def data(self): return self._buffer[:self.length] @property def ema(self): return self._ema def __getitem__(self, slc): return self._buffer[:self.length][slc]
[docs]class KDT(NearestNeighbors_): warning_shown = False @functools.wraps(NearestNeighbors_.__init__) def __init__(self, n_neighbors=5, radius=1.0, algorithm='auto', leaf_size=30, metric='minkowski', p=2, metric_params=None, n_jobs=1, **kwargs): if sklearn.__version__=="0.16.1": if not KDT.warning_shown: logger.warning("sklearn version does not support MP, try to upgrade it.") KDT.warning_shown = True if "n_jobs" in kwargs: kwargs.pop("n_jobs") else: kwargs['n_jobs'] = n_jobs super(KDT, self).__init__(n_neighbors=n_neighbors, radius=radius, algorithm=algorithm, leaf_size=leaf_size, metric=metric, p=p, metric_params=metric_params, **kwargs) __init__.__doc__ = NearestNeighbors_.__init__.__doc__
@numba.jit(nopython=True, looplift=True, cache=True) def _merge(distances, indices, coordinates, pairwise_dist, sort_ix, new_points, k, query_points): q = len(query_points) dim = query_points.shape[1:] distances_new = np.zeros((q, k), dtype=np.float32) # (q,k) indices_new = np.zeros((q, k), dtype=np.int64) # (q,k) coordinates_new = np.zeros((q, k,) + dim, dtype=np.float32) # (q,k,2/3) kdt_pointer = np.zeros(q, dtype=np.int64) # (q) should be maximal k+1 new_pointer = np.zeros(q, dtype=np.int64) # (q) should be maximal m+1 for p in range(q): # over query points for c in range(k): # over #NNs new_ix = sort_ix[p, new_pointer[p]] d_new = pairwise_dist[p, new_ix] d_kdt = distances[p, kdt_pointer[p]] if d_kdt>d_new: distances_new[p, c] = d_new indices_new[p, c] = -666 coordinates_new[p, c] = new_points[new_ix] new_pointer[p] += 1 else: distances_new[p, c] = d_kdt indices_new[p, c] = indices[p, kdt_pointer[p]] coordinates_new[p, c] = coordinates[p, kdt_pointer[p]] kdt_pointer[p] += 1 return distances_new, indices_new, coordinates_new
[docs]class DynamicKDT(object): def __init__(self, points=None, k=1, n_jobs=-1, rebuild_thresh=100, aniso_scale=None): self._kdt = None self._new_points = [] self._static_points = [] self._k = k self._jobs = n_jobs self._rebuild_thresh = rebuild_thresh self.aniso_scale = 1 if aniso_scale is not None: if isinstance(aniso_scale, (list, tuple)): self.aniso_scale = np.atleast_2d(np.array(aniso_scale)) elif isinstance(aniso_scale, np.ndarray): self.aniso_scale = np.atleast_2d(aniso_scale) else: raise ValueError("aniso_scale not understood") if points is not None: if len(points)<=k: raise ValueError("points must be longer than k") self._kdt = KDT(n_neighbors=k, n_jobs=n_jobs, algorithm='kd_tree', leaf_size=20) self._kdt.fit(points * self.aniso_scale) self._static_points = points
[docs] def append(self, point): point = np.asarray(point) if self._new_points==[]: self._new_points = AccumulationArray(right_shape=point.shape, n_init=self._rebuild_thresh) if len(self._new_points)==self._rebuild_thresh: if self._static_points==[]: self._static_points = self._new_points.data.copy() else: self._static_points = np.concatenate( [self._static_points, self._new_points.data], axis=0) self._new_points.clear() self._kdt = KDT(n_neighbors=self._k, n_jobs=self._jobs, algorithm='kd_tree', leaf_size=20) self._kdt.fit(self._static_points * self.aniso_scale) self._new_points.append(point)
[docs] def get_knn(self, query_points, k=None): if k is None: k = self._k if k>(len(self._new_points) + len(self._static_points)): raise ValueError("The requested number of neighbours is larger " "than the number of stored points") if query_points.ndim==1: query_points = query_points[None] q = len(query_points) if len(self._static_points): # assert k==self._kdt.n_neighbors distances, indices = self._kdt.kneighbors( query_points * self.aniso_scale, n_neighbors=k) # Add inf for stopping distances = np.hstack([distances, np.ones((q, 1)) * np.inf]) else: distances = np.ones((q, 1)) * np.inf indices = np.zeros((q, 1), dtype=np.int) if len(self._new_points): new_points = self._new_points.data pairwise_dist = cdist(query_points * self.aniso_scale, new_points * self.aniso_scale, p=2) # Add inf for stopping pairwise_dist = np.hstack( [pairwise_dist, np.ones((q, 1)) * np.inf]) else: new_points = np.zeros((0, 1), dtype=query_points.dtype) pairwise_dist = np.ones((q, 1)) * np.inf # (q,1) if k==1: indices = indices[:, 0] distances = distances[:, 0].astype(np.float32) if len(self._static_points): coordinates = self._static_points[indices] else: coordinates = new_points[indices] # Override found neighbours if a closer neighbour is in new_points replace_by_new = pairwise_dist.min(axis=1)<distances distances[replace_by_new] = pairwise_dist[replace_by_new] new_index = pairwise_dist.argmin(axis=1) if np.any(replace_by_new): coordinates[replace_by_new] = new_points[ new_index[replace_by_new]] indices[replace_by_new] = new_index[replace_by_new] + len( self._static_points) # -666 # This is just a dummy, atm indices is not used anyway distances[replace_by_new] = pairwise_dist[ replace_by_new, new_index[replace_by_new]] else: if len(self._static_points): coordinates = self._static_points[indices] else: coordinates = np.zeros_like(new_points) sort_ix = pairwise_dist.argsort(axis=1) # (q,n) this is ascending distances, indices, coordinates = _merge(distances, indices, coordinates, pairwise_dist, sort_ix, new_points, k, query_points) if q==1: distances = distances[0] indices = indices[0] coordinates = coordinates[0] return distances, indices, coordinates
[docs] def get_radius_nn(self, query_points, radius): raise NotImplementedError()
### Various Simple Utils ###
[docs]def import_variable_from_file(file_path, class_name): directory = os.path.dirname(file_path) sys.path.append(directory) mod_name = os.path.split(file_path)[1] if mod_name[-3:]=='.py': mod_name = mod_name[:-3] mod = importlib.import_module(mod_name) sys.path.pop(-1) cls = getattr(mod, class_name) return cls
[docs]def picklesave(data, file_name): """ Writes one or many objects to pickle file data: single objects to save or iterable of objects to save. For iterable, all objects are written in this order to the file. file_name: string path/name of destination file """ file_name = os.path.expanduser(file_name) with open(file_name, 'wb') as f: pkl.dump(data, f, protocol=2)
[docs]def pickleload(file_name): """ Loads all object that are saved in the pickle file. Multiple objects are returned as list. """ file_name = os.path.expanduser(file_name) ret = [] try: with open(file_name, 'rb') as f: try: while True: # Python 3 needs explicit encoding specification, # which Python 2 lacks: if sys.version_info.major>=3: ret.append(pkl.load(f, encoding='latin1')) else: ret.append(pkl.load(f)) except EOFError: pass if len(ret)==1: return ret[0] else: return ret except pkl.UnpicklingError: with gzip.open(file_name, 'rb') as f: try: while True: # Python 3 needs explicit encoding specification, # which Python 2 lacks: if sys.version_info.major>=3: ret.append(pkl.load(f, encoding='latin1')) else: ret.append(pkl.load(f)) except EOFError: pass if len(ret)==1: return ret[0] else: return ret
[docs]def h5save(data, file_name, keys=None, compress=True): """ Writes one or many arrays to h5 file data: single array to save or iterable of arrays to save. For iterable all arrays are written to the file. file_name: string path/name of destination file keys: string / list thereof For single arrays this is a single string which is used as a name for the data set. For multiple arrays each dataset is named by the corresponding key. If keys is ``None``, the dataset names created by enumeration: ``data%i`` compress: Bool Whether to use lzf compression, defaults to ``True``. Most useful for label arrays. """ file_name = os.path.expanduser(file_name) compr = 'lzf' if compress else None f = h5py.File(file_name, "w") if isinstance(data, list) or isinstance(data, tuple): if keys is not None: assert len(keys)==len(data) for i, d in enumerate(data): if keys is None: f.create_dataset(str(i), data=d, compression=compr) else: f.create_dataset(keys[i], data=d, compression=compr) else: if keys is None: f.create_dataset('data', data=data, compression=compr) else: f.create_dataset(keys, data=data, compression=compr) f.close()
[docs]def h5load(file_name, keys=None): """ Loads data sets from h5 file file_name: string destination file keys: string / list thereof Load only data sets specified in keys and return as list in the order of ``keys`` For a single key the data is returned directly - not as list If keys is ``None`` all datasets that are listed in the keys-attribute of the h5 file are loaded. """ file_name = os.path.expanduser(file_name) ret = [] try: f = h5py.File(file_name, "r") except IOError: raise IOError("Could not open h5-File %s" % (file_name)) if keys is not None: try: if isinstance(keys, str): ret.append(f[keys].value) else: for k in keys: ret.append(f[k].value) except KeyError: raise KeyError("Could not read h5-dataset named %s. Available " "datasets: %s" % (keys, list(f.keys()))) else: for k in f.keys(): ret.append(f[k].value) f.close() if len(ret)==1: return ret[0] else: return ret
[docs]def pretty_string_ops(n): """ Return a humanized string representation of a large number. """ abbrevs = [(1000000000000, 'Tera Ops'), (1000000000, 'Giga Ops'), (1000000, 'Mega Ops'), (1000, 'kilo Ops')] for factor, suffix in abbrevs: if n>=factor: break return "%.1f %s" % (float(n) / factor, suffix)
[docs]def makeversiondir(path, dir_name=None, cd=False): path = os.path.expanduser(path) if dir_name: path = os.path.join(path, dir_name) while True: if os.path.exists(path): try: num = re.findall(r"-v(\d+)$", path)[0] num = int(num) i = 2 + np.int(np.log10(num) + 1) num = "-v" + str(int(num) + 1) path = path[:-i] + num except: path = path + "-v1" else: break os.makedirs(path, mode=0o755) if cd: os.chdir(path) return path
[docs]class Timer(object): def __init__(self, silent_all=False): self.last_t = time.time() self.total = 0 self.checktimes = [] self.checknames = [] self.accumulator = {} self.silent_all = silent_all
[docs] def check(self, name=None, silent=False): t = time.time() dt = t - self.last_t self.total += dt self.last_t = t name = name if name is not None else "" if not silent and not self.silent_all: print("%s\tdt=%.3g s,\tt=%.3g s" % (name, dt, self.total)) self.checknames.append(name) self.checktimes.append(dt) if name is not None: accum = self.accumulator.get(name, 0) self.accumulator[name] = accum + dt
[docs] def plot(self, accum=False): # I don't want this import every time utils is used import matplotlib.pyplot as plt fig, ax = plt.subplots() if accum: times = list(self.accumulator.values()) names = list(self.accumulator.keys()) else: times = self.checktimes names = self.checknames ind = np.arange(len(times)) ax.bar(ind, times) ax.set_xticks(ind + 0.5) ax.set_xticklabels(names) plt.show()
[docs] def summary(self, silent=False, print_func=None): s = "Total t: %.3gs\n" % self.total if len(self.accumulator): ix = np.argsort(self.accumulator.values()) for i in ix: name, dt = self.accumulator.items()[i] s += "%s:\t\t%.3gs\t%.3g%%\n"\ %(name, dt, dt/self.total*100.0) else: ix = np.argsort(self.checktimes) for i in ix: name, dt = self.checknames[i], self.checktimes[i] s += "%s:\t\t%.3gs\t%.3g%%\n"\ % (name, dt. dt/self.total*100.0) if silent: return s else: if print_func: print_func(s) else: print(s)
[docs]def pretty_string_time(t): """Custom printing of elapsed time""" if t>4000: s = 't=%.1fh' % (t / 3600) elif t>300: s = 't=%.0fm' % (t / 60) else: s = 't=%.0fs' % (t) return s
[docs]def unique_rows(a): # removes duplicates from a a = np.ascontiguousarray(a) unique_a, index = np.unique(a.view([('', a.dtype)] * a.shape[1]), return_index=True) ret = unique_a.view(a.dtype).reshape( (unique_a.shape[0], a.shape[1])), index print("Removed %i of %i (new %i)" % ( len(a) - len(index), len(a), len(index))) return ret
[docs]def as_list(var): if var is None: return var elif isinstance(var, (list, tuple)): return list(var) else: return [var, ]