dev-pod-api-build/vendor/github.com/fergusstrange/embedded-postgres/remote_fetch.go
2026-04-16 04:16:36 +00:00

169 lines
4.6 KiB
Go

package embeddedpostgres
import (
"archive/zip"
"bytes"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
)
// RemoteFetchStrategy provides a strategy to fetch a Postgres binary so that it is available for use.
type RemoteFetchStrategy func() error
//nolint:funlen
func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionStrategy, cacheLocator CacheLocator) RemoteFetchStrategy {
return func() error {
operatingSystem, architecture, version := versionStrategy()
jarDownloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar",
remoteFetchHost,
operatingSystem,
architecture,
version,
operatingSystem,
architecture,
version)
jarDownloadResponse, err := http.Get(jarDownloadURL)
if err != nil {
return fmt.Errorf("unable to connect to %s", remoteFetchHost)
}
defer closeBody(jarDownloadResponse)()
if jarDownloadResponse.StatusCode != http.StatusOK {
return fmt.Errorf("no version found matching %s", version)
}
jarBodyBytes, err := io.ReadAll(jarDownloadResponse.Body)
if err != nil {
return errorFetchingPostgres(err)
}
shaDownloadURL := fmt.Sprintf("%s.sha256", jarDownloadURL)
shaDownloadResponse, err := http.Get(shaDownloadURL)
if err != nil {
return fmt.Errorf("download sha256 from %s failed: %w", shaDownloadURL, err)
}
defer closeBody(shaDownloadResponse)()
if err == nil && shaDownloadResponse.StatusCode == http.StatusOK {
if shaBodyBytes, err := io.ReadAll(shaDownloadResponse.Body); err == nil {
jarChecksum := sha256.Sum256(jarBodyBytes)
if !bytes.Equal(shaBodyBytes, []byte(hex.EncodeToString(jarChecksum[:]))) {
return errors.New("downloaded checksums do not match")
}
}
}
return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, jarDownloadURL)
}
}
func closeBody(resp *http.Response) func() {
return func() {
if resp == nil || resp.Body == nil {
return
}
if err := resp.Body.Close(); err != nil {
log.Fatal(err)
}
}
}
func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, downloadURL string) error {
size := contentLength
// if the content length is not set (i.e. chunked encoding),
// we need to use the length of the bodyBytes otherwise
// the unzip operation will fail
if contentLength < 0 {
size = int64(len(bodyBytes))
}
zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), size)
if err != nil {
return errorFetchingPostgres(err)
}
cacheLocation, _ := cacheLocator()
if err := os.MkdirAll(filepath.Dir(cacheLocation), 0755); err != nil {
return errorExtractingPostgres(err)
}
for _, file := range zipReader.File {
if !file.FileHeader.FileInfo().IsDir() && strings.HasSuffix(file.FileHeader.Name, ".txz") {
if err := decompressSingleFile(file, cacheLocation); err != nil {
return err
}
// we have successfully found the file, return early
return nil
}
}
return fmt.Errorf("error fetching postgres: cannot find binary in archive retrieved from %s", downloadURL)
}
func decompressSingleFile(file *zip.File, cacheLocation string) error {
renamed := false
archiveReader, err := file.Open()
if err != nil {
return errorExtractingPostgres(err)
}
archiveBytes, err := io.ReadAll(archiveReader)
if err != nil {
return errorExtractingPostgres(err)
}
// if multiple processes attempt to extract
// to prevent file corruption when multiple processes attempt to extract at the same time
// first to a cache location, and then move the file into place.
tmp, err := os.CreateTemp(filepath.Dir(cacheLocation), "temp_")
if err != nil {
return errorExtractingPostgres(err)
}
defer func() {
// if anything failed before the rename then the temporary file should be cleaned up.
// if the rename was successful then there is no temporary file to remove.
if !renamed {
if err := os.Remove(tmp.Name()); err != nil {
panic(err)
}
}
}()
if _, err := tmp.Write(archiveBytes); err != nil {
return errorExtractingPostgres(err)
}
// Windows cannot rename a file if is it still open.
// The file needs to be manually closed to allow the rename to happen
if err := tmp.Close(); err != nil {
return errorExtractingPostgres(err)
}
if err := renameOrIgnore(tmp.Name(), cacheLocation); err != nil {
return errorExtractingPostgres(err)
}
renamed = true
return nil
}
func errorExtractingPostgres(err error) error {
return fmt.Errorf("unable to extract postgres archive: %s", err)
}
func errorFetchingPostgres(err error) error {
return fmt.Errorf("error fetching postgres: %s", err)
}