From 66a1c97c244bd046d72d5dae8e0506bfcf2af94d Mon Sep 17 00:00:00 2001 From: Anshul Gupta Date: Wed, 8 May 2024 15:25:14 +0200 Subject: [PATCH] Removed Login after Registration and Additional Info Tests --- tabpfn_client/client.py | 21 ---------- tabpfn_client/service_wrapper.py | 5 --- tabpfn_client/tests/unit/test_client.py | 40 +++++++++++-------- .../tests/unit/test_service_wrapper.py | 29 +++----------- 4 files changed, 30 insertions(+), 65 deletions(-) diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index ce2c661..dfd1d83 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -348,27 +348,6 @@ def get_password_policy(self) -> {}: return response.json()["requirements"] - def add_user_information( - self, company: str | None, role: str | None, use_case: str | None, contact_via_email: bool - ): - """ - Send additional user information to the server. - """ - information = {"contact_via_email": contact_via_email} - if company: - information["company"] = company - if role: - information["role"] = role - if use_case: - information["use_case"] = use_case - - response = self.httpx_client.post( - self.server_endpoints.add_user_information.path, - json=information - ) - - self._validate_response(response, "add_user_information") - def retrieve_greeting_messages(self) -> list[str]: """ Retrieve greeting messages that are new for the user. diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index 6c46d19..4a2a753 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -78,11 +78,6 @@ def try_reuse_existing_token(self) -> bool: def get_password_policy(self): return self.service_client.get_password_policy() - def add_user_information( - self, company: str | None, role: str | None, use_case: str | None, contact_via_email: bool - ): - self.service_client.add_user_information(company, role, use_case, contact_via_email) - def reset_cache(self): self._reset_token() diff --git a/tabpfn_client/tests/unit/test_client.py b/tabpfn_client/tests/unit/test_client.py index ab2a850..65553a2 100644 --- a/tabpfn_client/tests/unit/test_client.py +++ b/tabpfn_client/tests/unit/test_client.py @@ -50,22 +50,42 @@ def test_validate_email_invalid(self, mock_server): @with_mock_server() def test_register_user(self, mock_server): mock_server.router.post(mock_server.endpoints.register.path).respond(200, json={"message": "dummy_message"}) - self.assertTrue(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation")[0]) + self.assertTrue(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation", { + "company": "dummy_company", + "use_case": "dummy_usecase", + "role": "dummy_role", + "contact_via_email": False + })[0]) @with_mock_server() def test_register_user_with_invalid_email(self, mock_server): mock_server.router.post(mock_server.endpoints.register.path).respond(401, json={"detail": "dummy_message"}) - self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation")[0]) + self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation", { + "company": "dummy_company", + "use_case": "dummy_usecase", + "role": "dummy_role", + "contact_via_email": False + })[0]) @with_mock_server() def test_register_user_with_invalid_validation_link(self, mock_server): mock_server.router.post(mock_server.endpoints.register.path).respond(401, json={"detail": "dummy_message"}) - self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation")[0]) + self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation", { + "company": "dummy_company", + "use_case": "dummy_usecase", + "role": "dummy_role", + "contact_via_email": False + })[0]) @with_mock_server() def test_register_user_with_limit_reached(self, mock_server): mock_server.router.post(mock_server.endpoints.register.path).respond(401, json={"detail": "dummy_message"}) - self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation")[0]) + self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation", { + "company": "dummy_company", + "use_case": "dummy_usecase", + "role": "dummy_role", + "contact_via_email": False + })[0]) @with_mock_server() def test_invalid_auth_token(self, mock_server): @@ -101,18 +121,6 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server): ) self.assertTrue(np.array_equal(pred, dummy_result["y_pred"])) - @with_mock_server() - def test_add_user_information(self, mock_server): - mock_server.router.post(mock_server.endpoints.add_user_information.path).respond(200) - self.assertIsNone(self.client.add_user_information( - "company", "dev", "", True)) - - @with_mock_server() - def test_add_user_information_raises_runtime_error(self, mock_server): - mock_server.router.post(mock_server.endpoints.add_user_information.path).respond(500) - with self.assertRaises(RuntimeError): - self.client.add_user_information("company", "dev", "", True) - def test_validate_response_no_error(self): response = Mock() response.status_code = 200 diff --git a/tabpfn_client/tests/unit/test_service_wrapper.py b/tabpfn_client/tests/unit/test_service_wrapper.py index 66aa76a..ee44800 100644 --- a/tabpfn_client/tests/unit/test_service_wrapper.py +++ b/tabpfn_client/tests/unit/test_service_wrapper.py @@ -72,28 +72,6 @@ def test_try_reusing_non_existing_token(self): # assert token is not set self.assertIsNone(ServiceClient().access_token) - @with_mock_server() - def test_set_token_by_valid_registration(self, mock_server): - # mock valid registration response, and valid login response - dummy_token = "dummy_token" - mock_server.router.post(mock_server.endpoints.register.path).respond( - 200, - json={"message": "doesn't matter"} - ) - mock_server.router.post(mock_server.endpoints.login.path).respond( - 200, - json={"access_token": dummy_token} - ) - - self.assertTrue( - UserAuthenticationClient(ServiceClient()).set_token_by_registration( - "dummy_email", "dummy_password", "dummy_password", "dummy_validation" - )[0] - ) - - # assert token is set - self.assertEqual(dummy_token, ServiceClient().access_token) - @with_mock_server() def test_set_token_by_invalid_registration(self, mock_server): # mock invalid registration response @@ -103,7 +81,12 @@ def test_set_token_by_invalid_registration(self, mock_server): (False, "Password mismatch"), UserAuthenticationClient(ServiceClient()).set_token_by_registration( "dummy_email", "dummy_password", "dummy_password", - "dummy_validation") + "dummy_validation", { + "company": "dummy_company", + "use_case": "dummy_usecase", + "role": "dummy_role", + "contact_via_email": False + }) ) # assert token is not set