// Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package inline import ( "bytes" "fmt" "go/ast" "go/constant" "go/format" "go/parser" "go/token" "go/types" pathpkg "path" "reflect" "strconv" "strings" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/types/typeutil" "golang.org/x/tools/imports" internalastutil "golang.org/x/tools/internal/astutil" "golang.org/x/tools/internal/typeparams" ) // A Caller describes the function call and its enclosing context. // // The client is responsible for populating this struct and passing it to Inline. type Caller struct { Fset *token.FileSet Types *types.Package Info *types.Info File *ast.File Call *ast.CallExpr Content []byte // source of file containing path []ast.Node // path from call to root of file syntax tree enclosingFunc *ast.FuncDecl // top-level function/method enclosing the call, if any } // Options specifies parameters affecting the inliner algorithm. // All fields are optional. type Options struct { Logf func(string, ...any) // log output function, records decision-making process IgnoreEffects bool // ignore potential side effects of arguments (unsound) } // Result holds the result of code transformation. type Result struct { Content []byte // formatted, transformed content of caller file Literalized bool // chosen strategy replaced callee() with func(){...}() // TODO(adonovan): provide an API for clients that want structured // output: a list of import additions and deletions plus one or more // localized diffs (or even AST transformations, though ownership and // mutation are tricky) near the call site. } // Inline inlines the called function (callee) into the function call (caller) // and returns the updated, formatted content of the caller source file. // // Inline does not mutate any public fields of Caller or Callee. func Inline(caller *Caller, callee *Callee, opts *Options) (*Result, error) { copy := *opts // shallow copy opts = © // Set default options. if opts.Logf == nil { opts.Logf = func(string, ...any) {} } st := &state{ caller: caller, callee: callee, opts: opts, } return st.inline() } // state holds the working state of the inliner. type state struct { caller *Caller callee *Callee opts *Options } func (st *state) inline() (*Result, error) { logf, caller, callee := st.opts.Logf, st.caller, st.callee logf("inline %s @ %v", debugFormatNode(caller.Fset, caller.Call), caller.Fset.PositionFor(caller.Call.Lparen, false)) if !consistentOffsets(caller) { return nil, fmt.Errorf("internal error: caller syntax positions are inconsistent with file content (did you forget to use FileSet.PositionFor when computing the file name?)") } // TODO(adonovan): use go1.21's ast.IsGenerated. // Break the string literal so we can use inlining in this file. :) if bytes.Contains(caller.Content, []byte("// Code generated by "+"cmd/cgo; DO NOT EDIT.")) { return nil, fmt.Errorf("cannot inline calls from files that import \"C\"") } res, err := st.inlineCall() if err != nil { return nil, err } // Replace the call (or some node that encloses it) by new syntax. assert(res.old != nil, "old is nil") assert(res.new != nil, "new is nil") // A single return operand inlined to a unary // expression context may need parens. Otherwise: // func two() int { return 1+1 } // print(-two()) => print(-1+1) // oops! // // Usually it is not necessary to insert ParenExprs // as the formatter is smart enough to insert them as // needed by the context. But the res.{old,new} // substitution is done by formatting res.new in isolation // and then splicing its text over res.old, so the // formatter doesn't see the parent node and cannot do // the right thing. (One solution would be to always // format the enclosing node of old, but that requires // non-lossy comment handling, #20744.) // // So, we must analyze the call's context // to see whether ambiguity is possible. // For example, if the context is x[y:z], then // the x subtree is subject to precedence ambiguity // (replacing x by p+q would give p+q[y:z] which is wrong) // but the y and z subtrees are safe. if needsParens(caller.path, res.old, res.new) { res.new = &ast.ParenExpr{X: res.new.(ast.Expr)} } // Some reduction strategies return a new block holding the // callee's statements. The block's braces may be elided when // there is no conflict between names declared in the block // with those declared by the parent block, and no risk of // a caller's goto jumping forward across a declaration. // // This elision is only safe when the ExprStmt is beneath a // BlockStmt, CaseClause.Body, or CommClause.Body; // (see "statement theory"). // // The inlining analysis may have already determined that eliding braces is // safe. Otherwise, we analyze its safety here. elideBraces := res.elideBraces if !elideBraces { if newBlock, ok := res.new.(*ast.BlockStmt); ok { i := nodeIndex(caller.path, res.old) parent := caller.path[i+1] var body []ast.Stmt switch parent := parent.(type) { case *ast.BlockStmt: body = parent.List case *ast.CommClause: body = parent.Body case *ast.CaseClause: body = parent.Body } if body != nil { callerNames := declares(body) // If BlockStmt is a function body, // include its receiver, params, and results. addFieldNames := func(fields *ast.FieldList) { if fields != nil { for _, field := range fields.List { for _, id := range field.Names { callerNames[id.Name] = true } } } } switch f := caller.path[i+2].(type) { case *ast.FuncDecl: addFieldNames(f.Recv) addFieldNames(f.Type.Params) addFieldNames(f.Type.Results) case *ast.FuncLit: addFieldNames(f.Type.Params) addFieldNames(f.Type.Results) } if len(callerLabels(caller.path)) > 0 { // TODO(adonovan): be more precise and reject // only forward gotos across the inlined block. logf("keeping block braces: caller uses control labels") } else if intersects(declares(newBlock.List), callerNames) { logf("keeping block braces: avoids name conflict") } else { elideBraces = true } } } } // Don't call replaceNode(caller.File, res.old, res.new) // as it mutates the caller's syntax tree. // Instead, splice the file, replacing the extent of the "old" // node by a formatting of the "new" node, and re-parse. // We'll fix up the imports on this new tree, and format again. var f *ast.File { start := offsetOf(caller.Fset, res.old.Pos()) end := offsetOf(caller.Fset, res.old.End()) var out bytes.Buffer out.Write(caller.Content[:start]) // TODO(adonovan): might it make more sense to use // callee.Fset when formatting res.new? // The new tree is a mix of (cloned) caller nodes for // the argument expressions and callee nodes for the // function body. In essence the question is: which // is more likely to have comments? // Usually the callee body will be larger and more // statement-heavy than the arguments, but a // strategy may widen the scope of the replacement // (res.old) from CallExpr to, say, its enclosing // block, so the caller nodes dominate. // Precise comment handling would make this a // non-issue. Formatting wouldn't really need a // FileSet at all. if elideBraces { for i, stmt := range res.new.(*ast.BlockStmt).List { if i > 0 { out.WriteByte('\n') } if err := format.Node(&out, caller.Fset, stmt); err != nil { return nil, err } } } else { if err := format.Node(&out, caller.Fset, res.new); err != nil { return nil, err } } out.Write(caller.Content[end:]) const mode = parser.ParseComments | parser.SkipObjectResolution | parser.AllErrors f, err = parser.ParseFile(caller.Fset, "callee.go", &out, mode) if err != nil { // Something has gone very wrong. logf("failed to parse <<%s>>", &out) // debugging return nil, err } } // Add new imports. // // Insert new imports after last existing import, // to avoid migration of pre-import comments. // The imports will be organized below. if len(res.newImports) > 0 { var importDecl *ast.GenDecl if len(f.Imports) > 0 { // Append specs to existing import decl importDecl = f.Decls[0].(*ast.GenDecl) } else { // Insert new import decl. importDecl = &ast.GenDecl{Tok: token.IMPORT} f.Decls = prepend[ast.Decl](importDecl, f.Decls...) } for _, imp := range res.newImports { // Check that the new imports are accessible. path, _ := strconv.Unquote(imp.spec.Path.Value) if !canImport(caller.Types.Path(), path) { return nil, fmt.Errorf("can't inline function %v as its body refers to inaccessible package %q", callee, path) } importDecl.Specs = append(importDecl.Specs, imp.spec) } } var out bytes.Buffer if err := format.Node(&out, caller.Fset, f); err != nil { return nil, err } newSrc := out.Bytes() // Remove imports that are no longer referenced. // // It ought to be possible to compute the set of PkgNames used // by the "old" code, compute the free identifiers of the // "new" code using a syntax-only (no go/types) algorithm, and // see if the reduction in the number of uses of any PkgName // equals the number of times it appears in caller.Info.Uses, // indicating that it is no longer referenced by res.new. // // However, the notorious ambiguity of resolving T{F: 0} makes this // unreliable: without types, we can't tell whether F refers to // a field of struct T, or a package-level const/var of a // dot-imported (!) package. // // So, for now, we run imports.Process, which is // unsatisfactory as it has to run the go command, and it // looks at the user's module cache state--unnecessarily, // since this step cannot add new imports. // // TODO(adonovan): replace with a simpler implementation since // all the necessary imports are present but merely untidy. // That will be faster, and also less prone to nondeterminism // if there are bugs in our logic for import maintenance. // // However, golang.org/x/tools/internal/imports.ApplyFixes is // too simple as it requires the caller to have figured out // all the logical edits. In our case, we know all the new // imports that are needed (see newImports), each of which can // be specified as: // // &imports.ImportFix{ // StmtInfo: imports.ImportInfo{path, name, // IdentName: name, // FixType: imports.AddImport, // } // // but we don't know which imports are made redundant by the // inlining itself. For example, inlining a call to // fmt.Println may make the "fmt" import redundant. // // Also, both imports.Process and internal/imports.ApplyFixes // reformat the entire file, which is not ideal for clients // such as gopls. (That said, the point of a canonical format // is arguably that any tool can reformat as needed without // this being inconvenient.) // // We could invoke imports.Process and parse its result, // compare against the original AST, compute a list of import // fixes, and return that too. // Recompute imports only if there were existing ones. if len(f.Imports) > 0 { formatted, err := imports.Process("output", newSrc, nil) if err != nil { logf("cannot reformat: %v <<%s>>", err, &out) return nil, err // cannot reformat (a bug?) } newSrc = formatted } literalized := false if call, ok := res.new.(*ast.CallExpr); ok && is[*ast.FuncLit](call.Fun) { literalized = true } return &Result{ Content: newSrc, Literalized: literalized, }, nil } type newImport struct { pkgName string spec *ast.ImportSpec } type inlineCallResult struct { newImports []newImport // If elideBraces is set, old is an ast.Stmt and new is an ast.BlockStmt to // be spliced in. This allows the inlining analysis to assert that inlining // the block is OK; if elideBraces is unset and old is an ast.Stmt and new is // an ast.BlockStmt, braces may still be elided if the post-processing // analysis determines that it is safe to do so. // // Ideally, it would not be necessary for the inlining analysis to "reach // through" to the post-processing pass in this way. Instead, inlining could // just set old to be an ast.BlockStmt and rewrite the entire BlockStmt, but // unfortunately in order to preserve comments, it is important that inlining // replace as little syntax as possible. elideBraces bool old, new ast.Node // e.g. replace call expr by callee function body expression } // inlineCall returns a pair of an old node (the call, or something // enclosing it) and a new node (its replacement, which may be a // combination of caller, callee, and new nodes), along with the set // of new imports needed. // // TODO(adonovan): rethink the 'result' interface. The assumption of a // one-to-one replacement seems fragile. One can easily imagine the // transformation replacing the call and adding new variable // declarations, for example, or replacing a call statement by zero or // many statements.) // // TODO(adonovan): in earlier drafts, the transformation was expressed // by splicing substrings of the two source files because syntax // trees don't preserve comments faithfully (see #20744), but such // transformations don't compose. The current implementation is // tree-based but is very lossy wrt comments. It would make a good // candidate for evaluating an alternative fully self-contained tree // representation, such as any proposed solution to #20744, or even // dst or some private fork of go/ast.) func (st *state) inlineCall() (*inlineCallResult, error) { logf, caller, callee := st.opts.Logf, st.caller, &st.callee.impl checkInfoFields(caller.Info) // Inlining of dynamic calls is not currently supported, // even for local closure calls. (This would be a lot of work.) calleeSymbol := typeutil.StaticCallee(caller.Info, caller.Call) if calleeSymbol == nil { // e.g. interface method return nil, fmt.Errorf("cannot inline: not a static function call") } // Reject cross-package inlining if callee has // free references to unexported symbols. samePkg := caller.Types.Path() == callee.PkgPath if !samePkg && len(callee.Unexported) > 0 { return nil, fmt.Errorf("cannot inline call to %s because body refers to non-exported %s", callee.Name, callee.Unexported[0]) } // -- analyze callee's free references in caller context -- // Compute syntax path enclosing Call, innermost first (Path[0]=Call), // and outermost enclosing function, if any. caller.path, _ = astutil.PathEnclosingInterval(caller.File, caller.Call.Pos(), caller.Call.End()) for _, n := range caller.path { if decl, ok := n.(*ast.FuncDecl); ok { caller.enclosingFunc = decl break } } // If call is within a function, analyze all its // local vars for the "single assignment" property. // (Taking the address &v counts as a potential assignment.) var assign1 func(v *types.Var) bool // reports whether v a single-assignment local var { updatedLocals := make(map[*types.Var]bool) if caller.enclosingFunc != nil { escape(caller.Info, caller.enclosingFunc, func(v *types.Var, _ bool) { updatedLocals[v] = true }) logf("multiple-assignment vars: %v", updatedLocals) } assign1 = func(v *types.Var) bool { return !updatedLocals[v] } } // import map, initially populated with caller imports. // // For simplicity we ignore existing dot imports, so that a // qualified identifier (QI) in the callee is always // represented by a QI in the caller, allowing us to treat a // QI like a selection on a package name. importMap := make(map[string][]string) // maps package path to local name(s) for _, imp := range caller.File.Imports { if pkgname, ok := importedPkgName(caller.Info, imp); ok && pkgname.Name() != "." && pkgname.Name() != "_" { path := pkgname.Imported().Path() importMap[path] = append(importMap[path], pkgname.Name()) } } // localImportName returns the local name for a given imported package path. var newImports []newImport localImportName := func(obj *object) string { // Does an import exist? for _, name := range importMap[obj.PkgPath] { // Check that either the import preexisted, // or that it was newly added (no PkgName) but is not shadowed, // either in the callee (shadows) or caller (caller.lookup). if !obj.Shadow[name] { found := caller.lookup(name) if is[*types.PkgName](found) || found == nil { return name } } } newlyAdded := func(name string) bool { for _, new := range newImports { if new.pkgName == name { return true } } return false } // import added by callee // // Choose local PkgName based on last segment of // package path plus, if needed, a numeric suffix to // ensure uniqueness. // // "init" is not a legal PkgName. // // TODO(rfindley): is it worth preserving local package names for callee // imports? Are they likely to be better or worse than the name we choose // here? base := obj.PkgName name := base for n := 0; obj.Shadow[name] || caller.lookup(name) != nil || newlyAdded(name) || name == "init"; n++ { name = fmt.Sprintf("%s%d", base, n) } logf("adding import %s %q", name, obj.PkgPath) spec := &ast.ImportSpec{ Path: &ast.BasicLit{ Kind: token.STRING, Value: strconv.Quote(obj.PkgPath), }, } // Use explicit pkgname (out of necessity) when it differs from the declared name, // or (for good style) when it differs from base(pkgpath). if name != obj.PkgName || name != pathpkg.Base(obj.PkgPath) { spec.Name = makeIdent(name) } newImports = append(newImports, newImport{ pkgName: name, spec: spec, }) importMap[obj.PkgPath] = append(importMap[obj.PkgPath], name) return name } // Compute the renaming of the callee's free identifiers. objRenames := make([]ast.Expr, len(callee.FreeObjs)) // nil => no change for i, obj := range callee.FreeObjs { // obj is a free object of the callee. // // Possible cases are: // - builtin function, type, or value (e.g. nil, zero) // => check not shadowed in caller. // - package-level var/func/const/types // => same package: check not shadowed in caller. // => otherwise: import other package, form a qualified identifier. // (Unexported cross-package references were rejected already.) // - type parameter // => not yet supported // - pkgname // => import other package and use its local name. // // There can be no free references to labels, fields, or methods. // Note that we must consider potential shadowing both // at the caller side (caller.lookup) and, when // choosing new PkgNames, within the callee (obj.shadow). var newName ast.Expr if obj.Kind == "pkgname" { // Use locally appropriate import, creating as needed. newName = makeIdent(localImportName(&obj)) // imported package } else if !obj.ValidPos { // Built-in function, type, or value (e.g. nil, zero): // check not shadowed at caller. found := caller.lookup(obj.Name) // always finds something if found.Pos().IsValid() { return nil, fmt.Errorf("cannot inline, because the callee refers to built-in %q, which in the caller is shadowed by a %s (declared at line %d)", obj.Name, objectKind(found), caller.Fset.PositionFor(found.Pos(), false).Line) } } else { // Must be reference to package-level var/func/const/type, // since type parameters are not yet supported. qualify := false if obj.PkgPath == callee.PkgPath { // reference within callee package if samePkg { // Caller and callee are in same package. // Check caller has not shadowed the decl. // // This may fail if the callee is "fake", such as for signature // refactoring where the callee is modified to be a trivial wrapper // around the refactored signature. found := caller.lookup(obj.Name) if found != nil && !isPkgLevel(found) { return nil, fmt.Errorf("cannot inline, because the callee refers to %s %q, which in the caller is shadowed by a %s (declared at line %d)", obj.Kind, obj.Name, objectKind(found), caller.Fset.PositionFor(found.Pos(), false).Line) } } else { // Cross-package reference. qualify = true } } else { // Reference to a package-level declaration // in another package, without a qualified identifier: // it must be a dot import. qualify = true } // Form a qualified identifier, pkg.Name. if qualify { pkgName := localImportName(&obj) newName = &ast.SelectorExpr{ X: makeIdent(pkgName), Sel: makeIdent(obj.Name), } } } objRenames[i] = newName } res := &inlineCallResult{ newImports: newImports, } // Parse callee function declaration. calleeFset, calleeDecl, err := parseCompact(callee.Content) if err != nil { return nil, err // "can't happen" } // replaceCalleeID replaces an identifier in the callee. // The replacement tree must not belong to the caller; use cloneNode as needed. replaceCalleeID := func(offset int, repl ast.Expr) { id := findIdent(calleeDecl, calleeDecl.Pos()+token.Pos(offset)) logf("- replace id %q @ #%d to %q", id.Name, offset, debugFormatNode(calleeFset, repl)) replaceNode(calleeDecl, id, repl) } // Generate replacements for each free identifier. // (The same tree may be spliced in multiple times, resulting in a DAG.) for _, ref := range callee.FreeRefs { if repl := objRenames[ref.Object]; repl != nil { replaceCalleeID(ref.Offset, repl) } } // Gather the effective call arguments, including the receiver. // Later, elements will be eliminated (=> nil) by parameter substitution. args, err := st.arguments(caller, calleeDecl, assign1) if err != nil { return nil, err // e.g. implicit field selection cannot be made explicit } // Gather effective parameter tuple, including the receiver if any. // Simplify variadic parameters to slices (in all cases but one). var params []*parameter // including receiver; nil => parameter substituted { sig := calleeSymbol.Type().(*types.Signature) if sig.Recv() != nil { params = append(params, ¶meter{ obj: sig.Recv(), fieldType: calleeDecl.Recv.List[0].Type, info: callee.Params[0], }) } // Flatten the list of syntactic types. var types []ast.Expr for _, field := range calleeDecl.Type.Params.List { if field.Names == nil { types = append(types, field.Type) } else { for range field.Names { types = append(types, field.Type) } } } for i := 0; i < sig.Params().Len(); i++ { params = append(params, ¶meter{ obj: sig.Params().At(i), fieldType: types[i], info: callee.Params[len(params)], }) } // Variadic function? // // There are three possible types of call: // - ordinary f(a1, ..., aN) // - ellipsis f(a1, ..., slice...) // - spread f(recv?, g()) where g() is a tuple. // The first two are desugared to non-variadic calls // with an ordinary slice parameter; // the third is tricky and cannot be reduced, and (if // a receiver is present) cannot even be literalized. // Fortunately it is vanishingly rare. // // TODO(adonovan): extract this to a function. if sig.Variadic() { lastParam := last(params) if len(args) > 0 && last(args).spread { // spread call to variadic: tricky lastParam.variadic = true } else { // ordinary/ellipsis call to variadic // simplify decl: func(T...) -> func([]T) lastParamField := last(calleeDecl.Type.Params.List) lastParamField.Type = &ast.ArrayType{ Elt: lastParamField.Type.(*ast.Ellipsis).Elt, } if caller.Call.Ellipsis.IsValid() { // ellipsis call: f(slice...) -> f(slice) // nop } else { // ordinary call: f(a1, ... aN) -> f([]T{a1, ..., aN}) n := len(params) - 1 ordinary, extra := args[:n], args[n:] var elts []ast.Expr pure, effects := true, false for _, arg := range extra { elts = append(elts, arg.expr) pure = pure && arg.pure effects = effects || arg.effects } args = append(ordinary, &argument{ expr: &ast.CompositeLit{ Type: lastParamField.Type, Elts: elts, }, typ: lastParam.obj.Type(), constant: nil, pure: pure, effects: effects, duplicable: false, freevars: nil, // not needed }) } } } } // Log effective arguments. for i, arg := range args { logf("arg #%d: %s pure=%t effects=%t duplicable=%t free=%v type=%v", i, debugFormatNode(caller.Fset, arg.expr), arg.pure, arg.effects, arg.duplicable, arg.freevars, arg.typ) } // Note: computation below should be expressed in terms of // the args and params slices, not the raw material. // Perform parameter substitution. // May eliminate some elements of params/args. substitute(logf, caller, params, args, callee.Effects, callee.Falcon, replaceCalleeID) // Update the callee's signature syntax. updateCalleeParams(calleeDecl, params) // Create a var (param = arg; ...) decl for use by some strategies. bindingDecl := createBindingDecl(logf, caller, args, calleeDecl, callee.Results) var remainingArgs []ast.Expr for _, arg := range args { if arg != nil { remainingArgs = append(remainingArgs, arg.expr) } } // -- let the inlining strategies begin -- // // When we commit to a strategy, we log a message of the form: // // "strategy: reduce expr-context call to { return expr }" // // This is a terse way of saying: // // we plan to reduce a call // that appears in expression context // to a function whose body is of the form { return expr } // TODO(adonovan): split this huge function into a sequence of // function calls with an error sentinel that means "try the // next strategy", and make sure each strategy writes to the // log the reason it didn't match. // Special case: eliminate a call to a function whose body is empty. // (=> callee has no results and caller is a statement.) // // func f(params) {} // f(args) // => _, _ = args // if len(calleeDecl.Body.List) == 0 { logf("strategy: reduce call to empty body") // Evaluate the arguments for effects and delete the call entirely. stmt := callStmt(caller.path, false) // cannot fail res.old = stmt if nargs := len(remainingArgs); nargs > 0 { // Emit "_, _ = args" to discard results. // TODO(adonovan): if args is the []T{a1, ..., an} // literal synthesized during variadic simplification, // consider unwrapping it to its (pure) elements. // Perhaps there's no harm doing this for any slice literal. // Make correction for spread calls // f(g()) or recv.f(g()) where g() is a tuple. if last := last(args); last != nil && last.spread { nspread := last.typ.(*types.Tuple).Len() if len(args) > 1 { // [recv, g()] // A single AssignStmt cannot discard both, so use a 2-spec var decl. res.new = &ast.GenDecl{ Tok: token.VAR, Specs: []ast.Spec{ &ast.ValueSpec{ Names: []*ast.Ident{makeIdent("_")}, Values: []ast.Expr{args[0].expr}, }, &ast.ValueSpec{ Names: blanks[*ast.Ident](nspread), Values: []ast.Expr{args[1].expr}, }, }, } return res, nil } // Sole argument is spread call. nargs = nspread } res.new = &ast.AssignStmt{ Lhs: blanks[ast.Expr](nargs), Tok: token.ASSIGN, Rhs: remainingArgs, } } else { // No remaining arguments: delete call statement entirely res.new = &ast.EmptyStmt{} } return res, nil } // If all parameters have been substituted and no result // variable is referenced, we don't need a binding decl. // This may enable better reduction strategies. allResultsUnreferenced := forall(callee.Results, func(i int, r *paramInfo) bool { return len(r.Refs) == 0 }) needBindingDecl := !allResultsUnreferenced || exists(params, func(i int, p *parameter) bool { return p != nil }) // The two strategies below overlap for a tail call of {return exprs}: // The expr-context reduction is nice because it keeps the // caller's return stmt and merely switches its operand, // without introducing a new block, but it doesn't work with // implicit return conversions. // // TODO(adonovan): unify these cases more cleanly, allowing return- // operand replacement and implicit conversions, by adding // conversions around each return operand (if not a spread return). // Special case: call to { return exprs }. // // Reduces to: // { var (bindings); _, _ = exprs } // or _, _ = exprs // or expr // // If: // - the body is just "return expr" with trivial implicit conversions, // or the caller's return type matches the callee's, // - all parameters and result vars can be eliminated // or replaced by a binding decl, // then the call expression can be replaced by the // callee's body expression, suitably substituted. if len(calleeDecl.Body.List) == 1 && is[*ast.ReturnStmt](calleeDecl.Body.List[0]) && len(calleeDecl.Body.List[0].(*ast.ReturnStmt).Results) > 0 { // not a bare return results := calleeDecl.Body.List[0].(*ast.ReturnStmt).Results parent, grandparent := callContext(caller.path) // statement context if stmt, ok := parent.(*ast.ExprStmt); ok && (!needBindingDecl || bindingDecl != nil) { logf("strategy: reduce stmt-context call to { return exprs }") clearPositions(calleeDecl.Body) if callee.ValidForCallStmt { logf("callee body is valid as statement") // Inv: len(results) == 1 if !needBindingDecl { // Reduces to: expr res.old = caller.Call res.new = results[0] } else { // Reduces to: { var (bindings); expr } res.old = stmt res.new = &ast.BlockStmt{ List: []ast.Stmt{ bindingDecl.stmt, &ast.ExprStmt{X: results[0]}, }, } } } else { logf("callee body is not valid as statement") // The call is a standalone statement, but the // callee body is not suitable as a standalone statement // (f() or <-ch), explicitly discard the results: // Reduces to: _, _ = exprs discard := &ast.AssignStmt{ Lhs: blanks[ast.Expr](callee.NumResults), Tok: token.ASSIGN, Rhs: results, } res.old = stmt if !needBindingDecl { // Reduces to: _, _ = exprs res.new = discard } else { // Reduces to: { var (bindings); _, _ = exprs } res.new = &ast.BlockStmt{ List: []ast.Stmt{ bindingDecl.stmt, discard, }, } } } return res, nil } // Assignment context. // // If there is no binding decl, or if the binding decl declares no names, // an assignment a, b := f() can be reduced to a, b := x, y. if stmt, ok := parent.(*ast.AssignStmt); ok && is[*ast.BlockStmt](grandparent) && (!needBindingDecl || (bindingDecl != nil && len(bindingDecl.names) == 0)) { // Reduces to: { var (bindings); lhs... := rhs... } if newStmts, ok := st.assignStmts(stmt, results); ok { logf("strategy: reduce assign-context call to { return exprs }") clearPositions(calleeDecl.Body) block := &ast.BlockStmt{ List: newStmts, } if needBindingDecl { block.List = prepend(bindingDecl.stmt, block.List...) } // assignStmts does not introduce new bindings, and replacing an // assignment only works if the replacement occurs in the same scope. // Therefore, we must ensure that braces are elided. res.elideBraces = true res.old = stmt res.new = block return res, nil } } // expression context if !needBindingDecl { clearPositions(calleeDecl.Body) anyNonTrivialReturns := hasNonTrivialReturn(callee.Returns) if callee.NumResults == 1 { logf("strategy: reduce expr-context call to { return expr }") // (includes some simple tail-calls) // Make implicit return conversion explicit. if anyNonTrivialReturns { results[0] = convert(calleeDecl.Type.Results.List[0].Type, results[0]) } res.old = caller.Call res.new = results[0] return res, nil } else if !anyNonTrivialReturns { logf("strategy: reduce spread-context call to { return expr }") // There is no general way to reify conversions in a spread // return, hence the requirement above. // // TODO(adonovan): allow this reduction when no // conversion is required by the context. // The call returns multiple results but is // not a standalone call statement. It must // be the RHS of a spread assignment: // var x, y = f() // x, y := f() // x, y = f() // or the sole argument to a spread call: // printf(f()) // or spread return statement: // return f() res.old = parent switch context := parent.(type) { case *ast.AssignStmt: // Inv: the call must be in Rhs[0], not Lhs. assign := shallowCopy(context) assign.Rhs = results res.new = assign case *ast.ValueSpec: // Inv: the call must be in Values[0], not Names. spec := shallowCopy(context) spec.Values = results res.new = spec case *ast.CallExpr: // Inv: the call must be in Args[0], not Fun. call := shallowCopy(context) call.Args = results res.new = call case *ast.ReturnStmt: // Inv: the call must be Results[0]. ret := shallowCopy(context) ret.Results = results res.new = ret default: return nil, fmt.Errorf("internal error: unexpected context %T for spread call", context) } return res, nil } } } // Special case: tail-call. // // Inlining: // return f(args) // where: // func f(params) (results) { body } // reduces to: // { var (bindings); body } // { body } // so long as: // - all parameters can be eliminated or replaced by a binding decl, // - call is a tail-call; // - all returns in body have trivial result conversions, // or the caller's return type matches the callee's, // - there is no label conflict; // - no result variable is referenced by name, // or implicitly by a bare return. // // The body may use defer, arbitrary control flow, and // multiple returns. // // TODO(adonovan): add a strategy for a 'void tail // call', i.e. a call statement prior to an (explicit // or implicit) return. parent, _ := callContext(caller.path) if ret, ok := parent.(*ast.ReturnStmt); ok && len(ret.Results) == 1 && tailCallSafeReturn(caller, calleeSymbol, callee) && !callee.HasBareReturn && (!needBindingDecl || bindingDecl != nil) && !hasLabelConflict(caller.path, callee.Labels) && allResultsUnreferenced { logf("strategy: reduce tail-call") body := calleeDecl.Body clearPositions(body) if needBindingDecl { body.List = prepend(bindingDecl.stmt, body.List...) } res.old = ret res.new = body return res, nil } // Special case: call to void function // // Inlining: // f(args) // where: // func f(params) { stmts } // reduces to: // { var (bindings); stmts } // { stmts } // so long as: // - callee is a void function (no returns) // - callee does not use defer // - there is no label conflict between caller and callee // - all parameters and result vars can be eliminated // or replaced by a binding decl, // - caller ExprStmt is in unrestricted statement context. if stmt := callStmt(caller.path, true); stmt != nil && (!needBindingDecl || bindingDecl != nil) && !callee.HasDefer && !hasLabelConflict(caller.path, callee.Labels) && len(callee.Returns) == 0 { logf("strategy: reduce stmt-context call to { stmts }") body := calleeDecl.Body var repl ast.Stmt = body clearPositions(repl) if needBindingDecl { body.List = prepend(bindingDecl.stmt, body.List...) } res.old = stmt res.new = repl return res, nil } // TODO(adonovan): parameterless call to { stmts; return expr } // from one of these contexts: // x, y = f() // x, y := f() // var x, y = f() // => // var (x T1, y T2); { stmts; x, y = expr } // // Because the params are no longer declared simultaneously // we need to check that (for example) x ∉ freevars(T2), // in addition to the usual checks for arg/result conversions, // complex control, etc. // Also test cases where expr is an n-ary call (spread returns). // Literalization isn't quite infallible. // Consider a spread call to a method in which // no parameters are eliminated, e.g. // new(T).f(g()) // where // func (recv *T) f(x, y int) { body } // func g() (int, int) // This would be literalized to: // func (recv *T, x, y int) { body }(new(T), g()), // which is not a valid argument list because g() must appear alone. // Reject this case for now. if len(args) == 2 && args[0] != nil && args[1] != nil && is[*types.Tuple](args[1].typ) { return nil, fmt.Errorf("can't yet inline spread call to method") } // Infallible general case: literalization. // // func(params) { body }(args) // logf("strategy: literalization") funcLit := &ast.FuncLit{ Type: calleeDecl.Type, Body: calleeDecl.Body, } // Literalization can still make use of a binding // decl as it gives a more natural reading order: // // func() { var params = args; body }() // // TODO(adonovan): relax the allResultsUnreferenced requirement // by adding a parameter-only (no named results) binding decl. if bindingDecl != nil && allResultsUnreferenced { funcLit.Type.Params.List = nil remainingArgs = nil funcLit.Body.List = prepend(bindingDecl.stmt, funcLit.Body.List...) } // Emit a new call to a function literal in place of // the callee name, with appropriate replacements. newCall := &ast.CallExpr{ Fun: funcLit, Ellipsis: token.NoPos, // f(slice...) is always simplified Args: remainingArgs, } clearPositions(newCall.Fun) res.old = caller.Call res.new = newCall return res, nil } type argument struct { expr ast.Expr typ types.Type // may be tuple for sole non-receiver arg in spread call constant constant.Value // value of argument if constant spread bool // final arg is call() assigned to multiple params pure bool // expr is pure (doesn't read variables) effects bool // expr has effects (updates variables) duplicable bool // expr may be duplicated freevars map[string]bool // free names of expr substitutable bool // is candidate for substitution } // arguments returns the effective arguments of the call. // // If the receiver argument and parameter have // different pointerness, make the "&" or "*" explicit. // // Also, if x.f() is shorthand for promoted method x.y.f(), // make the .y explicit in T.f(x.y, ...). // // Beware that: // // - a method can only be called through a selection, but only // the first of these two forms needs special treatment: // // expr.f(args) -> ([&*]expr, args) MethodVal // T.f(recv, args) -> ( expr, args) MethodExpr // // - the presence of a value in receiver-position in the call // is a property of the caller, not the callee. A method // (calleeDecl.Recv != nil) may be called like an ordinary // function. // // - the types.Signatures seen by the caller (from // StaticCallee) and by the callee (from decl type) // differ in this case. // // In a spread call f(g()), the sole ordinary argument g(), // always last in args, has a tuple type. // // We compute type-based predicates like pure, duplicable, // freevars, etc, now, before we start modifying syntax. func (st *state) arguments(caller *Caller, calleeDecl *ast.FuncDecl, assign1 func(*types.Var) bool) ([]*argument, error) { var args []*argument callArgs := caller.Call.Args if calleeDecl.Recv != nil { sel := astutil.Unparen(caller.Call.Fun).(*ast.SelectorExpr) seln := caller.Info.Selections[sel] var recvArg ast.Expr switch seln.Kind() { case types.MethodVal: // recv.f(callArgs) recvArg = sel.X case types.MethodExpr: // T.f(recv, callArgs) recvArg = callArgs[0] callArgs = callArgs[1:] } if recvArg != nil { // Compute all the type-based predicates now, // before we start meddling with the syntax; // the meddling will update them. arg := &argument{ expr: recvArg, typ: caller.Info.TypeOf(recvArg), constant: caller.Info.Types[recvArg].Value, pure: pure(caller.Info, assign1, recvArg), effects: st.effects(caller.Info, recvArg), duplicable: duplicable(caller.Info, recvArg), freevars: freeVars(caller.Info, recvArg), } recvArg = nil // prevent accidental use // Move receiver argument recv.f(args) to argument list f(&recv, args). args = append(args, arg) // Make field selections explicit (recv.f -> recv.y.f), // updating arg.{expr,typ}. indices := seln.Index() for _, index := range indices[:len(indices)-1] { fld := typeparams.CoreType(typeparams.Deref(arg.typ)).(*types.Struct).Field(index) if fld.Pkg() != caller.Types && !fld.Exported() { return nil, fmt.Errorf("in %s, implicit reference to unexported field .%s cannot be made explicit", debugFormatNode(caller.Fset, caller.Call.Fun), fld.Name()) } if isPointer(arg.typ) { arg.pure = false // implicit *ptr operation => impure } arg.expr = &ast.SelectorExpr{ X: arg.expr, Sel: makeIdent(fld.Name()), } arg.typ = fld.Type() arg.duplicable = false } // Make * or & explicit. argIsPtr := isPointer(arg.typ) paramIsPtr := isPointer(seln.Obj().Type().Underlying().(*types.Signature).Recv().Type()) if !argIsPtr && paramIsPtr { // &recv arg.expr = &ast.UnaryExpr{Op: token.AND, X: arg.expr} arg.typ = types.NewPointer(arg.typ) } else if argIsPtr && !paramIsPtr { // *recv arg.expr = &ast.StarExpr{X: arg.expr} arg.typ = typeparams.Deref(arg.typ) arg.duplicable = false arg.pure = false } } } for _, expr := range callArgs { tv := caller.Info.Types[expr] args = append(args, &argument{ expr: expr, typ: tv.Type, constant: tv.Value, spread: is[*types.Tuple](tv.Type), // => last pure: pure(caller.Info, assign1, expr), effects: st.effects(caller.Info, expr), duplicable: duplicable(caller.Info, expr), freevars: freeVars(caller.Info, expr), }) } // Re-typecheck each constant argument expression in a neutral context. // // In a call such as func(int16){}(1), the type checker infers // the type "int16", not "untyped int", for the argument 1, // because it has incorporated information from the left-hand // side of the assignment implicit in parameter passing, but // of course in a different context, the expression 1 may have // a different type. // // So, we must use CheckExpr to recompute the type of the // argument in a neutral context to find its inherent type. // (This is arguably a bug in go/types, but I'm pretty certain // I requested it be this way long ago... -adonovan) // // This is only needed for constants. Other implicit // assignment conversions, such as unnamed-to-named struct or // chan to <-chan, do not result in the type-checker imposing // the LHS type on the RHS value. for _, arg := range args { if arg.constant == nil { continue } info := &types.Info{Types: make(map[ast.Expr]types.TypeAndValue)} if err := types.CheckExpr(caller.Fset, caller.Types, caller.Call.Pos(), arg.expr, info); err != nil { return nil, err } arg.typ = info.TypeOf(arg.expr) } return args, nil } type parameter struct { obj *types.Var // parameter var from caller's signature fieldType ast.Expr // syntax of type, from calleeDecl.Type.{Recv,Params} info *paramInfo // information from AnalyzeCallee variadic bool // (final) parameter is unsimplified ...T } // substitute implements parameter elimination by substitution. // // It considers each parameter and its corresponding argument in turn // and evaluate these conditions: // // - the parameter is neither address-taken nor assigned; // - the argument is pure; // - if the parameter refcount is zero, the argument must // not contain the last use of a local var; // - if the parameter refcount is > 1, the argument must be duplicable; // - the argument (or types.Default(argument) if it's untyped) has // the same type as the parameter. // // If all conditions are met then the parameter can be substituted and // each reference to it replaced by the argument. In that case, the // replaceCalleeID function is called for each reference to the // parameter, and is provided with its relative offset and replacement // expression (argument), and the corresponding elements of params and // args are replaced by nil. func substitute(logf func(string, ...any), caller *Caller, params []*parameter, args []*argument, effects []int, falcon falconResult, replaceCalleeID func(offset int, repl ast.Expr)) { // Inv: // in calls to variadic, len(args) >= len(params)-1 // in spread calls to non-variadic, len(args) < len(params) // in spread calls to variadic, len(args) <= len(params) // (In spread calls len(args) = 1, or 2 if call has receiver.) // Non-spread variadics have been simplified away already, // so the args[i] lookup is safe if we stop after the spread arg. next: for i, param := range params { arg := args[i] // Check argument against parameter. // // Beware: don't use types.Info on arg since // the syntax may be synthetic (not created by parser) // and thus lacking positions and types; // do it earlier (see pure/duplicable/freevars). if arg.spread { // spread => last argument, but not always last parameter logf("keeping param %q and following ones: argument %s is spread", param.info.Name, debugFormatNode(caller.Fset, arg.expr)) return // give up } assert(!param.variadic, "unsimplified variadic parameter") if param.info.Escapes { logf("keeping param %q: escapes from callee", param.info.Name) continue } if param.info.Assigned { logf("keeping param %q: assigned by callee", param.info.Name) continue // callee needs the parameter variable } if len(param.info.Refs) > 1 && !arg.duplicable { logf("keeping param %q: argument is not duplicable", param.info.Name) continue // incorrect or poor style to duplicate an expression } if len(param.info.Refs) == 0 { if arg.effects { logf("keeping param %q: though unreferenced, it has effects", param.info.Name) continue } // If the caller is within a function body, // eliminating an unreferenced parameter might // remove the last reference to a caller local var. if caller.enclosingFunc != nil { for free := range arg.freevars { // TODO(rfindley): we can get this 100% right by looking for // references among other arguments which have non-zero references // within the callee. if v, ok := caller.lookup(free).(*types.Var); ok && within(v.Pos(), caller.enclosingFunc.Body) && !isUsedOutsideCall(caller, v) { logf("keeping param %q: arg contains perhaps the last reference to caller local %v @ %v", param.info.Name, v, caller.Fset.PositionFor(v.Pos(), false)) continue next } } } } // Check for shadowing. // // Consider inlining a call f(z, 1) to // func f(x, y int) int { z := y; return x + y + z }: // we can't replace x in the body by z (or any // expression that has z as a free identifier) // because there's an intervening declaration of z // that would shadow the caller's one. for free := range arg.freevars { if param.info.Shadow[free] { logf("keeping param %q: cannot replace with argument as it has free ref to %s that is shadowed", param.info.Name, free) continue next // shadowing conflict } } arg.substitutable = true // may be substituted, if effects permit } // Reject constant arguments as substitution candidates // if they cause violation of falcon constraints. checkFalconConstraints(logf, params, args, falcon) // As a final step, introduce bindings to resolve any // evaluation order hazards. This must be done last, as // additional subsequent bindings could introduce new hazards. resolveEffects(logf, args, effects) // The remaining candidates are safe to substitute. for i, param := range params { if arg := args[i]; arg.substitutable { // Wrap the argument in an explicit conversion if // substitution might materially change its type. // (We already did the necessary shadowing check // on the parameter type syntax.) // // This is only needed for substituted arguments. All // other arguments are given explicit types in either // a binding decl or when using the literalization // strategy. if len(param.info.Refs) > 0 && !trivialConversion(args[i].constant, args[i].typ, params[i].obj.Type()) { arg.expr = convert(params[i].fieldType, arg.expr) logf("param %q: adding explicit %s -> %s conversion around argument", param.info.Name, args[i].typ, params[i].obj.Type()) } // It is safe to substitute param and replace it with arg. // The formatter introduces parens as needed for precedence. // // Because arg.expr belongs to the caller, // we clone it before splicing it into the callee tree. logf("replacing parameter %q by argument %q", param.info.Name, debugFormatNode(caller.Fset, arg.expr)) for _, ref := range param.info.Refs { replaceCalleeID(ref, internalastutil.CloneNode(arg.expr).(ast.Expr)) } params[i] = nil // substituted args[i] = nil // substituted } } } // isUsedOutsideCall reports whether v is used outside of caller.Call, within // the body of caller.enclosingFunc. func isUsedOutsideCall(caller *Caller, v *types.Var) bool { used := false ast.Inspect(caller.enclosingFunc.Body, func(n ast.Node) bool { if n == caller.Call { return false } switch n := n.(type) { case *ast.Ident: if use := caller.Info.Uses[n]; use == v { used = true } case *ast.FuncType: // All params are used. for _, fld := range n.Params.List { for _, n := range fld.Names { if def := caller.Info.Defs[n]; def == v { used = true } } } } return !used // keep going until we find a use }) return used } // checkFalconConstraints checks whether constant arguments // are safe to substitute (e.g. s[i] -> ""[0] is not safe.) // // Any failed constraint causes us to reject all constant arguments as // substitution candidates (by clearing args[i].substitution=false). // // TODO(adonovan): we could obtain a finer result rejecting only the // freevars of each failed constraint, and processing constraints in // order of increasing arity, but failures are quite rare. func checkFalconConstraints(logf func(string, ...any), params []*parameter, args []*argument, falcon falconResult) { // Create a dummy package, as this is the only // way to create an environment for CheckExpr. pkg := types.NewPackage("falcon", "falcon") // Declare types used by constraints. for _, typ := range falcon.Types { logf("falcon env: type %s %s", typ.Name, types.Typ[typ.Kind]) pkg.Scope().Insert(types.NewTypeName(token.NoPos, pkg, typ.Name, types.Typ[typ.Kind])) } // Declared constants and variables for for parameters. nconst := 0 for i, param := range params { name := param.info.Name if name == "" { continue // unreferenced } arg := args[i] if arg.constant != nil && arg.substitutable && param.info.FalconType != "" { t := pkg.Scope().Lookup(param.info.FalconType).Type() pkg.Scope().Insert(types.NewConst(token.NoPos, pkg, name, t, arg.constant)) logf("falcon env: const %s %s = %v", name, param.info.FalconType, arg.constant) nconst++ } else { pkg.Scope().Insert(types.NewVar(token.NoPos, pkg, name, arg.typ)) logf("falcon env: var %s %s", name, arg.typ) } } if nconst == 0 { return // nothing to do } // Parse and evaluate the constraints in the environment. fset := token.NewFileSet() for _, falcon := range falcon.Constraints { expr, err := parser.ParseExprFrom(fset, "falcon", falcon, 0) if err != nil { panic(fmt.Sprintf("failed to parse falcon constraint %s: %v", falcon, err)) } if err := types.CheckExpr(fset, pkg, token.NoPos, expr, nil); err != nil { logf("falcon: constraint %s violated: %v", falcon, err) for j, arg := range args { if arg.constant != nil && arg.substitutable { logf("keeping param %q due falcon violation", params[j].info.Name) arg.substitutable = false } } break } logf("falcon: constraint %s satisfied", falcon) } } // resolveEffects marks arguments as non-substitutable to resolve // hazards resulting from the callee evaluation order described by the // effects list. // // To do this, each argument is categorized as a read (R), write (W), // or pure. A hazard occurs when the order of evaluation of a W // changes with respect to any R or W. Pure arguments can be // effectively ignored, as they can be safely evaluated in any order. // // The callee effects list contains the index of each parameter in the // order it is first evaluated during execution of the callee. In // addition, the two special values R∞ and W∞ indicate the relative // position of the callee's first non-parameter read and its first // effects (or other unknown behavior). // For example, the list [0 2 1 R∞ 3 W∞] for func(a, b, c, d) // indicates that the callee referenced parameters a, c, and b, // followed by an arbitrary read, then parameter d, and finally // unknown behavior. // // When an argument is marked as not substitutable, we say that it is // 'bound', in the sense that its evaluation occurs in a binding decl // or literalized call. Such bindings always occur in the original // callee parameter order. // // In this context, "resolving hazards" means binding arguments so // that they are evaluated in a valid, hazard-free order. A trivial // solution to this problem would be to bind all arguments, but of // course that's not useful. The goal is to bind as few arguments as // possible. // // The algorithm proceeds by inspecting arguments in reverse parameter // order (right to left), preserving the invariant that every // higher-ordered argument is either already substituted or does not // need to be substituted. At each iteration, if there is an // evaluation hazard in the callee effects relative to the current // argument, the argument must be bound. Subsequently, if the argument // is bound for any reason, each lower-ordered argument must also be // bound if either the argument or lower-order argument is a // W---otherwise the binding itself would introduce a hazard. // // Thus, after each iteration, there are no hazards relative to the // current argument. Subsequent iterations cannot introduce hazards // with that argument because they can result only in additional // binding of lower-ordered arguments. func resolveEffects(logf func(string, ...any), args []*argument, effects []int) { effectStr := func(effects bool, idx int) string { i := fmt.Sprint(idx) if idx == len(args) { i = "∞" } return string("RW"[btoi(effects)]) + i } for i := len(args) - 1; i >= 0; i-- { argi := args[i] if argi.substitutable && !argi.pure { // i is not bound: check whether it must be bound due to hazards. idx := index(effects, i) if idx >= 0 { for _, j := range effects[:idx] { var ( ji int // effective param index jw bool // j is a write ) if j == winf || j == rinf { jw = j == winf ji = len(args) } else { jw = args[j].effects ji = j } if ji > i && (jw || argi.effects) { // out of order evaluation logf("binding argument %s: preceded by %s", effectStr(argi.effects, i), effectStr(jw, ji)) argi.substitutable = false break } } } } if !argi.substitutable { for j := 0; j < i; j++ { argj := args[j] if argj.pure { continue } if (argi.effects || argj.effects) && argj.substitutable { logf("binding argument %s: %s is bound", effectStr(argj.effects, j), effectStr(argi.effects, i)) argj.substitutable = false } } } } } // updateCalleeParams updates the calleeDecl syntax to remove // substituted parameters and move the receiver (if any) to the head // of the ordinary parameters. func updateCalleeParams(calleeDecl *ast.FuncDecl, params []*parameter) { // The logic is fiddly because of the three forms of ast.Field: // // func(int), func(x int), func(x, y int) // // Also, ensure that all remaining parameters are named // to avoid a mix of named/unnamed when joining (recv, params...). // func (T) f(int, bool) -> (_ T, _ int, _ bool) // (Strictly, we need do this only for methods and only when // the namednesses of Recv and Params differ; that might be tidier.) paramIdx := 0 // index in original parameter list (incl. receiver) var newParams []*ast.Field filterParams := func(field *ast.Field) { var names []*ast.Ident if field.Names == nil { // Unnamed parameter field (e.g. func f(int) if params[paramIdx] != nil { // Give it an explicit name "_" since we will // make the receiver (if any) a regular parameter // and one cannot mix named and unnamed parameters. names = append(names, makeIdent("_")) } paramIdx++ } else { // Named parameter field e.g. func f(x, y int) // Remove substituted parameters in place. // If all were substituted, delete field. for _, id := range field.Names { if pinfo := params[paramIdx]; pinfo != nil { // Rename unreferenced parameters with "_". // This is crucial for binding decls, since // unlike parameters, they are subject to // "unreferenced var" checks. if len(pinfo.info.Refs) == 0 { id = makeIdent("_") } names = append(names, id) } paramIdx++ } } if names != nil { newParams = append(newParams, &ast.Field{ Names: names, Type: field.Type, }) } } if calleeDecl.Recv != nil { filterParams(calleeDecl.Recv.List[0]) calleeDecl.Recv = nil } for _, field := range calleeDecl.Type.Params.List { filterParams(field) } calleeDecl.Type.Params.List = newParams } // bindingDeclInfo records information about the binding decl produced by // createBindingDecl. type bindingDeclInfo struct { names map[string]bool // names bound by the binding decl; possibly empty stmt ast.Stmt // the binding decl itself } // createBindingDecl constructs a "binding decl" that implements // parameter assignment and declares any named result variables // referenced by the callee. It returns nil if there were no // unsubstituted parameters. // // It may not always be possible to create the decl (e.g. due to // shadowing), in which case it also returns nil; but if it succeeds, // the declaration may be used by reduction strategies to relax the // requirement that all parameters have been substituted. // // For example, a call: // // f(a0, a1, a2) // // where: // // func f(p0, p1 T0, p2 T1) { body } // // reduces to: // // { // var ( // p0, p1 T0 = a0, a1 // p2 T1 = a2 // ) // body // } // // so long as p0, p1 ∉ freevars(T1) or freevars(a2), and so on, // because each spec is statically resolved in sequence and // dynamically assigned in sequence. By contrast, all // parameters are resolved simultaneously and assigned // simultaneously. // // The pX names should already be blank ("_") if the parameter // is unreferenced; this avoids "unreferenced local var" checks. // // Strategies may impose additional checks on return // conversions, labels, defer, etc. func createBindingDecl(logf func(string, ...any), caller *Caller, args []*argument, calleeDecl *ast.FuncDecl, results []*paramInfo) *bindingDeclInfo { // Spread calls are tricky as they may not align with the // parameters' field groupings nor types. // For example, given // func g() (int, string) // the call // f(g()) // is legal with these decls of f: // func f(int, string) // func f(x, y any) // func f(x, y ...any) // TODO(adonovan): support binding decls for spread calls by // splitting parameter groupings as needed. if lastArg := last(args); lastArg != nil && lastArg.spread { logf("binding decls not yet supported for spread calls") return nil } var ( specs []ast.Spec names = make(map[string]bool) // names defined by previous specs ) // shadow reports whether any name referenced by spec is // shadowed by a name declared by a previous spec (since, // unlike parameters, each spec of a var decl is within the // scope of the previous specs). shadow := func(spec *ast.ValueSpec) bool { // Compute union of free names of type and values // and detect shadowing. Values is the arguments // (caller syntax), so we can use type info. // But Type is the untyped callee syntax, // so we have to use a syntax-only algorithm. free := make(map[string]bool) for _, value := range spec.Values { for name := range freeVars(caller.Info, value) { free[name] = true } } freeishNames(free, spec.Type) for name := range free { if names[name] { logf("binding decl would shadow free name %q", name) return true } } for _, id := range spec.Names { if id.Name != "_" { names[id.Name] = true } } return false } // parameters // // Bind parameters that were not eliminated through // substitution. (Non-nil arguments correspond to the // remaining parameters in calleeDecl.) var values []ast.Expr for _, arg := range args { if arg != nil { values = append(values, arg.expr) } } for _, field := range calleeDecl.Type.Params.List { // Each field (param group) becomes a ValueSpec. spec := &ast.ValueSpec{ Names: field.Names, Type: field.Type, Values: values[:len(field.Names)], } values = values[len(field.Names):] if shadow(spec) { return nil } specs = append(specs, spec) } assert(len(values) == 0, "args/params mismatch") // results // // Add specs to declare any named result // variables that are referenced by the body. if calleeDecl.Type.Results != nil { resultIdx := 0 for _, field := range calleeDecl.Type.Results.List { if field.Names == nil { resultIdx++ continue // unnamed field } var names []*ast.Ident for _, id := range field.Names { if len(results[resultIdx].Refs) > 0 { names = append(names, id) } resultIdx++ } if len(names) > 0 { spec := &ast.ValueSpec{ Names: names, Type: field.Type, } if shadow(spec) { return nil } specs = append(specs, spec) } } } if len(specs) == 0 { logf("binding decl not needed: all parameters substituted") return nil } stmt := &ast.DeclStmt{ Decl: &ast.GenDecl{ Tok: token.VAR, Specs: specs, }, } logf("binding decl: %s", debugFormatNode(caller.Fset, stmt)) return &bindingDeclInfo{names: names, stmt: stmt} } // lookup does a symbol lookup in the lexical environment of the caller. func (caller *Caller) lookup(name string) types.Object { pos := caller.Call.Pos() for _, n := range caller.path { if scope := scopeFor(caller.Info, n); scope != nil { if _, obj := scope.LookupParent(name, pos); obj != nil { return obj } } } return nil } func scopeFor(info *types.Info, n ast.Node) *types.Scope { // The function body scope (containing not just params) // is associated with the function's type, not body. switch fn := n.(type) { case *ast.FuncDecl: n = fn.Type case *ast.FuncLit: n = fn.Type } return info.Scopes[n] } // -- predicates over expressions -- // freeVars returns the names of all free identifiers of e: // those lexically referenced by it but not defined within it. // (Fields and methods are not included.) func freeVars(info *types.Info, e ast.Expr) map[string]bool { free := make(map[string]bool) ast.Inspect(e, func(n ast.Node) bool { if id, ok := n.(*ast.Ident); ok { // The isField check is so that we don't treat T{f: 0} as a ref to f. if obj, ok := info.Uses[id]; ok && !within(obj.Pos(), e) && !isField(obj) { free[obj.Name()] = true } } return true }) return free } // freeishNames computes an over-approximation to the free names // of the type syntax t, inserting values into the map. // // Because we don't have go/types annotations, we can't give an exact // result in all cases. In particular, an array type [n]T might have a // size such as unsafe.Sizeof(func() int{stmts...}()) and now the // precise answer depends upon all the statement syntax too. But that // never happens in practice. func freeishNames(free map[string]bool, t ast.Expr) { var visit func(n ast.Node) bool visit = func(n ast.Node) bool { switch n := n.(type) { case *ast.Ident: free[n.Name] = true case *ast.SelectorExpr: ast.Inspect(n.X, visit) return false // don't visit .Sel case *ast.Field: ast.Inspect(n.Type, visit) // Don't visit .Names: // FuncType parameters, interface methods, struct fields return false } return true } ast.Inspect(t, visit) } // effects reports whether an expression might change the state of the // program (through function calls and channel receives) and affect // the evaluation of subsequent expressions. func (st *state) effects(info *types.Info, expr ast.Expr) bool { effects := false ast.Inspect(expr, func(n ast.Node) bool { switch n := n.(type) { case *ast.FuncLit: return false // prune descent case *ast.CallExpr: if info.Types[n.Fun].IsType() { // A conversion T(x) has only the effect of its operand. } else if !callsPureBuiltin(info, n) { // A handful of built-ins have no effect // beyond those of their arguments. // All other calls (including append, copy, recover) // have unknown effects. // // As with 'pure', there is room for // improvement by inspecting the callee. effects = true } case *ast.UnaryExpr: if n.Op == token.ARROW { // <-ch effects = true } } return true }) // Even if consideration of effects is not desired, // we continue to compute, log, and discard them. if st.opts.IgnoreEffects && effects { effects = false st.opts.Logf("ignoring potential effects of argument %s", debugFormatNode(st.caller.Fset, expr)) } return effects } // pure reports whether an expression has the same result no matter // when it is executed relative to other expressions, so it can be // commuted with any other expression or statement without changing // its meaning. // // An expression is considered impure if it reads the contents of any // variable, with the exception of "single assignment" local variables // (as classified by the provided callback), which are never updated // after their initialization. // // Pure does not imply duplicable: for example, new(T) and T{} are // pure expressions but both return a different value each time they // are evaluated, so they are not safe to duplicate. // // Purity does not imply freedom from run-time panics. We assume that // target programs do not encounter run-time panics nor depend on them // for correct operation. // // TODO(adonovan): add unit tests of this function. func pure(info *types.Info, assign1 func(*types.Var) bool, e ast.Expr) bool { var pure func(e ast.Expr) bool pure = func(e ast.Expr) bool { switch e := e.(type) { case *ast.ParenExpr: return pure(e.X) case *ast.Ident: if v, ok := info.Uses[e].(*types.Var); ok { // In general variables are impure // as they may be updated, but // single-assignment local variables // never change value. // // We assume all package-level variables // may be updated, but for non-exported // ones we could do better by analyzing // the complete package. return !isPkgLevel(v) && assign1(v) } // All other kinds of reference are pure. return true case *ast.FuncLit: // A function literal may allocate a closure that // references mutable variables, but mutation // cannot be observed without calling the function, // and calls are considered impure. return true case *ast.BasicLit: return true case *ast.UnaryExpr: // + - ! ^ & but not <- return e.Op != token.ARROW && pure(e.X) case *ast.BinaryExpr: // arithmetic, shifts, comparisons, &&/|| return pure(e.X) && pure(e.Y) case *ast.CallExpr: // A conversion is as pure as its operand. if info.Types[e.Fun].IsType() { return pure(e.Args[0]) } // Calls to some built-ins are as pure as their arguments. if callsPureBuiltin(info, e) { for _, arg := range e.Args { if !pure(arg) { return false } } return true } // All other calls are impure, so we can // reject them without even looking at e.Fun. // // More sophisticated analysis could infer purity in // commonly used functions such as strings.Contains; // perhaps we could offer the client a hook so that // go/analysis-based implementation could exploit the // results of a purity analysis. But that would make // the inliner's choices harder to explain. return false case *ast.CompositeLit: // T{...} is as pure as its elements. for _, elt := range e.Elts { if kv, ok := elt.(*ast.KeyValueExpr); ok { if !pure(kv.Value) { return false } if id, ok := kv.Key.(*ast.Ident); ok { if v, ok := info.Uses[id].(*types.Var); ok && v.IsField() { continue // struct {field: value} } } // map/slice/array {key: value} if !pure(kv.Key) { return false } } else if !pure(elt) { return false } } return true case *ast.SelectorExpr: if seln, ok := info.Selections[e]; ok { // See types.SelectionKind for background. switch seln.Kind() { case types.MethodExpr: // A method expression T.f acts like a // reference to a func decl, so it is pure. return true case types.MethodVal, types.FieldVal: // A field or method selection x.f is pure // if x is pure and the selection does // not indirect a pointer. return !indirectSelection(seln) && pure(e.X) default: panic(seln) } } else { // A qualified identifier is // treated like an unqualified one. return pure(e.Sel) } case *ast.StarExpr: return false // *ptr depends on the state of the heap default: return false } } return pure(e) } // callsPureBuiltin reports whether call is a call of a built-in // function that is a pure computation over its operands (analogous to // a + operator). Because it does not depend on program state, it may // be evaluated at any point--though not necessarily at multiple // points (consider new, make). func callsPureBuiltin(info *types.Info, call *ast.CallExpr) bool { if id, ok := astutil.Unparen(call.Fun).(*ast.Ident); ok { if b, ok := info.ObjectOf(id).(*types.Builtin); ok { switch b.Name() { case "len", "cap", "complex", "imag", "real", "make", "new", "max", "min": return true } // Not: append clear close copy delete panic print println recover } } return false } // duplicable reports whether it is appropriate for the expression to // be freely duplicated. // // Given the declaration // // func f(x T) T { return x + g() + x } // // an argument y is considered duplicable if we would wish to see a // call f(y) simplified to y+g()+y. This is true for identifiers, // integer literals, unary negation, and selectors x.f where x is not // a pointer. But we would not wish to duplicate expressions that: // - have side effects (e.g. nearly all calls), // - are not referentially transparent (e.g. &T{}, ptr.field, *ptr), or // - are long (e.g. "huge string literal"). func duplicable(info *types.Info, e ast.Expr) bool { switch e := e.(type) { case *ast.ParenExpr: return duplicable(info, e.X) case *ast.Ident: return true case *ast.BasicLit: v := info.Types[e].Value switch e.Kind { case token.INT: return true // any int case token.STRING: return consteq(v, kZeroString) // only "" case token.FLOAT: return consteq(v, kZeroFloat) || consteq(v, kOneFloat) // only 0.0 or 1.0 } case *ast.UnaryExpr: // e.g. +1, -1 return (e.Op == token.ADD || e.Op == token.SUB) && duplicable(info, e.X) case *ast.CompositeLit: // Empty struct or array literals T{} are duplicable. // (Non-empty literals are too verbose, and slice/map // literals allocate indirect variables.) if len(e.Elts) == 0 { switch info.TypeOf(e).Underlying().(type) { case *types.Struct, *types.Array: return true } } return false case *ast.CallExpr: // Don't treat a conversion T(x) as duplicable even // if x is duplicable because it could duplicate // allocations. // // TODO(adonovan): there are cases to tease apart here: // duplicating string([]byte) conversions increases // allocation but doesn't change behavior, but the // reverse, []byte(string), allocates a distinct array, // which is observable return false case *ast.SelectorExpr: if seln, ok := info.Selections[e]; ok { // A field or method selection x.f is referentially // transparent if it does not indirect a pointer. return !indirectSelection(seln) } // A qualified identifier pkg.Name is referentially transparent. return true } return false } func consteq(x, y constant.Value) bool { return constant.Compare(x, token.EQL, y) } var ( kZeroInt = constant.MakeInt64(0) kZeroString = constant.MakeString("") kZeroFloat = constant.MakeFloat64(0.0) kOneFloat = constant.MakeFloat64(1.0) ) // -- inline helpers -- func assert(cond bool, msg string) { if !cond { panic(msg) } } // blanks returns a slice of n > 0 blank identifiers. func blanks[E ast.Expr](n int) []E { if n == 0 { panic("blanks(0)") } res := make([]E, n) for i := range res { res[i] = ast.Expr(makeIdent("_")).(E) // ugh } return res } func makeIdent(name string) *ast.Ident { return &ast.Ident{Name: name} } // importedPkgName returns the PkgName object declared by an ImportSpec. // TODO(adonovan): make this a method of types.Info (#62037). func importedPkgName(info *types.Info, imp *ast.ImportSpec) (*types.PkgName, bool) { var obj types.Object if imp.Name != nil { obj = info.Defs[imp.Name] } else { obj = info.Implicits[imp] } pkgname, ok := obj.(*types.PkgName) return pkgname, ok } func isPkgLevel(obj types.Object) bool { // TODO(adonovan): consider using the simpler obj.Parent() == // obj.Pkg().Scope() instead. But be sure to test carefully // with instantiations of generics. return obj.Pkg().Scope().Lookup(obj.Name()) == obj } // callContext returns the two nodes immediately enclosing the call // (specified as a PathEnclosingInterval), ignoring parens. func callContext(callPath []ast.Node) (parent, grandparent ast.Node) { _ = callPath[0].(*ast.CallExpr) // sanity check for _, n := range callPath[1:] { if !is[*ast.ParenExpr](n) { if parent == nil { parent = n } else { return parent, n } } } return parent, nil } // hasLabelConflict reports whether the set of labels of the function // enclosing the call (specified as a PathEnclosingInterval) // intersects with the set of callee labels. func hasLabelConflict(callPath []ast.Node, calleeLabels []string) bool { labels := callerLabels(callPath) for _, label := range calleeLabels { if labels[label] { return true // conflict } } return false } // callerLabels returns the set of control labels in the function (if // any) enclosing the call (specified as a PathEnclosingInterval). func callerLabels(callPath []ast.Node) map[string]bool { var callerBody *ast.BlockStmt switch f := callerFunc(callPath).(type) { case *ast.FuncDecl: callerBody = f.Body case *ast.FuncLit: callerBody = f.Body } var labels map[string]bool if callerBody != nil { ast.Inspect(callerBody, func(n ast.Node) bool { switch n := n.(type) { case *ast.FuncLit: return false // prune traversal case *ast.LabeledStmt: if labels == nil { labels = make(map[string]bool) } labels[n.Label.Name] = true } return true }) } return labels } // callerFunc returns the innermost Func{Decl,Lit} node enclosing the // call (specified as a PathEnclosingInterval). func callerFunc(callPath []ast.Node) ast.Node { _ = callPath[0].(*ast.CallExpr) // sanity check for _, n := range callPath[1:] { if is[*ast.FuncDecl](n) || is[*ast.FuncLit](n) { return n } } return nil } // callStmt reports whether the function call (specified // as a PathEnclosingInterval) appears within an ExprStmt, // and returns it if so. // // If unrestricted, callStmt returns nil if the ExprStmt f() appears // in a restricted context (such as "if f(); cond {") where it cannot // be replaced by an arbitrary statement. (See "statement theory".) func callStmt(callPath []ast.Node, unrestricted bool) *ast.ExprStmt { parent, _ := callContext(callPath) stmt, ok := parent.(*ast.ExprStmt) if ok && unrestricted { switch callPath[nodeIndex(callPath, stmt)+1].(type) { case *ast.LabeledStmt, *ast.BlockStmt, *ast.CaseClause, *ast.CommClause: // unrestricted default: // TODO(adonovan): handle restricted // XYZStmt.Init contexts (but not ForStmt.Post) // by creating a block around the if/for/switch: // "if f(); cond {" -> "{ stmts; if cond {" return nil // restricted } } return stmt } // Statement theory // // These are all the places a statement may appear in the AST: // // LabeledStmt.Stmt Stmt -- any // BlockStmt.List []Stmt -- any (but see switch/select) // IfStmt.Init Stmt? -- simple // IfStmt.Body BlockStmt // IfStmt.Else Stmt? -- IfStmt or BlockStmt // CaseClause.Body []Stmt -- any // SwitchStmt.Init Stmt? -- simple // SwitchStmt.Body BlockStmt -- CaseClauses only // TypeSwitchStmt.Init Stmt? -- simple // TypeSwitchStmt.Assign Stmt -- AssignStmt(TypeAssertExpr) or ExprStmt(TypeAssertExpr) // TypeSwitchStmt.Body BlockStmt -- CaseClauses only // CommClause.Comm Stmt? -- SendStmt or ExprStmt(UnaryExpr) or AssignStmt(UnaryExpr) // CommClause.Body []Stmt -- any // SelectStmt.Body BlockStmt -- CommClauses only // ForStmt.Init Stmt? -- simple // ForStmt.Post Stmt? -- simple // ForStmt.Body BlockStmt // RangeStmt.Body BlockStmt // // simple = AssignStmt | SendStmt | IncDecStmt | ExprStmt. // // A BlockStmt cannot replace an ExprStmt in // {If,Switch,TypeSwitch}Stmt.Init or ForStmt.Post. // That is allowed only within: // LabeledStmt.Stmt Stmt // BlockStmt.List []Stmt // CaseClause.Body []Stmt // CommClause.Body []Stmt // replaceNode performs a destructive update of the tree rooted at // root, replacing each occurrence of "from" with "to". If to is nil and // the element is within a slice, the slice element is removed. // // The root itself cannot be replaced; an attempt will panic. // // This function must not be called on the caller's syntax tree. // // TODO(adonovan): polish this up and move it to astutil package. // TODO(adonovan): needs a unit test. func replaceNode(root ast.Node, from, to ast.Node) { if from == nil { panic("from == nil") } if reflect.ValueOf(from).IsNil() { panic(fmt.Sprintf("from == (%T)(nil)", from)) } if from == root { panic("from == root") } found := false var parent reflect.Value // parent variable of interface type, containing a pointer var visit func(reflect.Value) visit = func(v reflect.Value) { switch v.Kind() { case reflect.Ptr: if v.Interface() == from { found = true // If v is a struct field or array element // (e.g. Field.Comment or Field.Names[i]) // then it is addressable (a pointer variable). // // But if it was the value an interface // (e.g. *ast.Ident within ast.Node) // then it is non-addressable, and we need // to set the enclosing interface (parent). if !v.CanAddr() { v = parent } // to=nil => use zero value var toV reflect.Value if to != nil { toV = reflect.ValueOf(to) } else { toV = reflect.Zero(v.Type()) // e.g. ast.Expr(nil) } v.Set(toV) } else if !v.IsNil() { switch v.Interface().(type) { case *ast.Object, *ast.Scope: // Skip fields of types potentially involved in cycles. default: visit(v.Elem()) } } case reflect.Struct: for i := 0; i < v.Type().NumField(); i++ { visit(v.Field(i)) } case reflect.Slice: compact := false for i := 0; i < v.Len(); i++ { visit(v.Index(i)) if v.Index(i).IsNil() { compact = true } } if compact { // Elements were deleted. Eliminate nils. // (Do this is a second pass to avoid // unnecessary writes in the common case.) j := 0 for i := 0; i < v.Len(); i++ { if !v.Index(i).IsNil() { v.Index(j).Set(v.Index(i)) j++ } } v.SetLen(j) } case reflect.Interface: parent = v visit(v.Elem()) case reflect.Array, reflect.Chan, reflect.Func, reflect.Map, reflect.UnsafePointer: panic(v) // unreachable in AST default: // bool, string, number: nop } parent = reflect.Value{} } visit(reflect.ValueOf(root)) if !found { panic(fmt.Sprintf("%T not found", from)) } } // clearPositions destroys token.Pos information within the tree rooted at root, // as positions in callee trees may cause caller comments to be emitted prematurely. // // In general it isn't safe to clear a valid Pos because some of them // (e.g. CallExpr.Ellipsis, TypeSpec.Assign) are significant to // go/printer, so this function sets each non-zero Pos to 1, which // suffices to avoid advancing the printer's comment cursor. // // This function mutates its argument; do not invoke on caller syntax. // // TODO(adonovan): remove this horrendous workaround when #20744 is finally fixed. func clearPositions(root ast.Node) { posType := reflect.TypeOf(token.NoPos) ast.Inspect(root, func(n ast.Node) bool { if n != nil { v := reflect.ValueOf(n).Elem() // deref the pointer to struct fields := v.Type().NumField() for i := 0; i < fields; i++ { f := v.Field(i) // Clearing Pos arbitrarily is destructive, // as its presence may be semantically significant // (e.g. CallExpr.Ellipsis, TypeSpec.Assign) // or affect formatting preferences (e.g. GenDecl.Lparen). // // Note: for proper formatting, it may be necessary to be selective // about which positions we set to 1 vs which we set to token.NoPos. // (e.g. we can set most to token.NoPos, save the few that are // significant). if f.Type() == posType { if f.Interface() != token.NoPos { f.Set(reflect.ValueOf(token.Pos(1))) } } } } return true }) } // findIdent returns the Ident beneath root that has the given pos. func findIdent(root ast.Node, pos token.Pos) *ast.Ident { // TODO(adonovan): opt: skip subtrees that don't contain pos. var found *ast.Ident ast.Inspect(root, func(n ast.Node) bool { if found != nil { return false } if id, ok := n.(*ast.Ident); ok { if id.Pos() == pos { found = id } } return true }) if found == nil { panic(fmt.Sprintf("findIdent %d not found in %s", pos, debugFormatNode(token.NewFileSet(), root))) } return found } func prepend[T any](elem T, slice ...T) []T { return append([]T{elem}, slice...) } // debugFormatNode formats a node or returns a formatting error. // Its sloppy treatment of errors is appropriate only for logging. func debugFormatNode(fset *token.FileSet, n ast.Node) string { var out strings.Builder if err := format.Node(&out, fset, n); err != nil { out.WriteString(err.Error()) } return out.String() } func shallowCopy[T any](ptr *T) *T { copy := *ptr return © } // ∀ func forall[T any](list []T, f func(i int, x T) bool) bool { for i, x := range list { if !f(i, x) { return false } } return true } // ∃ func exists[T any](list []T, f func(i int, x T) bool) bool { for i, x := range list { if f(i, x) { return true } } return false } // last returns the last element of a slice, or zero if empty. func last[T any](slice []T) T { n := len(slice) if n > 0 { return slice[n-1] } return *new(T) } // canImport reports whether one package is allowed to import another. // // TODO(adonovan): allow customization of the accessibility relation // (e.g. for Bazel). func canImport(from, to string) bool { // TODO(adonovan): better segment hygiene. if strings.HasPrefix(to, "internal/") { // Special case: only std packages may import internal/... // We can't reliably know whether we're in std, so we // use a heuristic on the first segment. first, _, _ := strings.Cut(from, "/") if strings.Contains(first, ".") { return false // example.com/foo ∉ std } if first == "testdata" { return false // testdata/foo ∉ std } } if i := strings.LastIndex(to, "/internal/"); i >= 0 { return strings.HasPrefix(from, to[:i]) } return true } // consistentOffsets reports whether the portion of caller.Content // that corresponds to caller.Call can be parsed as a call expression. // If not, the client has provided inconsistent information, possibly // because they forgot to ignore line directives when computing the // filename enclosing the call. // This is just a heuristic. func consistentOffsets(caller *Caller) bool { start := offsetOf(caller.Fset, caller.Call.Pos()) end := offsetOf(caller.Fset, caller.Call.End()) if !(0 < start && start < end && end <= len(caller.Content)) { return false } expr, err := parser.ParseExpr(string(caller.Content[start:end])) if err != nil { return false } return is[*ast.CallExpr](expr) } // needsParens reports whether parens are required to avoid ambiguity // around the new node replacing the specified old node (which is some // ancestor of the CallExpr identified by its PathEnclosingInterval). func needsParens(callPath []ast.Node, old, new ast.Node) bool { // Find enclosing old node and its parent. i := nodeIndex(callPath, old) if i == -1 { panic("not found") } // There is no precedence ambiguity when replacing // (e.g.) a statement enclosing the call. if !is[ast.Expr](old) { return false } // An expression beneath a non-expression // has no precedence ambiguity. parent, ok := callPath[i+1].(ast.Expr) if !ok { return false } precedence := func(n ast.Node) int { switch n := n.(type) { case *ast.UnaryExpr, *ast.StarExpr: return token.UnaryPrec case *ast.BinaryExpr: return n.Op.Precedence() } return -1 } // Parens are not required if the new node // is not unary or binary. newprec := precedence(new) if newprec < 0 { return false } // Parens are required if parent and child are both // unary or binary and the parent has higher precedence. if precedence(parent) > newprec { return true } // Was the old node the operand of a postfix operator? // f().sel // f()[i:j] // f()[i] // f().(T) // f()(x) switch parent := parent.(type) { case *ast.SelectorExpr: return parent.X == old case *ast.IndexExpr: return parent.X == old case *ast.SliceExpr: return parent.X == old case *ast.TypeAssertExpr: return parent.X == old case *ast.CallExpr: return parent.Fun == old } return false } func nodeIndex(nodes []ast.Node, n ast.Node) int { // TODO(adonovan): Use index[ast.Node]() in go1.20. for i, node := range nodes { if node == n { return i } } return -1 } // declares returns the set of lexical names declared by a // sequence of statements from the same block, excluding sub-blocks. // (Lexical names do not include control labels.) func declares(stmts []ast.Stmt) map[string]bool { names := make(map[string]bool) for _, stmt := range stmts { switch stmt := stmt.(type) { case *ast.DeclStmt: for _, spec := range stmt.Decl.(*ast.GenDecl).Specs { switch spec := spec.(type) { case *ast.ValueSpec: for _, id := range spec.Names { names[id.Name] = true } case *ast.TypeSpec: names[spec.Name.Name] = true } } case *ast.AssignStmt: if stmt.Tok == token.DEFINE { for _, lhs := range stmt.Lhs { names[lhs.(*ast.Ident).Name] = true } } } } delete(names, "_") return names } // assignStmts rewrites a statement assigning the results of a call into zero // or more statements that assign its return operands, or (nil, false) if no // such rewrite is possible. The set of bindings created by the result of // assignStmts is the same as the set of bindings created by the callerStmt. // // The callee must contain exactly one return statement. // // This is (once again) a surprisingly complex task. For example, depending on // types and existing bindings, the assignment // // a, b := f() // // could be rewritten as: // // a, b := 1, 2 // // but may need to be written as: // // a, b := int8(1), int32(2) // // In the case where the return statement within f is a spread call to another // function g(), we cannot explicitly convert the return values inline, and so // it may be necessary to split the declaration and assignment of variables // into separate statements: // // a, b := g() // // or // // var a int32 // a, b = g() // // or // // var ( // a int8 // b int32 // ) // a, b = g() // // Note: assignStmts may return (nil, true) if it determines that the rewritten // assignment consists only of _ = nil assignments. func (st *state) assignStmts(callerStmt *ast.AssignStmt, returnOperands []ast.Expr) ([]ast.Stmt, bool) { logf, caller, callee := st.opts.Logf, st.caller, &st.callee.impl assert(len(callee.Returns) == 1, "unexpected multiple returns") resultInfo := callee.Returns[0] // When constructing assign statements, we need to make sure that we don't // modify types on the left-hand side, such as would happen if the type of a // RHS expression does not match the corresponding LHS type at the caller // (due to untyped conversion or interface widening). // // This turns out to be remarkably tricky to handle correctly. // // Substrategies below are labeled as `Substrategy :`. // Collect LHS information. var ( lhs []ast.Expr // shallow copy of the LHS slice, for mutation defs = make([]*ast.Ident, len(callerStmt.Lhs)) // indexes in lhs of defining identifiers blanks = make([]bool, len(callerStmt.Lhs)) // indexes in lhs of blank identifiers byType typeutil.Map // map of distinct types -> indexes, for writing specs later ) for i, expr := range callerStmt.Lhs { lhs = append(lhs, expr) if name, ok := expr.(*ast.Ident); ok { if name.Name == "_" { blanks[i] = true continue // no type } if obj, isDef := caller.Info.Defs[name]; isDef { defs[i] = name typ := obj.Type() idxs, _ := byType.At(typ).([]int) idxs = append(idxs, i) byType.Set(typ, idxs) } } } // Collect RHS information // // The RHS is either a parallel assignment or spread assignment, but by // looping over both callerStmt.Rhs and returnOperands we handle both. var ( rhs []ast.Expr // new RHS of assignment, owned by the inliner callIdx = -1 // index of the call among the original RHS nilBlankAssigns = make(map[int]unit) // indexes in rhs of _ = nil assignments, which can be deleted freeNames = make(map[string]bool) // free(ish) names among rhs expressions nonTrivial = make(map[int]bool) // indexes in rhs of nontrivial result conversions ) for i, expr := range callerStmt.Rhs { if expr == caller.Call { assert(callIdx == -1, "malformed (duplicative) AST") callIdx = i for j, returnOperand := range returnOperands { freeishNames(freeNames, returnOperand) rhs = append(rhs, returnOperand) if resultInfo[j]&nonTrivialResult != 0 { nonTrivial[i+j] = true } if blanks[i+j] && resultInfo[j]&untypedNilResult != 0 { nilBlankAssigns[i+j] = unit{} } } } else { // We must clone before clearing positions, since e came from the caller. expr = internalastutil.CloneNode(expr) clearPositions(expr) freeishNames(freeNames, expr) rhs = append(rhs, expr) } } assert(callIdx >= 0, "failed to find call in RHS") // Substrategy "splice": Check to see if we can simply splice in the result // expressions from the callee, such as simplifying // // x, y := f() // // to // // x, y := e1, e2 // // where the types of x and y match the types of e1 and e2. // // This works as long as we don't need to write any additional type // information. if callerStmt.Tok == token.ASSIGN && // LHS types already determined before call len(nonTrivial) == 0 { // no non-trivial conversions to worry about logf("substrategy: slice assignment") return []ast.Stmt{&ast.AssignStmt{ Lhs: lhs, Tok: callerStmt.Tok, TokPos: callerStmt.TokPos, Rhs: rhs, }}, true } // Inlining techniques below will need to write type information in order to // preserve the correct types of LHS identifiers. // // writeType is a simple helper to write out type expressions. // TODO(rfindley): // 1. handle qualified type names (potentially adding new imports) // 2. expand this to handle more type expressions. // 3. refactor to share logic with callee rewriting. universeAny := types.Universe.Lookup("any") typeExpr := func(typ types.Type, shadows ...map[string]bool) ast.Expr { var typeName string switch typ := typ.(type) { case *types.Basic: typeName = typ.Name() case interface{ Obj() *types.TypeName }: // Named, Alias, TypeParam typeName = typ.Obj().Name() } // Special case: check for universe "any". // TODO(golang/go#66921): this may become unnecessary if any becomes a proper alias. if typ == universeAny.Type() { typeName = "any" } if typeName == "" { return nil } for _, shadow := range shadows { if shadow[typeName] { logf("cannot write shadowed type name %q", typeName) return nil } } obj, _ := caller.lookup(typeName).(*types.TypeName) if obj != nil && types.Identical(obj.Type(), typ) { return ast.NewIdent(typeName) } return nil } // Substrategy "spread": in the case of a spread call (func f() (T1, T2) return // g()), since we didn't hit the 'splice' substrategy, there must be some // non-declaring expression on the LHS. Simplify this by pre-declaring // variables, rewriting // // x, y := f() // // to // // var x int // x, y = g() // // Which works as long as the predeclared variables do not overlap with free // names on the RHS. if len(rhs) != len(lhs) { assert(len(rhs) == 1 && len(returnOperands) == 1, "expected spread call") for _, id := range defs { if id != nil && freeNames[id.Name] { // By predeclaring variables, we're changing them to be in scope of the // RHS. We can't do this if their names are free on the RHS. return nil, false } } // Write out the specs, being careful to avoid shadowing free names in // their type expressions. var ( specs []ast.Spec specIdxs []int shadow = make(map[string]bool) ) failed := false byType.Iterate(func(typ types.Type, v any) { if failed { return } idxs := v.([]int) specIdxs = append(specIdxs, idxs[0]) texpr := typeExpr(typ, shadow) if texpr == nil { failed = true return } spec := &ast.ValueSpec{ Type: texpr, } for _, idx := range idxs { spec.Names = append(spec.Names, ast.NewIdent(defs[idx].Name)) } specs = append(specs, spec) }) if failed { return nil, false } logf("substrategy: spread assignment") return []ast.Stmt{ &ast.DeclStmt{ Decl: &ast.GenDecl{ Tok: token.VAR, Specs: specs, }, }, &ast.AssignStmt{ Lhs: callerStmt.Lhs, Tok: token.ASSIGN, Rhs: returnOperands, }, }, true } assert(len(lhs) == len(rhs), "mismatching LHS and RHS") // Substrategy "convert": write out RHS expressions with explicit type conversions // as necessary, rewriting // // x, y := f() // // to // // x, y := 1, int32(2) // // As required to preserve types. // // In the special case of _ = nil, which is disallowed by the type checker // (since nil has no default type), we delete the assignment. var origIdxs []int // maps back to original indexes after lhs and rhs are pruned i := 0 for j := range lhs { if _, ok := nilBlankAssigns[j]; !ok { lhs[i] = lhs[j] rhs[i] = rhs[j] origIdxs = append(origIdxs, j) i++ } } lhs = lhs[:i] rhs = rhs[:i] if len(lhs) == 0 { logf("trivial assignment after pruning nil blanks assigns") // After pruning, we have no remaining assignments. // Signal this by returning a non-nil slice of statements. return nil, true } // Write out explicit conversions as necessary. // // A conversion is necessary if the LHS is being defined, and the RHS return // involved a nontrivial implicit conversion. for i, expr := range rhs { idx := origIdxs[i] if nonTrivial[idx] && defs[idx] != nil { typ := caller.Info.TypeOf(lhs[i]) texpr := typeExpr(typ) if texpr == nil { return nil, false } if _, ok := texpr.(*ast.StarExpr); ok { // TODO(rfindley): is this necessary? Doesn't the formatter add these parens? texpr = &ast.ParenExpr{X: texpr} // *T -> (*T) so that (*T)(x) is valid } rhs[i] = &ast.CallExpr{ Fun: texpr, Args: []ast.Expr{expr}, } } } logf("substrategy: convert assignment") return []ast.Stmt{&ast.AssignStmt{ Lhs: lhs, Tok: callerStmt.Tok, Rhs: rhs, }}, true } // tailCallSafeReturn reports whether the callee's return statements may be safely // used to return from the function enclosing the caller (which must exist). func tailCallSafeReturn(caller *Caller, calleeSymbol *types.Func, callee *gobCallee) bool { // It is safe if all callee returns involve only trivial conversions. if !hasNonTrivialReturn(callee.Returns) { return true } var callerType types.Type // Find type of innermost function enclosing call. // (Beware: Caller.enclosingFunc is the outermost.) loop: for _, n := range caller.path { switch f := n.(type) { case *ast.FuncDecl: callerType = caller.Info.ObjectOf(f.Name).Type() break loop case *ast.FuncLit: callerType = caller.Info.TypeOf(f) break loop } } // Non-trivial return conversions in the callee are permitted // if the same non-trivial conversion would occur after inlining, // i.e. if the caller and callee results tuples are identical. callerResults := callerType.(*types.Signature).Results() calleeResults := calleeSymbol.Type().(*types.Signature).Results() return types.Identical(callerResults, calleeResults) } // hasNonTrivialReturn reports whether any of the returns involve a nontrivial // implicit conversion of a result expression. func hasNonTrivialReturn(returnInfo [][]returnOperandFlags) bool { for _, resultInfo := range returnInfo { for _, r := range resultInfo { if r&nonTrivialResult != 0 { return true } } } return false } type unit struct{} // for representing sets as maps