diff --git a/causallearn/score/LocalScoreFunctionClass.py b/causallearn/score/LocalScoreFunctionClass.py index 8171d51..64b2081 100644 --- a/causallearn/score/LocalScoreFunctionClass.py +++ b/causallearn/score/LocalScoreFunctionClass.py @@ -29,7 +29,7 @@ def __init__( self.parameters = parameters self.score_cache = {} - if self.local_score_fun == local_score_BIC_from_cov: + if self.local_score_fun.__name__ == 'local_score_BIC_from_cov': self.cov = np.cov(self.data.T) self.n = self.data.shape[0] @@ -40,7 +40,7 @@ def score(self, i: int, PAi: List[int]) -> float: hash_key = tuple(sorted(PAi)) if not self.score_cache[i].__contains__(hash_key): - if self.local_score_fun == local_score_BIC_from_cov: + if self.local_score_fun.__name__ == 'local_score_BIC_from_cov': self.score_cache[i][hash_key] = self.local_score_fun((self.cov, self.n), i, PAi, self.parameters) else: self.score_cache[i][hash_key] = self.local_score_fun(self.data, i, PAi, self.parameters) @@ -48,7 +48,7 @@ def score(self, i: int, PAi: List[int]) -> float: return self.score_cache[i][hash_key] def score_nocache(self, i: int, PAi: List[int]) -> float: - if self.local_score_fun == local_score_BIC_from_cov: + if self.local_score_fun.__name__ == 'local_score_BIC_from_cov': return self.local_score_fun((self.cov, self.n), i, PAi, self.parameters) else: - return self.local_score_fun(self.data, i, PAi, self.parameters) \ No newline at end of file + return self.local_score_fun(self.data, i, PAi, self.parameters) diff --git a/causallearn/search/PermutationBased/BOSS.py b/causallearn/search/PermutationBased/BOSS.py index cf18d43..85551b8 100644 --- a/causallearn/search/PermutationBased/BOSS.py +++ b/causallearn/search/PermutationBased/BOSS.py @@ -23,7 +23,7 @@ def boss( X: np.ndarray, - score_func: str = "local_score_BIC", + score_func: str = "local_score_BIC_from_cov", parameters: Optional[Dict[str, Any]] = None, verbose: Optional[bool] = True, node_names: Optional[List[str]] = None,