Skip to content

Commit

Permalink
Merge pull request #1921 from matthiasn/feat/use_deepseek_r1
Browse files Browse the repository at this point in the history
test: add task summary tests
  • Loading branch information
matthiasn authored Feb 1, 2025
2 parents c05022e + e0ff9c1 commit 330f82a
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 14 deletions.
35 changes: 35 additions & 0 deletions lib/features/ai/repositories/ollama_repository.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import 'package:flutter_riverpod/flutter_riverpod.dart';
import 'package:ollama/ollama.dart';
import 'package:riverpod_annotation/riverpod_annotation.dart';

part 'ollama_repository.g.dart';

class OllamaRepository {
OllamaRepository({
Ollama? ollama,
}) : _ollama = ollama ?? Ollama();
final Ollama _ollama;

Stream<CompletionChunk> generate(
String prompt, {
required String model,
required String system,
required double temperature,
List<String>? images,
}) {
return _ollama.generate(
prompt,
model: model,
system: system,
options: ModelOptions(
temperature: temperature,
),
images: images,
);
}
}

@riverpod
OllamaRepository ollamaRepository(Ref ref) {
return OllamaRepository();
}
27 changes: 27 additions & 0 deletions lib/features/ai/repositories/ollama_repository.g.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 8 additions & 11 deletions lib/features/ai/state/ollama_task_summary.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import 'dart:io';
import 'package:lotti/classes/entity_definitions.dart';
import 'package:lotti/classes/journal_entities.dart';
import 'package:lotti/database/database.dart';
import 'package:lotti/features/ai/repositories/ollama_repository.dart';
import 'package:lotti/features/journal/util/entry_tools.dart';
import 'package:lotti/get_it.dart';
import 'package:lotti/logic/persistence_logic.dart';
import 'package:lotti/services/entities_cache_service.dart';
import 'package:lotti/utils/image_utils.dart';
import 'package:ollama/ollama.dart';
import 'package:riverpod_annotation/riverpod_annotation.dart';

