elektronn2.training package

Submodules

elektronn2.training.parallelisation module

class elektronn2.training.parallelisation.BackgroundProc(target, dtypes=None, shapes=None, n_proc=1, target_args=(), target_kwargs={}, profile=False)[source]

Bases: elektronn2.training.parallelisation.SharedMem

Data structure to manage repeated background tasks by reusing a fixed number of initially created background process with the same arguments at every time. (E.g. retrieving an augmented batch) Remember to call BackgroundProc.shutdown after use to avoid zombie process and RAM clutter.

Parameters:
  • dtypes – list of dtypes of the target return values
  • shapes – list of shapes of the target return values
  • n_proc (int) – number of background procs to use
  • target (callable) – target function for background proc. Can even be a method of an object, if object data is read-only (then data will not be copied in RAM and the new process is lean). If several procs use random modules, new seeds must be created inside target because they have the same random state at the beginning.
  • target_args (tuple) – Proc args (constant)
  • target_kwargs (dict) – Proc kwargs (constant)
  • profile (Bool) – Whether to print timing results in to stdout

Examples

Use case to retrieve batches from a data structure D:

>>> data, label = D.getbatch(2, strided=False, flip=True, grey_augment_channels=[0])
>>> kwargs = {'strided': False, 'flip': True, 'grey_augment_channels': [0]}
>>> bg = BackgroundProc([np.float32, np.int16], [data.shape,label.shape],                             D.getbatch, n_proc=2, target_args=(2,),                             target_kwargs=kwargs, profile=False)
>>> for i in range(100):
>>>    data, label = bg.get()
get(timeout=False)[source]

This gets the next result from a background process and blocks until the corresponding proc has finished.

reset()[source]

Should be called after an exception (e.g. by pressing ctrl+c) was raised.

shutdown()[source]

Must be called to free memory if the background tasks are no longer needed

class elektronn2.training.parallelisation.SharedQ(n_proc=0, profile=False)[source]

Bases: elektronn2.training.parallelisation.SharedMem

FIFO Queue to process np.ndarrays in the background (also pre-loading of data from disk)

procs must accept list of mp.Array and make items np.ndarray using SharedQ.shm2ndarray, for this the shapes are required as too. The target requires the signature:

>>> target(mp_arrays, shapes, *args, **kwargs)

Whereas mp_array and shape are automatically added internally

All parameters are optional:

Parameters:
  • n_proc (int) – If larger than 0, a message is printed if to few processes are running
  • profile (Bool) – Whether to print timing results in terminal

Examples

Automatic use:

>>> Q = SharedQ(n_proc=2)
>>> Q.startproc(target=, shape= args=, kwargs=)
>>> Q.startproc(target=, shape= args=, kwargs=)
>>> for i in range(5):
>>>     Q.startproc(target=, shape= args=, kwargs=)
>>>     item = Q.get() # starts as many new jobs as to maintain n_proc
>>>     dosomethingelse(item) # processes work in background to pre-fetch data for next iteration
get()[source]

This gets the first results in the queue and blocks until the corresponding proc has finished. If a n_proc value is defined this then new procs must be started before to avoid a warning message.

startproc(dtypes, shapes, target, target_args=(), target_kwargs={})[source]

Starts a new process

procs must accept list of mp.Array and make items np.ndarray using SharedQ.shm2ndarray, or this the shapes are required as too. The target requires the signature:

target(mp_arrays, shapes, *args, **kwargs)

Whereas mp_array and shape are automatically added internally

exception elektronn2.training.parallelisation.TimeoutError(*args, **kwargs)[source]

Bases: exceptions.RuntimeError

elektronn2.training.trainer module

class elektronn2.training.trainer.Trainer(exp_config)[source]

Bases: object

debug_getcnnbatch()[source]

Executes getbatch but with un-strided labels and always returning info. The first batch example is plotted and the whole batch is returned for inspection.

predict_and_write(pred_node, raw_img, number=0, export_class='all', block_name='', z_thick=5)[source]

Predict and and save a slice as preview image

Parameters:
  • raw_img (np.ndarray) – raw data in the format (ch, x, y, z)
  • number (int/float) – consecutive number for the save name (i.e. hours, iterations etc.)
  • export_class (str or int) – ‘all’ writes images of all classes, otherwise only the class with index export_class (int) is saved.
  • block_name (str) – Name/number to distinguish different raw_imges
preview_slice(number=0, export_class='all', max_z_pred=5)[source]

Predict and and save a data from a separately loaded file as preview

Parameters:
  • number (int/float) – consecutive number for the save name (i.e. hours, iterations etc.)
  • export_class (str or int) – ‘all’ writes images of all classes, otherwise only the class with index export_class (int) is saved.
  • max_z_pred (int) – approximate maximal number of z-slices to produce (depends on CNN architecture)
preview_slice_from_traindata(cube_i=0, off=(0, 0, 0), sh=(10, 400, 400), number=0, export_class='all')[source]

Predict and and save a selected slice from the training data as preview

