[dev.typeparams] go/*: switch from ListExpr to MultiIndexExpr

When instantiating a generic type or function with multiple type
arguments, we need to represent an index expression with multiple
indexes in the AST. Previous to this CL this was done with a new
ast.ListExpr node, which allowed packing multiple expressions into a
single ast.Expr. This compositional pattern can be both inefficient and
cumbersome to work with, and introduces a new node type that only exists
to augment the meaning of an existing node type.

By comparison, other specializations of syntax are given distinct nodes
in go/ast, for example variations of switch or for statements, so the
use of ListExpr was also (arguably) inconsistent.

This CL removes ListExpr, and instead adds a MultiIndexExpr node, which
is exactly like IndexExpr but allows for multiple index arguments. This
requires special handling for this new node type, but a new wrapper in
the typeparams helper package largely mitigates this special handling.

Change-Id: I65eb29c025c599bae37501716284dc7eb953b2ad
Reviewed-on: https://go-review.googlesource.com/c/go/+/327149
Trust: Robert Findley <rfindley@google.com>
Reviewed-by: Robert Griesemer <gri@golang.org>
Run-TryBot: Robert Findley <rfindley@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
This commit is contained in:
Rob Findley 2021-06-11 10:58:43 -04:00 committed by Robert Findley
parent 6b85a218b8
commit 334f2fc045
14 changed files with 170 additions and 167 deletions

View file

@ -344,6 +344,15 @@ type (
Rbrack token.Pos // position of "]" Rbrack token.Pos // position of "]"
} }
// A MultiIndexExpr node represents an expression followed by multiple
// indices.
MultiIndexExpr struct {
X Expr // expression
Lbrack token.Pos // position of "["
Indices []Expr // index expressions
Rbrack token.Pos // position of "]"
}
// A SliceExpr node represents an expression followed by slice indices. // A SliceExpr node represents an expression followed by slice indices.
SliceExpr struct { SliceExpr struct {
X Expr // expression X Expr // expression
@ -374,13 +383,6 @@ type (
Rparen token.Pos // position of ")" Rparen token.Pos // position of ")"
} }
// A ListExpr node represents a list of expressions separated by commas.
// ListExpr nodes are used as index in IndexExpr nodes representing type
// or function instantiations with more than one type argument.
ListExpr struct {
ElemList []Expr
}
// A StarExpr node represents an expression of the form "*" Expression. // A StarExpr node represents an expression of the form "*" Expression.
// Semantically it could be a unary "*" expression, or a pointer type. // Semantically it could be a unary "*" expression, or a pointer type.
// //
@ -494,15 +496,10 @@ func (x *CompositeLit) Pos() token.Pos {
func (x *ParenExpr) Pos() token.Pos { return x.Lparen } func (x *ParenExpr) Pos() token.Pos { return x.Lparen }
func (x *SelectorExpr) Pos() token.Pos { return x.X.Pos() } func (x *SelectorExpr) Pos() token.Pos { return x.X.Pos() }
func (x *IndexExpr) Pos() token.Pos { return x.X.Pos() } func (x *IndexExpr) Pos() token.Pos { return x.X.Pos() }
func (x *MultiIndexExpr) Pos() token.Pos { return x.X.Pos() }
func (x *SliceExpr) Pos() token.Pos { return x.X.Pos() } func (x *SliceExpr) Pos() token.Pos { return x.X.Pos() }
func (x *TypeAssertExpr) Pos() token.Pos { return x.X.Pos() } func (x *TypeAssertExpr) Pos() token.Pos { return x.X.Pos() }
func (x *CallExpr) Pos() token.Pos { return x.Fun.Pos() } func (x *CallExpr) Pos() token.Pos { return x.Fun.Pos() }
func (x *ListExpr) Pos() token.Pos {
if len(x.ElemList) > 0 {
return x.ElemList[0].Pos()
}
return token.NoPos
}
func (x *StarExpr) Pos() token.Pos { return x.Star } func (x *StarExpr) Pos() token.Pos { return x.Star }
func (x *UnaryExpr) Pos() token.Pos { return x.OpPos } func (x *UnaryExpr) Pos() token.Pos { return x.OpPos }
func (x *BinaryExpr) Pos() token.Pos { return x.X.Pos() } func (x *BinaryExpr) Pos() token.Pos { return x.X.Pos() }
@ -533,15 +530,10 @@ func (x *CompositeLit) End() token.Pos { return x.Rbrace + 1 }
func (x *ParenExpr) End() token.Pos { return x.Rparen + 1 } func (x *ParenExpr) End() token.Pos { return x.Rparen + 1 }
func (x *SelectorExpr) End() token.Pos { return x.Sel.End() } func (x *SelectorExpr) End() token.Pos { return x.Sel.End() }
func (x *IndexExpr) End() token.Pos { return x.Rbrack + 1 } func (x *IndexExpr) End() token.Pos { return x.Rbrack + 1 }
func (x *MultiIndexExpr) End() token.Pos { return x.Rbrack + 1 }
func (x *SliceExpr) End() token.Pos { return x.Rbrack + 1 } func (x *SliceExpr) End() token.Pos { return x.Rbrack + 1 }
func (x *TypeAssertExpr) End() token.Pos { return x.Rparen + 1 } func (x *TypeAssertExpr) End() token.Pos { return x.Rparen + 1 }
func (x *CallExpr) End() token.Pos { return x.Rparen + 1 } func (x *CallExpr) End() token.Pos { return x.Rparen + 1 }
func (x *ListExpr) End() token.Pos {
if len(x.ElemList) > 0 {
return x.ElemList[len(x.ElemList)-1].End()
}
return token.NoPos
}
func (x *StarExpr) End() token.Pos { return x.X.End() } func (x *StarExpr) End() token.Pos { return x.X.End() }
func (x *UnaryExpr) End() token.Pos { return x.X.End() } func (x *UnaryExpr) End() token.Pos { return x.X.End() }
func (x *BinaryExpr) End() token.Pos { return x.Y.End() } func (x *BinaryExpr) End() token.Pos { return x.Y.End() }
@ -570,10 +562,10 @@ func (*CompositeLit) exprNode() {}
func (*ParenExpr) exprNode() {} func (*ParenExpr) exprNode() {}
func (*SelectorExpr) exprNode() {} func (*SelectorExpr) exprNode() {}
func (*IndexExpr) exprNode() {} func (*IndexExpr) exprNode() {}
func (*MultiIndexExpr) exprNode() {}
func (*SliceExpr) exprNode() {} func (*SliceExpr) exprNode() {}
func (*TypeAssertExpr) exprNode() {} func (*TypeAssertExpr) exprNode() {}
func (*CallExpr) exprNode() {} func (*CallExpr) exprNode() {}
func (*ListExpr) exprNode() {}
func (*StarExpr) exprNode() {} func (*StarExpr) exprNode() {}
func (*UnaryExpr) exprNode() {} func (*UnaryExpr) exprNode() {}
func (*BinaryExpr) exprNode() {} func (*BinaryExpr) exprNode() {}

View file

@ -116,6 +116,12 @@ func Walk(v Visitor, node Node) {
Walk(v, n.X) Walk(v, n.X)
Walk(v, n.Index) Walk(v, n.Index)
case *MultiIndexExpr:
Walk(v, n.X)
for _, index := range n.Indices {
Walk(v, index)
}
case *SliceExpr: case *SliceExpr:
Walk(v, n.X) Walk(v, n.X)
if n.Low != nil { if n.Low != nil {
@ -138,11 +144,6 @@ func Walk(v Visitor, node Node) {
Walk(v, n.Fun) Walk(v, n.Fun)
walkExprList(v, n.Args) walkExprList(v, n.Args)
case *ListExpr:
for _, elem := range n.ElemList {
Walk(v, elem)
}
case *StarExpr: case *StarExpr:
Walk(v, n.X) Walk(v, n.X)

View file

@ -7,42 +7,56 @@ package typeparams
import ( import (
"fmt" "fmt"
"go/ast" "go/ast"
"go/token"
) )
const Enabled = true const Enabled = true
func PackExpr(list []ast.Expr) ast.Expr { func PackIndexExpr(x ast.Expr, lbrack token.Pos, exprs []ast.Expr, rbrack token.Pos) ast.Expr {
switch len(list) { switch len(exprs) {
case 0: case 0:
// Return an empty ListExpr here, rather than nil, as IndexExpr.Index must panic("internal error: PackIndexExpr with empty expr slice")
// never be nil.
// TODO(rFindley) would a BadExpr be more appropriate here?
return &ast.ListExpr{}
case 1: case 1:
return list[0] return &ast.IndexExpr{
X: x,
Lbrack: lbrack,
Index: exprs[0],
Rbrack: rbrack,
}
default: default:
return &ast.ListExpr{ElemList: list} return &ast.MultiIndexExpr{
X: x,
Lbrack: lbrack,
Indices: exprs,
Rbrack: rbrack,
}
} }
} }
// TODO(gri) Should find a more efficient solution that doesn't // IndexExpr wraps an ast.IndexExpr or ast.MultiIndexExpr into the
// require introduction of a new slice for simple // MultiIndexExpr interface.
// expressions. //
func UnpackExpr(x ast.Expr) []ast.Expr { // Orig holds the original ast.Expr from which this IndexExpr was derived.
if x, _ := x.(*ast.ListExpr); x != nil { type IndexExpr struct {
return x.ElemList Orig ast.Expr // the wrapped expr, which may be distinct from MultiIndexExpr below.
*ast.MultiIndexExpr
} }
if x != nil {
return []ast.Expr{x} func UnpackIndexExpr(n ast.Node) *IndexExpr {
switch e := n.(type) {
case *ast.IndexExpr:
return &IndexExpr{e, &ast.MultiIndexExpr{
X: e.X,
Lbrack: e.Lbrack,
Indices: []ast.Expr{e.Index},
Rbrack: e.Rbrack,
}}
case *ast.MultiIndexExpr:
return &IndexExpr{e, e}
} }
return nil return nil
} }
func IsListExpr(n ast.Node) bool {
_, ok := n.(*ast.ListExpr)
return ok
}
func Get(n ast.Node) *ast.FieldList { func Get(n ast.Node) *ast.FieldList {
switch n := n.(type) { switch n := n.(type) {
case *ast.TypeSpec: case *ast.TypeSpec:

View file

@ -600,7 +600,7 @@ func (p *parser) parseArrayFieldOrTypeInstance(x *ast.Ident) (*ast.Ident, ast.Ex
} }
// x[P], x[P1, P2], ... // x[P], x[P1, P2], ...
return nil, &ast.IndexExpr{X: x, Lbrack: lbrack, Index: typeparams.PackExpr(args), Rbrack: rbrack} return nil, typeparams.PackIndexExpr(x, lbrack, args, rbrack)
} }
func (p *parser) parseFieldDecl() *ast.Field { func (p *parser) parseFieldDecl() *ast.Field {
@ -991,7 +991,7 @@ func (p *parser) parseMethodSpec() *ast.Field {
p.exprLev-- p.exprLev--
} }
rbrack := p.expectClosing(token.RBRACK, "type argument list") rbrack := p.expectClosing(token.RBRACK, "type argument list")
typ = &ast.IndexExpr{X: ident, Lbrack: lbrack, Index: typeparams.PackExpr(list), Rbrack: rbrack} typ = typeparams.PackIndexExpr(ident, lbrack, list, rbrack)
} }
case p.tok == token.LPAREN: case p.tok == token.LPAREN:
// ordinary method // ordinary method
@ -1178,7 +1178,6 @@ func (p *parser) parseTypeInstance(typ ast.Expr) ast.Expr {
} }
opening := p.expect(token.LBRACK) opening := p.expect(token.LBRACK)
p.exprLev++ p.exprLev++
var list []ast.Expr var list []ast.Expr
for p.tok != token.RBRACK && p.tok != token.EOF { for p.tok != token.RBRACK && p.tok != token.EOF {
@ -1192,7 +1191,17 @@ func (p *parser) parseTypeInstance(typ ast.Expr) ast.Expr {
closing := p.expectClosing(token.RBRACK, "type argument list") closing := p.expectClosing(token.RBRACK, "type argument list")
return &ast.IndexExpr{X: typ, Lbrack: opening, Index: typeparams.PackExpr(list), Rbrack: closing} if len(list) == 0 {
p.errorExpected(closing, "type argument list")
return &ast.IndexExpr{
X: typ,
Lbrack: opening,
Index: &ast.BadExpr{From: opening + 1, To: closing},
Rbrack: closing,
}
}
return typeparams.PackIndexExpr(typ, opening, list, closing)
} }
func (p *parser) tryIdentOrType() ast.Expr { func (p *parser) tryIdentOrType() ast.Expr {
@ -1455,7 +1464,7 @@ func (p *parser) parseIndexOrSliceOrInstance(x ast.Expr) ast.Expr {
} }
// instance expression // instance expression
return &ast.IndexExpr{X: x, Lbrack: lbrack, Index: typeparams.PackExpr(args), Rbrack: rbrack} return typeparams.PackIndexExpr(x, lbrack, args, rbrack)
} }
func (p *parser) parseCallOrConversion(fun ast.Expr) *ast.CallExpr { func (p *parser) parseCallOrConversion(fun ast.Expr) *ast.CallExpr {
@ -1557,6 +1566,7 @@ func (p *parser) checkExpr(x ast.Expr) ast.Expr {
panic("unreachable") panic("unreachable")
case *ast.SelectorExpr: case *ast.SelectorExpr:
case *ast.IndexExpr: case *ast.IndexExpr:
case *ast.MultiIndexExpr:
case *ast.SliceExpr: case *ast.SliceExpr:
case *ast.TypeAssertExpr: case *ast.TypeAssertExpr:
// If t.Type == nil we have a type assertion of the form // If t.Type == nil we have a type assertion of the form
@ -1646,7 +1656,7 @@ func (p *parser) parsePrimaryExpr() (x ast.Expr) {
return return
} }
// x is possibly a composite literal type // x is possibly a composite literal type
case *ast.IndexExpr: case *ast.IndexExpr, *ast.MultiIndexExpr:
if p.exprLev < 0 { if p.exprLev < 0 {
return return
} }

View file

@ -871,17 +871,15 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
// TODO(gri): should treat[] like parentheses and undo one level of depth // TODO(gri): should treat[] like parentheses and undo one level of depth
p.expr1(x.X, token.HighestPrec, 1) p.expr1(x.X, token.HighestPrec, 1)
p.print(x.Lbrack, token.LBRACK) p.print(x.Lbrack, token.LBRACK)
// Note: we're a bit defensive here to handle the case of a ListExpr of
// length 1.
if list := typeparams.UnpackExpr(x.Index); len(list) > 0 {
if len(list) > 1 {
p.exprList(x.Lbrack, list, depth+1, commaTerm, x.Rbrack, false)
} else {
p.expr0(list[0], depth+1)
}
} else {
p.expr0(x.Index, depth+1) p.expr0(x.Index, depth+1)
} p.print(x.Rbrack, token.RBRACK)
case *ast.MultiIndexExpr:
// TODO(gri): as for IndexExpr, should treat [] like parentheses and undo
// one level of depth
p.expr1(x.X, token.HighestPrec, 1)
p.print(x.Lbrack, token.LBRACK)
p.exprList(x.Lbrack, x.Indices, depth+1, commaTerm, x.Rbrack, false)
p.print(x.Rbrack, token.RBRACK) p.print(x.Rbrack, token.RBRACK)
case *ast.SliceExpr: case *ast.SliceExpr:

View file

@ -16,23 +16,22 @@ import (
// funcInst type-checks a function instantiation inst and returns the result in x. // funcInst type-checks a function instantiation inst and returns the result in x.
// The operand x must be the evaluation of inst.X and its type must be a signature. // The operand x must be the evaluation of inst.X and its type must be a signature.
func (check *Checker) funcInst(x *operand, inst *ast.IndexExpr) { func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) {
xlist := typeparams.UnpackExpr(inst.Index) targs := check.typeList(ix.Indices)
targs := check.typeList(xlist)
if targs == nil { if targs == nil {
x.mode = invalid x.mode = invalid
x.expr = inst x.expr = ix.Orig
return return
} }
assert(len(targs) == len(xlist)) assert(len(targs) == len(ix.Indices))
// check number of type arguments (got) vs number of type parameters (want) // check number of type arguments (got) vs number of type parameters (want)
sig := x.typ.(*Signature) sig := x.typ.(*Signature)
got, want := len(targs), len(sig.tparams) got, want := len(targs), len(sig.tparams)
if got > want { if got > want {
check.errorf(xlist[got-1], _Todo, "got %d type arguments but want %d", got, want) check.errorf(ix.Indices[got-1], _Todo, "got %d type arguments but want %d", got, want)
x.mode = invalid x.mode = invalid
x.expr = inst x.expr = ix.Orig
return return
} }
@ -40,11 +39,11 @@ func (check *Checker) funcInst(x *operand, inst *ast.IndexExpr) {
inferred := false inferred := false
if got < want { if got < want {
targs = check.infer(inst, sig.tparams, targs, nil, nil, true) targs = check.infer(ix.Orig, sig.tparams, targs, nil, nil, true)
if targs == nil { if targs == nil {
// error was already reported // error was already reported
x.mode = invalid x.mode = invalid
x.expr = inst x.expr = ix.Orig
return return
} }
got = len(targs) got = len(targs)
@ -55,8 +54,8 @@ func (check *Checker) funcInst(x *operand, inst *ast.IndexExpr) {
// determine argument positions (for error reporting) // determine argument positions (for error reporting)
// TODO(rFindley) use a positioner here? instantiate would need to be // TODO(rFindley) use a positioner here? instantiate would need to be
// updated accordingly. // updated accordingly.
poslist := make([]token.Pos, len(xlist)) poslist := make([]token.Pos, len(ix.Indices))
for i, x := range xlist { for i, x := range ix.Indices {
poslist[i] = x.Pos() poslist[i] = x.Pos()
} }
@ -64,25 +63,27 @@ func (check *Checker) funcInst(x *operand, inst *ast.IndexExpr) {
res := check.instantiate(x.Pos(), sig, targs, poslist).(*Signature) res := check.instantiate(x.Pos(), sig, targs, poslist).(*Signature)
assert(res.tparams == nil) // signature is not generic anymore assert(res.tparams == nil) // signature is not generic anymore
if inferred { if inferred {
check.recordInferred(inst, targs, res) check.recordInferred(ix.Orig, targs, res)
} }
x.typ = res x.typ = res
x.mode = value x.mode = value
x.expr = inst x.expr = ix.Orig
} }
func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind { func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
var inst *ast.IndexExpr ix := typeparams.UnpackIndexExpr(call.Fun)
if iexpr, _ := call.Fun.(*ast.IndexExpr); iexpr != nil { if ix != nil {
if check.indexExpr(x, iexpr) { if check.indexExpr(x, ix) {
// Delay function instantiation to argument checking, // Delay function instantiation to argument checking,
// where we combine type and value arguments for type // where we combine type and value arguments for type
// inference. // inference.
assert(x.mode == value) assert(x.mode == value)
inst = iexpr } else {
ix = nil
} }
x.expr = iexpr x.expr = call.Fun
check.record(x) check.record(x)
} else { } else {
check.exprOrType(x, call.Fun) check.exprOrType(x, call.Fun)
} }
@ -149,21 +150,20 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
// evaluate type arguments, if any // evaluate type arguments, if any
var targs []Type var targs []Type
if inst != nil { if ix != nil {
xlist := typeparams.UnpackExpr(inst.Index) targs = check.typeList(ix.Indices)
targs = check.typeList(xlist)
if targs == nil { if targs == nil {
check.use(call.Args...) check.use(call.Args...)
x.mode = invalid x.mode = invalid
x.expr = call x.expr = call
return statement return statement
} }
assert(len(targs) == len(xlist)) assert(len(targs) == len(ix.Indices))
// check number of type arguments (got) vs number of type parameters (want) // check number of type arguments (got) vs number of type parameters (want)
got, want := len(targs), len(sig.tparams) got, want := len(targs), len(sig.tparams)
if got > want { if got > want {
check.errorf(xlist[want], _Todo, "got %d type arguments but want %d", got, want) check.errorf(ix.Indices[want], _Todo, "got %d type arguments but want %d", got, want)
check.use(call.Args...) check.use(call.Args...)
x.mode = invalid x.mode = invalid
x.expr = call x.expr = call

View file

@ -1331,9 +1331,10 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
case *ast.SelectorExpr: case *ast.SelectorExpr:
check.selector(x, e) check.selector(x, e)
case *ast.IndexExpr: case *ast.IndexExpr, *ast.MultiIndexExpr:
if check.indexExpr(x, e) { ix := typeparams.UnpackIndexExpr(e)
check.funcInst(x, e) if check.indexExpr(x, ix) {
check.funcInst(x, ix)
} }
if x.mode == invalid { if x.mode == invalid {
goto Error goto Error
@ -1423,13 +1424,8 @@ func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
// types, which are comparatively rare. // types, which are comparatively rare.
default: default:
if typeparams.IsListExpr(e) {
// catch-all for unexpected expression lists
check.errorf(e, _Todo, "unexpected list of expressions")
} else {
panic(fmt.Sprintf("%s: unknown expression type %T", check.fset.Position(e.Pos()), e)) panic(fmt.Sprintf("%s: unknown expression type %T", check.fset.Position(e.Pos()), e))
} }
}
// everything went well // everything went well
x.expr = e x.expr = e

View file

@ -67,11 +67,11 @@ func WriteExpr(buf *bytes.Buffer, x ast.Expr) {
buf.WriteByte('.') buf.WriteByte('.')
buf.WriteString(x.Sel.Name) buf.WriteString(x.Sel.Name)
case *ast.IndexExpr: case *ast.IndexExpr, *ast.MultiIndexExpr:
WriteExpr(buf, x.X) ix := typeparams.UnpackIndexExpr(x)
WriteExpr(buf, ix.X)
buf.WriteByte('[') buf.WriteByte('[')
exprs := typeparams.UnpackExpr(x.Index) for i, e := range ix.Indices {
for i, e := range exprs {
if i > 0 { if i > 0 {
buf.WriteString(", ") buf.WriteString(", ")
} }

View file

@ -15,18 +15,18 @@ import (
// If e is a valid function instantiation, indexExpr returns true. // If e is a valid function instantiation, indexExpr returns true.
// In that case x represents the uninstantiated function value and // In that case x represents the uninstantiated function value and
// it is the caller's responsibility to instantiate the function. // it is the caller's responsibility to instantiate the function.
func (check *Checker) indexExpr(x *operand, e *ast.IndexExpr) (isFuncInst bool) { func (check *Checker) indexExpr(x *operand, expr *typeparams.IndexExpr) (isFuncInst bool) {
check.exprOrType(x, e.X) check.exprOrType(x, expr.X)
switch x.mode { switch x.mode {
case invalid: case invalid:
check.use(typeparams.UnpackExpr(e.Index)...) check.use(expr.Indices...)
return false return false
case typexpr: case typexpr:
// type instantiation // type instantiation
x.mode = invalid x.mode = invalid
x.typ = check.varType(e) x.typ = check.varType(expr.Orig)
if x.typ != Typ[Invalid] { if x.typ != Typ[Invalid] {
x.mode = typexpr x.mode = typexpr
} }
@ -77,7 +77,7 @@ func (check *Checker) indexExpr(x *operand, e *ast.IndexExpr) (isFuncInst bool)
x.typ = typ.elem x.typ = typ.elem
case *Map: case *Map:
index := check.singleIndex(e) index := check.singleIndex(expr)
if index == nil { if index == nil {
x.mode = invalid x.mode = invalid
return return
@ -88,7 +88,7 @@ func (check *Checker) indexExpr(x *operand, e *ast.IndexExpr) (isFuncInst bool)
// ok to continue even if indexing failed - map element type is known // ok to continue even if indexing failed - map element type is known
x.mode = mapindex x.mode = mapindex
x.typ = typ.elem x.typ = typ.elem
x.expr = e x.expr = expr.Orig
return return
case *Union: case *Union:
@ -137,7 +137,7 @@ func (check *Checker) indexExpr(x *operand, e *ast.IndexExpr) (isFuncInst bool)
// If there are maps, the index expression must be assignable // If there are maps, the index expression must be assignable
// to the map key type (as for simple map index expressions). // to the map key type (as for simple map index expressions).
if nmaps > 0 { if nmaps > 0 {
index := check.singleIndex(e) index := check.singleIndex(expr)
if index == nil { if index == nil {
x.mode = invalid x.mode = invalid
return return
@ -151,7 +151,7 @@ func (check *Checker) indexExpr(x *operand, e *ast.IndexExpr) (isFuncInst bool)
if nmaps == typ.NumTerms() { if nmaps == typ.NumTerms() {
x.mode = mapindex x.mode = mapindex
x.typ = telem x.typ = telem
x.expr = e x.expr = expr.Orig
return return
} }
@ -180,7 +180,7 @@ func (check *Checker) indexExpr(x *operand, e *ast.IndexExpr) (isFuncInst bool)
return return
} }
index := check.singleIndex(e) index := check.singleIndex(expr)
if index == nil { if index == nil {
x.mode = invalid x.mode = invalid
return return
@ -311,23 +311,16 @@ L:
// singleIndex returns the (single) index from the index expression e. // singleIndex returns the (single) index from the index expression e.
// If the index is missing, or if there are multiple indices, an error // If the index is missing, or if there are multiple indices, an error
// is reported and the result is nil. // is reported and the result is nil.
func (check *Checker) singleIndex(e *ast.IndexExpr) ast.Expr { func (check *Checker) singleIndex(expr *typeparams.IndexExpr) ast.Expr {
index := e.Index if len(expr.Indices) == 0 {
if index == nil { check.invalidAST(expr.Orig, "index expression %v with 0 indices", expr)
check.invalidAST(e, "missing index for %s", e)
return nil return nil
} }
if len(expr.Indices) > 1 {
indexes := typeparams.UnpackExpr(index)
if len(indexes) == 0 {
check.invalidAST(index, "index expression %v with 0 indices", index)
return nil
}
if len(indexes) > 1 {
// TODO(rFindley) should this get a distinct error code? // TODO(rFindley) should this get a distinct error code?
check.invalidOp(indexes[1], _InvalidIndex, "more than one index") check.invalidOp(expr.Indices[1], _InvalidIndex, "more than one index")
} }
return indexes[0] return expr.Indices[0]
} }
// index checks an index expression for validity. // index checks an index expression for validity.

View file

@ -499,10 +499,12 @@ L: // unpack receiver type
} }
// unpack type parameters, if any // unpack type parameters, if any
if ptyp, _ := rtyp.(*ast.IndexExpr); ptyp != nil { switch rtyp.(type) {
rtyp = ptyp.X case *ast.IndexExpr, *ast.MultiIndexExpr:
ix := typeparams.UnpackIndexExpr(rtyp)
rtyp = ix.X
if unpackParams { if unpackParams {
for _, arg := range typeparams.UnpackExpr(ptyp.Index) { for _, arg := range ix.Indices {
var par *ast.Ident var par *ast.Ident
switch arg := arg.(type) { switch arg := arg.(type) {
case *ast.Ident: case *ast.Ident:
@ -510,7 +512,7 @@ L: // unpack receiver type
case *ast.BadExpr: case *ast.BadExpr:
// ignore - error already reported by parser // ignore - error already reported by parser
case nil: case nil:
check.invalidAST(ptyp, "parameterized receiver contains nil parameters") check.invalidAST(ix.Orig, "parameterized receiver contains nil parameters")
default: default:
check.errorf(arg, _Todo, "receiver type parameter %s must be an identifier", arg) check.errorf(arg, _Todo, "receiver type parameter %s must be an identifier", arg)
} }

View file

@ -244,24 +244,21 @@ func isubst(x ast.Expr, smap map[*ast.Ident]*ast.Ident) ast.Expr {
new.X = X new.X = X
return &new return &new
} }
case *ast.IndexExpr: case *ast.IndexExpr, *ast.MultiIndexExpr:
elems := typeparams.UnpackExpr(n.Index) ix := typeparams.UnpackIndexExpr(x)
var newElems []ast.Expr var newIndexes []ast.Expr
for i, elem := range elems { for i, index := range ix.Indices {
new := isubst(elem, smap) new := isubst(index, smap)
if new != elem { if new != index {
if newElems == nil { if newIndexes == nil {
newElems = make([]ast.Expr, len(elems)) newIndexes = make([]ast.Expr, len(ix.Indices))
copy(newElems, elems) copy(newIndexes, ix.Indices)
} }
newElems[i] = new newIndexes[i] = new
} }
} }
if newElems != nil { if newIndexes != nil {
index := typeparams.PackExpr(newElems) return typeparams.PackIndexExpr(ix.X, ix.Lbrack, newIndexes, ix.Rbrack)
new := *n
new.Index = index
return &new
} }
case *ast.ParenExpr: case *ast.ParenExpr:
return isubst(n.X, smap) // no need to keep parentheses return isubst(n.X, smap) // no need to keep parentheses

View file

@ -33,11 +33,11 @@ var _ A3
var x int var x int
type _ x /* ERROR not a type */ [int] type _ x /* ERROR not a type */ [int]
type _ int /* ERROR not a generic type */ [] type _ int /* ERROR not a generic type */ [] // ERROR expected type argument list
type _ myInt /* ERROR not a generic type */ [] type _ myInt /* ERROR not a generic type */ [] // ERROR expected type argument list
// TODO(gri) better error messages // TODO(gri) better error messages
type _ T1 /* ERROR got 0 arguments but 1 type parameters */ [] type _ T1[] // ERROR expected type argument list
type _ T1[x /* ERROR not a type */ ] type _ T1[x /* ERROR not a type */ ]
type _ T1 /* ERROR got 2 arguments but 1 type parameters */ [int, float32] type _ T1 /* ERROR got 2 arguments but 1 type parameters */ [int, float32]

