diff --git a/lib/features/ai/repositories/ollama_repository.dart b/lib/features/ai/repositories/ollama_repository.dart new file mode 100644 index 000000000..b553e8aa4 --- /dev/null +++ b/lib/features/ai/repositories/ollama_repository.dart @@ -0,0 +1,35 @@ +import 'package:flutter_riverpod/flutter_riverpod.dart'; +import 'package:ollama/ollama.dart'; +import 'package:riverpod_annotation/riverpod_annotation.dart'; + +part 'ollama_repository.g.dart'; + +class OllamaRepository { + OllamaRepository({ + Ollama? ollama, + }) : _ollama = ollama ?? Ollama(); + final Ollama _ollama; + + Stream generate( + String prompt, { + required String model, + required String system, + required double temperature, + List? images, + }) { + return _ollama.generate( + prompt, + model: model, + system: system, + options: ModelOptions( + temperature: temperature, + ), + images: images, + ); + } +} + +@riverpod +OllamaRepository ollamaRepository(Ref ref) { + return OllamaRepository(); +} diff --git a/lib/features/ai/repositories/ollama_repository.g.dart b/lib/features/ai/repositories/ollama_repository.g.dart new file mode 100644 index 000000000..11d175fe6 --- /dev/null +++ b/lib/features/ai/repositories/ollama_repository.g.dart @@ -0,0 +1,27 @@ +// GENERATED CODE - DO NOT MODIFY BY HAND + +part of 'ollama_repository.dart'; + +// ************************************************************************** +// RiverpodGenerator +// ************************************************************************** + +String _$ollamaRepositoryHash() => r'5352c524ff4a97ad9f95e72d8bbe90315361161f'; + +/// See also [ollamaRepository]. +@ProviderFor(ollamaRepository) +final ollamaRepositoryProvider = AutoDisposeProvider.internal( + ollamaRepository, + name: r'ollamaRepositoryProvider', + debugGetCreateSourceHash: const bool.fromEnvironment('dart.vm.product') + ? null + : _$ollamaRepositoryHash, + dependencies: null, + allTransitiveDependencies: null, +); + +@Deprecated('Will be removed in 3.0. Use Ref instead') +// ignore: unused_element +typedef OllamaRepositoryRef = AutoDisposeProviderRef; +// ignore_for_file: type=lint +// ignore_for_file: subtype_of_sealed_class, invalid_use_of_internal_member, invalid_use_of_visible_for_testing_member, deprecated_member_use_from_same_package diff --git a/lib/features/ai/state/ollama_task_summary.dart b/lib/features/ai/state/ollama_task_summary.dart index f421839f8..7bb7936fa 100644 --- a/lib/features/ai/state/ollama_task_summary.dart +++ b/lib/features/ai/state/ollama_task_summary.dart @@ -5,12 +5,12 @@ import 'dart:io'; import 'package:lotti/classes/entity_definitions.dart'; import 'package:lotti/classes/journal_entities.dart'; import 'package:lotti/database/database.dart'; +import 'package:lotti/features/ai/repositories/ollama_repository.dart'; import 'package:lotti/features/journal/util/entry_tools.dart'; import 'package:lotti/get_it.dart'; import 'package:lotti/logic/persistence_logic.dart'; import 'package:lotti/services/entities_cache_service.dart'; import 'package:lotti/utils/image_utils.dart'; -import 'package:ollama/ollama.dart'; import 'package:riverpod_annotation/riverpod_annotation.dart'; part 'ollama_task_summary.g.dart'; @@ -57,7 +57,6 @@ class AiTaskSummaryController extends _$AiTaskSummaryController { 'Keep it short and succinct. ' 'Calculate total time spent on the task. '; - final llm = Ollama(); final buffer = StringBuffer(); final images = processImages && entry is Task ? await getImages(entry) : null; @@ -65,15 +64,13 @@ class AiTaskSummaryController extends _$AiTaskSummaryController { const model = 'deepseek-r1:8b'; // TODO: make configurable const temperature = 0.6; - final stream = llm.generate( - markdown, - model: model, - system: systemMessage, - options: ModelOptions( - temperature: temperature, - ), - images: images, - ); + final stream = ref.read(ollamaRepositoryProvider).generate( + markdown, + model: model, + system: systemMessage, + temperature: temperature, + images: images, + ); await for (final chunk in stream) { buffer.write(chunk.text); diff --git a/lib/features/ai/state/ollama_task_summary.g.dart b/lib/features/ai/state/ollama_task_summary.g.dart index 317a1f36a..c39a78029 100644 --- a/lib/features/ai/state/ollama_task_summary.g.dart +++ b/lib/features/ai/state/ollama_task_summary.g.dart @@ -7,7 +7,7 @@ part of 'ollama_task_summary.dart'; // ************************************************************************** String _$aiTaskSummaryControllerHash() => - r'810f3c245ada01b17a3a21b4d4989d30c9b2f430'; + r'9a3e5bae712a5c05fb38b7edf83be9eb417d99f5'; /// Copied from Dart SDK class _SystemHash { diff --git a/pubspec.lock b/pubspec.lock index 2b0c617d4..151fb99a6 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -2446,7 +2446,7 @@ packages: source: hosted version: "1.0.5" riverpod: - dependency: transitive + dependency: "direct main" description: name: riverpod sha256: "59062512288d3056b2321804332a13ffdd1bf16df70dcc8e506e411280a72959" diff --git a/pubspec.yaml b/pubspec.yaml index de8a7de25..13dbafb61 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -1,7 +1,7 @@ name: lotti description: Achieve your goals and keep your data private with Lotti. publish_to: 'none' -version: 0.9.566+2860 +version: 0.9.566+2861 msix_config: display_name: LottiApp @@ -137,6 +137,7 @@ dependencies: url: https://github.com/matthiasn/research.package ref: fc229b9d + riverpod: ^2.6.1 riverpod_annotation: ^2.3.5 riverpod_lint: ^2.3.10 rxdart: ^0.28.0 diff --git a/test/features/ai/state/ollama_task_summary_test.dart b/test/features/ai/state/ollama_task_summary_test.dart new file mode 100644 index 000000000..46c757ff5 --- /dev/null +++ b/test/features/ai/state/ollama_task_summary_test.dart @@ -0,0 +1,240 @@ +import 'package:flutter_test/flutter_test.dart'; +import 'package:lotti/classes/entity_definitions.dart'; +import 'package:lotti/classes/entry_text.dart'; +import 'package:lotti/classes/journal_entities.dart'; +import 'package:lotti/classes/task.dart'; +import 'package:lotti/database/database.dart'; +import 'package:lotti/features/ai/repositories/ollama_repository.dart'; +import 'package:lotti/features/ai/state/ollama_task_summary.dart'; +import 'package:lotti/get_it.dart'; +import 'package:lotti/logic/persistence_logic.dart'; +import 'package:lotti/services/entities_cache_service.dart'; +import 'package:mocktail/mocktail.dart'; +import 'package:ollama/ollama.dart'; +import 'package:riverpod/riverpod.dart'; + +import '../../../mocks/mocks.dart'; + +// Add mock class for OllamaRepository +class MockOllamaRepository extends Mock implements OllamaRepository {} + +void main() { + late MockJournalDb mockDb; + late MockPersistenceLogic mockPersistenceLogic; + late MockEntitiesCacheService mockEntitiesCacheService; + late MockOllamaRepository mockOllamaRepository; + late ProviderContainer container; + + setUp(() { + mockDb = MockJournalDb(); + mockPersistenceLogic = MockPersistenceLogic(); + mockEntitiesCacheService = MockEntitiesCacheService(); + mockOllamaRepository = MockOllamaRepository(); + + // Register mocks with GetIt + getIt + ..registerSingleton(mockDb) + ..registerSingleton(mockPersistenceLogic) + ..registerSingleton(mockEntitiesCacheService); + + container = ProviderContainer( + overrides: [ + ollamaRepositoryProvider.overrideWith((ref) => mockOllamaRepository), + ], + ); + + registerFallbackValue( + const AiResponseData( + model: 'test-model', + systemMessage: 'test-system-message', + prompt: 'test-prompt', + thoughts: 'test-thoughts', + response: 'test-response', + ), + ); + }); + + tearDown(() { + container.dispose(); + getIt.reset(); + }); + + test('summarizeEntry processes task and creates AI response', () async { + const taskId = 'test-task-id'; + const categoryId = 'test-category-id'; + final now = DateTime.now(); + + // Create a mock task + final task = Task( + meta: Metadata( + id: taskId, + createdAt: now, + updatedAt: now, + dateFrom: now, + dateTo: now, + categoryId: categoryId, + ), + data: TaskData( + title: 'Test Task', + status: TaskStatus.inProgress( + id: taskId, + createdAt: DateTime.now(), + utcOffset: 60, + timezone: 'Europe/Berlin', + ), + dateFrom: DateTime.now(), + dateTo: DateTime.now(), + statusHistory: [], + ), + entryText: const EntryText( + markdown: 'Test task description', + plainText: 'Test task description', + ), + ); + + // Mock category + final category = CategoryDefinition( + id: categoryId, + name: 'Test Category', + createdAt: now, + updatedAt: now, + private: false, + active: true, + vectorClock: null, + ); + + // Setup mocks including OllamaRepository + when(() => mockDb.journalEntityById(taskId)).thenAnswer((_) async => task); + when(() => mockEntitiesCacheService.getCategoryById(categoryId)) + .thenReturn(category); + when(() => mockDb.getLinkedEntities(taskId)).thenAnswer((_) async => []); + when( + () => mockOllamaRepository.generate( + any(), + model: any(named: 'model'), + system: any(named: 'system'), + temperature: 0.6, + ), + ).thenAnswer( + (_) => Stream.fromIterable([ + CompletionChunk( + text: 'Some thoughts\nAI generated summary', + model: 'deepseek-r1:8b', + createdAt: DateTime.now(), + ), + ]), + ); + when( + () => mockPersistenceLogic.createAiResponseEntry( + data: any(named: 'data'), + dateFrom: any(named: 'dateFrom'), + linkedId: taskId, + categoryId: categoryId, + ), + ).thenAnswer((_) async => null); + + // Create and watch the provider + final provider = container.read( + aiTaskSummaryControllerProvider( + id: taskId, + processImages: false, + ).notifier, + ); + + // Trigger summarization + await provider.summarizeEntry(); + + // Verify interactions + verify(() => mockDb.journalEntityById(taskId)).called(3); + verify(() => mockDb.getLinkedEntities(taskId)).called(1); + verify( + () => mockOllamaRepository.generate( + any(), + model: any(named: 'model'), + system: any(named: 'system'), + temperature: 0.6, + ), + ).called(2); + verify( + () => mockPersistenceLogic.createAiResponseEntry( + data: any(named: 'data'), + dateFrom: any(named: 'dateFrom'), + linkedId: taskId, + categoryId: categoryId, + ), + ).called(2); + }); + + test( + 'summarizeEntry handles null task gracefully', + () async { + const taskId = 'non-existent-task-id'; + + // Setup mocks + when(() => mockDb.journalEntityById(taskId)) + .thenAnswer((_) async => null); + + // Create and watch the provider + final provider = container.read( + aiTaskSummaryControllerProvider( + id: taskId, + processImages: false, + ).notifier, + ); + + // Trigger summarization + await provider.summarizeEntry(); + + // Verify interactions + verify(() => mockDb.journalEntityById(taskId)).called(3); + verifyNever( + () => mockOllamaRepository.generate( + any(), + model: any(named: 'model'), + system: any(named: 'system'), + temperature: 0.6, + ), + ); + verifyNever( + () => mockPersistenceLogic.createAiResponseEntry( + data: any(named: 'data'), + dateFrom: any(named: 'dateFrom'), + linkedId: any(named: 'linkedId'), + categoryId: any(named: 'categoryId'), + ), + ); + }, + ); + + test('getMarkdown handles all task statuses correctly', () { + final now = DateTime.now(); + final task = Task( + meta: Metadata( + id: 'test-id', + createdAt: now, + updatedAt: now, + dateFrom: now, + dateTo: now, + ), + data: TaskData( + title: 'Test Task', + status: TaskStatus.done( + id: 'test-id', + createdAt: now, + utcOffset: 60, + timezone: 'Europe/Berlin', + ), + dateFrom: now, + dateTo: now, + statusHistory: [], + ), + entryText: const EntryText( + markdown: 'Test description', + plainText: 'Test description', + ), + ); + + // This should not throw a pattern matching error + expect(task.getMarkdown, returnsNormally); + }); +}