Parameters:
  • cube_i (int) – index of source cube in CNNData
  • off (3-tuple of int) – start index of slice to cut from cube (z,y,x)
  • sh (3-tuple of int) – shape of cube to cut (z,y,x)
  • number (int) – consecutive number for the save name (i.e. hours, iterations etc.)
  • export_class (str or int) – ‘all’ writes images of all classes, otherwise only the class with index export_class (int) is saved.
run()[source]
test_model(data_source)[source]

Computes Loss and error/accuracy on batch with monitor_batch_size

Parameters:data_source (string) – ‘train’ or ‘valid’
Returns:
Return type:Loss, error
class elektronn2.training.trainer.TracingTrainer(exp_config)[source]

Bases: elektronn2.training.trainer.Trainer

debug_getcnnbatch(extended=False)[source]

Executes getbatch but with un-strided labels and always returning info. The first batch example is plotted and the whole batch is returned for inspection.

run()[source]
static save_batch(img, lab, k, lab_img=None)[source]
test_model(data_source)[source]

Computes Loss and error/accuracy on batch with monitor_batch_size

Parameters:data_source (string) – ‘train’ or ‘valid’
Returns:
Return type:Loss, error
class elektronn2.training.trainer.TracingTrainerRNN(exp_config)[source]

Bases: elektronn2.training.trainer.TracingTrainer

run()[source]
test_model(data_source)[source]

Computes Loss and error/accuracy on batch with monitor_batch_size

Parameters:data_source (string) – ‘train’ or ‘valid’
Returns:
Return type:Loss, error

elektronn2.training.trainutils module

class elektronn2.training.trainutils.ExperimentConfig(exp_file, host_script_file=None, use_existing_dir=False)[source]

Bases: object

check_config()[source]
classmethod levenshtein(s1, s2)[source]

Computes Levenshtein-distance between s1 and s2 strings Taken from: http://en.wikibooks.org/wiki/Algorithm_Implementation/ Strings/Levenshtein_distance#Python

make_dir()[source]

Saves all python files into the folder specified by self.save_path Also changes working directory to the save_path directory

read_user_config()[source]
class elektronn2.training.trainutils.HistoryTracker[source]

Bases: object

load(file_name)[source]
plot(save_name=None, autoscale=True, close=True)[source]
register_debug_output_names(names)[source]
save(save_name)[source]
update_debug_outputs(vals)[source]
update_history(vals)[source]
update_regression(pred, target)[source]
update_timeline(vals)[source]
class elektronn2.training.trainutils.Schedule(**kwargs)[source]

Bases: object

Create a schedule for parameter or property

Examples

>>> lr_schedule = Schedule(dec=0.95) # decay by factor 0.95 every 1000 steps (i.e. decreasing by 5%)
>>> wd_schedule = Schedule(lindec=[4000, 0.001]) # from 0.001 to 0 in 400 steps
>>> mom_schedule = Schedule(updates=[(500,0.8), (1000,0.7), (1500,0.9), (2000, 0.2)])
>>> dropout_schedule = Schedule(updates=[(1000,[0.2, 0.2])]) # set rates per Layer
bind_variable(variable_param=None, obj=None, prop_name=None)[source]
update(iteration)[source]
elektronn2.training.trainutils.binary_nll(pred, gt)[source]
elektronn2.training.trainutils.confusion_table(labs, preds)[source]
Gives all counts of binary classifications situations:
labs:correct labels (-1 for ignore)
preds:0 for negative 1 for positive (class probabilities must be thresholded first)
Returns:count of: (true positive, true negative, false positive, false negative)
elektronn2.training.trainutils.error_hist(gt, preds, save_name, thresh=0.42)[source]

preds: predicted probability of class ‘1’ Saves plot to file

elektronn2.training.trainutils.eval_thresh(args)[source]

Calculates various performance measures at certain threshold :param args: thresh, labs, preds :return: tpr, fpr, precision, recall, bal_accur, accur, f1

elektronn2.training.trainutils.evaluate(gt, preds, save_name, thresh=None, n_proc=None)[source]

Evaluate prediction w.r.t. GT Saves plot to file :param save_name: :param gt: :param preds: from 0.0 to 1.0 :param thresh: if thresh is given (e.g. from tuning on validation set) some performance measures are shown at this threshold :return: perf, roc-area, threshs

elektronn2.training.trainutils.evaluate_model_binary(model, name, data=None, valid_d=None, valid_l=None, train_d=None, train_l=None, n_proc=2, betaloss=False, fudgeysoft=False)[source]
elektronn2.training.trainutils.find_nearest(array, value)[source]
elektronn2.training.trainutils.loadhistorytracker(file_name)[source]
elektronn2.training.trainutils.performance_measure(tp, tn, fp, fn)[source]
For output of confusion table gives various perfomance performance_measures:
return:tpr, fpr, precision, recall, balanced accuracy, accuracy, f1-score
elektronn2.training.trainutils.rescale_fudge(pred, fudge=0.15)[source]
elektronn2.training.trainutils.roc_area(tpr, fpr)[source]
Integrate ROC curve:
data:(tpr, fpr)
return:area
elektronn2.training.trainutils.user_input(local_vars)[source]

Module contents