diff --git a/src/database/sql/convert.go b/src/database/sql/convert.go index 630a585ab2..4983181fe7 100644 --- a/src/database/sql/convert.go +++ b/src/database/sql/convert.go @@ -12,6 +12,7 @@ import ( "fmt" "reflect" "strconv" + "sync" "time" "unicode" "unicode/utf8" @@ -37,86 +38,180 @@ func validateNamedValueName(name string) error { return fmt.Errorf("name %q does not begin with a letter", name) } +func driverNumInput(ds *driverStmt) int { + ds.Lock() + defer ds.Unlock() // in case NumInput panics + return ds.si.NumInput() +} + +// ccChecker wraps the driver.ColumnConverter and allows it to be used +// as if it were a NamedValueChecker. If the driver ColumnConverter +// is not present then the NamedValueChecker will return driver.ErrSkip. +type ccChecker struct { + sync.Locker + cci driver.ColumnConverter + want int +} + +func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error { + if c.cci == nil { + return driver.ErrSkip + } + // The column converter shouldn't be called on any index + // it isn't expecting. The final error will be thrown + // in the argument converter loop. + index := nv.Ordinal - 1 + if c.want <= index { + return nil + } + + // First, see if the value itself knows how to convert + // itself to a driver type. For example, a NullString + // struct changing into a string or nil. + if vr, ok := nv.Value.(driver.Valuer); ok { + sv, err := callValuerValue(vr) + if err != nil { + return err + } + if !driver.IsValue(sv) { + return fmt.Errorf("non-subset type %T returned from Value", sv) + } + nv.Value = sv + } + + // Second, ask the column to sanity check itself. For + // example, drivers might use this to make sure that + // an int64 values being inserted into a 16-bit + // integer field is in range (before getting + // truncated), or that a nil can't go into a NOT NULL + // column before going across the network to get the + // same error. + var err error + arg := nv.Value + c.Lock() + nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg) + c.Unlock() + if err != nil { + return err + } + if !driver.IsValue(nv.Value) { + return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value) + } + return nil +} + +// defaultCheckNamedValue wraps the default ColumnConverter to have the same +// function signature as the CheckNamedValue in the driver.NamedValueChecker +// interface. +func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value) + return err +} + // driverArgs converts arguments from callers of Stmt.Exec and // Stmt.Query into driver Values. // // The statement ds may be nil, if no statement is available. -func driverArgs(ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { +func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { nvargs := make([]driver.NamedValue, len(args)) + + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + want := -1 + var si driver.Stmt + var cc ccChecker if ds != nil { si = ds.si + want = driverNumInput(ds) + cc.Locker = ds.Locker + cc.want = want } - cc, ok := si.(driver.ColumnConverter) - // Normal path, for a driver.Stmt that is not a ColumnConverter. + // Check all types of interfaces from the start. + // Drivers may opt to use the NamedValueChecker for special + // argument types, then return driver.ErrSkip to pass it along + // to the column converter. + nvc, ok := si.(driver.NamedValueChecker) if !ok { - for n, arg := range args { - var err error - nv := &nvargs[n] - nv.Ordinal = n + 1 - if np, ok := arg.(NamedArg); ok { - if err := validateNamedValueName(np.Name); err != nil { - return nil, err - } - arg = np.Value - nvargs[n].Name = np.Name - } - nv.Value, err = driver.DefaultParameterConverter.ConvertValue(arg) - - if err != nil { - return nil, fmt.Errorf("sql: converting Exec argument %s type: %v", describeNamedValue(nv), err) - } - } - return nvargs, nil + nvc, ok = ci.(driver.NamedValueChecker) + } + cci, ok := si.(driver.ColumnConverter) + if ok { + cc.cci = cci } - // Let the Stmt convert its own arguments. - for n, arg := range args { + // Loop through all the arguments, checking each one. + // If no error is returned simply increment the index + // and continue. However if driver.ErrRemoveArgument + // is returned the argument is not included in the query + // argument list. + var err error + var n int + for _, arg := range args { nv := &nvargs[n] - nv.Ordinal = n + 1 if np, ok := arg.(NamedArg); ok { - if err := validateNamedValueName(np.Name); err != nil { + if err = validateNamedValueName(np.Name); err != nil { return nil, err } arg = np.Value nv.Name = np.Name } - // First, see if the value itself knows how to convert - // itself to a driver type. For example, a NullString - // struct changing into a string or nil. - if vr, ok := arg.(driver.Valuer); ok { - sv, err := callValuerValue(vr) - if err != nil { - return nil, fmt.Errorf("sql: argument %s from Value: %v", describeNamedValue(nv), err) - } - if !driver.IsValue(sv) { - return nil, fmt.Errorf("sql: argument %s: non-subset type %T returned from Value", describeNamedValue(nv), sv) - } - arg = sv + nv.Ordinal = n + 1 + nv.Value = arg + + // Checking sequence has four routes: + // A: 1. Default + // B: 1. NamedValueChecker 2. Column Converter 3. Default + // C: 1. NamedValueChecker 3. Default + // D: 1. Column Converter 2. Default + // + // The only time a Column Converter is called is first + // or after NamedValueConverter. If first it is handled before + // the nextCheck label. Thus for repeats tries only when the + // NamedValueConverter is selected should the Column Converter + // be used in the retry. + checker := defaultCheckNamedValue + nextCC := false + switch { + case nvc != nil: + nextCC = cci != nil + checker = nvc.CheckNamedValue + case cci != nil: + checker = cc.CheckNamedValue } - // Second, ask the column to sanity check itself. For - // example, drivers might use this to make sure that - // an int64 values being inserted into a 16-bit - // integer field is in range (before getting - // truncated), or that a nil can't go into a NOT NULL - // column before going across the network to get the - // same error. - var err error - ds.Lock() - nv.Value, err = cc.ColumnConverter(n).ConvertValue(arg) - ds.Unlock() - if err != nil { + nextCheck: + err = checker(nv) + switch err { + case nil: + n++ + continue + case driver.ErrRemoveArgument: + nvargs = nvargs[:len(nvargs)-1] + continue + case driver.ErrSkip: + if nextCC { + nextCC = false + checker = cc.CheckNamedValue + } else { + checker = defaultCheckNamedValue + } + goto nextCheck + default: return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err) } - if !driver.IsValue(nv.Value) { - return nil, fmt.Errorf("sql: for argument %s, driver ColumnConverter error converted %T to unsupported type %T", - describeNamedValue(nv), arg, nv.Value) - } + } + + // Check the length of arguments after convertion to allow for omitted + // arguments. + if want != -1 && len(nvargs) != want { + return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs)) } return nvargs, nil + } // convertAssign copies to dest the value in src, converting it if possible. diff --git a/src/database/sql/convert_test.go b/src/database/sql/convert_test.go index 853a12ce95..cfe52d7f54 100644 --- a/src/database/sql/convert_test.go +++ b/src/database/sql/convert_test.go @@ -10,6 +10,7 @@ import ( "reflect" "runtime" "strings" + "sync" "testing" "time" ) @@ -468,8 +469,8 @@ func TestDriverArgs(t *testing.T) { }, } for i, tt := range tests { - ds := new(driverStmt) - got, err := driverArgs(ds, tt.args) + ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}} + got, err := driverArgs(nil, ds, tt.args) if err != nil { t.Errorf("test[%d]: %v", i, err) continue diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go index d66196fd48..0262ca24ba 100644 --- a/src/database/sql/driver/driver.go +++ b/src/database/sql/driver/driver.go @@ -262,9 +262,39 @@ type StmtQueryContext interface { QueryContext(ctx context.Context, args []NamedValue) (Rows, error) } +// ErrRemoveArgument may be returned from NamedValueChecker to instruct the +// sql package to not pass the argument to the driver query interface. +// Return when accepting query specific options or structures that aren't +// SQL query arguments. +var ErrRemoveArgument = errors.New("driver: remove argument from query") + +// NamedValueChecker may be optionally implemented by Conn or Stmt. It provides +// the driver more control to handle Go and database types beyond the default +// Values types allowed. +// +// The sql package checks for value checkers in the following order, +// stopping at the first found match: Stmt.NamedValueChecker, Conn.NamedValueChecker, +// Stmt.ColumnConverter, DefaultParameterConverter. +// +// If CheckNamedValue returns ErrRemoveArgument, the NamedValue will not be included in +// the final query arguments. This may be used to pass special options to +// the query itself. +// +// If ErrSkip is returned the column converter error checking +// path is used for the argument. Drivers may wish to return ErrSkip after +// they have exhausted their own special cases. +type NamedValueChecker interface { + // CheckNamedValue is called before passing arguments to the driver + // and is called in place of any ColumnConverter. CheckNamedValue must do type + // validation and conversion as appropriate for the driver. + CheckNamedValue(*NamedValue) error +} + // ColumnConverter may be optionally implemented by Stmt if the // statement is aware of its own columns' types and can convert from // any type to a driver Value. +// +// Deprecated: Drivers should implement NamedValueChecker. type ColumnConverter interface { // ColumnConverter returns a ValueConverter for the provided // column index. If the type of a specific column isn't known diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index 4b15f5bec7..1c95c35a68 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -58,9 +58,10 @@ type fakeDriver struct { type fakeDB struct { name string - mu sync.Mutex - tables map[string]*table - badConn bool + mu sync.Mutex + tables map[string]*table + badConn bool + allowAny bool } type table struct { @@ -352,12 +353,14 @@ func (c *fakeConn) Close() (err error) { return nil } -func checkSubsetTypes(args []driver.NamedValue) error { +func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { for _, arg := range args { switch arg.Value.(type) { case int64, float64, bool, nil, []byte, string, time.Time: default: - return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) + if !allowAny { + return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) + } } } return nil @@ -373,7 +376,7 @@ func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver. // just to check that all the args are of the proper types. // ErrSkip is returned so the caller acts as if we didn't // implement this at all. - err := checkSubsetTypes(args) + err := checkSubsetTypes(c.db.allowAny, args) if err != nil { return nil, err } @@ -390,7 +393,7 @@ func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver // just to check that all the args are of the proper types. // ErrSkip is returned so the caller acts as if we didn't // implement this at all. - err := checkSubsetTypes(args) + err := checkSubsetTypes(c.db.allowAny, args) if err != nil { return nil, err } @@ -642,7 +645,7 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d return nil, driver.ErrBadConn } - err := checkSubsetTypes(args) + err := checkSubsetTypes(s.c.db.allowAny, args) if err != nil { return nil, err } @@ -753,7 +756,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( return nil, driver.ErrBadConn } - err := checkSubsetTypes(args) + err := checkSubsetTypes(s.c.db.allowAny, args) if err != nil { return nil, err } @@ -1004,6 +1007,12 @@ func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { return fmt.Sprintf("%v", v), nil } +type anyTypeConverter struct{} + +func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) { + return v, nil +} + func converterForType(typ string) driver.ValueConverter { switch typ { case "bool": @@ -1030,6 +1039,8 @@ func converterForType(typ string) driver.ValueConverter { return driver.Null{Converter: driver.DefaultParameterConverter} case "datetime": return driver.DefaultParameterConverter + case "any": + return anyTypeConverter{} } panic("invalid fakedb column type of " + typ) } @@ -1056,6 +1067,8 @@ func colTypeToReflectType(typ string) reflect.Type { return reflect.TypeOf(NullFloat64{}) case "datetime": return reflect.TypeOf(time.Time{}) + case "any": + return reflect.TypeOf(new(interface{})).Elem() } panic("invalid fakedb column type of " + typ) } diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 03f66c6cb7..35a74bbdb3 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -278,6 +278,27 @@ type Scanner interface { Scan(src interface{}) error } +// Out may be used to retrieve OUTPUT value parameters from stored procedures. +// +// Not all drivers and databases support OUTPUT value parameters. +// +// Example usage: +// +// var outArg string +// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", Out{Dest: &outArg})) +type Out struct { + _Named_Fields_Required struct{} + + // Dest is a pointer to the value that will be set to the result of the + // stored procedure's OUTPUT parameter. + Dest interface{} + + // In is whether the parameter is an INOUT parameter. If so, the input value to the stored + // procedure is the dereferenced value of Dest's pointer, which is then replaced with + // the output value. + In bool +} + // ErrNoRows is returned by Scan when QueryRow doesn't return a // row. In such a case, QueryRow returns a placeholder *Row value that // defers this error until a Scan. @@ -1206,7 +1227,7 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q }() if execer, ok := dc.ci.(driver.Execer); ok { var dargs []driver.NamedValue - dargs, err = driverArgs(nil, args) + dargs, err = driverArgs(dc.ci, nil, args) if err != nil { return nil, err } @@ -1231,7 +1252,7 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q } ds := &driverStmt{Locker: dc, si: si} defer ds.Close() - return resultFromStatement(ctx, ds, args...) + return resultFromStatement(ctx, dc.ci, ds, args...) } // QueryContext executes a query that returns rows, typically a SELECT. @@ -1270,7 +1291,7 @@ func (db *DB) query(ctx context.Context, query string, args []interface{}, strat // The connection gets released by the releaseConn function. func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { if queryer, ok := dc.ci.(driver.Queryer); ok { - dargs, err := driverArgs(nil, args) + dargs, err := driverArgs(dc.ci, nil, args) if err != nil { releaseConn(err) return nil, err @@ -1307,7 +1328,7 @@ func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(erro } ds := &driverStmt{Locker: dc, si: si} - rowsi, err := rowsiFromStatement(ctx, ds, args...) + rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...) if err != nil { ds.Close() releaseConn(err) @@ -2009,7 +2030,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er var res Result for i := 0; i < maxBadConnRetries; i++ { - _, releaseConn, ds, err := s.connStmt(ctx) + dc, releaseConn, ds, err := s.connStmt(ctx) if err != nil { if err == driver.ErrBadConn { continue @@ -2017,7 +2038,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er return nil, err } - res, err = resultFromStatement(ctx, ds, args...) + res, err = resultFromStatement(ctx, dc.ci, ds, args...) releaseConn(err) if err != driver.ErrBadConn { return res, err @@ -2032,23 +2053,8 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { return s.ExecContext(context.Background(), args...) } -func driverNumInput(ds *driverStmt) int { - ds.Lock() - defer ds.Unlock() // in case NumInput panics - return ds.si.NumInput() -} - -func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) { - want := driverNumInput(ds) - - // -1 means the driver doesn't know how to count the number of - // placeholders, so we won't sanity check input here and instead let the - // driver deal with errors. - if want != -1 && len(args) != want { - return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) - } - - dargs, err := driverArgs(ds, args) +func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) { + dargs, err := driverArgs(ci, ds, args) if err != nil { return nil, err } @@ -2174,7 +2180,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er return nil, err } - rowsi, err = rowsiFromStatement(ctx, ds, args...) + rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...) if err == nil { // Note: ownership of ci passes to the *Rows, to be freed // with releaseConn. @@ -2211,7 +2217,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return s.QueryContext(context.Background(), args...) } -func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) { +func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) { var want int withLock(ds, func() { want = ds.si.NumInput() @@ -2224,7 +2230,7 @@ func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{} return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) } - dargs, err := driverArgs(ds, args) + dargs, err := driverArgs(ci, ds, args) if err != nil { return nil, err } diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 2fd81f29a5..b7fdc8eb6c 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -3191,6 +3191,131 @@ func TestConnectionLeak(t *testing.T) { wg.Wait() } +type nvcDriver struct { + fakeDriver + skipNamedValueCheck bool +} + +func (d *nvcDriver) Open(dsn string) (driver.Conn, error) { + c, err := d.fakeDriver.Open(dsn) + fc := c.(*fakeConn) + fc.db.allowAny = true + return &nvcConn{fc, d.skipNamedValueCheck}, err +} + +type nvcConn struct { + *fakeConn + skipNamedValueCheck bool +} + +type decimal struct { + value int +} + +type doNotInclude struct{} + +var _ driver.NamedValueChecker = &nvcConn{} + +func (c *nvcConn) CheckNamedValue(nv *driver.NamedValue) error { + if c.skipNamedValueCheck { + return driver.ErrSkip + } + switch v := nv.Value.(type) { + default: + return driver.ErrSkip + case Out: + switch ov := v.Dest.(type) { + default: + return errors.New("unkown NameValueCheck OUTPUT type") + case *string: + *ov = "from-server" + nv.Value = "OUT:*string" + } + return nil + case decimal, []int64: + return nil + case doNotInclude: + return driver.ErrRemoveArgument + } +} + +func TestNamedValueChecker(t *testing.T) { + Register("NamedValueCheck", &nvcDriver{}) + db, err := Open("NamedValueCheck", "") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err = db.ExecContext(ctx, "WIPE") + if err != nil { + t.Fatal("exec wipe", err) + } + + _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any,str1=string,out1=string,array1=any") + if err != nil { + t.Fatal("exec create", err) + } + + o1 := "" + _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A,str1=?,out1=?O1,array1=?", Named("A", decimal{123}), "hello", Named("O1", Out{Dest: &o1}), []int64{42, 128, 707}, doNotInclude{}) + if err != nil { + t.Fatal("exec insert", err) + } + var ( + str1 string + dec1 decimal + arr1 []int64 + ) + err = db.QueryRowContext(ctx, "SELECT|keys|dec1,str1,array1|").Scan(&dec1, &str1, &arr1) + if err != nil { + t.Fatal("select", err) + } + + list := []struct{ got, want interface{} }{ + {o1, "from-server"}, + {dec1, decimal{123}}, + {str1, "hello"}, + {arr1, []int64{42, 128, 707}}, + } + + for index, item := range list { + if !reflect.DeepEqual(item.got, item.want) { + t.Errorf("got %#v wanted %#v for index %d", item.got, item.want, index) + } + } +} + +func TestNamedValueCheckerSkip(t *testing.T) { + Register("NamedValueCheckSkip", &nvcDriver{skipNamedValueCheck: true}) + db, err := Open("NamedValueCheckSkip", "") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err = db.ExecContext(ctx, "WIPE") + if err != nil { + t.Fatal("exec wipe", err) + } + + _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any") + if err != nil { + t.Fatal("exec create", err) + } + + _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A", Named("A", decimal{123})) + if err == nil { + t.Fatalf("expected error with bad argument, got %v", err) + } +} + // badConn implements a bad driver.Conn, for TestBadDriver. // The Exec method panics. type badConn struct{}