Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
TwIStOy committed Sep 1, 2023
1 parent 4243d25 commit 79bcc0c
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 101 deletions.
4 changes: 4 additions & 0 deletions lua/luasnip/_types.lua
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
---@alias LuaSnip.Cursor {[1]: number, [2]: number}

---@class LuaSnip.MatchRegion 0-based region
---@field row integer 0-based row
---@field col { [1]: integer, [2]: integer } 0-based column range, inclusive
64 changes: 23 additions & 41 deletions lua/luasnip/extras/_treesitter.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,16 @@ local function inspect_node(node)
end

---@param bufnr number
---@param match string
---@param region LuaSnip.MatchRegion
---@return LanguageTree, string
local function reparse_buffer_after_removing_match(bufnr, match)
local row, col = unpack(vim.api.nvim_win_get_cursor(0))
local function reparse_buffer_after_removing_match(bufnr, region)
local lang = get_lang(bufnr)

local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local current_line = lines[row]
local left_part = current_line:sub(1, col - #match)
local right_part = current_line:sub(col + 1)
lines[row] = left_part .. right_part
local current_line = lines[region.row]
local left_part = current_line:sub(1, region.col[1])
local right_part = current_line:sub(region.col[2] + 1)
lines[region.row] = left_part .. right_part

local source = table.concat(lines, "\n")
---@type LanguageTree
Expand All @@ -50,24 +49,22 @@ end
---@class LuaSnip.extra.FixBufferContext
---@field ori_bufnr number
---@field ori_line string
---@field ori_row number
---@field ori_col number
---@field match string
---@field region LuaSnip.MatchRegion
local FixBufferContext = {}

---@param ori_bufnr number
---@param region LuaSnip.MatchRegion
---@return LuaSnip.extra.FixBufferContext
function FixBufferContext.new(ori_bufnr, match)
local row, col = unpack(vim.api.nvim_win_get_cursor(0))
local lines = vim.api.nvim_buf_get_lines(ori_bufnr, row - 1, row, true)
function FixBufferContext.new(ori_bufnr, region)
local lines =
vim.api.nvim_buf_get_lines(ori_bufnr, region.row, region.row + 1, true)
assert(#lines == 1)
local current_line = lines[1]

local o = {
ori_bufnr = ori_bufnr,
ori_line = current_line,
ori_row = row,
ori_col = col,
match = match,
region = region,
}
setmetatable(o, {
__index = FixBufferContext,
Expand All @@ -77,13 +74,13 @@ function FixBufferContext.new(ori_bufnr, match)
end

function FixBufferContext:enter()
local current_line_left = self.ori_line:sub(1, self.ori_col - #self.match)
local current_line_right = self.ori_line:sub(self.ori_col + 1)
local current_line_left = self.ori_line:sub(1, self.region.col[1])
local current_line_right = self.ori_line:sub(self.region.col[2] + 1)

vim.api.nvim_buf_set_lines(
self.ori_bufnr,
self.ori_row - 1,
self.ori_row,
self.region.row,
self.region.row + 1,
true,
{ current_line_left .. current_line_right }
)
Expand All @@ -97,8 +94,8 @@ end
function FixBufferContext:leave()
vim.api.nvim_buf_set_lines(
self.ori_bufnr,
self.ori_row - 1,
self.ori_row,
self.region.row,
self.region.row + 1,
true,
{
self.ori_line,
Expand All @@ -110,16 +107,6 @@ function FixBufferContext:leave()
return parser, source
end

local function wrap_enter_and_leave_func(ori_bufnr, match)
local context = FixBufferContext.new(ori_bufnr, match)

return function()
context:enter()
end, function()
context:leave()
end
end

local builtin_tsnode_selectors = {
any = function(captures)
local best_match
Expand Down Expand Up @@ -375,10 +362,9 @@ function TSParser:captures_at_pos(captures, pos)
end

---@param node TSNode
---@param cursor { [1]: number, [2]: number }
---@param matched_trigger string
---@param end_pos { [1]: number, [2]: number } 0-based position
---@return number, number, string[]
function TSParser:get_node_text(node, cursor, matched_trigger)
function TSParser:get_node_text(node, end_pos)
local text
local start_row, start_col, _, _ = vim.treesitter.get_node_range(node)
if type(self.source) == "string" then
Expand All @@ -392,12 +378,10 @@ function TSParser:get_node_text(node, cursor, matched_trigger)
0,
start_row,
start_col,
cursor[1] - 1,
cursor[2],
end_pos[1],
end_pos[2] + 1,
{}
)
local last_line = text[#text]
text[#text] = last_line:sub(1, #last_line - #matched_trigger)
end
return start_row, start_col, text
end
Expand Down Expand Up @@ -467,10 +451,8 @@ return {
reparse_buffer_after_removing_match = reparse_buffer_after_removing_match,
TSParser = TSParser,
FixBufferContext = FixBufferContext,
wrap_enter_and_leave_func = wrap_enter_and_leave_func,
find_topmost_parent = find_topmost_parent,
find_first_parent = find_first_parent,
find_nth_parent = find_nth_parent,
inspect_node = inspect_node,
}

105 changes: 45 additions & 60 deletions lua/luasnip/extras/treesitter_postfix.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,68 +25,51 @@ local function generate_match_tsnode_func(opts)
end
end

---@param reparse boolean|string|nil
---@param real_resolver function
---@return fun(snippet, line_to_cursor, matched_trigger, captures):table?
local function wrap_with_reparse_context(reparse, real_resolver)
local function make_reparse_enter_and_leave_func(reparse, bufnr, matched_region)
if reparse == "live" then
return function(snippet, line_to_cursor, matched_trigger, captures)
local bufnr = vim.api.nvim_win_get_buf(0)
local cursor = vim.api.nvim_win_get_cursor(0)
local enter, leave =
ts.wrap_enter_and_leave_func(bufnr, matched_trigger)

enter()
local parser, source = vim.treesitter.get_parser(bufnr), bufnr
if parser == nil or source == nil then
return nil
end
local ret = real_resolver(
snippet,
line_to_cursor,
matched_trigger,
captures,
parser,
source,
bufnr,
{ cursor[1] - 1, cursor[2] - #matched_trigger - 1 }
)
leave()
return ret
local context = ts.FixBufferContext.new(bufnr, matched_region)
return function()
context:enter()
return vim.treesitter.get_parser(bufnr), bufnr
end, function(...)
context:leave()
end
end
if reparse == "copy" then
return function(snippet, line_to_cursor, matched_trigger, captures)
local bufnr = vim.api.nvim_win_get_buf(0)
local cursor = vim.api.nvim_win_get_cursor(0)
local parser, source =
ts.reparse_buffer_after_removing_match(bufnr, matched_trigger)
if parser == nil or source == nil then
return nil
end
local ret = real_resolver(
snippet,
line_to_cursor,
matched_trigger,
captures,
parser,
source,
bufnr,
{ cursor[1] - 1, cursor[2] - #matched_trigger - 1 }
)
elseif reparse == "copy" then
return function()
return ts.reparse_buffer_after_removing_match(bufnr, matched_region)
end, function(parser)
parser:destroy()
return ret
end
else
return function()
return vim.treesitter.get_parser(bufnr), bufnr
end, function(...) end
end
end

---@param reparse boolean|string|nil
---@param real_resolver function
---@return fun(snippet, line_to_cursor, matched_trigger, captures):table?
local function wrap_with_reparse_context(reparse, real_resolver)
return function(snippet, line_to_cursor, matched_trigger, captures)
local bufnr = vim.api.nvim_win_get_buf(0)
local cursor = vim.api.nvim_win_get_cursor(0)
local parser, source = vim.treesitter.get_parser(bufnr), bufnr
local matched_region = {
row = cursor[1] - 1,
col = {
cursor[2] - #matched_trigger,
cursor[2],
},
}

local enter, leave =
make_reparse_enter_and_leave_func(reparse, bufnr, matched_region)
local parser, source = enter()
if parser == nil or source == nil then
return nil
end
return real_resolver(

local ret = real_resolver(
snippet,
line_to_cursor,
matched_trigger,
Expand All @@ -96,6 +79,10 @@ local function wrap_with_reparse_context(reparse, real_resolver)
bufnr,
{ cursor[1] - 1, cursor[2] - #matched_trigger - 1 }
)

leave(parser)

return ret
end
end

Expand All @@ -118,7 +105,6 @@ local function generate_resolve_expand_param(match_tsnode, user_resolver)
bufnr,
pos
)
local cursor = vim.api.nvim_win_get_cursor(0)
local ts_parser = ts.TSParser.new(bufnr, parser, source)
if ts_parser == nil then
return
Expand All @@ -145,7 +131,8 @@ local function generate_resolve_expand_param(match_tsnode, user_resolver)

for _, info in ipairs(match_result.matches) do
local start_row, start_col, text =
ts_parser:get_node_text(info.node, cursor, matched_trigger)
ts_parser:get_node_text(info.node, pos)

if matches[info.capture_name] == nil then
matches[info.capture_name] = {}
end
Expand All @@ -156,11 +143,8 @@ local function generate_resolve_expand_param(match_tsnode, user_resolver)
})
end

local start_row, start_col, best_match_text = ts_parser:get_node_text(
match_result.best_match,
cursor,
matched_trigger
)
local start_row, start_col, best_match_text =
ts_parser:get_node_text(match_result.best_match, pos)

local ret = {
trigger = matched_trigger,
Expand All @@ -171,14 +155,14 @@ local function generate_resolve_expand_param(match_tsnode, user_resolver)
start_col,
},
to = {
cursor[1] - 1,
cursor[2],
pos[1],
pos[2] + #matched_trigger + 1,
},
},
env_override = {
TREESITTER_MATCHES = matches,
TREESITTER_BEST_MATCH = {
start = { start_col, start_col },
start = { start_row, start_col },
text = best_match_text,
type = match_result.best_match:type(),
},
Expand Down Expand Up @@ -285,6 +269,7 @@ local function treesitter_postfix(context, nodes, opts)
})

context = node_util.wrap_context(context)
context.wordTrig = false
---@type string|string[]|LuaSnip.extra.MatchTSNodeOpts|LuaSnip.extra.MatchTSNodeFunc
local match_tsnode = context.matchTSNode
local user_resolve = context.resolveExpandParams
Expand Down
1 change: 1 addition & 0 deletions lua/luasnip/util/table.lua
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
---Convert string or list of string to a table of booleans for fast lookup.
---@generic T
---@param values T|T[]|table<T, boolean>
---@return table<T, boolean>
Expand Down

0 comments on commit 79bcc0c

Please sign in to comment.