diff --git a/README.md b/README.md index 9b0911f..c5cb7fb 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ # The client for the (all new) TabPFN -This is an alpha family and friends service, so please do not expect this to never be down or run into errors. It worked fine in the settings that we tried, though. +This is an alpha family and friends service, so please do not expect this to never be down or run into errors. It worked fine in the settings that we tried, though. What model is behind the API? It is a new TabPFN which we allow to handle up to 10K data points with up to 500 features. You can control all pre-processing, the amount of ensembling etc. -### We would really appreciate your feedback! If you encounter bugs or suggestions for improvement please create an issue or email me (samuelgabrielmuller (at) gmail com). +### We would really appreciate your feedback! Please join our discord community here https://discord.gg/VJRuU3bSxt or email us at hello@priorlabs.ai # How To @@ -38,6 +38,26 @@ tabpfn.predict(X_test) # or you can also use tabpfn.predict_proba(X_test) ``` +To login using your access token, skipping the interactive flow, use: + +```python +from tabpfn_client import config + +# Retrieve Token +with open(config.g_tabpfn_config.user_auth_handler.CACHED_TOKEN_FILE, 'r') as file: + token = file.read() +print(f"TOKEN: {token}") +``` + +```python +from tabpfn_client import config + +# Set Token +service_client = config.ServiceClient() +config.g_tabpfn_config.user_auth_handler = config.UserAuthenticationClient(service_client=service_client) +user_auth = config.g_tabpfn_config.user_auth_handler.set_token(token) +``` + # Development To encourage better coding practices, `ruff` has been added to the pre-commit hooks. This will ensure that the code is formatted properly before being committed. To enable pre-commit (if you haven't), run the following command: diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index 37da8bc..25dab78 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -193,6 +193,10 @@ def _validate_response( # Read response. load = None try: + # This if clause is necessary for streaming responses (e.g. download) to + # prevent httpx.ResponseNotRead error. + if not response.is_closed: + response.read() load = response.json() except json.JSONDecodeError as e: logging.info(f"Failed to parse JSON from response in {method_name}: {e}") @@ -487,7 +491,12 @@ def download_all_data(self, save_dir: Path) -> Path | None: full_url = self.base_url + self.server_endpoints.download_all_data.path with httpx.stream( - "GET", full_url, headers={"Authorization": f"Bearer {self.access_token}"} + "GET", + full_url, + headers={ + "Authorization": f"Bearer {self.access_token}", + "client-version": get_client_version(), + }, ) as response: self._validate_response(response, "download_all_data")