diff --git a/config.go b/config.go index 35e9a98..9e63e00 100644 --- a/config.go +++ b/config.go @@ -3,6 +3,7 @@ package embeddedpostgres import ( "fmt" "io" + "net/url" "os" "time" ) @@ -11,6 +12,8 @@ import ( type Config struct { version PostgresVersion port uint32 + useUnixSocket bool + unixSocketDirectory string database string username string password string @@ -38,6 +41,8 @@ func DefaultConfig() Config { return Config{ version: V16, port: 5432, + useUnixSocket: false, + unixSocketDirectory: "/tmp/", database: "postgres", username: "postgres", password: "postgres", @@ -59,6 +64,17 @@ func (c Config) Port(port uint32) Config { return c } +// WithoutTcp makes Posgres listen on a UNIX socket instead of opening a TCP port. +func (c Config) WithoutTcp() Config { + c.useUnixSocket = true + return c +} + +func (c Config) WithUnixSocketDirectory(dir string) Config { + c.unixSocketDirectory = dir + return c +} + // Database sets the database name that will be created. func (c Config) Database(database string) Config { c.database = database @@ -145,7 +161,23 @@ func (c Config) BinaryRepositoryURL(binaryRepositoryURL string) Config { } func (c Config) GetConnectionURL() string { - return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", c.username, c.password, "localhost", c.port, c.database) + u := &url.URL{ + Scheme: "postgresql", + User: url.UserPassword(c.username, c.password), + Path: "/" + c.database, + } + + if c.useUnixSocket { + u.Host = fmt.Sprintf(":%d", c.port) + + q := url.Values{} + q.Set("host", c.unixSocketDirectory) + u.RawQuery = q.Encode() + } else { + u.Host = fmt.Sprintf("localhost:%d", c.port) + } + + return u.String() } // PostgresVersion represents the semantic version used to fetch and run the Postgres process. diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..1efb472 --- /dev/null +++ b/config_test.go @@ -0,0 +1,32 @@ +package embeddedpostgres + +import ( + "testing" +) + +func TestGetConnectionURL(t *testing.T) { + config := DefaultConfig().Database("mydb").Username("myuser").Password("mypass") + expect := "postgresql://myuser:mypass@localhost:5432/mydb" + + if got := config.GetConnectionURL(); got != expect { + t.Errorf("expected \"%s\" got \"%s\"", expect, got) + } +} + +func TestGetConnectionURLWithUnixSocket(t *testing.T) { + config := DefaultConfig().Database("mydb").Username("myuser").Password("mypass").WithoutTcp() + expect := "postgresql://myuser:mypass@:5432/mydb?host=%2Ftmp%2F" + + if got := config.GetConnectionURL(); got != expect { + t.Errorf("expected \"%s\" got \"%s\"", expect, got) + } +} + +func TestGetConnectionURLWithUnixSocketInCustomDir(t *testing.T) { + config := DefaultConfig().Database("mydb").Username("myuser").Password("mypass").WithoutTcp().WithUnixSocketDirectory("/path/to/socks") + expect := "postgresql://myuser:mypass@:5432/mydb?host=%2Fpath%2Fto%2Fsocks" + + if got := config.GetConnectionURL(); got != expect { + t.Errorf("expected \"%s\" got \"%s\"", expect, got) + } +} diff --git a/embedded_postgres.go b/embedded_postgres.go index afe8497..317b765 100644 --- a/embedded_postgres.go +++ b/embedded_postgres.go @@ -127,7 +127,12 @@ func (ep *EmbeddedPostgres) Start() error { ep.started = true if !reuseData { - if err := ep.createDatabase(ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil { + host := "localhost" + if ep.config.useUnixSocket { + host = ep.config.unixSocketDirectory + } + + if err := ep.createDatabase(host, ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil { if stopErr := stopPostgres(ep); stopErr != nil { return fmt.Errorf("unable to stop database caused by error %s", err) } @@ -167,6 +172,10 @@ func (ep *EmbeddedPostgres) downloadAndExtractBinary(cacheExists bool, cacheLoca return nil } +func (ep *EmbeddedPostgres) GetConnectionURL() string { + return ep.config.GetConnectionURL() +} + func (ep *EmbeddedPostgres) cleanDataDirectoryAndInit() error { if err := os.RemoveAll(ep.config.dataPath); err != nil { return fmt.Errorf("unable to clean up data directory %s with error: %s", ep.config.dataPath, err) @@ -210,7 +219,20 @@ func encodeOptions(port uint32, parameters map[string]string) string { } func startPostgres(ep *EmbeddedPostgres) error { + if ep.config.startParameters == nil { + ep.config.startParameters = make(map[string]string) + } + + if ep.config.useUnixSocket { + ep.config.startParameters["listen_addresses"] = "" + ep.config.startParameters["unix_socket_directories"] = ep.config.unixSocketDirectory + } + postgresBinary := filepath.Join(ep.config.binariesPath, "bin/pg_ctl") + fmt.Println(postgresBinary, "start", "-w", + "-D", ep.config.dataPath, + "-o", encodeOptions(ep.config.port, ep.config.startParameters)) + postgresProcess := exec.Command(postgresBinary, "start", "-w", "-D", ep.config.dataPath, "-o", encodeOptions(ep.config.port, ep.config.startParameters)) diff --git a/embedded_postgres_test.go b/embedded_postgres_test.go index e7e98b3..d76249b 100644 --- a/embedded_postgres_test.go +++ b/embedded_postgres_test.go @@ -156,7 +156,7 @@ func Test_ErrorWhenUnableToCreateDatabase(t *testing.T) { RuntimePath(extractPath). StartTimeout(10 * time.Second)) - database.createDatabase = func(port uint32, username, password, database string) error { + database.createDatabase = func(host string, port uint32, username, password, database string) error { return errors.New("ah noes") } @@ -176,7 +176,7 @@ func Test_TimesOutWhenCannotStart(t *testing.T) { Database("something-fancy"). StartTimeout(500 * time.Millisecond)) - database.createDatabase = func(port uint32, username, password, database string) error { + database.createDatabase = func(host string, port uint32, username, password, database string) error { return nil } @@ -802,3 +802,34 @@ func Test_RunningInParallel(t *testing.T) { waitGroup.Wait() } + +func Test_RunOnUnixSocket(t *testing.T) { + database := NewDatabase(DefaultConfig().Port(9876).WithoutTcp()) + if err := database.Start(); err != nil { + shutdownDBAndFail(t, err, database) + } + + defer database.Stop() + + if _, err := os.Stat("/tmp/.s.PGSQL.9876"); err != nil { + shutdownDBAndFail(t, err, database) + } +} + +func Test_RunOnUnixSocketOnCustomPath(t *testing.T) { + tempPath, err := os.MkdirTemp("", "custom_dir_socks") + if err != nil { + panic(err) + } + + database := NewDatabase(DefaultConfig().Port(9876).WithoutTcp().WithUnixSocketDirectory(tempPath)) + if err := database.Start(); err != nil { + shutdownDBAndFail(t, err, database) + } + + defer database.Stop() + + if _, err := os.Stat(fmt.Sprintf("%s/.s.PGSQL.9876", tempPath)); err != nil { + shutdownDBAndFail(t, err, database) + } +} diff --git a/examples/examples_test.go b/examples/examples_test.go index 9b100ac..85779c9 100644 --- a/examples/examples_test.go +++ b/examples/examples_test.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "reflect" "testing" @@ -27,7 +28,7 @@ func Test_GooseMigrations(t *testing.T) { } }() - db, err := connect() + db, err := connect(database.GetConnectionURL()) if err != nil { t.Fatal(err) } @@ -57,7 +58,7 @@ func Test_ZapioLogger(t *testing.T) { } }() - db, err := connect() + db, err := connect(database.GetConnectionURL()) if err != nil { t.Fatal(err) } @@ -79,7 +80,36 @@ func Test_Sqlx_SelectOne(t *testing.T) { } }() - db, err := connect() + db, err := connect(database.GetConnectionURL()) + if err != nil { + t.Fatal(err) + } + + rows := make([]int32, 0) + + err = db.Select(&rows, "SELECT 1") + if err != nil { + t.Fatal(err) + } + + if len(rows) != 1 { + t.Fatal("Expected one row returned") + } +} + +func Test_UnixSocket_Sqlx_SelectOne(t *testing.T) { + database := embeddedpostgres.NewDatabase(embeddedpostgres.DefaultConfig().WithoutTcp()) + if err := database.Start(); err != nil { + t.Fatal(err) + } + + defer func() { + if err := database.Stop(); err != nil { + t.Fatal(err) + } + }() + + db, err := connect(database.GetConnectionURL()) if err != nil { t.Fatal(err) } @@ -108,7 +138,7 @@ func Test_ManyTestsAgainstOneDatabase(t *testing.T) { } }() - db, err := connect() + db, err := connect(database.GetConnectionURL()) if err != nil { t.Fatal(err) } @@ -188,7 +218,18 @@ func Test_SimpleHttpWebApp(t *testing.T) { } } -func connect() (*sqlx.DB, error) { - db, err := sqlx.Connect("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable") +func connect(u string) (*sqlx.DB, error) { + parsed, err := url.Parse(u) + if err != nil { + return nil, err + } + + q := parsed.Query() + if q.Get("sock") == "" { + q.Set("sslmode", "disable") + } + parsed.RawQuery = q.Encode() + + db, err := sqlx.Connect("postgres", parsed.String()) return db, err } diff --git a/prepare_database.go b/prepare_database.go index 751aaea..d7c9fcc 100644 --- a/prepare_database.go +++ b/prepare_database.go @@ -19,7 +19,7 @@ const ( ) type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error -type createDatabase func(port uint32, username, password, database string) error +type createDatabase func(host string, port uint32, username, password, database string) error func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error { passwordFile, err := createPasswordFile(runtimePath, password) @@ -71,12 +71,12 @@ func createPasswordFile(runtimePath, password string) (string, error) { return passwordFileLocation, nil } -func defaultCreateDatabase(port uint32, username, password, database string) (err error) { +func defaultCreateDatabase(host string, port uint32, username, password, database string) (err error) { if database == "postgres" { return nil } - conn, err := openDatabaseConnection(port, username, password, "postgres") + conn, err := openDatabaseConnection(host, port, username, password, "postgres") if err != nil { return errorCustomDatabase(database, err) } @@ -120,7 +120,12 @@ func healthCheckDatabaseOrTimeout(config Config) error { go func() { for timeout.Err() == nil { - if err := healthCheckDatabase(config.port, config.database, config.username, config.password); err != nil { + host := "localhost" + if config.useUnixSocket { + host = config.unixSocketDirectory + } + + if err := healthCheckDatabase(host, config.port, config.database, config.username, config.password); err != nil { continue } healthCheckSignal <- true @@ -137,8 +142,8 @@ func healthCheckDatabaseOrTimeout(config Config) error { } } -func healthCheckDatabase(port uint32, database, username, password string) (err error) { - conn, err := openDatabaseConnection(port, username, password, database) +func healthCheckDatabase(host string, port uint32, database, username, password string) (err error) { + conn, err := openDatabaseConnection(host, port, username, password, database) if err != nil { return err } @@ -155,8 +160,9 @@ func healthCheckDatabase(port uint32, database, username, password string) (err return nil } -func openDatabaseConnection(port uint32, username string, password string, database string) (*pq.Connector, error) { - conn, err := pq.NewConnector(fmt.Sprintf("host=localhost port=%d user=%s password=%s dbname=%s sslmode=disable", +func openDatabaseConnection(host string, port uint32, username string, password string, database string) (*pq.Connector, error) { + conn, err := pq.NewConnector(fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + host, port, username, password, diff --git a/prepare_database_test.go b/prepare_database_test.go index 2700d27..477bf2c 100644 --- a/prepare_database_test.go +++ b/prepare_database_test.go @@ -132,7 +132,13 @@ func Test_defaultInitDatabase_PwFileRemoved(t *testing.T) { } func Test_defaultCreateDatabase_ErrorWhenSQLOpenError(t *testing.T) { - err := defaultCreateDatabase(1234, "user client_encoding=lol", "password", "database") + err := defaultCreateDatabase("localhost", 1234, "user client_encoding=lol", "password", "database") + + assert.EqualError(t, err, "unable to connect to create database with custom name database with the following error: client_encoding must be absent or 'UTF8'") +} + +func Test_defaultCreateDatabase_ErrorWhenSQLOpenError_UnixSocket(t *testing.T) { + err := defaultCreateDatabase("/tmp", 1234, "user client_encoding=lol", "password", "database") assert.EqualError(t, err, "unable to connect to create database with custom name database with the following error: client_encoding must be absent or 'UTF8'") } @@ -165,13 +171,13 @@ func Test_defaultCreateDatabase_ErrorWhenQueryError(t *testing.T) { } }() - err := defaultCreateDatabase(9831, "postgres", "postgres", "b33r") + err := defaultCreateDatabase("localhost", 9831, "postgres", "postgres", "b33r") assert.EqualError(t, err, `unable to connect to create database with custom name b33r with the following error: pq: database "b33r" already exists`) } func Test_healthCheckDatabase_ErrorWhenSQLConnectingError(t *testing.T) { - err := healthCheckDatabase(1234, "tom client_encoding=lol", "more", "b33r") + err := healthCheckDatabase("localhost", 1234, "tom client_encoding=lol", "more", "b33r") assert.EqualError(t, err, "client_encoding must be absent or 'UTF8'") } diff --git a/test_config.go b/test_config.go deleted file mode 100644 index 5646c9b..0000000 --- a/test_config.go +++ /dev/null @@ -1,12 +0,0 @@ -package embeddedpostgres - -import "testing" - -func TestGetConnectionURL(t *testing.T) { - config := DefaultConfig().Database("mydb").Username("myuser").Password("mypass") - expect := "postgresql://myuser:mypass@localhost:5432/mydb" - - if got := config.GetConnectionURL(); got != expect { - t.Errorf("expected \"%s\" got \"%s\"", expect, got) - } -}