Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalize range #69

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions lua/fittencode/engines/actions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ local Sessions = require('fittencode.sessions')
local Status = require('fittencode.status')
local SuggestionsPreprocessing = require('fittencode.suggestions_preprocessing')
local TaskScheduler = require('fittencode.tasks')
local Unicode = require('fittencode.unicode')

local schedule = Base.schedule

Expand Down Expand Up @@ -255,6 +256,40 @@ end

local VMODE = { ['v'] = true, ['V'] = true, [api.nvim_replace_termcodes('<C-V>', true, true, true)] = true }

---@param buffer number
---@param range ActionRange
local function normalize_range(buffer, range)
local start = range.start
local end_ = range['end']

if end_[1] < start[1] then
start[1], end_[1] = end_[1], start[1]
start[2], end_[2] = end_[2], start[2]
end
if end_[2] < start[2] and end_[1] == start[1] then
start[2], end_[2] = end_[2], start[2]
end

local utf_end_byte = function(row, col)
local line = api.nvim_buf_get_lines(buffer, row - 1, row, false)[1]
local byte_start = math.min(col + 1, #line)
local utf_index = Unicode.calculate_utf8_index(line)
local flag = utf_index[byte_start]
assert(flag == 0)
local byte_end = #line
local next = Unicode.find_zero(utf_index, byte_start + 1)
if next then
byte_end = next - 1
end
return byte_end
end

end_[2] = utf_end_byte(end_[1], end_[2])

range.start = start
range['end'] = end_
end

local function make_range(buffer)
local in_v = false
local region = nil
Expand All @@ -273,12 +308,15 @@ local function make_range(buffer)
local start = api.nvim_buf_get_mark(buffer, '<')
local end_ = api.nvim_buf_get_mark(buffer, '>')

---@type ActionRange
local range = {
start = start,
['end'] = end_,
vmode = in_v,
region = region,
}
normalize_range(buffer, range)

return range
end

Expand Down
12 changes: 1 addition & 11 deletions lua/fittencode/prompt_providers/actions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,6 @@ function M:get_priority()
return self.priority
end

local function max_len(buffer, row, len)
local max = string.len(api.nvim_buf_get_lines(buffer, row - 1, row, false)[1])
if len > max then
return max
end
return len
end

---@param buffer integer
---@param range ActionRange
---@return string
Expand All @@ -44,14 +36,12 @@ local function make_range_content(buffer, range)
if range.vmode and range.region then
lines = range.region or {}
else
-- lines = api.nvim_buf_get_text(buffer, range.start[1] - 1, 0, range.start[1] - 1, -1, {})
local end_col = max_len(buffer, range['end'][1], range['end'][2])
lines = api.nvim_buf_get_text(
buffer,
range.start[1] - 1,
range.start[2],
range['end'][1] - 1,
end_col + 1, {})
range['end'][2], {})
end
return table.concat(lines, '\n')
end
Expand Down
6 changes: 3 additions & 3 deletions lua/fittencode/unicode.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function M.calculate_utf8_index_tbl(lines)
return index
end

local function find_zero(tbl, start_index)
function M.find_zero(tbl, start_index)
for i = start_index, #tbl do
if tbl[i] == 0 then
return i
Expand All @@ -30,14 +30,14 @@ function M.find_first_character(s, tbl, start_index)
return nil
end

local v1 = find_zero(tbl, start_index)
local v1 = M.find_zero(tbl, start_index)
assert(v1 == start_index)
if v1 == nil then
-- Invalid UTF-8 sequence
return nil
end

local v2 = find_zero(tbl, v1 + 1)
local v2 = M.find_zero(tbl, v1 + 1)
if v2 == nil then
v2 = #tbl
else
Expand Down