Skip to content

Commit

Permalink
Merge pull request #9 from uf-hobi-informatics-lab/origin/advDev
Browse files Browse the repository at this point in the history
fix the batch first issue for LSTM
  • Loading branch information
bugface authored Sep 13, 2020
2 parents a10785c + d15bd09 commit 3a574fe
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions seq_ehr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __init__(self, config, model_type=ModelType.M_LSTM):
# TLSTM hidden state dim = (B, h)
self.seq_model = TLSTMCell(config.seq_input_dim, config.seq_hidden_dim)
elif model_type is ModelType.M_LSTM:
# LSTM hidden state dim = (num_layers * num_directions, batch, hidden_size)
self.seq_model = N.LSTM(config.seq_input_dim, config.seq_hidden_dim)
# LSTM hidden state dim = (batch, num_layers * num_directions, hidden_size)
self.seq_model = N.LSTM(config.seq_input_dim, config.seq_hidden_dim, batch_first=True)
else:
raise NotImplementedError("We only support model ctlsm and clstm but get {}".format(model_type.value))

Expand Down

0 comments on commit 3a574fe

Please sign in to comment.