Skip to content

Commit

Permalink
feat: allow other optimizer parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp committed Jun 12, 2020
1 parent 574ad66 commit 16d92d5
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 12 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ You can also get comfortable with how the code works by playing with the **noteb
- momentum : float

Momentum for batch normalization, typically ranges from 0.01 to 0.4 (default=0.02)
- lr : float (default = 0.02)

Initial learning rate used for training. As mentionned in the original paper, a large initial learning of ```0.02 ``` with decay is a good option.
- clip_value : float (default None)

If a float is given this will clip the gradient at clip_value.
Expand All @@ -116,6 +114,10 @@ You can also get comfortable with how the code works by playing with the **noteb

Pytorch optimizer function

- optimizer_params: dict (default=dict(lr=2e-2))

Parameters compatible with optimizer_fn used initialize the optimizer. Since we have Adam as our default optimizer, we use this to define the initial learning rate used for training. As mentionned in the original paper, a large initial learning of ```0.02 ``` with decay is a good option.

- scheduler_fn : torch.optim.lr_scheduler (default=None)

Pytorch Scheduler to change learning rates during training.
Expand Down
11 changes: 7 additions & 4 deletions census_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,11 @@
"metadata": {},
"outputs": [],
"source": [
"clf = TabNetClassifier(cat_idxs=cat_idxs, cat_dims=cat_dims,\n",
" cat_emb_dim=1)"
"clf = TabNetClassifier(cat_idxs=cat_idxs,\n",
" cat_dims=cat_dims,\n",
" cat_emb_dim=1,\n",
" optimizer_fn=torch.optim.Adam,\n",
" optimizer_params=dict(lr=2e-2))"
]
},
{
Expand Down Expand Up @@ -180,14 +183,14 @@
"metadata": {},
"outputs": [],
"source": [
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 20"
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
"scrolled": false
},
"outputs": [],
"source": [
Expand Down
2 changes: 1 addition & 1 deletion forest_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,13 @@
"source": [
"clf = TabNetClassifier(\n",
" n_d=64, n_a=64, n_steps=5,\n",
" lr=0.02,\n",
" gamma=1.5, n_independent=2, n_shared=2,\n",
" cat_idxs=cat_idxs,\n",
" cat_dims=cat_dims,\n",
" cat_emb_dim=1,\n",
" lambda_sparse=1e-4, momentum=0.3, clip_value=2.,\n",
" optimizer_fn=torch.optim.Adam,\n",
" optimizer_params=dict(lr=2e-2),\n",
" scheduler_params = {\"gamma\": 0.95,\n",
" \"step_size\": 20},\n",
" scheduler_fn=torch.optim.lr_scheduler.StepLR, epsilon=1e-15\n",
Expand Down
2 changes: 1 addition & 1 deletion multi_regression_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
"metadata": {},
"outputs": [],
"source": [
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 20"
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 2"
]
},
{
Expand Down
7 changes: 4 additions & 3 deletions pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[],
n_independent=2, n_shared=2, epsilon=1e-15, momentum=0.02,
lambda_sparse=1e-3, seed=0,
clip_value=1, verbose=1,
lr=2e-2, optimizer_fn=torch.optim.Adam,
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
scheduler_params=None, scheduler_fn=None,
device_name='auto'):
""" Class for TabNet model
Expand All @@ -45,8 +46,8 @@ def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[],
self.lambda_sparse = lambda_sparse
self.clip_value = clip_value
self.verbose = verbose
self.lr = lr
self.optimizer_fn = optimizer_fn
self.optimizer_params = optimizer_params
self.device_name = device_name
self.scheduler_params = scheduler_params
self.scheduler_fn = scheduler_fn
Expand Down Expand Up @@ -140,7 +141,7 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None,
self.network.post_embed_dim)

self.optimizer = self.optimizer_fn(self.network.parameters(),
lr=self.lr)
**self.optimizer_params)

if self.scheduler_fn:
self.scheduler = self.scheduler_fn(self.optimizer, **self.scheduler_params)
Expand Down
2 changes: 1 addition & 1 deletion regression_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
"metadata": {},
"outputs": [],
"source": [
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 20"
"max_epochs = 1000 if not os.getenv(\"CI\", False) else 2"
]
},
{
Expand Down

0 comments on commit 16d92d5

Please sign in to comment.