build source

This commit is contained in:
build 2026-04-16 04:16:36 +00:00
commit ee1fec43ed
4171 changed files with 1351288 additions and 0 deletions

View file

@ -0,0 +1,78 @@
// Package iobufpool implements a global segregated-fit pool of buffers for IO.
//
// It uses *[]byte instead of []byte to avoid the sync.Pool allocation with Put. Unfortunately, using a pointer to avoid
// an allocation is purposely not documented. https://github.com/golang/go/issues/16323
package iobufpool
import (
"math/bits"
"sync"
)
const minPoolExpOf2 = 8
var pools [18]*sync.Pool
func init() {
for i := range pools {
bufLen := 1 << (minPoolExpOf2 + i)
pools[i] = &sync.Pool{
New: func() any {
buf := make([]byte, bufLen)
return &buf
},
}
}
}
// Get gets a []byte of len size with cap <= size*2.
func Get(size int) *[]byte {
i := getPoolIdx(size)
if i >= len(pools) {
buf := make([]byte, size)
return &buf
}
ptrBuf := (pools[i].Get().(*[]byte))
*ptrBuf = (*ptrBuf)[:size]
return ptrBuf
}
func getPoolIdx(size int) int {
if size < 2 {
return 0
}
idx := bits.Len(uint(size-1)) - minPoolExpOf2
if idx < 0 {
return 0
}
return idx
}
// Put returns buf to the pool.
func Put(buf *[]byte) {
i := putPoolIdx(cap(*buf))
if i < 0 {
return
}
pools[i].Put(buf)
}
func putPoolIdx(size int) int {
// Only exact power-of-2 sizes match pool buckets
if size&(size-1) != 0 {
return -1
}
// Calculate log2(size) using trailing zeros count
exp := bits.TrailingZeros(uint(size))
idx := exp - minPoolExpOf2
if idx < 0 || idx >= len(pools) {
return -1
}
return idx
}

View file

@ -0,0 +1,6 @@
# pgio
Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
pgio provides functions for appending integers to a []byte while doing byte
order conversion.

6
vendor/github.com/jackc/pgx/v5/internal/pgio/doc.go generated vendored Normal file
View file

@ -0,0 +1,6 @@
// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
/*
pgio provides functions for appending integers to a []byte while doing byte
order conversion.
*/
package pgio

32
vendor/github.com/jackc/pgx/v5/internal/pgio/write.go generated vendored Normal file
View file

@ -0,0 +1,32 @@
package pgio
func AppendUint16(buf []byte, n uint16) []byte {
return append(buf, byte(n>>8), byte(n))
}
func AppendUint32(buf []byte, n uint32) []byte {
return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
}
func AppendUint64(buf []byte, n uint64) []byte {
return append(buf,
byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32),
byte(n>>24), byte(n>>16), byte(n>>8), byte(n),
)
}
func AppendInt16(buf []byte, n int16) []byte {
return AppendUint16(buf, uint16(n))
}
func AppendInt32(buf []byte, n int32) []byte {
return AppendUint32(buf, uint32(n))
}
func AppendInt64(buf []byte, n int64) []byte {
return AppendUint64(buf, uint64(n))
}
func SetInt32(buf []byte, n int32) {
*(*[4]byte)(buf) = [4]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)}
}

View file

@ -0,0 +1,60 @@
#!/usr/bin/env bash
current_branch=$(git rev-parse --abbrev-ref HEAD)
if [ "$current_branch" == "HEAD" ]; then
current_branch=$(git rev-parse HEAD)
fi
restore_branch() {
echo "Restoring original branch/commit: $current_branch"
git checkout "$current_branch"
}
trap restore_branch EXIT
# Check if there are uncommitted changes
if ! git diff --quiet || ! git diff --cached --quiet; then
echo "There are uncommitted changes. Please commit or stash them before running this script."
exit 1
fi
# Ensure that at least one commit argument is passed
if [ "$#" -lt 1 ]; then
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
exit 1
fi
commits=("$@")
benchmarks_dir=benchmarks
if ! mkdir -p "${benchmarks_dir}"; then
echo "Unable to create dir for benchmarks data"
exit 1
fi
# Benchmark results
bench_files=()
# Run benchmark for each listed commit
for i in "${!commits[@]}"; do
commit="${commits[i]}"
git checkout "$commit" || {
echo "Failed to checkout $commit"
exit 1
}
# Sanitized commit message
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
# Benchmark data will go there
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
if ! go test -bench=. -count=10 >"$bench_file"; then
echo "Benchmarking failed for commit $commit"
exit 1
fi
bench_files+=("$bench_file")
done
# go install golang.org/x/perf/cmd/benchstat[@latest]
benchstat "${bench_files[@]}"

View file

