Skip to content

Commit

Permalink
update impl
Browse files Browse the repository at this point in the history
  • Loading branch information
TwIStOy committed Aug 20, 2023
1 parent 254ae17 commit 98890fb
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 114 deletions.
1 change: 1 addition & 0 deletions lua/luasnip/_types.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
---@alias LuaSnip.Cursor {[1]: number, [2]: number}
24 changes: 24 additions & 0 deletions lua/luasnip/extras/_extra_types.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
---@class LuaSnip.extra.MatchTSNodeOpts
---@field query string|{[1]: string, [2]: string?}
---@field select LuaSnip.extra.SelectTSNodeOpts|LuaSnip.extra.SelectTSNodeFunc|nil

---@class LuaSnip.extra.MatchedTSNodeInfo
---@field capture_name string
---@field node TSNode

---@class LuaSnip.extra.MatchTSNodeResult
---@field best_match TSNode
---@field matches LuaSnip.extra.MatchedTSNodeInfo[]

---@alias LuaSnip.extra.BuiltinCaptureSelector
---| '"any"' # The default selector
---| '"shortest"'
---| '"longest"'

---@class LuaSnip.extra.SelectTSNodeOpts
---@field captures string|string[]|nil
---@field select_capture LuaSnip.extra.BuiltinCaptureSelector?

---@alias LuaSnip.extra.SelectTSNodeFunc fun(nodes: LuaSnip.extra.MatchedTSNodeInfo[]):TSNode?
---@alias LuaSnip.extra.MatchTSNodeFunc fun(parser: LuaSnip.extra.TSParser, cursor: LuaSnip.Cursor):LuaSnip.extra.MatchTSNodeResult?

202 changes: 143 additions & 59 deletions lua/luasnip/extras/_treesitter.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
local util = require("luasnip.util.util")
local func = require("luasnip.util.functions")
local tbl = require("luasnip.util.table")

local function get_lang(bufnr)
Expand Down Expand Up @@ -121,6 +120,99 @@ local function wrap_enter_and_leave_func(ori_bufnr, match)
end
end

local builtin_tsnode_selectors = {
any = function(captures)
local best_match
return {
match = function(node)
if captures and not captures[node:type()] then
return false
end
if best_match == nil then
best_match = node
end
return true
end,
get_best_match = function()
return best_match
end,
}
end,
shortest = function(captures)
local best_match
local best_match_start
return {
match = function(node)
if captures and not captures[node:type()] then
return false
end

local start_row, start_col, _, _ =
vim.treesitter.get_node_range(node)
if
(best_match == nil)
or (start_row > best_match_start[1])
or (
start_row == best_match_start[1]
and start_col > best_match_start[2]
)
then
best_match = node
best_match_start = { start_row, start_col }
end
return true
end,
get_best_match = function()
return best_match
end,
}
end,
longest = function(captures)
local best_match
local best_match_start
return {
match = function(node)
if captures and not captures[node:type()] then
return false
end

local start_row, start_col, _, _ =
vim.treesitter.get_node_range(node)
if
(best_match == nil)
or (start_row < best_match_start[1])
or (
start_row == best_match_start[1]
and start_col < best_match_start[2]
)
then
best_match = node
best_match_start = { start_row, start_col }
end
return true
end,
get_best_match = function()
return best_match
end,
}
end,
}

---@param opts LuaSnip.extra.SelectTSNodeOpts?
---@return { match: (fun(node:TSNode):boolean), get_best_match: fun():TSNode? }
local function generate_select_tsnode_func(opts)
---@type LuaSnip.extra.SelectTSNodeOpts
opts = vim.F.if_nil(opts, {})
local captures = tbl.normalize_search_table(opts.captures)
local select_capture = opts.select_capture or "any"

local selector = builtin_tsnode_selectors[select_capture]
if selector == nil then
error("Unknown select_capture: " .. select_capture)
end
return selector.new_state(captures)
end

---@class LuaSnip.extra.TSParser
---@field parser LanguageTree
---@field source string|number
Expand Down Expand Up @@ -175,10 +267,9 @@ function TSParser:get_node_at_pos(pos)
return self.parser:named_node_for_range(range)
end

