Skip to content

Commit

Permalink
implemented initialization from jax weights for linear and atomistic …
Browse files Browse the repository at this point in the history
…readout
  • Loading branch information
M-R-Schaefer committed Apr 26, 2024
1 parent cc23f3d commit 7f77443
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
13 changes: 10 additions & 3 deletions apax/nn/torch/layers/ntk_linear.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class NTKLinearT(nn.Module):
def __init__(self, units_in, units_out) -> None:
def __init__(self, units_in=None, units_out=None, params=None) -> None:
super().__init__()

self.bias_factor = 0.1
# self.weight_factor = torch.sqrt(1.0 / dim_in)
if params:
w = torch.from_numpy(np.array(params["w"]).T)
b = torch.from_numpy(np.array(params["b"]))
else:
w = torch.rand((units_out, units_in))
b = torch.rand((units_out))

self.w = nn.Parameter(torch.rand((units_out, units_in)))
self.b = nn.Parameter(torch.rand((units_out)))
self.w = nn.Parameter(w)
self.b = nn.Parameter(b)
self.one = torch.tensor(1.0)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
24 changes: 15 additions & 9 deletions apax/nn/torch/layers/readout.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, List
from typing import Any, Callable, List, Optional

from apax.nn.torch.layers.activation import SwishT
from apax.nn.torch.layers.ntk_linear import NTKLinearT


class AtomisticReadoutT(nn.Module):
def __init__(
self, units: List[int] = [512, 512], activation_fn: Callable = SwishT
self, units: Optional[List[int]] = None, params_list = None, activation_fn: Callable = SwishT
) -> None:
super().__init__()

units = [360] + [u for u in units] + [1]
dense = []
for ii in range(len(units) - 1):
units_in, units_out = units[ii], units[ii + 1]
dense.append(NTKLinearT(units_in, units_out))
if ii < len(units) - 2:
dense.append(activation_fn())
if params_list:
for params in params_list:
dense.append(NTKLinearT(params=params))
if ii < len(units) - 2:
dense.append(activation_fn())
else:
units = [360] + [u for u in units] + [1]
for ii in range(len(units) - 1):
units_in, units_out = units[ii], units[ii + 1]
dense.append(NTKLinearT(units_in, units_out))
if ii < len(units) - 2:
dense.append(activation_fn())

self.sequential = nn.Sequential(*dense)

def forward(self, x):
Expand Down

0 comments on commit 7f77443

Please sign in to comment.