Skip to content

Commit

Permalink
Fix mistake during rebase and update naming of greeting message retri…
Browse files Browse the repository at this point in the history
…eval
  • Loading branch information
davidotte committed Mar 9, 2024
1 parent 2dd9fec commit 68e6d5f
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 26 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,4 @@ cython_debug/
.idea/

models_diff/
tabpfn_client/.tabpfn/
1 change: 0 additions & 1 deletion tabpfn_client/.tabpfn/config

This file was deleted.

12 changes: 6 additions & 6 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,18 +293,18 @@ def get_password_policy(self) -> {}:

return response.json()["requirements"]

def retrieve_messages(self) -> list[str]:
def retrieve_greeting_messages(self) -> list[str]:
"""
Retrieve messages that are new for the user.
Retrieve greeting messages that are new for the user.
"""
response = self.httpx_client.get(self.server_endpoints.retrieve_messages.path)
response = self.httpx_client.get(self.server_endpoints.retrieve_greeting_messages.path)

self.error_raising(response, "retrieve_messages", only_version_check=True)
self.error_raising(response, "retrieve_greeting_messages", only_version_check=True)
if response.status_code != 200:
return []

messages = response.json()["messages"]
return messages
greeting_messages = response.json()["messages"]
return greeting_messages

def get_data_summary(self) -> {}:
"""
Expand Down
4 changes: 2 additions & 2 deletions tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def prompt_reusing_existing_token(cls):
print(cls.indent(prompt))

@classmethod
def prompt_retrieved_messages(cls, messages: list[str]):
for message in messages:
def prompt_retrieved_greeting_messages(cls, greeting_messages: list[str]):
for message in greeting_messages:
print(cls.indent(message))


Expand Down
6 changes: 3 additions & 3 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ endpoints:
methods: [ "POST" ]
description: "User login"

retrieve_messages:
path: "/retrieve_messages/"
retrieve_greeting_messages:
path: "/retrieve_greeting_messages/"
methods: [ "GET" ]
description: "Retrieve new messages"
description: "Retrieve new greeting messages"

protected_root:
path: "/protected/"
Expand Down
4 changes: 2 additions & 2 deletions tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def _reset_token(self):
self.service_client.reset_authorization()
self.CACHED_TOKEN_FILE.unlink(missing_ok=True)

def retrieve_messages(self):
return self.service_client.retrieve_messages()
def retrieve_greeting_messages(self):
return self.service_client.retrieve_greeting_messages()


class UserDataClient(ServiceClientWrapper):
Expand Down
2 changes: 1 addition & 1 deletion tabpfn_client/tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def init(use_server=True):
# prompt for login / register
PromptAgent.prompt_and_set_token(user_auth_handler)

PromptAgent.prompt_retrieved_messages(user_auth_handler.retrieve_messages())
PromptAgent.prompt_retrieved_greeting_messages(user_auth_handler.retrieve_greeting_messages())

g_tabpfn_config.use_server = True
g_tabpfn_config.user_auth_handler = user_auth_handler
Expand Down
2 changes: 1 addition & 1 deletion tabpfn_client/tests/integration/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_use_remote_tabpfn_classifier(self, mock_server):
# mock connection and authentication
mock_server.router.get(mock_server.endpoints.root.path).respond(200)
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
mock_server.router.get(mock_server.endpoints.retrieve_messages.path).respond(
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": []})
tabpfn_classifier.init(use_server=True)

Expand Down
6 changes: 3 additions & 3 deletions tabpfn_client/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def test_valid_auth_token(self, mock_server):
self.assertTrue(self.client.try_authenticate("true_token"))

@with_mock_server()
def test_retrieve_messages(self, mock_server):
mock_server.router.get(mock_server.endpoints.retrieve_messages.path).respond(
def test_retrieve_greeting_messages(self, mock_server):
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": ["message_1", "message_2"]})
self.assertEqual(self.client.retrieve_messages(), ["message_1", "message_2"])
self.assertEqual(self.client.retrieve_greeting_messages(), ["message_1", "message_2"])

@with_mock_server()
def test_predict_with_valid_train_set_and_test_set(self, mock_server):
Expand Down
10 changes: 3 additions & 7 deletions tabpfn_client/tests/unit/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ def tearDown(self):

def test_init_local_classifier(self):
tabpfn_classifier.init(use_server=False)
<<<<<<< HEAD
tabpfn = TabPFNClassifier(model="tabpfn_1_local").fit(self.X_train, self.y_train)
=======
tabpfn = TabPFNClassifier(model="public_tabpfn_hosted").fit(self.X_train, self.y_train)
>>>>>>> ae59bf5 (Fix: Test Cases in Client and Add Model to TabPFN Classifier)
self.assertTrue(isinstance(tabpfn.classifier_, LocalTabPFNClassifier))

@with_mock_server()
Expand All @@ -56,7 +52,7 @@ def test_init_remote_classifier(self, mock_server, mock_prompt_for_terms_and_con
mock_server.router.post(mock_server.endpoints.upload_train_set.path).respond(
200, json={"train_set_uid": 5}
)
mock_server.router.get(mock_server.endpoints.retrieve_messages.path).respond(
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": []})

tabpfn_classifier.init(use_server=True)
Expand All @@ -70,7 +66,7 @@ def test_reuse_saved_access_token(self, mock_server):
# mock connection and authentication
mock_server.router.get(mock_server.endpoints.root.path).respond(200)
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
mock_server.router.get(mock_server.endpoints.retrieve_messages.path).respond(
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": []})

# create dummy token file
Expand Down Expand Up @@ -118,7 +114,7 @@ def test_reset_on_remote_classifier(self, mock_server):
# init classifier as usual
mock_server.router.get(mock_server.endpoints.root.path).respond(200)
mock_server.router.get(mock_server.endpoints.protected_root.path).respond(200)
mock_server.router.get(mock_server.endpoints.retrieve_messages.path).respond(
mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
200, json={"messages": []})
tabpfn_classifier.init(use_server=True)

Expand Down

0 comments on commit 68e6d5f

Please sign in to comment.