---@param opts LuaSnip.extra.MatchTSNodeFromCaptures
---@param opts LuaSnip.extra.MatchTSNodeOpts
---@param root TSNode?
---@param root_lang string?
function TSParser:prepare_query(opts, root, root_lang)
function TSParser:prepare_query(opts, root)
if root == nil then
-- try first tree's root
local first_tree = self.parser:trees()[1]
Expand All @@ -192,61 +283,56 @@ function TSParser:prepare_query(opts, root, root_lang)

local range = { root:range() }

if not root_lang then
local lang_tree = self.parser:language_for_range(range)

if lang_tree then
root_lang = lang_tree:lang()
end
end
if not root_lang then
return
end

---@type Query?
local query
local insert

if opts.query_group == nil and opts.captures == nil then
return
end
if opts.capture_text then
query = vim.treesitter.query.parse(self.buf_lang, opts.capture_text)
insert = function()
return true
end
if type(opts.query) == "string" then
query = vim.treesitter.query.parse(
self.buf_lang,
opts.query --[[@as string]]
)
else
opts.query_group = opts.query_group or "luasnip"
opts.captures = tbl.normalize_search_table(opts.captures)
query = vim.treesitter.query.get(self.buf_lang, opts.query_group)
insert = function(capture_name)
return opts.captures and opts.captures[capture_name]
end
local query_group = opts.query[1]
local lang = opts.query[2] or self.buf_lang
query = vim.treesitter.query.get(lang, query_group)
end

if not query then
return
end

return query,
{
root = root,
insert = insert,
start = range[1],
stop = range[3] + 1,
}
return query, {
root = root,
start = range[1],
stop = range[3] + 1,
}
end

---@param captures LuaSnip.extra.MatchTSNodeFromCaptures
---@return { name: string, node: TSNode }[]?
function TSParser:get_capture_matches(captures)
---@param captures LuaSnip.extra.MatchTSNodeOpts
---@return LuaSnip.extra.MatchTSNodeResult?
function TSParser:get_capture_matches(captures) end

---@param captures LuaSnip.extra.MatchTSNodeOpts
---@param pos { [1]: number, [2]: number }?
---@return LuaSnip.extra.MatchTSNodeResult?
function TSParser:captures_at_pos(captures, pos)
pos = vim.F.if_nil(pos, util.get_cursor_0ind())

---@type LuaSnip.extra.MatchedTSNodeInfo[]
local results = {}

local query, info = self:prepare_query(captures)
if query == nil or info == nil then
return
end

local selector
if type(captures.select) ~= "function" then
selector = generate_select_tsnode_func(
captures.select --[[@as LuaSnip.extra.SelectTSNodeOpts?]]
)
end

local matches =
query:iter_matches(info.root, self.source, info.start, info.stop)

Expand All @@ -258,34 +344,32 @@ function TSParser:get_capture_matches(captures)

for id, node in ipairs(match) do
local capture_name = query.captures[id]
if info.insert and info.insert(capture_name) then

if
(vim.treesitter.is_in_node_range(match.node, pos[1], pos[2]))
and (selector == nil or selector.match(node))
then
results[#results + 1] = {
name = capture_name,
capture_name = capture_name,
node = node,
}
end
end
end

return results
end

---@param captures LuaSnip.extra.MatchTSNodeFromCaptures
---@param pos { [1]: number, [2]: number }?
---@return { name: string, node: TSNode }[]?
function TSParser:captures_at_pos(captures, pos)
pos = vim.F.if_nil(pos, util.get_cursor_0ind())
local matches = self:get_capture_matches(captures)
if matches == nil then
return
local best_match
if selector ~= nil then
best_match = selector.get_best_match()
else
local select_best_match = captures --[[@as LuaSnip.extra.SelectTSNodeFunc]]
best_match = select_best_match(results)
end
local results = {}
for _, match in ipairs(matches) do
if vim.treesitter.is_in_node_range(match.node, pos[1], pos[2]) then
results[#results + 1] = match
end
if best_match ~= nil then
return {
best_match = best_match,
matches = results,
}
end
return results
end

---@param node TSNode
Expand Down
Loading

0 comments on commit 98890fb

Please sign in to comment.