dev-pod-api-build/vendor/github.com/fergusstrange/embedded-postgres/prepare_database.go
2026-04-16 04:16:36 +00:00

173 lines
4.4 KiB
Go

package embeddedpostgres
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"github.com/lib/pq"
)
const (
fmtCloseDBConn = "unable to close database connection: %w"
fmtAfterError = "%v happened after error: %w"
)
type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error
type createDatabase func(port uint32, username, password, database string) error
func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error {
passwordFile, err := createPasswordFile(runtimePath, password)
if err != nil {
return err
}
args := []string{
"-A", "password",
"-U", username,
"-D", pgDataDir,
fmt.Sprintf("--pwfile=%s", passwordFile),
}
if locale != "" {
args = append(args, fmt.Sprintf("--locale=%s", locale))
}
if encoding != "" {
args = append(args, fmt.Sprintf("--encoding=%s", encoding))
}
postgresInitDBBinary := filepath.Join(binaryExtractLocation, "bin/initdb")
postgresInitDBProcess := exec.Command(postgresInitDBBinary, args...)
postgresInitDBProcess.Stderr = logger
postgresInitDBProcess.Stdout = logger
if err = postgresInitDBProcess.Run(); err != nil {
logContent, readLogsErr := readLogsOrTimeout(logger) // we want to preserve the original error
if readLogsErr != nil {
logContent = []byte(string(logContent) + " - " + readLogsErr.Error())
}
return fmt.Errorf("unable to init database using '%s': %w\n%s", postgresInitDBProcess.String(), err, string(logContent))
}
if err = os.Remove(passwordFile); err != nil {
return fmt.Errorf("unable to remove password file '%v': %w", passwordFile, err)
}
return nil
}
func createPasswordFile(runtimePath, password string) (string, error) {
passwordFileLocation := filepath.Join(runtimePath, "pwfile")
if err := os.WriteFile(passwordFileLocation, []byte(password), 0600); err != nil {
return "", fmt.Errorf("unable to write password file to %s", passwordFileLocation)
}
return passwordFileLocation, nil
}
func defaultCreateDatabase(port uint32, username, password, database string) (err error) {
if database == "postgres" {
return nil
}
conn, err := openDatabaseConnection(port, username, password, "postgres")
if err != nil {
return errorCustomDatabase(database, err)
}
db := sql.OpenDB(conn)
defer func() {
err = connectionClose(db, err)
}()
if _, err := db.Exec(fmt.Sprintf("CREATE DATABASE \"%s\"", database)); err != nil {
return errorCustomDatabase(database, err)
}
return nil
}
// connectionClose closes the database connection and handles the error of the function that used the database connection
func connectionClose(db io.Closer, err error) error {
closeErr := db.Close()
if closeErr != nil {
closeErr = fmt.Errorf(fmtCloseDBConn, closeErr)
if err != nil {
err = fmt.Errorf(fmtAfterError, closeErr, err)
} else {
err = closeErr
}
}
return err
}
func healthCheckDatabaseOrTimeout(config Config) error {
healthCheckSignal := make(chan bool)
defer close(healthCheckSignal)
timeout, cancelFunc := context.WithTimeout(context.Background(), config.startTimeout)
defer cancelFunc()
go func() {
for timeout.Err() == nil {
if err := healthCheckDatabase(config.port, config.database, config.username, config.password); err != nil {
continue
}
healthCheckSignal <- true
break
}
}()
select {
case <-healthCheckSignal:
return nil
case <-timeout.Done():
return errors.New("timed out waiting for database to become available")
}
}
func healthCheckDatabase(port uint32, database, username, password string) (err error) {
conn, err := openDatabaseConnection(port, username, password, database)
if err != nil {
return err
}
db := sql.OpenDB(conn)
defer func() {
err = connectionClose(db, err)
}()
if _, err := db.Query("SELECT 1"); err != nil {
return err
}
return nil
}
func openDatabaseConnection(port uint32, username string, password string, database string) (*pq.Connector, error) {
conn, err := pq.NewConnector(fmt.Sprintf("host=localhost port=%d user=%s password=%s dbname=%s sslmode=disable",
port,
username,
password,
database))
if err != nil {
return nil, err
}
return conn, nil
}
func errorCustomDatabase(database string, err error) error {
return fmt.Errorf("unable to connect to create database with custom name %s with the following error: %s", database, err)
}