diff --git a/octis/models/ETM.py b/octis/models/ETM.py index fe0116c8..9f200e9c 100644 --- a/octis/models/ETM.py +++ b/octis/models/ETM.py @@ -16,7 +16,7 @@ def __init__( self, num_topics=10, num_epochs=100, t_hidden_size=800, rho_size=300, embedding_size=300, activation='relu', dropout=0.5, lr=0.005, optimizer='adam', batch_size=128, clip=0.0, wdecay=1.2e-6, bow_norm=1, - device='cpu', train_embeddings=True, embeddings_path=None, + device='cuda', train_embeddings=True, embeddings_path=None, embeddings_type='pickle', binary_embeddings=True, headerless_embeddings=False, use_partitions=True): """ @@ -131,9 +131,12 @@ def set_model(self, dataset, hyperparameters): self.train_tokens, self.train_counts = self.preprocess( vocab2id, data_corpus, None) - self.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu") - + if isinstance(self.device, str): + self.device = torch.device(self.device) + + if (self.device.type == 'cuda' and not torch.cuda.is_available()) or (self.device.type == 'mps' and not torch.backends.mps.is_available()): + self.device = torch.device('cpu') + self.set_default_hyperparameters(hyperparameters) self.load_embeddings() # define model and optimizer diff --git a/octis/models/LSI.py b/octis/models/LSI.py index 1d6229ec..1c74def4 100644 --- a/octis/models/LSI.py +++ b/octis/models/LSI.py @@ -106,10 +106,10 @@ def train_model(self, dataset, hyperparameters={}, top_words=10): else: partition = [dataset.get_corpus(), []] - if self.id2word == None: + if self.id2word is None: self.id2word = corpora.Dictionary(dataset.get_corpus()) - if self.id_corpus == None: + if self.id_corpus is None: self.id_corpus = [self.id2word.doc2bow( document) for document in partition[0]]