diff --git a/pkg/snapshot/generator/postgres/schema/pgdumprestore/pg_options_generator.go b/pkg/snapshot/generator/postgres/schema/pgdumprestore/pg_options_generator.go index 556c31ac..4ab1fcba 100644 --- a/pkg/snapshot/generator/postgres/schema/pgdumprestore/pg_options_generator.go +++ b/pkg/snapshot/generator/postgres/schema/pgdumprestore/pg_options_generator.go @@ -98,16 +98,24 @@ func (o *optionGenerator) pgdumpOptions(ctx context.Context, schemaTables map[st Role: o.role, } - switch { - case hasWildcardSchema(schemaTables): - // no need to filter schemas, since we are including all of them + if hasWildcardSchema(schemaTables) && !o.includeGlobalDBObjects { + // wildcard schema without global objects: discover all user schemas + // and use schema inclusion filter to exclude global objects + allSchemas, err := o.discoverAllSchemas(ctx) + if err != nil { + return nil, err + } + opts.Schemas = allSchemas + } else if hasWildcardSchema(schemaTables) { + // wildcard schema with global objects: no filter needed, just + // exclude the pgstream internal schema opts.Schemas = nil opts.ExcludeSchemas = []string{pglib.QuoteIdentifier(pgstreamSchema)} - case o.includeGlobalDBObjects: - // instead of using the schema filter, we use the exclude schemas filter - // to make sure extensions and other database global objects are - // created. pg_dump will not include them when using the schema filter, - // since they do not belong to the schema. + } else if o.includeGlobalDBObjects { + // specific schemas with global objects: use exclude filter to make + // sure extensions and other database global objects are created. + // pg_dump will not include them when using the schema filter, since + // they do not belong to the schema. var err error opts.ExcludeSchemas, err = o.pgdumpExcludedSchemas(ctx, schemas) if err != nil { @@ -230,6 +238,10 @@ func (o *optionGenerator) pgdumpExcludedSchemas(ctx context.Context, includeSche return excludeSchemas, nil } +func (o *optionGenerator) discoverAllSchemas(ctx context.Context) ([]string, error) { + return pglib.DiscoverAllSchemas(ctx, o.querier) +} + func quoteSchema(schema string) string { if schema == wildcard { return wildcard diff --git a/pkg/snapshot/generator/postgres/schema/pgdumprestore/pg_options_generator_test.go b/pkg/snapshot/generator/postgres/schema/pgdumprestore/pg_options_generator_test.go index 17b9ef65..69eb146f 100644 --- a/pkg/snapshot/generator/postgres/schema/pgdumprestore/pg_options_generator_test.go +++ b/pkg/snapshot/generator/postgres/schema/pgdumprestore/pg_options_generator_test.go @@ -279,14 +279,32 @@ func TestOptionsGenerator_pgdumpOptions(t *testing.T) { includeGlobal: false, conn: &pglibmocks.Querier{ QueryFn: func(ctx context.Context, _ uint, query string, args ...any) (pglib.Rows, error) { - return nil, fmt.Errorf("QueryFn should not be called") + require.Equal(t, pglib.DiscoverAllSchemasQuery, query) + return &pglibmocks.Rows{ + NextFn: func(i uint) bool { return i <= 2 }, + ScanFn: func(i uint, dest ...any) error { + require.Len(t, dest, 1) + schema, ok := dest[0].(*string) + require.True(t, ok) + switch i { + case 1: + *schema = "public" + case 2: + *schema = "other" + } + return nil + }, + ErrFn: func() error { return nil }, + CloseFn: func() {}, + }, nil }, }, wantOpts: &pglib.PGDumpOptions{ ConnectionString: "source-url", Format: "p", - ExcludeSchemas: []string{`"pgstream"`, `"excluded_schema"`}, + Schemas: []string{"public", "other"}, + ExcludeSchemas: []string{`"excluded_schema"`}, SchemaOnly: true, }, wantErr: nil, @@ -300,18 +318,69 @@ func TestOptionsGenerator_pgdumpOptions(t *testing.T) { includeGlobal: false, conn: &pglibmocks.Querier{ QueryFn: func(ctx context.Context, _ uint, query string, args ...any) (pglib.Rows, error) { - require.Equal(t, selectTablesQuery, query) - require.Equal(t, []any{[]string{"table1", "table2"}}, args) + switch query { + case pglib.DiscoverAllSchemasQuery: + return &pglibmocks.Rows{ + NextFn: func(i uint) bool { return i == 1 }, + ScanFn: func(i uint, dest ...any) error { + require.Len(t, dest, 1) + schema, ok := dest[0].(*string) + require.True(t, ok) + *schema = "public" + return nil + }, + ErrFn: func() error { return nil }, + CloseFn: func() {}, + }, nil + case selectTablesQuery: + require.Equal(t, []any{[]string{"table1", "table2"}}, args) + return &pglibmocks.Rows{ + NextFn: func(i uint) bool { return i == 1 }, + ScanFn: func(i uint, dest ...any) error { + require.Len(t, dest, 2) + schema, ok := dest[0].(*string) + require.True(t, ok) + *schema = "public" + table, ok := dest[1].(*string) + require.True(t, ok) + *table = "table3" + return nil + }, + ErrFn: func() error { return nil }, + CloseFn: func() {}, + }, nil + default: + return nil, fmt.Errorf("unexpected query: %s", query) + } + }, + }, + + wantOpts: &pglib.PGDumpOptions{ + ConnectionString: "source-url", + Format: "p", + Schemas: []string{"public"}, + SchemaOnly: true, + ExcludeTables: []string{`"public"."table3"`}, + }, + wantErr: nil, + }, + { + name: "wildcard schema and wildcard tables", + schemaTables: map[string][]string{ + "*": {"*"}, + }, + excludedTables: map[string][]string{}, + includeGlobal: false, + conn: &pglibmocks.Querier{ + QueryFn: func(ctx context.Context, _ uint, query string, args ...any) (pglib.Rows, error) { + require.Equal(t, pglib.DiscoverAllSchemasQuery, query) return &pglibmocks.Rows{ NextFn: func(i uint) bool { return i == 1 }, ScanFn: func(i uint, dest ...any) error { - require.Len(t, dest, 2) + require.Len(t, dest, 1) schema, ok := dest[0].(*string) require.True(t, ok) *schema = "public" - table, ok := dest[1].(*string) - require.True(t, ok) - *table = "table3" return nil }, ErrFn: func() error { return nil }, @@ -323,19 +392,18 @@ func TestOptionsGenerator_pgdumpOptions(t *testing.T) { wantOpts: &pglib.PGDumpOptions{ ConnectionString: "source-url", Format: "p", - ExcludeSchemas: []string{`"pgstream"`}, + Schemas: []string{"public"}, SchemaOnly: true, - ExcludeTables: []string{`"public"."table3"`}, }, wantErr: nil, }, { - name: "wildcard schema and wildcard tables", + name: "wildcard schema and wildcard tables with include global objects enabled", schemaTables: map[string][]string{ "*": {"*"}, }, excludedTables: map[string][]string{}, - includeGlobal: false, + includeGlobal: true, conn: &pglibmocks.Querier{ QueryFn: func(ctx context.Context, _ uint, query string, args ...any) (pglib.Rows, error) { return nil, errors.New("QueryFn should not be called")