Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions database/dbJob.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ func NewJobDBHandler(dbConnection *helper.Database, withTableDrop bool, encrypti
jobDbHandler.EncryptionKey = encryptionKey[0]
}

if withTableDrop {
err := dbConnection.DropFunctionsFromPublicSchema(loadSql.JobFunctions)
if err != nil {
return nil, helper.NewError("drop job functions", err)
}

err = jobDbHandler.DropTables()
if err != nil {
return nil, helper.NewError("drop tables", err)
}
}

err := loadSql.LoadJobSql(jobDbHandler.db.Instance, false)
if err != nil {
return nil, helper.NewError("load job sql", err)
Expand All @@ -70,13 +82,6 @@ func NewJobDBHandler(dbConnection *helper.Database, withTableDrop bool, encrypti
return nil, helper.NewError("load notify sql", err)
}

if withTableDrop {
err := jobDbHandler.DropTables()
if err != nil {
return nil, helper.NewError("drop tables", err)
}
}

err = jobDbHandler.CreateTable()
if err != nil {
return nil, helper.NewError("create table", err)
Expand Down
15 changes: 10 additions & 5 deletions database/dbMaster.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,23 @@ func NewMasterDBHandler(dbConnection *helper.Database, withTableDrop bool) (*Mas
db: dbConnection,
}

err := loadSql.LoadMasterSql(masterDbHandler.db.Instance, false)
if err != nil {
return nil, helper.NewError("load master sql", err)
}

if withTableDrop {
err := dbConnection.DropFunctionsFromPublicSchema(loadSql.MasterFunctions)
if err != nil {
return nil, helper.NewError("drop master functions", err)
}

err = masterDbHandler.DropTable()
if err != nil {
return nil, helper.NewError("drop master table", err)
}
}

err := loadSql.LoadMasterSql(masterDbHandler.db.Instance, false)
if err != nil {
return nil, helper.NewError("load master sql", err)
}

err = masterDbHandler.CreateTable()
if err != nil {
return nil, helper.NewError("create master table", err)
Expand Down
17 changes: 11 additions & 6 deletions database/dbWorker.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,23 @@ func NewWorkerDBHandler(dbConnection *helper.Database, withTableDrop bool) (*Wor
db: dbConnection,
}

err := loadSql.LoadWorkerSql(dbConnection.Instance, withTableDrop)
if err != nil {
return nil, helper.NewError("load worker sql", err)
}

if withTableDrop {
err := workerDbHandler.DropTable()
err := dbConnection.DropFunctionsFromPublicSchema(loadSql.WorkerFunctions)
if err != nil {
return nil, helper.NewError("drop worker functions", err)
}

err = workerDbHandler.DropTable()
if err != nil {
return nil, helper.NewError("drop worker table", err)
}
}

err := loadSql.LoadWorkerSql(dbConnection.Instance, withTableDrop)
if err != nil {
return nil, helper.NewError("load worker sql", err)
}

err = workerDbHandler.CreateTable()
if err != nil {
return nil, helper.NewError("create worker table", err)
Expand Down
49 changes: 49 additions & 0 deletions helper/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,52 @@ func (d *Database) Close() error {
log.Printf("Disconnected from database: %v", d.Instance)
return d.Instance.Close()
}

// DropFunctionsFromPublicSchema drops user-defined functions from the public schema.
// It queries pg_proc to find all overloaded versions of the specified functions,
// filters to only public schema functions, excludes extension-owned functions,
// and drops each by its full signature. This is SQL injection safe as it uses
// parameterized queries for function name lookup.
func (d *Database) DropFunctionsFromPublicSchema(functionNames []string) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

for _, functionName := range functionNames {
// Query for all overloaded versions, filtering to public schema and excluding extensions
rows, queryErr := d.Instance.QueryContext(ctx, `
SELECT pg_proc.oid::regprocedure::text
FROM pg_proc
JOIN pg_namespace ON pg_proc.pronamespace = pg_namespace.oid
WHERE pg_proc.proname = $1
AND pg_namespace.nspname = 'public'
AND NOT EXISTS (
SELECT 1 FROM pg_depend
WHERE objid = pg_proc.oid AND deptype = 'e'
)
`, functionName)
if queryErr != nil {
return NewError("query function overloads "+functionName, queryErr)
}
defer rows.Close()

var signatures []string
for rows.Next() {
var signature string
if scanErr := rows.Scan(&signature); scanErr != nil {
return NewError("scan function signature", scanErr)
}
signatures = append(signatures, signature)
}

// Drop each overloaded function by its full signature
for _, signature := range signatures {
dropQuery := fmt.Sprintf(`DROP FUNCTION IF EXISTS %s;`, signature)
_, err := d.Instance.ExecContext(ctx, dropQuery)
if err != nil {
return NewError("drop function "+signature, err)
}
}
}

return nil
}
6 changes: 3 additions & 3 deletions queuerJob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func TestAddJobRunning(t *testing.T) {
t.Run("Successfully runs a job with schedule options once", func(t *testing.T) {
options := &model.Options{
Schedule: &model.Schedule{
Start: time.Now().Add(1 * time.Second),
Start: time.Now().Add(2 * time.Second),
MaxCount: 1,
Interval: 15 * time.Second,
},
Expand All @@ -144,8 +144,8 @@ func TestAddJobRunning(t *testing.T) {

queuedJob, err := testQueuer.GetJob(job.RID)
require.NoError(t, err, "GetJob should not return an error")
require.NotNil(t, queuedJob, "GetJob should return the job that is currently running")
assert.Equal(t, model.JobStatusScheduled, queuedJob.Status, "Job should be in Running status")
require.NotNil(t, queuedJob, "GetJob should return the scheduled job")
assert.Equal(t, model.JobStatusScheduled, queuedJob.Status, "Job should be in Scheduled status")

job = testQueuer.WaitForJobFinished(job.RID, 5*time.Second)
assert.NotNil(t, job, "WaitForJobFinished should return the finished job")
Expand Down