169 lines
4.6 KiB
Go
169 lines
4.6 KiB
Go
package embeddedpostgres
|
|
|
|
import (
|
|
"archive/zip"
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
)
|
|
|
|
// RemoteFetchStrategy provides a strategy to fetch a Postgres binary so that it is available for use.
|
|
type RemoteFetchStrategy func() error
|
|
|
|
//nolint:funlen
|
|
func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionStrategy, cacheLocator CacheLocator) RemoteFetchStrategy {
|
|
return func() error {
|
|
operatingSystem, architecture, version := versionStrategy()
|
|
|
|
jarDownloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar",
|
|
remoteFetchHost,
|
|
operatingSystem,
|
|
architecture,
|
|
version,
|
|
operatingSystem,
|
|
architecture,
|
|
version)
|
|
|
|
jarDownloadResponse, err := http.Get(jarDownloadURL)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to connect to %s", remoteFetchHost)
|
|
}
|
|
|
|
defer closeBody(jarDownloadResponse)()
|
|
|
|
if jarDownloadResponse.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("no version found matching %s", version)
|
|
}
|
|
|
|
jarBodyBytes, err := io.ReadAll(jarDownloadResponse.Body)
|
|
if err != nil {
|
|
return errorFetchingPostgres(err)
|
|
}
|
|
|
|
shaDownloadURL := fmt.Sprintf("%s.sha256", jarDownloadURL)
|
|
shaDownloadResponse, err := http.Get(shaDownloadURL)
|
|
if err != nil {
|
|
return fmt.Errorf("download sha256 from %s failed: %w", shaDownloadURL, err)
|
|
}
|
|
defer closeBody(shaDownloadResponse)()
|
|
|
|
if err == nil && shaDownloadResponse.StatusCode == http.StatusOK {
|
|
if shaBodyBytes, err := io.ReadAll(shaDownloadResponse.Body); err == nil {
|
|
jarChecksum := sha256.Sum256(jarBodyBytes)
|
|
if !bytes.Equal(shaBodyBytes, []byte(hex.EncodeToString(jarChecksum[:]))) {
|
|
return errors.New("downloaded checksums do not match")
|
|
}
|
|
}
|
|
}
|
|
|
|
return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, jarDownloadURL)
|
|
}
|
|
}
|
|
|
|
func closeBody(resp *http.Response) func() {
|
|
return func() {
|
|
if resp == nil || resp.Body == nil {
|
|
return
|
|
}
|
|
if err := resp.Body.Close(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, downloadURL string) error {
|
|
size := contentLength
|
|
// if the content length is not set (i.e. chunked encoding),
|
|
// we need to use the length of the bodyBytes otherwise
|
|
// the unzip operation will fail
|
|
if contentLength < 0 {
|
|
size = int64(len(bodyBytes))
|
|
}
|
|
zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), size)
|
|
if err != nil {
|
|
return errorFetchingPostgres(err)
|
|
}
|
|
|
|
cacheLocation, _ := cacheLocator()
|
|
|
|
if err := os.MkdirAll(filepath.Dir(cacheLocation), 0755); err != nil {
|
|
return errorExtractingPostgres(err)
|
|
}
|
|
|
|
for _, file := range zipReader.File {
|
|
if !file.FileHeader.FileInfo().IsDir() && strings.HasSuffix(file.FileHeader.Name, ".txz") {
|
|
if err := decompressSingleFile(file, cacheLocation); err != nil {
|
|
return err
|
|
}
|
|
|
|
// we have successfully found the file, return early
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("error fetching postgres: cannot find binary in archive retrieved from %s", downloadURL)
|
|
}
|
|
|
|
func decompressSingleFile(file *zip.File, cacheLocation string) error {
|
|
renamed := false
|
|
|
|
archiveReader, err := file.Open()
|
|
if err != nil {
|
|
return errorExtractingPostgres(err)
|
|
}
|
|
|
|
archiveBytes, err := io.ReadAll(archiveReader)
|
|
if err != nil {
|
|
return errorExtractingPostgres(err)
|
|
}
|
|
|
|
// if multiple processes attempt to extract
|
|
// to prevent file corruption when multiple processes attempt to extract at the same time
|
|
// first to a cache location, and then move the file into place.
|
|
tmp, err := os.CreateTemp(filepath.Dir(cacheLocation), "temp_")
|
|
if err != nil {
|
|
return errorExtractingPostgres(err)
|
|
}
|
|
defer func() {
|
|
// if anything failed before the rename then the temporary file should be cleaned up.
|
|
// if the rename was successful then there is no temporary file to remove.
|
|
if !renamed {
|
|
if err := os.Remove(tmp.Name()); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
}()
|
|
|
|
if _, err := tmp.Write(archiveBytes); err != nil {
|
|
return errorExtractingPostgres(err)
|
|
}
|
|
|
|
// Windows cannot rename a file if is it still open.
|
|
// The file needs to be manually closed to allow the rename to happen
|
|
if err := tmp.Close(); err != nil {
|
|
return errorExtractingPostgres(err)
|
|
}
|
|
|
|
if err := renameOrIgnore(tmp.Name(), cacheLocation); err != nil {
|
|
return errorExtractingPostgres(err)
|
|
}
|
|
renamed = true
|
|
|
|
return nil
|
|
}
|
|
|
|
func errorExtractingPostgres(err error) error {
|
|
return fmt.Errorf("unable to extract postgres archive: %s", err)
|
|
}
|
|
|
|
func errorFetchingPostgres(err error) error {
|
|
return fmt.Errorf("error fetching postgres: %s", err)
|
|
}
|