-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
246cc65
commit 2aebae2
Showing
4 changed files
with
52 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
""" | ||
Something like here: https://github.com/yang-song/score_sde_pytorch/blob/main/sampling.py | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
""" | ||
This file should contain a general abstraction of the score models and | ||
should function as a wrapper for different models we might want to use. | ||
tThe idea is to "hide" the particular tree we want to use so that | ||
we can easily switch between different models without having to change | ||
the rest of the code. | ||
""" | ||
|
||
import abc | ||
|
||
|
||
class ScoreModel(abc.ABC): | ||
@abc.abstractmethod | ||
def score(self, data): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
""" | ||
Similar to https://github.com/yang-song/score_sde_pytorch/blob/main/sde_lib.py | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
""" | ||
This should be the main file corresponding to the project. | ||
""" | ||
|
||
from sklearn.base import BaseEstimator | ||
|
||
|
||
class Treeffuser(BaseEstimator): | ||
|
||
def __init__(self, *args, **kwargs): | ||
pass | ||
|
||
def fit(self, X, y): | ||
pass | ||
|
||
def predict(self, X): | ||
pass | ||
|
||
def sample(self, X): | ||
pass | ||
|
||
def likelihood(self, X, y): | ||
""" | ||
Something that computes the log-likelihood of the model. | ||
""" | ||
|
||
def pred_distribution(self, X): | ||
""" | ||
Maybe the CDF? | ||
""" |