diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index 7130bb6..b98ef9f 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -179,7 +179,7 @@ def reset_authorization(cls): cls.httpx_client.headers.pop("Authorization", None) @classmethod - def upload_train_set(cls, X, y) -> str: + def fit(cls, X, y, config=None) -> str: """ Upload a train set to server and return the train set UID if successful. @@ -189,6 +189,8 @@ def upload_train_set(cls, X, y) -> str: The training input samples. y : array-like of shape (n_samples,) or (n_samples, n_outputs) The target values. + config : dict, optional + Configuration for the fit method. Includes tabpfn_systems and paper_version. Returns ------- @@ -200,26 +202,34 @@ def upload_train_set(cls, X, y) -> str: X_serialized = common_utils.serialize_to_csv_formatted_bytes(X) y_serialized = common_utils.serialize_to_csv_formatted_bytes(y) + if config is None: + tabpfn_systems = ["preprocessing", "text"] + else: + tabpfn_systems = ( + [] if config["paper_version"] else ["preprocessing", "text"] + ) + # Get hash for dataset. Include access token for the case that one user uses different accounts. cached_dataset_uid, dataset_hash = ( cls.dataset_uid_cache_manager.get_dataset_uid( - X_serialized, y_serialized, cls._access_token + X_serialized, y_serialized, cls._access_token, "_".join(tabpfn_systems) ) ) if cached_dataset_uid: return cached_dataset_uid response = cls.httpx_client.post( - url=cls.server_endpoints.upload_train_set.path, + url=cls.server_endpoints.fit.path, files=common_utils.to_httpx_post_file_format( [ ("x_file", "x_train_filename", X_serialized), ("y_file", "y_train_filename", y_serialized), ] ), + params={"tabpfn_systems": json.dumps(tabpfn_systems)}, ) - cls._validate_response(response, "upload_train_set") + cls._validate_response(response, "fit") train_set_uid = response.json()["train_set_uid"] cls.dataset_uid_cache_manager.add_dataset_uid(dataset_hash, train_set_uid) @@ -253,20 +263,28 @@ def predict( x_test_serialized = common_utils.serialize_to_csv_formatted_bytes(x_test) + params = {"train_set_uid": train_set_uid, "task": task} + if tabpfn_config is not None: + paper_version = tabpfn_config.pop("paper_version") + params["tabpfn_config"] = json.dumps( + tabpfn_config, default=lambda x: x.to_dict() + ) + else: + paper_version = False + tabpfn_systems = [] if paper_version else ["preprocessing", "text"] + params["tabpfn_systems"] = json.dumps(tabpfn_systems) + # In the arguments for hashing, include train_set_uid for the case that the same test set was previously used # with different train set. Include access token for the case that a user uses different accounts. cached_test_set_uid, dataset_hash = ( cls.dataset_uid_cache_manager.get_dataset_uid( - x_test_serialized, train_set_uid, cls._access_token + x_test_serialized, + train_set_uid, + cls._access_token, + "_".join(tabpfn_systems), ) ) - params = {"train_set_uid": train_set_uid, "task": task} - if tabpfn_config is not None: - params["tabpfn_config"] = json.dumps( - tabpfn_config, default=lambda x: x.to_dict() - ) - # Send prediction request. Loop two times, such that if anything cached is not correct # anymore, there is a second iteration where the datasets are uploaded. results = None @@ -331,7 +349,14 @@ def run_progress(): raise RuntimeError( "Train set data is required to re-upload but was not provided." ) - train_set_uid = cls.upload_train_set(X_train, y_train) + train_set_uid = cls.fit( + X_train, + y_train, + config=dict( + tabpfn_config if tabpfn_config else {}, + **{"paper_version": paper_version}, + ), + ) params["train_set_uid"] = train_set_uid cached_test_set_uid = None else: diff --git a/tabpfn_client/estimator.py b/tabpfn_client/estimator.py index 71763f9..3ef3ec8 100644 --- a/tabpfn_client/estimator.py +++ b/tabpfn_client/estimator.py @@ -185,6 +185,7 @@ def __init__( remove_outliers=12.0, add_fingerprint_features=True, subsample_samples=-1, + paper_version=False, ): """ Parameters: @@ -208,6 +209,7 @@ def __init__( remove_outliers: If not 0.0, will remove outliers from the input features, where values with a standard deviation larger than remove_outliers will be removed. add_fingerprint_features: If True, will add one feature of random values, that will be added to the input features. This helps discern duplicated samples in the transformer model. subsample_samples: If not None, will use a random subset of the samples for training in each ensemble configuration. If 1 or above, this will subsample to the specified number of samples. If in 0 to 1, the value is viewed as a fraction of the training set size. + paper_version: If True, will use the model described in the paper. Otherwise, will use a better model. Default is False. """ self.model = model self.n_estimators = n_estimators @@ -224,6 +226,7 @@ def __init__( self.remove_outliers = remove_outliers self.add_fingerprint_features = add_fingerprint_features self.subsample_samples = subsample_samples + self.paper_version = paper_version self.last_train_set_uid = None self.last_train_X = None self.last_train_y = None @@ -239,20 +242,17 @@ def _validate_targets_and_classes(self, y) -> np.ndarray: not_nan_mask = ~np.isnan(y) self.classes_ = np.unique(y_[not_nan_mask]) - @staticmethod - def _validate_data_size(X: np.ndarray, y: np.ndarray | None): - if X.shape[0] != y.shape[0]: - raise ValueError("X and y must have the same number of samples") - def fit(self, X, y): # assert init() is called init() validate_data_size(X, y) self._validate_targets_and_classes(y) + _check_paper_version(self.paper_version, X) + estimator_param = self.get_params() if Config.use_server: - self.last_train_set_uid = InferenceClient.fit(X, y) + self.last_train_set_uid = InferenceClient.fit(X, y, config=estimator_param) self.last_train_X = X self.last_train_y = y self.fitted_ = True @@ -271,6 +271,7 @@ def predict(self, X): def predict_proba(self, X): check_is_fitted(self) validate_data_size(X) + _check_paper_version(self.paper_version, X) estimator_param = self.get_params() if "model" in estimator_param: @@ -340,6 +341,7 @@ def __init__( cancel_nan_borders: bool = True, super_bar_dist_averaging: bool = False, subsample_samples: float = -1, + paper_version: bool = False, ): """ Parameters: @@ -374,6 +376,7 @@ def __init__( subsample_samples: If not None, will use a random subset of the samples for training in each ensemble configuration. If 1 or above, this will subsample to the specified number of samples. If in 0 to 1, the value is viewed as a fraction of the training set size. + paper_version: If True, will use the model described in the paper. Otherwise, will use a better model. Default is False. """ if model not in self._AVAILABLE_MODELS: @@ -399,15 +402,18 @@ def __init__( self.last_train_set_uid = None self.last_train_X = None self.last_train_y = None + self.paper_version = paper_version def fit(self, X, y): # assert init() is called init() validate_data_size(X, y) + _check_paper_version(self.paper_version, X) + estimator_param = self.get_params() if Config.use_server: - self.last_train_set_uid = InferenceClient.fit(X, y) + self.last_train_set_uid = InferenceClient.fit(X, y, config=estimator_param) self.last_train_X = X self.last_train_y = y self.fitted_ = True @@ -431,6 +437,7 @@ def predict(self, X): def predict_full(self, X): check_is_fitted(self) validate_data_size(X) + _check_paper_version(self.paper_version, X) estimator_param = self.get_params() if "model" in estimator_param: @@ -468,3 +475,13 @@ def validate_data_size(X: np.ndarray, y: np.ndarray | None = None): raise ValueError(f"The number of rows cannot be more than {MAX_ROWS}.") if X.shape[1] > MAX_COLS: raise ValueError(f"The number of columns cannot be more than {MAX_COLS}.") + + +def _check_paper_version(paper_version, X): + if paper_version: + # check if X can be converted to numerical values + try: + np.array(X, dtype=np.float32) + except ValueError: + raise ValueError("""X must be numerical to use the paper version of the model. + Preprocess your data or use `paper_version=False`.""") diff --git a/tabpfn_client/server_config.yaml b/tabpfn_client/server_config.yaml index acb3f7b..6a6b897 100644 --- a/tabpfn_client/server_config.yaml +++ b/tabpfn_client/server_config.yaml @@ -68,6 +68,11 @@ endpoints: methods: [ "POST" ] description: "Upload train set" + fit: + path: "/fit/" + methods: [ "POST" ] + description: "Fit" + predict: path: "/predict/" methods: [ "POST" ] diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index b23e560..f7860df 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -210,8 +210,8 @@ def __new__(self): ) @classmethod - def fit(cls, X, y) -> str: - return ServiceClient.upload_train_set(X, y) + def fit(cls, X, y, config=None) -> str: + return ServiceClient.fit(X, y, config=config) @classmethod def predict( diff --git a/tabpfn_client/tests/integration/test_tabpfn_classifier.py b/tabpfn_client/tests/integration/test_tabpfn_classifier.py index 928ec8e..92da77b 100644 --- a/tabpfn_client/tests/integration/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/integration/test_tabpfn_classifier.py @@ -42,7 +42,7 @@ def test_use_remote_tabpfn_classifier(self, mock_server): tabpfn = TabPFNClassifier() # mock fitting - mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond( + mock_server.router.post(mock_server.endpoints.fit.path).respond( 200, json={"train_set_uid": "5"} ) tabpfn.fit(self.X_train, self.y_train) diff --git a/tabpfn_client/tests/unit/test_client.py b/tabpfn_client/tests/unit/test_client.py index 6847b09..740bc72 100644 --- a/tabpfn_client/tests/unit/test_client.py +++ b/tabpfn_client/tests/unit/test_client.py @@ -193,11 +193,11 @@ def test_retrieve_greeting_messages(self, mock_server): @with_mock_server() def test_predict_with_valid_train_set_and_test_set(self, mock_server): dummy_json = {"train_set_uid": "5"} - mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond( + mock_server.router.post(mock_server.endpoints.fit.path).respond( 200, json=dummy_json ) ServiceClient.authorize("dummy_token") - ServiceClient.upload_train_set(self.X_train, self.y_train) + ServiceClient.fit(self.X_train, self.y_train) dummy_result = {"test_set_uid": "dummy_uid", "classification": [1, 2, 3]} mock_server.router.post(mock_server.endpoints.predict.path).respond( @@ -251,14 +251,14 @@ def test_validate_response_only_version_check(self): self.assertIsNone(r) @with_mock_server() - def test_upload_train_set_with_caching(self, mock_server): + def test_fit_with_caching(self, mock_server): """ - Test that uploading the same training set multiple times uses the cache and - only calls the upload_train_set endpoint once. + Test that calling fit with the same training set multiple times uses the cache and + only calls the fit endpoint once. """ ServiceClient.authorize("dummy_access_token") - # Mock the upload_train_set endpoint to return a fixed train_set_uid + # Mock the fit endpoint to return a fixed train_set_uid with patch.object( ServiceClient.httpx_client, "post", wraps=ServiceClient.httpx_client.post ) as mock_post: @@ -269,15 +269,15 @@ def test_upload_train_set_with_caching(self, mock_server): mock_post.return_value = mock_response # First upload - train_set_uid1 = ServiceClient.upload_train_set(self.X_train, self.y_train) + train_set_uid1 = ServiceClient.fit(self.X_train, self.y_train) # Second upload with the same data - train_set_uid2 = ServiceClient.upload_train_set(self.X_train, self.y_train) + train_set_uid2 = ServiceClient.fit(self.X_train, self.y_train) # The train_set_uid should be the same due to caching self.assertEqual(train_set_uid1, train_set_uid2) - # The upload_train_set endpoint should have been called only once + # The fit endpoint should have been called only once mock_post.assert_called_once() def test_predict_with_caching(self): @@ -287,7 +287,7 @@ def test_predict_with_caching(self): """ ServiceClient.authorize("dummy_access_token") - # Mock the upload_train_set and predict endpoints + # Mock the fit and predict endpoints with ( patch.object( ServiceClient.httpx_client, @@ -302,10 +302,7 @@ def test_predict_with_caching(self): ): # Mock responses def side_effect(*args, **kwargs): - if ( - kwargs.get("url") - == ServiceClient.server_endpoints.upload_train_set.path - ): + if kwargs.get("url") == ServiceClient.server_endpoints.fit.path: response = Mock() response.status_code = 200 response.json.return_value = { @@ -333,7 +330,7 @@ def side_effect(*args, **kwargs): mock_stream.side_effect = side_effect # Upload train set - train_set_uid = ServiceClient.upload_train_set(self.X_train, self.y_train) + train_set_uid = ServiceClient.fit(self.X_train, self.y_train) # First prediction pred1 = ServiceClient.predict( @@ -351,7 +348,7 @@ def side_effect(*args, **kwargs): # The predict endpoint should have been called twice self.assertEqual( mock_post.call_count + mock_stream.call_count, 3 - ) # 1 for upload_train_set, 2 for predict + ) # 1 for fit, 2 for predict # Check that the test set was uploaded only once (first predict call) upload_calls = [ @@ -366,7 +363,7 @@ def test_predict_with_invalid_cached_uids(self): """ ServiceClient.authorize("dummy_access_token") - # Mock the upload_train_set and predict endpoints + # Mock the fit and predict endpoints with ( patch.object( ServiceClient.httpx_client, @@ -381,10 +378,7 @@ def test_predict_with_invalid_cached_uids(self): ): # Mock responses with side effects to simulate invalid cached UIDs def side_effect(*args, **kwargs): - if ( - kwargs.get("url") - == ServiceClient.server_endpoints.upload_train_set.path - ): + if kwargs.get("url") == ServiceClient.server_endpoints.fit.path: response = Mock() response.status_code = 200 response.json.return_value = { @@ -430,7 +424,7 @@ def side_effect_counter(*args, **kwargs): mock_stream.side_effect = side_effect_counter # Upload train set - train_set_uid = ServiceClient.upload_train_set(self.X_train, self.y_train) + train_set_uid = ServiceClient.fit(self.X_train, self.y_train) # Attempt prediction, which should fail and trigger retry pred = ServiceClient.predict( @@ -447,14 +441,13 @@ def side_effect_counter(*args, **kwargs): # The predict endpoint should have been called twice due to retry self.assertEqual( mock_post.call_count + mock_stream.call_count, 4 - ) # 1 upload_train_set + 2 predict + 1 re-upload + ) # 1 fit + 2 predict + 1 re-upload - # Ensure that upload_train_set was called again (re-upload) + # Ensure that fit was called again (re-upload) upload_calls = [ call for call in mock_post.call_args_list - if call.kwargs.get("url") - == ServiceClient.server_endpoints.upload_train_set.path + if call.kwargs.get("url") == ServiceClient.server_endpoints.fit.path ] self.assertEqual(len(upload_calls), 2) diff --git a/tabpfn_client/tests/unit/test_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_tabpfn_classifier.py index c74ce96..70b8b6e 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/unit/test_tabpfn_classifier.py @@ -3,7 +3,10 @@ import shutil import json +import pandas as pd + import numpy as np + from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split from sklearn.exceptions import NotFittedError @@ -11,7 +14,6 @@ from tabpfn_client import init, reset from tabpfn_client.estimator import TabPFNClassifier from tabpfn_client.service_wrapper import UserAuthenticationClient, InferenceClient -from tabpfn_client.client import ServiceClient from tabpfn_client.tests.mock_tabpfn_server import with_mock_server from tabpfn_client.constants import CACHE_DIR from tabpfn_client import config @@ -47,7 +49,7 @@ def test_init_remote_classifier( # mock server connection mock_server.router.get(mock_server.endpoints.root.path).respond(200) - mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond( + mock_server.router.post(mock_server.endpoints.fit.path).respond( 200, json={"train_set_uid": "5"} ) @@ -62,9 +64,6 @@ def test_init_remote_classifier( content=f'data: {json.dumps({"event": "result", "data": {"classification": mock_predict_response, "test_set_uid": "6"}})}\n\n', headers={"Content-Type": "text/event-stream"}, ) - print( - f"{UserAuthenticationClient.CACHED_TOKEN_FILE.exists()=}, {ServiceClient.get_access_token()=}" - ) init(use_server=True) self.assertTrue(mock_prompt_and_set_token.called) @@ -165,6 +164,141 @@ def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_con self.assertRaises(RuntimeError, init, use_server=True) self.assertTrue(mock_prompt_for_terms_and_cond.called) + @with_mock_server() + @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") + @patch( + "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", + return_value=True, + ) + def test_cache_based_on_paper_version( + self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token + ): + """ + Check that the dataset cached is different for different paper_version + and similar for the same paper_version + """ + # a bit out of place but we don't want to skip init for this test + mock_prompt_and_set_token.side_effect = ( + lambda: UserAuthenticationClient.set_token(self.dummy_token) + ) + + # mock server connection + mock_server.router.get(mock_server.endpoints.root.path).respond(200) + fit_route = mock_server.router.post(mock_server.endpoints.fit.path) + fit_route.respond(200, json={"train_set_uid": "5"}) + + mock_server.router.get( + mock_server.endpoints.retrieve_greeting_messages.path + ).respond(200, json={"messages": []}) + + mock_predict_response = [[1, 0.0], [0.9, 0.1], [0.01, 0.99]] + predict_route = mock_server.router.post(mock_server.endpoints.predict.path) + predict_route.respond( + 200, + content=f'data: {json.dumps({"event": "result", "data": {"classification": mock_predict_response, "test_set_uid": "6"}})}\n\n', + headers={"Content-Type": "text/event-stream"}, + ) + + init(use_server=True) + + X = np.random.rand(10, 5) + y = np.random.randint(0, 2, 10) + test_X = np.random.rand(5, 5) + + # Initialize with paper_version=True + tabpfn_true = TabPFNClassifier(paper_version=True) + + tabpfn_true.fit(X, y) + tabpfn_true.predict(test_X) + + # Call fit and predict again with the same paper_version + tabpfn_true.fit(X, y) + tabpfn_true.predict(test_X) + + # Ensure fit endpoint is not called again + self.assertEqual( + fit_route.call_count, + 1, + "Fit endpoint should not be called again with the same paper_version", + ) + + # Initialize with paper_version=False + tabpfn_false = TabPFNClassifier(paper_version=False) + + tabpfn_false.fit(X, y) + tabpfn_false.predict(test_X) + + # check fit is called + self.assertEqual( + fit_route.call_count, + 2, + "Fit endpoint should be called again with a different paper_version", + ) + + # Call fit and predict again with the same paper_version + tabpfn_false.fit(X, y) + tabpfn_false.predict(test_X) + + # Ensure fit endpoint is not called again + self.assertEqual( + fit_route.call_count, + 2, + "Fit endpoint should not be called again with the same paper_version", + ) + + # TODO: fix this + # # Check that different cache entries are created for training set + # cache_manager = ServiceClient.dataset_uid_cache_manager + # X_serialized = common_utils.serialize_to_csv_formatted_bytes(X) + # y_serialized = common_utils.serialize_to_csv_formatted_bytes(y) + # uid_true_train, hash_true_train = cache_manager.get_dataset_uid( + # X_serialized, y_serialized, self.dummy_token, "_".join([]) + # ) + # uid_false_train, hash_false_train = cache_manager.get_dataset_uid( + # X_serialized, + # y_serialized, + # self.dummy_token, + # "_".join(["preprocessing", "text"]), + # ) + + # self.assertNotEqual( + # hash_true_train, + # hash_false_train, + # "Cache hash should differ based on paper_version for training set", + # ) + + # # Check that different cache entries are created for test set + # test_X_serialized = common_utils.serialize_to_csv_formatted_bytes(test_X) + # uid_true_test, hash_true_test = cache_manager.get_dataset_uid( + # test_X_serialized, uid_true_train, self.dummy_token, "_".join([]) + # ) + # uid_false_test, hash_false_test = cache_manager.get_dataset_uid( + # test_X_serialized, + # uid_false_train, + # self.dummy_token, + # "_".join(["preprocessing", "text"]), + # ) + + # self.assertNotEqual( + # hash_true_test, + # hash_false_test, + # "Cache hash should differ based on paper_version for test set", + # ) + + # # Verify that the cache entries are used correctly + # self.assertIsNotNone( + # uid_true_train, "Training set cache should be used for paper_version=True" + # ) + # self.assertIsNotNone( + # uid_false_train, "Training set cache should be used for paper_version=False" + # ) + # self.assertIsNotNone( + # uid_true_test, "Test set cache should be used for paper_version=True" + # ) + # self.assertIsNotNone( + # uid_false_test, "Test set cache should be used for paper_version=False" + # ) + class TestTabPFNClassifierInference(unittest.TestCase): def setUp(self): @@ -302,3 +436,61 @@ def test_predict_proba_uses_correct_model_path(self): self.assertEqual( predict_kwargs["config"]["model_path"], expected_model_path ) + + @patch.object(InferenceClient, "fit", return_value="dummy_uid") + @patch.object( + InferenceClient, "predict", return_value={"probas": np.random.rand(10, 2)} + ) + def test_paper_version_behavior(self, mock_predict, mock_fit): + # this just test that it doesn't break, + # but the actual behavior is easier to test + # on the server side + X = np.random.rand(10, 5) + y = np.random.randint(0, 2, 10) + test_X = np.random.rand(5, 5) + + # Test with paper_version=True + tabpfn_true = TabPFNClassifier(paper_version=True) + tabpfn_true.fit(X, y) + y_pred_true = tabpfn_true.predict(test_X) + self.assertIsNotNone(y_pred_true) + + # Test with paper_version=False + tabpfn_false = TabPFNClassifier(paper_version=False) + tabpfn_false.fit(X, y) + y_pred_false = tabpfn_false.predict(test_X) + self.assertIsNotNone(y_pred_false) + + @patch.object(InferenceClient, "fit", return_value="dummy_uid") + @patch.object( + InferenceClient, "predict", return_value={"probas": np.random.rand(10, 2)} + ) + def test_check_paper_version_with_non_numerical_data_raises_error( + self, mock_predict, mock_fit + ): + # Create a TabPFNClassifier with paper_version=True + tabpfn = TabPFNClassifier(paper_version=True) + + # Create non-numerical data + X = pd.DataFrame({"feature1": ["a", "b", "c"], "feature2": ["d", "e", "f"]}) + y = np.array([0, 1, 0]) + + with self.assertRaises(ValueError) as context: + tabpfn.fit(X, y) + + self.assertIn( + "X must be numerical to use the paper version of the model", + str(context.exception), + ) + + # check that it works with paper_version=False + tabpfn = TabPFNClassifier(paper_version=False) + tabpfn.fit(X, y) + + # check that paper_version=True works with numerical data even with the wrong type + X = np.random.rand(10, 5).astype(str) + y = np.random.randint(0, 2, 10) + tabpfn = TabPFNClassifier(paper_version=True) + tabpfn.fit(X, y) + X = pd.DataFrame(X).astype(str) + tabpfn.predict(X) diff --git a/tabpfn_client/tests/unit/test_tabpfn_regressor.py b/tabpfn_client/tests/unit/test_tabpfn_regressor.py index 44dbafc..1689b0e 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_regressor.py +++ b/tabpfn_client/tests/unit/test_tabpfn_regressor.py @@ -14,6 +14,7 @@ from tabpfn_client.constants import CACHE_DIR from tabpfn_client import config import json +import pandas as pd class TestTabPFNRegressorInit(unittest.TestCase): @@ -46,7 +47,7 @@ def test_init_remote_regressor( # mock server connection mock_server.router.get(mock_server.endpoints.root.path).respond(200) - mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond( + mock_server.router.post(mock_server.endpoints.fit.path).respond( 200, json={"train_set_uid": "5"} ) mock_server.router.get( @@ -177,6 +178,140 @@ def test_decline_terms_and_cond(self, mock_server, mock_prompt_for_terms_and_con self.assertRaises(RuntimeError, init, use_server=True) self.assertTrue(mock_prompt_for_terms_and_cond.called) + @with_mock_server() + @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") + @patch( + "tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", + return_value=True, + ) + def test_cache_based_on_paper_version( + self, mock_server, mock_prompt_for_terms_and_cond, mock_prompt_and_set_token + ): + mock_prompt_and_set_token.side_effect = ( + lambda: UserAuthenticationClient.set_token(self.dummy_token) + ) + + # mock server connection + mock_server.router.get(mock_server.endpoints.root.path).respond(200) + fit_route = mock_server.router.post(mock_server.endpoints.fit.path) + fit_route.respond(200, json={"train_set_uid": "5"}) + + mock_server.router.get( + mock_server.endpoints.retrieve_greeting_messages.path + ).respond(200, json={"messages": []}) + + mock_predict_response = { + "mean": [100, 200, 300], + "median": [110, 210, 310], + "mode": [120, 220, 320], + } + predict_route = mock_server.router.post(mock_server.endpoints.predict.path) + predict_route.respond( + 200, + content=f'data: {json.dumps({"event": "result", "data": {"regression": mock_predict_response, "test_set_uid": "6"}})}\n\n', + headers={"Content-Type": "text/event-stream"}, + ) + + init(use_server=True) + + X = np.random.rand(10, 5) + y = np.random.rand(10) + test_X = np.random.rand(5, 5) + + # Initialize with paper_version=True + tabpfn_true = TabPFNRegressor(paper_version=True) + + tabpfn_true.fit(X, y) + tabpfn_true.predict(test_X) + + # Call fit and predict again with the same paper_version + tabpfn_true.fit(X, y) + tabpfn_true.predict(test_X) + + # Ensure fit endpoint is not called again + self.assertEqual( + fit_route.call_count, + 1, + "Fit endpoint should not be called again with the same paper_version", + ) + + # Initialize with paper_version=False + tabpfn_false = TabPFNRegressor(paper_version=False) + + tabpfn_false.fit(X, y) + tabpfn_false.predict(test_X) + + # check fit is called + self.assertEqual( + fit_route.call_count, + 2, + "Fit endpoint should be called again with a different paper_version", + ) + + # Call fit and predict again with the same paper_version + tabpfn_false.fit(X, y) + tabpfn_false.predict(test_X) + + # Ensure fit endpoint is not called again + self.assertEqual( + fit_route.call_count, + 2, + "Fit endpoint should not be called again with the same paper_version", + ) + + # TODO: fix this + # # Check that different cache entries are created for training set + # cache_manager = ServiceClient.dataset_uid_cache_manager + # X_serialized = common_utils.serialize_to_csv_formatted_bytes(X) + # y_serialized = common_utils.serialize_to_csv_formatted_bytes(y) + # uid_true_train, hash_true_train = cache_manager.get_dataset_uid( + # X_serialized, y_serialized, self.dummy_token, "_".join([]) + # ) + # uid_false_train, hash_false_train = cache_manager.get_dataset_uid( + # X_serialized, + # y_serialized, + # self.dummy_token, + # "_".join(["preprocessing", "text"]), + # ) + + # self.assertNotEqual( + # hash_true_train, + # hash_false_train, + # "Cache hash should differ based on paper_version for training set", + # ) + + # # Check that different cache entries are created for test set + # test_X_serialized = common_utils.serialize_to_csv_formatted_bytes(test_X) + # uid_true_test, hash_true_test = cache_manager.get_dataset_uid( + # test_X_serialized, uid_true_train, self.dummy_token, "_".join([]) + # ) + # uid_false_test, hash_false_test = cache_manager.get_dataset_uid( + # test_X_serialized, + # uid_false_train, + # self.dummy_token, + # "_".join(["preprocessing", "text"]), + # ) + + # self.assertNotEqual( + # hash_true_test, + # hash_false_test, + # "Cache hash should differ based on paper_version for test set", + # ) + + # # Verify that the cache entries are used correctly + # self.assertIsNotNone( + # uid_true_train, "Training set cache should be used for paper_version=True" + # ) + # self.assertIsNotNone( + # uid_false_train, "Training set cache should be used for paper_version=False" + # ) + # self.assertIsNotNone( + # uid_true_test, "Test set cache should be used for paper_version=True" + # ) + # self.assertIsNotNone( + # uid_false_test, "Test set cache should be used for paper_version=False" + # ) + class TestTabPFNRegressorInference(unittest.TestCase): def setUp(self): @@ -306,3 +441,57 @@ def test_predict_uses_correct_model_path(self): self.assertEqual( predict_kwargs["config"]["model_path"], expected_model_path ) + + @patch.object(InferenceClient, "fit", return_value="dummy_uid") + @patch.object(InferenceClient, "predict", return_value={"mean": np.random.rand(10)}) + def test_paper_version_behavior(self, mock_predict, mock_fit): + # this just tests that it doesn't break, + # but the actual behavior is easier to test + # on the server side + X = np.random.rand(10, 5) + y = np.random.rand(10) + test_X = np.random.rand(5, 5) + + # Test with paper_version=True + tabpfn_true = TabPFNRegressor(paper_version=True) + tabpfn_true.fit(X, y) + y_pred_true = tabpfn_true.predict(test_X) + self.assertIsNotNone(y_pred_true) + + # Test with paper_version=False + tabpfn_false = TabPFNRegressor(paper_version=False) + tabpfn_false.fit(X, y) + y_pred_false = tabpfn_false.predict(test_X) + self.assertIsNotNone(y_pred_false) + + @patch.object(InferenceClient, "fit", return_value="dummy_uid") + @patch.object(InferenceClient, "predict", return_value={"mean": np.random.rand(10)}) + def test_check_paper_version_with_non_numerical_data_raises_error( + self, mock_predict, mock_fit + ): + # Create a TabPFNRegressor with paper_version=True + tabpfn = TabPFNRegressor(paper_version=True) + + # Create non-numerical data + X = pd.DataFrame({"feature1": ["a", "b", "c"], "feature2": ["d", "e", "f"]}) + y = np.array([0.1, 0.2, 0.3]) + + with self.assertRaises(ValueError) as context: + tabpfn.fit(X, y) + + self.assertIn( + "X must be numerical to use the paper version of the model", + str(context.exception), + ) + + # check that it works with paper_version=False + tabpfn = TabPFNRegressor(paper_version=False) + tabpfn.fit(X, y) + + # check that paper_version=True works with numerical data even with the wrong type + X = np.random.rand(10, 5).astype(str) + y = np.random.rand(10) # Continuous target for regression + tabpfn = TabPFNRegressor(paper_version=True) + tabpfn.fit(X, y) + X = pd.DataFrame(X).astype(str) + tabpfn.predict(X)