Skip to content

Commit

Permalink
refactor: extract repository
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasn committed Feb 1, 2025
1 parent 7b05349 commit e0ff9c1
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 19 deletions.
35 changes: 35 additions & 0 deletions lib/features/ai/repositories/ollama_repository.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import 'package:flutter_riverpod/flutter_riverpod.dart';
import 'package:ollama/ollama.dart';
import 'package:riverpod_annotation/riverpod_annotation.dart';

part 'ollama_repository.g.dart';

class OllamaRepository {
OllamaRepository({
Ollama? ollama,
}) : _ollama = ollama ?? Ollama();
final Ollama _ollama;

Stream<CompletionChunk> generate(
String prompt, {
required String model,
required String system,
required double temperature,
List<String>? images,
}) {
return _ollama.generate(
prompt,
model: model,
system: system,
options: ModelOptions(
temperature: temperature,
),
images: images,
);
}
}

@riverpod
OllamaRepository ollamaRepository(Ref ref) {
return OllamaRepository();
}
27 changes: 27 additions & 0 deletions lib/features/ai/repositories/ollama_repository.g.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 8 additions & 11 deletions lib/features/ai/state/ollama_task_summary.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import 'dart:io';
import 'package:lotti/classes/entity_definitions.dart';
import 'package:lotti/classes/journal_entities.dart';
import 'package:lotti/database/database.dart';
import 'package:lotti/features/ai/repositories/ollama_repository.dart';
import 'package:lotti/features/journal/util/entry_tools.dart';
import 'package:lotti/get_it.dart';
import 'package:lotti/logic/persistence_logic.dart';
import 'package:lotti/services/entities_cache_service.dart';
import 'package:lotti/utils/image_utils.dart';
import 'package:ollama/ollama.dart';
import 'package:riverpod_annotation/riverpod_annotation.dart';

