forked from rwth-i6/returnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdump-forward.py
executable file
·76 lines (61 loc) · 2.52 KB
/
dump-forward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python
"""
For debugging, go through some dataset, forward it through the net, and output the layer activations on stdout.
"""
from __future__ import annotations
import sys
import _setup_returnn_env # noqa
import returnn.__main__ as rnn
from returnn.log import log
import argparse
from returnn.util.basic import pretty_print
def dump(dataset, options):
"""
:type dataset: Dataset.Dataset
:param options: argparse.Namespace
"""
print("Epoch: %i" % options.epoch, file=log.v3)
dataset.init_seq_order(options.epoch)
output_dict = {}
for name, layer in rnn.engine.network.layers.items():
output_dict["%s:out" % name] = layer.output.placeholder
for i, v in layer.output.size_placeholder.items():
output_dict["%s:shape(%i)" % (name, layer.output.get_batch_axis(i))] = v
seq_idx = options.startseq
if options.endseq < 0:
options.endseq = float("inf")
while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= options.endseq:
print("Seq idx: %i" % (seq_idx,), file=log.v3)
out = rnn.engine.run_single(dataset=dataset, seq_idx=seq_idx, output_dict=output_dict)
for name, v in sorted(out.items()):
print(" %s: %s" % (name, pretty_print(v)))
seq_idx += 1
print("Done. More seqs which we did not dumped: %s" % dataset.is_less_than_num_seqs(seq_idx), file=log.v1)
def init(config_filename, command_line_options):
"""
:param str config_filename:
:param list[str] command_line_options:
"""
rnn.init(
config_filename=config_filename,
command_line_options=command_line_options,
config_updates={"log": None},
extra_greeting="RETURNN dump-forward starting up.",
)
rnn.engine.init_train_from_config(config=rnn.config, train_data=rnn.train_data)
# rnn.engine.init_network_from_config(rnn.config)
def main(argv):
"""
Main entry.
"""
arg_parser = argparse.ArgumentParser(description="Forward something and dump it.")
arg_parser.add_argument("returnn_config")
arg_parser.add_argument("--epoch", type=int, default=1)
arg_parser.add_argument("--startseq", type=int, default=0, help="start seq idx (inclusive) (default: 0)")
arg_parser.add_argument("--endseq", type=int, default=10, help="end seq idx (inclusive) or -1 (default: 10)")
args = arg_parser.parse_args(argv[1:])
init(config_filename=args.returnn_config, command_line_options=[])
dump(rnn.train_data, args)
rnn.finalize()
if __name__ == "__main__":
main(sys.argv)