From 31e191b3201b03ae572f66e66a6ec3cce73bb084 Mon Sep 17 00:00:00 2001 From: simon-hirsch Date: Fri, 30 Aug 2024 17:08:46 +0200 Subject: [PATCH 1/7] Add valid structures attribute to link functions --- src/rolch/abc.py | 4 +++- src/rolch/link.py | 10 ++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/rolch/abc.py b/src/rolch/abc.py index f725d37..d36b60f 100644 --- a/src/rolch/abc.py +++ b/src/rolch/abc.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Tuple, Union +from typing import List, Tuple, Union import numpy as np @@ -7,6 +7,8 @@ class LinkFunction(ABC): """The base class for the link functions.""" + self._valid_structures: List[str] + @abstractmethod def link(self, x: np.ndarray) -> np.ndarray: """Calculate the Link""" diff --git a/src/rolch/link.py b/src/rolch/link.py index d35f8c4..d6e569f 100644 --- a/src/rolch/link.py +++ b/src/rolch/link.py @@ -15,7 +15,7 @@ class LogLink(LinkFunction): """ def __init__(self): - pass + self._valid_structures = ["vector", "matrix", "square_matrix"] def link(self, x): return np.log(np.fmax(x, LOG_LOWER_BOUND)) @@ -38,7 +38,7 @@ class IdentityLink(LinkFunction): """ def __init__(self): - pass + self._valid_structures = ["vector", "matrix", "square_matrix"] def link(self, x): return x @@ -62,6 +62,7 @@ class LogShiftValueLink(LinkFunction): def __init__(self, value): self.value = value + self._valid_structures = ["vector", "matrix", "square_matrix"] def link(self, x): return np.log(x - self.value + LOG_LOWER_BOUND) @@ -97,7 +98,7 @@ class SqrtLink(LinkFunction): """ def __init__(self): - pass + self._valid_structures = ["vector", "matrix", "square_matrix"] def link(self, x): return np.sqrt(x) @@ -121,6 +122,7 @@ class SqrtShiftValueLink(LinkFunction): def __init__(self, value): self.value = value + self._valid_structures = ["vector", "matrix", "square_matrix"] def link(self, x): return np.sqrt(x - self.value + LOG_LOWER_BOUND) @@ -156,7 +158,7 @@ class LogIdentLink(LinkFunction): """ def __init__(self): - pass + self._valid_structures = ["vector", "matrix", "square_matrix"] def link(self, x: np.ndarray): return np.where(x <= 1, np.log(x), x - 1) From e0cf3f85d11591978d4c4aced4cc2226ed357a08 Mon Sep 17 00:00:00 2001 From: simon-hirsch Date: Fri, 30 Aug 2024 17:14:01 +0200 Subject: [PATCH 2/7] Add param structure to distributions --- src/rolch/abc.py | 12 ++++++++++++ src/rolch/distributions/johnsonsu.py | 2 ++ src/rolch/distributions/normal.py | 2 ++ src/rolch/distributions/studentt.py | 3 +++ 4 files changed, 19 insertions(+) diff --git a/src/rolch/abc.py b/src/rolch/abc.py index d36b60f..84f018c 100644 --- a/src/rolch/abc.py +++ b/src/rolch/abc.py @@ -24,6 +24,18 @@ def derivative(self, x: np.ndarray) -> np.ndarray: class Distribution(ABC): + self.links: Dict[int, Linkfunction] | List[Linkfunction] + self._param_structure: Dict[int, str] + + def _check_links(self): + for p in range(self.n_params): + if self.param_structure[p] not in self.links[p]._valid_structures: + raise ValueError( + f"Link function does not match parameter structure for parameter {p}. \n" + f"Parameter structure is {self.param_structure[p]}. \n" + f"Link function supports {self.links[p]._valid_structures}" + ) + @abstractmethod def theta_to_params(self, theta: np.ndarray) -> Tuple: """Take the fitted values and return tuple of vectors for distribution parameters.""" diff --git a/src/rolch/distributions/johnsonsu.py b/src/rolch/distributions/johnsonsu.py index 26644a8..371c7a1 100644 --- a/src/rolch/distributions/johnsonsu.py +++ b/src/rolch/distributions/johnsonsu.py @@ -34,6 +34,8 @@ def __init__( self.shape_link, # skew self.tail_link, # tail ] + self._param_structure = {0: "vector", 1: "vector", 2: "vector", 3: "vector"} + self._check_links() def theta_to_params(self, theta): mu = theta[:, 0] diff --git a/src/rolch/distributions/normal.py b/src/rolch/distributions/normal.py index 2ebe6f5..d7e55a0 100644 --- a/src/rolch/distributions/normal.py +++ b/src/rolch/distributions/normal.py @@ -13,6 +13,8 @@ def __init__(self, loc_link=IdentityLink(), scale_link=LogLink()): self.loc_link = loc_link self.scale_link = scale_link self.links = [self.loc_link, self.scale_link] + self._param_structure = {0: "vector", 1: "vector", 2} + self._check_links() def theta_to_params(self, theta): mu = theta[:, 0] diff --git a/src/rolch/distributions/studentt.py b/src/rolch/distributions/studentt.py index d54a012..c48a0d5 100644 --- a/src/rolch/distributions/studentt.py +++ b/src/rolch/distributions/studentt.py @@ -17,6 +17,9 @@ def __init__( self.scale_link = scale_link self.tail_link = tail_link self.links = [self.loc_link, self.scale_link, self.tail_link] + self._param_structure = {0: "vector", 1: "vector", 2, 3: "vector"} + self._check_links() + def theta_to_params(self, theta): mu = theta[:, 0] From d0d95c20b3355b8c32e46e2dc83988f8d729ed9f Mon Sep 17 00:00:00 2001 From: simon-hirsch Date: Fri, 30 Aug 2024 17:21:39 +0200 Subject: [PATCH 3/7] Remove self for abstract base class attributes --- src/rolch/abc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rolch/abc.py b/src/rolch/abc.py index 84f018c..332217f 100644 --- a/src/rolch/abc.py +++ b/src/rolch/abc.py @@ -7,7 +7,7 @@ class LinkFunction(ABC): """The base class for the link functions.""" - self._valid_structures: List[str] + _valid_structures: List[str] @abstractmethod def link(self, x: np.ndarray) -> np.ndarray: @@ -24,8 +24,8 @@ def derivative(self, x: np.ndarray) -> np.ndarray: class Distribution(ABC): - self.links: Dict[int, Linkfunction] | List[Linkfunction] - self._param_structure: Dict[int, str] + links: Dict[int, Linkfunction] | List[Linkfunction] + _param_structure: Dict[int, str] def _check_links(self): for p in range(self.n_params): From bacfc2633c8e2eb903e30db3635f65bf38f40473 Mon Sep 17 00:00:00 2001 From: simon-hirsch Date: Fri, 30 Aug 2024 17:23:02 +0200 Subject: [PATCH 4/7] Add Dict to typing imports --- src/rolch/abc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rolch/abc.py b/src/rolch/abc.py index 332217f..eabdcee 100644 --- a/src/rolch/abc.py +++ b/src/rolch/abc.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Tuple, Union +from typing import Dict, List, Tuple, Union import numpy as np From 3c7e2f43bb80fbe56890d0b9c5d611a09a6dfca4 Mon Sep 17 00:00:00 2001 From: simon-hirsch Date: Fri, 30 Aug 2024 17:24:08 +0200 Subject: [PATCH 5/7] Fix typo --- src/rolch/abc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rolch/abc.py b/src/rolch/abc.py index eabdcee..b28c29f 100644 --- a/src/rolch/abc.py +++ b/src/rolch/abc.py @@ -24,7 +24,7 @@ def derivative(self, x: np.ndarray) -> np.ndarray: class Distribution(ABC): - links: Dict[int, Linkfunction] | List[Linkfunction] + links: Dict[int, LinkFunction] | List[LinkFunction] _param_structure: Dict[int, str] def _check_links(self): From b0e932377e9ab52d6576cb3e37be37c54e0ffc7f Mon Sep 17 00:00:00 2001 From: simon-hirsch Date: Fri, 30 Aug 2024 17:25:46 +0200 Subject: [PATCH 6/7] Fix bugs --- src/rolch/distributions/johnsonsu.py | 2 +- src/rolch/distributions/normal.py | 2 +- src/rolch/distributions/studentt.py | 6 +++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/rolch/distributions/johnsonsu.py b/src/rolch/distributions/johnsonsu.py index 371c7a1..a7ae4d7 100644 --- a/src/rolch/distributions/johnsonsu.py +++ b/src/rolch/distributions/johnsonsu.py @@ -34,7 +34,7 @@ def __init__( self.shape_link, # skew self.tail_link, # tail ] - self._param_structure = {0: "vector", 1: "vector", 2: "vector", 3: "vector"} + self._param_structure = {0: "vector", 1: "vector": 2: "vector", 3: "vector"} self._check_links() def theta_to_params(self, theta): diff --git a/src/rolch/distributions/normal.py b/src/rolch/distributions/normal.py index d7e55a0..ea209e4 100644 --- a/src/rolch/distributions/normal.py +++ b/src/rolch/distributions/normal.py @@ -13,7 +13,7 @@ def __init__(self, loc_link=IdentityLink(), scale_link=LogLink()): self.loc_link = loc_link self.scale_link = scale_link self.links = [self.loc_link, self.scale_link] - self._param_structure = {0: "vector", 1: "vector", 2} + self._param_structure = {0: "vector", 1: "vector"} self._check_links() def theta_to_params(self, theta): diff --git a/src/rolch/distributions/studentt.py b/src/rolch/distributions/studentt.py index c48a0d5..8b67560 100644 --- a/src/rolch/distributions/studentt.py +++ b/src/rolch/distributions/studentt.py @@ -17,7 +17,11 @@ def __init__( self.scale_link = scale_link self.tail_link = tail_link self.links = [self.loc_link, self.scale_link, self.tail_link] - self._param_structure = {0: "vector", 1: "vector", 2, 3: "vector"} + self._param_structure = { + 0: "vector", + 1: "vector": 2, + 3: "vector" + } self._check_links() From a12cb8ba18c1320e99ad0e677da4e3a9369a4421 Mon Sep 17 00:00:00 2001 From: simon-hirsch Date: Fri, 30 Aug 2024 17:27:17 +0200 Subject: [PATCH 7/7] Fix some more bugs --- src/rolch/distributions/johnsonsu.py | 2 +- src/rolch/distributions/studentt.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/rolch/distributions/johnsonsu.py b/src/rolch/distributions/johnsonsu.py index a7ae4d7..371c7a1 100644 --- a/src/rolch/distributions/johnsonsu.py +++ b/src/rolch/distributions/johnsonsu.py @@ -34,7 +34,7 @@ def __init__( self.shape_link, # skew self.tail_link, # tail ] - self._param_structure = {0: "vector", 1: "vector": 2: "vector", 3: "vector"} + self._param_structure = {0: "vector", 1: "vector", 2: "vector", 3: "vector"} self._check_links() def theta_to_params(self, theta): diff --git a/src/rolch/distributions/studentt.py b/src/rolch/distributions/studentt.py index 8b67560..9628e2e 100644 --- a/src/rolch/distributions/studentt.py +++ b/src/rolch/distributions/studentt.py @@ -17,14 +17,9 @@ def __init__( self.scale_link = scale_link self.tail_link = tail_link self.links = [self.loc_link, self.scale_link, self.tail_link] - self._param_structure = { - 0: "vector", - 1: "vector": 2, - 3: "vector" - } + self._param_structure = {0: "vector", 1: "vector", 2: "vector"} self._check_links() - def theta_to_params(self, theta): mu = theta[:, 0] sigma = theta[:, 1]