From e713ac31638671f60cc3cf62fa514f784e834e66 Mon Sep 17 00:00:00 2001 From: apocelipes Date: Sat, 30 Mar 2024 04:06:54 +0800 Subject: [PATCH] database/sql: use slices to simplify the code --- src/database/sql/fakedb_test.go | 50 ++++++++++++--------------------- src/database/sql/sql.go | 10 +++---- src/database/sql/sql_test.go | 23 ++++++--------- 3 files changed, 30 insertions(+), 53 deletions(-) diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index c6c3172b5c..3dfcd447b5 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -11,7 +11,7 @@ import ( "fmt" "io" "reflect" - "sort" + "slices" "strconv" "strings" "sync" @@ -120,12 +120,7 @@ type table struct { } func (t *table) columnIndex(name string) int { - for n, nname := range t.colname { - if name == nname { - return n - } - } - return -1 + return slices.Index(t.colname, name) } type row struct { @@ -217,15 +212,6 @@ func init() { Register("test", fdriver) } -func contains(list []string, y string) bool { - for _, x := range list { - if x == y { - return true - } - } - return false -} - type Dummy struct { driver.Driver } @@ -235,7 +221,7 @@ func TestDrivers(t *testing.T) { Register("test", fdriver) Register("invalid", Dummy{}) all := Drivers() - if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { + if len(all) < 2 || !slices.IsSorted(all) || !slices.Contains(all, "test") || !slices.Contains(all, "invalid") { t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) } } @@ -345,10 +331,8 @@ func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { if !ok { return } - for n, cname := range t.colname { - if cname == column { - return t.coltype[n], true - } + if i := slices.Index(t.colname, column); i != -1 { + return t.coltype[i], true } return "", false } @@ -823,6 +807,15 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd) } +func valueFromPlaceholderName(args []driver.NamedValue, name string) driver.Value { + for i := range args { + if args[i].Name == name { + return args[i].Value + } + } + return nil +} + // When doInsert is true, add the row to the table. // When doInsert is false do prep-work and error checking, but don't // actually add the row to the table. @@ -857,11 +850,8 @@ func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.R val = args[argPos].Value } else { // Assign value from argument placeholder name. - for _, a := range args { - if a.Name == strvalue[1:] { - val = a.Value - break - } + if v := valueFromPlaceholderName(args, strvalue[1:]); v != nil { + val = v } } argPos++ @@ -997,12 +987,8 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( if wcol.Placeholder == "?" { argValue = args[wcol.Ordinal-1].Value } else { - // Assign arg value from placeholder name. - for _, a := range args { - if a.Name == wcol.Placeholder[1:] { - argValue = a.Value - break - } + if v := valueFromPlaceholderName(args, wcol.Placeholder[1:]); v != nil { + argValue = v } } if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 36995a1059..5b4a3f5409 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -24,7 +24,7 @@ import ( "math/rand/v2" "reflect" "runtime" - "sort" + "slices" "strconv" "sync" "sync/atomic" @@ -69,7 +69,7 @@ func Drivers() []string { for name := range drivers { list = append(list, name) } - sort.Strings(list) + slices.Sort(list) return list } @@ -3452,10 +3452,8 @@ func (r *Row) Scan(dest ...any) error { // they were obtained from the network anyway) But for now we // don't care. defer r.rows.Close() - for _, dp := range dest { - if _, ok := dp.(*RawBytes); ok { - return errors.New("sql: RawBytes isn't allowed on Row.Scan") - } + if scanArgsContainRawBytes(dest) { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") } if !r.rows.Next() { diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 7bf3ebbe08..25ca5ff0ad 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -40,14 +40,7 @@ func init() { freedFrom[c] = s } putConnHook = func(db *DB, c *driverConn) { - idx := -1 - for i, v := range db.freeConn { - if v == c { - idx = i - break - } - } - if idx >= 0 { + if slices.Contains(db.freeConn, c) { // print before panic, as panic may get lost due to conflicting panic // (all goroutines asleep) elsewhere, since we might not unlock // the mutex in freeConn here. @@ -291,7 +284,7 @@ func TestQuery(t *testing.T) { {age: 2, name: "Bob"}, {age: 3, name: "Chris"}, } - if !reflect.DeepEqual(got, want) { + if !slices.Equal(got, want) { t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) } @@ -355,7 +348,7 @@ func TestQueryContext(t *testing.T) { {age: 1, name: "Alice"}, {age: 2, name: "Bob"}, } - if !reflect.DeepEqual(got, want) { + if !slices.Equal(got, want) { t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) } @@ -540,7 +533,7 @@ func TestMultiResultSetQuery(t *testing.T) { {age: 2, name: "Bob"}, {age: 3, name: "Chris"}, } - if !reflect.DeepEqual(got1, want1) { + if !slices.Equal(got1, want1) { t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1) } @@ -566,7 +559,7 @@ func TestMultiResultSetQuery(t *testing.T) { {name: "Bob"}, {name: "Chris"}, } - if !reflect.DeepEqual(got2, want2) { + if !slices.Equal(got2, want2) { t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2) } if rows.NextResultSet() { @@ -614,7 +607,7 @@ func TestQueryNamedArg(t *testing.T) { want := []row{ {age: 2, name: "Bob"}, } - if !reflect.DeepEqual(got, want) { + if !slices.Equal(got, want) { t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) } @@ -724,7 +717,7 @@ func TestRowsColumns(t *testing.T) { t.Fatalf("Columns: %v", err) } want := []string{"age", "name"} - if !reflect.DeepEqual(cols, want) { + if !slices.Equal(cols, want) { t.Errorf("got %#v; want %#v", cols, want) } if err := rows.Close(); err != nil { @@ -827,7 +820,7 @@ func TestQueryRow(t *testing.T) { t.Fatalf("photo QueryRow+Scan: %v", err) } want := []byte("APHOTO") - if !reflect.DeepEqual(photo, want) { + if !slices.Equal(photo, want) { t.Errorf("photo = %q; want %q", photo, want) } }