@ -0,0 +1,460 @@
package sanitize
import (
"bytes"
"encoding/hex"
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
)
// Part is either a string or an int. A string is raw SQL. An int is a
// argument placeholder.
type Part any
type Query struct {
Parts []Part
}
// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
// character. utf8.RuneError is not an error if it is also width 3.
//
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3
const maxBufSize = 16384 // 16 Ki
var bufPool = &pool[*bytes.Buffer]{
new: func() *bytes.Buffer {
return &bytes.Buffer{}
},
reset: func(b *bytes.Buffer) bool {
n := b.Len()
b.Reset()
return n < maxBufSize
},
}
var null = []byte("null")
func (q *Query) Sanitize(args ...any) (string, error) {
argUse := make([]bool, len(args))
buf := bufPool.get()
defer bufPool.put(buf)
for _, part := range q.Parts {
switch part := part.(type) {
case string:
buf.WriteString(part)
case int:
argIdx := part - 1
var p []byte
if argIdx < 0 {
return "", fmt.Errorf("first sql argument must be > 0")
}
if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments")
}
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')
arg := args[argIdx]
switch arg := arg.(type) {
case nil:
p = null
case int64:
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
case float64:
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
case bool:
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
case []byte:
p = QuoteBytes(buf.AvailableBuffer(), arg)
case string:
p = QuoteString(buf.AvailableBuffer(), arg)
case time.Time:
p = arg.Truncate(time.Microsecond).
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", fmt.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true
buf.Write(p)
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')
default:
return "", fmt.Errorf("invalid Part type: %T", part)
}
}
for i, used := range argUse {
if !used {
return "", fmt.Errorf("unused argument: %d", i)
}
}
return buf.String(), nil
}
func NewQuery(sql string) (*Query, error) {
query := &Query{}
query.init(sql)
return query, nil
}
var sqlLexerPool = &pool[*sqlLexer]{
new: func() *sqlLexer {
return &sqlLexer{}
},
reset: func(sl *sqlLexer) bool {
*sl = sqlLexer{}
return true
},
}
func (q *Query) init(sql string) {
parts := q.Parts[:0]
if parts == nil {
// dirty, but fast heuristic to preallocate for ~90% usecases
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
parts = make([]Part, 0, n)
}
l := sqlLexerPool.get()
defer sqlLexerPool.put(l)
l.src = sql
l.stateFn = rawState
l.parts = parts
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
q.Parts = l.parts
}
func QuoteString(dst []byte, str string) []byte {
const quote = '\''
// Preallocate space for the worst case scenario
dst = slices.Grow(dst, len(str)*2+2)
// Add opening quote
dst = append(dst, quote)
// Iterate through the string without allocating
for i := 0; i < len(str); i++ {
if str[i] == quote {
dst = append(dst, quote, quote)
} else {
dst = append(dst, str[i])
}
}
// Add closing quote
dst = append(dst, quote)
return dst
}
func QuoteBytes(dst, buf []byte) []byte {
if len(buf) == 0 {
return append(dst, `'\x'`...)
}
// Calculate required length
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
// Ensure dst has enough capacity
if cap(dst)-len(dst) < requiredLen {
newDst := make([]byte, len(dst), len(dst)+requiredLen)
copy(newDst, dst)
dst = newDst
}
// Record original length and extend slice
origLen := len(dst)
dst = dst[:origLen+requiredLen]
// Add prefix
dst[origLen] = '\''
dst[origLen+1] = '\\'
dst[origLen+2] = 'x'
// Encode bytes directly into dst
hex.Encode(dst[origLen+3:len(dst)-1], buf)
// Add suffix
dst[len(dst)-1] = '\''
return dst
}
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []Part
}
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case 'e', 'E':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '\'' {
l.pos += width
return escapeStringState
}
case '\'':
return singleQuoteState
case '"':
return doubleQuoteState
case '$':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if '0' <= nextRune && nextRune <= '9' {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return placeholderState
}
case '-':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '-' {
l.pos += width
return oneLineCommentState
}
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
return multilineCommentState
}
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
func singleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
func doubleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '"':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '"' {
return rawState
}
l.pos += width
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
// placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit.
func placeholderState(l *sqlLexer) stateFn {
num := 0
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
if '0' <= r && r <= '9' {
num *= 10
num += int(r - '0')
} else {
l.parts = append(l.parts, num)
l.pos -= width
l.start = l.pos
return rawState
}
}
}
func escapeStringState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
func oneLineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\n', '\r':
return rawState
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
func multilineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
l.nested++
}
case '*':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '/' {
continue
}
l.pos += width
if l.nested == 0 {
return rawState
}
l.nested--
case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
}
var queryPool = &pool[*Query]{
new: func() *Query {
return &Query{}
},
reset: func(q *Query) bool {
n := len(q.Parts)
q.Parts = q.Parts[:0]
return n < 64 // drop too large queries
},
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...any) (string, error) {
query := queryPool.get()
query.init(sql)
defer queryPool.put(query)
return query.Sanitize(args...)
}
type pool[E any] struct {
p sync.Pool
new func() E
reset func(E) bool
}
func (pool *pool[E]) get() E {
v, ok := pool.p.Get().(E)
if !ok {
v = pool.new()
}
return v
}
func (p *pool[E]) put(v E) {
if p.reset(v) {
p.p.Put(v)
}
}

