Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add CoT support for DeepSeek-R1 (only for reference) #228

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion lua/gp/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ local config = {
-- secret : "sk-...",
-- secret = os.getenv("env_name.."),
openai = {
disable = false,
disable = true,
endpoint = "https://api.openai.com/v1/chat/completions",
-- secret = os.getenv("OPENAI_API_KEY"),
},
Expand Down Expand Up @@ -103,6 +103,7 @@ local config = {
disable = true,
},
{
provider = "openai",
name = "ChatGPT4o",
chat = true,
command = false,
Expand Down
55 changes: 40 additions & 15 deletions lua/gp/dispatcher.lua
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ end
---@param handler function # response handler
---@param on_exit function | nil # optional on_exit handler
---@param callback function | nil # optional callback handler
local query = function(buf, provider, payload, handler, on_exit, callback)
---@param is_reasoning boolean # whether model is reasoning model
local query = function(buf, provider, payload, handler, on_exit, callback, is_reasoning)
-- make sure handler is a function
if type(handler) ~= "function" then
logger.error(
Expand Down Expand Up @@ -238,9 +239,15 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
qt.raw_response = qt.raw_response .. line .. "\n"
end
line = line:gsub("^data: ", "")

local content = ""
local reasoning_content = ""

if line:match("choices") and line:match("delta") and line:match("content") then
line = vim.json.decode(line)
if line.choices[1] and line.choices[1].delta and line.choices[1].delta.reasoning_content then
reasoning_content = line.choices[1].delta.reasoning_content
end
if line.choices[1] and line.choices[1].delta and line.choices[1].delta.content then
content = line.choices[1].delta.content
end
Expand All @@ -264,10 +271,15 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
end
end


if content and type(content) == "string" then
if reasoning_content ~= "" and type(reasoning_content) == "string" then
handler(qid, reasoning_content, true)
elseif content ~= "" and type(content) == "string" then
if is_reasoning then
handler(qid, "\n</details>\n</think>\n", false)
is_reasoning = false
end
qt.response = qt.response .. content
handler(qid, content)
handler(qid, content, false)
end
end
end
Expand Down Expand Up @@ -311,11 +323,16 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
end
end


if qt.response == "" then
logger.error(qt.provider .. " response is empty: \n" .. vim.inspect(qt.raw_response))
if is_reasoning then
handler(qid, "\n", false)
handler(qid, "\n</details>\n</think>\n", false)
is_reasoning = false
end

-- if qt.response == "" then
-- logger.error(qt.provider .. " response is empty: \n" .. vim.inspect(qt.raw_response))
-- end

-- optional on_exit handler
if type(on_exit) == "function" then
on_exit(qid)
Expand Down Expand Up @@ -393,7 +410,7 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
end

local temp_file = D.query_dir ..
"/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
"/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
helpers.table_to_file(payload, temp_file)

local curl_params = vim.deepcopy(D.config.curl_params or {})
Expand Down Expand Up @@ -425,16 +442,17 @@ end
---@param handler function # response handler
---@param on_exit function | nil # optional on_exit handler
---@param callback function | nil # optional callback handler
D.query = function(buf, provider, payload, handler, on_exit, callback)
---@param is_reasoning boolean # whether the model is reasoning model
D.query = function(buf, provider, payload, handler, on_exit, callback, is_reasoning)
if provider == "copilot" then
return vault.run_with_secret(provider, function()
vault.refresh_copilot_bearer(function()
query(buf, provider, payload, handler, on_exit, callback)
query(buf, provider, payload, handler, on_exit, callback, is_reasoning)
end)
end)
end
vault.run_with_secret(provider, function()
query(buf, provider, payload, handler, on_exit, callback)
query(buf, provider, payload, handler, on_exit, callback, is_reasoning)
end)
end

Expand Down Expand Up @@ -463,7 +481,7 @@ D.create_handler = function(buf, win, line, first_undojoin, prefix, cursor)
})

local response = ""
return vim.schedule_wrap(function(qid, chunk)
return vim.schedule_wrap(function(qid, chunk, is_reasoning)
local qt = tasker.get_query(qid)
if not qt then
return
Expand Down Expand Up @@ -503,6 +521,13 @@ D.create_handler = function(buf, win, line, first_undojoin, prefix, cursor)
lines[i] = prefix .. l
end

-- prepend prefix > to each line inside CoT
if is_reasoning then
for i, l in ipairs(lines) do
lines[i] = "> " .. l
end
end

local unfinished_lines = {}
for i = finished_lines + 1, #lines do
table.insert(unfinished_lines, lines[i])
Expand All @@ -511,9 +536,9 @@ D.create_handler = function(buf, win, line, first_undojoin, prefix, cursor)
vim.api.nvim_buf_set_lines(buf, first_line + finished_lines, first_line + finished_lines, false, unfinished_lines)

local new_finished_lines = math.max(0, #lines - 1)
for i = finished_lines, new_finished_lines do
vim.api.nvim_buf_add_highlight(buf, qt.ns_id, hl_handler_group, first_line + i, 0, -1)
end
-- for i = finished_lines, new_finished_lines do
-- vim.api.nvim_buf_add_highlight(buf, qt.ns_id, hl_handler_group, first_line + i, 0, -1)
-- end
finished_lines = new_finished_lines

local end_line = first_line + #vim.split(response, "\n")
Expand Down
72 changes: 53 additions & 19 deletions lua/gp/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1031,22 +1031,37 @@ M.chat_respond = function(params)
agent_suffix = M.render.template(agent_suffix, { ["{{agent}}"] = agent_name })

local old_default_user_prefix = "🗨:"
local in_cot_block = false -- Flag to track if we're inside a CoT block

for index = start_index, end_index do
local line = lines[index]
if line:sub(1, #M.config.chat_user_prefix) == M.config.chat_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#M.config.chat_user_prefix + 1)
elseif line:sub(1, #old_default_user_prefix) == old_default_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#old_default_user_prefix + 1)
elseif line:sub(1, #agent_prefix) == agent_prefix then
table.insert(messages, { role = role, content = content })
role = "assistant"
content = ""
elseif role ~= "" then
content = content .. "\n" .. line

if line:match("^<think>$") then
in_cot_block = true
end

-- Skip lines if we're inside a CoT block
if not in_cot_block then
-- Original logic for handling chat messages
if line:sub(1, #M.config.chat_user_prefix) == M.config.chat_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#M.config.chat_user_prefix + 1)
elseif line:sub(1, #old_default_user_prefix) == old_default_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#old_default_user_prefix + 1)
elseif line:sub(1, #agent_prefix) == agent_prefix then
table.insert(messages, { role = role, content = content })
role = "assistant"
content = ""
elseif role ~= "" then
content = content .. "\n" .. line
end
end

if line:match("^</think>$") then
in_cot_block = false
end
end
-- insert last message not handled in loop
Expand All @@ -1063,6 +1078,8 @@ M.chat_respond = function(params)
-- make it multiline again if it contains escaped newlines
content = content:gsub("\\n", "\n")
messages[1] = { role = "system", content = content }
else
table.remove(messages, 1)
end

-- strip whitespace from ends of content
Expand All @@ -1074,12 +1091,23 @@ M.chat_respond = function(params)
local last_content_line = M.helpers.last_content_line(buf)
vim.api.nvim_buf_set_lines(buf, last_content_line, last_content_line, false, { "", agent_prefix .. agent_suffix, "" })

local offset = 0
local is_reasoning = false
-- Add CoT for DeepSeekReasoner
if string.match(agent_name, "^DeepSeekReasoner") then
vim.api.nvim_buf_set_lines(buf, last_content_line + 3, last_content_line + 3, false,
{ "<think>", "<details>", "<summary>CoT</summary>", "" })
offset = 1
is_reasoning = true
end

-- call the model and write response
M.dispatcher.query(
buf,
headers.provider or agent.provider,
M.dispatcher.prepare_payload(messages, headers.model or agent.model, headers.provider or agent.provider),
M.dispatcher.create_handler(buf, win, M.helpers.last_content_line(buf), true, "", not M.config.chat_free_cursor),
M.dispatcher.create_handler(buf, win, M.helpers.last_content_line(buf) + offset, true, "",
not M.config.chat_free_cursor),
vim.schedule_wrap(function(qid)
local qt = M.tasker.get_query(qid)
if not qt then
Expand Down Expand Up @@ -1125,7 +1153,8 @@ M.chat_respond = function(params)
topic_handler,
vim.schedule_wrap(function()
-- get topic from invisible buffer
local topic = vim.api.nvim_buf_get_lines(topic_buf, 0, -1, false)[1]
-- instead of the first line, get the last two line can skip CoT
local topic = vim.api.nvim_buf_get_lines(topic_buf, -3, -1, false)[1]
-- close invisible buffer
vim.api.nvim_buf_delete(topic_buf, { force = true })
-- strip whitespace from ends of topic
Expand All @@ -1141,15 +1170,19 @@ M.chat_respond = function(params)
-- replace topic in current buffer
M.helpers.undojoin(buf)
vim.api.nvim_buf_set_lines(buf, 0, 1, false, { "# topic: " .. topic })
end)
end),
nil,
false
)
end
if not M.config.chat_free_cursor then
local line = vim.api.nvim_buf_line_count(buf)
M.helpers.cursor_to_line(line, buf, win)
end
vim.cmd("doautocmd User GpDone")
end)
end),
nil,
is_reasoning
)
end

Expand Down Expand Up @@ -1935,7 +1968,8 @@ M.Prompt = function(params, target, agent, template, prompt, whisper, callback)
on_exit(qid)
vim.cmd("doautocmd User GpDone")
end),
callback
callback,
false
)
end

Expand Down