From 191f28c6d25447b4be23b691e4416f586b56af4d Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Mon, 27 May 2024 20:49:37 +0100 Subject: [PATCH] Fix `build_index` (#24) --- .github/workflows/CI.yml | 2 ++ CHANGELOG.md | 5 +++ Project.toml | 2 +- src/preparation.jl | 24 +++++++++----- test/preparation.jl | 71 +++++++++++++++++++++++++++++++++++++++- test/runtests.jl | 1 + 6 files changed, 95 insertions(+), 10 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0d3c43b..85f2a6a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -11,6 +11,8 @@ concurrency: # Cancel intermediate builds: only if it is a pull request build. group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +env: + OPENAI_API_KEY: "invalid-key-just-for-testing" jobs: test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ab40d4..963a951 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.1.1] + +### Fixed +- Fixed a bug in `build_index` where imports were missing and keywords were not passed properly in all scenarios. + ## [0.1.0] ### Added diff --git a/Project.toml b/Project.toml index b38f032..9a0bb68 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AIHelpMe" uuid = "01402e1f-dc83-4213-a98b-42887d758baa" authors = ["J S <49557684+svilupp@users.noreply.github.com> and contributors"] -version = "0.1.0" +version = "0.1.1" [deps] HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" diff --git a/src/preparation.jl b/src/preparation.jl index 8ad13fb..73c4bf6 100644 --- a/src/preparation.jl +++ b/src/preparation.jl @@ -97,16 +97,20 @@ function RT.build_index(mod::Module; verbose::Int = 1, kwargs...) ## Extract current configuration chunker_kwargs_ = (; sources = all_sources) chunker_kwargs = haskey(kwargs, :chunker_kwargs) ? - merge(kwargs.chunker_kwargs, chunker_kwargs_) : chunker_kwargs_ + merge(kwargs[:chunker_kwargs], chunker_kwargs_) : chunker_kwargs_ embedder_kwargs_ = RT.getpropertynested( RAG_KWARGS[], [:retriever_kwargs], :embedder_kwargs, nothing) + # Note: force Matrix{Bool} structure for now, switch to Int8-based binary embeddings with the latest PT embedder_kwargs = haskey(kwargs, :embedder_kwargs) ? - merge(kwargs.embedder_kwargs, embedder_kwargs_) : embedder_kwargs_ + merge( + (; return_type = Matrix{Bool}), embedder_kwargs_, kwargs[:embedder_kwargs]) : + merge((; return_type = Matrix{Bool}), embedder_kwargs_) new_index = RT.build_index(RAG_CONFIG[].indexer, all_docs; - embedder_kwargs, chunker = TextChunker(), chunker_kwargs, - verbose, index_id = nameof(mod), kwargs...) + kwargs..., + embedder_kwargs, chunker = RT.TextChunker(), chunker_kwargs, + verbose, index_id = nameof(mod)) end """ @@ -124,14 +128,18 @@ function RT.build_index(modules::Vector{Module} = Base.Docs.modules; verbose::In ## Extract current configuration chunker_kwargs_ = (; sources = all_sources) chunker_kwargs = haskey(kwargs, :chunker_kwargs) ? - merge(kwargs.chunker_kwargs, chunker_kwargs_) : chunker_kwargs_ + merge(kwargs[:chunker_kwargs], chunker_kwargs_) : chunker_kwargs_ + # Note: force Matrix{Bool} structure for now, switch to Int8-based binary embeddings with the latest PT embedder_kwargs_ = RT.getpropertynested( RAG_KWARGS[], [:retriever_kwargs], :embedder_kwargs, nothing) embedder_kwargs = haskey(kwargs, :embedder_kwargs) ? - merge(kwargs.embedder_kwargs, embedder_kwargs_) : embedder_kwargs_ + merge( + (; return_type = Matrix{Bool}), embedder_kwargs_, kwargs[:embedder_kwargs]) : + merge((; return_type = Matrix{Bool}), embedder_kwargs_) new_index = RT.build_index(RAG_CONFIG[].indexer, all_docs; - embedder_kwargs, chunker = TextChunker(), chunker_kwargs, - verbose, index_id = nameof(mod), kwargs...) + kwargs..., + embedder_kwargs, chunker = RT.TextChunker(), chunker_kwargs, + verbose, index_id = :all_modules) end diff --git a/test/preparation.jl b/test/preparation.jl index 7c23f51..112dccd 100644 --- a/test/preparation.jl +++ b/test/preparation.jl @@ -1,4 +1,4 @@ -using AIHelpMe: docextract +using AIHelpMe: docextract, build_index # create an empty module module ABC123 @@ -24,3 +24,72 @@ end @test length(all_sources) == 2 @test occursin("ABC1234", all_sources[2]) end + +@testset "build_index" begin + # test with a mock server + PORT = rand(9000:31000) + PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema()) + PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema()) + PT.register_model!(; name = "mock-gen", schema = PT.CustomOpenAISchema()) + + echo_server = HTTP.serve!(PORT; verbose = -1) do req + content = JSON3.read(req.body) + + if content[:model] == "mock-gen" + user_msg = last(content[:messages]) + response = Dict( + :choices => [ + Dict(:message => user_msg, :finish_reason => "stop") + ], + :model => content[:model], + :usage => Dict(:total_tokens => length(user_msg[:content]), + :prompt_tokens => length(user_msg[:content]), + :completion_tokens => 0)) + elseif content[:model] == "mock-emb" + response = Dict( + :data => [Dict(:embedding => ones(Float32, 1536)) + for i in 1:length(content[:input])], + :usage => Dict(:total_tokens => length(content[:input]), + :prompt_tokens => length(content[:input]), + :completion_tokens => 0)) + elseif content[:model] == "mock-meta" + user_msg = last(content[:messages]) + response = Dict( + :choices => [ + Dict(:finish_reason => "stop", + :message => Dict(:tool_calls => [ + Dict(:function => Dict(:arguments => JSON3.write(MaybeTags([ + Tag("yes", "category") + ]))))]))], + :model => content[:model], + :usage => Dict(:total_tokens => length(user_msg[:content]), + :prompt_tokens => length(user_msg[:content]), + :completion_tokens => 0)) + else + @info content + end + return HTTP.Response(200, JSON3.write(response)) + end + + # One module + index = build_index(AIHelpMe; verbose = 2, embedder_kwargs = (; model = "mock-emb"), + tagger_kwargs = (; model = "mock-meta"), api_kwargs = (; + url = "http://localhost:$(PORT)")) + @test index.embeddings == ones(Bool, 1024, length(index.chunks)) + @test all(x -> occursin("AIHelpMe", x), index.sources) + @test index.tags == nothing + @test index.tags_vocab == nothing + + # Many modules + index = build_index( + [AIHelpMe, Test]; verbose = 2, embedder_kwargs = (; model = "mock-emb"), + tagger_kwargs = (; model = "mock-meta"), api_kwargs = (; + url = "http://localhost:$(PORT)")) + @test index.embeddings == ones(Bool, 1024, length(index.chunks)) + @test all(x -> occursin("AIHelpMe", x) || occursin("Test", x), index.sources) + @test index.tags == nothing + @test index.tags_vocab == nothing + + # clean up + close(echo_server) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index d4365c4..cb200d3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using PromptingTools using PromptingTools.Experimental.RAGTools const PT = PromptingTools const RT = PromptingTools.Experimental.RAGTools +using PromptingTools: HTTP, JSON3 using HDF5, Serialization using Test using Aqua