From c9277b982a8d71ae2f6637853a76da3facb1ecdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20K=C3=B6rber?= <56073945+jakobkoerber@users.noreply.github.com> Date: Tue, 3 Oct 2023 16:21:17 +0200 Subject: [PATCH] Add Search to Web, Fix Search for Unauthorized User (#96) --- lib/base/enums/search_category.dart | 9 ++ lib/navigation.dart | 4 +- .../viewmodels/global_search_viewmodel.dart | 103 +++++++++--------- .../views/search_body_view.dart | 1 + .../views/search_category_picker_view.dart | 28 +++-- 5 files changed, 80 insertions(+), 65 deletions(-) diff --git a/lib/base/enums/search_category.dart b/lib/base/enums/search_category.dart index 21d8e121..f23386ca 100644 --- a/lib/base/enums/search_category.dart +++ b/lib/base/enums/search_category.dart @@ -41,4 +41,13 @@ extension SearchCategoryExtension on SearchCategory { static List lectureSearch() { return [SearchCategory.personalLectures, SearchCategory.lectures]; } + + static List unAuthorizedSearch() { + return [ + SearchCategory.studyRoom, + SearchCategory.cafeterias, + SearchCategory.movie, + SearchCategory.news + ]; + } } diff --git a/lib/navigation.dart b/lib/navigation.dart index 61768e2f..6f4c6559 100644 --- a/lib/navigation.dart +++ b/lib/navigation.dart @@ -96,7 +96,9 @@ class _NavigationState extends ConsumerState { }()), actions: [ if (kIsWeb && isLandScape) - IconButton(onPressed: () {}, icon: const Icon(Icons.search)), + IconButton( + onPressed: () => _toggleSearch(), + icon: const Icon(Icons.search)), IconButton( onPressed: () { Navigator.of(context).push(MaterialPageRoute( diff --git a/lib/searchComponent/viewmodels/global_search_viewmodel.dart b/lib/searchComponent/viewmodels/global_search_viewmodel.dart index ef9d2954..91ebe254 100644 --- a/lib/searchComponent/viewmodels/global_search_viewmodel.dart +++ b/lib/searchComponent/viewmodels/global_search_viewmodel.dart @@ -5,6 +5,7 @@ import 'package:campus_flutter/base/enums/search_category.dart'; import 'package:campus_flutter/loginComponent/viewModels/login_viewmodel.dart'; import 'package:campus_flutter/providers_get_it.dart'; import 'package:campus_flutter/searchComponent/model/vocab.dart'; +import 'package:flutter/foundation.dart'; import 'package:flutter/services.dart'; import 'package:flutter_riverpod/flutter_riverpod.dart'; import 'package:rxdart/rxdart.dart'; @@ -24,8 +25,10 @@ class GlobalSearchViewModel { late HashMap vocab2; GlobalSearchViewModel(this.ref) { - loadVocabulary(); - initializeNaturalLanguageModel(); + if (!kIsWeb) { + loadVocabulary(); + initializeNaturalLanguageModel(); + } } Future initializeNaturalLanguageModel() async { @@ -50,37 +53,11 @@ class GlobalSearchViewModel { this.searchString = searchString; if (selectedCategories.value.isEmpty) { if (index == 0) { - final tokens = await tokenizeBert(searchString); - var output = List.filled(1 * 6, 0).reshape([1, 6]); - interpreter.run(tokens, output); - final probabilities = output.first as List; - final List categoryNames = [ - "cafeterias", - "calendar", - "grade", - "movie", - "news", - "studyroom" - ]; - final categories = Map.fromIterables(categoryNames, probabilities); - if (ref.read(loginViewModel).credentials.value != Credentials.tumId) { - categories - .removeWhere((key, value) => ["calendar, grade"].contains(key)); - } - List> sortedEntries = categories.entries - .toList() - ..sort((a, b) => b.value.compareTo(a.value)); - final sortedCategories = Map.fromEntries(sortedEntries) - .keys - .map((key) => SearchCategory.fromString(key)) - .toList(); - - /// if authenticated add lecture and person search - if (ref.read(loginViewModel).credentials.value == Credentials.tumId) { - sortedCategories - .addAll([SearchCategory.lectures, SearchCategory.persons]); + if (!kIsWeb) { + _textClassificationModel(searchString); + } else { + _webSearch(searchString); } - result.add(sortedCategories); } else { switch (index) { case 1: @@ -99,6 +76,46 @@ class GlobalSearchViewModel { } } + void _webSearch(String searchString) async { + if (ref.read(loginViewModel).credentials.value == Credentials.tumId) { + result.add(SearchCategory.values); + } else { + result.add(SearchCategoryExtension.unAuthorizedSearch()); + } + } + + Future _textClassificationModel(String searchString) async { + final tokens = await tokenizeBert(searchString); + var output = List.filled(1 * 6, 0).reshape([1, 6]); + interpreter.run(tokens, output); + final probabilities = output.first as List; + final List categoryNames = [ + "cafeterias", + "calendar", + "grade", + "movie", + "news", + "studyroom" + ]; + final categories = Map.fromIterables(categoryNames, probabilities); + if (ref.read(loginViewModel).credentials.value != Credentials.tumId) { + categories.removeWhere((key, value) => ["calendar, grade"].contains(key)); + } + List> sortedEntries = categories.entries.toList() + ..sort((a, b) => b.value.compareTo(a.value)); + final sortedCategories = Map.fromEntries(sortedEntries) + .keys + .map((key) => SearchCategory.fromString(key)) + .toList(); + + /// if authenticated add lecture and person search + if (ref.read(loginViewModel).credentials.value == Credentials.tumId) { + sortedCategories + .addAll([SearchCategory.lectures, SearchCategory.persons]); + } + result.add(sortedCategories); + } + Future> tokenizeBert(String input) async { List substrings = input.split(" "); List tokenized = []; @@ -137,26 +154,14 @@ class GlobalSearchViewModel { return tokenized; } - void addCategory(SearchCategory searchCategory) { - if (!selectedCategories.value.contains(searchCategory)) { - final categories = selectedCategories.value; - categories.add(searchCategory); - if (searchString.isEmpty) { - search(index, searchString); - } - selectedCategories.add(categories); - } - } - - void removeCategory(SearchCategory searchCategory) { + void updateCategory(SearchCategory searchCategory) { + final categories = selectedCategories.value; if (selectedCategories.value.contains(searchCategory)) { - final categories = selectedCategories.value; categories.remove(searchCategory); - if (searchString.isEmpty) { - search(index, searchString); - } - selectedCategories.add(categories); + } else { + categories.add(searchCategory); } + selectedCategories.add(categories); } void clear() { diff --git a/lib/searchComponent/views/search_body_view.dart b/lib/searchComponent/views/search_body_view.dart index e3db34c0..4666204d 100644 --- a/lib/searchComponent/views/search_body_view.dart +++ b/lib/searchComponent/views/search_body_view.dart @@ -39,6 +39,7 @@ class SearchView extends ConsumerWidget { child: StreamBuilder( stream: ref.watch(searchViewModel).result, builder: (context, snapshot) { + print(snapshot); if (!snapshot.hasData && textEditingController.text.isEmpty) { return const Center(child: Text("Enter a Query to Start")); } else { diff --git a/lib/searchComponent/views/search_category_picker_view.dart b/lib/searchComponent/views/search_category_picker_view.dart index 988d159e..b29d6d27 100644 --- a/lib/searchComponent/views/search_category_picker_view.dart +++ b/lib/searchComponent/views/search_category_picker_view.dart @@ -1,4 +1,5 @@ import 'package:campus_flutter/base/helpers/horizontal_slider.dart'; +import 'package:campus_flutter/loginComponent/viewModels/login_viewmodel.dart'; import 'package:campus_flutter/providers_get_it.dart'; import 'package:campus_flutter/base/enums/search_category.dart'; import 'package:flutter/material.dart'; @@ -17,21 +18,14 @@ class SearchCategoryPickerView extends ConsumerWidget { return Padding( padding: const EdgeInsets.symmetric(vertical: 10.0), child: HorizontalSlider( - data: _getData(snapshot.data ?? []), + data: _getData(snapshot.data ?? [], ref), height: 40, child: (searchCategory) => FilterChip( label: Text(searchCategory.title), onSelected: (selected) { - if (snapshot.data?.contains(searchCategory) ?? - false) { - ref - .read(searchViewModel) - .removeCategory(searchCategory); - } else { - ref - .read(searchViewModel) - .addCategory(searchCategory); - } + ref + .read(searchViewModel) + .updateCategory(searchCategory); ref .read(searchViewModel) .triggerSearchAfterUpdate(null, null); @@ -43,14 +37,18 @@ class SearchCategoryPickerView extends ConsumerWidget { }); } - List _getData(List data) { + List _getData(List data, WidgetRef ref) { List searchCategories = []; if (index == 2) { searchCategories = SearchCategoryExtension.lectureSearch(); } else { - searchCategories = SearchCategory.values - .where((element) => element != SearchCategory.unknown) - .toList(); + if (ref.read(loginViewModel).credentials.value == Credentials.tumId) { + searchCategories = SearchCategory.values + .where((element) => element != SearchCategory.unknown) + .toList(); + } else { + searchCategories = SearchCategoryExtension.unAuthorizedSearch(); + } } searchCategories.sort((a, b) { return data.contains(a) && data.contains(b)