diff --git a/model.py b/model.py index 46efc93d..0601971f 100644 --- a/model.py +++ b/model.py @@ -5,6 +5,7 @@ import pytorch_lightning as pl import copy from feature_transformer import DoubleFeatureTransformerSlice +from optimizer import NnueOptimizer # 3 layer fully connected network L1 = 512 @@ -95,6 +96,14 @@ def forward(self, x, ls_indices): return l3x_ + def disable_l1_factorization(self, optimizer): + with torch.no_grad(): + for i in range(self.count): + self.l1.weight[i*L2:(i+1)*L2, :].add_(self.l1_fact.weight.data) + self.l1_fact.weight.data.fill_(0.0) + self.l1_fact.weight.requires_grad = False + + def get_coalesced_layer_stacks(self): for i in range(self.count): with torch.no_grad(): @@ -128,6 +137,9 @@ def __init__(self, feature_set, lambda_=1.0): self.layer_stacks = LayerStacks(self.num_ls_buckets) self.lambda_ = lambda_ + self.disable_ft_factorization_after_steps = -1 + self.disable_l1_factorization_after_steps = -1 + self.weight_clipping = [ {'params' : [self.layer_stacks.l1.weight], 'min_weight' : -127/64, 'max_weight' : 127/64, 'virtual_params' : self.layer_stacks.l1_fact.weight }, {'params' : [self.layer_stacks.l2.weight], 'min_weight' : -127/64, 'max_weight' : 127/64 }, @@ -293,11 +305,35 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): self.step_(batch, batch_idx, 'test_loss') + def post_optimizer_step(self, optimizer): + num_finished_steps = optimizer.get_num_finished_steps() + + if num_finished_steps == self.disable_l1_factorization_after_steps: + self.disable_l1_factorization(optimizer) + + if num_finished_steps == self.disable_ft_factorization_after_steps: + self.disable_ft_factorization(optimizer) + + def disable_l1_factorization(self, optimizer): + self.layer_stacks.disable_l1_factorization(optimizer) + + def disable_ft_factorization(self, optimizer): + with torch.no_grad(): + weight = self.input.weight.data + indices = self.feature_set.get_virtual_to_real_features_gather_indices() + for i_real, is_virtual in enumerate(indices): + weight[i_real, :] = sum(weight[i_virtual, :] for i_virtual in is_virtual) + + for a, b in self.feature_set.get_virtual_feature_ranges(): + weight[a:b, :].fill_(0.0) + optimizer.freeze_parameter_region(self.input.weight, (slice(a, b), slice(None, None))) + def configure_optimizers(self): # Train with a lower LR on the output layer LR = 1.5e-3 train_params = [ - {'params' : get_parameters([self.input]), 'lr' : LR, 'gc_dim' : 0 }, + {'params' : [self.input.weight], 'lr' : LR, 'gc_dim' : 0, 'maskable' : (self.disable_ft_factorization_after_steps != -1) }, + {'params' : [self.input.bias], 'lr' : LR }, {'params' : [self.layer_stacks.l1_fact.weight], 'lr' : LR }, {'params' : [self.layer_stacks.l1.weight], 'lr' : LR }, {'params' : [self.layer_stacks.l1.bias], 'lr' : LR }, @@ -307,7 +343,8 @@ def configure_optimizers(self): {'params' : [self.layer_stacks.output.bias], 'lr' : LR / 10 }, ] # increasing the eps leads to less saturated nets with a few dead neurons - optimizer = ranger.Ranger(train_params, betas=(.9, 0.999), eps=1.0e-7, gc_loc=False, use_gc=False) + optimizer = NnueOptimizer(ranger.Ranger, train_params, betas=(.9, 0.999), eps=1.0e-7, gc_loc=False, use_gc=False) + optimizer.set_post_step_callback(lambda opt: self.post_optimizer_step(opt)) # Drop learning rate after 75 epochs scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.987) return [optimizer], [scheduler] diff --git a/optimizer.py b/optimizer.py new file mode 100644 index 00000000..8efcce0c --- /dev/null +++ b/optimizer.py @@ -0,0 +1,81 @@ +from torch.optim.optimizer import Optimizer, required +import torch + +def NnueOptimizer(optimizer_cls, params, **kwargs): + class SpecificNnueOptimizer(optimizer_cls): + def __init__(self, params, **kwargs): + super().__init__(params, **kwargs) + + self.state['nnue_optimizer']['num_finished_steps'] = 0 + self.post_step_callback = None + + self._add_default('maskable', False) + + + + def set_post_step_callback(self, callback): + self.post_step_callback = callback + + + + def get_num_finished_steps(self): + return self.state['nnue_optimizer']['num_finished_steps'] + + + + def freeze_parameter_region(self, param, indices): + if len(indices) != len(param.shape): + raise Exception('Invalid indices for parameter region.') + + if param not in self.state: + raise Exception('No state for parameter.') + + state = self.state[param] + + if 'weight_mask' not in state: + raise Exception('Parameter not masked.') + + state['weight_mask'][indices].fill_(0.0) + + + + def step(self, closure=None): + loss = super(SpecificNnueOptimizer, self).step(closure) + + for group in self.param_groups: + for p in group['params']: + state = self.state[p] # get state dict for this param + + if not 'nnue_optimizer_initialized' in state: + state['nnue_optimizer_initialized'] = True + + if group['maskable']: + state['weight_mask'] = torch.ones_like(p) + + if 'weight_mask' in state: + p.data.mul_(state['weight_mask']) + + self.state['nnue_optimizer']['num_finished_steps'] += 1 + + if self.post_step_callback is not None: + self.post_step_callback(self) + + return loss + + + + def _add_default(self, default_name, default_value): + if default_name in self.defaults: + raise Exception('Default already exists.') + + self.defaults[default_name] = default_value + + for group in self.param_groups: + if default_value is required and not default_name in group: + raise ValueError("parameter group didn't specify a value of required optimization parameter " + name) + else: + group.setdefault(default_name, default_value) + + + + return SpecificNnueOptimizer(params, **kwargs)