Skip to content

Commit

Permalink
Issue/160 2 user (#162)
Browse files Browse the repository at this point in the history
* TDD: add failing test

* lint

* fix by inserting using on_conflict_do_nothing
  • Loading branch information
peterdudfield authored Oct 1, 2024
1 parent 92503cf commit ceaec24
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 19 deletions.
39 changes: 20 additions & 19 deletions pvsite_datamodel/read/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from typing import List, Optional

from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session, contains_eager

from pvsite_datamodel.sqlmodels import APIRequestSQL, SiteGroupSQL, UserSQL
Expand All @@ -27,21 +28,16 @@ def get_user_by_email(session: Session, email: str, make_new_user_if_none: bool
logger.info(f"User with email {email} not found, so making one")

# checking for site_group
site_group = (
session.query(SiteGroupSQL)
.filter(SiteGroupSQL.site_group_name == f"site_group_for_{email}")
.first()
)
# making a new site group if one doesn't exist
if site_group is None:
site_group = SiteGroupSQL(site_group_name=f"site_group_for_{email}")
session.add(site_group)
session.commit()
site_group_name = f"site_group_for_{email}"
site_group = get_site_group_by_name(session=session, site_group_name=site_group_name)

# make a new user
user = UserSQL(email=email, site_group_uuid=site_group.site_group_uuid)
session.add(user)
session.commit()
stmt = postgresql.insert(UserSQL.__table__)
stmt = stmt.on_conflict_do_nothing()
session.execute(stmt, [{"site_group_uuid": site_group.site_group_uuid, "email": email}])

# get a new user
user = session.query(UserSQL).filter(UserSQL.email == email).first()

return user

Expand All @@ -62,7 +58,7 @@ def get_all_users(session: Session) -> List[UserSQL]:
return users


def get_site_group_by_name(session: Session, site_group_name: str):
def get_site_group_by_name(session: Session, site_group_name: str, create_if_none: bool = True):
"""
Get site group by name. If site group does not exist, make one.
Expand All @@ -75,13 +71,18 @@ def get_site_group_by_name(session: Session, site_group_name: str):
session.query(SiteGroupSQL).filter(SiteGroupSQL.site_group_name == site_group_name).first()
)

if site_group is None:
if (site_group is None) and (create_if_none is True):
logger.info(f"Site group with name {site_group_name} not found, so making one")

# make a new site group
site_group = SiteGroupSQL(site_group_name=site_group_name)
session.add(site_group)
session.commit()
stmt = postgresql.insert(SiteGroupSQL.__table__)
stmt = stmt.on_conflict_do_nothing()
session.execute(stmt, [{"site_group_name": site_group_name}])

site_group = (
session.query(SiteGroupSQL)
.filter(SiteGroupSQL.site_group_name == site_group_name)
.first()
)

return site_group

Expand Down
14 changes: 14 additions & 0 deletions tests/read/test_get_user_by_email.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from concurrent.futures import ThreadPoolExecutor
from functools import partial

from pvsite_datamodel import SiteGroupSQL, UserSQL
from pvsite_datamodel.read import get_user_by_email
from pvsite_datamodel.write.user_and_site import create_site_group, create_user
Expand Down Expand Up @@ -36,3 +39,14 @@ def test_get_user_by_email_no_user_maker_user_false(self, db_session):
)
assert user is None
assert len(db_session.query(UserSQL).all()) == 0

def test_make_user_db_twice(self, db_session):
get_user_by_email_partial = partial(get_user_by_email, db_session, "[email protected]")
tasks = [get_user_by_email_partial for _ in range(5)]

with ThreadPoolExecutor() as executor:
running_tasks = [executor.submit(task) for task in tasks]
for running_task in running_tasks:
running_task.result()

assert len(db_session.query(UserSQL).all()) == 1

0 comments on commit ceaec24

Please sign in to comment.