From 2df6a8242e71a5697353a88ef826c1f043a0a81d Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Mon, 2 Dec 2024 17:56:41 +0000 Subject: [PATCH] fix: prompt decorator crashes AI Gateway when trying to decorate "prepend" --- kong/llm/plugin/ctx.lua | 5 + .../shared-filters/normalize-request.lua | 11 +- .../filters/decorate-prompt.lua | 17 +- .../02-integration_spec.lua | 241 ++++++++++++++---- 4 files changed, 211 insertions(+), 63 deletions(-) diff --git a/kong/llm/plugin/ctx.lua b/kong/llm/plugin/ctx.lua index 69d4a475bc5a..aa227dac794b 100644 --- a/kong/llm/plugin/ctx.lua +++ b/kong/llm/plugin/ctx.lua @@ -139,6 +139,11 @@ local EMPTY_REQUEST_T = _M.immutable_table({}) function _M.get_request_body_table_inuse() local request_body_table + + if _M.has_namespace("decorate-prompt") then -- has ai-prompt-decorator and others in future + request_body_table = _M.get_namespaced_ctx("decorate-prompt", "request_body_table") + end + if _M.has_namespace("normalize-request") then -- has ai-proxy/ai-proxy-advanced request_body_table = _M.get_namespaced_ctx("normalize-request", "request_body_table") end diff --git a/kong/llm/plugin/shared-filters/normalize-request.lua b/kong/llm/plugin/shared-filters/normalize-request.lua index d46764785a12..2242a5737709 100644 --- a/kong/llm/plugin/shared-filters/normalize-request.lua +++ b/kong/llm/plugin/shared-filters/normalize-request.lua @@ -76,7 +76,14 @@ local function validate_and_transform(conf) local model_t = conf_m.model local model_provider = conf.model.provider -- use the one from conf, not the merged one to avoid potential security risk - local request_table = ai_plugin_ctx.get_namespaced_ctx("parse-request", "request_body_table") + local request_table + if ai_plugin_ctx.has_namespace("decorate-prompt") and + ai_plugin_ctx.get_namespaced_ctx("decorate-prompt", "decorated") then + request_table = ai_plugin_ctx.get_namespaced_ctx("decorate-prompt", "request_body_table") + else + request_table = ai_plugin_ctx.get_namespaced_ctx("parse-request", "request_body_table") + end + if not request_table then return bail(400, "content-type header does not match request body, or bad JSON formatting") end @@ -219,4 +226,4 @@ function _M:run(conf) return true end -return _M \ No newline at end of file +return _M diff --git a/kong/plugins/ai-prompt-decorator/filters/decorate-prompt.lua b/kong/plugins/ai-prompt-decorator/filters/decorate-prompt.lua index 525c08223631..69599999e3aa 100644 --- a/kong/plugins/ai-prompt-decorator/filters/decorate-prompt.lua +++ b/kong/plugins/ai-prompt-decorator/filters/decorate-prompt.lua @@ -6,16 +6,17 @@ -- [ END OF LICENSE 0867164ffc95e54f04670b5169c09574bdbd9bba ] local new_tab = require("table.new") -local deep_copy = require("kong.tools.table").deep_copy local ai_plugin_ctx = require("kong.llm.plugin.ctx") +local cycle_aware_deep_copy = require("kong.tools.table").cycle_aware_deep_copy local _M = { NAME = "decorate-prompt", STAGE = "REQ_TRANSFORMATION", - } +} local FILTER_OUTPUT_SCHEMA = { decorated = "boolean", + request_body_table = "table", } local _, set_ctx = ai_plugin_ctx.get_namespaced_accesors(_M.NAME, FILTER_OUTPUT_SCHEMA) @@ -24,7 +25,7 @@ local EMPTY = {} local function bad_request(msg) - kong.log.debug(msg) + kong.log.info(msg) return kong.response.exit(400, { error = { message = msg } }) end @@ -37,9 +38,6 @@ local function execute(request, conf) local prepend = conf.prompts.prepend or EMPTY local append = conf.prompts.append or EMPTY - -- ensure we don't modify the original request - request = deep_copy(request) - local old_messages = request.messages local new_messages = new_tab(#append + #prepend + #old_messages, 0) request.messages = new_messages @@ -81,9 +79,14 @@ function _M:run(conf) return bad_request("this LLM route only supports llm/chat type requests") end - kong.service.request.set_body(execute(request_body_table, conf), "application/json") + -- Deep copy to avoid modifying the immutable table. + -- Re-assign it to trigger GC of the old one and save memory. + request_body_table = execute(cycle_aware_deep_copy(request_body_table), conf) + + kong.service.request.set_body(request_body_table, "application/json") -- legacy set_ctx("decorated", true) + set_ctx("request_body_table", request_body_table) return true end diff --git a/spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua b/spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua index 80c00b1af944..89acb6fd211c 100644 --- a/spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua +++ b/spec/03-plugins/41-ai-prompt-decorator/02-integration_spec.lua @@ -1,23 +1,45 @@ -local helpers = require "spec.helpers" +local helpers = require("spec.helpers") +local cjson = require("cjson") + local PLUGIN_NAME = "ai-prompt-decorator" -for _, strategy in helpers.all_strategies() do +local openai_flat_chat = { + messages = { + { + role = "user", + content = "I think that cheddar is the best cheese.", + }, + { + role = "assistant", + content = "No, brie is the best cheese.", + }, + { + role = "user", + content = "Why brie?", + }, + }, +} + + +for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() local client lazy_setup(function() - local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME }) + local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME, "ctx-checker-last", "ctx-checker" }) + - local route1 = bp.routes:insert({ - hosts = { "test1.com" }, + -- echo route, we don't need a mock AI here + local prepend = bp.routes:insert({ + hosts = { "prepend.decorate.local" }, }) bp.plugins:insert { name = PLUGIN_NAME, - route = { id = route1.id }, + route = { id = prepend.id }, config = { prompts = { prepend = { @@ -30,6 +52,28 @@ for _, strategy in helpers.all_strategies() do content = "Prepend text 2 here.", }, }, + }, + }, + } + + bp.plugins:insert { + name = "ctx-checker-last", + route = { id = prepend.id }, + config = { + ctx_check_field = "ai_namespaced_ctx", + } + } + + + local append = bp.routes:insert({ + hosts = { "append.decorate.local" }, + }) + + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = append.id }, + config = { + prompts = { append = { [1] = { role = "assistant", @@ -44,72 +88,161 @@ for _, strategy in helpers.all_strategies() do }, } + bp.plugins:insert { + name = "ctx-checker-last", + route = { id = append.id }, + config = { + ctx_check_field = "ai_namespaced_ctx", + } + } + + local both = bp.routes:insert({ + hosts = { "both.decorate.local" }, + }) + + + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = both.id }, + config = { + prompts = { + prepend = { + [1] = { + role = "system", + content = "Prepend text 1 here.", + }, + [2] = { + role = "assistant", + content = "Prepend text 2 here.", + }, + }, + append = { + [1] = { + role = "assistant", + content = "Append text 3 here.", + }, + [2] = { + role = "user", + content = "Append text 4 here.", + }, + }, + }, + }, + } + + bp.plugins:insert { + name = "ctx-checker-last", + route = { id = both.id }, + config = { + ctx_check_field = "ai_namespaced_ctx", + } + } + + assert(helpers.start_kong({ database = strategy, nginx_conf = "spec/fixtures/custom_nginx.template", - plugins = "bundled," .. PLUGIN_NAME, + plugins = "bundled,ctx-checker-last,ctx-checker," .. PLUGIN_NAME, declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, })) end) - lazy_teardown(function() - helpers.stop_kong() + helpers.stop_kong(nil, true) end) - before_each(function() client = helpers.proxy_client() end) - after_each(function() if client then client:close() end end) - - - it("blocks a non-chat message", function() - local r = client:get("/request", { - headers = { - host = "test1.com", - ["Content-Type"] = "application/json", - }, - body = [[ - { - "anything": [ - { - "random": "data" - } - ] - }]], - method = "POST", - }) - - assert.response(r).has.status(400) - local json = assert.response(r).has.jsonbody() - assert.same({ error = { message = "this LLM route only supports llm/chat type requests" }}, json) - end) - - - it("blocks an empty messages array", function() - local r = client:get("/request", { - headers = { - host = "test1.com", - ["Content-Type"] = "application/json", - }, - body = [[ - { - "messages": [] - }]], - method = "POST", - }) - - assert.response(r).has.status(400) - local json = assert.response(r).has.jsonbody() - assert.same({ error = { message = "this LLM route only supports llm/chat type requests" }}, json) + describe("request", function() + it("modifies the LLM chat request - prepend", function() + local r = client:get("/", { + headers = { + host = "prepend.decorate.local", + ["Content-Type"] = "application/json" + }, + body = cjson.encode(openai_flat_chat), + }) + + -- get the REQUEST body, that left Kong for the upstream, using the echo system + assert.response(r).has.status(200) + local request = assert.response(r).has.jsonbody() + request = cjson.decode(request.post_data.text) + + assert.same({ content = "Prepend text 1 here.", role = "system" }, request.messages[1]) + assert.same({ content = "Prepend text 2 here.", role = "system" }, request.messages[2]) + + -- check ngx.ctx was set properly for later AI chain filters + local ctx = assert.response(r).has.header("ctx-checker-last-ai-namespaced-ctx") + ctx = ngx.unescape_uri(ctx) + assert.match_re(ctx, [[.*decorate-prompt.*]]) + assert.match_re(ctx, [[.*decorated = true.*]]) + assert.match_re(ctx, [[.*Prepend text 1 here.*]]) + assert.match_re(ctx, [[.*Prepend text 2 here.*]]) + end) + + it("modifies the LLM chat request - append", function() + local r = client:get("/", { + headers = { + host = "append.decorate.local", + ["Content-Type"] = "application/json" + }, + body = cjson.encode(openai_flat_chat), + }) + + -- get the REQUEST body, that left Kong for the upstream, using the echo system + assert.response(r).has.status(200) + local request = assert.response(r).has.jsonbody() + request = cjson.decode(request.post_data.text) + + assert.same({ content = "Append text 1 here.", role = "assistant" }, request.messages[#request.messages-1]) + assert.same({ content = "Append text 2 here.", role = "user" }, request.messages[#request.messages]) + + -- check ngx.ctx was set properly for later AI chain filters + local ctx = assert.response(r).has.header("ctx-checker-last-ai-namespaced-ctx") + ctx = ngx.unescape_uri(ctx) + assert.match_re(ctx, [[.*decorate-prompt.*]]) + assert.match_re(ctx, [[.*decorated = true.*]]) + assert.match_re(ctx, [[.*Append text 1 here.*]]) + assert.match_re(ctx, [[.*Append text 2 here.*]]) + end) + + + it("modifies the LLM chat request - both", function() + local r = client:get("/", { + headers = { + host = "both.decorate.local", + ["Content-Type"] = "application/json" + }, + body = cjson.encode(openai_flat_chat), + }) + + -- get the REQUEST body, that left Kong for the upstream, using the echo system + assert.response(r).has.status(200) + local request = assert.response(r).has.jsonbody() + request = cjson.decode(request.post_data.text) + + assert.same({ content = "Prepend text 1 here.", role = "system" }, request.messages[1]) + assert.same({ content = "Prepend text 2 here.", role = "assistant" }, request.messages[2]) + assert.same({ content = "Append text 3 here.", role = "assistant" }, request.messages[#request.messages-1]) + assert.same({ content = "Append text 4 here.", role = "user" }, request.messages[#request.messages]) + + -- check ngx.ctx was set properly for later AI chain filters + local ctx = assert.response(r).has.header("ctx-checker-last-ai-namespaced-ctx") + ctx = ngx.unescape_uri(ctx) + assert.match_re(ctx, [[.*decorate-prompt.*]]) + assert.match_re(ctx, [[.*decorated = true.*]]) + assert.match_re(ctx, [[.*Prepend text 1 here.*]]) + assert.match_re(ctx, [[.*Prepend text 2 here.*]]) + assert.match_re(ctx, [[.*Append text 3 here.*]]) + assert.match_re(ctx, [[.*Append text 4 here.*]]) + end) end) - end) -end +end end