View file

@ -0,0 +1,187 @@
package stmtcache
import (
"github.com/jackc/pgx/v5/pgconn"
)
// lruNode is a typed doubly-linked list node with freelist support.
type lruNode struct {
sd *pgconn.StatementDescription
prev *lruNode
next *lruNode
}
// LRUCache implements Cache with a Least Recently Used (LRU) cache.
type LRUCache struct {
m map[string]*lruNode
head *lruNode
tail *lruNode
len int
cap int
freelist *lruNode
invalidStmts []*pgconn.StatementDescription
invalidSet map[string]struct{}
}
// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache.
func NewLRUCache(cap int) *LRUCache {
head := &lruNode{}
tail := &lruNode{}
head.next = tail
tail.prev = head
return &LRUCache{
cap: cap,
m: make(map[string]*lruNode, cap),
head: head,
tail: tail,
invalidSet: make(map[string]struct{}),
}
}
// Get returns the statement description for sql. Returns nil if not found.
func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
node, ok := c.m[key]
if !ok {
return nil
}
c.moveToFront(node)
return node.sd
}
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or
// sd.SQL has been invalidated and HandleInvalidated has not been called yet.
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
if sd.SQL == "" {
panic("cannot store statement description with empty SQL")
}
if _, present := c.m[sd.SQL]; present {
return
}
// The statement may have been invalidated but not yet handled. Do not readd it to the cache.
if _, invalidated := c.invalidSet[sd.SQL]; invalidated {
return
}
if c.len == c.cap {
c.invalidateOldest()
}
node := c.allocNode()
node.sd = sd
c.insertAfter(c.head, node)
c.m[sd.SQL] = node
c.len++
}
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
func (c *LRUCache) Invalidate(sql string) {
node, ok := c.m[sql]
if !ok {
return
}
delete(c.m, sql)
c.invalidStmts = append(c.invalidStmts, node.sd)
c.invalidSet[sql] = struct{}{}
c.unlink(node)
c.len--
c.freeNode(node)
}
// InvalidateAll invalidates all statement descriptions.
func (c *LRUCache) InvalidateAll() {
for node := c.head.next; node != c.tail; {
next := node.next
c.invalidStmts = append(c.invalidStmts, node.sd)
c.invalidSet[node.sd.SQL] = struct{}{}
c.freeNode(node)
node = next
}
clear(c.m)
c.head.next = c.tail
c.tail.prev = c.head
c.len = 0
}
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *LRUCache) RemoveInvalidated() {
c.invalidStmts = c.invalidStmts[:0]
clear(c.invalidSet)
}
// Len returns the number of cached prepared statement descriptions.
func (c *LRUCache) Len() int {
return c.len
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *LRUCache) Cap() int {
return c.cap
}
func (c *LRUCache) invalidateOldest() {
node := c.tail.prev
if node == c.head {
return
}
c.invalidStmts = append(c.invalidStmts, node.sd)
c.invalidSet[node.sd.SQL] = struct{}{}
delete(c.m, node.sd.SQL)
c.unlink(node)
c.len--
c.freeNode(node)
}
// List operations - sentinel nodes eliminate nil checks
func (c *LRUCache) insertAfter(at, node *lruNode) {
node.prev = at
node.next = at.next
at.next.prev = node
at.next = node
}
func (c *LRUCache) unlink(node *lruNode) {
node.prev.next = node.next
node.next.prev = node.prev
}
func (c *LRUCache) moveToFront(node *lruNode) {
if node.prev == c.head {
return
}
c.unlink(node)
c.insertAfter(c.head, node)
}
// Node pool operations - reuse evicted nodes to avoid allocations
func (c *LRUCache) allocNode() *lruNode {
if c.freelist != nil {
node := c.freelist
c.freelist = node.next
node.next = nil
node.prev = nil
return node
}
return &lruNode{}
}
func (c *LRUCache) freeNode(node *lruNode) {
node.sd = nil
node.prev = nil
node.next = c.freelist
c.freelist = node
}

View file

@ -0,0 +1,45 @@
// Package stmtcache is a cache for statement descriptions.
package stmtcache
import (
"crypto/sha256"
"encoding/hex"
"github.com/jackc/pgx/v5/pgconn"
)
// StatementName returns a statement name that will be stable for sql across multiple connections and program
// executions.
func StatementName(sql string) string {
digest := sha256.Sum256([]byte(sql))
return "stmtcache_" + hex.EncodeToString(digest[0:24])
}
// Cache caches statement descriptions.
type Cache interface {
// Get returns the statement description for sql. Returns nil if not found.
Get(sql string) *pgconn.StatementDescription
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
Put(sd *pgconn.StatementDescription)
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
Invalidate(sql string)
// InvalidateAll invalidates all statement descriptions.
InvalidateAll()
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
GetInvalidated() []*pgconn.StatementDescription
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
RemoveInvalidated()
// Len returns the number of cached prepared statement descriptions.
Len() int
// Cap returns the maximum number of cached prepared statement descriptions.
Cap() int
}