Implement or-patterns in case clauses

You can now say

    expr_move(?dst, ?src) | expr_assign(?dst, ?src) { ... }

to match both expr_move and expr_assign. The names, types, and number
of bound names have to match in all the patterns.

Closes #449.
This commit is contained in:
Marijn Haverbeke 2011-07-08 16:27:55 +02:00
parent 4d325b1a15
commit 86ee3454a1
13 changed files with 168 additions and 40 deletions

View file

@ -300,7 +300,7 @@ fn walk_pat(&mutable vec[node_id] found, &@ast::pat p) {
case (_) { }
}
}
walk_pat(dnums, arm.pat);
walk_pat(dnums, arm.pats.(0));
ret dnums;
}

View file

@ -288,7 +288,7 @@ fn walk_constr(@env e, &@ast::constr c, &scopes sc, &vt[scopes] v) {
c.node.path.node.idents, ns_value));
}
fn walk_arm(@env e, &ast::arm a, &scopes sc, &vt[scopes] v) {
walk_pat(*e, sc, a.pat);
for (@ast::pat p in a.pats) { walk_pat(*e, sc, p); }
visit_arm_with_scope(a, sc, v);
}
fn walk_pat(&env e, &scopes sc, &@ast::pat pat) {
@ -648,7 +648,7 @@ fn in_scope(&env e, &span sp, &ident name, &scope s, namespace ns) ->
}
case (scope_block(?b)) { ret lookup_in_block(name, b.node, ns); }
case (scope_arm(?a)) {
if (ns == ns_value) { ret lookup_in_pat(name, *a.pat); }
if (ns == ns_value) { ret lookup_in_pat(name, *a.pats.(0)); }
}
}
ret none[def];
@ -1264,7 +1264,32 @@ fn walk_pat(checker ch, &@ast::pat p) {
case (_) { }
}
}
walk_pat(checker(*e, "binding"), a.pat);
auto ch0 = checker(*e, "binding");
walk_pat(ch0, a.pats.(0));
auto seen0 = ch0.seen;
auto i = ivec::len(a.pats);
while (i > 1u) {
i -= 1u;
auto ch = checker(*e, "binding");
walk_pat(ch, a.pats.(i));
// Ensure the bindings introduced in this pattern are the same as in
// the first pattern.
if (vec::len(ch.seen) != vec::len(seen0)) {
e.sess.span_err(a.pats.(i).span,
"inconsistent number of bindings");
} else {
for (ident name in ch.seen) {
if (option::is_none(vec::find(bind str::eq(name, _),
seen0))) {
// Fight the alias checker
auto name_ = name;
e.sess.span_err
(a.pats.(i).span, "binding " + name_ +
" does not occur in first pattern");
}
}
}
}
}
fn check_block(@env e, &ast::block b, &() x, &vt[()] v) {

View file

@ -4842,8 +4842,10 @@ fn trans_pat_match(&@block_ctxt cx, &@ast::pat pat, ValueRef llval,
}
}
type bind_map = hashmap[ast::ident, result];
fn trans_pat_binding(&@block_ctxt cx, &@ast::pat pat, ValueRef llval,
bool is_mem) -> result {
bool is_mem, &bind_map bound) -> result {
alt (pat.node) {
case (ast::pat_wild) { ret rslt(cx, llval); }
case (ast::pat_lit(_)) { ret rslt(cx, llval); }
@ -4853,8 +4855,9 @@ fn trans_pat_binding(&@block_ctxt cx, &@ast::pat pat, ValueRef llval,
val = spill_if_immediate
(cx, llval, node_id_type(cx.fcx.lcx.ccx, pat.id));
}
cx.fcx.lllocals.insert(pat.id, val);
ret rslt(cx, val);
auto r = rslt(cx, val);
bound.insert(name, r);
ret r;
}
case (ast::pat_tag(_, ?subpats)) {
if (std::ivec::len[@ast::pat](subpats) == 0u) {
@ -4887,7 +4890,7 @@ fn trans_pat_binding(&@block_ctxt cx, &@ast::pat pat, ValueRef llval,
ty_param_substs, i);
this_cx = rslt.bcx;
auto subpat_res =
trans_pat_binding(this_cx, subpat, rslt.val, true);
trans_pat_binding(this_cx, subpat, rslt.val, true, bound);
this_cx = subpat_res.bcx;
i += 1;
}
@ -4902,16 +4905,34 @@ fn trans_alt(&@block_ctxt cx, &@ast::expr expr, &ast::arm[] arms,
auto this_cx = expr_res.bcx;
let result[] arm_results = ~[];
for (ast::arm arm in arms) {
auto bind_maps = ~[];
auto block_cx = new_scope_block_ctxt(expr_res.bcx, "case block");
for (@ast::pat pat in arm.pats) {
auto next_cx = new_sub_block_ctxt(expr_res.bcx, "next");
auto match_res =
trans_pat_match(this_cx, arm.pat, expr_res.val, next_cx);
auto binding_res =
trans_pat_binding(match_res.bcx, arm.pat, expr_res.val, false);
auto block_cx = new_scope_block_ctxt(match_res.bcx, "case block");
trans_pat_match(this_cx, pat, expr_res.val, next_cx);
auto bind_map = new_str_hash[result]();
auto binding_res = trans_pat_binding
(match_res.bcx, pat, expr_res.val, false, bind_map);
bind_maps += ~[bind_map];
binding_res.bcx.build.Br(block_cx.llbb);
this_cx = next_cx;
}
// Go over the names and node_ids of the bound variables, add a Phi
// node for each and register the bindings.
for each (@tup(ast::ident, ast::node_id) item in
ast::pat_id_map(arm.pats.(0)).items()) {
auto vals = ~[]; auto llbbs = ~[];
for (bind_map map in bind_maps) {
auto rslt = map.get(item._0);
vals += ~[rslt.val];
llbbs += ~[rslt.bcx.llbb];
}
auto phi = block_cx.build.Phi(val_ty(vals.(0)), vals, llbbs);
block_cx.fcx.lllocals.insert(item._1, phi);
}
auto block_res = trans_block(block_cx, arm.block, output);
arm_results += ~[block_res];
this_cx = next_cx;
}
auto default_cx = this_cx;
trans_fail(default_cx, some[span](expr.span),

View file

@ -10,6 +10,7 @@
import util::common;
import syntax::codemap::span;
import std::map::new_int_hash;
import std::map::new_str_hash;
import util::common::new_def_hash;
import util::common::log_expr_err;
import middle::ty;
@ -1286,10 +1287,10 @@ fn check_lit(@crate_ctxt ccx, &@ast::lit lit) -> ty::t {
}
}
// Pattern checking is top-down rather than bottom-up so that bindings get
// their types immediately.
fn check_pat(&@fn_ctxt fcx, &@ast::pat pat, ty::t expected) {
fn check_pat(&@fn_ctxt fcx, &ast::pat_id_map map, &@ast::pat pat,
ty::t expected) {
alt (pat.node) {
case (ast::pat_wild) {
write::ty_only_fixup(fcx, pat.id, expected);
@ -1303,6 +1304,12 @@ fn check_pat(&@fn_ctxt fcx, &@ast::pat pat, ty::t expected) {
auto vid = lookup_local(fcx, pat.span, pat.id);
auto typ = ty::mk_var(fcx.ccx.tcx, vid);
typ = demand::simple(fcx, pat.span, expected, typ);
auto canon_id = map.get(name);
if (canon_id != pat.id) {
auto ct = ty::mk_var(fcx.ccx.tcx,
lookup_local(fcx, pat.span, canon_id));
typ = demand::simple(fcx, pat.span, ct, typ);
}
write::ty_only_fixup(fcx, pat.id, typ);
}
case (ast::pat_tag(?path, ?subpats)) {
@ -1358,7 +1365,7 @@ fn check_pat(&@fn_ctxt fcx, &@ast::pat pat, ty::t expected) {
auto i = 0u;
for (@ast::pat subpat in subpats) {
check_pat(fcx, subpat, arg_types.(i));
check_pat(fcx, map, subpat, arg_types.(i));
i += 1u;
}
} else if (subpats_len > 0u) {
@ -1884,10 +1891,11 @@ fn check_binop_type_compat(&@fn_ctxt fcx, span span,
// bindings.
auto pattern_ty = ty::expr_ty(fcx.ccx.tcx, expr);
let vec[@ast::pat] pats = [];
for (ast::arm arm in arms) {
check_pat(fcx, arm.pat, pattern_ty);
pats += [arm.pat];
auto id_map = ast::pat_id_map(arm.pats.(0));
for (@ast::pat p in arm.pats) {
check_pat(fcx, id_map, p, pattern_ty);
}
}
// Now typecheck the blocks.

View file

@ -125,6 +125,25 @@ fn def_id_of_def(def d) -> def_id {
pat_tag(path, (@pat)[]);
}
type pat_id_map = std::map::hashmap[str, ast::node_id];
// This is used because same-named variables in alternative patterns need to
// use the node_id of their namesake in the first pattern.
fn pat_id_map(&@pat pat) -> pat_id_map {
auto map = std::map::new_str_hash[node_id]();
fn walk(&pat_id_map map, &@pat pat) {
alt (pat.node) {
pat_bind(?name) { map.insert(name, pat.id); }
pat_tag(_, ?sub) {
for (@pat p in sub) { walk(map, p); }
}
_ {}
}
}
walk(map, pat);
ret map;
}
tag mutability { mut; imm; maybe_mut; }
tag layer { layer_value; layer_state; layer_gc; }
@ -227,7 +246,7 @@ fn unop_to_str(unop op) -> str {
tag decl_ { decl_local(@local); decl_item(@item); }
type arm = rec(@pat pat, block block);
type arm = rec((@pat)[] pats, block block);
type elt = rec(mutability mut, @expr expr);

View file

@ -255,7 +255,8 @@ fn noop_fold_stmt(&stmt_ s, ast_fold fld) -> stmt_ {
}
fn noop_fold_arm(&arm a, ast_fold fld) -> arm {
ret rec(pat=fld.fold_pat(a.pat), block=fld.fold_block(a.block));
ret rec(pats=ivec::map(fld.fold_pat, a.pats),
block=fld.fold_block(a.block));
}
fn noop_fold_pat(&pat_ p, ast_fold fld) -> pat_ {

View file

@ -1359,10 +1359,10 @@ fn parse_alt_expr(&parser p) -> @ast::expr {
eat_word(p, "case");
auto parens = false;
if (p.peek() == token::LPAREN) { parens = true; p.bump(); }
auto pat = parse_pat(p);
auto pats = parse_pats(p);
if (parens) { expect(p, token::RPAREN); }
auto block = parse_block(p);
arms += ~[rec(pat=pat, block=block)];
arms += ~[rec(pats=pats, block=block)];
}
auto hi = p.get_hi_pos();
p.bump();
@ -1405,7 +1405,6 @@ fn parse_initializer(&parser p) -> option::t[ast::initializer] {
p.bump();
ret some(rec(op=ast::init_move, expr=parse_expr(p)));
}
case (
// Now that the the channel is the first argument to receive,
// combining it with an initializer doesn't really make sense.
// case (token::RECV) {
@ -1413,12 +1412,25 @@ fn parse_initializer(&parser p) -> option::t[ast::initializer] {
// ret some(rec(op = ast::init_recv,
// expr = parse_expr(p)));
// }
_) {
case (_) {
ret none;
}
}
}
fn parse_pats(&parser p) -> (@ast::pat)[] {
auto pats = ~[];
while (true) {
pats += ~[parse_pat(p)];
if (p.peek() == token::BINOP(token::OR)) {
p.bump();
} else {
break;
}
}
ret pats;
}
fn parse_pat(&parser p) -> @ast::pat {
auto lo = p.get_lo_pos();
auto hi = p.get_hi_pos();

View file

@ -832,9 +832,12 @@ fn print_opt(&ps s, &option::t[@ast::expr] expr) {
for (ast::arm arm in arms) {
space(s.s);
head(s, "case");
popen(s);
print_pat(s, arm.pat);
pclose(s);
auto first = true;
for (@ast::pat p in arm.pats) {
if (first) { first = false; }
else { space(s.s); word_space(s, "|"); }
print_pat(s, p);
}
space(s.s);
print_block(s, arm.block);
}

View file

@ -394,7 +394,7 @@ fn visit_expr[E](&@expr ex, &E e, &vt[E] v) {
}
fn visit_arm[E](&arm a, &E e, &vt[E] v) {
vt(v).visit_pat(a.pat, e, v);
for (@pat p in a.pats) { vt(v).visit_pat(p, e, v); }
vt(v).visit_block(a.block, e, v);
}
// Local Variables:

View file

@ -334,7 +334,7 @@ fn walk_expr(&ast_visitor v, @ast::expr e) {
case (ast::expr_alt(?x, ?arms)) {
walk_expr(v, x);
for (ast::arm a in arms) {
walk_pat(v, a.pat);
for (@ast::pat p in a.pats) { walk_pat(v, p); }
v.visit_arm_pre(a);
walk_block(v, a.block);
v.visit_arm_post(a);

View file

@ -206,6 +206,15 @@ fn all[T](fn(&T)->bool f, &T[] v) -> bool {
ret true;
}
fn member[T](&T x, &T[] v) -> bool {
for (T elt in v) { if (x == elt) { ret true; } }
ret false;
}
fn find[T](fn(&T) -> bool f, &T[] v) -> option::t[T] {
for (T elt in v) { if (f(elt)) { ret some[T](elt); } }
ret none[T];
}
mod unsafe {
fn copy_from_buf[T](&mutable T[] v, *T ptr, uint count) {

View file

@ -0,0 +1,12 @@
// error-pattern: mismatched types
tag blah {
a(int, int, uint);
b(int, int);
}
fn main() {
alt a(1, 1, 2u) {
a(_, ?x, ?y) | b(?x, ?y) { }
}
}

View file

@ -0,0 +1,18 @@
tag blah {
a(int, int, uint);
b(int, int);
c;
}
fn or_alt(&blah q) -> int {
alt q {
a(?x, ?y, _) | b(?x, ?y) { ret x + y;}
c { ret 0; }
}
}
fn main() {
assert or_alt(c) == 0;
assert or_alt(a(10, 100, 0u)) == 110;
assert or_alt(b(20, 200)) == 220;
}