Skip to content

Commit

Permalink
batch size for learning rate
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixBenning committed May 12, 2024
1 parent 1a0fdf8 commit 6c5c3ab
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 22 deletions.
7 changes: 5 additions & 2 deletions benchmarking/classification/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def train_rfd(problem, cov_string, covariance_model):
)

classifier = Classifier(
problem["model"](), optimizer=RFD, covariance_model=covariance_model
problem["model"](),
optimizer=RFD,
covariance_model=covariance_model,
b_size_inv=1/problem["batch_size"],
)

trainer = trainer_from_problem(problem, opt_name=f"RFD({cov_string})")
Expand All @@ -62,7 +65,7 @@ def trainer_from_problem(problem, opt_name):
tlogger.log_hyperparams(params={"batch_size": problem["batch_size"]})
csvlogger.log_hyperparams(params={"batch_size": problem["batch_size"]})
return L.Trainer(
max_epochs=2,
max_epochs=5,
log_every_n_steps=1,
logger=[tlogger, csvlogger],
)
Expand Down
28 changes: 22 additions & 6 deletions benchmarking/classification/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,43 @@ def training_step(self, batch, *args, **kwargs):
x_in, y_out = batch
prediction: torch.Tensor = self.model(x_in)
loss_value = F.nll_loss(prediction, y_out)
self.log("train_loss", loss_value, prog_bar=True)
self.log("train/loss", loss_value, prog_bar=True)
acc = metrics.multiclass_accuracy(prediction, y_out, prediction.size(dim=1))
self.log("train_accuracy", acc, on_epoch=True)
self.log("train/accuracy", acc, on_epoch=True)
return loss_value

def validation_step(self, batch):
"""Apply Negative Log-Likelihood loss to the model output and logging"""
x_in, y_out = batch
prediction: torch.Tensor = self.model(x_in)
loss_value = F.nll_loss(prediction, y_out)
self.log("val/loss", loss_value, on_epoch=True)

acc = metrics.multiclass_accuracy(prediction, y_out, prediction.size(dim=1))
self.log("val/accuracy", acc, on_epoch=True)

rec = metrics.multiclass_recall(prediction, y_out, prediction.size(dim=1))
self.log("val/recall", rec, on_epoch=True)

prec = metrics.multiclass_precision(prediction, y_out, prediction.size(dim=1))
self.log("val/precision", prec, on_epoch=True)

# pylint: disable=arguments-differ
def test_step(self, batch):
"""Apply Negative Log-Likelihood loss to the model output and logging"""
x_in, y_out = batch
prediction: torch.Tensor = self.model(x_in)
loss_value = F.nll_loss(prediction, y_out)
self.log("test_loss", loss_value, on_epoch=True)
self.log("test/loss", loss_value, on_epoch=True)

acc = metrics.multiclass_accuracy(prediction, y_out, prediction.size(dim=1))
self.log("test_accuracy", acc, on_epoch=True)
self.log("test/accuracy", acc, on_epoch=True)

rec = metrics.multiclass_recall(prediction, y_out, prediction.size(dim=1))
self.log("test_recall", rec, on_epoch=True)
self.log("test/recall", rec, on_epoch=True)

prec = metrics.multiclass_precision(prediction, y_out, prediction.size(dim=1))
self.log("test_precision", prec, on_epoch=True)
self.log("test/precision", prec, on_epoch=True)

def configure_optimizers(self):
return self.optimizer(self.parameters(), **self.hyperparemeters)
Expand Down
33 changes: 22 additions & 11 deletions pyrfd/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@ def __repr__(self) -> str:
)

@abstractmethod
def learning_rate(self, loss, grad_norm, b_size_inverse=0):
def learning_rate(self, loss, grad_norm, b_size_inv=0):
"""learning rate of this covariance model from the RFD paper"""
return NotImplemented

