From fd272d9b9f39b683675d05a660cd80cee393f7f8 Mon Sep 17 00:00:00 2001 From: Eric Klatzer Date: Sat, 26 Oct 2024 22:59:03 +0200 Subject: [PATCH] implemented optional logger to log details when applying migrations --- log.go | 11 +++++++ log/log.go | 70 +++++++++++++++++++++++++++++++++++++++++++ log/log_test.go | 79 +++++++++++++++++++++++++++++++++++++++++++++++++ migrate.go | 50 +++++++++++++++++++++++-------- migrate_test.go | 52 ++++++++++++++++++++++++++++++++ 5 files changed, 249 insertions(+), 13 deletions(-) create mode 100644 log.go create mode 100644 log/log.go create mode 100644 log/log_test.go diff --git a/log.go b/log.go new file mode 100644 index 00000000..c43c8d22 --- /dev/null +++ b/log.go @@ -0,0 +1,11 @@ +package migrate + +import ( + "context" +) + +type Logger interface { + Info(context.Context, string, ...any) + Warn(context.Context, string, ...any) + Error(context.Context, string, ...any) +} diff --git a/log/log.go b/log/log.go new file mode 100644 index 00000000..d95f3a8d --- /dev/null +++ b/log/log.go @@ -0,0 +1,70 @@ +package log + +import ( + "context" + "fmt" +) + +// Level is a type to represent the log level +type Level int + +const ( + LevelSilent Level = iota + LevelError + LevelWarn + LevelInfo +) + +type DefaultLogger struct { + level Level + writer Writer +} + +type Writer interface { + Printf(string, ...interface{}) +} + +// DefaultLogWriter is a default implementation of the Writer interface +// that writes using fmt.Printf +type DefaultLogWriter struct{} + +func (w *DefaultLogWriter) Printf(format string, args ...interface{}) { + fmt.Printf(format, args...) +} + +// NewDefaultLogger creates a new DefaultLogger with a silent log level +func NewDefaultLogger() *DefaultLogger { + return &DefaultLogger{ + level: LevelSilent, + writer: &DefaultLogWriter{}, + } +} + +func (l *DefaultLogger) WithLevel(level Level) *DefaultLogger { + l.level = level + return l +} + +func (l *DefaultLogger) WithWriter(writer Writer) *DefaultLogger { + l.writer = writer + return l +} + +func (l *DefaultLogger) Info(_ context.Context, format string, args ...interface{}) { + l.logIfPermittedByLevel(LevelInfo, format, args...) +} + +func (l *DefaultLogger) Warn(_ context.Context, format string, args ...interface{}) { + l.logIfPermittedByLevel(LevelWarn, format, args...) +} + +func (l *DefaultLogger) Error(_ context.Context, format string, args ...interface{}) { + l.logIfPermittedByLevel(LevelError, format, args...) +} + +func (l *DefaultLogger) logIfPermittedByLevel(requiredLevel Level, format string, args ...interface{}) { + if l.level < requiredLevel { + return + } + l.writer.Printf(format, args...) +} diff --git a/log/log_test.go b/log/log_test.go new file mode 100644 index 00000000..35d880d5 --- /dev/null +++ b/log/log_test.go @@ -0,0 +1,79 @@ +package log_test + +import ( + "context" + "fmt" + "testing" + + "github.com/rubenv/sql-migrate/log" +) + +type mockWriter struct { + logs []string +} + +func (mw *mockWriter) Printf(format string, args ...interface{}) { + mw.logs = append(mw.logs, fmt.Sprintf(format, args...)) +} + +func TestDefaultLoggerWithLevelInfo(t *testing.T) { + mockWriter := &mockWriter{logs: []string{}} + + logger := log.NewDefaultLogger().WithLevel(log.LevelInfo).WithWriter(mockWriter) + logger.Info(context.Background(), "This should be logged") + logger.Warn(context.Background(), "This should also be logged") + logger.Error(context.Background(), "This should also be logged") + + expectedLogs := []string{ + "This should be logged", + "This should also be logged", + "This should also be logged", + } + + if len(mockWriter.logs) != len(expectedLogs) { + t.Fatalf("Expected %d logs, got %d", len(expectedLogs), len(mockWriter.logs)) + } + + for i, expectedLog := range expectedLogs { + if expectedLog != mockWriter.logs[i] { + t.Fatalf("Expected log %d to be %s, got %s", i, expectedLog, mockWriter.logs[i]) + } + } +} + +func TestDefaultLoggerWithLevelSilent(t *testing.T) { + mockWriter := &mockWriter{logs: []string{}} + + logger := log.NewDefaultLogger().WithLevel(log.LevelSilent).WithWriter(mockWriter) + logger.Info(context.Background(), "This should not be logged") + logger.Warn(context.Background(), "This should not be logged") + logger.Error(context.Background(), "This should not be logged") + + if len(mockWriter.logs) != 0 { + t.Fatalf("Expected no logs, got %d", len(mockWriter.logs)) + } +} + +func TestDefaultLoggerWithLevelWarn(t *testing.T) { + mockWriter := &mockWriter{logs: []string{}} + + logger := log.NewDefaultLogger().WithLevel(log.LevelWarn).WithWriter(mockWriter) + logger.Info(context.Background(), "This should not be logged") + logger.Warn(context.Background(), "This should be logged") + logger.Error(context.Background(), "This should also be logged") + + expectedLogs := []string{ + "This should be logged", + "This should also be logged", + } + + if len(mockWriter.logs) != len(expectedLogs) { + t.Fatalf("Expected %d logs, got %d", len(expectedLogs), len(mockWriter.logs)) + } + + for i, expectedLog := range expectedLogs { + if expectedLog != mockWriter.logs[i] { + t.Fatalf("Expected log %d to be %s, got %s", i, expectedLog, mockWriter.logs[i]) + } + } +} diff --git a/migrate.go b/migrate.go index 7fb56f1a..68020718 100644 --- a/migrate.go +++ b/migrate.go @@ -19,6 +19,7 @@ import ( "github.com/go-gorp/gorp/v3" + "github.com/rubenv/sql-migrate/log" "github.com/rubenv/sql-migrate/sqlparse" ) @@ -42,12 +43,14 @@ type MigrationSet struct { IgnoreUnknown bool // DisableCreateTable disable the creation of the migration table DisableCreateTable bool + // Logger is used to log additional information during the migration process. + Logger Logger } var migSet = MigrationSet{} // NewMigrationSet returns a parametrized Migration object -func (ms MigrationSet) getTableName() string { +func (ms *MigrationSet) getTableName() string { if ms.TableName == "" { return "gorp_migrations" } @@ -124,6 +127,10 @@ func SetIgnoreUnknown(v bool) { migSet.IgnoreUnknown = v } +func SetLogger(l Logger) { + migSet.Logger = l +} + type Migration struct { Id string Up []string @@ -448,7 +455,7 @@ func Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) } // Returns the number of applied migrations. -func (ms MigrationSet) Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) { +func (ms *MigrationSet) Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) { return ms.ExecMaxContext(context.Background(), db, dialect, m, dir, 0) } @@ -460,7 +467,7 @@ func ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSou } // Returns the number of applied migrations. -func (ms MigrationSet) ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) { +func (ms *MigrationSet) ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) { return ms.ExecMaxContext(ctx, db, dialect, m, dir, 0) } @@ -504,12 +511,12 @@ func ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m Migra } // Returns the number of applied migrations. -func (ms MigrationSet) ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) { +func (ms *MigrationSet) ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) { return ms.ExecMaxContext(context.Background(), db, dialect, m, dir, max) } // Returns the number of applied migrations, but applies with an input context. -func (ms MigrationSet) ExecMaxContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) { +func (ms *MigrationSet) ExecMaxContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) { migrations, dbMap, err := ms.PlanMigration(db, dialect, m, dir, max) if err != nil { return 0, err @@ -518,11 +525,11 @@ func (ms MigrationSet) ExecMaxContext(ctx context.Context, db *sql.DB, dialect s } // Returns the number of applied migrations. -func (ms MigrationSet) ExecVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) { +func (ms *MigrationSet) ExecVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) { return ms.ExecVersionContext(context.Background(), db, dialect, m, dir, version) } -func (ms MigrationSet) ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) { +func (ms *MigrationSet) ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) { migrations, dbMap, err := ms.PlanMigrationToVersion(db, dialect, m, dir, version) if err != nil { return 0, err @@ -531,9 +538,11 @@ func (ms MigrationSet) ExecVersionContext(ctx context.Context, db *sql.DB, diale } // Applies the planned migrations and returns the number of applied migrations. -func (MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection, migrations []*PlannedMigration, dbMap *gorp.DbMap) (int, error) { +func (m MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection, migrations []*PlannedMigration, dbMap *gorp.DbMap) (int, error) { applied := 0 for _, migration := range migrations { + m.logger().Info(ctx, "Applying migration %s", migration.Id) + var executor SqlExecutor var err error @@ -563,6 +572,8 @@ func (MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection, switch dir { case Up: + m.logger().Info(ctx, "Migrating up %s", migration.Id) + err = executor.Insert(&MigrationRecord{ Id: migration.Id, AppliedAt: time.Now(), @@ -575,6 +586,8 @@ func (MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection, return applied, newTxError(migration, err) } case Down: + m.logger().Info(ctx, "Migrating down %s", migration.Id) + _, err := executor.Delete(&MigrationRecord{ Id: migration.Id, }) @@ -590,12 +603,16 @@ func (MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection, } if trans, ok := executor.(*gorp.Transaction); ok { + m.logger().Info(ctx, "Committing transaction for %s", migration.Id) + if err := trans.Commit(); err != nil { return applied, newTxError(migration, err) } } applied++ + + m.logger().Info(ctx, "Applied %d/%d migrations", applied, len(migrations)) } return applied, nil @@ -612,17 +629,17 @@ func PlanMigrationToVersion(db *sql.DB, dialect string, m MigrationSource, dir M } // Plan a migration. -func (ms MigrationSet) PlanMigration(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) ([]*PlannedMigration, *gorp.DbMap, error) { +func (ms *MigrationSet) PlanMigration(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) ([]*PlannedMigration, *gorp.DbMap, error) { return ms.planMigrationCommon(db, dialect, m, dir, max, -1) } // Plan a migration to version. -func (ms MigrationSet) PlanMigrationToVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) ([]*PlannedMigration, *gorp.DbMap, error) { +func (ms *MigrationSet) PlanMigrationToVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) ([]*PlannedMigration, *gorp.DbMap, error) { return ms.planMigrationCommon(db, dialect, m, dir, 0, version) } // A common method to plan a migration. -func (ms MigrationSet) planMigrationCommon(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int, version int64) ([]*PlannedMigration, *gorp.DbMap, error) { +func (ms *MigrationSet) planMigrationCommon(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int, version int64) ([]*PlannedMigration, *gorp.DbMap, error) { dbMap, err := ms.getMigrationDbMap(db, dialect) if err != nil { return nil, nil, err @@ -822,7 +839,7 @@ func GetMigrationRecords(db *sql.DB, dialect string) ([]*MigrationRecord, error) return migSet.GetMigrationRecords(db, dialect) } -func (ms MigrationSet) GetMigrationRecords(db *sql.DB, dialect string) ([]*MigrationRecord, error) { +func (ms *MigrationSet) GetMigrationRecords(db *sql.DB, dialect string) ([]*MigrationRecord, error) { dbMap, err := ms.getMigrationDbMap(db, dialect) if err != nil { return nil, err @@ -838,7 +855,14 @@ func (ms MigrationSet) GetMigrationRecords(db *sql.DB, dialect string) ([]*Migra return records, nil } -func (ms MigrationSet) getMigrationDbMap(db *sql.DB, dialect string) (*gorp.DbMap, error) { +func (ms *MigrationSet) logger() Logger { + if migSet.Logger == nil { + migSet.Logger = log.NewDefaultLogger() + } + return migSet.Logger +} + +func (ms *MigrationSet) getMigrationDbMap(db *sql.DB, dialect string) (*gorp.DbMap, error) { d, ok := MigrationDialects[dialect] if !ok { return nil, fmt.Errorf("Unknown dialect: %s", dialect) diff --git a/migrate_test.go b/migrate_test.go index f1d66d6f..c4055974 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "embed" + "fmt" "net/http" "time" @@ -63,6 +64,57 @@ func (s *SqliteMigrateSuite) TestRunMigration(c *C) { c.Assert(n, Equals, 0) } +type memoryLogger struct { + logs []string +} + +func (l *memoryLogger) appendLogWithLevel(level string, format string, args ...interface{}) { + l.logs = append(l.logs, fmt.Sprintf("%s:%s", level, fmt.Sprintf(format, args...))) +} + +func (l *memoryLogger) Info(_ context.Context, format string, args ...interface{}) { + l.appendLogWithLevel("INFO", format, args...) +} + +func (l *memoryLogger) Warn(_ context.Context, format string, args ...interface{}) { + l.appendLogWithLevel("WARN", format, args...) +} + +func (l *memoryLogger) Error(_ context.Context, format string, args ...interface{}) { + l.appendLogWithLevel("ERROR", format, args...) +} + +func (s *SqliteMigrateSuite) TestRunMigrationWithLogger(c *C) { + logger := &memoryLogger{} + SetLogger(logger) + + migrations := &MemoryMigrationSource{ + Migrations: sqliteMigrations, + } + + // Execute migrations + n, err := Exec(s.Db, "sqlite3", migrations, Up) + c.Assert(err, IsNil) + c.Assert(n, Equals, len(sqliteMigrations)) + + // Can use table now + _, err = s.DbMap.Exec("SELECT * FROM people") + c.Assert(err, IsNil) + + // Check logs + c.Assert(logger.logs, HasLen, 8) + c.Assert(logger.logs, Equals, []string{ + "INFO:Applying migration 123", + "INFO:Migrating up 123", + "INFO:Committing transaction for 123", + "INFO:Applied 1/2 migrations", + "INFO:Applying migration 124", + "INFO:Migrating up 124", + "INFO:Committing transaction for 124", + "INFO:Applied 2/2 migrations", + }) +} + func (s *SqliteMigrateSuite) TestRunMigrationEscapeTable(c *C) { migrations := &MemoryMigrationSource{ Migrations: sqliteMigrations[:1],