package sqlx import ( "bytes" "database/sql/driver" "errors" "reflect" "strconv" "strings" "sync" "github.com/jmoiron/sqlx/reflectx" ) // Bindvar types supported by Rebind, BindMap and BindStruct. const ( UNKNOWN = iota QUESTION DOLLAR NAMED AT ) var defaultBinds = map[int][]string{ DOLLAR: []string{"postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql", "nrpostgres", "cockroach"}, QUESTION: []string{"mysql", "sqlite3", "nrmysql", "nrsqlite3"}, NAMED: []string{"oci8", "ora", "goracle"}, AT: []string{"sqlserver"}, } var binds sync.Map func init() { for bind, drivers := range defaultBinds { for _, driver := range drivers { BindDriver(driver, bind) } } } // BindType returns the bindtype for a given database given a drivername. func BindType(driverName string) int { itype, ok := binds.Load(driverName) if !ok { return UNKNOWN } return itype.(int) } // BindDriver sets the BindType for driverName to bindType. func BindDriver(driverName string, bindType int) { binds.Store(driverName, bindType) } // FIXME: this should be able to be tolerant of escaped ?'s in queries without // losing much speed, and should be to avoid confusion. // Rebind a query from the default bindtype (QUESTION) to the target bindtype. func Rebind(bindType int, query string) string { switch bindType { case QUESTION, UNKNOWN: return query } // Add space enough for 10 params before we have to allocate rqb := make([]byte, 0, len(query)+10) var i, j int for i = strings.Index(query, "?"); i != -1; i = strings.Index(query, "?") { rqb = append(rqb, query[:i]...) switch bindType { case DOLLAR: rqb = append(rqb, '$') case NAMED: rqb = append(rqb, ':', 'a', 'r', 'g') case AT: rqb = append(rqb, '@', 'p') } j++ rqb = strconv.AppendInt(rqb, int64(j), 10) query = query[i+1:] } return string(append(rqb, query...)) } // Experimental implementation of Rebind which uses a bytes.Buffer. The code is // much simpler and should be more resistant to odd unicode, but it is twice as // slow. Kept here for benchmarking purposes and to possibly replace Rebind if // problems arise with its somewhat naive handling of unicode. func rebindBuff(bindType int, query string) string { if bindType != DOLLAR { return query } b := make([]byte, 0, len(query)) rqb := bytes.NewBuffer(b) j := 1 for _, r := range query { if r == '?' { rqb.WriteRune('$') rqb.WriteString(strconv.Itoa(j)) j++ } else { rqb.WriteRune(r) } } return rqb.String() } func asSliceForIn(i interface{}) (v reflect.Value, ok bool) { if i == nil { return reflect.Value{}, false } v = reflect.ValueOf(i) t := reflectx.Deref(v.Type()) // Only expand slices if t.Kind() != reflect.Slice { return reflect.Value{}, false } // []byte is a driver.Value type so it should not be expanded if t == reflect.TypeOf([]byte{}) { return reflect.Value{}, false } return v, true } // In expands slice values in args, returning the modified query string // and a new arg list that can be executed by a database. The `query` should // use the `?` bindVar. The return value uses the `?` bindVar. func In(query string, args ...interface{}) (string, []interface{}, error) { // argMeta stores reflect.Value and length for slices and // the value itself for non-slice arguments type argMeta struct { v reflect.Value i interface{} length int } var flatArgsCount int var anySlices bool var stackMeta [32]argMeta var meta []argMeta if len(args) <= len(stackMeta) { meta = stackMeta[:len(args)] } else { meta = make([]argMeta, len(args)) } for i, arg := range args { if a, ok := arg.(driver.Valuer); ok { var err error arg, err = a.Value() if err != nil { return "", nil, err } } if v, ok := asSliceForIn(arg); ok { meta[i].length = v.Len() meta[i].v = v anySlices = true flatArgsCount += meta[i].length if meta[i].length == 0 { return "", nil, errors.New("empty slice passed to 'in' query") } } else { meta[i].i = arg flatArgsCount++ } } // don't do any parsing if there aren't any slices; note that this means // some errors that we might have caught below will not be returned. if !anySlices { return query, args, nil } newArgs := make([]interface{}, 0, flatArgsCount) var buf strings.Builder buf.Grow(len(query) + len(", ?")*flatArgsCount) var arg, offset int for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') { if arg >= len(meta) { // if an argument wasn't passed, lets return an error; this is // not actually how database/sql Exec/Query works, but since we are // creating an argument list programmatically, we want to be able // to catch these programmer errors earlier. return "", nil, errors.New("number of bindVars exceeds arguments") } argMeta := meta[arg] arg++ // not a slice, continue. // our questionmark will either be written before the next expansion // of a slice or after the loop when writing the rest of the query if argMeta.length == 0 { offset = offset + i + 1 newArgs = append(newArgs, argMeta.i) continue } // write everything up to and including our ? character buf.WriteString(query[:offset+i+1]) for si := 1; si < argMeta.length; si++ { buf.WriteString(", ?") } newArgs = appendReflectSlice(newArgs, argMeta.v, argMeta.length) // slice the query and reset the offset. this avoids some bookkeeping for // the write after the loop query = query[offset+i+1:] offset = 0 } buf.WriteString(query) if arg < len(meta) { return "", nil, errors.New("number of bindVars less than number arguments") } return buf.String(), newArgs, nil } func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} { switch val := v.Interface().(type) { case []interface{}: args = append(args, val...) case []int: for i := range val { args = append(args, val[i]) } case []string: for i := range val { args = append(args, val[i]) } default: for si := 0; si < vlen; si++ { args = append(args, v.Index(si).Interface()) } } return args }