package squirrel import ( "bytes" "fmt" "strings" ) // PlaceholderFormat is the interface that wraps the ReplacePlaceholders method. // // ReplacePlaceholders takes a SQL statement and replaces each question mark // placeholder with a (possibly different) SQL placeholder. type PlaceholderFormat interface { ReplacePlaceholders(sql string) (string, error) } type placeholderDebugger interface { debugPlaceholder() string } var ( // Question is a PlaceholderFormat instance that leaves placeholders as // question marks. Question = questionFormat{} // Dollar is a PlaceholderFormat instance that replaces placeholders with // dollar-prefixed positional placeholders (e.g. $1, $2, $3). Dollar = dollarFormat{} // Colon is a PlaceholderFormat instance that replaces placeholders with // colon-prefixed positional placeholders (e.g. :1, :2, :3). Colon = colonFormat{} // AtP is a PlaceholderFormat instance that replaces placeholders with // "@p"-prefixed positional placeholders (e.g. @p1, @p2, @p3). AtP = atpFormat{} ) type questionFormat struct{} func (questionFormat) ReplacePlaceholders(sql string) (string, error) { return sql, nil } func (questionFormat) debugPlaceholder() string { return "?" } type dollarFormat struct{} func (dollarFormat) ReplacePlaceholders(sql string) (string, error) { return replacePositionalPlaceholders(sql, "$") } func (dollarFormat) debugPlaceholder() string { return "$" } type colonFormat struct{} func (colonFormat) ReplacePlaceholders(sql string) (string, error) { return replacePositionalPlaceholders(sql, ":") } func (colonFormat) debugPlaceholder() string { return ":" } type atpFormat struct{} func (atpFormat) ReplacePlaceholders(sql string) (string, error) { return replacePositionalPlaceholders(sql, "@p") } func (atpFormat) debugPlaceholder() string { return "@p" } // Placeholders returns a string with count ? placeholders joined with commas. func Placeholders(count int) string { if count < 1 { return "" } return strings.Repeat(",?", count)[1:] } func replacePositionalPlaceholders(sql, prefix string) (string, error) { buf := &bytes.Buffer{} i := 0 for { p := strings.Index(sql, "?") if p == -1 { break } if len(sql[p:]) > 1 && sql[p:p+2] == "??" { // escape ?? => ? buf.WriteString(sql[:p]) buf.WriteString("?") if len(sql[p:]) == 1 { break } sql = sql[p+2:] } else { i++ buf.WriteString(sql[:p]) fmt.Fprintf(buf, "%s%d", prefix, i) sql = sql[p+1:] } } buf.WriteString(sql) return buf.String(), nil }