From fc2687a4ff42f3258caf3f01d29b0a4fe7825f83 Mon Sep 17 00:00:00 2001 From: luozhiya Date: Wed, 22 May 2024 09:17:26 +0800 Subject: [PATCH] Normalize range --- lua/fittencode/engines/actions.lua | 38 +++++++++++++++++++++ lua/fittencode/prompt_providers/actions.lua | 12 +------ lua/fittencode/unicode.lua | 6 ++-- 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/lua/fittencode/engines/actions.lua b/lua/fittencode/engines/actions.lua index 32db5adb..73cb6b29 100644 --- a/lua/fittencode/engines/actions.lua +++ b/lua/fittencode/engines/actions.lua @@ -12,6 +12,7 @@ local Sessions = require('fittencode.sessions') local Status = require('fittencode.status') local SuggestionsPreprocessing = require('fittencode.suggestions_preprocessing') local TaskScheduler = require('fittencode.tasks') +local Unicode = require('fittencode.unicode') local schedule = Base.schedule @@ -255,6 +256,40 @@ end local VMODE = { ['v'] = true, ['V'] = true, [api.nvim_replace_termcodes('', true, true, true)] = true } +---@param buffer number +---@param range ActionRange +local function normalize_range(buffer, range) + local start = range.start + local end_ = range['end'] + + if end_[1] < start[1] then + start[1], end_[1] = end_[1], start[1] + start[2], end_[2] = end_[2], start[2] + end + if end_[2] < start[2] and end_[1] == start[1] then + start[2], end_[2] = end_[2], start[2] + end + + local utf_end_byte = function(row, col) + local line = api.nvim_buf_get_lines(buffer, row - 1, row, false)[1] + local byte_start = math.min(col + 1, #line) + local utf_index = Unicode.calculate_utf8_index(line) + local flag = utf_index[byte_start] + assert(flag == 0) + local byte_end = #line + local next = Unicode.find_zero(utf_index, byte_start + 1) + if next then + byte_end = next - 1 + end + return byte_end + end + + end_[2] = utf_end_byte(end_[1], end_[2]) + + range.start = start + range['end'] = end_ +end + local function make_range(buffer) local in_v = false local region = nil @@ -273,12 +308,15 @@ local function make_range(buffer) local start = api.nvim_buf_get_mark(buffer, '<') local end_ = api.nvim_buf_get_mark(buffer, '>') + ---@type ActionRange local range = { start = start, ['end'] = end_, vmode = in_v, region = region, } + normalize_range(buffer, range) + return range end diff --git a/lua/fittencode/prompt_providers/actions.lua b/lua/fittencode/prompt_providers/actions.lua index 3dee89ad..bc37ef55 100644 --- a/lua/fittencode/prompt_providers/actions.lua +++ b/lua/fittencode/prompt_providers/actions.lua @@ -28,14 +28,6 @@ function M:get_priority() return self.priority end -local function max_len(buffer, row, len) - local max = string.len(api.nvim_buf_get_lines(buffer, row - 1, row, false)[1]) - if len > max then - return max - end - return len -end - ---@param buffer integer ---@param range ActionRange ---@return string @@ -44,14 +36,12 @@ local function make_range_content(buffer, range) if range.vmode and range.region then lines = range.region or {} else - -- lines = api.nvim_buf_get_text(buffer, range.start[1] - 1, 0, range.start[1] - 1, -1, {}) - local end_col = max_len(buffer, range['end'][1], range['end'][2]) lines = api.nvim_buf_get_text( buffer, range.start[1] - 1, range.start[2], range['end'][1] - 1, - end_col + 1, {}) + range['end'][2], {}) end return table.concat(lines, '\n') end diff --git a/lua/fittencode/unicode.lua b/lua/fittencode/unicode.lua index 3d6c88f4..af2c95bc 100644 --- a/lua/fittencode/unicode.lua +++ b/lua/fittencode/unicode.lua @@ -17,7 +17,7 @@ function M.calculate_utf8_index_tbl(lines) return index end -local function find_zero(tbl, start_index) +function M.find_zero(tbl, start_index) for i = start_index, #tbl do if tbl[i] == 0 then return i @@ -30,14 +30,14 @@ function M.find_first_character(s, tbl, start_index) return nil end - local v1 = find_zero(tbl, start_index) + local v1 = M.find_zero(tbl, start_index) assert(v1 == start_index) if v1 == nil then -- Invalid UTF-8 sequence return nil end - local v2 = find_zero(tbl, v1 + 1) + local v2 = M.find_zero(tbl, v1 + 1) if v2 == nil then v2 = #tbl else