diff --git a/neuralnetlib/models.py b/neuralnetlib/models.py index 54befbe..748bbf1 100644 --- a/neuralnetlib/models.py +++ b/neuralnetlib/models.py @@ -1205,7 +1205,8 @@ def save(self, filename: str): 'random_state': self.random_state, 'skip_connections': self.skip_connections, 'l1_reg': self.l1_reg, - 'l2_reg': self.l2_reg + 'l2_reg': self.l2_reg, + 'variational': self.variational } for layer in self.encoder_layers: @@ -2265,7 +2266,8 @@ def evaluate(self, x_test: list[np.ndarray], y_test: np.ndarray, batch_size: int return avg_loss, all_predictions def get_config(self) -> dict: - return { + config = { + 'type': 'Transformer', 'src_vocab_size': self.src_vocab_size, 'tgt_vocab_size': self.tgt_vocab_size, 'd_model': self.d_model, @@ -2278,15 +2280,24 @@ def get_config(self) -> dict: 'gradient_clip_threshold': self.gradient_clip_threshold, 'enable_padding': self.enable_padding, 'padding_size': self.padding_size, - 'random_state': self.random_state + 'random_state': self.random_state, + + 'src_embedding': self.src_embedding.get_config(), + 'tgt_embedding': self.tgt_embedding.get_config(), + 'positional_encoding': self.positional_encoding.get_config(), + + 'encoder_layers': [layer.get_config() for layer in self.encoder_layers], + 'decoder_layers': [layer.get_config() for layer in self.decoder_layers], + + 'encoder_dropout': self.encoder_dropout.get_config(), + 'decoder_dropout': self.decoder_dropout.get_config(), + + 'output_layer': self.output_layer.get_config(), + + 'loss_function': self.loss_function.get_config() if self.loss_function is not None else None, + 'optimizer': self.optimizer.get_config() if self.optimizer is not None else None } - - def save(self, filename: str) -> None: - config = self.get_config() - config['type'] = 'Transformer' - - with open(filename, 'w') as f: - json.dump(config, f, indent=4) + return config @classmethod def load(cls, filename: str) -> 'Transformer': @@ -2296,7 +2307,94 @@ def load(cls, filename: str) -> 'Transformer': if config['type'] != 'Transformer': raise ValueError(f"Invalid model type {config['type']}") - return cls(**{k: v for k, v in config.items() if k != 'type'}) + model = cls( + src_vocab_size=config['src_vocab_size'], + tgt_vocab_size=config['tgt_vocab_size'], + d_model=config['d_model'], + n_heads=config['n_heads'], + n_encoder_layers=config['n_encoder_layers'], + n_decoder_layers=config['n_decoder_layers'], + d_ff=config['d_ff'], + dropout_rate=config['dropout_rate'], + max_sequence_length=config['max_sequence_length'], + gradient_clip_threshold=config['gradient_clip_threshold'], + enable_padding=config['enable_padding'], + padding_size=config['padding_size'], + random_state=config['random_state'] + ) + + model.src_embedding = Embedding.from_config(config['src_embedding']) + model.tgt_embedding = Embedding.from_config(config['tgt_embedding']) + model.positional_encoding = PositionalEncoding.from_config(config['positional_encoding']) + + model.encoder_dropout = Dropout.from_config(config['encoder_dropout']) + model.decoder_dropout = Dropout.from_config(config['decoder_dropout']) + + model.encoder_layers = [TransformerEncoderLayer.from_config(layer_config) + for layer_config in config['encoder_layers']] + model.decoder_layers = [TransformerDecoderLayer.from_config(layer_config) + for layer_config in config['decoder_layers']] + + model.output_layer = Dense.from_config(config['output_layer']) + + if config['loss_function']: + model.loss_function = LossFunction.from_config(config['loss_function']) + if config['optimizer']: + model.optimizer = Optimizer.from_config(config['optimizer']) + + return model + + def save(self, filename: str) -> None: + base, ext = os.path.splitext(filename) + + config = self.get_config() + + if self.src_embedding is not None: + src_emb_file = f"{base}_src_embedding{ext}" + config['src_embedding_file'] = src_emb_file + with open(src_emb_file, 'w') as f: + json.dump(self.src_embedding.get_config(), f, indent=4) + + if self.tgt_embedding is not None: + tgt_emb_file = f"{base}_tgt_embedding{ext}" + config['tgt_embedding_file'] = tgt_emb_file + with open(tgt_emb_file, 'w') as f: + json.dump(self.tgt_embedding.get_config(), f, indent=4) + + config['encoder_layers_files'] = [] + for i, layer in enumerate(self.encoder_layers): + encoder_file = f"{base}_encoder_layer_{i}{ext}" + config['encoder_layers_files'].append(encoder_file) + with open(encoder_file, 'w') as f: + json.dump(layer.get_config(), f, indent=4) + + config['decoder_layers_files'] = [] + for i, layer in enumerate(self.decoder_layers): + decoder_file = f"{base}_decoder_layer_{i}{ext}" + config['decoder_layers_files'].append(decoder_file) + with open(decoder_file, 'w') as f: + json.dump(layer.get_config(), f, indent=4) + + if self.output_layer is not None: + output_file = f"{base}_output_layer{ext}" + config['output_layer_file'] = output_file + with open(output_file, 'w') as f: + json.dump(self.output_layer.get_config(), f, indent=4) + + if self.optimizer is not None: + optimizer_file = f"{base}_optimizer{ext}" + config['optimizer_file'] = optimizer_file + with open(optimizer_file, 'w') as f: + json.dump(self.optimizer.get_config(), f, indent=4) + + if self.loss_function is not None: + loss_file = f"{base}_loss{ext}" + config['loss_file'] = loss_file + with open(loss_file, 'w') as f: + json.dump(self.loss_function.get_config(), f, indent=4) + + with open(filename, 'w') as f: + json.dump(config, f, indent=4) def __str__(self) -> str: return (f"Transformer(\n"