diff --git a/internal/handler/generic_document_usecase.go b/internal/handler/generic_document_usecase.go index fbd3bbb7..f6067bd1 100644 --- a/internal/handler/generic_document_usecase.go +++ b/internal/handler/generic_document_usecase.go @@ -9,10 +9,19 @@ import ( lsp "go.lsp.dev/protocol" ) -func (h *langHandler) genericDocumentUseCase(params lsp.TextDocumentPositionParams) (*lsplocal.Document, *charts.Chart, *sitter.Node, error) { +type GenericDocumentUseCase struct { + *lsplocal.Document + *charts.Chart + *sitter.Node + error +} + +func (h *langHandler) NewGenericDocumentUseCase(params lsp.TextDocumentPositionParams) GenericDocumentUseCase { doc, ok := h.documents.Get(params.TextDocument.URI) if !ok { - return nil, nil, nil, errors.New("Could not get document: " + params.TextDocument.URI.Filename()) + return GenericDocumentUseCase{ + error: errors.New("Could not get document: " + params.TextDocument.URI.Filename()), + } } chart, err := h.chartStore.GetChartForDoc(params.TextDocument.URI) if err != nil { @@ -20,9 +29,15 @@ func (h *langHandler) genericDocumentUseCase(params lsp.TextDocumentPositionPara } node := h.getNode(doc, params.Position) if node == nil { - return doc, chart, nil, errors.New("Could not get node for: " + params.TextDocument.URI.Filename()) + return GenericDocumentUseCase{ + error: errors.New("Could not get node for: " + params.TextDocument.URI.Filename()), + } + } + return GenericDocumentUseCase{ + Document: doc, + Chart: chart, + Node: node, } - return doc, chart, node, nil } func (h *langHandler) getNode(doc *lsplocal.Document, position lsp.Position) *sitter.Node { diff --git a/internal/handler/references.go b/internal/handler/references.go index 06cd6cb2..248edd42 100644 --- a/internal/handler/references.go +++ b/internal/handler/references.go @@ -3,6 +3,7 @@ package handler import ( "context" + lsplocal "github.com/mrjosh/helm-ls/internal/lsp" "github.com/mrjosh/helm-ls/internal/tree-sitter/gotemplate" "github.com/mrjosh/helm-ls/internal/util" lsp "go.lsp.dev/protocol" @@ -10,17 +11,19 @@ import ( // References implements protocol.Server. func (h *langHandler) References(ctx context.Context, params *lsp.ReferenceParams) (result []lsp.Location, err error) { - doc, _, node, err := h.genericDocumentUseCase(params.TextDocumentPositionParams) - if err != nil { + genericDocumentUseCase := h.NewGenericDocumentUseCase(params.TextDocumentPositionParams) + if genericDocumentUseCase.error != nil { return nil, err } - parentNode := node.Parent() + parentNode := genericDocumentUseCase.Node.Parent() pt := parentNode.Type() - ct := node.Type() + ct := genericDocumentUseCase.Node.Type() if pt == gotemplate.NodeTypeDefineAction && ct == gotemplate.NodeTypeInterpretedStringLiteral { - referenceRanges, ok := doc.SymbolTable.GetIncludeReference(util.RemoveQuotes(node.Content([]byte(doc.Content)))) + referenceRanges := genericDocumentUseCase.Document.SymbolTable.GetIncludeReference( + util.RemoveQuotes(genericDocumentUseCase.Node.Content([]byte(genericDocumentUseCase.Document.Content))), + ) locations := []lsp.Location{} for _, referenceRange := range referenceRanges { @@ -33,8 +36,25 @@ func (h *langHandler) References(ctx context.Context, params *lsp.ReferenceParam }) } - if ok { + return locations, nil + } + + if pt == gotemplate.NodeTypeArgumentList { + includeName, err := lsplocal.ParseIncludeFunctionCall(parentNode.Parent(), []byte(genericDocumentUseCase.Document.Content)) + if err == nil { + referenceRanges := genericDocumentUseCase.Document.SymbolTable.GetIncludeReference(includeName) + locations := []lsp.Location{} + for _, referenceRange := range referenceRanges { + locations = append(locations, lsp.Location{ + URI: params.TextDocumentPositionParams.TextDocument.URI, + Range: lsp.Range{ + Start: util.PointToPosition(referenceRange.StartPoint), + End: util.PointToPosition(referenceRange.EndPoint), + }, + }) + } return locations, nil + } } return nil, nil diff --git a/internal/handler/text_document.go b/internal/handler/text_document.go index d9752110..70f874ed 100644 --- a/internal/handler/text_document.go +++ b/internal/handler/text_document.go @@ -4,7 +4,6 @@ import ( "context" "errors" - lspinternal "github.com/mrjosh/helm-ls/internal/lsp" lsplocal "github.com/mrjosh/helm-ls/internal/lsp" lsp "go.lsp.dev/protocol" ) @@ -63,7 +62,7 @@ func (h *langHandler) DidChange(ctx context.Context, params *lsp.DidChangeTextDo doc.ApplyChanges(params.ContentChanges) for _, change := range params.ContentChanges { - node := lspinternal.NodeAtPosition(doc.Ast, change.Range.Start) + node := lsplocal.NodeAtPosition(doc.Ast, change.Range.Start) if node.Type() != "text" { shouldSendFullUpdateToYamlls = true break diff --git a/internal/lsp/symbol_table.go b/internal/lsp/symbol_table.go index 418e0ccd..bdb5e8be 100644 --- a/internal/lsp/symbol_table.go +++ b/internal/lsp/symbol_table.go @@ -46,12 +46,10 @@ func (s *SymbolTable) GetIncludeDefinitions(symbol string) ([]sitter.Range, bool return result, true } -func (s *SymbolTable) GetIncludeReference(symbol string) ([]sitter.Range, bool) { - result, ok := s.includeReferences[symbol] - if !ok { - return []sitter.Range{}, false - } - return result, true +func (s *SymbolTable) GetIncludeReference(symbol string) []sitter.Range { + result := s.includeReferences[symbol] + definitions := s.includeDefinitions[symbol] + return append(result, definitions...) } func (s *SymbolTable) parseTree(ast *sitter.Tree, content []byte) { diff --git a/internal/lsp/symbol_table_includes.go b/internal/lsp/symbol_table_includes.go index 2b66629f..e47735a3 100644 --- a/internal/lsp/symbol_table_includes.go +++ b/internal/lsp/symbol_table_includes.go @@ -1,6 +1,8 @@ package lsp import ( + "fmt" + "github.com/mrjosh/helm-ls/internal/tree-sitter/gotemplate" "github.com/mrjosh/helm-ls/internal/util" sitter "github.com/smacker/go-tree-sitter" @@ -24,20 +26,37 @@ func (v *IncludeDefinitionsVisitor) Enter(node *sitter.Node) { v.symbolTable.AddIncludeDefinition(util.RemoveQuotes(content), getRangeForNode(node)) } - // TODO: move this to separate function and use early returns if node.Type() == gotemplate.NodeTypeFunctionCall { - functionName := node.ChildByFieldName("function").Content(v.content) - if functionName == "include" { - arguments := node.ChildByFieldName("arguments") - if arguments.ChildCount() > 0 { - firstArgument := arguments.Child(0) - if firstArgument.Type() == gotemplate.NodeTypeInterpretedStringLiteral { - content := firstArgument.Content(v.content) - v.symbolTable.AddIncludeReference(util.RemoveQuotes(content), getRangeForNode(node)) - } - } - } + v.EnterFunctionCall(node) + } +} + +func (v *IncludeDefinitionsVisitor) EnterFunctionCall(node *sitter.Node) { + includeName, err := ParseIncludeFunctionCall(node, v.content) + if err != nil { + return + } + + v.symbolTable.AddIncludeReference(includeName, getRangeForNode(node)) +} + +func ParseIncludeFunctionCall(node *sitter.Node, content []byte) (string, error) { + if node.Type() != gotemplate.NodeTypeFunctionCall { + return "", fmt.Errorf("node is not a function call") + } + functionName := node.ChildByFieldName("function").Content(content) + if functionName != "include" { + return "", fmt.Errorf("function name is not include") + } + arguments := node.ChildByFieldName("arguments") + if arguments == nil || arguments.ChildCount() == 0 { + return "", fmt.Errorf("no arguments") + } + firstArgument := arguments.Child(0) + if firstArgument.Type() != gotemplate.NodeTypeInterpretedStringLiteral { + return "", fmt.Errorf("first argument is not an interpreted string literal") } + return util.RemoveQuotes(firstArgument.Content(content)), nil } func (v *IncludeDefinitionsVisitor) Exit(node *sitter.Node) {} diff --git a/internal/lsp/symbol_table_values.go b/internal/lsp/symbol_table_values.go index f0d3ad34..38434b7b 100644 --- a/internal/lsp/symbol_table_values.go +++ b/internal/lsp/symbol_table_values.go @@ -60,9 +60,7 @@ func (v *ValuesVisitor) Enter(node *sitter.Node) { case gotemplate.NodeTypeSelectorExpression: operandNode := node.ChildByFieldName("operand") if operandNode.Type() == gotemplate.NodeTypeVariable && operandNode.Content(v.content) == "$" { - v.stashedContext = append(v.stashedContext, v.currentContext) - v.currentContext = []string{} - // v.StashContext() + v.StashContext() } } } @@ -72,9 +70,7 @@ func (v *ValuesVisitor) Exit(node *sitter.Node) { case gotemplate.NodeTypeSelectorExpression: operandNode := node.ChildByFieldName("operand") if operandNode.Type() == gotemplate.NodeTypeVariable && operandNode.Content(v.content) == "$" { - v.currentContext = v.stashedContext[len(v.stashedContext)-1] - v.stashedContext = v.stashedContext[:len(v.stashedContext)-1] - // v.RestoreStashedContext() + v.RestoreStashedContext() } } } @@ -86,39 +82,30 @@ func (v *ValuesVisitor) EnterContextShift(node *sitter.Node, suffix string) { v.currentContext = append(v.currentContext, content) case gotemplate.NodeTypeField: content := node.ChildByFieldName("name").Content(v.content) + suffix - v.currentContext = append(v.currentContext, content) - // v.PushContext(content) + v.PushContext(content) case gotemplate.NodeTypeSelectorExpression: s := getContextForSelectorExpression(node, v.content) if len(s) > 0 { s[len(s)-1] = s[len(s)-1] + suffix if s[0] == "$" { - v.stashedContext = append(v.stashedContext, v.currentContext) - v.currentContext = []string{} - // v.StashContext() + v.StashContext() s = s[1:] } } - v.currentContext = append(v.currentContext, s...) - // v.PushContextMany(s) + v.PushContextMany(s) } } func (v *ValuesVisitor) ExitContextShift(node *sitter.Node) { switch node.Type() { case gotemplate.NodeTypeField, gotemplate.NodeTypeFieldIdentifier: - v.currentContext = v.currentContext[:len(v.currentContext)-1] - // v.PopContext() + v.PopContext() case gotemplate.NodeTypeSelectorExpression: s := getContextForSelectorExpression(node, v.content) if len(s) > 0 && s[0] == "$" { - v.currentContext = v.stashedContext[len(v.stashedContext)-1] - v.stashedContext = v.stashedContext[:len(v.stashedContext)-1] - s = s[1:] - // v.RestoreStashedContext() + v.RestoreStashedContext() } else { - v.currentContext = v.currentContext[:len(v.currentContext)-len(s)] - // v.PopContextN(len(s)) + v.PopContextN(len(s)) } } } diff --git a/internal/tree-sitter/gotemplate/node-types.go b/internal/tree-sitter/gotemplate/node-types.go index 0f535c40..085a8413 100644 --- a/internal/tree-sitter/gotemplate/node-types.go +++ b/internal/tree-sitter/gotemplate/node-types.go @@ -2,6 +2,7 @@ package gotemplate const ( NodeTypeAssignment = "assignment" + NodeTypeArgumentList = "argument_list" NodeTypeBlock = "block" NodeTypeBlockAction = "block_action" NodeTypeChainedPipeline = "chained_pipeline"