diff --git a/app/general/tests/test_filter.py b/app/general/tests/test_filter.py index c9d0ca92..3f78cdbf 100644 --- a/app/general/tests/test_filter.py +++ b/app/general/tests/test_filter.py @@ -3,7 +3,7 @@ from django.test import TestCase from general.filters import DocumentFileFilter -from general.models import DocumentFile, Institution, Language, Subject +from general.models import DocumentFile, Institution, Language, Project, Subject class TestSearchFilter(TestCase): @@ -38,30 +38,48 @@ def setUp(self): self.doc2.subjects.add(self.subject2) self.doc2.languages.add(self.language2) + # Create Projects for search testing + self.project1 = Project.objects.create( + name="Project 1", + description="Project 1 description", + institution=self.institution1, + logo="logo1.png", + ) + def test_institution_filter(self): data = {"institution": [self.institution1.id]} filter = DocumentFileFilter(data=data) - self.assertEqual(len(filter.qs), 1) - self.assertIn(self.doc1, filter.qs) + qs = filter.qs + self.assertEqual(len(qs), 1) + self.assertEqual(qs[0]["id"], self.doc1.id) + + def test_institution_filter(self): + data = {"institution": [self.institution1.id]} + filter = DocumentFileFilter(data=data) + qs = filter.qs + self.assertEqual(len(qs), 1) + self.assertEqual(qs[0]["id"], self.doc1.id) def test_document_type_filter(self): data = {"document_type": ["glossary"]} filter = DocumentFileFilter(data=data) - self.assertEqual(len(filter.qs), 2) - self.assertIn(self.doc1, filter.qs) - self.assertIn(self.doc2, filter.qs) + qs = filter.qs + self.assertEqual(len(qs), 2) + self.assertCountEqual([qs[0]["id"], qs[1]["id"]], [self.doc1.id, self.doc2.id]) def test_subjects_filter(self): data = {"subjects": [self.subject1.id]} filter = DocumentFileFilter(data=data) - self.assertEqual(len(filter.qs), 1) - self.assertIn(self.doc1, filter.qs) + qs = filter.qs + self.assertEqual(len(qs), 1) + self.assertEqual(qs[0]["id"], self.doc1.id) def test_languages_filter(self): data = {"languages": [self.language1.id]} filter = DocumentFileFilter(data=data) - self.assertEqual(len(filter.qs), 1) - self.assertIn(self.doc1, filter.qs) + qs = filter.qs + self.assertEqual(len(qs), 1) + self.assertEqual(qs[0]["id"], self.doc1.id) def test_combined_filters(self): data = { @@ -71,20 +89,36 @@ def test_combined_filters(self): "languages": [self.language1.id], } filter = DocumentFileFilter(data=data) - self.assertEqual(len(filter.qs), 1) - self.assertIn(self.doc1, filter.qs) + qs = filter.qs + self.assertEqual(len(qs), 1) + self.assertEqual(qs[0]["id"], self.doc1.id) - def test_search_filter(self): + def test_search_filter_documents(self): data = {"search": "Document"} filter = DocumentFileFilter(data=data) - self.assertEqual(len(filter.qs), 4) - self.assertIn(self.doc1, filter.qs) - self.assertIn(self.doc2, filter.qs) + qs = filter.qs + self.assertEqual(len(qs), 2) + self.assertCountEqual([qs[0]["id"], qs[1]["id"]], [self.doc1.id, self.doc2.id]) data = {"search": "Document 1"} filter = DocumentFileFilter(data=data) - self.assertEqual(len(filter.qs), 2) - self.assertIn(self.doc1, filter.qs) + qs = filter.qs + self.assertEqual(len(qs), 1) + self.assertEqual(qs[0]["id"], self.doc1.id) + + def test_search_filter_projects(self): + data = {"search": "Project 1"} + filter = DocumentFileFilter(data=data) + qs = filter.qs + self.assertEqual(len(qs), 1) + self.assertEqual(qs[0]["id"], self.project1.id) + + def test_search_filter_combined(self): + data = {"search": "1"} + filter = DocumentFileFilter(data=data) + qs = filter.qs + self.assertEqual(len(qs), 2) + self.assertCountEqual([qs[0]["id"], qs[1]["id"]], [self.doc1.id, self.project1.id]) if __name__ == "__main__": diff --git a/app/general/tests/test_view_search.py b/app/general/tests/test_view_search.py index 5d0386ab..21ab0c42 100644 --- a/app/general/tests/test_view_search.py +++ b/app/general/tests/test_view_search.py @@ -48,7 +48,7 @@ def test_search_filtering(self): client = Client() response = client.get(reverse("search"), {"search": "Document 1"}) self.assertEqual(response.status_code, 200) - self.assertEqual(response.context["documents"][0].title, "Document 1") + self.assertEqual(response.context["documents"][0]["heading"], "Document 1") def test_invalid_page_number(self): client = Client()