database/sql: use slices to simplify the code

This commit is contained in:
apocelipes 2024-03-30 04:06:54 +08:00
parent 9a028e14a5
commit e713ac3163
3 changed files with 30 additions and 53 deletions

View file

@ -11,7 +11,7 @@ import (
"fmt" "fmt"
"io" "io"
"reflect" "reflect"
"sort" "slices"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -120,12 +120,7 @@ type table struct {
} }
func (t *table) columnIndex(name string) int { func (t *table) columnIndex(name string) int {
for n, nname := range t.colname { return slices.Index(t.colname, name)
if name == nname {
return n
}
}
return -1
} }
type row struct { type row struct {
@ -217,15 +212,6 @@ func init() {
Register("test", fdriver) Register("test", fdriver)
} }
func contains(list []string, y string) bool {
for _, x := range list {
if x == y {
return true
}
}
return false
}
type Dummy struct { type Dummy struct {
driver.Driver driver.Driver
} }
@ -235,7 +221,7 @@ func TestDrivers(t *testing.T) {
Register("test", fdriver) Register("test", fdriver)
Register("invalid", Dummy{}) Register("invalid", Dummy{})
all := Drivers() all := Drivers()
if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { if len(all) < 2 || !slices.IsSorted(all) || !slices.Contains(all, "test") || !slices.Contains(all, "invalid") {
t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
} }
} }
@ -345,10 +331,8 @@ func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
if !ok { if !ok {
return return
} }
for n, cname := range t.colname { if i := slices.Index(t.colname, column); i != -1 {
if cname == column { return t.coltype[i], true
return t.coltype[n], true
}
} }
return "", false return "", false
} }
@ -823,6 +807,15 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd) return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
} }
func valueFromPlaceholderName(args []driver.NamedValue, name string) driver.Value {
for i := range args {
if args[i].Name == name {
return args[i].Value
}
}
return nil
}
// When doInsert is true, add the row to the table. // When doInsert is true, add the row to the table.
// When doInsert is false do prep-work and error checking, but don't // When doInsert is false do prep-work and error checking, but don't
// actually add the row to the table. // actually add the row to the table.
@ -857,11 +850,8 @@ func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.R
val = args[argPos].Value val = args[argPos].Value
} else { } else {
// Assign value from argument placeholder name. // Assign value from argument placeholder name.
for _, a := range args { if v := valueFromPlaceholderName(args, strvalue[1:]); v != nil {
if a.Name == strvalue[1:] { val = v
val = a.Value
break
}
} }
} }
argPos++ argPos++
@ -997,12 +987,8 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
if wcol.Placeholder == "?" { if wcol.Placeholder == "?" {
argValue = args[wcol.Ordinal-1].Value argValue = args[wcol.Ordinal-1].Value
} else { } else {
// Assign arg value from placeholder name. if v := valueFromPlaceholderName(args, wcol.Placeholder[1:]); v != nil {
for _, a := range args { argValue = v
if a.Name == wcol.Placeholder[1:] {
argValue = a.Value
break
}
} }
} }
if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {

View file

@ -24,7 +24,7 @@ import (
"math/rand/v2" "math/rand/v2"
"reflect" "reflect"
"runtime" "runtime"
"sort" "slices"
"strconv" "strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -69,7 +69,7 @@ func Drivers() []string {
for name := range drivers { for name := range drivers {
list = append(list, name) list = append(list, name)
} }
sort.Strings(list) slices.Sort(list)
return list return list
} }
@ -3452,11 +3452,9 @@ func (r *Row) Scan(dest ...any) error {
// they were obtained from the network anyway) But for now we // they were obtained from the network anyway) But for now we
// don't care. // don't care.
defer r.rows.Close() defer r.rows.Close()
for _, dp := range dest { if scanArgsContainRawBytes(dest) {
if _, ok := dp.(*RawBytes); ok {
return errors.New("sql: RawBytes isn't allowed on Row.Scan") return errors.New("sql: RawBytes isn't allowed on Row.Scan")
} }
}
if !r.rows.Next() { if !r.rows.Next() {
if err := r.rows.Err(); err != nil { if err := r.rows.Err(); err != nil {

View file

@ -40,14 +40,7 @@ func init() {
freedFrom[c] = s freedFrom[c] = s
} }
putConnHook = func(db *DB, c *driverConn) { putConnHook = func(db *DB, c *driverConn) {
idx := -1 if slices.Contains(db.freeConn, c) {
for i, v := range db.freeConn {
if v == c {
idx = i
break
}
}
if idx >= 0 {
// print before panic, as panic may get lost due to conflicting panic // print before panic, as panic may get lost due to conflicting panic
// (all goroutines asleep) elsewhere, since we might not unlock // (all goroutines asleep) elsewhere, since we might not unlock
// the mutex in freeConn here. // the mutex in freeConn here.
@ -291,7 +284,7 @@ func TestQuery(t *testing.T) {
{age: 2, name: "Bob"}, {age: 2, name: "Bob"},
{age: 3, name: "Chris"}, {age: 3, name: "Chris"},
} }
if !reflect.DeepEqual(got, want) { if !slices.Equal(got, want) {
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
} }
@ -355,7 +348,7 @@ func TestQueryContext(t *testing.T) {
{age: 1, name: "Alice"}, {age: 1, name: "Alice"},
{age: 2, name: "Bob"}, {age: 2, name: "Bob"},
} }
if !reflect.DeepEqual(got, want) { if !slices.Equal(got, want) {
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
} }
@ -540,7 +533,7 @@ func TestMultiResultSetQuery(t *testing.T) {
{age: 2, name: "Bob"}, {age: 2, name: "Bob"},
{age: 3, name: "Chris"}, {age: 3, name: "Chris"},
} }
if !reflect.DeepEqual(got1, want1) { if !slices.Equal(got1, want1) {
t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1) t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
} }
@ -566,7 +559,7 @@ func TestMultiResultSetQuery(t *testing.T) {
{name: "Bob"}, {name: "Bob"},
{name: "Chris"}, {name: "Chris"},
} }
if !reflect.DeepEqual(got2, want2) { if !slices.Equal(got2, want2) {
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2) t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
} }
if rows.NextResultSet() { if rows.NextResultSet() {
@ -614,7 +607,7 @@ func TestQueryNamedArg(t *testing.T) {
want := []row{ want := []row{
{age: 2, name: "Bob"}, {age: 2, name: "Bob"},
} }
if !reflect.DeepEqual(got, want) { if !slices.Equal(got, want) {
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
} }
@ -724,7 +717,7 @@ func TestRowsColumns(t *testing.T) {
t.Fatalf("Columns: %v", err) t.Fatalf("Columns: %v", err)
} }
want := []string{"age", "name"} want := []string{"age", "name"}
if !reflect.DeepEqual(cols, want) { if !slices.Equal(cols, want) {
t.Errorf("got %#v; want %#v", cols, want) t.Errorf("got %#v; want %#v", cols, want)
} }
if err := rows.Close(); err != nil { if err := rows.Close(); err != nil {
@ -827,7 +820,7 @@ func TestQueryRow(t *testing.T) {
t.Fatalf("photo QueryRow+Scan: %v", err) t.Fatalf("photo QueryRow+Scan: %v", err)
} }
want := []byte("APHOTO") want := []byte("APHOTO")
if !reflect.DeepEqual(photo, want) { if !slices.Equal(photo, want) {
t.Errorf("photo = %q; want %q", photo, want) t.Errorf("photo = %q; want %q", photo, want)
} }
} }