Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Refactor links and distributions #20

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/rolch/abc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from abc import ABC, abstractmethod
from typing import Tuple, Union
from typing import Dict, List, Tuple, Union

import numpy as np


class LinkFunction(ABC):
"""The base class for the link functions."""

_valid_structures: List[str]

@abstractmethod
def link(self, x: np.ndarray) -> np.ndarray:
"""Calculate the Link"""
Expand All @@ -22,6 +24,18 @@ def derivative(self, x: np.ndarray) -> np.ndarray:

class Distribution(ABC):

links: Dict[int, LinkFunction] | List[LinkFunction]
_param_structure: Dict[int, str]

def _check_links(self):
for p in range(self.n_params):
if self.param_structure[p] not in self.links[p]._valid_structures:
raise ValueError(
f"Link function does not match parameter structure for parameter {p}. \n"
f"Parameter structure is {self.param_structure[p]}. \n"
f"Link function supports {self.links[p]._valid_structures}"
)

@abstractmethod
def theta_to_params(self, theta: np.ndarray) -> Tuple:
"""Take the fitted values and return tuple of vectors for distribution parameters."""
Expand Down
2 changes: 2 additions & 0 deletions src/rolch/distributions/johnsonsu.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __init__(
self.shape_link, # skew
self.tail_link, # tail
]
self._param_structure = {0: "vector", 1: "vector", 2: "vector", 3: "vector"}
self._check_links()

def theta_to_params(self, theta):
mu = theta[:, 0]
Expand Down
2 changes: 2 additions & 0 deletions src/rolch/distributions/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def __init__(self, loc_link=IdentityLink(), scale_link=LogLink()):
self.loc_link = loc_link
self.scale_link = scale_link
self.links = [self.loc_link, self.scale_link]
self._param_structure = {0: "vector", 1: "vector"}
self._check_links()

def theta_to_params(self, theta):
mu = theta[:, 0]
Expand Down
2 changes: 2 additions & 0 deletions src/rolch/distributions/studentt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(
self.scale_link = scale_link
self.tail_link = tail_link
self.links = [self.loc_link, self.scale_link, self.tail_link]
self._param_structure = {0: "vector", 1: "vector", 2: "vector"}
self._check_links()

def theta_to_params(self, theta):
mu = theta[:, 0]
Expand Down
10 changes: 6 additions & 4 deletions src/rolch/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LogLink(LinkFunction):
"""

def __init__(self):
pass
self._valid_structures = ["vector", "matrix", "square_matrix"]

def link(self, x):
return np.log(np.fmax(x, LOG_LOWER_BOUND))
Expand All @@ -38,7 +38,7 @@ class IdentityLink(LinkFunction):
"""

def __init__(self):
pass
self._valid_structures = ["vector", "matrix", "square_matrix"]

def link(self, x):
return x
Expand All @@ -62,6 +62,7 @@ class LogShiftValueLink(LinkFunction):

def __init__(self, value):
self.value = value
self._valid_structures = ["vector", "matrix", "square_matrix"]

def link(self, x):
return np.log(x - self.value + LOG_LOWER_BOUND)
Expand Down Expand Up @@ -97,7 +98,7 @@ class SqrtLink(LinkFunction):
"""

def __init__(self):
pass
self._valid_structures = ["vector", "matrix", "square_matrix"]

def link(self, x):
return np.sqrt(x)
Expand All @@ -121,6 +122,7 @@ class SqrtShiftValueLink(LinkFunction):

def __init__(self, value):
self.value = value
self._valid_structures = ["vector", "matrix", "square_matrix"]

def link(self, x):
return np.sqrt(x - self.value + LOG_LOWER_BOUND)
Expand Down Expand Up @@ -156,7 +158,7 @@ class LogIdentLink(LinkFunction):
"""

def __init__(self):
pass
self._valid_structures = ["vector", "matrix", "square_matrix"]

def link(self, x: np.ndarray):
return np.where(x <= 1, np.log(x), x - 1)
Expand Down
Loading