part 'ollama_task_summary.g.dart';
Expand Down Expand Up @@ -57,23 +57,20 @@ class AiTaskSummaryController extends _$AiTaskSummaryController {
'Keep it short and succinct. '
'Calculate total time spent on the task. ';

final llm = Ollama();
final buffer = StringBuffer();
final images =
processImages && entry is Task ? await getImages(entry) : null;

const model = 'deepseek-r1:8b'; // TODO: make configurable
const temperature = 0.6;

final stream = llm.generate(
markdown,
model: model,
system: systemMessage,
options: ModelOptions(
temperature: temperature,
),
images: images,
);
final stream = ref.read(ollamaRepositoryProvider).generate(
markdown,
model: model,
system: systemMessage,
temperature: temperature,
images: images,
);

await for (final chunk in stream) {
buffer.write(chunk.text);
Expand Down
2 changes: 1 addition & 1 deletion lib/features/ai/state/ollama_task_summary.g.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -2446,7 +2446,7 @@ packages:
source: hosted
version: "1.0.5"
riverpod:
dependency: transitive
dependency: "direct main"
description:
name: riverpod
sha256: "59062512288d3056b2321804332a13ffdd1bf16df70dcc8e506e411280a72959"
Expand Down
3 changes: 2 additions & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: lotti
description: Achieve your goals and keep your data private with Lotti.
publish_to: 'none'
version: 0.9.566+2860
version: 0.9.566+2861

msix_config:
display_name: LottiApp
Expand Down Expand Up @@ -137,6 +137,7 @@ dependencies:
url: https://github.com/matthiasn/research.package
ref: fc229b9d

riverpod: ^2.6.1
riverpod_annotation: ^2.3.5
riverpod_lint: ^2.3.10
rxdart: ^0.28.0
Expand Down
240 changes: 240 additions & 0 deletions test/features/ai/state/ollama_task_summary_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import 'package:flutter_test/flutter_test.dart';
import 'package:lotti/classes/entity_definitions.dart';
import 'package:lotti/classes/entry_text.dart';
import 'package:lotti/classes/journal_entities.dart';
import 'package:lotti/classes/task.dart';
import 'package:lotti/database/database.dart';
import 'package:lotti/features/ai/repositories/ollama_repository.dart';
import 'package:lotti/features/ai/state/ollama_task_summary.dart';
import 'package:lotti/get_it.dart';
import 'package:lotti/logic/persistence_logic.dart';
import 'package:lotti/services/entities_cache_service.dart';
import 'package:mocktail/mocktail.dart';
import 'package:ollama/ollama.dart';
import 'package:riverpod/riverpod.dart';

import '../../../mocks/mocks.dart';

// Add mock class for OllamaRepository
class MockOllamaRepository extends Mock implements OllamaRepository {}

void main() {
late MockJournalDb mockDb;
late MockPersistenceLogic mockPersistenceLogic;
late MockEntitiesCacheService mockEntitiesCacheService;
late MockOllamaRepository mockOllamaRepository;
late ProviderContainer container;

setUp(() {
mockDb = MockJournalDb();
mockPersistenceLogic = MockPersistenceLogic();
mockEntitiesCacheService = MockEntitiesCacheService();
mockOllamaRepository = MockOllamaRepository();

// Register mocks with GetIt
getIt
..registerSingleton<JournalDb>(mockDb)
..registerSingleton<PersistenceLogic>(mockPersistenceLogic)
..registerSingleton<EntitiesCacheService>(mockEntitiesCacheService);

container = ProviderContainer(
overrides: [
ollamaRepositoryProvider.overrideWith((ref) => mockOllamaRepository),
],
);

registerFallbackValue(
const AiResponseData(
model: 'test-model',
systemMessage: 'test-system-message',
prompt: 'test-prompt',
thoughts: 'test-thoughts',
response: 'test-response',
),
);
});

tearDown(() {
container.dispose();
getIt.reset();
});

test('summarizeEntry processes task and creates AI response', () async {
const taskId = 'test-task-id';
const categoryId = 'test-category-id';
final now = DateTime.now();

// Create a mock task
final task = Task(
meta: Metadata(
id: taskId,
createdAt: now,
updatedAt: now,
dateFrom: now,
dateTo: now,
categoryId: categoryId,
),
data: TaskData(
title: 'Test Task',
status: TaskStatus.inProgress(
id: taskId,
createdAt: DateTime.now(),
utcOffset: 60,
timezone: 'Europe/Berlin',
),
dateFrom: DateTime.now(),
dateTo: DateTime.now(),
statusHistory: [],
),
entryText: const EntryText(
markdown: 'Test task description',
plainText: 'Test task description',
),
);

// Mock category
final category = CategoryDefinition(
id: categoryId,
name: 'Test Category',
createdAt: now,
updatedAt: now,
private: false,
active: true,
vectorClock: null,
);

// Setup mocks including OllamaRepository
when(() => mockDb.journalEntityById(taskId)).thenAnswer((_) async => task);
when(() => mockEntitiesCacheService.getCategoryById(categoryId))
.thenReturn(category);
when(() => mockDb.getLinkedEntities(taskId)).thenAnswer((_) async => []);
when(
() => mockOllamaRepository.generate(
any(),
model: any(named: 'model'),
system: any(named: 'system'),
temperature: 0.6,
),
).thenAnswer(
(_) => Stream.fromIterable([
CompletionChunk(
text: '<think>Some thoughts</think>\nAI generated summary',
model: 'deepseek-r1:8b',
createdAt: DateTime.now(),
),
]),
);
when(
() => mockPersistenceLogic.createAiResponseEntry(
data: any(named: 'data'),
dateFrom: any(named: 'dateFrom'),
linkedId: taskId,
categoryId: categoryId,
),
).thenAnswer((_) async => null);

// Create and watch the provider
final provider = container.read(
aiTaskSummaryControllerProvider(
id: taskId,
processImages: false,
).notifier,
);

// Trigger summarization
await provider.summarizeEntry();

// Verify interactions
verify(() => mockDb.journalEntityById(taskId)).called(3);
verify(() => mockDb.getLinkedEntities(taskId)).called(1);
verify(
() => mockOllamaRepository.generate(
any(),
model: any(named: 'model'),
system: any(named: 'system'),
temperature: 0.6,
),
).called(2);
verify(
() => mockPersistenceLogic.createAiResponseEntry(
data: any(named: 'data'),
dateFrom: any(named: 'dateFrom'),
linkedId: taskId,
categoryId: categoryId,
),
).called(2);
});

test(
'summarizeEntry handles null task gracefully',
() async {
const taskId = 'non-existent-task-id';

// Setup mocks
when(() => mockDb.journalEntityById(taskId))
.thenAnswer((_) async => null);

// Create and watch the provider
final provider = container.read(
aiTaskSummaryControllerProvider(
id: taskId,
processImages: false,
).notifier,
);

// Trigger summarization
await provider.summarizeEntry();

// Verify interactions
verify(() => mockDb.journalEntityById(taskId)).called(3);
verifyNever(
() => mockOllamaRepository.generate(
any(),
model: any(named: 'model'),
system: any(named: 'system'),
temperature: 0.6,
),
);
verifyNever(
() => mockPersistenceLogic.createAiResponseEntry(
data: any(named: 'data'),
dateFrom: any(named: 'dateFrom'),
linkedId: any(named: 'linkedId'),
categoryId: any(named: 'categoryId'),
),
);
},
);

test('getMarkdown handles all task statuses correctly', () {
final now = DateTime.now();
final task = Task(
meta: Metadata(
id: 'test-id',
createdAt: now,
updatedAt: now,
dateFrom: now,
dateTo: now,
),
data: TaskData(
title: 'Test Task',
status: TaskStatus.done(
id: 'test-id',
createdAt: now,
utcOffset: 60,
timezone: 'Europe/Berlin',
),
dateFrom: now,
dateTo: now,
statusHistory: [],
),
entryText: const EntryText(
markdown: 'Test description',
plainText: 'Test description',
),
);

// This should not throw a pattern matching error
expect(task.getMarkdown, returnsNormally);
});
}

0 comments on commit 330f82a

Please sign in to comment.