Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a way to disable factorization #137

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytorch_lightning as pl
import copy
from feature_transformer import DoubleFeatureTransformerSlice
from optimizer import NnueOptimizer

# 3 layer fully connected network
L1 = 512
Expand Down Expand Up @@ -95,6 +96,14 @@ def forward(self, x, ls_indices):

return l3x_

def disable_l1_factorization(self, optimizer):
with torch.no_grad():
for i in range(self.count):
self.l1.weight[i*L2:(i+1)*L2, :].add_(self.l1_fact.weight.data)
self.l1_fact.weight.data.fill_(0.0)
self.l1_fact.weight.requires_grad = False


def get_coalesced_layer_stacks(self):
for i in range(self.count):
with torch.no_grad():
Expand Down Expand Up @@ -128,6 +137,9 @@ def __init__(self, feature_set, lambda_=1.0):
self.layer_stacks = LayerStacks(self.num_ls_buckets)
self.lambda_ = lambda_

self.disable_ft_factorization_after_steps = -1
self.disable_l1_factorization_after_steps = -1

self.weight_clipping = [
{'params' : [self.layer_stacks.l1.weight], 'min_weight' : -127/64, 'max_weight' : 127/64, 'virtual_params' : self.layer_stacks.l1_fact.weight },
{'params' : [self.layer_stacks.l2.weight], 'min_weight' : -127/64, 'max_weight' : 127/64 },
Expand Down Expand Up @@ -293,11 +305,35 @@ def validation_step(self, batch, batch_idx):
def test_step(self, batch, batch_idx):
self.step_(batch, batch_idx, 'test_loss')

def post_optimizer_step(self, optimizer):
num_finished_steps = optimizer.get_num_finished_steps()

if num_finished_steps == self.disable_l1_factorization_after_steps:
self.disable_l1_factorization(optimizer)

if num_finished_steps == self.disable_ft_factorization_after_steps:
self.disable_ft_factorization(optimizer)

def disable_l1_factorization(self, optimizer):
self.layer_stacks.disable_l1_factorization(optimizer)

def disable_ft_factorization(self, optimizer):
with torch.no_grad():
weight = self.input.weight.data
indices = self.feature_set.get_virtual_to_real_features_gather_indices()
for i_real, is_virtual in enumerate(indices):
weight[i_real, :] = sum(weight[i_virtual, :] for i_virtual in is_virtual)

for a, b in self.feature_set.get_virtual_feature_ranges():
weight[a:b, :].fill_(0.0)
optimizer.freeze_parameter_region(self.input.weight, (slice(a, b), slice(None, None)))

def configure_optimizers(self):
# Train with a lower LR on the output layer
LR = 1.5e-3
train_params = [
{'params' : get_parameters([self.input]), 'lr' : LR, 'gc_dim' : 0 },
{'params' : [self.input.weight], 'lr' : LR, 'gc_dim' : 0, 'maskable' : (self.disable_ft_factorization_after_steps != -1) },
{'params' : [self.input.bias], 'lr' : LR },
{'params' : [self.layer_stacks.l1_fact.weight], 'lr' : LR },
{'params' : [self.layer_stacks.l1.weight], 'lr' : LR },
{'params' : [self.layer_stacks.l1.bias], 'lr' : LR },
Expand All @@ -307,7 +343,8 @@ def configure_optimizers(self):
{'params' : [self.layer_stacks.output.bias], 'lr' : LR / 10 },
]
# increasing the eps leads to less saturated nets with a few dead neurons
optimizer = ranger.Ranger(train_params, betas=(.9, 0.999), eps=1.0e-7, gc_loc=False, use_gc=False)
optimizer = NnueOptimizer(ranger.Ranger, train_params, betas=(.9, 0.999), eps=1.0e-7, gc_loc=False, use_gc=False)
optimizer.set_post_step_callback(lambda opt: self.post_optimizer_step(opt))
# Drop learning rate after 75 epochs
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.987)
return [optimizer], [scheduler]
81 changes: 81 additions & 0 deletions optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from torch.optim.optimizer import Optimizer, required
import torch

def NnueOptimizer(optimizer_cls, params, **kwargs):
class SpecificNnueOptimizer(optimizer_cls):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if anyone has an idea how to do this part better I'd be happy to know

def __init__(self, params, **kwargs):
super().__init__(params, **kwargs)

self.state['nnue_optimizer']['num_finished_steps'] = 0
self.post_step_callback = None

self._add_default('maskable', False)



def set_post_step_callback(self, callback):
self.post_step_callback = callback



def get_num_finished_steps(self):
return self.state['nnue_optimizer']['num_finished_steps']



def freeze_parameter_region(self, param, indices):
if len(indices) != len(param.shape):
raise Exception('Invalid indices for parameter region.')

if param not in self.state:
raise Exception('No state for parameter.')

state = self.state[param]

if 'weight_mask' not in state:
raise Exception('Parameter not masked.')

state['weight_mask'][indices].fill_(0.0)



def step(self, closure=None):
loss = super(SpecificNnueOptimizer, self).step(closure)

for group in self.param_groups:
for p in group['params']:
state = self.state[p] # get state dict for this param

if not 'nnue_optimizer_initialized' in state:
state['nnue_optimizer_initialized'] = True

if group['maskable']:
state['weight_mask'] = torch.ones_like(p)

if 'weight_mask' in state:
p.data.mul_(state['weight_mask'])

self.state['nnue_optimizer']['num_finished_steps'] += 1

if self.post_step_callback is not None:
self.post_step_callback(self)

return loss



def _add_default(self, default_name, default_value):
if default_name in self.defaults:
raise Exception('Default already exists.')

self.defaults[default_name] = default_value

for group in self.param_groups:
if default_value is required and not default_name in group:
raise ValueError("parameter group didn't specify a value of required optimization parameter " + name)
else:
group.setdefault(default_name, default_value)



return SpecificNnueOptimizer(params, **kwargs)