[dev.typeparams] cmd/compile/internal/syntax: go/ast-style walk API

This CL adds go/ast's Visitor, Walk, and Inspect functions to package
syntax. Having functions with the same API and semantics as their
go/ast counterparts reduces the mental load of context switching
between go/ast and syntax.

It also renames the existing Walk function into Crawl, and marks it as
a deprecated wrapper around Inspect. (I named it "Crawl" because it's
less functional than "Walk"... get it??)

There aren't that many callers to Crawl, so we can probably remove it
in the future. But it doesn't seem pressing, and I'm more concerned
about the risk of forgetting to invert a bool condition somewhere.

Change-Id: Ib2fb275873a1d1a730249c9cb584864cb6ec370e
Reviewed-on: https://go-review.googlesource.com/c/go/+/330429
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
Trust: Matthew Dempsky <mdempsky@google.com>
Trust: Robert Griesemer <gri@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
This commit is contained in:
Matthew Dempsky 2021-06-23 12:08:42 -07:00
parent a72a499c24
commit 8165256bc2
7 changed files with 74 additions and 30 deletions

View file

@ -191,7 +191,7 @@ Outer:
// Double check for any type-checking inconsistencies. This can be // Double check for any type-checking inconsistencies. This can be
// removed once we're confident in IR generation results. // removed once we're confident in IR generation results.
syntax.Walk(p.file, func(n syntax.Node) bool { syntax.Crawl(p.file, func(n syntax.Node) bool {
g.validate(n) g.validate(n)
return false return false
}) })

View file

