Skip to content

Commit

Permalink
replace_all
Browse files Browse the repository at this point in the history
new test
  • Loading branch information
xzbdmw committed Nov 24, 2024
1 parent 51e54e8 commit 0798b03
Show file tree
Hide file tree
Showing 16 changed files with 846 additions and 68 deletions.
Binary file added gopls/doc/assets/extract-expressions-after.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added gopls/doc/assets/extract-expressions-before.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified gopls/doc/assets/extract-var-after.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed gopls/doc/assets/extract-var-before.png
Binary file not shown.
12 changes: 10 additions & 2 deletions gopls/doc/features/transformation.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Gopls supports the following code actions:
- [`refactor.extract.method`](#extract)
- [`refactor.extract.toNewFile`](#extract.toNewFile)
- [`refactor.extract.variable`](#extract)
- [`refactor.extract.variable-all`](#extract)
- [`refactor.inline.call`](#refactor.inline.call)
- [`refactor.rewrite.changeQuote`](#refactor.rewrite.changeQuote)
- [`refactor.rewrite.fillStruct`](#refactor.rewrite.fillStruct)
Expand Down Expand Up @@ -353,11 +354,18 @@ newly created declaration that contains the selected code:
will be a method of the same receiver type.

- **`refactor.extract.variable`** replaces an expression by a reference to a new
local variable named `x` initialized by the expression:
local variable named `newVar` initialized by the expression:

![Before extracting a var](../assets/extract-var-before.png)
![Before extracting a var](../assets/extract-expressions-before.png)
![After extracting a var](../assets/extract-var-after.png)

- **`refactor.extract.variable-all`** replaces all occurrences of the selected expression
within the function with a reference to a new local variable named `newVar`.
This extracts the expression once and reuses it wherever it appears in the function.

![Before extracting all expressions](../assets/extract-expressions-before.png)
![After extracting all expressions](../assets/extract-expressions-after.png)

If the default name for the new declaration is already in use, gopls
generates a fresh name.

Expand Down
7 changes: 6 additions & 1 deletion gopls/doc/release/v0.17.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,9 @@ into account its signature, including input parameters and results.
Since this feature is implemented by the server (gopls), it is compatible with
all LSP-compliant editors. VS Code users may continue to use the client-side
`Go: Generate Unit Tests For file/function/package` command which utilizes the
[gotests](https://github.com/cweill/gotests) tool.
[gotests](https://github.com/cweill/gotests) tool.

## Extract all occurrences of the same expression under selection
When you have multiple instances of the same expression in a function,
you can use this code action to extract it into a variable.
All occurrences of the expression will be replaced with a reference to the new variable.
25 changes: 23 additions & 2 deletions gopls/internal/golang/codeaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"golang.org/x/tools/gopls/internal/protocol"
"golang.org/x/tools/gopls/internal/protocol/command"
"golang.org/x/tools/gopls/internal/settings"
"golang.org/x/tools/gopls/internal/util/safetoken"
"golang.org/x/tools/gopls/internal/util/typesutil"
"golang.org/x/tools/internal/event"
"golang.org/x/tools/internal/imports"
Expand Down Expand Up @@ -236,7 +237,8 @@ var codeActionProducers = [...]codeActionProducer{
{kind: settings.RefactorExtractFunction, fn: refactorExtractFunction},
{kind: settings.RefactorExtractMethod, fn: refactorExtractMethod},
{kind: settings.RefactorExtractToNewFile, fn: refactorExtractToNewFile},
{kind: settings.RefactorExtractVariable, fn: refactorExtractVariable},
{kind: settings.RefactorExtractVariableAll, fn: refactorExtractVairableAll, needPkg: true},
{kind: settings.RefactorExtractVariable, fn: refactorExtractVariable, needPkg: true},
{kind: settings.RefactorInlineCall, fn: refactorInlineCall, needPkg: true},
{kind: settings.RefactorRewriteChangeQuote, fn: refactorRewriteChangeQuote},
{kind: settings.RefactorRewriteFillStruct, fn: refactorRewriteFillStruct, needPkg: true},
Expand Down Expand Up @@ -463,12 +465,31 @@ func refactorExtractMethod(ctx context.Context, req *codeActionsRequest) error {
// refactorExtractVariable produces "Extract variable" code actions.
// See [extractVariable] for command implementation.
func refactorExtractVariable(ctx context.Context, req *codeActionsRequest) error {
if _, _, ok, _ := canExtractVariable(req.start, req.end, req.pgf.File); ok {
if _, ok, _ := canExtractVariable(req.pkg.TypesInfo(), req.start, req.end, req.pgf.File, false); ok {
req.addApplyFixAction("Extract variable", fixExtractVariable, req.loc)
}
return nil
}

// refactorExtractVairableAll produces "Extract n occurrences of expression" code action.
// See [extractAllOccursOfExpr] for command implementation.
func refactorExtractVairableAll(ctx context.Context, req *codeActionsRequest) error {
// Don't suggest if only one expr is found,
// otherwise it will duplicate with [refactorExtractVariable]
if exprs, ok, _ := canExtractVariable(req.pkg.TypesInfo(), req.start, req.end, req.pgf.File, true); ok && len(exprs) > 1 {
startOffset, endOffset, err := safetoken.Offsets(req.pgf.Tok, exprs[0].Pos(), exprs[0].End())
if err != nil {
return err
}
desc := string(req.pgf.Src[startOffset:endOffset])
if len(desc) >= 40 || strings.Contains(desc, "\n") {
desc = astutil.NodeDescription(exprs[0])
}
req.addApplyFixAction(fmt.Sprintf("Extract %d occurrences of %s", len(exprs), desc), fixExtractVariableAll, req.loc)
}
return nil
}

// refactorExtractToNewFile produces "Extract declarations to new file" code actions.
// See [server.commandHandler.ExtractToNewFile] for command implementation.
func refactorExtractToNewFile(ctx context.Context, req *codeActionsRequest) error {
Expand Down
206 changes: 173 additions & 33 deletions gopls/internal/golang/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,47 +20,112 @@ import (

"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/ast/astutil"
astutilinternal "golang.org/x/tools/gopls/internal/util/astutil"
"golang.org/x/tools/gopls/internal/util/bug"
"golang.org/x/tools/gopls/internal/util/safetoken"
"golang.org/x/tools/internal/analysisinternal"
"golang.org/x/tools/internal/typesinternal"
)

func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*token.FileSet, *analysis.SuggestedFix, error) {
return extractExprs(fset, start, end, src, file, pkg, info, false)
}

func extractVariableAll(fset *token.FileSet, start, end token.Pos, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*token.FileSet, *analysis.SuggestedFix, error) {
return extractExprs(fset, start, end, src, file, pkg, info, true)
}

// extractExprs replaces occurrence(s) of a specified expression within the same function
// with newVar. If 'all' is true, it replaces all occurrences of the same expression;
// otherwise, it only replaces the selected expression.
//
// The new variable is declared as close as possible to the first found expression
// within the deepest common scope accessible to all candidate occurrences.
func extractExprs(fset *token.FileSet, start, end token.Pos, src []byte, file *ast.File, pkg *types.Package, info *types.Info, all bool) (*token.FileSet, *analysis.SuggestedFix, error) {
tokFile := fset.File(file.FileStart)
expr, path, ok, err := canExtractVariable(start, end, file)
if !ok {
return nil, nil, fmt.Errorf("extractVariable: cannot extract %s: %v", safetoken.StartPosition(fset, start), err)
exprs, _, err := canExtractVariable(info, start, end, file, all)
if err != nil {
return nil, nil, fmt.Errorf("extractVariable: cannot extract: %v", err)
}

scopes := make([][]*types.Scope, len(exprs))
for i, e := range exprs {
path, _ := astutil.PathEnclosingInterval(file, e.Pos(), e.End())
scopes[i] = CollectScopes(info, path, e.Pos())
}

// Deduplicate, prepare to generate new variable name.
var scopeSet []*types.Scope
seen := make(map[*types.Scope]struct{})
for _, scope := range scopes {
for _, s := range scope {
if s != nil {
if _, exist := seen[s]; !exist {
seen[s] = struct{}{}
scopeSet = append(scopeSet, s)
}
}
}
}
scopeSet = append(scopeSet, pkg.Scope())

// Create new AST node for extracted code.
var lhsNames []string
switch expr := expr.(type) {
switch expr := exprs[0].(type) {
// TODO: stricter rules for selectorExpr.
case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr,
*ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
lhsName, _ := generateAvailableName(expr.Pos(), path, pkg, info, "x", 0)
*ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr, *ast.FuncLit:
lhsName, _ := generateAvailableNameByScopes(scopeSet, "newVar", 0)
lhsNames = append(lhsNames, lhsName)
case *ast.CallExpr:
tup, ok := info.TypeOf(expr).(*types.Tuple)
if !ok {
// If the call expression only has one return value, we can treat it the
// same as our standard extract variable case.
lhsName, _ := generateAvailableName(expr.Pos(), path, pkg, info, "x", 0)
lhsName, _ := generateAvailableNameByScopes(scopeSet, "newVar", 0)
lhsNames = append(lhsNames, lhsName)
break
}
idx := 0
for i := 0; i < tup.Len(); i++ {
// Generate a unique variable for each return value.
var lhsName string
lhsName, idx = generateAvailableName(expr.Pos(), path, pkg, info, "x", idx)
lhsName, idx = generateAvailableNameByScopes(scopeSet, "newVar", idx)
lhsNames = append(lhsNames, lhsName)
}
default:
return nil, nil, fmt.Errorf("cannot extract %T", expr)
}

var enclosingScopeOfFirstExpr *types.Scope
for _, scope := range scopes[0] {
if scope != nil {
enclosingScopeOfFirstExpr = scope
break
}
}
// Where all the extractable positions can see variable being declared.
commonScope, err := findDeepestCommonScope(scopes)
if err != nil {
return nil, nil, fmt.Errorf("extractVariable: %v", err)
}
var visiblePath []ast.Node
if commonScope != enclosingScopeOfFirstExpr {
// This means the first expr within function body is not the largest scope,
// we need to find the scope immediately follow the common
// scope where we will insert the statement before.
child := enclosingScopeOfFirstExpr
for p := child; p != nil; p = p.Parent() {
if p == commonScope {
break
}
child = p
}
visiblePath, _ = astutil.PathEnclosingInterval(file, child.Pos(), child.End())
} else {
// Insert newVar inside commonScope before the first occurrence of the expression.
visiblePath, _ = astutil.PathEnclosingInterval(file, exprs[0].Pos(), exprs[0].End())
}
// TODO: There is a bug here: for a variable declared in a labeled
// switch/for statement it returns the for/switch statement itself
// which produces the below code which is a compiler error e.g.
Expand All @@ -70,7 +135,7 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file
// label:
// x := r()
// switch r1 := x { ... break label ... } // compiler error
insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path)
insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(visiblePath)
if insertBeforeStmt == nil {
return nil, nil, fmt.Errorf("cannot find location to insert extraction")
}
Expand All @@ -84,59 +149,134 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file
assignStmt := &ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(lhs)},
Tok: token.DEFINE,
Rhs: []ast.Expr{expr},
Rhs: []ast.Expr{exprs[0]},
}
var buf bytes.Buffer
if err := format.Node(&buf, fset, assignStmt); err != nil {
return nil, nil, err
}
assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent

textEdits := []analysis.TextEdit{{
Pos: insertBeforeStmt.Pos(),
End: insertBeforeStmt.Pos(),
NewText: []byte(assignment),
}}
for _, e := range exprs {
textEdits = append(textEdits, analysis.TextEdit{
Pos: e.Pos(),
End: e.End(),
NewText: []byte(lhs),
})
}
return fset, &analysis.SuggestedFix{
TextEdits: []analysis.TextEdit{
{
Pos: insertBeforeStmt.Pos(),
End: insertBeforeStmt.Pos(),
NewText: []byte(assignment),
},
{
Pos: start,
End: end,
NewText: []byte(lhs),
},
},
TextEdits: textEdits,
}, nil
}

// findDeepestCommonScope finds the deepest (innermost) scope that is common to all provided scope chains.
// Each scope chain represents the scopes of an expression from innermost to outermost.
// If no common scope is found, it returns an error.
func findDeepestCommonScope(scopeChains [][]*types.Scope) (*types.Scope, error) {
if len(scopeChains) == 0 {
return nil, fmt.Errorf("no scopes provided")
}
// Get the first scope chain as the reference.
referenceChain := scopeChains[0]

// Iterate from innermost to outermost scope.
for i := 0; i < len(referenceChain); i++ {
candidateScope := referenceChain[i]
if candidateScope == nil {
continue
}
isCommon := true
// See if other exprs' chains all have candidateScope as a common ancestor.
for _, chain := range scopeChains[1:] {
found := false
for j := 0; j < len(chain); j++ {
if chain[j] == candidateScope {
found = true
break
}
}
if !found {
isCommon = false
break
}
}
if isCommon {
return candidateScope, nil
}
}
return nil, fmt.Errorf("no common scope found")
}

// generateAvailableNameByScopes adjusts the new identifier name
// until there are no collisions in any of the provided scopes.
func generateAvailableNameByScopes(scopes []*types.Scope, prefix string, idx int) (string, int) {
return generateName(idx, prefix, func(name string) bool {
for _, scope := range scopes {
if scope != nil && scope.Lookup(name) != nil {
return true
}
}
return false
})
}

// canExtractVariable reports whether the code in the given range can be
// extracted to a variable.
func canExtractVariable(start, end token.Pos, file *ast.File) (ast.Expr, []ast.Node, bool, error) {
// extracted to a variable. It returns the selected expression or
// all occurrences of expression structural equal to selected one
// sorted by position depends on all.
func canExtractVariable(info *types.Info, start, end token.Pos, file *ast.File, all bool) ([]ast.Expr, bool, error) {
if start == end {
return nil, nil, false, fmt.Errorf("start and end are equal")
return nil, false, fmt.Errorf("start and end are equal")
}
path, _ := astutil.PathEnclosingInterval(file, start, end)
if len(path) == 0 {
return nil, nil, false, fmt.Errorf("no path enclosing interval")
return nil, false, fmt.Errorf("no path enclosing interval")
}
for _, n := range path {
if _, ok := n.(*ast.ImportSpec); ok {
return nil, nil, false, fmt.Errorf("cannot extract variable in an import block")
return nil, false, fmt.Errorf("cannot extract variable in an import block")
}
}
node := path[0]
if start != node.Pos() || end != node.End() {
return nil, nil, false, fmt.Errorf("range does not map to an AST node")
return nil, false, fmt.Errorf("range does not map to an AST node")
}
expr, ok := node.(ast.Expr)
if !ok {
return nil, nil, false, fmt.Errorf("node is not an expression")
return nil, false, fmt.Errorf("node is not an expression")
}

var exprs []ast.Expr
if !all {
exprs = append(exprs, expr)
} else if funcDecl, ok := path[len(path)-2].(*ast.FuncDecl); ok {
ast.Inspect(funcDecl, func(n ast.Node) bool {
if e, ok := n.(ast.Expr); ok {
if astutilinternal.Equal(e, expr, func(x, y *ast.Ident) bool {
return x.Name == y.Name && info.ObjectOf(x) == info.ObjectOf(y)
}) {
exprs = append(exprs, e)
}
}
return true
})
} else {
return nil, false, fmt.Errorf("node %T is not inside a function", expr)
}
sort.Slice(exprs, func(i, j int) bool {
return exprs[i].Pos() < exprs[j].Pos()
})

switch expr.(type) {
case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.CallExpr,
*ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
return expr, path, true, nil
*ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr, *ast.FuncLit:
return exprs, true, nil
}
return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
return nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
}

// Calculate indentation for insertion.
Expand Down
2 changes: 2 additions & 0 deletions gopls/internal/golang/fix.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func singleFile(fixer1 singleFileFixer) fixer {
// Names of ApplyFix.Fix created directly by the CodeAction handler.
const (
fixExtractVariable = "extract_variable"
fixExtractVariableAll = "extract_variable_all"
fixExtractFunction = "extract_function"
fixExtractMethod = "extract_method"
fixInlineCall = "inline_call"
Expand Down Expand Up @@ -106,6 +107,7 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file
fixExtractFunction: singleFile(extractFunction),
fixExtractMethod: singleFile(extractMethod),
fixExtractVariable: singleFile(extractVariable),
fixExtractVariableAll: singleFile(extractVariableAll),
fixInlineCall: inlineCall,
fixInvertIfCondition: singleFile(invertIfCondition),
fixSplitLines: singleFile(splitLines),
Expand Down
Loading

0 comments on commit 0798b03

Please sign in to comment.