diff --git a/lua/luasnip/extras/_treesitter.lua b/lua/luasnip/extras/_treesitter.lua index 53a736199..d4a1f1eac 100644 --- a/lua/luasnip/extras/_treesitter.lua +++ b/lua/luasnip/extras/_treesitter.lua @@ -7,6 +7,51 @@ ---@operator call(TSNode): boolean local TSNodeMatcher = {} +local function get_lang(bufnr) + local ft = vim.api.nvim_buf_get_option(bufnr, "ft") + local lang = vim.treesitter.language.get_lang(ft) or ft + return lang +end + +---@param bufnr number +---@param match string +---@return LanguageTree, string +local function reparse_buffer_after_removing_match(bufnr, match) + local row, col = unpack(vim.api.nvim_win_get_cursor(0)) + local lang = get_lang(bufnr) + + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local current_line = lines[row] + local left_part = current_line:sub(1, col - #match) + local right_part = current_line:sub(col + 1) + lines[row] = left_part .. right_part + + local source = table.concat(lines, "\n") + ---@type LanguageTree + local parser = vim.treesitter.get_string_parser(source, lang, nil) + parser:parse(true) + return parser, source +end + +---@param pos { [1]: number, [2]: number } +---@param parser number|LanguageTree +---@return TSNode? +local function get_node(pos, parser) + local row, col = pos[1], pos[2] + assert( + row >= 0 and col >= 0, + "Invalid position: row and col must be non-negative" + ) + local range = { row, col, row, col } + if type(parser) == "number" then + parser = vim.treesitter.get_parser(parser) + end + if parser == nil then + return + end + return parser:named_node_for_range(range) +end + ---@param root TSNode ---@param n number ---@param matcher LuaSnip.extra.TSNodeChecker|fun(node:TSNode):boolean|nil @@ -136,4 +181,7 @@ return { find_first_parent = find_first_parent, inside_node = inside_node, TSNodeMatcher = TSNodeMatcher, + get_lang = get_lang, + get_node = get_node, + reparse_buffer_after_removing_match = reparse_buffer_after_removing_match, } diff --git a/lua/luasnip/extras/treesitter_postfix.lua b/lua/luasnip/extras/treesitter_postfix.lua index e82df7caa..ce5576e7b 100644 --- a/lua/luasnip/extras/treesitter_postfix.lua +++ b/lua/luasnip/extras/treesitter_postfix.lua @@ -46,11 +46,7 @@ local function generate_cursor_not_in_node_range_checker(cursor) ---@param node TSNode ---@return boolean return ts.TSNodeMatcher.new(function(node) - local start_row, start_col, end_row, end_col = - vim.treesitter.get_node_range(node) - return (start_row > cursor[1] or end_row < cursor[1]) - or (start_row == cursor[1] and start_col > cursor[2]) - or (end_row == cursor[1] and end_col < cursor[2]) + return vim.treesitter.is_in_node_range(node, cursor[1], cursor[2]) end) end @@ -63,40 +59,114 @@ local function generate_current_cursor_not_in_node_range_checker() return range_checker end ----@param find_root nil|fun(root: TSNode):TSNode? +---@param lang string +---@param query_name string +---@param capture string +---@param source string|number +---@param cursor { [1]: number, [2]: number } +---@return fun(parser: LanguageTree):TSNode? +local function generate_tsquery_node_resolver( + lang, + query_name, + capture, + source, + cursor +) + local query = vim.treesitter.query.get(lang, query_name) + ---@param parser LanguageTree + ---@return TSNode? + return function(parser) + local trees = parser:parse() + for _, tree in ipairs(trees) do + local matches = query:iter_matches(tree:root(), source, 0, -1) + while true do + local pattern, match = matches() + if pattern == nil then + break + end + for id, node in pairs(match) do + if + query.captures[id] == capture + and vim.treesitter.is_in_node_range( + node, + cursor[1], + cursor[2] + ) + then + return node + end + end + end + end + end +end + +---@param resolve_node fun(parser: LanguageTree):TSNode?|string ---@param user_resolver nil|fun(snippet, line_to_cursor, matched_trigger, captures):table? +---@param reparse_buffer boolean ---@return fun(snippet, line_to_cursor, matched_trigger, captures):table? -local function wrap_resolve_expand_params_func(find_root, user_resolver) +local function wrap_resolve_expand_params_func( + resolve_node, + user_resolver, + reparse_buffer +) return function(snippet, line_to_cursor, matched_trigger, captures) local cursor = vim.api.nvim_win_get_cursor(0) - local cursor_range = { cursor[1] - 1, cursor[2] - #matched_trigger - 1 } - - local buf = vim.api.nvim_win_get_buf(0) - local root = vim.treesitter.get_node({ - bufnr = buf, - pos = cursor_range, - }) - if root and find_root then - root = find_root(root) + local bufnr = vim.api.nvim_win_get_buf(0) + local lang = ts.get_lang(bufnr) + + local parser, source + local row, col + if reparse_buffer then + parser, source = + ts.reparse_buffer_after_removing_match(0, matched_trigger) + row, col = cursor[1] - 1, cursor[2] - #matched_trigger + else + parser, source = vim.treesitter.get_parser(bufnr), bufnr + row, col = cursor[1] - 1, cursor[2] - #matched_trigger - 1 end - if root == nil then + if parser == nil then + return + end + + local match_node + if type(resolve_node) == "string" then + match_node = generate_tsquery_node_resolver( + lang, + "luasnip", + resolve_node, + source, + { row, col } + ) + else + match_node = resolve_node + end + + local matched_node = match_node(parser) + if matched_node == nil then return nil end -- try to use the text from `line_to_cursor` - local start_row, start_col, end_row, end_col = - vim.treesitter.get_node_range(root) - - local node_text = vim.api.nvim_buf_get_text( - 0, - start_row, - start_col, - cursor[1] - 1, - cursor[2], - {} - ) - local last_line = node_text[#node_text] - node_text[#node_text] = last_line:sub(1, #last_line - #matched_trigger) + local start_row, start_col, _, _ = + vim.treesitter.get_node_range(matched_node) + + local node_text + if reparse_buffer then + node_text = vim.treesitter.get_node_text(matched_node, source) + else + node_text = vim.api.nvim_buf_get_text( + 0, + start_row, + start_col, + cursor[1] - 1, + cursor[2], + {} + ) + local last_line = node_text[#node_text] + node_text[#node_text] = + last_line:sub(1, #last_line - #matched_trigger) + end local ret = { trigger = matched_trigger, @@ -184,10 +254,19 @@ local function treesitter_postfix(context, nodes, opts) context = node_util.wrap_context(context) context.wordTrig = false context.trigEngine = "plain" - local find_tsnode = context.find_tsnode + local reparse_buffer = context.reparseBuffer or false + local resolve_node = context.matchNodeBeforeCursor + local capture = context.matchCapture + assert( + resolve_node or capture, + "matchNodeBeforeCursor and matchCapture are both nil" + ) local user_resolve = context.resolveExpandParams - context.resolveExpandParams = - wrap_resolve_expand_params_func(find_tsnode, user_resolve) + context.resolveExpandParams = wrap_resolve_expand_params_func( + resolve_node or capture, + user_resolve, + reparse_buffer + ) return snip(context, nodes, opts) end