@ -36,7 +36,7 @@ func posBasesOf(noders []*noder) []*syntax.PosBase {
var bases []*syntax.PosBase var bases []*syntax.PosBase
for _, p := range noders { for _, p := range noders {
syntax.Walk(p.file, func(n syntax.Node) bool { syntax.Crawl(p.file, func(n syntax.Node) bool {
if b := n.Pos().Base(); !seen[b] { if b := n.Pos().Base(); !seen[b] {
bases = append(bases, b) bases = append(bases, b)
seen[b] = true seen[b] = true
@ -74,7 +74,7 @@ func importedObjsOf(curpkg *types2.Package, info *types2.Info, noders []*noder)
} }
for _, p := range noders { for _, p := range noders {
syntax.Walk(p.file, func(n syntax.Node) bool { syntax.Crawl(p.file, func(n syntax.Node) bool {
switch n := n.(type) { switch n := n.(type) {
case *syntax.ConstDecl: case *syntax.ConstDecl:
assoc(n, n.NameList...) assoc(n, n.NameList...)
@ -167,7 +167,7 @@ func importedObjsOf(curpkg *types2.Package, info *types2.Info, noders []*noder)
if n == nil { if n == nil {
return return
} }
syntax.Walk(n, func(n syntax.Node) bool { syntax.Crawl(n, func(n syntax.Node) bool {
switch n := n.(type) { switch n := n.(type) {
case *syntax.Name: case *syntax.Name:
checkdef(n) checkdef(n)
@ -237,7 +237,7 @@ func importedObjsOf(curpkg *types2.Package, info *types2.Info, noders []*noder)
} }
if phase >= 5 { if phase >= 5 {
syntax.Walk(p.file, func(n syntax.Node) bool { syntax.Crawl(p.file, func(n syntax.Node) bool {
if name, ok := n.(*syntax.Name); ok { if name, ok := n.(*syntax.Name); ok {
if obj, ok := info.Uses[name]; ok { if obj, ok := info.Uses[name]; ok {
resolveObj(name.Pos(), obj) resolveObj(name.Pos(), obj)

View file

@ -1318,7 +1318,7 @@ func (w *writer) captureVars(expr *syntax.FuncLit) (closureVars []posObj, locals
// function literal as the position of the intermediary capture. // function literal as the position of the intermediary capture.
if quirksMode() && rbracePos == (syntax.Pos{}) { if quirksMode() && rbracePos == (syntax.Pos{}) {
rbracePos = n.Body.Rbrace rbracePos = n.Body.Rbrace
syntax.Walk(n.Body, visitor) syntax.Crawl(n.Body, visitor)
rbracePos = syntax.Pos{} rbracePos = syntax.Pos{}
return true return true
} }
@ -1327,17 +1327,17 @@ func (w *writer) captureVars(expr *syntax.FuncLit) (closureVars []posObj, locals
// Quirk: typecheck visits (and thus captures) the RHS of // Quirk: typecheck visits (and thus captures) the RHS of
// assignment statements before the LHS. // assignment statements before the LHS.
if quirksMode() && (n.Op == 0 || n.Op == syntax.Def) { if quirksMode() && (n.Op == 0 || n.Op == syntax.Def) {
syntax.Walk(n.Rhs, visitor) syntax.Crawl(n.Rhs, visitor)
syntax.Walk(n.Lhs, visitor) syntax.Crawl(n.Lhs, visitor)
return true return true
} }
case *syntax.RangeClause: case *syntax.RangeClause:
// Quirk: Similarly, it visits the expression to be iterated // Quirk: Similarly, it visits the expression to be iterated
// over before the iteration variables. // over before the iteration variables.
if quirksMode() { if quirksMode() {
syntax.Walk(n.X, visitor) syntax.Crawl(n.X, visitor)
if n.Lhs != nil { if n.Lhs != nil {
syntax.Walk(n.Lhs, visitor) syntax.Crawl(n.Lhs, visitor)
} }
return true return true
} }
@ -1345,7 +1345,7 @@ func (w *writer) captureVars(expr *syntax.FuncLit) (closureVars []posObj, locals
return false return false
} }
syntax.Walk(expr.Body, visitor) syntax.Crawl(expr.Body, visitor)
return return
} }
@ -1392,7 +1392,7 @@ func (pw *pkgWriter) collectDecls(noders []*noder) {
for _, p := range noders { for _, p := range noders {
var importedEmbed, importedUnsafe bool var importedEmbed, importedUnsafe bool
syntax.Walk(p.file, func(n syntax.Node) bool { syntax.Crawl(p.file, func(n syntax.Node) bool {
switch n := n.(type) { switch n := n.(type) {
case *syntax.File: case *syntax.File:
pw.checkPragmas(n.Pragma, ir.GoBuildPragma, false) pw.checkPragmas(n.Pragma, ir.GoBuildPragma, false)

View file

@ -8,31 +8,73 @@ package syntax
import "fmt" import "fmt"
// Walk traverses a syntax in pre-order: It starts by calling f(root); // Inspect traverses an AST in pre-order: It starts by calling
// root must not be nil. If f returns false (== "continue"), Walk calls // f(node); node must not be nil. If f returns true, Inspect invokes f
// recursively for each of the non-nil children of node, followed by a
// call of f(nil).
//
// See Walk for caveats about shared nodes.
func Inspect(root Node, f func(Node) bool) {
Walk(root, inspector(f))
}
type inspector func(Node) bool
func (v inspector) Visit(node Node) Visitor {
if v(node) {
return v
}
return nil
}
// Crawl traverses a syntax in pre-order: It starts by calling f(root);
// root must not be nil. If f returns false (== "continue"), Crawl calls
// f recursively for each of the non-nil children of that node; if f // f recursively for each of the non-nil children of that node; if f
// returns true (== "stop"), Walk does not traverse the respective node's // returns true (== "stop"), Crawl does not traverse the respective node's
// children. // children.
//
// See Walk for caveats about shared nodes.
//
// Deprecated: Use Inspect instead.
func Crawl(root Node, f func(Node) bool) {
Inspect(root, func(node Node) bool {
return node != nil && !f(node)
})
}
// Walk traverses an AST in pre-order: It starts by calling
// v.Visit(node); node must not be nil. If the visitor w returned by
// v.Visit(node) is not nil, Walk is invoked recursively with visitor
// w for each of the non-nil children of node, followed by a call of
// w.Visit(nil).
//
// Some nodes may be shared among multiple parent nodes (e.g., types in // Some nodes may be shared among multiple parent nodes (e.g., types in
// field lists such as type T in "a, b, c T"). Such shared nodes are // field lists such as type T in "a, b, c T"). Such shared nodes are
// walked multiple times. // walked multiple times.
// TODO(gri) Revisit this design. It may make sense to walk those nodes // TODO(gri) Revisit this design. It may make sense to walk those nodes
// only once. A place where this matters is types2.TestResolveIdents. // only once. A place where this matters is types2.TestResolveIdents.
func Walk(root Node, f func(Node) bool) { func Walk(root Node, v Visitor) {
w := walker{f} walker{v}.node(root)
w.node(root) }
// A Visitor's Visit method is invoked for each node encountered by Walk.
// If the result visitor w is not nil, Walk visits each of the children
// of node with the visitor w, followed by a call of w.Visit(nil).
type Visitor interface {
Visit(node Node) (w Visitor)
} }
type walker struct { type walker struct {
f func(Node) bool v Visitor
} }
func (w *walker) node(n Node) { func (w walker) node(n Node) {
if n == nil { if n == nil {
panic("invalid syntax tree: nil node") panic("invalid syntax tree: nil node")
} }
if w.f(n) { w.v = w.v.Visit(n)
if w.v == nil {
return return
} }
@ -285,33 +327,35 @@ func (w *walker) node(n Node) {
default: default:
panic(fmt.Sprintf("internal error: unknown node type %T", n)) panic(fmt.Sprintf("internal error: unknown node type %T", n))
} }
w.v.Visit(nil)
} }
func (w *walker) declList(list []Decl) { func (w walker) declList(list []Decl) {
for _, n := range list { for _, n := range list {
w.node(n) w.node(n)
} }
} }
func (w *walker) exprList(list []Expr) { func (w walker) exprList(list []Expr) {
for _, n := range list { for _, n := range list {
w.node(n) w.node(n)
} }
} }
func (w *walker) stmtList(list []Stmt) { func (w walker) stmtList(list []Stmt) {
for _, n := range list { for _, n := range list {
w.node(n) w.node(n)
} }
} }
func (w *walker) nameList(list []*Name) { func (w walker) nameList(list []*Name) {
for _, n := range list { for _, n := range list {
w.node(n) w.node(n)
} }
} }
func (w *walker) fieldList(list []*Field) { func (w walker) fieldList(list []*Field) {
for _, n := range list { for _, n := range list {
w.node(n) w.node(n)
} }

View file

@ -18,7 +18,7 @@ func TestErrorCalls(t *testing.T) {
} }
for _, file := range files { for _, file := range files {
syntax.Walk(file, func(n syntax.Node) bool { syntax.Crawl(file, func(n syntax.Node) bool {
call, _ := n.(*syntax.CallExpr) call, _ := n.(*syntax.CallExpr)
if call == nil { if call == nil {
return false return false

View file

@ -321,7 +321,7 @@ func TestIssue25627(t *testing.T) {
} }
} }
syntax.Walk(f, func(n syntax.Node) bool { syntax.Crawl(f, func(n syntax.Node) bool {
if decl, _ := n.(*syntax.TypeDecl); decl != nil { if decl, _ := n.(*syntax.TypeDecl); decl != nil {
if tv, ok := info.Types[decl.Type]; ok && decl.Name.Value == "T" { if tv, ok := info.Types[decl.Type]; ok && decl.Name.Value == "T" {
want := strings.Count(src, ";") + 1 want := strings.Count(src, ";") + 1

View file

@ -143,7 +143,7 @@ func TestResolveIdents(t *testing.T) {
// check that qualified identifiers are resolved // check that qualified identifiers are resolved
for _, f := range files { for _, f := range files {
syntax.Walk(f, func(n syntax.Node) bool { syntax.Crawl(f, func(n syntax.Node) bool {
if s, ok := n.(*syntax.SelectorExpr); ok { if s, ok := n.(*syntax.SelectorExpr); ok {
if x, ok := s.X.(*syntax.Name); ok { if x, ok := s.X.(*syntax.Name); ok {
obj := uses[x] obj := uses[x]
@ -177,7 +177,7 @@ func TestResolveIdents(t *testing.T) {
foundDefs := make(map[*syntax.Name]bool) foundDefs := make(map[*syntax.Name]bool)
var both []string var both []string
for _, f := range files { for _, f := range files {
syntax.Walk(f, func(n syntax.Node) bool { syntax.Crawl(f, func(n syntax.Node) bool {
if x, ok := n.(*syntax.Name); ok { if x, ok := n.(*syntax.Name); ok {
var objects int var objects int
if _, found := uses[x]; found { if _, found := uses[x]; found {