Skip to content

Commit

Permalink
fix(configs): models states saving and loading
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 10, 2024
1 parent 1fcafb5 commit 66be639
Showing 1 changed file with 109 additions and 11 deletions.
120 changes: 109 additions & 11 deletions neuralnetlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,8 @@ def save(self, filename: str):
'random_state': self.random_state,
'skip_connections': self.skip_connections,
'l1_reg': self.l1_reg,
'l2_reg': self.l2_reg
'l2_reg': self.l2_reg,
'variational': self.variational
}

for layer in self.encoder_layers:
Expand Down Expand Up @@ -2265,7 +2266,8 @@ def evaluate(self, x_test: list[np.ndarray], y_test: np.ndarray, batch_size: int
return avg_loss, all_predictions

def get_config(self) -> dict:
return {
config = {
'type': 'Transformer',
'src_vocab_size': self.src_vocab_size,
'tgt_vocab_size': self.tgt_vocab_size,
'd_model': self.d_model,
Expand All @@ -2278,15 +2280,24 @@ def get_config(self) -> dict:
'gradient_clip_threshold': self.gradient_clip_threshold,
'enable_padding': self.enable_padding,
'padding_size': self.padding_size,
'random_state': self.random_state
'random_state': self.random_state,

'src_embedding': self.src_embedding.get_config(),
'tgt_embedding': self.tgt_embedding.get_config(),
'positional_encoding': self.positional_encoding.get_config(),

'encoder_layers': [layer.get_config() for layer in self.encoder_layers],
'decoder_layers': [layer.get_config() for layer in self.decoder_layers],

'encoder_dropout': self.encoder_dropout.get_config(),
'decoder_dropout': self.decoder_dropout.get_config(),

'output_layer': self.output_layer.get_config(),

'loss_function': self.loss_function.get_config() if self.loss_function is not None else None,
'optimizer': self.optimizer.get_config() if self.optimizer is not None else None
}

def save(self, filename: str) -> None:
config = self.get_config()
config['type'] = 'Transformer'

with open(filename, 'w') as f:
json.dump(config, f, indent=4)
return config

@classmethod
def load(cls, filename: str) -> 'Transformer':
Expand All @@ -2296,7 +2307,94 @@ def load(cls, filename: str) -> 'Transformer':
if config['type'] != 'Transformer':
raise ValueError(f"Invalid model type {config['type']}")

return cls(**{k: v for k, v in config.items() if k != 'type'})
model = cls(
src_vocab_size=config['src_vocab_size'],
tgt_vocab_size=config['tgt_vocab_size'],
d_model=config['d_model'],
n_heads=config['n_heads'],
n_encoder_layers=config['n_encoder_layers'],
n_decoder_layers=config['n_decoder_layers'],
d_ff=config['d_ff'],
dropout_rate=config['dropout_rate'],
max_sequence_length=config['max_sequence_length'],
gradient_clip_threshold=config['gradient_clip_threshold'],
enable_padding=config['enable_padding'],
padding_size=config['padding_size'],
random_state=config['random_state']
)

model.src_embedding = Embedding.from_config(config['src_embedding'])
model.tgt_embedding = Embedding.from_config(config['tgt_embedding'])
model.positional_encoding = PositionalEncoding.from_config(config['positional_encoding'])

model.encoder_dropout = Dropout.from_config(config['encoder_dropout'])
model.decoder_dropout = Dropout.from_config(config['decoder_dropout'])

model.encoder_layers = [TransformerEncoderLayer.from_config(layer_config)
for layer_config in config['encoder_layers']]
model.decoder_layers = [TransformerDecoderLayer.from_config(layer_config)
for layer_config in config['decoder_layers']]

model.output_layer = Dense.from_config(config['output_layer'])

if config['loss_function']:
model.loss_function = LossFunction.from_config(config['loss_function'])
if config['optimizer']:
model.optimizer = Optimizer.from_config(config['optimizer'])

return model

def save(self, filename: str) -> None:
base, ext = os.path.splitext(filename)

config = self.get_config()

if self.src_embedding is not None:
src_emb_file = f"{base}_src_embedding{ext}"
config['src_embedding_file'] = src_emb_file
with open(src_emb_file, 'w') as f:
json.dump(self.src_embedding.get_config(), f, indent=4)

if self.tgt_embedding is not None:
tgt_emb_file = f"{base}_tgt_embedding{ext}"
config['tgt_embedding_file'] = tgt_emb_file
with open(tgt_emb_file, 'w') as f:
json.dump(self.tgt_embedding.get_config(), f, indent=4)

config['encoder_layers_files'] = []
for i, layer in enumerate(self.encoder_layers):
encoder_file = f"{base}_encoder_layer_{i}{ext}"
config['encoder_layers_files'].append(encoder_file)
with open(encoder_file, 'w') as f:
json.dump(layer.get_config(), f, indent=4)

config['decoder_layers_files'] = []
for i, layer in enumerate(self.decoder_layers):
decoder_file = f"{base}_decoder_layer_{i}{ext}"
config['decoder_layers_files'].append(decoder_file)
with open(decoder_file, 'w') as f:
json.dump(layer.get_config(), f, indent=4)

if self.output_layer is not None:
output_file = f"{base}_output_layer{ext}"
config['output_layer_file'] = output_file
with open(output_file, 'w') as f:
json.dump(self.output_layer.get_config(), f, indent=4)

if self.optimizer is not None:
optimizer_file = f"{base}_optimizer{ext}"
config['optimizer_file'] = optimizer_file
with open(optimizer_file, 'w') as f:
json.dump(self.optimizer.get_config(), f, indent=4)

if self.loss_function is not None:
loss_file = f"{base}_loss{ext}"
config['loss_file'] = loss_file
with open(loss_file, 'w') as f:
json.dump(self.loss_function.get_config(), f, indent=4)

with open(filename, 'w') as f:
json.dump(config, f, indent=4)

def __str__(self) -> str:
return (f"Transformer(\n"
Expand Down

0 comments on commit 66be639

Please sign in to comment.