part 'ollama_task_summary.g.dart';
Expand Down Expand Up @@ -57,23 +57,20 @@ class AiTaskSummaryController extends _$AiTaskSummaryController {
'Keep it short and succinct. '
'Calculate total time spent on the task. ';

final llm = Ollama();
final buffer = StringBuffer();
final images =
processImages && entry is Task ? await getImages(entry) : null;

const model = 'deepseek-r1:8b'; // TODO: make configurable
const temperature = 0.6;

final stream = llm.generate(
markdown,
model: model,
system: systemMessage,
options: ModelOptions(
temperature: temperature,
),
images: images,
);
final stream = ref.read(ollamaRepositoryProvider).generate(
markdown,
model: model,
system: systemMessage,
temperature: temperature,
images: images,
);

await for (final chunk in stream) {
buffer.write(chunk.text);
Expand Down
2 changes: 1 addition & 1 deletion lib/features/ai/state/ollama_task_summary.g.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

89 changes: 82 additions & 7 deletions test/features/ai/state/ollama_task_summary_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,44 @@ import 'package:lotti/classes/entry_text.dart';
import 'package:lotti/classes/journal_entities.dart';
import 'package:lotti/classes/task.dart';
import 'package:lotti/database/database.dart';
import 'package:lotti/features/ai/repositories/ollama_repository.dart';
import 'package:lotti/features/ai/state/ollama_task_summary.dart';
import 'package:lotti/get_it.dart';
import 'package:lotti/logic/persistence_logic.dart';
import 'package:lotti/services/entities_cache_service.dart';
import 'package:mocktail/mocktail.dart';
import 'package:ollama/ollama.dart';
import 'package:riverpod/riverpod.dart';

import '../../../mocks/mocks.dart';

// Add mock class for OllamaRepository
class MockOllamaRepository extends Mock implements OllamaRepository {}

void main() {
late MockJournalDb mockDb;
late MockPersistenceLogic mockPersistenceLogic;
late MockEntitiesCacheService mockEntitiesCacheService;
late MockOllamaRepository mockOllamaRepository;
late ProviderContainer container;

setUp(() {
mockDb = MockJournalDb();
mockPersistenceLogic = MockPersistenceLogic();
mockEntitiesCacheService = MockEntitiesCacheService();
mockOllamaRepository = MockOllamaRepository();

// Register mocks with GetIt
getIt
..registerSingleton<JournalDb>(mockDb)
..registerSingleton<PersistenceLogic>(mockPersistenceLogic)
..registerSingleton<EntitiesCacheService>(mockEntitiesCacheService);

container = ProviderContainer();
container = ProviderContainer(
overrides: [
ollamaRepositoryProvider.overrideWith((ref) => mockOllamaRepository),
],
);

registerFallbackValue(
const AiResponseData(
Expand Down Expand Up @@ -65,10 +76,11 @@ void main() {
),
data: TaskData(
title: 'Test Task',
status: TaskStatus.done(
id: 'test-task-id',
status: TaskStatus.inProgress(
id: taskId,
createdAt: DateTime.now(),
utcOffset: 60,
timezone: 'Europe/Berlin',
),
dateFrom: DateTime.now(),
dateTo: DateTime.now(),
Expand All @@ -91,11 +103,27 @@ void main() {
vectorClock: null,
);

// Setup mocks
// Setup mocks including OllamaRepository
when(() => mockDb.journalEntityById(taskId)).thenAnswer((_) async => task);
when(() => mockEntitiesCacheService.getCategoryById(categoryId))
.thenReturn(category);
when(() => mockDb.getLinkedEntities(taskId)).thenAnswer((_) async => []);
when(
() => mockOllamaRepository.generate(
any(),
model: any(named: 'model'),
system: any(named: 'system'),
temperature: 0.6,
),
).thenAnswer(
(_) => Stream.fromIterable([
CompletionChunk(
text: '<think>Some thoughts</think>\nAI generated summary',
model: 'deepseek-r1:8b',
createdAt: DateTime.now(),
),
]),
);
when(
() => mockPersistenceLogic.createAiResponseEntry(
data: any(named: 'data'),
Expand All @@ -119,14 +147,22 @@ void main() {
// Verify interactions
verify(() => mockDb.journalEntityById(taskId)).called(3);
verify(() => mockDb.getLinkedEntities(taskId)).called(1);
verify(
() => mockOllamaRepository.generate(
any(),
model: any(named: 'model'),
system: any(named: 'system'),
temperature: 0.6,
),
).called(2);
verify(
() => mockPersistenceLogic.createAiResponseEntry(
data: any(named: 'data'),
dateFrom: any(named: 'dateFrom'),
linkedId: taskId,
categoryId: categoryId,
),
).called(1);
).called(2);
});

test(
Expand All @@ -150,7 +186,15 @@ void main() {
await provider.summarizeEntry();

// Verify interactions
verify(() => mockDb.journalEntityById(taskId)).called(1);
verify(() => mockDb.journalEntityById(taskId)).called(3);
verifyNever(
() => mockOllamaRepository.generate(
any(),
model: any(named: 'model'),
system: any(named: 'system'),
temperature: 0.6,
),
);
verifyNever(
() => mockPersistenceLogic.createAiResponseEntry(
data: any(named: 'data'),
Expand All @@ -160,6 +204,37 @@ void main() {
),
);
},
skip: true,
);

test('getMarkdown handles all task statuses correctly', () {
final now = DateTime.now();
final task = Task(
meta: Metadata(
id: 'test-id',
createdAt: now,
updatedAt: now,
dateFrom: now,
dateTo: now,
),
data: TaskData(
title: 'Test Task',
status: TaskStatus.done(
id: 'test-id',
createdAt: now,
utcOffset: 60,
timezone: 'Europe/Berlin',
),
dateFrom: now,
dateTo: now,
statusHistory: [],
),
entryText: const EntryText(
markdown: 'Test description',
plainText: 'Test description',
),
);

// This should not throw a pattern matching error
expect(task.getMarkdown, returnsNormally);
});
}

0 comments on commit e0ff9c1

Please sign in to comment.