TFEngine

TensorFlow engine

The basic engine for the TensorFlow backend is implemented here, i.e. the high-level logic to train, i.e. looping over epochs, holding the network instance, creating the TensorFlow session, managing the data pipeline, etc.

See Technological overview for an overview how it fits all together.

class TFEngine.Engine(config=None)[source]
Parameters:config (Config.Config|None) –
analyze(data, statistics)[source]
Parameters:
Returns:

nothing, will print everything to log.v1

check_last_epoch()[source]
check_uninitialized_vars()[source]

All vars in TF which are controlled by us should also have been initialized by us. We also take care about the optimizer slot variables. However, TF can still create other vars which we do not know about. E.g. the Adam optimizer creates the beta1_power/beta2_power vars (which are no slot vars). Here, we find all remaining uninitialized vars, report about them and initialize them.

config_get_final_epoch(config)[source]
epoch_model_filename(model_filename, epoch, is_pretrain)[source]
Return type:str
eval_model()[source]
eval_single(dataset, seq_idx, output_dict)[source]
Parameters:
  • dataset (Dataset.Dataset) –
  • seq_idx (int) –
  • output_dict (dict[str,tf.Tensor]) – key -> tf.Tensor
Returns:

output_dict but values evaluated

Return type:

dict[str,numpy.ndarray]

finalize()[source]
format_score(score)[source]
forward_single(dataset, seq_idx, output_layer_name=None)[source]
Parameters:
  • dataset (Dataset.Dataset) –
  • seq_idx (int) –
  • output_layer_name (str|None) – e.g. “output”. if not set, will read from config “forward_output_layer”
Returns:

numpy array, output in time major format (time,dim)

Return type:

numpy.ndarray

forward_to_hdf(data, output_file, combine_labels='', batch_size=0)[source]

Is aiming at recreating the same interface and output as Engine.forward_to_hdf.

get_all_merged_summaries()[source]
Returns:merged summaries, serialized string
Return type:tf.Tensor
get_const_tensor(key, value)[source]
get_epoch_model(config)[source]

:returns (epoch, modelFilename) :rtype: (int|None, str|None)

get_epoch_model_filename(epoch=None)[source]
get_epoch_str()[source]
get_eval_datasets()[source]
get_specific_feed_dict(dataset, seq_idx)[source]
Parameters:
Returns:

feed_dict for self.tf_session.run()

Return type:

dict[str,numpy.ndarray]

get_train_start_epoch_batch(config)[source]

We will always automatically determine the best start (epoch,batch) tuple based on existing model files. This ensures that the files are present and enforces that there are no old outdated files which should be ignored. Note that epochs start at idx 1 and batches at idx 0. :type config: Config.Config :returns (epoch,batch) :rtype (int,int)

init_network_from_config(config)[source]
Parameters:config (Config.Config) –
init_train_epoch()[source]
init_train_from_config(config, train_data, dev_data=None, eval_data=None)[source]
is_first_epoch_after_pretrain()[source]
is_pretrain_epoch(epoch=None)[source]
is_requesting_for_gpu()[source]
load_model(epoch=None, filename=None)[source]
Parameters:
  • epoch (int) –
  • filename (str) –
maybe_init_new_network(net_desc)[source]
save_model(filename=None)[source]
Parameters:filename (str) – full filename for model
search(dataset, do_eval=True, output_layer_name='output')[source]
Parameters:
  • dataset (Dataset.Dataset) –
  • do_eval (bool) – calculate errors. can only be done if we have the reference target
  • output_layer_name (str) –
train()[source]
train_epoch()[source]
class TFEngine.Runner(engine, dataset, batches, train, eval=True, extra_fetches=None, extra_fetches_callback=None)[source]
Parameters:
  • engine (Engine) –
  • dataset (Dataset.Dataset) –
  • batches (BatchSetGenerator) –
  • train (bool) – whether to do updates on the model
  • eval (bool) – whether to evaluate (i.e. calculate loss/error)
  • extra_fetches (dict[str,tf.Tensor|TFUtil.Data|TFNetworkLayer.LayerBase]|None) – additional fetches per step. extra_fetches_callback will be called with these. In case of Data/LayerBase, it will return a list, where each item corresponds to the batch-seq. It might also be useful to add network.get_extern_data(“seq_idx”) and network.get_extern_data(“seq_tag”).
  • extra_fetches_callback ((**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None) –

    called if extra_fetches

run(report_prefix)[source]
Parameters:report_prefix (str) – prefix for logging