Skip to content

Commit

Permalink
Merge pull request #216 from kunwuz/main
Browse files Browse the repository at this point in the history
Update LocalScoreFunctionClass to fix issue calling local_score_BIC_from_cov
  • Loading branch information
bja43 authored Jan 11, 2025
2 parents f6a96e3 + bc9b002 commit 2a01a86
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions causallearn/score/LocalScoreFunctionClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
self.parameters = parameters
self.score_cache = {}

if self.local_score_fun == local_score_BIC_from_cov:
if self.local_score_fun.__name__ == 'local_score_BIC_from_cov':
self.cov = np.cov(self.data.T)
self.n = self.data.shape[0]

Expand All @@ -40,15 +40,15 @@ def score(self, i: int, PAi: List[int]) -> float:
hash_key = tuple(sorted(PAi))

if not self.score_cache[i].__contains__(hash_key):
if self.local_score_fun == local_score_BIC_from_cov:
if self.local_score_fun.__name__ == 'local_score_BIC_from_cov':
self.score_cache[i][hash_key] = self.local_score_fun((self.cov, self.n), i, PAi, self.parameters)
else:
self.score_cache[i][hash_key] = self.local_score_fun(self.data, i, PAi, self.parameters)

return self.score_cache[i][hash_key]

def score_nocache(self, i: int, PAi: List[int]) -> float:
if self.local_score_fun == local_score_BIC_from_cov:
if self.local_score_fun.__name__ == 'local_score_BIC_from_cov':
return self.local_score_fun((self.cov, self.n), i, PAi, self.parameters)
else:
return self.local_score_fun(self.data, i, PAi, self.parameters)
return self.local_score_fun(self.data, i, PAi, self.parameters)
2 changes: 1 addition & 1 deletion causallearn/search/PermutationBased/BOSS.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

def boss(
X: np.ndarray,
score_func: str = "local_score_BIC",
score_func: str = "local_score_BIC_from_cov",
parameters: Optional[Dict[str, Any]] = None,
verbose: Optional[bool] = True,
node_names: Optional[List[str]] = None,
Expand Down

0 comments on commit 2a01a86

Please sign in to comment.