elektronn2.training.trainer module

class elektronn2.training.trainer.Trainer(exp_config)[source]

Bases: object

debug_getcnnbatch()[source]

Executes getbatch but with un-strided labels and always returning info. The first batch example is plotted and the whole batch is returned for inspection.

predict_and_write(pred_node, raw_img, number=0, export_class='all', block_name='', z_thick=5)[source]

Predict and and save a slice as preview image

Parameters:
  • raw_img (np.ndarray) – raw data in the format (ch, x, y, z)
  • number (int/float) – consecutive number for the save name (i.e. hours, iterations etc.)
  • export_class (str or int) – ‘all’ writes images of all classes, otherwise only the class with index export_class (int) is saved.
  • block_name (str) – Name/number to distinguish different raw_imges
preview_slice(number=0, export_class='all', max_z_pred=5)[source]

Predict and and save a data from a separately loaded file as preview

Parameters:
  • number (int/float) – consecutive number for the save name (i.e. hours, iterations etc.)
  • export_class (str or int) – ‘all’ writes images of all classes, otherwise only the class with index export_class (int) is saved.
  • max_z_pred (int) – approximate maximal number of z-slices to produce (depends on CNN architecture)
preview_slice_from_traindata(cube_i=0, off=(0, 0, 0), sh=(10, 400, 400), number=0, export_class='all')[source]

Predict and and save a selected slice from the training data as preview

Parameters:
  • cube_i (int) – index of source cube in CNNData
  • off (3-tuple of int) – start index of slice to cut from cube (z,y,x)
  • sh (3-tuple of int) – shape of cube to cut (z,y,x)
  • number (int) – consecutive number for the save name (i.e. hours, iterations etc.)
  • export_class (str or int) – ‘all’ writes images of all classes, otherwise only the class with index export_class (int) is saved.
run()[source]
test_model(data_source)[source]

Computes Loss and error/accuracy on batch with monitor_batch_size

Parameters:data_source (string) – ‘train’ or ‘valid’
Returns:
Return type:Loss, error
class elektronn2.training.trainer.TracingTrainer(exp_config)[source]

Bases: elektronn2.training.trainer.Trainer

debug_getcnnbatch(extended=False)[source]

Executes getbatch but with un-strided labels and always returning info. The first batch example is plotted and the whole batch is returned for inspection.

run()[source]
static save_batch(img, lab, k, lab_img=None)[source]
test_model(data_source)[source]

Computes Loss and error/accuracy on batch with monitor_batch_size

Parameters:data_source (string) – ‘train’ or ‘valid’
Returns:
Return type:Loss, error
class elektronn2.training.trainer.TracingTrainerRNN(exp_config)[source]

Bases: elektronn2.training.trainer.TracingTrainer

run()[source]
test_model(data_source)[source]

Computes Loss and error/accuracy on batch with monitor_batch_size

Parameters:data_source (string) – ‘train’ or ‘valid’
Returns:
Return type:Loss, error