mirror of
https://github.com/golang/go
synced 2024-10-06 08:00:07 +00:00
database/sql: use slices to simplify the code
This commit is contained in:
parent
9a028e14a5
commit
e713ac3163
|
@ -11,7 +11,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -120,12 +120,7 @@ type table struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *table) columnIndex(name string) int {
|
func (t *table) columnIndex(name string) int {
|
||||||
for n, nname := range t.colname {
|
return slices.Index(t.colname, name)
|
||||||
if name == nname {
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type row struct {
|
type row struct {
|
||||||
|
@ -217,15 +212,6 @@ func init() {
|
||||||
Register("test", fdriver)
|
Register("test", fdriver)
|
||||||
}
|
}
|
||||||
|
|
||||||
func contains(list []string, y string) bool {
|
|
||||||
for _, x := range list {
|
|
||||||
if x == y {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
type Dummy struct {
|
type Dummy struct {
|
||||||
driver.Driver
|
driver.Driver
|
||||||
}
|
}
|
||||||
|
@ -235,7 +221,7 @@ func TestDrivers(t *testing.T) {
|
||||||
Register("test", fdriver)
|
Register("test", fdriver)
|
||||||
Register("invalid", Dummy{})
|
Register("invalid", Dummy{})
|
||||||
all := Drivers()
|
all := Drivers()
|
||||||
if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
|
if len(all) < 2 || !slices.IsSorted(all) || !slices.Contains(all, "test") || !slices.Contains(all, "invalid") {
|
||||||
t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
|
t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -345,10 +331,8 @@ func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for n, cname := range t.colname {
|
if i := slices.Index(t.colname, column); i != -1 {
|
||||||
if cname == column {
|
return t.coltype[i], true
|
||||||
return t.coltype[n], true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
@ -823,6 +807,15 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
|
||||||
return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
|
return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func valueFromPlaceholderName(args []driver.NamedValue, name string) driver.Value {
|
||||||
|
for i := range args {
|
||||||
|
if args[i].Name == name {
|
||||||
|
return args[i].Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// When doInsert is true, add the row to the table.
|
// When doInsert is true, add the row to the table.
|
||||||
// When doInsert is false do prep-work and error checking, but don't
|
// When doInsert is false do prep-work and error checking, but don't
|
||||||
// actually add the row to the table.
|
// actually add the row to the table.
|
||||||
|
@ -857,11 +850,8 @@ func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.R
|
||||||
val = args[argPos].Value
|
val = args[argPos].Value
|
||||||
} else {
|
} else {
|
||||||
// Assign value from argument placeholder name.
|
// Assign value from argument placeholder name.
|
||||||
for _, a := range args {
|
if v := valueFromPlaceholderName(args, strvalue[1:]); v != nil {
|
||||||
if a.Name == strvalue[1:] {
|
val = v
|
||||||
val = a.Value
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
argPos++
|
argPos++
|
||||||
|
@ -997,12 +987,8 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
|
||||||
if wcol.Placeholder == "?" {
|
if wcol.Placeholder == "?" {
|
||||||
argValue = args[wcol.Ordinal-1].Value
|
argValue = args[wcol.Ordinal-1].Value
|
||||||
} else {
|
} else {
|
||||||
// Assign arg value from placeholder name.
|
if v := valueFromPlaceholderName(args, wcol.Placeholder[1:]); v != nil {
|
||||||
for _, a := range args {
|
argValue = v
|
||||||
if a.Name == wcol.Placeholder[1:] {
|
|
||||||
argValue = a.Value
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
|
if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
|
||||||
|
|
|
@ -24,7 +24,7 @@ import (
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sort"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -69,7 +69,7 @@ func Drivers() []string {
|
||||||
for name := range drivers {
|
for name := range drivers {
|
||||||
list = append(list, name)
|
list = append(list, name)
|
||||||
}
|
}
|
||||||
sort.Strings(list)
|
slices.Sort(list)
|
||||||
return list
|
return list
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3452,10 +3452,8 @@ func (r *Row) Scan(dest ...any) error {
|
||||||
// they were obtained from the network anyway) But for now we
|
// they were obtained from the network anyway) But for now we
|
||||||
// don't care.
|
// don't care.
|
||||||
defer r.rows.Close()
|
defer r.rows.Close()
|
||||||
for _, dp := range dest {
|
if scanArgsContainRawBytes(dest) {
|
||||||
if _, ok := dp.(*RawBytes); ok {
|
return errors.New("sql: RawBytes isn't allowed on Row.Scan")
|
||||||
return errors.New("sql: RawBytes isn't allowed on Row.Scan")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !r.rows.Next() {
|
if !r.rows.Next() {
|
||||||
|
|
|
@ -40,14 +40,7 @@ func init() {
|
||||||
freedFrom[c] = s
|
freedFrom[c] = s
|
||||||
}
|
}
|
||||||
putConnHook = func(db *DB, c *driverConn) {
|
putConnHook = func(db *DB, c *driverConn) {
|
||||||
idx := -1
|
if slices.Contains(db.freeConn, c) {
|
||||||
for i, v := range db.freeConn {
|
|
||||||
if v == c {
|
|
||||||
idx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if idx >= 0 {
|
|
||||||
// print before panic, as panic may get lost due to conflicting panic
|
// print before panic, as panic may get lost due to conflicting panic
|
||||||
// (all goroutines asleep) elsewhere, since we might not unlock
|
// (all goroutines asleep) elsewhere, since we might not unlock
|
||||||
// the mutex in freeConn here.
|
// the mutex in freeConn here.
|
||||||
|
@ -291,7 +284,7 @@ func TestQuery(t *testing.T) {
|
||||||
{age: 2, name: "Bob"},
|
{age: 2, name: "Bob"},
|
||||||
{age: 3, name: "Chris"},
|
{age: 3, name: "Chris"},
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, want) {
|
if !slices.Equal(got, want) {
|
||||||
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
|
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -355,7 +348,7 @@ func TestQueryContext(t *testing.T) {
|
||||||
{age: 1, name: "Alice"},
|
{age: 1, name: "Alice"},
|
||||||
{age: 2, name: "Bob"},
|
{age: 2, name: "Bob"},
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, want) {
|
if !slices.Equal(got, want) {
|
||||||
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
|
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -540,7 +533,7 @@ func TestMultiResultSetQuery(t *testing.T) {
|
||||||
{age: 2, name: "Bob"},
|
{age: 2, name: "Bob"},
|
||||||
{age: 3, name: "Chris"},
|
{age: 3, name: "Chris"},
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got1, want1) {
|
if !slices.Equal(got1, want1) {
|
||||||
t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
|
t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -566,7 +559,7 @@ func TestMultiResultSetQuery(t *testing.T) {
|
||||||
{name: "Bob"},
|
{name: "Bob"},
|
||||||
{name: "Chris"},
|
{name: "Chris"},
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got2, want2) {
|
if !slices.Equal(got2, want2) {
|
||||||
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
|
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
|
||||||
}
|
}
|
||||||
if rows.NextResultSet() {
|
if rows.NextResultSet() {
|
||||||
|
@ -614,7 +607,7 @@ func TestQueryNamedArg(t *testing.T) {
|
||||||
want := []row{
|
want := []row{
|
||||||
{age: 2, name: "Bob"},
|
{age: 2, name: "Bob"},
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, want) {
|
if !slices.Equal(got, want) {
|
||||||
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
|
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -724,7 +717,7 @@ func TestRowsColumns(t *testing.T) {
|
||||||
t.Fatalf("Columns: %v", err)
|
t.Fatalf("Columns: %v", err)
|
||||||
}
|
}
|
||||||
want := []string{"age", "name"}
|
want := []string{"age", "name"}
|
||||||
if !reflect.DeepEqual(cols, want) {
|
if !slices.Equal(cols, want) {
|
||||||
t.Errorf("got %#v; want %#v", cols, want)
|
t.Errorf("got %#v; want %#v", cols, want)
|
||||||
}
|
}
|
||||||
if err := rows.Close(); err != nil {
|
if err := rows.Close(); err != nil {
|
||||||
|
@ -827,7 +820,7 @@ func TestQueryRow(t *testing.T) {
|
||||||
t.Fatalf("photo QueryRow+Scan: %v", err)
|
t.Fatalf("photo QueryRow+Scan: %v", err)
|
||||||
}
|
}
|
||||||
want := []byte("APHOTO")
|
want := []byte("APHOTO")
|
||||||
if !reflect.DeepEqual(photo, want) {
|
if !slices.Equal(photo, want) {
|
||||||
t.Errorf("photo = %q; want %q", photo, want)
|
t.Errorf("photo = %q; want %q", photo, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue