diff --git a/src/cmd/compile/internal/noder/irgen.go b/src/cmd/compile/internal/noder/irgen.go index aac8b5e641..9d14b06d3c 100644 --- a/src/cmd/compile/internal/noder/irgen.go +++ b/src/cmd/compile/internal/noder/irgen.go @@ -191,7 +191,7 @@ Outer: // Double check for any type-checking inconsistencies. This can be // removed once we're confident in IR generation results. - syntax.Walk(p.file, func(n syntax.Node) bool { + syntax.Crawl(p.file, func(n syntax.Node) bool { g.validate(n) return false }) diff --git a/src/cmd/compile/internal/noder/quirks.go b/src/cmd/compile/internal/noder/quirks.go index 91b4c22025..914c5d2bd7 100644 --- a/src/cmd/compile/internal/noder/quirks.go +++ b/src/cmd/compile/internal/noder/quirks.go @@ -36,7 +36,7 @@ func posBasesOf(noders []*noder) []*syntax.PosBase { var bases []*syntax.PosBase for _, p := range noders { - syntax.Walk(p.file, func(n syntax.Node) bool { + syntax.Crawl(p.file, func(n syntax.Node) bool { if b := n.Pos().Base(); !seen[b] { bases = append(bases, b) seen[b] = true @@ -74,7 +74,7 @@ func importedObjsOf(curpkg *types2.Package, info *types2.Info, noders []*noder) } for _, p := range noders { - syntax.Walk(p.file, func(n syntax.Node) bool { + syntax.Crawl(p.file, func(n syntax.Node) bool { switch n := n.(type) { case *syntax.ConstDecl: assoc(n, n.NameList...) @@ -167,7 +167,7 @@ func importedObjsOf(curpkg *types2.Package, info *types2.Info, noders []*noder) if n == nil { return } - syntax.Walk(n, func(n syntax.Node) bool { + syntax.Crawl(n, func(n syntax.Node) bool { switch n := n.(type) { case *syntax.Name: checkdef(n) @@ -237,7 +237,7 @@ func importedObjsOf(curpkg *types2.Package, info *types2.Info, noders []*noder) } if phase >= 5 { - syntax.Walk(p.file, func(n syntax.Node) bool { + syntax.Crawl(p.file, func(n syntax.Node) bool { if name, ok := n.(*syntax.Name); ok { if obj, ok := info.Uses[name]; ok { resolveObj(name.Pos(), obj) diff --git a/src/cmd/compile/internal/noder/writer.go b/src/cmd/compile/internal/noder/writer.go index cc749b0d1e..bc89e1a262 100644 --- a/src/cmd/compile/internal/noder/writer.go +++ b/src/cmd/compile/internal/noder/writer.go @@ -1318,7 +1318,7 @@ func (w *writer) captureVars(expr *syntax.FuncLit) (closureVars []posObj, locals // function literal as the position of the intermediary capture. if quirksMode() && rbracePos == (syntax.Pos{}) { rbracePos = n.Body.Rbrace - syntax.Walk(n.Body, visitor) + syntax.Crawl(n.Body, visitor) rbracePos = syntax.Pos{} return true } @@ -1327,17 +1327,17 @@ func (w *writer) captureVars(expr *syntax.FuncLit) (closureVars []posObj, locals // Quirk: typecheck visits (and thus captures) the RHS of // assignment statements before the LHS. if quirksMode() && (n.Op == 0 || n.Op == syntax.Def) { - syntax.Walk(n.Rhs, visitor) - syntax.Walk(n.Lhs, visitor) + syntax.Crawl(n.Rhs, visitor) + syntax.Crawl(n.Lhs, visitor) return true } case *syntax.RangeClause: // Quirk: Similarly, it visits the expression to be iterated // over before the iteration variables. if quirksMode() { - syntax.Walk(n.X, visitor) + syntax.Crawl(n.X, visitor) if n.Lhs != nil { - syntax.Walk(n.Lhs, visitor) + syntax.Crawl(n.Lhs, visitor) } return true } @@ -1345,7 +1345,7 @@ func (w *writer) captureVars(expr *syntax.FuncLit) (closureVars []posObj, locals return false } - syntax.Walk(expr.Body, visitor) + syntax.Crawl(expr.Body, visitor) return } @@ -1392,7 +1392,7 @@ func (pw *pkgWriter) collectDecls(noders []*noder) { for _, p := range noders { var importedEmbed, importedUnsafe bool - syntax.Walk(p.file, func(n syntax.Node) bool { + syntax.Crawl(p.file, func(n syntax.Node) bool { switch n := n.(type) { case *syntax.File: pw.checkPragmas(n.Pragma, ir.GoBuildPragma, false) diff --git a/src/cmd/compile/internal/syntax/walk.go b/src/cmd/compile/internal/syntax/walk.go index c26e97a0d8..ef213daf7d 100644 --- a/src/cmd/compile/internal/syntax/walk.go +++ b/src/cmd/compile/internal/syntax/walk.go @@ -8,31 +8,73 @@ package syntax import "fmt" -// Walk traverses a syntax in pre-order: It starts by calling f(root); -// root must not be nil. If f returns false (== "continue"), Walk calls +// Inspect traverses an AST in pre-order: It starts by calling +// f(node); node must not be nil. If f returns true, Inspect invokes f +// recursively for each of the non-nil children of node, followed by a +// call of f(nil). +// +// See Walk for caveats about shared nodes. +func Inspect(root Node, f func(Node) bool) { + Walk(root, inspector(f)) +} + +type inspector func(Node) bool + +func (v inspector) Visit(node Node) Visitor { + if v(node) { + return v + } + return nil +} + +// Crawl traverses a syntax in pre-order: It starts by calling f(root); +// root must not be nil. If f returns false (== "continue"), Crawl calls // f recursively for each of the non-nil children of that node; if f -// returns true (== "stop"), Walk does not traverse the respective node's +// returns true (== "stop"), Crawl does not traverse the respective node's // children. +// +// See Walk for caveats about shared nodes. +// +// Deprecated: Use Inspect instead. +func Crawl(root Node, f func(Node) bool) { + Inspect(root, func(node Node) bool { + return node != nil && !f(node) + }) +} + +// Walk traverses an AST in pre-order: It starts by calling +// v.Visit(node); node must not be nil. If the visitor w returned by +// v.Visit(node) is not nil, Walk is invoked recursively with visitor +// w for each of the non-nil children of node, followed by a call of +// w.Visit(nil). +// // Some nodes may be shared among multiple parent nodes (e.g., types in // field lists such as type T in "a, b, c T"). Such shared nodes are // walked multiple times. // TODO(gri) Revisit this design. It may make sense to walk those nodes // only once. A place where this matters is types2.TestResolveIdents. -func Walk(root Node, f func(Node) bool) { - w := walker{f} - w.node(root) +func Walk(root Node, v Visitor) { + walker{v}.node(root) +} + +// A Visitor's Visit method is invoked for each node encountered by Walk. +// If the result visitor w is not nil, Walk visits each of the children +// of node with the visitor w, followed by a call of w.Visit(nil). +type Visitor interface { + Visit(node Node) (w Visitor) } type walker struct { - f func(Node) bool + v Visitor } -func (w *walker) node(n Node) { +func (w walker) node(n Node) { if n == nil { panic("invalid syntax tree: nil node") } - if w.f(n) { + w.v = w.v.Visit(n) + if w.v == nil { return } @@ -285,33 +327,35 @@ func (w *walker) node(n Node) { default: panic(fmt.Sprintf("internal error: unknown node type %T", n)) } + + w.v.Visit(nil) } -func (w *walker) declList(list []Decl) { +func (w walker) declList(list []Decl) { for _, n := range list { w.node(n) } } -func (w *walker) exprList(list []Expr) { +func (w walker) exprList(list []Expr) { for _, n := range list { w.node(n) } } -func (w *walker) stmtList(list []Stmt) { +func (w walker) stmtList(list []Stmt) { for _, n := range list { w.node(n) } } -func (w *walker) nameList(list []*Name) { +func (w walker) nameList(list []*Name) { for _, n := range list { w.node(n) } } -func (w *walker) fieldList(list []*Field) { +func (w walker) fieldList(list []*Field) { for _, n := range list { w.node(n) } diff --git a/src/cmd/compile/internal/types2/errorcalls_test.go b/src/cmd/compile/internal/types2/errorcalls_test.go index 28bb33aaff..80b05f9f0f 100644 --- a/src/cmd/compile/internal/types2/errorcalls_test.go +++ b/src/cmd/compile/internal/types2/errorcalls_test.go @@ -18,7 +18,7 @@ func TestErrorCalls(t *testing.T) { } for _, file := range files { - syntax.Walk(file, func(n syntax.Node) bool { + syntax.Crawl(file, func(n syntax.Node) bool { call, _ := n.(*syntax.CallExpr) if call == nil { return false diff --git a/src/cmd/compile/internal/types2/issues_test.go b/src/cmd/compile/internal/types2/issues_test.go index e716a48038..aafe8de367 100644 --- a/src/cmd/compile/internal/types2/issues_test.go +++ b/src/cmd/compile/internal/types2/issues_test.go @@ -321,7 +321,7 @@ func TestIssue25627(t *testing.T) { } } - syntax.Walk(f, func(n syntax.Node) bool { + syntax.Crawl(f, func(n syntax.Node) bool { if decl, _ := n.(*syntax.TypeDecl); decl != nil { if tv, ok := info.Types[decl.Type]; ok && decl.Name.Value == "T" { want := strings.Count(src, ";") + 1 diff --git a/src/cmd/compile/internal/types2/resolver_test.go b/src/cmd/compile/internal/types2/resolver_test.go index aee435ff5f..a02abce081 100644 --- a/src/cmd/compile/internal/types2/resolver_test.go +++ b/src/cmd/compile/internal/types2/resolver_test.go @@ -143,7 +143,7 @@ func TestResolveIdents(t *testing.T) { // check that qualified identifiers are resolved for _, f := range files { - syntax.Walk(f, func(n syntax.Node) bool { + syntax.Crawl(f, func(n syntax.Node) bool { if s, ok := n.(*syntax.SelectorExpr); ok { if x, ok := s.X.(*syntax.Name); ok { obj := uses[x] @@ -177,7 +177,7 @@ func TestResolveIdents(t *testing.T) { foundDefs := make(map[*syntax.Name]bool) var both []string for _, f := range files { - syntax.Walk(f, func(n syntax.Node) bool { + syntax.Crawl(f, func(n syntax.Node) bool { if x, ok := n.(*syntax.Name); ok { var objects int if _, found := uses[x]; found {