diff --git a/.gitignore b/.gitignore index 596ba0b..119b94a 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,4 @@ cython_debug/ .idea/ models_diff/ +tabpfn_client/.tabpfn/ diff --git a/tabpfn_client/.tabpfn/config b/tabpfn_client/.tabpfn/config deleted file mode 100644 index a573f8f..0000000 --- a/tabpfn_client/.tabpfn/config +++ /dev/null @@ -1 +0,0 @@ -eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiNjUyNDc4Y2QtMWVkOS00MDdhLTk3MDQtOGQxZTA4MWI2MzBiIiwiZXhwIjoxNzEwMTY3MTAyfQ.bIe0Gan4OhWSKPXoPOxz8z9syxVklZYlQmwrs3eUbK0 \ No newline at end of file diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index ab9c96b..4800f8b 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -293,18 +293,18 @@ def get_password_policy(self) -> {}: return response.json()["requirements"] - def retrieve_messages(self) -> list[str]: + def retrieve_greeting_messages(self) -> list[str]: """ - Retrieve messages that are new for the user. + Retrieve greeting messages that are new for the user. """ - response = self.httpx_client.get(self.server_endpoints.retrieve_messages.path) + response = self.httpx_client.get(self.server_endpoints.retrieve_greeting_messages.path) - self.error_raising(response, "retrieve_messages", only_version_check=True) + self.error_raising(response, "retrieve_greeting_messages", only_version_check=True) if response.status_code != 200: return [] - messages = response.json()["messages"] - return messages + greeting_messages = response.json()["messages"] + return greeting_messages def get_data_summary(self) -> {}: """ diff --git a/tabpfn_client/prompt_agent.py b/tabpfn_client/prompt_agent.py index dc4fd7c..b88de66 100644 --- a/tabpfn_client/prompt_agent.py +++ b/tabpfn_client/prompt_agent.py @@ -100,8 +100,8 @@ def prompt_reusing_existing_token(cls): print(cls.indent(prompt)) @classmethod - def prompt_retrieved_messages(cls, messages: list[str]): - for message in messages: + def prompt_retrieved_greeting_messages(cls, greeting_messages: list[str]): + for message in greeting_messages: print(cls.indent(message)) diff --git a/tabpfn_client/server_config.yaml b/tabpfn_client/server_config.yaml index 9c46772..68efae2 100644 --- a/tabpfn_client/server_config.yaml +++ b/tabpfn_client/server_config.yaml @@ -28,10 +28,10 @@ endpoints: methods: [ "POST" ] description: "User login" - retrieve_messages: - path: "/retrieve_messages/" + retrieve_greeting_messages: + path: "/retrieve_greeting_messages/" methods: [ "GET" ] - description: "Retrieve new messages" + description: "Retrieve new greeting messages" protected_root: path: "/protected/" diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index 816a616..e3d9e47 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -85,8 +85,8 @@ def _reset_token(self): self.service_client.reset_authorization() self.CACHED_TOKEN_FILE.unlink(missing_ok=True) - def retrieve_messages(self): - return self.service_client.retrieve_messages() + def retrieve_greeting_messages(self): + return self.service_client.retrieve_greeting_messages() class UserDataClient(ServiceClientWrapper): diff --git a/tabpfn_client/tabpfn_classifier.py b/tabpfn_client/tabpfn_classifier.py index a7c6d4f..beb3db9 100644 --- a/tabpfn_client/tabpfn_classifier.py +++ b/tabpfn_client/tabpfn_classifier.py @@ -52,7 +52,7 @@ def init(use_server=True): # prompt for login / register PromptAgent.prompt_and_set_token(user_auth_handler) - PromptAgent.prompt_retrieved_messages(user_auth_handler.retrieve_messages()) + PromptAgent.prompt_retrieved_greeting_messages(user_auth_handler.retrieve_greeting_messages()) g_tabpfn_config.use_server = True g_tabpfn_config.user_auth_handler = user_auth_handler diff --git a/tabpfn_client/tests/integration/test_tabpfn_classifier.py b/tabpfn_client/tests/integration/test_tabpfn_classifier.py index 62183ef..1fe4dbf 100644 --- a/tabpfn_client/tests/integration/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/integration/test_tabpfn_classifier.py @@ -38,7 +38,7 @@ def test_use_remote_tabpfn_classifier(self, mock_server): # mock connection and authentication mock_server.router.get(mock_server.endpoints.root.path).respond(200) mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) - mock_server.router.get(mock_server.endpoints.retrieve_messages.path).respond( + mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond( 200, json={"messages": []}) tabpfn_classifier.init(use_server=True) diff --git a/tabpfn_client/tests/unit/test_client.py b/tabpfn_client/tests/unit/test_client.py index 172ab3b..2386d04 100644 --- a/tabpfn_client/tests/unit/test_client.py +++ b/tabpfn_client/tests/unit/test_client.py @@ -72,10 +72,10 @@ def test_valid_auth_token(self, mock_server): self.assertTrue(self.client.try_authenticate("true_token")) @with_mock_server() - def test_retrieve_messages(self, mock_server): - mock_server.router.get(mock_server.endpoints.retrieve_messages.path).respond( + def test_retrieve_greeting_messages(self, mock_server): + mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond( 200, json={"messages": ["message_1", "message_2"]}) - self.assertEqual(self.client.retrieve_messages(), ["message_1", "message_2"]) + self.assertEqual(self.client.retrieve_greeting_messages(), ["message_1", "message_2"]) @with_mock_server() def test_predict_with_valid_train_set_and_test_set(self, mock_server): diff --git a/tabpfn_client/tests/unit/test_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_tabpfn_classifier.py index 22819ae..ccf8a88 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/unit/test_tabpfn_classifier.py @@ -36,11 +36,7 @@ def tearDown(self): def test_init_local_classifier(self): tabpfn_classifier.init(use_server=False) -<<<<<<< HEAD tabpfn = TabPFNClassifier(model="tabpfn_1_local").fit(self.X_train, self.y_train) -======= - tabpfn = TabPFNClassifier(model="public_tabpfn_hosted").fit(self.X_train, self.y_train) ->>>>>>> ae59bf5 (Fix: Test Cases in Client and Add Model to TabPFN Classifier) self.assertTrue(isinstance(tabpfn.classifier_, LocalTabPFNClassifier)) @with_mock_server() @@ -56,7 +52,7 @@ def test_init_remote_classifier(self, mock_server, mock_prompt_for_terms_and_con mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond( 200, json={"train_set_uid": 5} ) - mock_server.router.get(mock_server.endpoints.retrieve_messages.path).respond( + mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond( 200, json={"messages": []}) tabpfn_classifier.init(use_server=True) @@ -70,7 +66,7 @@ def test_reuse_saved_access_token(self, mock_server): # mock connection and authentication mock_server.router.get(mock_server.endpoints.root.path).respond(200) mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) - mock_server.router.get(mock_server.endpoints.retrieve_messages.path).respond( + mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond( 200, json={"messages": []}) # create dummy token file @@ -118,7 +114,7 @@ def test_reset_on_remote_classifier(self, mock_server): # init classifier as usual mock_server.router.get(mock_server.endpoints.root.path).respond(200) mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200) - mock_server.router.get(mock_server.endpoints.retrieve_messages.path).respond( + mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond( 200, json={"messages": []}) tabpfn_classifier.init(use_server=True)