dev-pod-api-build/vendor/github.com/fergusstrange/embedded-postgres/embedded_postgres.go
2026-04-16 04:16:36 +00:00

268 lines
7.2 KiB
Go

package embeddedpostgres
import (
"errors"
"fmt"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
)
var mu sync.Mutex
var (
ErrServerNotStarted = errors.New("server has not been started")
ErrServerAlreadyStarted = errors.New("server is already started")
)
// EmbeddedPostgres maintains all configuration and runtime functions for maintaining the lifecycle of one Postgres process.
type EmbeddedPostgres struct {
config Config
cacheLocator CacheLocator
remoteFetchStrategy RemoteFetchStrategy
initDatabase initDatabase
createDatabase createDatabase
started bool
syncedLogger *syncedLogger
}
// NewDatabase creates a new EmbeddedPostgres struct that can be used to start and stop a Postgres process.
// When called with no parameters it will assume a default configuration state provided by the DefaultConfig method.
// When called with parameters the first Config parameter will be used for configuration.
func NewDatabase(config ...Config) *EmbeddedPostgres {
if len(config) < 1 {
return newDatabaseWithConfig(DefaultConfig())
}
return newDatabaseWithConfig(config[0])
}
func newDatabaseWithConfig(config Config) *EmbeddedPostgres {
versionStrategy := defaultVersionStrategy(
config,
runtime.GOOS,
runtime.GOARCH,
linuxMachineName,
shouldUseAlpineLinuxBuild,
)
cacheLocator := defaultCacheLocator(config.cachePath, versionStrategy)
remoteFetchStrategy := defaultRemoteFetchStrategy(config.binaryRepositoryURL, versionStrategy, cacheLocator)
return &EmbeddedPostgres{
config: config,
cacheLocator: cacheLocator,
remoteFetchStrategy: remoteFetchStrategy,
initDatabase: defaultInitDatabase,
createDatabase: defaultCreateDatabase,
started: false,
}
}
// Start will try to start the configured Postgres process returning an error when there were any problems with invocation.
// If any error occurs Start will try to also Stop the Postgres process in order to not leave any sub-process running.
//
//nolint:funlen
func (ep *EmbeddedPostgres) Start() error {
if ep.started {
return ErrServerAlreadyStarted
}
if err := ensurePortAvailable(ep.config.port); err != nil {
return err
}
logger, err := newSyncedLogger("", ep.config.logger)
if err != nil {
return errors.New("unable to create logger")
}
ep.syncedLogger = logger
cacheLocation, cacheExists := ep.cacheLocator()
if ep.config.runtimePath == "" {
ep.config.runtimePath = filepath.Join(filepath.Dir(cacheLocation), "extracted")
}
if ep.config.dataPath == "" {
ep.config.dataPath = filepath.Join(ep.config.runtimePath, "data")
}
if err := os.RemoveAll(ep.config.runtimePath); err != nil {
return fmt.Errorf("unable to clean up runtime directory %s with error: %s", ep.config.runtimePath, err)
}
if ep.config.binariesPath == "" {
ep.config.binariesPath = ep.config.runtimePath
}
if err := ep.downloadAndExtractBinary(cacheExists, cacheLocation); err != nil {
return err
}
if err := os.MkdirAll(ep.config.runtimePath, os.ModePerm); err != nil {
return fmt.Errorf("unable to create runtime directory %s with error: %s", ep.config.runtimePath, err)
}
reuseData := dataDirIsValid(ep.config.dataPath, ep.config.version)
if !reuseData {
if err := ep.cleanDataDirectoryAndInit(); err != nil {
return err
}
}
if err := startPostgres(ep); err != nil {
return err
}
if err := ep.syncedLogger.flush(); err != nil {
return err
}
ep.started = true
if !reuseData {
if err := ep.createDatabase(ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil {
if stopErr := stopPostgres(ep); stopErr != nil {
return fmt.Errorf("unable to stop database caused by error %s", err)
}
return err
}
}
if err := healthCheckDatabaseOrTimeout(ep.config); err != nil {
if stopErr := stopPostgres(ep); stopErr != nil {
return fmt.Errorf("unable to stop database caused by error %s", err)
}
return err
}
return nil
}
func (ep *EmbeddedPostgres) downloadAndExtractBinary(cacheExists bool, cacheLocation string) error {
// lock to prevent collisions with duplicate downloads
mu.Lock()
defer mu.Unlock()
_, binDirErr := os.Stat(filepath.Join(ep.config.binariesPath, "bin", "pg_ctl"))
if os.IsNotExist(binDirErr) {
if !cacheExists {
if err := ep.remoteFetchStrategy(); err != nil {
return err
}
}
if err := decompressTarXz(defaultTarReader, cacheLocation, ep.config.binariesPath); err != nil {
return err
}
}
return nil
}
func (ep *EmbeddedPostgres) cleanDataDirectoryAndInit() error {
if err := os.RemoveAll(ep.config.dataPath); err != nil {
return fmt.Errorf("unable to clean up data directory %s with error: %s", ep.config.dataPath, err)
}
if err := ep.initDatabase(ep.config.binariesPath, ep.config.runtimePath, ep.config.dataPath, ep.config.username, ep.config.password, ep.config.locale, ep.config.encoding, ep.syncedLogger.file); err != nil {
return err
}
return nil
}
// Stop will try to stop the Postgres process gracefully returning an error when there were any problems.
func (ep *EmbeddedPostgres) Stop() error {
if !ep.started {
return ErrServerNotStarted
}
if err := stopPostgres(ep); err != nil {
return err
}
ep.started = false
if err := ep.syncedLogger.flush(); err != nil {
return err
}
return nil
}
func encodeOptions(port uint32, parameters map[string]string) string {
options := []string{fmt.Sprintf("-p %d", port)}
for k, v := range parameters {
// Double-quote parameter values - they may have spaces.
// Careful: CMD on Windows uses only double quotes to delimit strings.
// It treats single quotes as regular characters.
options = append(options, fmt.Sprintf("-c %s=\"%s\"", k, v))
}
return strings.Join(options, " ")
}
func startPostgres(ep *EmbeddedPostgres) error {
postgresBinary := filepath.Join(ep.config.binariesPath, "bin/pg_ctl")
postgresProcess := exec.Command(postgresBinary, "start", "-w",
"-D", ep.config.dataPath,
"-o", encodeOptions(ep.config.port, ep.config.startParameters))
postgresProcess.Stdout = ep.syncedLogger.file
postgresProcess.Stderr = ep.syncedLogger.file
if err := postgresProcess.Run(); err != nil {
_ = ep.syncedLogger.flush()
logContent, _ := readLogsOrTimeout(ep.syncedLogger.file)
return fmt.Errorf("could not start postgres using %s:\n%s", postgresProcess.String(), string(logContent))
}
return nil
}
func stopPostgres(ep *EmbeddedPostgres) error {
postgresBinary := filepath.Join(ep.config.binariesPath, "bin/pg_ctl")
postgresProcess := exec.Command(postgresBinary, "stop", "-w",
"-D", ep.config.dataPath)
postgresProcess.Stderr = ep.syncedLogger.file
postgresProcess.Stdout = ep.syncedLogger.file
if err := postgresProcess.Run(); err != nil {
return err
}
return nil
}
func ensurePortAvailable(port uint32) error {
conn, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
return fmt.Errorf("process already listening on port %d", port)
}
if err := conn.Close(); err != nil {
return err
}
return nil
}
func dataDirIsValid(dataDir string, version PostgresVersion) bool {
pgVersion := filepath.Join(dataDir, "PG_VERSION")
d, err := os.ReadFile(pgVersion)
if err != nil {
return false
}
v := strings.TrimSuffix(string(d), "\n")
return strings.HasPrefix(string(version), v)
}