Skip to content

Commit

Permalink
[Lint] Black linter
Browse files Browse the repository at this point in the history
  • Loading branch information
ephoris committed Jan 23, 2024
1 parent 6d5d465 commit bd9b38f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
6 changes: 5 additions & 1 deletion endure/lcm/model/kaplsm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@

DECISION_DIM = 64


class KapEmbedding(nn.Module):
"""
Special embedding that creates separate embeddings for each K_i on each
level. Number of k's will dictate the number of linear layers.
"""

def __init__(
self,
input_size: int,
Expand All @@ -37,6 +39,7 @@ def forward(self, x: Tensor) -> Tensor:

return out


class KapModel(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -90,7 +93,8 @@ def _split_input(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
capacities = capacities.to(torch.long)
capacities = F.one_hot(capacities, num_classes=self.capacity_range)
else:
capacities = torch.unflatten(capacities, 1, (-1, self.capacity_range))
capacities = torch.unflatten(
capacities, 1, (-1, self.capacity_range))

size_ratio = capacities[:, 0, :]
k_cap = capacities[:, 1:, :]
Expand Down
7 changes: 5 additions & 2 deletions endure/lsm/solver/klsm_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .util import kl_div_con, get_bounds
from .util import H_DEFAULT, T_DEFAULT, LAMBDA_DEFAULT, ETA_DEFAULT, K_DEFAULT


class KLSMSolver:
def __init__(self, config: dict[str, Any]):
self.config = config
Expand All @@ -28,8 +29,10 @@ def robust_objective(
lamb, eta = x[-2:]
design = LSMDesign(h=h, T=t, K=kaps, policy=Policy.KHybrid)
query_cost = 0
query_cost += z0 * kl_div_con((self.cf.Z0(design, system) - eta) / lamb)
query_cost += z1 * kl_div_con((self.cf.Z1(design, system) - eta) / lamb)
query_cost += z0 * \
kl_div_con((self.cf.Z0(design, system) - eta) / lamb)
query_cost += z1 * \
kl_div_con((self.cf.Z1(design, system) - eta) / lamb)
query_cost += q * kl_div_con((self.cf.Q(design, system) - eta) / lamb)
query_cost += w * kl_div_con((self.cf.W(design, system) - eta) / lamb)
cost = eta + (rho * lamb) + (lamb * query_cost)
Expand Down

0 comments on commit bd9b38f

Please sign in to comment.