From 1beeba379ec0364d6a5a4cfd59a992c4bc40ce16 Mon Sep 17 00:00:00 2001 From: Jaglyser Date: Thu, 6 Nov 2025 10:39:48 +0100 Subject: [PATCH] Added dynamic port allocation, GetPort helper, and test --- embedded_postgres.go | 28 +++++++++++++++++++++ embedded_postgres_test.go | 52 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/embedded_postgres.go b/embedded_postgres.go index afe8497..cff0132 100644 --- a/embedded_postgres.go +++ b/embedded_postgres.go @@ -71,6 +71,15 @@ func (ep *EmbeddedPostgres) Start() error { return ErrServerAlreadyStarted } + if ep.config.port == 0 { + port, err := getFreePort() + if err != nil { + return err + } + + ep.config.port = port + } + if err := ensurePortAvailable(ep.config.port); err != nil { return err } @@ -147,6 +156,10 @@ func (ep *EmbeddedPostgres) Start() error { return nil } +func (ep *EmbeddedPostgres) GetPort() uint32 { + return ep.config.port +} + func (ep *EmbeddedPostgres) downloadAndExtractBinary(cacheExists bool, cacheLocation string) error { // lock to prevent collisions with duplicate downloads mu.Lock() @@ -254,6 +267,21 @@ func ensurePortAvailable(port uint32) error { return nil } +func getFreePort() (uint32, error) { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + return 0, fmt.Errorf("failed to resolve TCP address: %w", err) + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return 0, fmt.Errorf("failed to listen on TCP: %w", err) + } + defer l.Close() + + return uint32(l.Addr().(*net.TCPAddr).Port), nil +} + func dataDirIsValid(dataDir string, version PostgresVersion) bool { pgVersion := filepath.Join(dataDir, "PG_VERSION") diff --git a/embedded_postgres_test.go b/embedded_postgres_test.go index 6429a89..bead4d3 100644 --- a/embedded_postgres_test.go +++ b/embedded_postgres_test.go @@ -833,3 +833,55 @@ func Test_RunningInParallel(t *testing.T) { waitGroup.Wait() } + +func Test_DynamicallyAllocatingPort(t *testing.T) { + tempDir, err := os.MkdirTemp("", "embedded_postgres_test") + if err != nil { + panic(err) + } + + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + panic(err) + } + }() + + database := NewDatabase(DefaultConfig(). + Username("gin"). + Password("wine"). + Database("beer"). + Version(V15). + RuntimePath(tempDir). + Port(0). + StartTimeout(10 * time.Second). + Locale("C"). + Encoding("UTF8"). + Logger(nil)) + + if err := database.Start(); err != nil { + shutdownDBAndFail(t, err, database) + } + + port := database.GetPort() + + db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d user=gin password=wine dbname=beer sslmode=disable", port)) + if err != nil { + shutdownDBAndFail(t, err, database) + } + + if !strings.Contains(database.config.GetConnectionURL(), fmt.Sprint(port)) { + shutdownDBAndFail(t, errors.New("wrong port in connection url"), database) + } + + if err = db.Ping(); err != nil { + shutdownDBAndFail(t, err, database) + } + + if err := db.Close(); err != nil { + shutdownDBAndFail(t, err, database) + } + + if err := database.Stop(); err != nil { + shutdownDBAndFail(t, err, database) + } +}