Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: LCFS - Reformat In-App Notifications tied to Email Subscriptions into AG Grid #1464

Merged
merged 13 commits into from
Dec 17, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""update notification message model

Revision ID: f93546eaec61
Revises: 5d729face5ab
Create Date: 2024-12-17 11:23:19.563138

"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "f93546eaec61"
down_revision = "5d729face5ab"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("notification_message", sa.Column("type", sa.Text(), nullable=False))
op.add_column(
"notification_message",
sa.Column("related_transaction_id", sa.Text(), nullable=False),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("notification_message", "related_transaction_id")
op.drop_column("notification_message", "type")
# ### end Alembic commands ###
8 changes: 3 additions & 5 deletions backend/lcfs/db/models/notification/NotificationMessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class NotificationMessage(BaseModel, Auditable):
is_warning = Column(Boolean, default=False)
is_error = Column(Boolean, default=False)
is_archived = Column(Boolean, default=False)
type = Column(Text, nullable=False)
message = Column(Text, nullable=False)

related_organization_id = Column(
Expand All @@ -32,12 +33,9 @@ class NotificationMessage(BaseModel, Auditable):
notification_type_id = Column(
Integer, ForeignKey("notification_type.notification_type_id")
)
related_transaction_id = Column(Text, nullable=False)

# Models not created yet
# related_transaction_id = Column(Integer,ForeignKey(''))
# related_document_id = Column(Integer, ForeignKey('document.id'))
# related_report_id = Column(Integer, ForeignKey('compliance_report.id'))

# Relationships
related_organization = relationship(
"Organization", back_populates="notification_messages"
)
Expand Down
27 changes: 17 additions & 10 deletions backend/lcfs/tests/compliance_report/test_update_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def mock_user_has_roles():
def mock_notification_service():
mock_service = AsyncMock(spec=NotificationService)
with patch(
"lcfs.web.api.compliance_report.update_service.Depends",
return_value=mock_service
"lcfs.web.api.compliance_report.update_service.Depends",
return_value=mock_service,
):
yield mock_service

Expand All @@ -47,6 +47,7 @@ def mock_environment_vars():
mock_settings.ches_sender_name = "Mock Notification System"
yield mock_settings


# Mock for adjust_balance method within the OrganizationsService
@pytest.fixture
def mock_org_service():
Expand All @@ -66,6 +67,8 @@ async def test_update_compliance_report_status_change(
mock_report.compliance_report_id = report_id
mock_report.current_status = MagicMock(spec=ComplianceReportStatus)
mock_report.current_status.status = ComplianceReportStatusEnum.Draft
mock_report.compliance_period = MagicMock()
mock_report.compliance_period.description = "2024"

new_status = MagicMock(spec=ComplianceReportStatus)
new_status.status = ComplianceReportStatusEnum.Submitted
Expand All @@ -78,8 +81,8 @@ async def test_update_compliance_report_status_change(
mock_repo.get_compliance_report_by_id.return_value = mock_report
mock_repo.get_compliance_report_status_by_desc.return_value = new_status
compliance_report_update_service.handle_status_change = AsyncMock()
compliance_report_update_service.notfn_service = mock_notification_service
mock_repo.update_compliance_report.return_value = mock_report
compliance_report_update_service._perform_notificaiton_call = AsyncMock()

# Call the method
updated_report = await compliance_report_update_service.update_compliance_report(
Expand All @@ -101,10 +104,9 @@ async def test_update_compliance_report_status_change(
mock_report, compliance_report_update_service.request.user
)
mock_repo.update_compliance_report.assert_called_once_with(mock_report)

assert mock_report.current_status == new_status
assert mock_report.supplemental_note == report_data.supplemental_note
mock_notification_service.send_notification.assert_called_once()
compliance_report_update_service._perform_notificaiton_call.assert_called_once_with(
mock_report, "Submitted"
)


@pytest.mark.anyio
Expand All @@ -118,6 +120,10 @@ async def test_update_compliance_report_no_status_change(
mock_report.current_status = MagicMock(spec=ComplianceReportStatus)
mock_report.current_status.status = ComplianceReportStatusEnum.Draft

# Fix for JSON serialization
mock_report.compliance_period = MagicMock()
mock_report.compliance_period.description = "2024"

report_data = ComplianceReportUpdateSchema(
status="Draft", supplemental_note="Test note"
)
Expand All @@ -131,6 +137,7 @@ async def test_update_compliance_report_no_status_change(

# Mock the handle_status_change method
compliance_report_update_service.handle_status_change = AsyncMock()
compliance_report_update_service._perform_notificaiton_call = AsyncMock()

# Call the method
updated_report = await compliance_report_update_service.update_compliance_report(
Expand All @@ -148,9 +155,9 @@ async def test_update_compliance_report_no_status_change(
compliance_report_update_service.handle_status_change.assert_not_called()
mock_repo.add_compliance_report_history.assert_not_called()
mock_repo.update_compliance_report.assert_called_once_with(mock_report)

assert mock_report.current_status == mock_report.current_status
assert mock_report.supplemental_note == report_data.supplemental_note
compliance_report_update_service._perform_notificaiton_call.assert_called_once_with(
mock_report, "Draft"
)


@pytest.mark.anyio
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,23 @@ async def test_get_initiative_agreement(service, mock_repo):
mock_repo.get_initiative_agreement_by_id.assert_called_once_with(1)


@pytest.mark.anyio
@pytest.mark.anyio
async def test_create_initiative_agreement(service, mock_repo, mock_request):
# Mock status for the initiative agreement
mock_status = MagicMock(status=InitiativeAgreementStatusEnum.Recommended)
mock_repo.get_initiative_agreement_status_by_name.return_value = mock_status
mock_repo.create_initiative_agreement.return_value = MagicMock(
spec=InitiativeAgreement
)

# Create a mock initiative agreement with serializable fields
mock_initiative_agreement = MagicMock(spec=InitiativeAgreement)
mock_initiative_agreement.initiative_agreement_id = 1
mock_initiative_agreement.current_status.status = "Recommended"
mock_initiative_agreement.to_organization_id = 3

# Mock return value of create_initiative_agreement
mock_repo.create_initiative_agreement.return_value = mock_initiative_agreement

# Create input data
create_data = InitiativeAgreementCreateSchema(
compliance_units=150,
current_status="Recommended",
Expand All @@ -104,10 +113,18 @@ async def test_create_initiative_agreement(service, mock_repo, mock_request):
internal_comment=None,
)

# Mock _perform_notificaiton_call to isolate it
service._perform_notificaiton_call = AsyncMock()

# Call the service method
result = await service.create_initiative_agreement(create_data)

assert isinstance(result, InitiativeAgreement)
# Assertions
assert result == mock_initiative_agreement
mock_repo.create_initiative_agreement.assert_called_once()
service._perform_notificaiton_call.assert_called_once_with(
mock_initiative_agreement
)


@pytest.mark.anyio
Expand Down
107 changes: 88 additions & 19 deletions backend/lcfs/tests/notification/test_notification_repo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from lcfs.db.models.notification.NotificationChannel import ChannelEnum
from lcfs.web.api.base import NotificationTypeEnum, PaginationRequestSchema
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import delete
Expand Down Expand Up @@ -92,34 +94,21 @@ async def mock_execute(*args, **kwargs):
@pytest.mark.anyio
async def test_get_notification_messages_by_user(notification_repo, mock_db_session):
mock_notification1 = MagicMock(spec=NotificationMessage)
mock_notification1.related_user_id = 1
mock_notification1.origin_user_id = 2
mock_notification1.notification_message_id = 1
mock_notification1.message = "Test message 1"

mock_notification2 = MagicMock(spec=NotificationMessage)
mock_notification2.related_user_id = 1
mock_notification2.origin_user_id = 2
mock_notification2.notification_message_id = 2
mock_notification2.message = "Test message 2"

mock_result_chain = MagicMock()
mock_result_chain.scalars.return_value.all.return_value = [
mock_result = MagicMock()
mock_result.unique.return_value.scalars.return_value.all.return_value = [
mock_notification1,
mock_notification2,
]

async def mock_execute(*args, **kwargs):
return mock_result_chain

# Inject the mocked execute method into the session
mock_db_session.execute = mock_execute
mock_db_session.execute = AsyncMock(return_value=mock_result)

result = await notification_repo.get_notification_messages_by_user(1)

assert len(result) == 2
assert result[0].notification_message_id == 1
assert result[1].notification_message_id == 2
assert result == [mock_notification1, mock_notification2]
mock_db_session.execute.assert_called_once()


@pytest.mark.anyio
Expand Down Expand Up @@ -158,7 +147,7 @@ async def test_delete_notification_message(notification_repo, mock_db_session):
NotificationMessage.notification_message_id == notification_id
)
assert str(executed_query) == str(expected_query)

mock_db_session.execute.assert_called_once()
mock_db_session.flush.assert_called_once()


Expand Down Expand Up @@ -277,3 +266,83 @@ async def mock_execute(*args, **kwargs):

assert result is not None
assert result.notification_channel_subscription_id == subscription_id


@pytest.mark.anyio
async def test_create_notification_messages(notification_repo, mock_db_session):
messages = [
MagicMock(spec=NotificationMessage),
MagicMock(spec=NotificationMessage),
]

await notification_repo.create_notification_messages(messages)

mock_db_session.add_all.assert_called_once_with(messages)
mock_db_session.flush.assert_called_once()


@pytest.mark.anyio
async def test_mark_notifications_as_read(notification_repo, mock_db_session):
user_id = 1
notification_ids = [1, 2, 3]

mock_db_session.execute = AsyncMock()
mock_db_session.flush = AsyncMock()

result = await notification_repo.mark_notifications_as_read(
user_id, notification_ids
)

assert result == notification_ids
mock_db_session.execute.assert_called_once()
mock_db_session.flush.assert_called_once()


@pytest.mark.anyio
async def test_get_notification_type_by_name(notification_repo, mock_db_session):
# Create a mock result that properly simulates the SQLAlchemy result
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.first.return_value = 123
mock_result.scalars.return_value = mock_scalars

mock_db_session.execute = AsyncMock(return_value=mock_result)

result = await notification_repo.get_notification_type_by_name("TestNotification")

assert result == 123
mock_db_session.execute.assert_called_once()


@pytest.mark.anyio
async def test_get_notification_channel_by_name(notification_repo, mock_db_session):
# Similar setup to the previous test
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.first.return_value = 456
mock_result.scalars.return_value = mock_scalars

mock_db_session.execute = AsyncMock(return_value=mock_result)

result = await notification_repo.get_notification_channel_by_name(ChannelEnum.EMAIL)

assert result == 456
mock_db_session.execute.assert_called_once()


@pytest.mark.anyio
async def test_get_subscribed_users_by_channel(notification_repo, mock_db_session):
# Similar setup, but using .all() instead of .first()
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = [1, 2, 3]
mock_result.scalars.return_value = mock_scalars

mock_db_session.execute = AsyncMock(return_value=mock_result)

result = await notification_repo.get_subscribed_users_by_channel(
NotificationTypeEnum.BCEID__TRANSFER__PARTNER_ACTIONS, ChannelEnum.EMAIL
)

assert result == [1, 2, 3]
mock_db_session.execute.assert_called_once()
18 changes: 18 additions & 0 deletions backend/lcfs/tests/transfer/test_transfer_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,15 @@ async def test_update_transfer_success(
):
transfer_status = TransferStatus(transfer_status_id=1, status="status")
transfer_id = 1
# Create valid nested organization objects
from_org = Organization(organization_id=1, name="org1")
to_org = Organization(organization_id=2, name="org2")

# Create a Transfer object with the necessary attributes
transfer = Transfer(
transfer_id=transfer_id,
from_organization=from_org,
to_organization=to_org,
from_organization_id=1,
to_organization_id=2,
from_transaction_id=1,
Expand All @@ -114,11 +121,22 @@ async def test_update_transfer_success(
mock_transfer_repo.get_transfer_by_id.return_value = transfer
mock_transfer_repo.update_transfer.return_value = transfer

# Replace _perform_notificaiton_call with an AsyncMock
transfer_service._perform_notificaiton_call = AsyncMock()

result = await transfer_service.update_transfer(transfer)

# Assertions
assert result.transfer_id == transfer_id
assert isinstance(result, Transfer)

# Verify mocks
mock_transfer_repo.get_transfer_by_id.assert_called_once_with(transfer_id)
mock_transfer_repo.update_transfer.assert_called_once_with(transfer)
transfer_service._perform_notificaiton_call.assert_awaited_once_with(
transfer, status="Return to analyst"
)


@pytest.mark.anyio
async def test_update_category_success(transfer_service, mock_transfer_repo):
Expand Down
Loading