Skip to content

Commit

Permalink
Update the way to find ts_postfix intrested node
Browse files Browse the repository at this point in the history
  • Loading branch information
TwIStOy committed Aug 16, 2023
1 parent c02db45 commit 46e0e7a
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 33 deletions.
48 changes: 48 additions & 0 deletions lua/luasnip/extras/_treesitter.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,51 @@
---@operator call(TSNode): boolean
local TSNodeMatcher = {}

local function get_lang(bufnr)
local ft = vim.api.nvim_buf_get_option(bufnr, "ft")
local lang = vim.treesitter.language.get_lang(ft) or ft
return lang
end

---@param bufnr number
---@param match string
---@return LanguageTree, string
local function reparse_buffer_after_removing_match(bufnr, match)
local row, col = unpack(vim.api.nvim_win_get_cursor(0))
local lang = get_lang(bufnr)

local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
local current_line = lines[row]
local left_part = current_line:sub(1, col - #match)
local right_part = current_line:sub(col + 1)
lines[row] = left_part .. right_part

local source = table.concat(lines, "\n")
---@type LanguageTree
local parser = vim.treesitter.get_string_parser(source, lang, nil)
parser:parse(true)
return parser, source
end

---@param pos { [1]: number, [2]: number }
---@param parser number|LanguageTree
---@return TSNode?
local function get_node(pos, parser)
local row, col = pos[1], pos[2]
assert(
row >= 0 and col >= 0,
"Invalid position: row and col must be non-negative"
)
local range = { row, col, row, col }
if type(parser) == "number" then
parser = vim.treesitter.get_parser(parser)
end
if parser == nil then
return
end
return parser:named_node_for_range(range)
end

---@param root TSNode
---@param n number
---@param matcher LuaSnip.extra.TSNodeChecker|fun(node:TSNode):boolean|nil
Expand Down Expand Up @@ -136,4 +181,7 @@ return {
find_first_parent = find_first_parent,
inside_node = inside_node,
TSNodeMatcher = TSNodeMatcher,
get_lang = get_lang,
get_node = get_node,
reparse_buffer_after_removing_match = reparse_buffer_after_removing_match,
}
145 changes: 112 additions & 33 deletions lua/luasnip/extras/treesitter_postfix.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@ local function generate_cursor_not_in_node_range_checker(cursor)
---@param node TSNode
---@return boolean
return ts.TSNodeMatcher.new(function(node)
local start_row, start_col, end_row, end_col =
vim.treesitter.get_node_range(node)
return (start_row > cursor[1] or end_row < cursor[1])
or (start_row == cursor[1] and start_col > cursor[2])
or (end_row == cursor[1] and end_col < cursor[2])
return vim.treesitter.is_in_node_range(node, cursor[1], cursor[2])
end)
end

Expand All @@ -63,40 +59,114 @@ local function generate_current_cursor_not_in_node_range_checker()
return range_checker
end

---@param find_root nil|fun(root: TSNode):TSNode?
---@param lang string
---@param query_name string
---@param capture string
---@param source string|number
---@param cursor { [1]: number, [2]: number }
---@return fun(parser: LanguageTree):TSNode?
local function generate_tsquery_node_resolver(
lang,
query_name,
capture,
source,
cursor
)
local query = vim.treesitter.query.get(lang, query_name)
---@param parser LanguageTree
---@return TSNode?
return function(parser)
local trees = parser:parse()
for _, tree in ipairs(trees) do
local matches = query:iter_matches(tree:root(), source, 0, -1)
while true do
local pattern, match = matches()
if pattern == nil then
break
end
for id, node in pairs(match) do
if
query.captures[id] == capture
and vim.treesitter.is_in_node_range(
node,
cursor[1],
cursor[2]
)
then
return node
end
end
end
end
end
end

---@param resolve_node fun(parser: LanguageTree):TSNode?|string
---@param user_resolver nil|fun(snippet, line_to_cursor, matched_trigger, captures):table?
---@param reparse_buffer boolean
---@return fun(snippet, line_to_cursor, matched_trigger, captures):table?
local function wrap_resolve_expand_params_func(find_root, user_resolver)
local function wrap_resolve_expand_params_func(
resolve_node,
user_resolver,
reparse_buffer
)
return function(snippet, line_to_cursor, matched_trigger, captures)
local cursor = vim.api.nvim_win_get_cursor(0)
local cursor_range = { cursor[1] - 1, cursor[2] - #matched_trigger - 1 }

local buf = vim.api.nvim_win_get_buf(0)
local root = vim.treesitter.get_node({
bufnr = buf,
pos = cursor_range,
})
if root and find_root then
root = find_root(root)
local bufnr = vim.api.nvim_win_get_buf(0)
local lang = ts.get_lang(bufnr)

local parser, source
local row, col
if reparse_buffer then
parser, source =
ts.reparse_buffer_after_removing_match(0, matched_trigger)
row, col = cursor[1] - 1, cursor[2] - #matched_trigger
else
parser, source = vim.treesitter.get_parser(bufnr), bufnr
row, col = cursor[1] - 1, cursor[2] - #matched_trigger - 1
end
if root == nil then
if parser == nil then
return
end

local match_node
if type(resolve_node) == "string" then
match_node = generate_tsquery_node_resolver(
lang,
"luasnip",
resolve_node,
source,
{ row, col }
)
else
match_node = resolve_node
end

local matched_node = match_node(parser)
if matched_node == nil then
return nil
end

-- try to use the text from `line_to_cursor`
local start_row, start_col, end_row, end_col =
vim.treesitter.get_node_range(root)

local node_text = vim.api.nvim_buf_get_text(
0,
start_row,
start_col,
cursor[1] - 1,
cursor[2],
{}
)
local last_line = node_text[#node_text]
node_text[#node_text] = last_line:sub(1, #last_line - #matched_trigger)
local start_row, start_col, _, _ =
vim.treesitter.get_node_range(matched_node)

local node_text
if reparse_buffer then
node_text = vim.treesitter.get_node_text(matched_node, source)
else
node_text = vim.api.nvim_buf_get_text(
0,
start_row,
start_col,
cursor[1] - 1,
cursor[2],
{}
)
local last_line = node_text[#node_text]
node_text[#node_text] =
last_line:sub(1, #last_line - #matched_trigger)
end

local ret = {
trigger = matched_trigger,
Expand Down Expand Up @@ -184,10 +254,19 @@ local function treesitter_postfix(context, nodes, opts)
context = node_util.wrap_context(context)
context.wordTrig = false
context.trigEngine = "plain"
local find_tsnode = context.find_tsnode
local reparse_buffer = context.reparseBuffer or false
local resolve_node = context.matchNodeBeforeCursor
local capture = context.matchCapture
assert(
resolve_node or capture,
"matchNodeBeforeCursor and matchCapture are both nil"
)
local user_resolve = context.resolveExpandParams
context.resolveExpandParams =
wrap_resolve_expand_params_func(find_tsnode, user_resolve)
context.resolveExpandParams = wrap_resolve_expand_params_func(
resolve_node or capture,
user_resolve,
reparse_buffer
)

return snip(context, nodes, opts)
end
Expand Down

0 comments on commit 46e0e7a

Please sign in to comment.