View file

@ -10,7 +10,7 @@ func main() {
type N[T any] struct{} type N[T any] struct{}
var _ N /* ERROR "0 arguments but 1 type parameters" */ [] var _ N [] // ERROR expected type argument list
type I interface { type I interface {
~map[int]int | ~[]int ~map[int]int | ~[]int

View file

@ -261,13 +261,13 @@ func (check *Checker) typInternal(e0 ast.Expr, def *Named) (T Type) {
check.errorf(&x, _NotAType, "%s is not a type", &x) check.errorf(&x, _NotAType, "%s is not a type", &x)
} }
case *ast.IndexExpr: case *ast.IndexExpr, *ast.MultiIndexExpr:
ix := typeparams.UnpackIndexExpr(e)
if typeparams.Enabled { if typeparams.Enabled {
exprs := typeparams.UnpackExpr(e.Index) return check.instantiatedType(ix, def)
return check.instantiatedType(e.X, exprs, def)
} }
check.errorf(e0, _NotAType, "%s is not a type", e0) check.errorf(e0, _NotAType, "%s is not a type", e0)
check.use(e.X) check.use(ix.X)
case *ast.ParenExpr: case *ast.ParenExpr:
// Generic types must be instantiated before they can be used in any form. // Generic types must be instantiated before they can be used in any form.
@ -403,8 +403,8 @@ func (check *Checker) typeOrNil(e ast.Expr) Type {
return Typ[Invalid] return Typ[Invalid]
} }
func (check *Checker) instantiatedType(x ast.Expr, targs []ast.Expr, def *Named) Type { func (check *Checker) instantiatedType(ix *typeparams.IndexExpr, def *Named) Type {
b := check.genericType(x, true) // TODO(gri) what about cycles? b := check.genericType(ix.X, true) // TODO(gri) what about cycles?
if b == Typ[Invalid] { if b == Typ[Invalid] {
return b // error already reported return b // error already reported
} }
@ -420,19 +420,19 @@ func (check *Checker) instantiatedType(x ast.Expr, targs []ast.Expr, def *Named)
def.setUnderlying(typ) def.setUnderlying(typ)
typ.check = check typ.check = check
typ.pos = x.Pos() typ.pos = ix.X.Pos()
typ.base = base typ.base = base
// evaluate arguments (always) // evaluate arguments (always)
typ.targs = check.typeList(targs) typ.targs = check.typeList(ix.Indices)
if typ.targs == nil { if typ.targs == nil {
def.setUnderlying(Typ[Invalid]) // avoid later errors due to lazy instantiation def.setUnderlying(Typ[Invalid]) // avoid later errors due to lazy instantiation
return Typ[Invalid] return Typ[Invalid]
} }
// determine argument positions (for error reporting) // determine argument positions (for error reporting)
typ.poslist = make([]token.Pos, len(targs)) typ.poslist = make([]token.Pos, len(ix.Indices))
for i, arg := range targs { for i, arg := range ix.Indices {
typ.poslist[i] = arg.Pos() typ.poslist[i] = arg.Pos()
} }