Skip to content

Commit

Permalink
Removed Login after Registration and Additional Info Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
anshulg954 committed May 8, 2024
1 parent 5be3d36 commit 66a1c97
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 65 deletions.
21 changes: 0 additions & 21 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,27 +348,6 @@ def get_password_policy(self) -> {}:

return response.json()["requirements"]

def add_user_information(
self, company: str | None, role: str | None, use_case: str | None, contact_via_email: bool
):
"""
Send additional user information to the server.
"""
information = {"contact_via_email": contact_via_email}
if company:
information["company"] = company
if role:
information["role"] = role
if use_case:
information["use_case"] = use_case

response = self.httpx_client.post(
self.server_endpoints.add_user_information.path,
json=information
)

self._validate_response(response, "add_user_information")

def retrieve_greeting_messages(self) -> list[str]:
"""
Retrieve greeting messages that are new for the user.
Expand Down
5 changes: 0 additions & 5 deletions tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,6 @@ def try_reuse_existing_token(self) -> bool:
def get_password_policy(self):
return self.service_client.get_password_policy()

def add_user_information(
self, company: str | None, role: str | None, use_case: str | None, contact_via_email: bool
):
self.service_client.add_user_information(company, role, use_case, contact_via_email)

def reset_cache(self):
self._reset_token()

Expand Down
40 changes: 24 additions & 16 deletions tabpfn_client/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,42 @@ def test_validate_email_invalid(self, mock_server):
@with_mock_server()
def test_register_user(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(200, json={"message": "dummy_message"})
self.assertTrue(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation")[0])
self.assertTrue(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation", {
"company": "dummy_company",
"use_case": "dummy_usecase",
"role": "dummy_role",
"contact_via_email": False
})[0])

@with_mock_server()
def test_register_user_with_invalid_email(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(401, json={"detail": "dummy_message"})
self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation")[0])
self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation", {
"company": "dummy_company",
"use_case": "dummy_usecase",
"role": "dummy_role",
"contact_via_email": False
})[0])

@with_mock_server()
def test_register_user_with_invalid_validation_link(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(401, json={"detail": "dummy_message"})
self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation")[0])
self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation", {
"company": "dummy_company",
"use_case": "dummy_usecase",
"role": "dummy_role",
"contact_via_email": False
})[0])

@with_mock_server()
def test_register_user_with_limit_reached(self, mock_server):
mock_server.router.post(mock_server.endpoints.register.path).respond(401, json={"detail": "dummy_message"})
self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation")[0])
self.assertFalse(self.client.register("dummy_email", "dummy_password", "dummy_password", "dummy_validation", {
"company": "dummy_company",
"use_case": "dummy_usecase",
"role": "dummy_role",
"contact_via_email": False
})[0])

@with_mock_server()
def test_invalid_auth_token(self, mock_server):
Expand Down Expand Up @@ -101,18 +121,6 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server):
)
self.assertTrue(np.array_equal(pred, dummy_result["y_pred"]))

@with_mock_server()
def test_add_user_information(self, mock_server):
mock_server.router.post(mock_server.endpoints.add_user_information.path).respond(200)
self.assertIsNone(self.client.add_user_information(
"company", "dev", "", True))

@with_mock_server()
def test_add_user_information_raises_runtime_error(self, mock_server):
mock_server.router.post(mock_server.endpoints.add_user_information.path).respond(500)
with self.assertRaises(RuntimeError):
self.client.add_user_information("company", "dev", "", True)

def test_validate_response_no_error(self):
response = Mock()
response.status_code = 200
Expand Down
29 changes: 6 additions & 23 deletions tabpfn_client/tests/unit/test_service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,28 +72,6 @@ def test_try_reusing_non_existing_token(self):
# assert token is not set
self.assertIsNone(ServiceClient().access_token)

@with_mock_server()
def test_set_token_by_valid_registration(self, mock_server):
# mock valid registration response, and valid login response
dummy_token = "dummy_token"
mock_server.router.post(mock_server.endpoints.register.path).respond(
200,
json={"message": "doesn't matter"}
)
mock_server.router.post(mock_server.endpoints.login.path).respond(
200,
json={"access_token": dummy_token}
)

self.assertTrue(
UserAuthenticationClient(ServiceClient()).set_token_by_registration(
"dummy_email", "dummy_password", "dummy_password", "dummy_validation"
)[0]
)

# assert token is set
self.assertEqual(dummy_token, ServiceClient().access_token)

@with_mock_server()
def test_set_token_by_invalid_registration(self, mock_server):
# mock invalid registration response
Expand All @@ -103,7 +81,12 @@ def test_set_token_by_invalid_registration(self, mock_server):
(False, "Password mismatch"),
UserAuthenticationClient(ServiceClient()).set_token_by_registration(
"dummy_email", "dummy_password", "dummy_password",
"dummy_validation")
"dummy_validation", {
"company": "dummy_company",
"use_case": "dummy_usecase",
"role": "dummy_role",
"contact_via_email": False
})
)

# assert token is not set
Expand Down

0 comments on commit 66a1c97

Please sign in to comment.