Skip to content

Commit

Permalink
limit "t" and correct prev non blank for search
Browse files Browse the repository at this point in the history
  • Loading branch information
jotix16 committed Jun 1, 2021
1 parent e62f264 commit dcae584
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions common/models/transducer/transducer_fullsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,15 @@ def make(self, encoder: LayerRef):
blank_idx = self.ctx.blank_idx

rec_decoder = {
"am0": {"class": "gather_nd", "from": _base(encoder), "position": "prev:t"}, # [B,D]
"index": {"class": "eval", "from": ["prev:t", "enc_seq_len"], "eval": 'tf.minimum(source(0), source(1)-1)'},
"am0": {"class": "gather_nd", "from": _base(encoder), "position": "index"}, # [B,D]
"am": {"class": "copy", "from": "am0" if search else "data:source"},

"prev_output_wo_b": {
"class": "masked_computation", "unit": {"class": "copy", "initial_output": 0},
"from": "prev:output_", "mask": "prev:output_emit", "initial_output": 0},
"prev_out_non_blank": {
"class": "reinterpret_data", "from": "prev:output_", "set_sparse_dim": target.get_num_classes()},
"class": "reinterpret_data", "from": "prev_output_wo_b", "set_sparse_dim": target.get_num_classes()},

"slow_rnn": self.slow_rnn.make(
prev_sparse_label_nb="prev_out_non_blank",
Expand All @@ -252,7 +256,7 @@ def make(self, encoder: LayerRef):

"output": {
"class": 'choice',
'target': target.key, # note: wrong! but this is ignored both in full-sum training and in search
'target': target.key if train else None, # note: wrong! but this is ignored both in full-sum training and in search
'beam_size': beam_size,
'from': "output_log_prob_wb", "input_type": "log_prob",
"initial_output": 0,
Expand Down

0 comments on commit dcae584

Please sign in to comment.