|
7 | 7 | """
|
8 | 8 |
|
9 | 9 | import numpy as np
|
| 10 | +import pandas as pd |
10 | 11 | import xspline
|
11 | 12 | from numpy.typing import NDArray
|
12 | 13 |
|
@@ -451,7 +452,7 @@ def create_spline(
|
451 | 452 |
|
452 | 453 | Returns
|
453 | 454 | -------
|
454 |
| - xspline.XSpline |
| 455 | + XSpline |
455 | 456 | The spline object.
|
456 | 457 |
|
457 | 458 | """
|
@@ -535,7 +536,7 @@ def create_design_mat(self, data) -> tuple[NDArray, NDArray]:
|
535 | 536 |
|
536 | 537 | Returns
|
537 | 538 | -------
|
538 |
| - tuple[numpy.ndarray, numpy.ndarray] |
| 539 | + tuple[NDArray, NDArray] |
539 | 540 | Return the design matrix for linear cov or spline.
|
540 | 541 |
|
541 | 542 | """
|
@@ -832,7 +833,7 @@ def create_z_mat(self, data):
|
832 | 833 |
|
833 | 834 | Returns
|
834 | 835 | -------
|
835 |
| - numpy.ndarray |
| 836 | + NDArray |
836 | 837 | Design matrix for random effects.
|
837 | 838 |
|
838 | 839 | """
|
@@ -884,7 +885,7 @@ def create_z_mat(self, data):
|
884 | 885 |
|
885 | 886 | Returns
|
886 | 887 | -------
|
887 |
| - numpy.ndarray |
| 888 | + NDArray |
888 | 889 | Design matrix for random effects.
|
889 | 890 |
|
890 | 891 | """
|
@@ -929,3 +930,110 @@ def num_constraints(self):
|
929 | 930 | @property
|
930 | 931 | def num_z_vars(self):
|
931 | 932 | return int(self.use_re)
|
| 933 | + |
| 934 | + |
| 935 | +class CatCovModel(CovModel): |
| 936 | + """Categorical covariate model. |
| 937 | +
|
| 938 | + TODO: Add order prior. |
| 939 | + """ |
| 940 | + |
| 941 | + def __init__( |
| 942 | + self, |
| 943 | + alt_cov, |
| 944 | + name=None, |
| 945 | + ref_cov=None, |
| 946 | + ref_cat=None, |
| 947 | + use_re=False, |
| 948 | + prior_beta_gaussian=None, |
| 949 | + prior_beta_uniform=None, |
| 950 | + prior_beta_laplace=None, |
| 951 | + prior_gamma_gaussian=None, |
| 952 | + prior_gamma_uniform=None, |
| 953 | + prior_gamma_laplace=None, |
| 954 | + ) -> None: |
| 955 | + super().__init__( |
| 956 | + alt_cov=alt_cov, |
| 957 | + name=name, |
| 958 | + ref_cov=ref_cov, |
| 959 | + use_re=use_re, |
| 960 | + prior_beta_gaussian=prior_beta_gaussian, |
| 961 | + prior_beta_uniform=prior_beta_uniform, |
| 962 | + prior_beta_laplace=prior_beta_laplace, |
| 963 | + prior_gamma_gaussian=prior_gamma_gaussian, |
| 964 | + prior_gamma_uniform=prior_gamma_uniform, |
| 965 | + prior_gamma_laplace=prior_gamma_laplace, |
| 966 | + ) |
| 967 | + self.ref_cat = ref_cat |
| 968 | + if len(self.alt_cov) != 1: |
| 969 | + raise ValueError("alt_cov should be a single column.") |
| 970 | + if len(self.ref_cov) > 1: |
| 971 | + raise ValueError("ref_cov should be nothing or a single column.") |
| 972 | + |
| 973 | + self.cats: pd.Series |
| 974 | + |
| 975 | + def attach_data(self, data: MRData) -> None: |
| 976 | + """Attach data and parse the categories. Number of variables will be |
| 977 | + determined here and priors will be processed here as well. |
| 978 | +
|
| 979 | + """ |
| 980 | + alt_cov = data.get_covs(self.alt_cov) |
| 981 | + ref_cov = data.get_covs(self.ref_cov) |
| 982 | + self.cats = pd.Series( |
| 983 | + np.unique(np.hstack([alt_cov, ref_cov])), |
| 984 | + name="cats", |
| 985 | + ) |
| 986 | + self._process_priors() |
| 987 | + |
| 988 | + def has_data(self) -> bool: |
| 989 | + """Return if the data has been attached and categories has been parsed.""" |
| 990 | + return hasattr(self, "cats") |
| 991 | + |
| 992 | + def encode(self, x: NDArray) -> NDArray: |
| 993 | + """Encode the provided categories into dummy variables.""" |
| 994 | + col = pd.merge(pd.Series(x, name="cats"), self.cats.reset_index())[ |
| 995 | + "index" |
| 996 | + ] |
| 997 | + mat = np.zeros((len(x), self.num_x_vars)) |
| 998 | + mat[range(len(x)), col] = 1.0 |
| 999 | + return mat |
| 1000 | + |
| 1001 | + def create_design_mat(self, data: MRData) -> tuple[NDArray, NDArray]: |
| 1002 | + """Create design matrix for alternative and reference categories.""" |
| 1003 | + alt_cov = data.get_covs(self.alt_cov).ravel() |
| 1004 | + ref_cov = data.get_covs(self.ref_cov).ravel() |
| 1005 | + |
| 1006 | + alt_mat = self.encode(alt_cov) |
| 1007 | + if ref_cov.size == 0: |
| 1008 | + ref_mat = np.zeros((len(alt_cov), self.num_x_vars)) |
| 1009 | + else: |
| 1010 | + ref_mat = self.encode(ref_cov) |
| 1011 | + return alt_mat, ref_mat |
| 1012 | + |
| 1013 | + def create_constraint_mat(self) -> tuple[NDArray, NDArray]: |
| 1014 | + """TODO: Create constraint matrix from order priors.""" |
| 1015 | + return np.empty((0, self.num_x_vars)), np.empty((2, 0)) |
| 1016 | + |
| 1017 | + @property |
| 1018 | + def num_x_vars(self) -> int: |
| 1019 | + """Number of the fixed effects. Returns 0 if data is not attached |
| 1020 | + otherwise it will return the number of categories. |
| 1021 | +
|
| 1022 | + """ |
| 1023 | + if not hasattr(self, "cats"): |
| 1024 | + return 0 |
| 1025 | + return len(self.cats) |
| 1026 | + |
| 1027 | + @property |
| 1028 | + def num_z_vars(self) -> int: |
| 1029 | + """Number of the random effects. Currently it is the same with the |
| 1030 | + number of the fixed effects, but this is to be discussed. |
| 1031 | + TODO: Overwrite the number of random effects. |
| 1032 | +
|
| 1033 | + """ |
| 1034 | + return self.num_x_vars |
| 1035 | + |
| 1036 | + @property |
| 1037 | + def num_constraints(self) -> int: |
| 1038 | + """TODO: Overwrite the number of constraints.""" |
| 1039 | + return 0 |
0 commit comments