Skip to content

Commit

Permalink
Increase timeout, fix model argument
Browse files Browse the repository at this point in the history
  • Loading branch information
liam-sbhoo committed Oct 2, 2024
1 parent 4b2c1c7 commit 16eafb0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self):
self.server_endpoints = SERVER_CONFIG["endpoints"]
self.base_url = f"{self.server_config.protocol}://{self.server_config.host}:{self.server_config.port}"
self.httpx_timeout_s = (
30 # temporary workaround for slow computation on server side
4 * 5 * 60 + 15 # temporary workaround for slow computation on server side
)
self.httpx_client = httpx.Client(
base_url=self.base_url,
Expand Down
13 changes: 9 additions & 4 deletions tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def to_dict(self):
class TabPFNClassifier(BaseEstimator, ClassifierMixin):
def __init__(
self,
model="latest_tabpfn_hosted",
model="default",
n_estimators: int = 4,
preprocess_transforms: Tuple[PreprocessorConfig, ...] = (
PreprocessorConfig(
Expand Down Expand Up @@ -212,8 +212,8 @@ def fit(self, X, y):
if config.g_tabpfn_config.use_server:
try:
assert (
self.model == "latest_tabpfn_hosted"
), "Only 'latest_tabpfn_hosted' model is supported at the moment for init(use_server=True)"
self.model == "default"
), "Only 'default' model is supported at the moment for init(use_server=True)"
except AssertionError as e:
print(e)

Expand All @@ -235,11 +235,16 @@ def predict_proba(self, X):
check_is_fitted(self)
validate_data_size(X)

estimator_param = self.get_params()
if "model" in estimator_param:
# TabPFNClassifier doesn't support different models at the moment.
estimator_param.pop("model")

return config.g_tabpfn_config.inference_handler.predict(
X,
task="classification",
train_set_uid=self.last_train_set_uid,
config=self.get_params(),
config=estimator_param,
)["probas"]


Expand Down

0 comments on commit 16eafb0

Please sign in to comment.