def asymptotic_learning_rate(self, b_size_inverse=0, limiting_loss=0):
def asymptotic_learning_rate(self, b_size_inv=0, limiting_loss=0):
"""asymptotic learning rate of RFD
b_size_inverse:
Expand All @@ -193,11 +193,11 @@ def asymptotic_learning_rate(self, b_size_inverse=0, limiting_loss=0):
"""
assert self.fitted, "The covariance has not been fitted yet."
assert (
b_size_inverse <= 1
b_size_inv <= 1
), "Please pass the batch size inverse 1/b not the batch size b"
enumerator = self.var_reg.predict(b_size_inverse)
enumerator = self.var_reg.predict(np.array(b_size_inv).reshape(-1, 1))[0]
denominator = (
self.g_var_reg.predict(b_size_inverse)
self.g_var_reg.predict(np.array(b_size_inv).reshape(-1, 1))[0]
/ self.dims
* (self.mean - limiting_loss)
)
Expand Down Expand Up @@ -275,7 +275,9 @@ def auto_fit(
different batch size parameters such that it returns (x,y) tuples when
iterated on
"""
self.dims = sum(p.numel() for p in model_factory().parameters() if p.requires_grad)
self.dims = sum(
p.numel() for p in model_factory().parameters() if p.requires_grad
)
print(f"\n\nAutomatically fitting Covariance Model: {repr(self)}")

cached_samples = CachedSamples(cache)
Expand Down Expand Up @@ -390,16 +392,25 @@ def scale(self):
"The covariance is not fitted yet, use `auto_fit` or `fit` before use"
)

def learning_rate(self, loss, grad_norm, b_size_inverse=0):
def learning_rate(self, loss, grad_norm, b_size_inv=0):
"""RFD learning rate from Random Function Descent paper"""

var_reg = self.var_reg
g_var_reg = self.g_var_reg

# C(0)/(C(0) + 1/b * C_eps(0))
var_adjust = self.var_reg.intercept_/self.var_reg.predict(b_size_inverse)
var_g_adjust = self.g_var_reg.intercept_/self.g_var_reg.predict(b_size_inverse)
var_adjust = var_reg.intercept_ / (
var_reg.intercept_ + var_reg.coef_[0] * b_size_inv
)
var_g_adjust = g_var_reg.intercept_ / (
g_var_reg.intercept_ + g_var_reg.coef_[0] * b_size_inv
)

tmp = var_adjust * (self.mean - loss) / 2
return var_g_adjust * (self.scale**2) / (
torch.sqrt(tmp**2 + (self.scale * grad_norm * var_g_adjust) ** 2) + tmp
return (
var_g_adjust
* (self.scale**2)
/ (torch.sqrt(tmp**2 + (self.scale * grad_norm * var_g_adjust) ** 2) + tmp)
)


Expand Down
16 changes: 13 additions & 3 deletions pyrfd/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@ class RFD(Optimizer):
"""

def __init__(
self, params, *, covariance_model: IsotropicCovariance, momentum=0, lr=1
self,
params,
*,
covariance_model: IsotropicCovariance,
momentum=0,
lr=1,
b_size_inv=0,
):
defaults = {
"cov": covariance_model,
"momentum": momentum,
"lr": lr, # really a learning rate multiplier,
# but this name ensures compatibility with schedulers
"learning_rate": None,
"b_size_inv": b_size_inv,
}
super().__init__(params, defaults)

Expand Down Expand Up @@ -52,10 +59,13 @@ def step(self, closure): # pylint: disable=locally-disabled, signature-differs
grad_norm = torch.cat(grads).norm()

momentum = group["momentum"]
lr = group["lr"]
lr_multiplier = group["lr"]
b_size_inv = group["b_size_inv"]

cov_model: IsotropicCovariance = group["cov"]
learning_rate = lr * cov_model.learning_rate(loss, grad_norm)
learning_rate = lr_multiplier * cov_model.learning_rate(
loss, grad_norm, b_size_inv=b_size_inv
)
group["learning_rate"] = learning_rate

for param in group["params"]:
Expand Down

0 comments on commit 6c5c3ab

Please sign in to comment.