diff --git a/seq_ehr_model.py b/seq_ehr_model.py index 6dcbc09..6f6b709 100644 --- a/seq_ehr_model.py +++ b/seq_ehr_model.py @@ -54,8 +54,8 @@ def __init__(self, config, model_type=ModelType.M_LSTM): # TLSTM hidden state dim = (B, h) self.seq_model = TLSTMCell(config.seq_input_dim, config.seq_hidden_dim) elif model_type is ModelType.M_LSTM: - # LSTM hidden state dim = (num_layers * num_directions, batch, hidden_size) - self.seq_model = N.LSTM(config.seq_input_dim, config.seq_hidden_dim) + # LSTM hidden state dim = (batch, num_layers * num_directions, hidden_size) + self.seq_model = N.LSTM(config.seq_input_dim, config.seq_hidden_dim, batch_first=True) else: raise NotImplementedError("We only support model ctlsm and clstm but get {}".format(model_type.value))