diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index e681e5a1b..733cf2601 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -19,6 +19,24 @@ Here are a few things you can do that will increase the likelihood of your pull - Keep your change as focused as possible. If there are multiple changes you would like to make that are not dependent upon each other, consider submitting them as separate pull requests. - Write a [good commit message](http://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html). +## Development Guidelines + +### Channel Safety + +When working with channels in goroutines, it's critical to prevent deadlocks that can occur when a channel receiver exits due to an error while senders are still trying to send values. Always use `base.SendWithContext` for channel sends to avoid deadlocks: + +```go +// ✅ CORRECT - Uses helper to prevent deadlock +if err := base.SendWithContext(ctx, ch, value); err != nil { + return err // context was cancelled +} + +// ❌ WRONG - Can deadlock if receiver exits +ch <- value +``` + +Even if the destination channel is buffered, deadlocks could still occur if the buffer fills up and the receiver exits, so it's important to use `SendWithContext` in those cases as well. + ## Resources - [Contributing to Open Source on GitHub](https://guides.github.com/activities/contributing-to-open-source/) diff --git a/.github/workflows/replica-tests.yml b/.github/workflows/replica-tests.yml index 957d7a176..5a9bfb7f7 100644 --- a/.github/workflows/replica-tests.yml +++ b/.github/workflows/replica-tests.yml @@ -28,6 +28,20 @@ jobs: - name: Run tests run: script/docker-gh-ost-replica-tests run + - name: Set artifact name + if: failure() + run: | + ARTIFACT_NAME=$(echo "${{ matrix.image }}" | tr '/:' '-') + echo "ARTIFACT_NAME=test-logs-${ARTIFACT_NAME}" >> $GITHUB_ENV + + - name: Upload test logs on failure + if: failure() + uses: actions/upload-artifact@v4 + with: + name: ${{ env.ARTIFACT_NAME }} + path: /tmp/gh-ost-test.* + retention-days: 7 + - name: Teardown environment if: always() run: script/docker-gh-ost-replica-tests down diff --git a/doc/command-line-flags.md b/doc/command-line-flags.md index c706a4fe4..2012c8c4c 100644 --- a/doc/command-line-flags.md +++ b/doc/command-line-flags.md @@ -24,6 +24,11 @@ By default, `gh-ost` would like you to connect to a replica, from where it figur If, for some reason, you do not wish `gh-ost` to connect to a replica, you may connect it directly to the master and approve this via `--allow-on-master`. +### allow-setup-metadata-lock-instruments + +`--allow-setup-metadata-lock-instruments` allows gh-ost to enable the [`metadata_locks`](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-metadata-locks-table.html) table in `performance_schema`, if it is not already enabled. This is used for a safety check before cut-over. +See also: [`skip-metadata-lock-check`](#skip-metadata-lock-check) + ### approve-renamed-columns When your migration issues a column rename (`change column old_name new_name ...`) `gh-ost` analyzes the statement to try and associate the old column name with new column name. Otherwise, the new structure may also look like some column was dropped and another was added. @@ -247,6 +252,13 @@ Defaults to an auto-determined and advertised upon startup file. Defines Unix so By default `gh-ost` verifies no foreign keys exist on the migrated table. On servers with large number of tables this check can take a long time. If you're absolutely certain no foreign keys exist (table does not reference other table nor is referenced by other tables) and wish to save the check time, provide with `--skip-foreign-key-checks`. +### skip-metadata-lock-check + +By default `gh-ost` performs a check before the cut-over to ensure the rename session holds the exclusive metadata lock on the table. In case `performance_schema.metadata_locks` cannot be enabled on your setup, this check can be skipped with `--skip-metadata-lock-check`. +:warning: Disabling this check involves the small chance of data loss in case a session accesses the ghost table during cut-over. See https://github.com/github/gh-ost/pull/1536 for details. + +See also: [`allow-setup-metadata-lock-instruments`](#allow-setup-metadata-lock-instruments) + ### skip-strict-mode By default `gh-ost` enforces STRICT_ALL_TABLES sql_mode as a safety measure. In some cases this changes the behaviour of other modes (namely ERROR_FOR_DIVISION_BY_ZERO, NO_ZERO_DATE, and NO_ZERO_IN_DATE) which may lead to errors during migration. Use `--skip-strict-mode` to explicitly tell `gh-ost` not to enforce this. **Danger** This may have some unexpected disastrous side effects. diff --git a/doc/hooks.md b/doc/hooks.md index c1fe59453..0fa9fb628 100644 --- a/doc/hooks.md +++ b/doc/hooks.md @@ -49,6 +49,7 @@ The full list of supported hooks is best found in code: [hooks.go](https://githu - `gh-ost-on-before-cut-over` - `gh-ost-on-success` - `gh-ost-on-failure` +- `gh-ost-on-batch-copy-retry` ### Context @@ -76,11 +77,14 @@ The following variables are available on all hooks: - `GH_OST_HOOKS_HINT_OWNER` - copy of `--hooks-hint-owner` value - `GH_OST_HOOKS_HINT_TOKEN` - copy of `--hooks-hint-token` value - `GH_OST_DRY_RUN` - whether or not the `gh-ost` run is a dry run +- `GH_OST_REVERT` - whether or not `gh-ost` is running in revert mode The following variable are available on particular hooks: +- `GH_OST_INSTANT_DDL` is only available in `gh-ost-on-success`. The value is `true` if instant DDL was successful, and `false` if it was not. - `GH_OST_COMMAND` is only available in `gh-ost-on-interactive-command` - `GH_OST_STATUS` is only available in `gh-ost-on-status` +- `GH_OST_LAST_BATCH_COPY_ERROR` is only available in `gh-ost-on-batch-copy-retry` ### Examples diff --git a/doc/resume.md b/doc/resume.md index 46d6ac909..6f12ab7a3 100644 --- a/doc/resume.md +++ b/doc/resume.md @@ -4,6 +4,7 @@ - The first `gh-ost` process was invoked with `--checkpoint` - The first `gh-ost` process had at least one successful checkpoint - The binlogs from the last checkpoint's binlog coordinates still exist on the replica gh-ost is inspecting (specified by `--host`) +- The checkpoint table (name ends with `_ghk`) still exists To resume, invoke `gh-ost` again with the same arguments with the `--resume` flag. diff --git a/doc/revert.md b/doc/revert.md new file mode 100644 index 000000000..6737da1b2 --- /dev/null +++ b/doc/revert.md @@ -0,0 +1,56 @@ +# Reverting Migrations + +`gh-ost` can attempt to revert a previously completed migration if the follow conditions are met: +- The first `gh-ost` process was invoked with `--checkpoint` +- The checkpoint table (name ends with `_ghk`) still exists +- The binlogs from the time of the migration's cut-over still exist on the replica gh-ost is inspecting (specified by `--host`) + +To revert, find the name of the "old" table from the original migration e.g. `_mytable_del`. Then invoke `gh-ost` with the same arguments and the flags `--revert` and `--old-table="_mytable_del"`. +gh-ost will read the binlog coordinates of the original cut-over from the checkpoint table and bring the old table up to date. Then it performs another cut-over to complete the reversion. +Note that the checkpoint table (name ends with _ghk) will not be automatically dropped unless `--ok-to-drop-table` is provided. + +> [!WARNING] +> It is recommended use `--checkpoint` with `--gtid` enabled so that checkpoint binlog coordinates store GTID sets rather than file positions. In that case, `gh-ost` can revert using a different replica than it originally attached to. + +### ❗ Note ❗ +Reverting is roughly equivalent to applying the "reverse" migration. _Before attempting to revert you should determine if the reverse migration is possible and does not involve any unacceptable data loss._ + +For example: if the original migration drops a `NOT NULL` column that has no `DEFAULT` then the reverse migration adds the column. In this case, the reverse migration is impossible if rows were added after the original cut-over and the revert will fail. +Another example: if the original migration modifies a `VARCHAR(32)` column to `VARCHAR(64)`, the reverse migration truncates the `VARCHAR(64)` column to `VARCHAR(32)`. If values were inserted with length > 32 after the cut-over then the revert will fail. + + +## Example +The migration starts with a `gh-ost` invocation such as: +```shell +gh-ost \ +--chunk-size=100 \ +--host=replica1.company.com \ +--database="mydb" \ +--table="mytable" \ +--alter="drop key idx1" +--gtid \ +--checkpoint \ +--checkpoint-seconds=60 \ +--execute +``` + +In this example `gh-ost` writes a cut-over checkpoint to `_mytable_ghk` after the cut-over is successful. The original table is renamed to `_mytable_del`. + +Suppose that dropping the index causes problems, the migration can be revert with: +```shell +# revert migration +gh-ost \ +--chunk-size=100 \ +--host=replica1.company.com \ +--database="mydb" \ +--table="mytable" \ +--old-table="_mytable_del" +--gtid \ +--checkpoint \ +--checkpoint-seconds=60 \ +--revert \ +--execute +``` + +gh-ost then reconnects at the binlog coordinates stored in the cut-over checkpoint and applies DMLs until the old table is up-to-date. +Note that the "reverse" migration is `ADD KEY idx(...)` so there is no potential data loss to consider in this case. diff --git a/go/base/context.go b/go/base/context.go index 35030e30b..2c8d28d56 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -6,6 +6,7 @@ package base import ( + "context" "fmt" "math" "os" @@ -104,6 +105,8 @@ type MigrationContext struct { AzureMySQL bool AttemptInstantDDL bool Resume bool + Revert bool + OldTableName string // SkipPortValidation allows skipping the port validation in `ValidateConnection` // This is useful when connecting to a MySQL instance where the external port @@ -223,6 +226,16 @@ type MigrationContext struct { InCutOverCriticalSectionFlag int64 PanicAbort chan error + // Context for cancellation signaling across all goroutines + // Stored in struct as it spans the entire migration lifecycle, not per-function. + // context.Context is safe for concurrent use by multiple goroutines. + ctx context.Context //nolint:containedctx + cancelFunc context.CancelFunc + + // Stores the fatal error that triggered abort + AbortError error + abortMutex *sync.Mutex + OriginalTableColumnsOnApplier *sql.ColumnList OriginalTableColumns *sql.ColumnList OriginalTableVirtualColumns *sql.ColumnList @@ -254,6 +267,7 @@ type MigrationContext struct { BinlogSyncerMaxReconnectAttempts int AllowSetupMetadataLockInstruments bool + SkipMetadataLockCheck bool IsOpenMetadataLockInstruments bool Log Logger @@ -290,6 +304,7 @@ type ContextConfig struct { } func NewMigrationContext() *MigrationContext { + ctx, cancelFunc := context.WithCancel(context.Background()) return &MigrationContext{ Uuid: uuid.NewString(), defaultNumRetries: 60, @@ -310,6 +325,9 @@ func NewMigrationContext() *MigrationContext { lastHeartbeatOnChangelogMutex: &sync.Mutex{}, ColumnRenameMap: make(map[string]string), PanicAbort: make(chan error), + ctx: ctx, + cancelFunc: cancelFunc, + abortMutex: &sync.Mutex{}, Log: NewDefaultLogger(), } } @@ -348,6 +366,10 @@ func getSafeTableName(baseName string, suffix string) string { // GetGhostTableName generates the name of ghost table, based on original table name // or a given table name func (this *MigrationContext) GetGhostTableName() string { + if this.Revert { + // When reverting the "ghost" table is the _del table from the original migration. + return this.OldTableName + } if this.ForceTmpTableName != "" { return getSafeTableName(this.ForceTmpTableName, "gho") } else { @@ -364,14 +386,18 @@ func (this *MigrationContext) GetOldTableName() string { tableName = this.OriginalTableName } + suffix := "del" + if this.Revert { + suffix = "rev_del" + } if this.TimestampOldTable { t := this.StartTime timestamp := fmt.Sprintf("%d%02d%02d%02d%02d%02d", t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second()) - return getSafeTableName(tableName, fmt.Sprintf("%s_del", timestamp)) + return getSafeTableName(tableName, fmt.Sprintf("%s_%s", timestamp, suffix)) } - return getSafeTableName(tableName, "del") + return getSafeTableName(tableName, suffix) } // GetChangelogTableName generates the name of changelog table, based on original table name @@ -600,6 +626,13 @@ func (this *MigrationContext) GetIteration() int64 { return atomic.LoadInt64(&this.Iteration) } +func (this *MigrationContext) SetNextIterationRangeMinValues() { + this.MigrationIterationRangeMinValues = this.MigrationIterationRangeMaxValues + if this.MigrationIterationRangeMinValues == nil { + this.MigrationIterationRangeMinValues = this.MigrationRangeMinValues + } +} + func (this *MigrationContext) MarkPointOfInterest() int64 { this.pointOfInterestTimeMutex.Lock() defer this.pointOfInterestTimeMutex.Unlock() @@ -959,9 +992,59 @@ func (this *MigrationContext) GetGhostTriggerName(triggerName string) string { return triggerName + this.TriggerSuffix } -// validateGhostTriggerLength check if the ghost trigger name length is not more than 64 characters +// ValidateGhostTriggerLengthBelowMaxLength checks if the given trigger name (already transformed +// by GetGhostTriggerName) does not exceed the maximum allowed length. func (this *MigrationContext) ValidateGhostTriggerLengthBelowMaxLength(triggerName string) bool { - ghostTriggerName := this.GetGhostTriggerName(triggerName) + return utf8.RuneCountInString(triggerName) <= mysql.MaxTableNameLength +} + +// GetContext returns the migration context for cancellation checking +func (this *MigrationContext) GetContext() context.Context { + return this.ctx +} - return utf8.RuneCountInString(ghostTriggerName) <= mysql.MaxTableNameLength +// SetAbortError stores the fatal error that triggered abort +// Only the first error is stored (subsequent errors are ignored) +func (this *MigrationContext) SetAbortError(err error) { + this.abortMutex.Lock() + defer this.abortMutex.Unlock() + if this.AbortError == nil { + this.AbortError = err + } +} + +// GetAbortError retrieves the stored abort error +func (this *MigrationContext) GetAbortError() error { + this.abortMutex.Lock() + defer this.abortMutex.Unlock() + return this.AbortError +} + +// CancelContext cancels the migration context to signal all goroutines to stop +// The cancel function is safe to call multiple times and from multiple goroutines. +func (this *MigrationContext) CancelContext() { + if this.cancelFunc != nil { + this.cancelFunc() + } +} + +// SendWithContext attempts to send a value to a channel, but returns early +// if the context is cancelled. This prevents goroutine deadlocks when the +// channel receiver has exited due to an error. +// +// Use this instead of bare channel sends (ch <- val) in goroutines to ensure +// proper cleanup when the migration is aborted. +// +// Example: +// +// if err := base.SendWithContext(ctx, ch, value); err != nil { +// return err // context was cancelled +// } +func SendWithContext[T any](ctx context.Context, ch chan<- T, val T) error { + select { + case ch <- val: + return nil + case <-ctx.Done(): + return ctx.Err() + } } diff --git a/go/base/context_test.go b/go/base/context_test.go index f87bc9f13..a9f62150d 100644 --- a/go/base/context_test.go +++ b/go/base/context_test.go @@ -6,8 +6,10 @@ package base import ( + "errors" "os" "strings" + "sync" "testing" "time" @@ -86,38 +88,69 @@ func TestGetTriggerNames(t *testing.T) { } func TestValidateGhostTriggerLengthBelowMaxLength(t *testing.T) { + // Tests simulate the real call pattern: GetGhostTriggerName first, then validate the result. { + // Short trigger name with suffix appended: well under 64 chars context := NewMigrationContext() context.TriggerSuffix = "_gho" - require.True(t, context.ValidateGhostTriggerLengthBelowMaxLength("my_trigger")) + ghostName := context.GetGhostTriggerName("my_trigger") // "my_trigger_gho" = 14 chars + require.True(t, context.ValidateGhostTriggerLengthBelowMaxLength(ghostName)) } { + // 64-char original + "_ghost" suffix = 70 chars → exceeds limit context := NewMigrationContext() context.TriggerSuffix = "_ghost" - require.False(t, context.ValidateGhostTriggerLengthBelowMaxLength(strings.Repeat("my_trigger_ghost", 4))) // 64 characters + "_ghost" + ghostName := context.GetGhostTriggerName(strings.Repeat("my_trigger_ghost", 4)) // 64 + 6 = 70 + require.False(t, context.ValidateGhostTriggerLengthBelowMaxLength(ghostName)) } { + // 48-char original + "_ghost" suffix = 54 chars → valid context := NewMigrationContext() context.TriggerSuffix = "_ghost" - require.True(t, context.ValidateGhostTriggerLengthBelowMaxLength(strings.Repeat("my_trigger_ghost", 3))) // 48 characters + "_ghost" + ghostName := context.GetGhostTriggerName(strings.Repeat("my_trigger_ghost", 3)) // 48 + 6 = 54 + require.True(t, context.ValidateGhostTriggerLengthBelowMaxLength(ghostName)) } { + // RemoveTriggerSuffix: 64-char name ending in "_ghost" → suffix removed → 58 chars → valid context := NewMigrationContext() context.TriggerSuffix = "_ghost" context.RemoveTriggerSuffix = true - require.True(t, context.ValidateGhostTriggerLengthBelowMaxLength(strings.Repeat("my_trigger_ghost", 4))) // 64 characters + "_ghost" removed + ghostName := context.GetGhostTriggerName(strings.Repeat("my_trigger_ghost", 4)) // suffix removed → 58 + require.True(t, context.ValidateGhostTriggerLengthBelowMaxLength(ghostName)) } { + // RemoveTriggerSuffix: name doesn't end in suffix → suffix appended → 65 + 6 = 71 chars → exceeds context := NewMigrationContext() context.TriggerSuffix = "_ghost" context.RemoveTriggerSuffix = true - require.False(t, context.ValidateGhostTriggerLengthBelowMaxLength(strings.Repeat("my_trigger_ghost", 4)+"X")) // 65 characters + "_ghost" not removed + ghostName := context.GetGhostTriggerName(strings.Repeat("my_trigger_ghost", 4) + "X") // no match, appended → 71 + require.False(t, context.ValidateGhostTriggerLengthBelowMaxLength(ghostName)) } { + // RemoveTriggerSuffix: 70-char name ending in "_ghost" → suffix removed → 64 chars → exactly at limit → valid context := NewMigrationContext() context.TriggerSuffix = "_ghost" context.RemoveTriggerSuffix = true - require.True(t, context.ValidateGhostTriggerLengthBelowMaxLength(strings.Repeat("my_trigger_ghost", 4)+"_ghost")) // 70 characters + last "_ghost" removed + ghostName := context.GetGhostTriggerName(strings.Repeat("my_trigger_ghost", 4) + "_ghost") // suffix removed → 64 + require.True(t, context.ValidateGhostTriggerLengthBelowMaxLength(ghostName)) + } + { + // Edge case: exactly 64 chars after transformation → valid (boundary test) + context := NewMigrationContext() + context.TriggerSuffix = "_ght" + originalName := strings.Repeat("x", 60) // 60 chars + ghostName := context.GetGhostTriggerName(originalName) // 60 + 4 = 64 + require.Equal(t, 64, len(ghostName)) + require.True(t, context.ValidateGhostTriggerLengthBelowMaxLength(ghostName)) + } + { + // Edge case: 65 chars after transformation → exceeds (boundary test) + context := NewMigrationContext() + context.TriggerSuffix = "_ght" + originalName := strings.Repeat("x", 61) // 61 chars + ghostName := context.GetGhostTriggerName(originalName) // 61 + 4 = 65 + require.Equal(t, 65, len(ghostName)) + require.False(t, context.ValidateGhostTriggerLengthBelowMaxLength(ghostName)) } } @@ -182,3 +215,58 @@ func TestReadConfigFile(t *testing.T) { } } } + +func TestSetAbortError_StoresFirstError(t *testing.T) { + ctx := NewMigrationContext() + + err1 := errors.New("first error") + err2 := errors.New("second error") + + ctx.SetAbortError(err1) + ctx.SetAbortError(err2) + + got := ctx.GetAbortError() + if got != err1 { //nolint:errorlint // Testing pointer equality for sentinel error + t.Errorf("Expected first error %v, got %v", err1, got) + } +} + +func TestSetAbortError_ThreadSafe(t *testing.T) { + ctx := NewMigrationContext() + + var wg sync.WaitGroup + errs := []error{ + errors.New("error 1"), + errors.New("error 2"), + errors.New("error 3"), + } + + // Launch 3 goroutines trying to set error concurrently + for _, err := range errs { + wg.Add(1) + go func(e error) { + defer wg.Done() + ctx.SetAbortError(e) + }(err) + } + + wg.Wait() + + // Should store exactly one of the errors + got := ctx.GetAbortError() + if got == nil { + t.Fatal("Expected error to be stored, got nil") + } + + // Verify it's one of the errors we sent + found := false + for _, err := range errs { + if got == err { //nolint:errorlint // Testing pointer equality for sentinel error + found = true + break + } + } + if !found { + t.Errorf("Stored error %v not in list of sent errors", got) + } +} diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 0a2a8afa4..567137fd5 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -138,7 +138,8 @@ func main() { flag.Int64Var(&migrationContext.HooksStatusIntervalSec, "hooks-status-interval", 60, "how many seconds to wait between calling onStatus hook") flag.UintVar(&migrationContext.ReplicaServerId, "replica-server-id", 99999, "server id used by gh-ost process. Default: 99999") - flag.BoolVar(&migrationContext.AllowSetupMetadataLockInstruments, "allow-setup-metadata-lock-instruments", false, "validate rename session hold the MDL of original table before unlock tables in cut-over phase") + flag.BoolVar(&migrationContext.AllowSetupMetadataLockInstruments, "allow-setup-metadata-lock-instruments", false, "Validate rename session hold the MDL of original table before unlock tables in cut-over phase") + flag.BoolVar(&migrationContext.SkipMetadataLockCheck, "skip-metadata-lock-check", false, "Skip metadata lock check at cut-over time. The checks require performance_schema.metadata_lock to be enabled") flag.IntVar(&migrationContext.BinlogSyncerMaxReconnectAttempts, "binlogsyncer-max-reconnect-attempts", 0, "when master node fails, the maximum number of binlog synchronization attempts to reconnect. 0 is unlimited") flag.BoolVar(&migrationContext.IncludeTriggers, "include-triggers", false, "When true, the triggers (if exist) will be created on the new table") @@ -148,6 +149,8 @@ func main() { flag.BoolVar(&migrationContext.Checkpoint, "checkpoint", false, "Enable migration checkpoints") flag.Int64Var(&migrationContext.CheckpointIntervalSeconds, "checkpoint-seconds", 300, "The number of seconds between checkpoints") flag.BoolVar(&migrationContext.Resume, "resume", false, "Attempt to resume migration from checkpoint") + flag.BoolVar(&migrationContext.Revert, "revert", false, "Attempt to revert completed migration") + flag.StringVar(&migrationContext.OldTableName, "old-table", "", "The name of the old table when using --revert, e.g. '_mytable_del'") maxLoad := flag.String("max-load", "", "Comma delimited status-name=threshold. e.g: 'Threads_running=100,Threads_connected=500'. When status exceeds threshold, app throttles writes") criticalLoad := flag.String("critical-load", "", "Comma delimited status-name=threshold, same format as --max-load. When status exceeds threshold, app panics and quits") @@ -206,12 +209,35 @@ func main() { migrationContext.SetConnectionCharset(*charset) - if migrationContext.AlterStatement == "" { + if migrationContext.AlterStatement == "" && !migrationContext.Revert { log.Fatal("--alter must be provided and statement must not be empty") } parser := sql.NewParserFromAlterStatement(migrationContext.AlterStatement) migrationContext.AlterStatementOptions = parser.GetAlterStatementOptions() + if migrationContext.Revert { + if migrationContext.Resume { + log.Fatal("--revert cannot be used with --resume") + } + if migrationContext.OldTableName == "" { + migrationContext.Log.Fatalf("--revert must be called with --old-table") + } + + // options irrelevant to revert mode + if migrationContext.AlterStatement != "" { + log.Warning("--alter was provided with --revert, it will be ignored") + } + if migrationContext.AttemptInstantDDL { + log.Warning("--attempt-instant-ddl was provided with --revert, it will be ignored") + } + if migrationContext.IncludeTriggers { + log.Warning("--include-triggers was provided with --revert, it will be ignored") + } + if migrationContext.DiscardForeignKeys { + log.Warning("--discard-foreign-keys was provided with --revert, it will be ignored") + } + } + if migrationContext.DatabaseName == "" { if parser.HasExplicitSchema() { migrationContext.DatabaseName = parser.GetExplicitSchema() @@ -290,6 +316,9 @@ func main() { if migrationContext.CheckpointIntervalSeconds < 10 { migrationContext.Log.Fatalf("--checkpoint-seconds should be >=10") } + if migrationContext.CountTableRows && migrationContext.PanicOnWarnings { + migrationContext.Log.Warning("--exact-rowcount with --panic-on-warnings: row counts cannot be exact due to warning detection") + } switch *cutOver { case "atomic", "default", "": @@ -347,7 +376,14 @@ func main() { acceptSignals(migrationContext) migrator := logic.NewMigrator(migrationContext, AppVersion) - if err := migrator.Migrate(); err != nil { + var err error + if migrationContext.Revert { + err = migrator.Revert() + } else { + err = migrator.Migrate() + } + + if err != nil { migrator.ExecOnFailureHook() migrationContext.Log.Fatale(err) } diff --git a/go/logic/applier.go b/go/logic/applier.go index 3fe7f2287..709fd08da 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -94,6 +94,21 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier { } } +// compileMigrationKeyWarningRegex compiles a regex pattern that matches duplicate key warnings +// for the migration's unique key. Duplicate warnings are formatted differently across MySQL versions, +// hence the optional table name prefix. Metacharacters in table/index names are escaped to avoid +// regex syntax errors. +func (this *Applier) compileMigrationKeyWarningRegex() (*regexp.Regexp, error) { + escapedTable := regexp.QuoteMeta(this.migrationContext.GetGhostTableName()) + escapedKey := regexp.QuoteMeta(this.migrationContext.UniqueKey.NameInGhostTable) + migrationUniqueKeyPattern := fmt.Sprintf(`for key '(%s\.)?%s'`, escapedTable, escapedKey) + migrationKeyRegex, err := regexp.Compile(migrationUniqueKeyPattern) + if err != nil { + return nil, fmt.Errorf("failed to compile migration key pattern: %w", err) + } + return migrationKeyRegex, nil +} + func (this *Applier) InitDBConnections() (err error) { applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) uriWithMulti := fmt.Sprintf("%s&multiStatements=true", applierUri) @@ -290,7 +305,32 @@ func (this *Applier) AttemptInstantDDL() error { return err } // We don't need a trx, because for instant DDL the SQL mode doesn't matter. - _, err := this.db.Exec(query) + return retryOnLockWaitTimeout(func() error { + _, err := this.db.Exec(query) + return err + }, this.migrationContext.Log) +} + +// retryOnLockWaitTimeout retries the given operation on MySQL lock wait timeout +// (errno 1205). Non-timeout errors return immediately. This is used for instant +// DDL attempts where the operation may be blocked by a long-running transaction. +func retryOnLockWaitTimeout(operation func() error, logger base.Logger) error { + const maxRetries = 5 + var err error + for i := 0; i < maxRetries; i++ { + if i != 0 { + logger.Infof("Retrying after lock wait timeout (attempt %d/%d)", i+1, maxRetries) + RetrySleepFn(time.Duration(i) * 5 * time.Second) + } + err = operation() + if err == nil { + return nil + } + var mysqlErr *drivermysql.MySQLError + if !errors.As(err, &mysqlErr) || mysqlErr.Number != 1205 { + return err + } + } return err } @@ -433,10 +473,11 @@ func (this *Applier) CreateCheckpointTable() error { colDefs := []string{ "`gh_ost_chk_id` bigint auto_increment primary key", "`gh_ost_chk_timestamp` bigint", - "`gh_ost_chk_coords` varchar(4096)", + "`gh_ost_chk_coords` text charset ascii", "`gh_ost_chk_iteration` bigint", "`gh_ost_rows_copied` bigint", "`gh_ost_dml_applied` bigint", + "`gh_ost_is_cutover` tinyint(1) DEFAULT '0'", } for _, col := range this.migrationContext.UniqueKey.Columns.Columns() { if col.MySQLType == "" { @@ -444,18 +485,12 @@ func (this *Applier) CreateCheckpointTable() error { } minColName := sql.TruncateColumnName(col.Name, sql.MaxColumnNameLength-4) + "_min" colDef := fmt.Sprintf("%s %s", sql.EscapeName(minColName), col.MySQLType) - if !col.Nullable { - colDef += " NOT NULL" - } colDefs = append(colDefs, colDef) } for _, col := range this.migrationContext.UniqueKey.Columns.Columns() { maxColName := sql.TruncateColumnName(col.Name, sql.MaxColumnNameLength-4) + "_max" colDef := fmt.Sprintf("%s %s", sql.EscapeName(maxColName), col.MySQLType) - if !col.Nullable { - colDef += " NOT NULL" - } colDefs = append(colDefs, colDef) } @@ -488,10 +523,16 @@ func (this *Applier) dropTable(tableName string) error { return nil } +// StateMetadataLockInstrument checks if metadata_locks is enabled in performance_schema. +// If not it attempts to enable metadata_locks if this is allowed. func (this *Applier) StateMetadataLockInstrument() error { query := `select /*+ MAX_EXECUTION_TIME(300) */ ENABLED, TIMED from performance_schema.setup_instruments WHERE NAME = 'wait/lock/metadata/sql/mdl'` var enabled, timed string if err := this.db.QueryRow(query).Scan(&enabled, &timed); err != nil { + if errors.Is(err, gosql.ErrNoRows) { + // performance_schema may be disabled. + return nil + } return this.migrationContext.Log.Errorf("query performance_schema.setup_instruments with name wait/lock/metadata/sql/mdl error: %s", err) } if strings.EqualFold(enabled, "YES") && strings.EqualFold(timed, "YES") { @@ -627,7 +668,7 @@ func (this *Applier) WriteCheckpoint(chk *Checkpoint) (int64, error) { if err != nil { return insertId, err } - args := sqlutils.Args(chk.LastTrxCoords.String(), chk.Iteration, chk.RowsCopied, chk.DMLApplied) + args := sqlutils.Args(chk.LastTrxCoords.String(), chk.Iteration, chk.RowsCopied, chk.DMLApplied, chk.IsCutover) args = append(args, uniqueKeyArgs...) res, err := this.db.Exec(query, args...) if err != nil { @@ -637,7 +678,7 @@ func (this *Applier) WriteCheckpoint(chk *Checkpoint) (int64, error) { } func (this *Applier) ReadLastCheckpoint() (*Checkpoint, error) { - row := this.db.QueryRow(fmt.Sprintf(`select /* gh-ost */ * from %s.%s order by gh_ost_chk_id desc limit 1`, this.migrationContext.DatabaseName, this.migrationContext.GetCheckpointTableName())) + row := this.db.QueryRow(fmt.Sprintf(`select /* gh-ost */ * from %s.%s order by gh_ost_chk_id desc limit 1`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetCheckpointTableName()))) chk := &Checkpoint{ IterationRangeMin: sql.NewColumnValues(this.migrationContext.UniqueKey.Columns.Len()), IterationRangeMax: sql.NewColumnValues(this.migrationContext.UniqueKey.Columns.Len()), @@ -645,7 +686,7 @@ func (this *Applier) ReadLastCheckpoint() (*Checkpoint, error) { var coordStr string var timestamp int64 - ptrs := []interface{}{&chk.Id, ×tamp, &coordStr, &chk.Iteration, &chk.RowsCopied, &chk.DMLApplied} + ptrs := []interface{}{&chk.Id, ×tamp, &coordStr, &chk.Iteration, &chk.RowsCopied, &chk.DMLApplied, &chk.IsCutover} ptrs = append(ptrs, chk.IterationRangeMin.ValuesPointers...) ptrs = append(ptrs, chk.IterationRangeMax.ValuesPointers...) err := row.Scan(ptrs...) @@ -694,7 +735,17 @@ func (this *Applier) InitiateHeartbeat() { ticker := time.NewTicker(time.Duration(this.migrationContext.HeartbeatIntervalMilliseconds) * time.Millisecond) defer ticker.Stop() - for range ticker.C { + for { + // Check for context cancellation each iteration + ctx := this.migrationContext.GetContext() + select { + case <-ctx.Done(): + this.migrationContext.Log.Debugf("Heartbeat injection cancelled") + return + case <-ticker.C: + // Process heartbeat + } + if atomic.LoadInt64(&this.finishedMigrating) > 0 { return } @@ -705,7 +756,8 @@ func (this *Applier) InitiateHeartbeat() { continue } if err := injectHeartbeat(); err != nil { - this.migrationContext.PanicAbort <- fmt.Errorf("injectHeartbeat writing failed %d times, last error: %w", numSuccessiveFailures, err) + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, fmt.Errorf("injectHeartbeat writing failed %d times, last error: %w", numSuccessiveFailures, err)) return } } @@ -818,17 +870,6 @@ func (this *Applier) ReadMigrationRangeValues() error { // no further chunk to work through, i.e. we're past the last chunk and are done with // iterating the range (and thus done with copying row chunks) func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange bool, err error) { - this.LastIterationRangeMutex.Lock() - if this.migrationContext.MigrationIterationRangeMinValues != nil && this.migrationContext.MigrationIterationRangeMaxValues != nil { - this.LastIterationRangeMinValues = this.migrationContext.MigrationIterationRangeMinValues.Clone() - this.LastIterationRangeMaxValues = this.migrationContext.MigrationIterationRangeMaxValues.Clone() - } - this.LastIterationRangeMutex.Unlock() - - this.migrationContext.MigrationIterationRangeMinValues = this.migrationContext.MigrationIterationRangeMaxValues - if this.migrationContext.MigrationIterationRangeMinValues == nil { - this.migrationContext.MigrationIterationRangeMinValues = this.migrationContext.MigrationRangeMinValues - } for i := 0; i < 2; i++ { buildFunc := sql.BuildUniqueKeyRangeEndPreparedQueryViaOffset if i == 1 { @@ -927,6 +968,12 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected return nil, err } + // Compile regex once before loop to avoid performance penalty and handle errors properly + migrationKeyRegex, err := this.compileMigrationKeyWarningRegex() + if err != nil { + return nil, err + } + var sqlWarnings []string for rows.Next() { var level, message string @@ -935,10 +982,7 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected this.migrationContext.Log.Warningf("Failed to read SHOW WARNINGS row") continue } - // Duplicate warnings are formatted differently across mysql versions, hence the optional table name prefix - migrationUniqueKeyExpression := fmt.Sprintf("for key '(%s\\.)?%s'", this.migrationContext.GetGhostTableName(), this.migrationContext.UniqueKey.NameInGhostTable) - matched, _ := regexp.MatchString(migrationUniqueKeyExpression, message) - if strings.Contains(message, "Duplicate entry") && matched { + if strings.Contains(message, "Duplicate entry") && migrationKeyRegex.MatchString(message) { continue } sqlWarnings = append(sqlWarnings, fmt.Sprintf("%s: %s (%d)", level, message, code)) @@ -1349,7 +1393,7 @@ func (this *Applier) AtomicCutOverMagicLock(sessionIdChan chan int64, tableLocke this.migrationContext.Log.Infof("Session renameLockSessionId is %+v", *renameLockSessionId) // Checking the lock is held by rename session - if *renameLockSessionId > 0 && this.migrationContext.IsOpenMetadataLockInstruments { + if *renameLockSessionId > 0 && this.migrationContext.IsOpenMetadataLockInstruments && !this.migrationContext.SkipMetadataLockCheck { sleepDuration := time.Duration(10*this.migrationContext.CutOverLockTimeoutSeconds) * time.Millisecond for i := 1; i <= 100; i++ { err := this.ExpectMetadataLock(*renameLockSessionId) @@ -1469,16 +1513,116 @@ func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlB results = append(results, this.buildDMLEventQuery(dmlEvent)...) return results } - query, sharedArgs, uniqueKeyArgs, err := this.dmlUpdateQueryBuilder.BuildQuery(dmlEvent.NewColumnValues.AbstractValues(), dmlEvent.WhereColumnValues.AbstractValues()) + query, updateArgs, err := this.dmlUpdateQueryBuilder.BuildQuery(dmlEvent.NewColumnValues.AbstractValues(), dmlEvent.WhereColumnValues.AbstractValues()) args := sqlutils.Args() - args = append(args, sharedArgs...) - args = append(args, uniqueKeyArgs...) + args = append(args, updateArgs...) return []*dmlBuildResult{newDmlBuildResult(query, args, 0, err)} } } return []*dmlBuildResult{newDmlBuildResultError(fmt.Errorf("Unknown dml event type: %+v", dmlEvent.DML))} } +// executeBatchWithWarningChecking executes a batch of DML statements with SHOW WARNINGS +// interleaved after each statement to detect warnings from any statement in the batch. +// This is used when PanicOnWarnings is enabled to ensure warnings from middle statements +// are not lost (SHOW WARNINGS only shows warnings from the last statement in a multi-statement batch). +func (this *Applier) executeBatchWithWarningChecking(ctx context.Context, tx *gosql.Tx, buildResults []*dmlBuildResult) (int64, error) { + // Build query with interleaved SHOW WARNINGS: stmt1; SHOW WARNINGS; stmt2; SHOW WARNINGS; ... + var queryBuilder strings.Builder + args := make([]interface{}, 0) + + for _, buildResult := range buildResults { + queryBuilder.WriteString(buildResult.query) + queryBuilder.WriteString(";\nSHOW WARNINGS;\n") + args = append(args, buildResult.args...) + } + + query := queryBuilder.String() + + // Execute the multi-statement query + rows, err := tx.QueryContext(ctx, query, args...) + if err != nil { + return 0, fmt.Errorf("%w; query=%s; args=%+v", err, query, args) + } + defer rows.Close() + + var totalDelta int64 + + // QueryContext with multi-statement queries returns rows positioned at the first result set + // that produces rows (i.e., the first SHOW WARNINGS), automatically skipping DML results. + // Verify we're at a SHOW WARNINGS result set (should have 3 columns: Level, Code, Message) + cols, err := rows.Columns() + if err != nil { + return 0, fmt.Errorf("failed to get columns: %w", err) + } + + // If somehow we're not at a result set with columns, try to advance + if len(cols) == 0 { + if !rows.NextResultSet() { + return 0, fmt.Errorf("expected SHOW WARNINGS result set after first statement") + } + } + + // Compile regex once before loop to avoid performance penalty and handle errors properly + migrationKeyRegex, err := this.compileMigrationKeyWarningRegex() + if err != nil { + return 0, err + } + + // Iterate through SHOW WARNINGS result sets. + // DML statements don't create navigable result sets, so we move directly between SHOW WARNINGS. + // Pattern: [at SHOW WARNINGS #1] -> read warnings -> NextResultSet() -> [at SHOW WARNINGS #2] -> ... + for i := 0; i < len(buildResults); i++ { + // We can't get exact rows affected with QueryContext (needed for reading SHOW WARNINGS). + // Use the theoretical delta (+1 for INSERT, -1 for DELETE, 0 for UPDATE) as an approximation. + // This may be inaccurate (e.g., INSERT IGNORE with duplicate affects 0 rows but we count +1). + totalDelta += buildResults[i].rowsDelta + + // Read warnings from this statement's SHOW WARNINGS result set + var sqlWarnings []string + for rows.Next() { + var level, message string + var code int + if err := rows.Scan(&level, &code, &message); err != nil { + // Scan failure means we cannot reliably read warnings. + // Since PanicOnWarnings is a safety feature, we must fail hard rather than silently skip. + return 0, fmt.Errorf("failed to scan SHOW WARNINGS for statement %d: %w", i+1, err) + } + + if strings.Contains(message, "Duplicate entry") && migrationKeyRegex.MatchString(message) { + // Duplicate entry on migration unique key is expected during binlog replay + // (row was already copied during bulk copy phase) + continue + } + sqlWarnings = append(sqlWarnings, fmt.Sprintf("%s: %s (%d)", level, message, code)) + } + + // Check for errors that occurred while iterating through warnings + if err := rows.Err(); err != nil { + return 0, fmt.Errorf("error reading SHOW WARNINGS result set for statement %d: %w", i+1, err) + } + + if len(sqlWarnings) > 0 { + return 0, fmt.Errorf("warnings detected in statement %d of %d: %v", i+1, len(buildResults), sqlWarnings) + } + + // Move to the next statement's SHOW WARNINGS result set + // For the last statement, there's no next result set + // DML statements don't create result sets, so we only need one NextResultSet call + // to move from SHOW WARNINGS #N to SHOW WARNINGS #(N+1) + if i < len(buildResults)-1 { + if !rows.NextResultSet() { + if err := rows.Err(); err != nil { + return 0, fmt.Errorf("error moving to SHOW WARNINGS for statement %d: %w", i+2, err) + } + return 0, fmt.Errorf("expected SHOW WARNINGS result set for statement %d", i+2) + } + } + } + + return totalDelta, nil +} + // ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) error { var totalDelta int64 @@ -1518,45 +1662,55 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) } } - // We batch together the DML queries into multi-statements to minimize network trips. - // We have to use the raw driver connection to access the rows affected - // for each statement in the multi-statement. - execErr := conn.Raw(func(driverConn any) error { - ex := driverConn.(driver.ExecerContext) - nvc := driverConn.(driver.NamedValueChecker) - - multiArgs := make([]driver.NamedValue, 0, nArgs) - multiQueryBuilder := strings.Builder{} - for _, buildResult := range buildResults { - for _, arg := range buildResult.args { - nv := driver.NamedValue{Value: driver.Value(arg)} - nvc.CheckNamedValue(&nv) - multiArgs = append(multiArgs, nv) + // When PanicOnWarnings is enabled, we need to check warnings after each statement + // in the batch. SHOW WARNINGS only shows warnings from the last statement in a + // multi-statement query, so we interleave SHOW WARNINGS after each DML statement. + if this.migrationContext.PanicOnWarnings { + totalDelta, err = this.executeBatchWithWarningChecking(ctx, tx, buildResults) + if err != nil { + return rollback(err) + } + } else { + // Fast path: batch together DML queries into multi-statements to minimize network trips. + // We use the raw driver connection to access the rows affected for each statement. + execErr := conn.Raw(func(driverConn any) error { + ex := driverConn.(driver.ExecerContext) + nvc := driverConn.(driver.NamedValueChecker) + + multiArgs := make([]driver.NamedValue, 0, nArgs) + multiQueryBuilder := strings.Builder{} + for _, buildResult := range buildResults { + for _, arg := range buildResult.args { + nv := driver.NamedValue{Value: driver.Value(arg)} + nvc.CheckNamedValue(&nv) + multiArgs = append(multiArgs, nv) + } + + multiQueryBuilder.WriteString(buildResult.query) + multiQueryBuilder.WriteString(";\n") } - multiQueryBuilder.WriteString(buildResult.query) - multiQueryBuilder.WriteString(";\n") - } + res, err := ex.ExecContext(ctx, multiQueryBuilder.String(), multiArgs) + if err != nil { + err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs) + return err + } - res, err := ex.ExecContext(ctx, multiQueryBuilder.String(), multiArgs) - if err != nil { - err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs) - return err - } + mysqlRes := res.(drivermysql.Result) - mysqlRes := res.(drivermysql.Result) + // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1). + // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event + for i, rowsAffected := range mysqlRes.AllRowsAffected() { + totalDelta += buildResults[i].rowsDelta * rowsAffected + } + return nil + }) - // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1). - // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event - for i, rowsAffected := range mysqlRes.AllRowsAffected() { - totalDelta += buildResults[i].rowsDelta * rowsAffected + if execErr != nil { + return rollback(execErr) } - return nil - }) - - if execErr != nil { - return rollback(execErr) } + if err := tx.Commit(); err != nil { return err } diff --git a/go/logic/applier_test.go b/go/logic/applier_test.go index 232349d16..a9ecb889d 100644 --- a/go/logic/applier_test.go +++ b/go/logic/applier_test.go @@ -8,9 +8,12 @@ package logic import ( "context" gosql "database/sql" + "errors" "strings" "testing" + "time" + drivermysql "github.com/go-sql-driver/mysql" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -23,7 +26,6 @@ import ( "github.com/github/gh-ost/go/binlog" "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" - "github.com/testcontainers/testcontainers-go/wait" ) func TestApplierGenerateSqlModeQuery(t *testing.T) { @@ -147,7 +149,7 @@ func TestApplierBuildDMLEventQuery(t *testing.T) { require.Len(t, res, 1) require.NoError(t, res[0].err) require.Equal(t, - `replace /* gh-ost `+"`test`.`_test_gho`"+` */ + `insert /* gh-ost `+"`test`.`_test_gho`"+` */ ignore into `+"`test`.`_test_gho`"+` `+"(`id`, `item_id`)"+` @@ -199,6 +201,71 @@ func TestApplierInstantDDL(t *testing.T) { }) } +func TestRetryOnLockWaitTimeout(t *testing.T) { + oldRetrySleepFn := RetrySleepFn + defer func() { RetrySleepFn = oldRetrySleepFn }() + RetrySleepFn = func(d time.Duration) {} // no-op for tests + + logger := base.NewMigrationContext().Log + + lockWaitTimeoutErr := &drivermysql.MySQLError{Number: 1205, Message: "Lock wait timeout exceeded"} + nonRetryableErr := &drivermysql.MySQLError{Number: 1845, Message: "ALGORITHM=INSTANT is not supported"} + + t.Run("success on first attempt", func(t *testing.T) { + calls := 0 + err := retryOnLockWaitTimeout(func() error { + calls++ + return nil + }, logger) + require.NoError(t, err) + require.Equal(t, 1, calls) + }) + + t.Run("retry on lock wait timeout then succeed", func(t *testing.T) { + calls := 0 + err := retryOnLockWaitTimeout(func() error { + calls++ + if calls < 3 { + return lockWaitTimeoutErr + } + return nil + }, logger) + require.NoError(t, err) + require.Equal(t, 3, calls) + }) + + t.Run("non-retryable error returns immediately", func(t *testing.T) { + calls := 0 + err := retryOnLockWaitTimeout(func() error { + calls++ + return nonRetryableErr + }, logger) + require.ErrorIs(t, err, nonRetryableErr) + require.Equal(t, 1, calls) + }) + + t.Run("non-mysql error returns immediately", func(t *testing.T) { + calls := 0 + genericErr := errors.New("connection refused") + err := retryOnLockWaitTimeout(func() error { + calls++ + return genericErr + }, logger) + require.ErrorIs(t, err, genericErr) + require.Equal(t, 1, calls) + }) + + t.Run("exhausts all retries", func(t *testing.T) { + calls := 0 + err := retryOnLockWaitTimeout(func() error { + calls++ + return lockWaitTimeoutErr + }, logger) + require.ErrorIs(t, err, lockWaitTimeoutErr) + require.Equal(t, 5, calls) + }) +} + type ApplierTestSuite struct { suite.Suite @@ -213,7 +280,7 @@ func (suite *ApplierTestSuite) SetupSuite() { testmysql.WithDatabase(testMysqlDatabase), testmysql.WithUsername(testMysqlUser), testmysql.WithPassword(testMysqlPass), - testcontainers.WithWaitStrategy(wait.ForExposedPort()), + testmysql.WithConfigFile("my.cnf.test"), ) suite.Require().NoError(err) @@ -272,7 +339,7 @@ func (suite *ApplierTestSuite) TestInitDBConnections() { mysqlVersion, _ := strings.CutPrefix(testMysqlContainerImage, "mysql:") suite.Require().Equal(mysqlVersion, migrationContext.ApplierMySQLVersion) suite.Require().Equal(int64(28800), migrationContext.ApplierWaitTimeout) - suite.Require().Equal("SYSTEM", migrationContext.ApplierTimeZone) + suite.Require().Equal("+00:00", migrationContext.ApplierTimeZone) suite.Require().Equal(sql.NewColumnList([]string{"id", "item_id"}), migrationContext.OriginalTableColumnsOnApplier) } @@ -541,7 +608,9 @@ func (suite *ApplierTestSuite) TestPanicOnWarningsInApplyIterationInsertQuerySuc err = applier.ReadMigrationRangeValues() suite.Require().NoError(err) + migrationContext.SetNextIterationRangeMinValues() hasFurtherRange, err := applier.CalculateNextIterationRangeEndValues() + suite.Require().NoError(err) suite.Require().True(hasFurtherRange) @@ -619,6 +688,7 @@ func (suite *ApplierTestSuite) TestPanicOnWarningsInApplyIterationInsertQueryFai err = applier.AlterGhost() suite.Require().NoError(err) + migrationContext.SetNextIterationRangeMinValues() hasFurtherRange, err := applier.CalculateNextIterationRangeEndValues() suite.Require().NoError(err) suite.Require().True(hasFurtherRange) @@ -702,6 +772,7 @@ func (suite *ApplierTestSuite) TestWriteCheckpoint() { Iteration: 2, RowsCopied: 100000, DMLApplied: 200000, + IsCutover: true, } id, err := applier.WriteCheckpoint(chk) suite.Require().NoError(err) @@ -716,6 +787,762 @@ func (suite *ApplierTestSuite) TestWriteCheckpoint() { suite.Require().Equal(chk.IterationRangeMax.String(), gotChk.IterationRangeMax.String()) suite.Require().Equal(chk.RowsCopied, gotChk.RowsCopied) suite.Require().Equal(chk.DMLApplied, gotChk.DMLApplied) + suite.Require().Equal(chk.IsCutover, gotChk.IsCutover) +} + +func (suite *ApplierTestSuite) TestPanicOnWarningsWithDuplicateKeyOnNonMigrationIndex() { + ctx := context.Background() + + var err error + + // Create table with id and email columns, where id is the primary key + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100));", getTestTableName())) + suite.Require().NoError(err) + + // Create ghost table with same schema plus a new unique index on email + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100), UNIQUE KEY email_unique (email));", getTestGhostTableName())) + suite.Require().NoError(err) + + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + + migrationContext.PanicOnWarnings = true + + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: "PRIMARY", + NameInGhostTable: "PRIMARY", + Columns: *sql.NewColumnList([]string{"id"}), + } + + applier := NewApplier(migrationContext) + suite.Require().NoError(applier.prepareQueries()) + defer applier.Teardown() + + err = applier.InitDBConnections() + suite.Require().NoError(err) + + // Insert initial rows into ghost table (simulating bulk copy phase) + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, email) VALUES (1, 'user1@example.com'), (2, 'user2@example.com'), (3, 'user3@example.com');", getTestGhostTableName())) + suite.Require().NoError(err) + + // Simulate binlog event: try to insert a row with duplicate email + // This should fail with a warning because the ghost table has a unique index on email + dmlEvents := []*binlog.BinlogDMLEvent{ + { + DatabaseName: testMysqlDatabase, + TableName: testMysqlTableName, + DML: binlog.InsertDML, + NewColumnValues: sql.ToColumnValues([]interface{}{4, "user2@example.com"}), // duplicate email + }, + } + + // This should return an error when PanicOnWarnings is enabled + err = applier.ApplyDMLEventQueries(dmlEvents) + suite.Require().Error(err) + suite.Require().Contains(err.Error(), "Duplicate entry") + + // Verify that the ghost table still has only the original 3 rows with correct data (no data loss) + rows, err := suite.db.Query("SELECT id, email FROM " + getTestGhostTableName() + " ORDER BY id") + suite.Require().NoError(err) + defer rows.Close() + + var results []struct { + id int + email string + } + for rows.Next() { + var id int + var email string + err = rows.Scan(&id, &email) + suite.Require().NoError(err) + results = append(results, struct { + id int + email string + }{id, email}) + } + suite.Require().NoError(rows.Err()) + + // All 3 original rows should still be present with correct data + suite.Require().Len(results, 3) + suite.Require().Equal(1, results[0].id) + suite.Require().Equal("user1@example.com", results[0].email) + suite.Require().Equal(2, results[1].id) + suite.Require().Equal("user2@example.com", results[1].email) + suite.Require().Equal(3, results[2].id) + suite.Require().Equal("user3@example.com", results[2].email) +} + +func (suite *ApplierTestSuite) TestPanicOnWarningsWithDuplicateCompositeUniqueKey() { + ctx := context.Background() + + var err error + + // Create table with id, email, and username columns + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100), username VARCHAR(100));", getTestTableName())) + suite.Require().NoError(err) + + // Create ghost table with same schema plus a composite unique index on (email, username) + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100), username VARCHAR(100), UNIQUE KEY email_username_unique (email, username));", getTestGhostTableName())) + suite.Require().NoError(err) + + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + + migrationContext.PanicOnWarnings = true + + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "email", "username"}) + migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "email", "username"}) + migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "email", "username"}) + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: "PRIMARY", + NameInGhostTable: "PRIMARY", + Columns: *sql.NewColumnList([]string{"id"}), + } + + applier := NewApplier(migrationContext) + suite.Require().NoError(applier.prepareQueries()) + defer applier.Teardown() + + err = applier.InitDBConnections() + suite.Require().NoError(err) + + // Insert initial rows into ghost table (simulating bulk copy phase) + // alice@example.com + bob is ok due to composite unique index + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, email, username) VALUES (1, 'alice@example.com', 'alice'), (2, 'alice@example.com', 'bob'), (3, 'charlie@example.com', 'charlie');", getTestGhostTableName())) + suite.Require().NoError(err) + + // Simulate binlog event: try to insert a row with duplicate composite key (email + username) + // This should fail with a warning because the ghost table has a composite unique index + dmlEvents := []*binlog.BinlogDMLEvent{ + { + DatabaseName: testMysqlDatabase, + TableName: testMysqlTableName, + DML: binlog.InsertDML, + NewColumnValues: sql.ToColumnValues([]interface{}{4, "alice@example.com", "alice"}), // duplicate (email, username) + }, + } + + // This should return an error when PanicOnWarnings is enabled + err = applier.ApplyDMLEventQueries(dmlEvents) + suite.Require().Error(err) + suite.Require().Contains(err.Error(), "Duplicate entry") + + // Verify that the ghost table still has only the original 3 rows with correct data (no data loss) + rows, err := suite.db.Query("SELECT id, email, username FROM " + getTestGhostTableName() + " ORDER BY id") + suite.Require().NoError(err) + defer rows.Close() + + var results []struct { + id int + email string + username string + } + for rows.Next() { + var id int + var email string + var username string + err = rows.Scan(&id, &email, &username) + suite.Require().NoError(err) + results = append(results, struct { + id int + email string + username string + }{id, email, username}) + } + suite.Require().NoError(rows.Err()) + + // All 3 original rows should still be present with correct data + suite.Require().Len(results, 3) + suite.Require().Equal(1, results[0].id) + suite.Require().Equal("alice@example.com", results[0].email) + suite.Require().Equal("alice", results[0].username) + suite.Require().Equal(2, results[1].id) + suite.Require().Equal("alice@example.com", results[1].email) + suite.Require().Equal("bob", results[1].username) + suite.Require().Equal(3, results[2].id) + suite.Require().Equal("charlie@example.com", results[2].email) + suite.Require().Equal("charlie", results[2].username) +} + +// TestUpdateModifyingUniqueKeyWithDuplicateOnOtherIndex tests the scenario where: +// 1. An UPDATE modifies the unique key (converted to DELETE+INSERT) +// 2. The INSERT would create a duplicate on a NON-migration unique index +// 3. Without warning detection: DELETE succeeds, INSERT IGNORE skips = DATA LOSS +// 4. With PanicOnWarnings: Warning detected, transaction rolled back, no data loss +// This test verifies that PanicOnWarnings correctly prevents the data loss scenario. +func (suite *ApplierTestSuite) TestUpdateModifyingUniqueKeyWithDuplicateOnOtherIndex() { + ctx := context.Background() + + var err error + + // Create table with id (PRIMARY) and email (NO unique constraint yet) + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100));", getTestTableName())) + suite.Require().NoError(err) + + // Create ghost table with id (PRIMARY) AND email unique index (being added) + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100), UNIQUE KEY email_unique (email));", getTestGhostTableName())) + suite.Require().NoError(err) + + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + + migrationContext.PanicOnWarnings = true + + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: "PRIMARY", + NameInGhostTable: "PRIMARY", + Columns: *sql.NewColumnList([]string{"id"}), + } + + applier := NewApplier(migrationContext) + suite.Require().NoError(applier.prepareQueries()) + defer applier.Teardown() + + err = applier.InitDBConnections() + suite.Require().NoError(err) + + // Setup: Insert initial rows into ghost table + // Row 1: id=1, email='bob@example.com' + // Row 2: id=2, email='charlie@example.com' + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, email) VALUES (1, 'bob@example.com'), (2, 'charlie@example.com');", getTestGhostTableName())) + suite.Require().NoError(err) + + // Simulate binlog event: UPDATE that changes BOTH PRIMARY KEY and email + // From: id=2, email='charlie@example.com' + // To: id=3, email='bob@example.com' (duplicate email with id=1) + // This will be converted to DELETE (id=2) + INSERT (id=3, 'bob@example.com') + // With INSERT IGNORE, the INSERT will skip because email='bob@example.com' already exists in id=1 + // Result: id=2 deleted, id=3 never inserted = DATA LOSS + dmlEvents := []*binlog.BinlogDMLEvent{ + { + DatabaseName: testMysqlDatabase, + TableName: testMysqlTableName, + DML: binlog.UpdateDML, + NewColumnValues: sql.ToColumnValues([]interface{}{3, "bob@example.com"}), // new: id=3, email='bob@example.com' + WhereColumnValues: sql.ToColumnValues([]interface{}{2, "charlie@example.com"}), // old: id=2, email='charlie@example.com' + }, + } + + // First verify this would be converted to DELETE+INSERT + buildResults := applier.buildDMLEventQuery(dmlEvents[0]) + suite.Require().Len(buildResults, 2, "UPDATE modifying unique key should be converted to DELETE+INSERT") + + // Apply the event - this should FAIL because INSERT will have duplicate email warning + err = applier.ApplyDMLEventQueries(dmlEvents) + suite.Require().Error(err, "Should fail when DELETE+INSERT causes duplicate on non-migration unique key") + suite.Require().Contains(err.Error(), "Duplicate entry", "Error should mention duplicate entry") + + // Verify that BOTH rows still exist (transaction rolled back) + rows, err := suite.db.Query("SELECT id, email FROM " + getTestGhostTableName() + " ORDER BY id") + suite.Require().NoError(err) + defer rows.Close() + + var count int + var ids []int + var emails []string + for rows.Next() { + var id int + var email string + err = rows.Scan(&id, &email) + suite.Require().NoError(err) + ids = append(ids, id) + emails = append(emails, email) + count++ + } + suite.Require().NoError(rows.Err()) + + // Transaction should have rolled back, so original 2 rows should still be there + suite.Require().Equal(2, count, "Should still have 2 rows after failed transaction") + suite.Require().Equal([]int{1, 2}, ids, "Should have original ids") + suite.Require().Equal([]string{"bob@example.com", "charlie@example.com"}, emails) +} + +// TestNormalUpdateWithPanicOnWarnings tests that normal UPDATEs (not modifying unique key) work correctly +func (suite *ApplierTestSuite) TestNormalUpdateWithPanicOnWarnings() { + ctx := context.Background() + + var err error + + // Create table with id (PRIMARY) and email + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100));", getTestTableName())) + suite.Require().NoError(err) + + // Create ghost table with same schema plus unique index on email + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100), UNIQUE KEY email_unique (email));", getTestGhostTableName())) + suite.Require().NoError(err) + + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + + migrationContext.PanicOnWarnings = true + + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: "PRIMARY", + NameInGhostTable: "PRIMARY", + Columns: *sql.NewColumnList([]string{"id"}), + } + + applier := NewApplier(migrationContext) + suite.Require().NoError(applier.prepareQueries()) + defer applier.Teardown() + + err = applier.InitDBConnections() + suite.Require().NoError(err) + + // Setup: Insert initial rows into ghost table + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, email) VALUES (1, 'alice@example.com'), (2, 'bob@example.com');", getTestGhostTableName())) + suite.Require().NoError(err) + + // Simulate binlog event: Normal UPDATE that only changes email (not PRIMARY KEY) + // This should use UPDATE query, not DELETE+INSERT + dmlEvents := []*binlog.BinlogDMLEvent{ + { + DatabaseName: testMysqlDatabase, + TableName: testMysqlTableName, + DML: binlog.UpdateDML, + NewColumnValues: sql.ToColumnValues([]interface{}{2, "robert@example.com"}), // update email only + WhereColumnValues: sql.ToColumnValues([]interface{}{2, "bob@example.com"}), + }, + } + + // Verify this generates a single UPDATE query (not DELETE+INSERT) + buildResults := applier.buildDMLEventQuery(dmlEvents[0]) + suite.Require().Len(buildResults, 1, "Normal UPDATE should generate single UPDATE query") + + // Apply the event - should succeed + err = applier.ApplyDMLEventQueries(dmlEvents) + suite.Require().NoError(err) + + // Verify the update was applied correctly + rows, err := suite.db.Query("SELECT id, email FROM " + getTestGhostTableName() + " WHERE id = 2") + suite.Require().NoError(err) + defer rows.Close() + + var id int + var email string + suite.Require().True(rows.Next(), "Should find updated row") + err = rows.Scan(&id, &email) + suite.Require().NoError(err) + suite.Require().Equal(2, id) + suite.Require().Equal("robert@example.com", email) + suite.Require().False(rows.Next(), "Should only have one row") + suite.Require().NoError(rows.Err()) +} + +// TestDuplicateOnMigrationKeyAllowedInBinlogReplay tests the positive case where +// a duplicate on the migration unique key during binlog replay is expected and should be allowed +func (suite *ApplierTestSuite) TestDuplicateOnMigrationKeyAllowedInBinlogReplay() { + ctx := context.Background() + + var err error + + // Create table with id and email columns, where id is the primary key + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100));", getTestTableName())) + suite.Require().NoError(err) + + // Create ghost table with same schema plus a new unique index on email + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100), UNIQUE KEY email_unique (email));", getTestGhostTableName())) + suite.Require().NoError(err) + + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + + migrationContext.PanicOnWarnings = true + + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: "PRIMARY", + NameInGhostTable: "PRIMARY", + Columns: *sql.NewColumnList([]string{"id"}), + } + + applier := NewApplier(migrationContext) + suite.Require().NoError(applier.prepareQueries()) + defer applier.Teardown() + + err = applier.InitDBConnections() + suite.Require().NoError(err) + + // Insert initial rows into ghost table (simulating bulk copy phase) + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, email) VALUES (1, 'alice@example.com'), (2, 'bob@example.com');", getTestGhostTableName())) + suite.Require().NoError(err) + + // Simulate binlog event: try to insert the same row again (duplicate on PRIMARY KEY - the migration key) + // This is expected during binlog replay when a row was already copied during bulk copy + dmlEvents := []*binlog.BinlogDMLEvent{ + { + DatabaseName: testMysqlDatabase, + TableName: testMysqlTableName, + DML: binlog.InsertDML, + NewColumnValues: sql.ToColumnValues([]interface{}{1, "alice@example.com"}), // duplicate PRIMARY KEY + }, + } + + // This should succeed - duplicate on migration unique key is expected and should be filtered out + err = applier.ApplyDMLEventQueries(dmlEvents) + suite.Require().NoError(err) + + // Verify that the ghost table still has only the original 2 rows with correct data + rows, err := suite.db.Query("SELECT id, email FROM " + getTestGhostTableName() + " ORDER BY id") + suite.Require().NoError(err) + defer rows.Close() + + var results []struct { + id int + email string + } + for rows.Next() { + var id int + var email string + err = rows.Scan(&id, &email) + suite.Require().NoError(err) + results = append(results, struct { + id int + email string + }{id, email}) + } + suite.Require().NoError(rows.Err()) + + // Should still have exactly 2 rows with correct data + suite.Require().Len(results, 2) + suite.Require().Equal(1, results[0].id) + suite.Require().Equal("alice@example.com", results[0].email) + suite.Require().Equal(2, results[1].id) + suite.Require().Equal("bob@example.com", results[1].email) +} + +// TestRegexMetacharactersInIndexName tests that index names with regex metacharacters +// are properly escaped. We test with a plus sign in the index name, which without +// QuoteMeta would be treated as a regex quantifier (one or more of 'x' in this case). +// This test verifies the pattern matches ONLY the exact index name, not a regex pattern. +func (suite *ApplierTestSuite) TestRegexMetacharactersInIndexName() { + ctx := context.Background() + + var err error + + // Create tables with an index name containing a plus sign + // Without QuoteMeta, "idx+email" would be treated as a regex pattern where + is a quantifier + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100), UNIQUE KEY `idx+email` (email));", getTestTableName())) + suite.Require().NoError(err) + + // MySQL allows + in index names when quoted + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100), UNIQUE KEY `idx+email` (email));", getTestGhostTableName())) + suite.Require().NoError(err) + + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + + migrationContext.PanicOnWarnings = true + + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: "idx+email", + NameInGhostTable: "idx+email", + Columns: *sql.NewColumnList([]string{"email"}), + } + + applier := NewApplier(migrationContext) + suite.Require().NoError(applier.prepareQueries()) + defer applier.Teardown() + + err = applier.InitDBConnections() + suite.Require().NoError(err) + + // Insert initial rows + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, email) VALUES (1, 'alice@example.com'), (2, 'bob@example.com');", getTestGhostTableName())) + suite.Require().NoError(err) + + // Test: duplicate on idx+email (the migration key) should be allowed + // This verifies our regex correctly identifies "idx+email" as the migration key + // Without regexp.QuoteMeta, the + would be treated as a regex quantifier and might not match correctly + dmlEvents := []*binlog.BinlogDMLEvent{ + { + DatabaseName: testMysqlDatabase, + TableName: testMysqlTableName, + DML: binlog.InsertDML, + NewColumnValues: sql.ToColumnValues([]interface{}{3, "alice@example.com"}), + }, + } + + err = applier.ApplyDMLEventQueries(dmlEvents) + suite.Require().NoError(err, "Duplicate on idx+email (migration key) should be allowed with PanicOnWarnings enabled") + + // Test: duplicate on PRIMARY (not the migration key) should fail + dmlEvents = []*binlog.BinlogDMLEvent{ + { + DatabaseName: testMysqlDatabase, + TableName: testMysqlTableName, + DML: binlog.InsertDML, + NewColumnValues: sql.ToColumnValues([]interface{}{1, "charlie@example.com"}), + }, + } + + err = applier.ApplyDMLEventQueries(dmlEvents) + suite.Require().Error(err, "Duplicate on PRIMARY (not migration key) should fail with PanicOnWarnings enabled") + suite.Require().Contains(err.Error(), "Duplicate entry") + + // Verify final state - should still have only the original 2 rows + rows, err := suite.db.Query("SELECT id, email FROM " + getTestGhostTableName() + " ORDER BY id") + suite.Require().NoError(err) + defer rows.Close() + + var results []struct { + id int + email string + } + for rows.Next() { + var id int + var email string + err = rows.Scan(&id, &email) + suite.Require().NoError(err) + results = append(results, struct { + id int + email string + }{id, email}) + } + suite.Require().NoError(rows.Err()) + + suite.Require().Len(results, 2) + suite.Require().Equal(1, results[0].id) + suite.Require().Equal("alice@example.com", results[0].email) + suite.Require().Equal(2, results[1].id) + suite.Require().Equal("bob@example.com", results[1].email) +} + +// TestPanicOnWarningsDisabled tests that when PanicOnWarnings is false, +// warnings are not checked and duplicates are silently ignored +func (suite *ApplierTestSuite) TestPanicOnWarningsDisabled() { + ctx := context.Background() + + var err error + + // Create table with id and email columns + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100));", getTestTableName())) + suite.Require().NoError(err) + + // Create ghost table with unique index on email + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100), UNIQUE KEY email_unique (email));", getTestGhostTableName())) + suite.Require().NoError(err) + + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + + // PanicOnWarnings is false (default) + migrationContext.PanicOnWarnings = false + + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: "PRIMARY", + NameInGhostTable: "PRIMARY", + Columns: *sql.NewColumnList([]string{"id"}), + } + + applier := NewApplier(migrationContext) + suite.Require().NoError(applier.prepareQueries()) + defer applier.Teardown() + + err = applier.InitDBConnections() + suite.Require().NoError(err) + + // Insert initial rows into ghost table + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, email) VALUES (1, 'alice@example.com'), (2, 'bob@example.com');", getTestGhostTableName())) + suite.Require().NoError(err) + + // Simulate binlog event: insert duplicate email on non-migration index + // With PanicOnWarnings disabled, this should succeed (INSERT IGNORE skips it) + dmlEvents := []*binlog.BinlogDMLEvent{ + { + DatabaseName: testMysqlDatabase, + TableName: testMysqlTableName, + DML: binlog.InsertDML, + NewColumnValues: sql.ToColumnValues([]interface{}{3, "alice@example.com"}), // duplicate email + }, + } + + // Should succeed because PanicOnWarnings is disabled + err = applier.ApplyDMLEventQueries(dmlEvents) + suite.Require().NoError(err) + + // Verify that only 2 original rows exist with correct data (the duplicate was silently ignored) + rows, err := suite.db.Query("SELECT id, email FROM " + getTestGhostTableName() + " ORDER BY id") + suite.Require().NoError(err) + defer rows.Close() + + var results []struct { + id int + email string + } + for rows.Next() { + var id int + var email string + err = rows.Scan(&id, &email) + suite.Require().NoError(err) + results = append(results, struct { + id int + email string + }{id, email}) + } + suite.Require().NoError(rows.Err()) + + // Should still have exactly 2 original rows (id=3 was silently ignored) + suite.Require().Len(results, 2) + suite.Require().Equal(1, results[0].id) + suite.Require().Equal("alice@example.com", results[0].email) + suite.Require().Equal(2, results[1].id) + suite.Require().Equal("bob@example.com", results[1].email) +} + +// TestMultipleDMLEventsInBatch tests that multiple DML events are processed in a single transaction +// and that if one fails due to a warning, the entire batch is rolled back - including events that +// come AFTER the failure. This proves true transaction atomicity. +func (suite *ApplierTestSuite) TestMultipleDMLEventsInBatch() { + ctx := context.Background() + + var err error + + // Create table with id and email columns + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100));", getTestTableName())) + suite.Require().NoError(err) + + // Create ghost table with unique index on email + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, email VARCHAR(100), UNIQUE KEY email_unique (email));", getTestGhostTableName())) + suite.Require().NoError(err) + + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + + migrationContext.PanicOnWarnings = true + + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "email"}) + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: "PRIMARY", + NameInGhostTable: "PRIMARY", + Columns: *sql.NewColumnList([]string{"id"}), + } + + applier := NewApplier(migrationContext) + suite.Require().NoError(applier.prepareQueries()) + defer applier.Teardown() + + err = applier.InitDBConnections() + suite.Require().NoError(err) + + // Insert initial rows into ghost table + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, email) VALUES (1, 'alice@example.com'), (3, 'charlie@example.com');", getTestGhostTableName())) + suite.Require().NoError(err) + + // Simulate multiple binlog events in a batch: + // 1. Duplicate on PRIMARY KEY (allowed - expected during binlog replay) + // 2. Duplicate on email index (should fail) ← FAILURE IN MIDDLE + // 3. Valid insert (would succeed) ← SUCCESS AFTER FAILURE + // + // The critical test: Even though event #3 would succeed on its own, it must be rolled back + // because event #2 failed. This proves the entire batch is truly atomic. + dmlEvents := []*binlog.BinlogDMLEvent{ + { + DatabaseName: testMysqlDatabase, + TableName: testMysqlTableName, + DML: binlog.InsertDML, + NewColumnValues: sql.ToColumnValues([]interface{}{1, "alice@example.com"}), // duplicate PRIMARY (normally allowed) + }, + { + DatabaseName: testMysqlDatabase, + TableName: testMysqlTableName, + DML: binlog.InsertDML, + NewColumnValues: sql.ToColumnValues([]interface{}{4, "alice@example.com"}), // duplicate email (FAILS) + }, + { + DatabaseName: testMysqlDatabase, + TableName: testMysqlTableName, + DML: binlog.InsertDML, + NewColumnValues: sql.ToColumnValues([]interface{}{2, "bob@example.com"}), // valid insert (would succeed) + }, + } + + // Should fail due to the second event + err = applier.ApplyDMLEventQueries(dmlEvents) + suite.Require().Error(err) + suite.Require().Contains(err.Error(), "Duplicate entry") + + // Verify that the entire batch was rolled back - still only the original 2 rows + // Critically: id=2 (bob@example.com) from event #3 should NOT be present + rows, err := suite.db.Query("SELECT id, email FROM " + getTestGhostTableName() + " ORDER BY id") + suite.Require().NoError(err) + defer rows.Close() + + var results []struct { + id int + email string + } + for rows.Next() { + var id int + var email string + err = rows.Scan(&id, &email) + suite.Require().NoError(err) + results = append(results, struct { + id int + email string + }{id, email}) + } + suite.Require().NoError(rows.Err()) + + // Should still have exactly 2 original rows (entire batch was rolled back) + // This proves that even event #3 (which would have succeeded) was rolled back + suite.Require().Len(results, 2) + suite.Require().Equal(1, results[0].id) + suite.Require().Equal("alice@example.com", results[0].email) + suite.Require().Equal(3, results[1].id) + suite.Require().Equal("charlie@example.com", results[1].email) + // Critically: id=2 (bob@example.com) is NOT present, proving event #3 was rolled back } func TestApplier(t *testing.T) { diff --git a/go/logic/checkpoint.go b/go/logic/checkpoint.go index 69cc2bd20..cffe08c4b 100644 --- a/go/logic/checkpoint.go +++ b/go/logic/checkpoint.go @@ -28,4 +28,5 @@ type Checkpoint struct { Iteration int64 RowsCopied int64 DMLApplied int64 + IsCutover bool } diff --git a/go/logic/hooks.go b/go/logic/hooks.go index 2543f8e9a..a1853e7ed 100644 --- a/go/logic/hooks.go +++ b/go/logic/hooks.go @@ -28,6 +28,7 @@ const ( onInteractiveCommand = "gh-ost-on-interactive-command" onSuccess = "gh-ost-on-success" onFailure = "gh-ost-on-failure" + onBatchCopyRetry = "gh-ost-on-batch-copy-retry" onStatus = "gh-ost-on-status" onStopReplication = "gh-ost-on-stop-replication" onStartReplication = "gh-ost-on-start-replication" @@ -69,6 +70,7 @@ func (this *HooksExecutor) applyEnvironmentVariables(extraVariables ...string) [ env = append(env, fmt.Sprintf("GH_OST_HOOKS_HINT_OWNER=%s", this.migrationContext.HooksHintOwner)) env = append(env, fmt.Sprintf("GH_OST_HOOKS_HINT_TOKEN=%s", this.migrationContext.HooksHintToken)) env = append(env, fmt.Sprintf("GH_OST_DRY_RUN=%t", this.migrationContext.Noop)) + env = append(env, fmt.Sprintf("GH_OST_REVERT=%t", this.migrationContext.Revert)) env = append(env, extraVariables...) return env @@ -77,6 +79,7 @@ func (this *HooksExecutor) applyEnvironmentVariables(extraVariables ...string) [ // executeHook executes a command, and sets relevant environment variables // combined output & error are printed to the configured writer. func (this *HooksExecutor) executeHook(hook string, extraVariables ...string) error { + this.migrationContext.Log.Infof("executing hook: %+v", hook) cmd := exec.Command(hook) cmd.Env = this.applyEnvironmentVariables(extraVariables...) @@ -123,6 +126,11 @@ func (this *HooksExecutor) onBeforeRowCopy() error { return this.executeHooks(onBeforeRowCopy) } +func (this *HooksExecutor) onBatchCopyRetry(errorMessage string) error { + v := fmt.Sprintf("GH_OST_LAST_BATCH_COPY_ERROR=%s", errorMessage) + return this.executeHooks(onBatchCopyRetry, v) +} + func (this *HooksExecutor) onRowCopyComplete() error { return this.executeHooks(onRowCopyComplete) } @@ -140,8 +148,9 @@ func (this *HooksExecutor) onInteractiveCommand(command string) error { return this.executeHooks(onInteractiveCommand, v) } -func (this *HooksExecutor) onSuccess() error { - return this.executeHooks(onSuccess) +func (this *HooksExecutor) onSuccess(instantDDL bool) error { + v := fmt.Sprintf("GH_OST_INSTANT_DDL=%t", instantDDL) + return this.executeHooks(onSuccess, v) } func (this *HooksExecutor) onFailure() error { diff --git a/go/logic/hooks_test.go b/go/logic/hooks_test.go index 8fdcf9ec8..94ea62b07 100644 --- a/go/logic/hooks_test.go +++ b/go/logic/hooks_test.go @@ -105,6 +105,8 @@ func TestHooksExecutorExecuteHooks(t *testing.T) { require.Equal(t, 50.0, progress) case "GH_OST_TABLE_NAME": require.Equal(t, migrationContext.OriginalTableName, split[1]) + case "GH_OST_INSTANT_DDL": + require.Equal(t, "false", split[1]) case "TEST": require.Equal(t, t.Name(), split[1]) } diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 044360153..97895890d 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -596,7 +596,7 @@ func (this *Inspector) validateGhostTriggersDontExist() error { var foundTriggers []string for _, trigger := range this.migrationContext.Triggers { triggerName := this.migrationContext.GetGhostTriggerName(trigger.Name) - query := "select 1 from information_schema.triggers where trigger_name = ? and trigger_schema = ? and event_object_table = ?" + query := "select 1 from information_schema.triggers where trigger_name = ? and trigger_schema = ?" err := sqlutils.QueryRowsMap(this.db, query, func(rowMap sqlutils.RowMap) error { triggerExists := rowMap.GetInt("1") if triggerExists == 1 { @@ -606,7 +606,6 @@ func (this *Inspector) validateGhostTriggersDontExist() error { }, triggerName, this.migrationContext.DatabaseName, - this.migrationContext.OriginalTableName, ) if err != nil { return err diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 33271d01a..e282147ae 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -44,6 +44,11 @@ func ReadChangelogState(s string) ChangelogState { type tableWriteFunc func() error +type lockProcessedStruct struct { + state string + coords mysql.BinlogCoordinates +} + type applyEventStruct struct { writeFunc *tableWriteFunc dmlEvent *binlog.BinlogDMLEvent @@ -85,7 +90,8 @@ type Migrator struct { firstThrottlingCollected chan bool ghostTableMigrated chan bool rowCopyComplete chan error - allEventsUpToLockProcessed chan string + allEventsUpToLockProcessed chan *lockProcessedStruct + lastLockProcessed *lockProcessedStruct rowCopyCompleteFlag int64 // copyRowsQueue should not be buffered; if buffered some non-damaging but @@ -105,7 +111,7 @@ func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator { ghostTableMigrated: make(chan bool), firstThrottlingCollected: make(chan bool, 3), rowCopyComplete: make(chan error), - allEventsUpToLockProcessed: make(chan string), + allEventsUpToLockProcessed: make(chan *lockProcessedStruct), copyRowsQueue: make(chan tableWriteFunc), applyEventsQueue: make(chan *applyEventStruct, base.MaxEventsBatchSize), @@ -118,6 +124,10 @@ func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator { // (or fails with error) func (this *Migrator) sleepWhileTrue(operation func() (bool, error)) error { for { + // Check for abort before continuing + if err := this.checkAbort(); err != nil { + return err + } shouldSleep, err := operation() if err != nil { return err @@ -129,6 +139,18 @@ func (this *Migrator) sleepWhileTrue(operation func() (bool, error)) error { } } +func (this *Migrator) retryBatchCopyWithHooks(operation func() error, notFatalHint ...bool) (err error) { + wrappedOperation := func() error { + if err := operation(); err != nil { + this.hooksExecutor.onBatchCopyRetry(err.Error()) + return err + } + return nil + } + + return this.retryOperation(wrappedOperation, notFatalHint...) +} + // retryOperation attempts up to `count` attempts at running given function, // exiting as soon as it returns with non-error. func (this *Migrator) retryOperation(operation func() error, notFatalHint ...bool) (err error) { @@ -138,14 +160,26 @@ func (this *Migrator) retryOperation(operation func() error, notFatalHint ...boo // sleep after previous iteration RetrySleepFn(1 * time.Second) } + // Check for abort/context cancellation before each retry + if abortErr := this.checkAbort(); abortErr != nil { + return abortErr + } err = operation() if err == nil { return nil } + // Check if this is an unrecoverable error (data consistency issues won't resolve on retry) + if strings.Contains(err.Error(), "warnings detected") { + if len(notFatalHint) == 0 { + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) + } + return err + } // there's an error. Let's try again. } if len(notFatalHint) == 0 { - this.migrationContext.PanicAbort <- err + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) } return err } @@ -167,13 +201,25 @@ func (this *Migrator) retryOperationWithExponentialBackoff(operation func() erro if i != 0 { RetrySleepFn(time.Duration(interval) * time.Second) } + // Check for abort/context cancellation before each retry + if abortErr := this.checkAbort(); abortErr != nil { + return abortErr + } err = operation() if err == nil { return nil } + // Check if this is an unrecoverable error (data consistency issues won't resolve on retry) + if strings.Contains(err.Error(), "warnings detected") { + if len(notFatalHint) == 0 { + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) + } + return err + } } if len(notFatalHint) == 0 { - this.migrationContext.PanicAbort <- err + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) } return err } @@ -182,14 +228,19 @@ func (this *Migrator) retryOperationWithExponentialBackoff(operation func() erro // consumes and drops any further incoming events that may be left hanging. func (this *Migrator) consumeRowCopyComplete() { if err := <-this.rowCopyComplete; err != nil { - this.migrationContext.PanicAbort <- err + // Abort synchronously to ensure checkAbort() sees the error immediately + this.abort(err) + // Don't mark row copy as complete if there was an error + return } atomic.StoreInt64(&this.rowCopyCompleteFlag, 1) this.migrationContext.MarkRowCopyEndTime() go func() { for err := range this.rowCopyComplete { if err != nil { - this.migrationContext.PanicAbort <- err + // Abort synchronously to ensure the error is stored immediately + this.abort(err) + return } } }() @@ -220,11 +271,14 @@ func (this *Migrator) onChangelogStateEvent(dmlEntry *binlog.BinlogEntry) (err e case Migrated, ReadMigrationRangeValues: // no-op event case GhostTableMigrated: - this.ghostTableMigrated <- true + // Use helper to prevent deadlock if migration aborts before receiver is ready + _ = base.SendWithContext(this.migrationContext.GetContext(), this.ghostTableMigrated, true) case AllEventsUpToLockProcessed: var applyEventFunc tableWriteFunc = func() error { - this.allEventsUpToLockProcessed <- changelogStateString - return nil + return base.SendWithContext(this.migrationContext.GetContext(), this.allEventsUpToLockProcessed, &lockProcessedStruct{ + state: changelogStateString, + coords: dmlEntry.Coordinates.Clone(), + }) } // at this point we know all events up to lock have been read from the streamer, // because the streamer works sequentially. So those events are either already handled, @@ -232,7 +286,8 @@ func (this *Migrator) onChangelogStateEvent(dmlEntry *binlog.BinlogEntry) (err e // So as not to create a potential deadlock, we write this func to applyEventsQueue // asynchronously, understanding it doesn't really matter. go func() { - this.applyEventsQueue <- newApplyEventStructByFunc(&applyEventFunc) + // Use helper to prevent deadlock if buffer fills and executeWriteFuncs exits + _ = base.SendWithContext(this.migrationContext.GetContext(), this.applyEventsQueue, newApplyEventStructByFunc(&applyEventFunc)) }() default: return fmt.Errorf("Unknown changelog state: %+v", changelogState) @@ -256,10 +311,24 @@ func (this *Migrator) onChangelogHeartbeatEvent(dmlEntry *binlog.BinlogEntry) (e } } -// listenOnPanicAbort aborts on abort request +// abort stores the error, cancels the context, and logs the abort. +// This is the common abort logic used by both listenOnPanicAbort and +// consumeRowCopyComplete to ensure consistent error handling. +func (this *Migrator) abort(err error) { + // Store the error for Migrate() to return + this.migrationContext.SetAbortError(err) + + // Cancel the context to signal all goroutines to stop + this.migrationContext.CancelContext() + + // Log the error (but don't panic or exit) + this.migrationContext.Log.Errorf("Migration aborted: %v", err) +} + +// listenOnPanicAbort listens for fatal errors and initiates graceful shutdown func (this *Migrator) listenOnPanicAbort() { err := <-this.migrationContext.PanicAbort - this.migrationContext.Log.Fatale(err) + this.abort(err) } // validateAlterStatement validates the `alter` statement meets criteria. @@ -327,10 +396,36 @@ func (this *Migrator) createFlagFiles() (err error) { return nil } +// checkAbort returns abort error if migration was aborted +func (this *Migrator) checkAbort() error { + if abortErr := this.migrationContext.GetAbortError(); abortErr != nil { + return abortErr + } + + ctx := this.migrationContext.GetContext() + if ctx != nil { + select { + case <-ctx.Done(): + // Context cancelled but no abort error stored yet + if abortErr := this.migrationContext.GetAbortError(); abortErr != nil { + return abortErr + } + return ctx.Err() + default: + // Not cancelled + } + } + return nil +} + // Migrate executes the complete migration logic. This is *the* major gh-ost function. func (this *Migrator) Migrate() (err error) { this.migrationContext.Log.Infof("Migrating %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) this.migrationContext.StartTime = time.Now() + + // Ensure context is cancelled on exit (cleanup) + defer this.migrationContext.CancelContext() + if this.migrationContext.Hostname, err = os.Hostname(); err != nil { return err } @@ -354,18 +449,27 @@ func (this *Migrator) Migrate() (err error) { if err := this.initiateInspector(); err != nil { return err } + if err := this.checkAbort(); err != nil { + return err + } // If we are resuming, we will initiateStreaming later when we know - // the coordinates to resume streaming. + // the binlog coordinates to resume streaming from. // If not resuming, the streamer must be initiated before the applier, // so that the "GhostTableMigrated" event gets processed. if !this.migrationContext.Resume { if err := this.initiateStreaming(); err != nil { return err } + if err := this.checkAbort(); err != nil { + return err + } } if err := this.initiateApplier(); err != nil { return err } + if err := this.checkAbort(); err != nil { + return err + } if err := this.createFlagFiles(); err != nil { return err } @@ -380,7 +484,7 @@ func (this *Migrator) Migrate() (err error) { if err := this.finalCleanup(); err != nil { return nil } - if err := this.hooksExecutor.onSuccess(); err != nil { + if err := this.hooksExecutor.onSuccess(true); err != nil { return err } this.migrationContext.Log.Infof("Success! table %s.%s migrated instantly", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) @@ -461,7 +565,12 @@ func (this *Migrator) Migrate() (err error) { if err := this.hooksExecutor.onBeforeRowCopy(); err != nil { return err } - go this.executeWriteFuncs() + go func() { + if err := this.executeWriteFuncs(); err != nil { + // Send error to PanicAbort to trigger abort + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) + } + }() go this.iterateChunks() this.migrationContext.MarkRowCopyStartTime() go this.initiateStatus() @@ -472,6 +581,10 @@ func (this *Migrator) Migrate() (err error) { this.migrationContext.Log.Debugf("Operating until row copy is complete") this.consumeRowCopyComplete() this.migrationContext.Log.Infof("Row copy complete") + // Check if row copy was aborted due to error + if err := this.checkAbort(); err != nil { + return err + } if err := this.hooksExecutor.onRowCopyComplete(); err != nil { return err } @@ -495,13 +608,136 @@ func (this *Migrator) Migrate() (err error) { } atomic.StoreInt64(&this.migrationContext.CutOverCompleteFlag, 1) + if this.migrationContext.Checkpoint && !this.migrationContext.Noop { + cutoverChk, err := this.CheckpointAfterCutOver() + if err != nil { + this.migrationContext.Log.Warningf("failed to checkpoint after cutover: %+v", err) + } else { + this.migrationContext.Log.Infof("checkpoint success after cutover at coords=%+v", cutoverChk.LastTrxCoords.DisplayString()) + } + } + if err := this.finalCleanup(); err != nil { return nil } - if err := this.hooksExecutor.onSuccess(); err != nil { + if err := this.hooksExecutor.onSuccess(false); err != nil { return err } this.migrationContext.Log.Infof("Done migrating %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) + // Final check for abort before declaring success + if err := this.checkAbort(); err != nil { + return err + } + return nil +} + +// Revert reverts a migration that previously completed by applying all DML events that happened +// after the original cutover, then doing another cutover to swap the tables back. +// The steps are similar to Migrate(), but without row copying. +func (this *Migrator) Revert() error { + this.migrationContext.Log.Infof("Reverting %s.%s from %s.%s", + sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), + sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OldTableName)) + this.migrationContext.StartTime = time.Now() + + // Ensure context is cancelled on exit (cleanup) + defer this.migrationContext.CancelContext() + + var err error + if this.migrationContext.Hostname, err = os.Hostname(); err != nil { + return err + } + + go this.listenOnPanicAbort() + + if err := this.hooksExecutor.onStartup(); err != nil { + return err + } + if err := this.validateAlterStatement(); err != nil { + return err + } + defer this.teardown() + + if err := this.initiateInspector(); err != nil { + return err + } + if err := this.checkAbort(); err != nil { + return err + } + if err := this.initiateApplier(); err != nil { + return err + } + if err := this.checkAbort(); err != nil { + return err + } + if err := this.createFlagFiles(); err != nil { + return err + } + if err := this.inspector.inspectOriginalAndGhostTables(); err != nil { + return err + } + if err := this.applier.prepareQueries(); err != nil { + return err + } + + lastCheckpoint, err := this.applier.ReadLastCheckpoint() + if err != nil { + return this.migrationContext.Log.Errorf("No checkpoint found, unable to revert: %+v", err) + } + if !lastCheckpoint.IsCutover { + return this.migrationContext.Log.Errorf("Last checkpoint is not after cutover, unable to revert: coords=%+v time=%+v", lastCheckpoint.LastTrxCoords, lastCheckpoint.Timestamp) + } + this.migrationContext.InitialStreamerCoords = lastCheckpoint.LastTrxCoords + this.migrationContext.TotalRowsCopied = lastCheckpoint.RowsCopied + this.migrationContext.MigrationIterationRangeMinValues = lastCheckpoint.IterationRangeMin + this.migrationContext.MigrationIterationRangeMaxValues = lastCheckpoint.IterationRangeMax + if err := this.initiateStreaming(); err != nil { + return err + } + if err := this.checkAbort(); err != nil { + return err + } + if err := this.hooksExecutor.onValidated(); err != nil { + return err + } + if err := this.initiateServer(); err != nil { + return err + } + defer this.server.RemoveSocketFile() + if err := this.addDMLEventsListener(); err != nil { + return err + } + + this.initiateThrottler() + go this.initiateStatus() + go func() { + if err := this.executeDMLWriteFuncs(); err != nil { + // Send error to PanicAbort to trigger abort + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) + } + }() + + this.printStatus(ForcePrintStatusRule) + var retrier func(func() error, ...bool) error + if this.migrationContext.CutOverExponentialBackoff { + retrier = this.retryOperationWithExponentialBackoff + } else { + retrier = this.retryOperation + } + if err := this.hooksExecutor.onBeforeCutOver(); err != nil { + return err + } + if err := retrier(this.cutOver); err != nil { + return err + } + atomic.StoreInt64(&this.migrationContext.CutOverCompleteFlag, 1) + if err := this.finalCleanup(); err != nil { + return nil + } + if err := this.hooksExecutor.onSuccess(false); err != nil { + return err + } + this.migrationContext.Log.Infof("Done reverting %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) return nil } @@ -555,7 +791,7 @@ func (this *Migrator) cutOver() (err error) { this.migrationContext.MarkPointOfInterest() this.migrationContext.Log.Debugf("checking for cut-over postpone") - this.sleepWhileTrue( + if err := this.sleepWhileTrue( func() (bool, error) { heartbeatLag := this.migrationContext.TimeSinceLastHeartbeatOnChangelog() maxLagMillisecondsThrottle := time.Duration(atomic.LoadInt64(&this.migrationContext.MaxLagMillisecondsThrottleThreshold)) * time.Millisecond @@ -583,7 +819,9 @@ func (this *Migrator) cutOver() (err error) { } return false, nil }, - ) + ); err != nil { + return err + } atomic.StoreInt64(&this.migrationContext.IsPostponingCutOver, 0) this.migrationContext.MarkPointOfInterest() this.migrationContext.Log.Debugf("checking for cut-over postpone: complete") @@ -622,7 +860,7 @@ func (this *Migrator) cutOver() (err error) { // Inject the "AllEventsUpToLockProcessed" state hint, wait for it to appear in the binary logs, // make sure the queue is drained. -func (this *Migrator) waitForEventsUpToLock() (err error) { +func (this *Migrator) waitForEventsUpToLock() error { timeout := time.NewTimer(time.Second * time.Duration(this.migrationContext.CutOverLockTimeoutSeconds)) this.migrationContext.MarkPointOfInterest() @@ -635,19 +873,21 @@ func (this *Migrator) waitForEventsUpToLock() (err error) { } this.migrationContext.Log.Infof("Waiting for events up to lock") atomic.StoreInt64(&this.migrationContext.AllEventsUpToLockProcessedInjectedFlag, 1) + var lockProcessed *lockProcessedStruct for found := false; !found; { select { case <-timeout.C: { return this.migrationContext.Log.Errorf("Timeout while waiting for events up to lock") } - case state := <-this.allEventsUpToLockProcessed: + case lockProcessed = <-this.allEventsUpToLockProcessed: { - if state == allEventsUpToLockProcessedChallenge { - this.migrationContext.Log.Infof("Waiting for events up to lock: got %s", state) + if lockProcessed.state == allEventsUpToLockProcessedChallenge { + this.migrationContext.Log.Infof("Waiting for events up to lock: got %s", lockProcessed.state) found = true + this.lastLockProcessed = lockProcessed } else { - this.migrationContext.Log.Infof("Waiting for events up to lock: skipping %s", state) + this.migrationContext.Log.Infof("Waiting for events up to lock: skipping %s", lockProcessed.state) } } } @@ -730,7 +970,7 @@ func (this *Migrator) atomicCutOver() (err error) { // If we need to create triggers we need to do it here (only create part) if this.migrationContext.IncludeTriggers && len(this.migrationContext.Triggers) > 0 { if err := this.applier.CreateTriggersOnGhost(); err != nil { - this.migrationContext.Log.Errore(err) + return this.migrationContext.Log.Errore(err) } } @@ -1169,7 +1409,8 @@ func (this *Migrator) initiateStreaming() error { this.migrationContext.Log.Debugf("Beginning streaming") err := this.eventsStreamer.StreamEvents(this.canStopStreaming) if err != nil { - this.migrationContext.PanicAbort <- err + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) } this.migrationContext.Log.Debugf("Done streaming") }() @@ -1195,8 +1436,9 @@ func (this *Migrator) addDMLEventsListener() error { this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, func(dmlEntry *binlog.BinlogEntry) error { - this.applyEventsQueue <- newApplyEventStructByDML(dmlEntry) - return nil + // Use helper to prevent deadlock if buffer fills and executeWriteFuncs exits + // This is critical because this callback blocks the event streamer + return base.SendWithContext(this.migrationContext.GetContext(), this.applyEventsQueue, newApplyEventStructByDML(dmlEntry)) }, ) return err @@ -1220,7 +1462,12 @@ func (this *Migrator) initiateApplier() error { if err := this.applier.InitDBConnections(); err != nil { return err } - if !this.migrationContext.Resume { + if this.migrationContext.Revert { + if err := this.applier.CreateChangelogTable(); err != nil { + this.migrationContext.Log.Errorf("Unable to create changelog table, see further error details. Perhaps a previous migration failed without dropping the table? OR is there a running migration? Bailing out") + return err + } + } else if !this.migrationContext.Resume { if err := this.applier.ValidateOrDropExistingTables(); err != nil { return err } @@ -1245,12 +1492,22 @@ func (this *Migrator) initiateApplier() error { return err } } - this.applier.WriteChangelogState(string(GhostTableMigrated)) + if _, err := this.applier.WriteChangelogState(string(GhostTableMigrated)); err != nil { + return err + } } + + // ensure performance_schema.metadata_locks is available. if err := this.applier.StateMetadataLockInstrument(); err != nil { - this.migrationContext.Log.Errorf("Unable to enable metadata lock instrument, see further error details. Bailing out") - return err + this.migrationContext.Log.Warning("Unable to enable metadata lock instrument, see further error details.") } + if !this.migrationContext.IsOpenMetadataLockInstruments { + if !this.migrationContext.SkipMetadataLockCheck { + return this.migrationContext.Log.Errorf("Bailing out because metadata lock instrument not enabled. Use --skip-metadata-lock-check if you wish to proceed without. See https://github.com/github/gh-ost/pull/1536 for details.") + } + this.migrationContext.Log.Warning("Proceeding without metadata lock check. There is a small chance of data loss if another session accesses the ghost table during cut-over. See https://github.com/github/gh-ost/pull/1536 for details.") + } + go this.applier.InitiateHeartbeat() return nil } @@ -1259,7 +1516,7 @@ func (this *Migrator) initiateApplier() error { // a chunk of rows onto the ghost table. func (this *Migrator) iterateChunks() error { terminateRowIteration := func(err error) error { - this.rowCopyComplete <- err + _ = base.SendWithContext(this.migrationContext.GetContext(), this.rowCopyComplete, err) return this.migrationContext.Log.Errore(err) } if this.migrationContext.Noop { @@ -1274,33 +1531,33 @@ func (this *Migrator) iterateChunks() error { var hasNoFurtherRangeFlag int64 // Iterate per chunk: for { + if err := this.checkAbort(); err != nil { + return terminateRowIteration(err) + } if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 || atomic.LoadInt64(&hasNoFurtherRangeFlag) == 1 { // Done // There's another such check down the line return nil } copyRowsFunc := func() error { - if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 || atomic.LoadInt64(&hasNoFurtherRangeFlag) == 1 { - // Done. - // There's another such check down the line - return nil - } - - // When hasFurtherRange is false, original table might be write locked and CalculateNextIterationRangeEndValues would hangs forever - - hasFurtherRange := false - if err := this.retryOperation(func() (e error) { - hasFurtherRange, e = this.applier.CalculateNextIterationRangeEndValues() - return e - }); err != nil { - return terminateRowIteration(err) - } - if !hasFurtherRange { - atomic.StoreInt64(&hasNoFurtherRangeFlag, 1) - return terminateRowIteration(nil) - } + this.migrationContext.SetNextIterationRangeMinValues() // Copy task: applyCopyRowsFunc := func() error { + if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 || atomic.LoadInt64(&hasNoFurtherRangeFlag) == 1 { + // Done. + // There's another such check down the line + return nil + } + + // When hasFurtherRange is false, original table might be write locked and CalculateNextIterationRangeEndValues would hangs forever + hasFurtherRange, err := this.applier.CalculateNextIterationRangeEndValues() + if err != nil { + return err // wrapping call will retry + } + if !hasFurtherRange { + atomic.StoreInt64(&hasNoFurtherRangeFlag, 1) + return terminateRowIteration(nil) + } if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 { // No need for more writes. // This is the de-facto place where we avoid writing in the event of completed cut-over. @@ -1323,7 +1580,7 @@ func (this *Migrator) iterateChunks() error { this.migrationContext.Log.Infof("ApplyIterationInsertQuery has SQL warnings! %s", warning) } joinedWarnings := strings.Join(this.migrationContext.MigrationLastInsertSQLWarnings, "; ") - terminateRowIteration(fmt.Errorf("ApplyIterationInsertQuery failed because of SQL warnings: [%s]", joinedWarnings)) + return terminateRowIteration(fmt.Errorf("ApplyIterationInsertQuery failed because of SQL warnings: [%s]", joinedWarnings)) } } @@ -1331,13 +1588,29 @@ func (this *Migrator) iterateChunks() error { atomic.AddInt64(&this.migrationContext.Iteration, 1) return nil } - if err := this.retryOperation(applyCopyRowsFunc); err != nil { + if err := this.retryBatchCopyWithHooks(applyCopyRowsFunc); err != nil { return terminateRowIteration(err) } + + // record last successfully copied range + this.applier.LastIterationRangeMutex.Lock() + if this.migrationContext.MigrationIterationRangeMinValues != nil && this.migrationContext.MigrationIterationRangeMaxValues != nil { + this.applier.LastIterationRangeMinValues = this.migrationContext.MigrationIterationRangeMinValues.Clone() + this.applier.LastIterationRangeMaxValues = this.migrationContext.MigrationIterationRangeMaxValues.Clone() + } + this.applier.LastIterationRangeMutex.Unlock() + return nil } // Enqueue copy operation; to be executed by executeWriteFuncs() - this.copyRowsQueue <- copyRowsFunc + // Use helper to prevent deadlock if executeWriteFuncs exits + if err := base.SendWithContext(this.migrationContext.GetContext(), this.copyRowsQueue, copyRowsFunc); err != nil { + // Context cancelled, check for abort and exit + if abortErr := this.checkAbort(); abortErr != nil { + return terminateRowIteration(abortErr) + } + return terminateRowIteration(err) + } } } @@ -1418,23 +1691,50 @@ func (this *Migrator) Checkpoint(ctx context.Context) (*Checkpoint, error) { this.applier.LastIterationRangeMutex.Unlock() for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - this.applier.CurrentCoordinatesMutex.Lock() - if coords.SmallerThanOrEquals(this.applier.CurrentCoordinates) { - id, err := this.applier.WriteCheckpoint(chk) - chk.Id = id - this.applier.CurrentCoordinatesMutex.Unlock() - return chk, err - } + if err := ctx.Err(); err != nil { + return nil, err + } + this.applier.CurrentCoordinatesMutex.Lock() + if coords.SmallerThanOrEquals(this.applier.CurrentCoordinates) { + id, err := this.applier.WriteCheckpoint(chk) + chk.Id = id this.applier.CurrentCoordinatesMutex.Unlock() - time.Sleep(500 * time.Millisecond) + return chk, err } + this.applier.CurrentCoordinatesMutex.Unlock() + time.Sleep(500 * time.Millisecond) } } +// CheckpointAfterCutOver writes a final checkpoint after the cutover completes successfully. +func (this *Migrator) CheckpointAfterCutOver() (*Checkpoint, error) { + if this.lastLockProcessed == nil || this.lastLockProcessed.coords.IsEmpty() { + return nil, this.migrationContext.Log.Errorf("lastLockProcessed coords are empty") + } + + chk := &Checkpoint{ + IsCutover: true, + LastTrxCoords: this.lastLockProcessed.coords, + IterationRangeMin: sql.NewColumnValues(this.migrationContext.UniqueKey.Len()), + IterationRangeMax: sql.NewColumnValues(this.migrationContext.UniqueKey.Len()), + Iteration: this.migrationContext.GetIteration(), + RowsCopied: atomic.LoadInt64(&this.migrationContext.TotalRowsCopied), + DMLApplied: atomic.LoadInt64(&this.migrationContext.TotalDMLEventsApplied), + } + this.applier.LastIterationRangeMutex.Lock() + if this.applier.LastIterationRangeMinValues != nil { + chk.IterationRangeMin = this.applier.LastIterationRangeMinValues.Clone() + } + if this.applier.LastIterationRangeMaxValues != nil { + chk.IterationRangeMax = this.applier.LastIterationRangeMaxValues.Clone() + } + this.applier.LastIterationRangeMutex.Unlock() + + id, err := this.applier.WriteCheckpoint(chk) + chk.Id = id + return chk, err +} + func (this *Migrator) checkpointLoop() { if this.migrationContext.Noop { this.migrationContext.Log.Debugf("Noop operation; not really checkpointing") @@ -1443,9 +1743,12 @@ func (this *Migrator) checkpointLoop() { checkpointInterval := time.Duration(this.migrationContext.CheckpointIntervalSeconds) * time.Second ticker := time.NewTicker(checkpointInterval) for t := range ticker.C { - if atomic.LoadInt64(&this.finishedMigrating) > 0 { + if atomic.LoadInt64(&this.finishedMigrating) > 0 || atomic.LoadInt64(&this.migrationContext.CutOverCompleteFlag) > 0 { return } + if atomic.LoadInt64(&this.migrationContext.InCutOverCriticalSectionFlag) > 0 { + continue + } this.migrationContext.Log.Infof("starting checkpoint at %+v", t) ctx, cancel := context.WithTimeout(context.Background(), checkpointTimeout) chk, err := this.Checkpoint(ctx) @@ -1472,6 +1775,9 @@ func (this *Migrator) executeWriteFuncs() error { return nil } for { + if err := this.checkAbort(); err != nil { + return err + } if atomic.LoadInt64(&this.finishedMigrating) > 0 { return nil } @@ -1517,6 +1823,31 @@ func (this *Migrator) executeWriteFuncs() error { } } +func (this *Migrator) executeDMLWriteFuncs() error { + if this.migrationContext.Noop { + this.migrationContext.Log.Debugf("Noop operation; not really executing DML write funcs") + return nil + } + for { + if atomic.LoadInt64(&this.finishedMigrating) > 0 { + return nil + } + + this.throttler.throttle(nil) + + select { + case eventStruct := <-this.applyEventsQueue: + { + if err := this.onApplyEventStruct(eventStruct); err != nil { + return err + } + } + case <-time.After(time.Second): + continue + } + } +} + // finalCleanup takes actions at very end of migration, dropping tables etc. func (this *Migrator) finalCleanup() error { atomic.StoreInt64(&this.migrationContext.CleanupImminentFlag, 1) @@ -1541,17 +1872,19 @@ func (this *Migrator) finalCleanup() error { if err := this.retryOperation(this.applier.DropChangelogTable); err != nil { return err } - if err := this.retryOperation(this.applier.DropCheckpointTable); err != nil { - return err - } if this.migrationContext.OkToDropTable && !this.migrationContext.TestOnReplica { if err := this.retryOperation(this.applier.DropOldTable); err != nil { return err } - } else { - if !this.migrationContext.Noop { - this.migrationContext.Log.Infof("Am not dropping old table because I want this operation to be as live as possible. If you insist I should do it, please add `--ok-to-drop-table` next time. But I prefer you do not. To drop the old table, issue:") - this.migrationContext.Log.Infof("-- drop table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetOldTableName())) + if err := this.retryOperation(this.applier.DropCheckpointTable); err != nil { + return err + } + } else if !this.migrationContext.Noop { + this.migrationContext.Log.Infof("Am not dropping old table because I want this operation to be as live as possible. If you insist I should do it, please add `--ok-to-drop-table` next time. But I prefer you do not. To drop the old table, issue:") + this.migrationContext.Log.Infof("-- drop table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetOldTableName())) + if this.migrationContext.Checkpoint { + this.migrationContext.Log.Infof("Am not dropping checkpoint table without `--ok-to-drop-table`. To drop the checkpoint table, issue:") + this.migrationContext.Log.Infof("-- drop table %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.GetCheckpointTableName())) } } if this.migrationContext.Noop { diff --git a/go/logic/migrator_test.go b/go/logic/migrator_test.go index b268054ab..f731035e1 100644 --- a/go/logic/migrator_test.go +++ b/go/logic/migrator_test.go @@ -6,10 +6,12 @@ package logic import ( + "bytes" "context" gosql "database/sql" "errors" "fmt" + "io" "os" "path/filepath" "strings" @@ -30,7 +32,6 @@ import ( "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/wait" ) func TestMigratorOnChangelogEvent(t *testing.T) { @@ -300,7 +301,7 @@ func (suite *MigratorTestSuite) SetupSuite() { testmysql.WithDatabase(testMysqlDatabase), testmysql.WithUsername(testMysqlUser), testmysql.WithPassword(testMysqlPass), - testcontainers.WithWaitStrategy(wait.ForExposedPort()), + testmysql.WithConfigFile("my.cnf.test"), ) suite.Require().NoError(err) @@ -324,6 +325,8 @@ func (suite *MigratorTestSuite) SetupTest() { _, err := suite.db.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS "+testMysqlDatabase) suite.Require().NoError(err) + + os.Remove("/tmp/gh-ost.sock") } func (suite *MigratorTestSuite) TearDownTest() { @@ -333,6 +336,10 @@ func (suite *MigratorTestSuite) TearDownTest() { suite.Require().NoError(err) _, err = suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestGhostTableName()) suite.Require().NoError(err) + _, err = suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestRevertedTableName()) + suite.Require().NoError(err) + _, err = suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestOldTableName()) + suite.Require().NoError(err) } func (suite *MigratorTestSuite) TestMigrateEmpty() { @@ -379,6 +386,126 @@ func (suite *MigratorTestSuite) TestMigrateEmpty() { suite.Require().Equal("_testing_del", tableName) } +func (suite *MigratorTestSuite) TestRetryBatchCopyWithHooks() { + ctx := context.Background() + + _, err := suite.db.ExecContext(ctx, "CREATE TABLE test.test_retry_batch (id INT PRIMARY KEY AUTO_INCREMENT, name TEXT)") + suite.Require().NoError(err) + + const initStride = 1000 + const totalBatches = 3 + for i := 0; i < totalBatches; i++ { + dataSize := 50 * i + for j := 0; j < initStride; j++ { + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO test.test_retry_batch (name) VALUES ('%s')", strings.Repeat("a", dataSize))) + suite.Require().NoError(err) + } + } + + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("SET GLOBAL max_binlog_cache_size = %d", 1024*8)) + suite.Require().NoError(err) + defer func() { + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("SET GLOBAL max_binlog_cache_size = %d", 1024*1024*1024)) + suite.Require().NoError(err) + }() + + tmpDir, err := os.MkdirTemp("", "gh-ost-hooks") + suite.Require().NoError(err) + defer os.RemoveAll(tmpDir) + + hookScript := filepath.Join(tmpDir, "gh-ost-on-batch-copy-retry") + hookContent := `#!/bin/bash +# Mock hook that reduces chunk size on binlog cache error +ERROR_MSG="$GH_OST_LAST_BATCH_COPY_ERROR" +SOCKET_PATH="/tmp/gh-ost.sock" + +if ! [[ "$ERROR_MSG" =~ "max_binlog_cache_size" ]]; then + echo "Nothing to do for error: $ERROR_MSG" + exit 0 +fi + +CHUNK_SIZE=$(echo "chunk-size=?" | nc -U $SOCKET_PATH | tr -d '\n') + +MIN_CHUNK_SIZE=10 +NEW_CHUNK_SIZE=$(( CHUNK_SIZE * 8 / 10 )) +if [ $NEW_CHUNK_SIZE -lt $MIN_CHUNK_SIZE ]; then + NEW_CHUNK_SIZE=$MIN_CHUNK_SIZE +fi + +if [ $CHUNK_SIZE -eq $NEW_CHUNK_SIZE ]; then + echo "Chunk size unchanged: $CHUNK_SIZE" + exit 0 +fi + +echo "[gh-ost-on-batch-copy-retry]: Changing chunk size from $CHUNK_SIZE to $NEW_CHUNK_SIZE" +echo "chunk-size=$NEW_CHUNK_SIZE" | nc -U $SOCKET_PATH +echo "[gh-ost-on-batch-copy-retry]: Done, exiting..." +` + err = os.WriteFile(hookScript, []byte(hookContent), 0755) + suite.Require().NoError(err) + + origStdout := os.Stdout + origStderr := os.Stderr + + rOut, wOut, _ := os.Pipe() + rErr, wErr, _ := os.Pipe() + os.Stdout = wOut + os.Stderr = wErr + + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + migrationContext := base.NewMigrationContext() + migrationContext.AllowedRunningOnMaster = true + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.InspectorConnectionConfig = connectionConfig + migrationContext.DatabaseName = "test" + migrationContext.SkipPortValidation = true + migrationContext.OriginalTableName = "test_retry_batch" + migrationContext.SetConnectionConfig("innodb") + migrationContext.AlterStatementOptions = "MODIFY name LONGTEXT, ENGINE=InnoDB" + migrationContext.ReplicaServerId = 99999 + migrationContext.HeartbeatIntervalMilliseconds = 100 + migrationContext.ThrottleHTTPIntervalMillis = 100 + migrationContext.ThrottleHTTPTimeoutMillis = 1000 + migrationContext.HooksPath = tmpDir + migrationContext.ChunkSize = 1000 + migrationContext.SetDefaultNumRetries(10) + migrationContext.ServeSocketFile = "/tmp/gh-ost.sock" + + migrator := NewMigrator(migrationContext, "0.0.0") + + err = migrator.Migrate() + suite.Require().NoError(err) + + wOut.Close() + wErr.Close() + os.Stdout = origStdout + os.Stderr = origStderr + + var bufOut, bufErr bytes.Buffer + io.Copy(&bufOut, rOut) + io.Copy(&bufErr, rErr) + + outStr := bufOut.String() + errStr := bufErr.String() + + suite.Assert().Contains(outStr, "chunk-size: 1000") + suite.Assert().Contains(errStr, "[gh-ost-on-batch-copy-retry]: Changing chunk size from 1000 to 800") + suite.Assert().Contains(outStr, "chunk-size: 800") + + suite.Assert().Contains(errStr, "[gh-ost-on-batch-copy-retry]: Changing chunk size from 800 to 640") + suite.Assert().Contains(outStr, "chunk-size: 640") + + suite.Assert().Contains(errStr, "[gh-ost-on-batch-copy-retry]: Changing chunk size from 640 to 512") + suite.Assert().Contains(outStr, "chunk-size: 512") + + var count int + err = suite.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM test.test_retry_batch").Scan(&count) + suite.Require().NoError(err) + suite.Assert().Equal(3000, count) +} + func (suite *MigratorTestSuite) TestCopierIntPK() { ctx := context.Background() @@ -587,6 +714,118 @@ func TestMigratorRetryWithExponentialBackoff(t *testing.T) { assert.Equal(t, tries, 100) } +func TestMigratorRetryAbortsOnContextCancellation(t *testing.T) { + oldRetrySleepFn := RetrySleepFn + defer func() { RetrySleepFn = oldRetrySleepFn }() + + migrationContext := base.NewMigrationContext() + migrationContext.SetDefaultNumRetries(100) + migrator := NewMigrator(migrationContext, "1.2.3") + + RetrySleepFn = func(duration time.Duration) { + // No sleep needed for this test + } + + var tries = 0 + retryable := func() error { + tries++ + if tries == 5 { + // Cancel context on 5th try + migrationContext.CancelContext() + } + return errors.New("Simulated error") + } + + result := migrator.retryOperation(retryable, false) + assert.Error(t, result) + // Should abort after 6 tries: 5 failures + 1 checkAbort detection + assert.True(t, tries <= 6, "Expected tries <= 6, got %d", tries) + // Verify we got context cancellation error + assert.Contains(t, result.Error(), "context canceled") +} + +func TestMigratorRetryWithExponentialBackoffAbortsOnContextCancellation(t *testing.T) { + oldRetrySleepFn := RetrySleepFn + defer func() { RetrySleepFn = oldRetrySleepFn }() + + migrationContext := base.NewMigrationContext() + migrationContext.SetDefaultNumRetries(100) + migrationContext.SetExponentialBackoffMaxInterval(42) + migrator := NewMigrator(migrationContext, "1.2.3") + + RetrySleepFn = func(duration time.Duration) { + // No sleep needed for this test + } + + var tries = 0 + retryable := func() error { + tries++ + if tries == 5 { + // Cancel context on 5th try + migrationContext.CancelContext() + } + return errors.New("Simulated error") + } + + result := migrator.retryOperationWithExponentialBackoff(retryable, false) + assert.Error(t, result) + // Should abort after 6 tries: 5 failures + 1 checkAbort detection + assert.True(t, tries <= 6, "Expected tries <= 6, got %d", tries) + // Verify we got context cancellation error + assert.Contains(t, result.Error(), "context canceled") +} + +func TestMigratorRetrySkipsRetriesForWarnings(t *testing.T) { + oldRetrySleepFn := RetrySleepFn + defer func() { RetrySleepFn = oldRetrySleepFn }() + + migrationContext := base.NewMigrationContext() + migrationContext.SetDefaultNumRetries(100) + migrator := NewMigrator(migrationContext, "1.2.3") + + RetrySleepFn = func(duration time.Duration) { + t.Fatal("Should not sleep/retry for warning errors") + } + + var tries = 0 + retryable := func() error { + tries++ + return errors.New("warnings detected in statement 1 of 1: [Warning: Duplicate entry 'test' for key 'idx' (1062)]") + } + + result := migrator.retryOperation(retryable, false) + assert.Error(t, result) + // Should only try once - no retries for warnings + assert.Equal(t, 1, tries, "Expected exactly 1 try (no retries) for warning error") + assert.Contains(t, result.Error(), "warnings detected") +} + +func TestMigratorRetryWithExponentialBackoffSkipsRetriesForWarnings(t *testing.T) { + oldRetrySleepFn := RetrySleepFn + defer func() { RetrySleepFn = oldRetrySleepFn }() + + migrationContext := base.NewMigrationContext() + migrationContext.SetDefaultNumRetries(100) + migrationContext.SetExponentialBackoffMaxInterval(42) + migrator := NewMigrator(migrationContext, "1.2.3") + + RetrySleepFn = func(duration time.Duration) { + t.Fatal("Should not sleep/retry for warning errors") + } + + var tries = 0 + retryable := func() error { + tries++ + return errors.New("warnings detected in statement 1 of 1: [Warning: Duplicate entry 'test' for key 'idx' (1062)]") + } + + result := migrator.retryOperationWithExponentialBackoff(retryable, false) + assert.Error(t, result) + // Should only try once - no retries for warnings + assert.Equal(t, 1, tries, "Expected exactly 1 try (no retries) for warning error") + assert.Contains(t, result.Error(), "warnings detected") +} + func (suite *MigratorTestSuite) TestCutOverLossDataCaseLockGhostBeforeRename() { ctx := context.Background() @@ -667,6 +906,518 @@ func (suite *MigratorTestSuite) TestCutOverLossDataCaseLockGhostBeforeRename() { suite.Require().Equal("CREATE TABLE `testing` (\n `id` int NOT NULL,\n `name` varchar(64) DEFAULT NULL,\n `foobar` varchar(255) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci", createTableSQL) } +func (suite *MigratorTestSuite) TestRevertEmpty() { + ctx := context.Background() + + _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, s CHAR(32))", getTestTableName())) + suite.Require().NoError(err) + + var oldTableName string + + // perform original migration + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + { + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.InspectorConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + migrationContext.AlterStatement = "ADD COLUMN newcol CHAR(32)" + migrationContext.Checkpoint = true + migrationContext.CheckpointIntervalSeconds = 10 + migrationContext.DropServeSocket = true + migrationContext.InitiallyDropOldTable = true + migrationContext.UseGTIDs = true + + migrator := NewMigrator(migrationContext, "0.0.0") + + err = migrator.Migrate() + oldTableName = migrationContext.GetOldTableName() + suite.Require().NoError(err) + } + + // revert the original migration + { + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.InspectorConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + migrationContext.DropServeSocket = true + migrationContext.UseGTIDs = true + migrationContext.Revert = true + migrationContext.OkToDropTable = true + migrationContext.OldTableName = oldTableName + + migrator := NewMigrator(migrationContext, "0.0.0") + + err = migrator.Revert() + suite.Require().NoError(err) + } +} + +func (suite *MigratorTestSuite) TestRevert() { + ctx := context.Background() + + _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, s CHAR(32))", getTestTableName())) + suite.Require().NoError(err) + + numRows := 0 + for range 100 { + _, err = suite.db.ExecContext(ctx, + fmt.Sprintf("INSERT INTO %s (id, s) VALUES (%d, MD5('%d'))", getTestTableName(), numRows, numRows)) + suite.Require().NoError(err) + numRows += 1 + } + + var oldTableName string + + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + // perform original migration + { + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.InspectorConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + migrationContext.AlterStatement = "ADD INDEX idx1 (s)" + migrationContext.Checkpoint = true + migrationContext.CheckpointIntervalSeconds = 10 + migrationContext.DropServeSocket = true + migrationContext.InitiallyDropOldTable = true + migrationContext.UseGTIDs = true + + migrator := NewMigrator(migrationContext, "0.0.0") + + err = migrator.Migrate() + oldTableName = migrationContext.GetOldTableName() + suite.Require().NoError(err) + } + + // do some writes + for range 100 { + _, err = suite.db.ExecContext(ctx, + fmt.Sprintf("INSERT INTO %s (id, s) VALUES (%d, MD5('%d'))", getTestTableName(), numRows, numRows)) + suite.Require().NoError(err) + numRows += 1 + } + for i := 0; i < numRows; i += 7 { + _, err = suite.db.ExecContext(ctx, + fmt.Sprintf("UPDATE %s SET s=MD5('%d') where id=%d", getTestTableName(), 2*i, i)) + suite.Require().NoError(err) + } + + // revert the original migration + { + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.InspectorConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + migrationContext.DropServeSocket = true + migrationContext.UseGTIDs = true + migrationContext.Revert = true + migrationContext.OldTableName = oldTableName + + migrator := NewMigrator(migrationContext, "0.0.0") + + err = migrator.Revert() + oldTableName = migrationContext.GetOldTableName() + suite.Require().NoError(err) + } + + // checksum original and reverted table + var _tableName, checksum1, checksum2 string + rows, err := suite.db.Query(fmt.Sprintf("CHECKSUM TABLE %s, %s", testMysqlTableName, oldTableName)) + suite.Require().NoError(err) + defer rows.Close() + suite.Require().True(rows.Next()) + suite.Require().NoError(rows.Scan(&_tableName, &checksum1)) + suite.Require().True(rows.Next()) + suite.Require().NoError(rows.Scan(&_tableName, &checksum2)) + suite.Require().NoError(rows.Err()) + + suite.Require().Equal(checksum1, checksum2) +} + func TestMigrator(t *testing.T) { suite.Run(t, new(MigratorTestSuite)) } + +func TestPanicAbort_PropagatesError(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Send an error to PanicAbort + testErr := errors.New("test abort error") + go func() { + migrationContext.PanicAbort <- testErr + }() + + // Wait a bit for error to be processed + time.Sleep(100 * time.Millisecond) + + // Verify error was stored + got := migrationContext.GetAbortError() + if got != testErr { //nolint:errorlint // Testing pointer equality for sentinel error + t.Errorf("Expected error %v, got %v", testErr, got) + } + + // Verify context was cancelled + ctx := migrationContext.GetContext() + select { + case <-ctx.Done(): + // Success - context was cancelled + default: + t.Error("Expected context to be cancelled") + } +} + +func TestPanicAbort_FirstErrorWins(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Send first error + err1 := errors.New("first error") + go func() { + migrationContext.PanicAbort <- err1 + }() + + // Wait for first error to be processed + time.Sleep(50 * time.Millisecond) + + // Try to send second error (should be ignored) + err2 := errors.New("second error") + migrationContext.SetAbortError(err2) + + // Verify only first error is stored + got := migrationContext.GetAbortError() + if got != err1 { //nolint:errorlint // Testing pointer equality for sentinel error + t.Errorf("Expected first error %v, got %v", err1, got) + } +} + +func TestAbort_AfterRowCopy(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Give listenOnPanicAbort time to start + time.Sleep(20 * time.Millisecond) + + // Simulate row copy error by sending to rowCopyComplete in a goroutine + // (unbuffered channel, so send must be async) + testErr := errors.New("row copy failed") + go func() { + migrator.rowCopyComplete <- testErr + }() + + // Consume the error (simulating what Migrate() does) + // This is a blocking call that waits for the error + migrator.consumeRowCopyComplete() + + // Wait for the error to be processed by listenOnPanicAbort + time.Sleep(50 * time.Millisecond) + + // Check that error was stored + if got := migrationContext.GetAbortError(); got == nil { + t.Fatal("Expected abort error to be stored after row copy error") + } else if got.Error() != "row copy failed" { + t.Errorf("Expected 'row copy failed', got %v", got) + } + + // Verify context was cancelled + ctx := migrationContext.GetContext() + select { + case <-ctx.Done(): + // Success + case <-time.After(1 * time.Second): + t.Error("Expected context to be cancelled after row copy error") + } +} + +func TestAbort_DuringInspection(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Simulate error during inspection phase + testErr := errors.New("inspection failed") + go func() { + time.Sleep(10 * time.Millisecond) + select { + case migrationContext.PanicAbort <- testErr: + case <-migrationContext.GetContext().Done(): + } + }() + + // Wait for abort to be processed + time.Sleep(50 * time.Millisecond) + + // Call checkAbort (simulating what Migrate() does after initiateInspector) + err := migrator.checkAbort() + if err == nil { + t.Fatal("Expected checkAbort to return error after abort during inspection") + } + + if err.Error() != "inspection failed" { + t.Errorf("Expected 'inspection failed', got %v", err) + } +} + +func TestAbort_DuringStreaming(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Simulate error from streaming goroutine + testErr := errors.New("streaming error") + go func() { + time.Sleep(10 * time.Millisecond) + // Use select pattern like actual code does + select { + case migrationContext.PanicAbort <- testErr: + case <-migrationContext.GetContext().Done(): + } + }() + + // Wait for abort to be processed + time.Sleep(50 * time.Millisecond) + + // Verify error stored and context cancelled + if got := migrationContext.GetAbortError(); got == nil { + t.Fatal("Expected abort error to be stored") + } else if got.Error() != "streaming error" { + t.Errorf("Expected 'streaming error', got %v", got) + } + + // Verify checkAbort catches it + err := migrator.checkAbort() + if err == nil { + t.Fatal("Expected checkAbort to return error after streaming abort") + } +} + +func TestRetryExhaustion_TriggersAbort(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrationContext.SetDefaultNumRetries(2) // Only 2 retries + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Operation that always fails + callCount := 0 + operation := func() error { + callCount++ + return errors.New("persistent failure") + } + + // Call retryOperation (with notFatalHint=false so it sends to PanicAbort) + err := migrator.retryOperation(operation) + + // Should have called operation MaxRetries times + if callCount != 2 { + t.Errorf("Expected 2 retry attempts, got %d", callCount) + } + + // Should return the error + if err == nil { + t.Fatal("Expected retryOperation to return error") + } + + // Wait for abort to be processed + time.Sleep(100 * time.Millisecond) + + // Verify error was sent to PanicAbort and stored + if got := migrationContext.GetAbortError(); got == nil { + t.Error("Expected abort error to be stored after retry exhaustion") + } + + // Verify context was cancelled + ctx := migrationContext.GetContext() + select { + case <-ctx.Done(): + // Success + default: + t.Error("Expected context to be cancelled after retry exhaustion") + } +} + +func TestRevert_AbortsOnError(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrationContext.Revert = true + migrationContext.OldTableName = "_test_del" + migrationContext.OriginalTableName = "test" + migrationContext.DatabaseName = "testdb" + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Simulate error during revert + testErr := errors.New("revert failed") + go func() { + time.Sleep(10 * time.Millisecond) + select { + case migrationContext.PanicAbort <- testErr: + case <-migrationContext.GetContext().Done(): + } + }() + + // Wait for abort to be processed + time.Sleep(50 * time.Millisecond) + + // Verify checkAbort catches it + err := migrator.checkAbort() + if err == nil { + t.Fatal("Expected checkAbort to return error during revert") + } + + if err.Error() != "revert failed" { + t.Errorf("Expected 'revert failed', got %v", err) + } + + // Verify context was cancelled + ctx := migrationContext.GetContext() + select { + case <-ctx.Done(): + // Success + default: + t.Error("Expected context to be cancelled during revert abort") + } +} + +func TestCheckAbort_ReturnsNilWhenNoError(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // No error has occurred + err := migrator.checkAbort() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } +} + +func TestCheckAbort_DetectsContextCancellation(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // Cancel context directly (without going through PanicAbort) + migrationContext.CancelContext() + + // checkAbort should detect the cancellation + err := migrator.checkAbort() + if err == nil { + t.Fatal("Expected checkAbort to return error when context is cancelled") + } +} + +func (suite *MigratorTestSuite) TestPanicOnWarningsDuplicateDuringCutoverWithHighRetries() { + ctx := context.Background() + + // Create table with email column (no unique constraint initially) + _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY AUTO_INCREMENT, email VARCHAR(100))", getTestTableName())) + suite.Require().NoError(err) + + // Insert initial rows with unique email values - passes pre-flight validation + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (email) VALUES ('user1@example.com')", getTestTableName())) + suite.Require().NoError(err) + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (email) VALUES ('user2@example.com')", getTestTableName())) + suite.Require().NoError(err) + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (email) VALUES ('user3@example.com')", getTestTableName())) + suite.Require().NoError(err) + + // Verify we have 3 rows + var count int + err = suite.db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", getTestTableName())).Scan(&count) + suite.Require().NoError(err) + suite.Require().Equal(3, count) + + // Create postpone flag file + tmpDir, err := os.MkdirTemp("", "gh-ost-postpone-test") + suite.Require().NoError(err) + defer os.RemoveAll(tmpDir) + postponeFlagFile := filepath.Join(tmpDir, "postpone.flag") + err = os.WriteFile(postponeFlagFile, []byte{}, 0644) + suite.Require().NoError(err) + + // Start migration in goroutine + done := make(chan error, 1) + go func() { + connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) + if err != nil { + done <- err + return + } + + migrationContext := newTestMigrationContext() + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.InspectorConnectionConfig = connectionConfig + migrationContext.SetConnectionConfig("innodb") + migrationContext.AlterStatementOptions = "ADD UNIQUE KEY unique_email_idx (email)" + migrationContext.HeartbeatIntervalMilliseconds = 100 + migrationContext.PostponeCutOverFlagFile = postponeFlagFile + migrationContext.PanicOnWarnings = true + + // High retry count + exponential backoff means retries will take a long time and fail the test if not properly aborted + migrationContext.SetDefaultNumRetries(30) + migrationContext.CutOverExponentialBackoff = true + migrationContext.SetExponentialBackoffMaxInterval(128) + + migrator := NewMigrator(migrationContext, "0.0.0") + + //nolint:contextcheck + done <- migrator.Migrate() + }() + + // Wait for migration to reach postponed state + // TODO replace this with an actual check for postponed state + time.Sleep(3 * time.Second) + + // Now insert a duplicate email value while migration is postponed + // This simulates data arriving during migration that would violate the unique constraint + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (email) VALUES ('user1@example.com')", getTestTableName())) + suite.Require().NoError(err) + + // Verify we now have 4 rows (including the duplicate) + err = suite.db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", getTestTableName())).Scan(&count) + suite.Require().NoError(err) + suite.Require().Equal(4, count) + + // Unpostpone the migration - gh-ost will now try to apply binlog events with the duplicate + err = os.Remove(postponeFlagFile) + suite.Require().NoError(err) + + // Wait for Migrate() to return - with timeout to detect if it hangs + select { + case migrateErr := <-done: + // Success - Migrate() returned + // It should return an error due to the duplicate + suite.Require().Error(migrateErr, "Expected migration to fail due to duplicate key violation") + suite.Require().Contains(migrateErr.Error(), "Duplicate entry", "Error should mention duplicate entry") + case <-time.After(5 * time.Minute): + suite.FailNow("Migrate() hung and did not return within 5 minutes - failure to abort on warnings in retry loop") + } + + // Verify all 4 rows are still in the original table (no silent data loss) + err = suite.db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", getTestTableName())).Scan(&count) + suite.Require().NoError(err) + suite.Require().Equal(4, count, "Original table should still have all 4 rows") + + // Verify both user1@example.com entries still exist + var duplicateCount int + err = suite.db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE email = 'user1@example.com'", getTestTableName())).Scan(&duplicateCount) + suite.Require().NoError(err) + suite.Require().Equal(2, duplicateCount, "Should have 2 duplicate email entries") +} diff --git a/go/logic/my.cnf.test b/go/logic/my.cnf.test new file mode 100644 index 000000000..2938c12cd --- /dev/null +++ b/go/logic/my.cnf.test @@ -0,0 +1,27 @@ +# mysql server configuration for testcontainer +[mysqld] +max_connections = 200 +innodb_log_file_size = 64M +innodb_flush_log_at_trx_commit = 2 +innodb_flush_method = O_DIRECT +skip-name-resolve +skip-ssl + +character-set-server = utf8mb4 +collation-server = utf8mb4_0900_ai_ci + +default-time-zone = '+00:00' + +sql_mode = STRICT_TRANS_TABLES,NO_ENGINE_SUBSTITUTION + +general_log = 0 +general_log_file = /var/log/mysql/general.log +slow_query_log = 0 +slow_query_log_file = /var/log/mysql/slow.log +long_query_time = 2 + +gtid_mode=ON +enforce_gtid_consistency=ON + +[client] +default-character-set = utf8mb4 diff --git a/go/logic/server.go b/go/logic/server.go index 45e5b2bd4..74097acb7 100644 --- a/go/logic/server.go +++ b/go/logic/server.go @@ -450,7 +450,8 @@ help # This message return NoPrintStatusRule, err } err := fmt.Errorf("User commanded 'panic'. The migration will be aborted without cleanup. Please drop the gh-ost tables before trying again.") - this.migrationContext.PanicAbort <- err + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) return NoPrintStatusRule, err } default: diff --git a/go/logic/streamer.go b/go/logic/streamer.go index 63afc3f3d..1c2635138 100644 --- a/go/logic/streamer.go +++ b/go/logic/streamer.go @@ -186,7 +186,12 @@ func (this *EventsStreamer) StreamEvents(canStopStreaming func() bool) error { // The next should block and execute forever, unless there's a serious error. var successiveFailures int var reconnectCoords mysql.BinlogCoordinates + ctx := this.migrationContext.GetContext() for { + // Check for context cancellation each iteration + if err := ctx.Err(); err != nil { + return err + } if canStopStreaming() { return nil } diff --git a/go/logic/streamer_test.go b/go/logic/streamer_test.go index 2c5d3886b..8e0b57f80 100644 --- a/go/logic/streamer_test.go +++ b/go/logic/streamer_test.go @@ -13,7 +13,6 @@ import ( "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/mysql" - "github.com/testcontainers/testcontainers-go/wait" "golang.org/x/sync/errgroup" ) @@ -31,7 +30,6 @@ func (suite *EventsStreamerTestSuite) SetupSuite() { mysql.WithDatabase(testMysqlDatabase), mysql.WithUsername(testMysqlUser), mysql.WithPassword(testMysqlPass), - testcontainers.WithWaitStrategy(wait.ForExposedPort()), ) suite.Require().NoError(err) diff --git a/go/logic/test_utils.go b/go/logic/test_utils.go index d532e0920..f552cfc76 100644 --- a/go/logic/test_utils.go +++ b/go/logic/test_utils.go @@ -28,6 +28,14 @@ func getTestGhostTableName() string { return fmt.Sprintf("`%s`.`_%s_gho`", testMysqlDatabase, testMysqlTableName) } +func getTestRevertedTableName() string { + return fmt.Sprintf("`%s`.`_%s_rev_del`", testMysqlDatabase, testMysqlTableName) +} + +func getTestOldTableName() string { + return fmt.Sprintf("`%s`.`_%s_del`", testMysqlDatabase, testMysqlTableName) +} + func getTestConnectionConfig(ctx context.Context, container testcontainers.Container) (*mysql.ConnectionConfig, error) { host, err := container.Host(ctx) if err != nil { diff --git a/go/logic/throttler.go b/go/logic/throttler.go index 7a6534baf..1ca40f957 100644 --- a/go/logic/throttler.go +++ b/go/logic/throttler.go @@ -362,7 +362,9 @@ func (this *Throttler) collectGeneralThrottleMetrics() error { // Regardless of throttle, we take opportunity to check for panic-abort if this.migrationContext.PanicFlagFile != "" { if base.FileExists(this.migrationContext.PanicFlagFile) { - this.migrationContext.PanicAbort <- fmt.Errorf("Found panic-file %s. Aborting without cleanup", this.migrationContext.PanicFlagFile) + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, fmt.Errorf("Found panic-file %s. Aborting without cleanup", this.migrationContext.PanicFlagFile)) + return nil } } @@ -385,7 +387,9 @@ func (this *Throttler) collectGeneralThrottleMetrics() error { } if criticalLoadMet && this.migrationContext.CriticalLoadIntervalMilliseconds == 0 { - this.migrationContext.PanicAbort <- fmt.Errorf("critical-load met: %s=%d, >=%d", variableName, value, threshold) + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, fmt.Errorf("critical-load met: %s=%d, >=%d", variableName, value, threshold)) + return nil } if criticalLoadMet && this.migrationContext.CriticalLoadIntervalMilliseconds > 0 { this.migrationContext.Log.Errorf("critical-load met once: %s=%d, >=%d. Will check again in %d millis", variableName, value, threshold, this.migrationContext.CriticalLoadIntervalMilliseconds) @@ -393,7 +397,8 @@ func (this *Throttler) collectGeneralThrottleMetrics() error { timer := time.NewTimer(time.Millisecond * time.Duration(this.migrationContext.CriticalLoadIntervalMilliseconds)) <-timer.C if criticalLoadMetAgain, variableName, value, threshold, _ := this.criticalLoadIsMet(); criticalLoadMetAgain { - this.migrationContext.PanicAbort <- fmt.Errorf("critical-load met again after %d millis: %s=%d, >=%d", this.migrationContext.CriticalLoadIntervalMilliseconds, variableName, value, threshold) + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, fmt.Errorf("critical-load met again after %d millis: %s=%d, >=%d", this.migrationContext.CriticalLoadIntervalMilliseconds, variableName, value, threshold)) } }() } @@ -481,7 +486,16 @@ func (this *Throttler) initiateThrottlerChecks() { ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() - for range ticker.C { + for { + // Check for context cancellation each iteration + ctx := this.migrationContext.GetContext() + select { + case <-ctx.Done(): + return + case <-ticker.C: + // Process throttle check + } + if atomic.LoadInt64(&this.finishedMigrating) > 0 { return } diff --git a/go/mysql/binlog_file.go b/go/mysql/binlog_file.go index b9df215bf..7ecf16654 100644 --- a/go/mysql/binlog_file.go +++ b/go/mysql/binlog_file.go @@ -7,19 +7,11 @@ package mysql import ( - "errors" "fmt" - "regexp" "strconv" "strings" ) -var detachPattern *regexp.Regexp - -func init() { - detachPattern, _ = regexp.Compile(`//([^/:]+):([\d]+)`) // e.g. `//binlog.01234:567890` -} - // FileBinlogCoordinates described binary log coordinates in the form of a binlog file & log position. type FileBinlogCoordinates struct { LogFile string @@ -120,51 +112,6 @@ func (this *FileBinlogCoordinates) FileNumber() (int, int) { return fileNum, numLen } -// PreviousFileCoordinatesBy guesses the filename of the previous binlog/relaylog, by given offset (number of files back) -func (this *FileBinlogCoordinates) PreviousFileCoordinatesBy(offset int) (BinlogCoordinates, error) { - result := &FileBinlogCoordinates{} - - fileNum, numLen := this.FileNumber() - if fileNum == 0 { - return result, errors.New("Log file number is zero, cannot detect previous file") - } - newNumStr := fmt.Sprintf("%d", (fileNum - offset)) - newNumStr = strings.Repeat("0", numLen-len(newNumStr)) + newNumStr - - tokens := strings.Split(this.LogFile, ".") - tokens[len(tokens)-1] = newNumStr - result.LogFile = strings.Join(tokens, ".") - return result, nil -} - -// PreviousFileCoordinates guesses the filename of the previous binlog/relaylog -func (this *FileBinlogCoordinates) PreviousFileCoordinates() (BinlogCoordinates, error) { - return this.PreviousFileCoordinatesBy(1) -} - -// PreviousFileCoordinates guesses the filename of the previous binlog/relaylog -func (this *FileBinlogCoordinates) NextFileCoordinates() (BinlogCoordinates, error) { - result := &FileBinlogCoordinates{} - - fileNum, numLen := this.FileNumber() - newNumStr := fmt.Sprintf("%d", (fileNum + 1)) - newNumStr = strings.Repeat("0", numLen-len(newNumStr)) + newNumStr - - tokens := strings.Split(this.LogFile, ".") - tokens[len(tokens)-1] = newNumStr - result.LogFile = strings.Join(tokens, ".") - return result, nil -} - -// FileSmallerThan returns true if this coordinate's file is strictly smaller than the other's. -func (this *FileBinlogCoordinates) DetachedCoordinates() (isDetached bool, detachedLogFile string, detachedLogPos string) { - detachedCoordinatesSubmatch := detachPattern.FindStringSubmatch(this.LogFile) - if len(detachedCoordinatesSubmatch) == 0 { - return false, "", "" - } - return true, detachedCoordinatesSubmatch[1], detachedCoordinatesSubmatch[2] -} - func (this *FileBinlogCoordinates) Clone() BinlogCoordinates { return &FileBinlogCoordinates{ LogPos: this.LogPos, diff --git a/go/mysql/utils.go b/go/mysql/utils.go index 7d57afbf1..9619e1e9f 100644 --- a/go/mysql/utils.go +++ b/go/mysql/utils.go @@ -248,9 +248,9 @@ func Kill(db *gosql.DB, connectionID string) error { // GetTriggers reads trigger list from given table func GetTriggers(db *gosql.DB, databaseName, tableName string) (triggers []Trigger, err error) { - query := fmt.Sprintf(`select trigger_name as name, event_manipulation as event, action_statement as statement, action_timing as timing - from information_schema.triggers - where trigger_schema = '%s' and event_object_table = '%s'`, databaseName, tableName) + query := `select trigger_name as name, event_manipulation as event, action_statement as statement, action_timing as timing + from information_schema.triggers + where trigger_schema = ? and event_object_table = ?` err = sqlutils.QueryRowsMap(db, query, func(rowMap sqlutils.RowMap) error { triggers = append(triggers, Trigger{ @@ -260,7 +260,7 @@ func GetTriggers(db *gosql.DB, databaseName, tableName string) (triggers []Trigg Timing: rowMap.GetString("timing"), }) return nil - }) + }, databaseName, tableName) if err != nil { return nil, err } diff --git a/go/sql/builder.go b/go/sql/builder.go index 757d74910..940ca4ca3 100644 --- a/go/sql/builder.go +++ b/go/sql/builder.go @@ -142,11 +142,11 @@ func NewCheckpointQueryBuilder(databaseName, tableName string, uniqueKeyColumns insert /* gh-ost */ into %s.%s (gh_ost_chk_timestamp, gh_ost_chk_coords, gh_ost_chk_iteration, - gh_ost_rows_copied, gh_ost_dml_applied, + gh_ost_rows_copied, gh_ost_dml_applied, gh_ost_is_cutover, %s, %s) values (unix_timestamp(now()), ?, ?, - ?, ?, + ?, ?, ?, %s, %s)`, databaseName, tableName, strings.Join(minUniqueColNames, ", "), @@ -169,11 +169,11 @@ func (b *CheckpointInsertQueryBuilder) BuildQuery(uniqueKeyArgs []interface{}) ( } convertedArgs := make([]interface{}, 0, 2*b.uniqueKeyColumns.Len()) for i, column := range b.uniqueKeyColumns.Columns() { - minArg := column.convertArg(uniqueKeyArgs[i], true) + minArg := column.convertArg(uniqueKeyArgs[i]) convertedArgs = append(convertedArgs, minArg) } for i, column := range b.uniqueKeyColumns.Columns() { - minArg := column.convertArg(uniqueKeyArgs[i+b.uniqueKeyColumns.Len()], true) + minArg := column.convertArg(uniqueKeyArgs[i+b.uniqueKeyColumns.Len()]) convertedArgs = append(convertedArgs, minArg) } return b.preparedStatement, convertedArgs, nil @@ -533,7 +533,7 @@ func (b *DMLDeleteQueryBuilder) BuildQuery(args []interface{}) (string, []interf uniqueKeyArgs := make([]interface{}, 0, b.uniqueKeyColumns.Len()) for _, column := range b.uniqueKeyColumns.Columns() { tableOrdinal := b.tableColumns.Ordinals[column.Name] - arg := column.convertArg(args[tableOrdinal], true) + arg := column.convertArg(args[tableOrdinal]) uniqueKeyArgs = append(uniqueKeyArgs, arg) } return b.preparedStatement, uniqueKeyArgs, nil @@ -566,7 +566,7 @@ func NewDMLInsertQueryBuilder(databaseName, tableName string, tableColumns, shar preparedValues := buildColumnsPreparedValues(mappedSharedColumns) stmt := fmt.Sprintf(` - replace /* gh-ost %s.%s */ + insert /* gh-ost %s.%s */ ignore into %s.%s (%s) @@ -595,7 +595,7 @@ func (b *DMLInsertQueryBuilder) BuildQuery(args []interface{}) (string, []interf sharedArgs := make([]interface{}, 0, b.sharedColumns.Len()) for _, column := range b.sharedColumns.Columns() { tableOrdinal := b.tableColumns.Ordinals[column.Name] - arg := column.convertArg(args[tableOrdinal], false) + arg := column.convertArg(args[tableOrdinal]) sharedArgs = append(sharedArgs, arg) } return b.preparedStatement, sharedArgs, nil @@ -661,20 +661,18 @@ func NewDMLUpdateQueryBuilder(databaseName, tableName string, tableColumns, shar // BuildQuery builds the arguments array for a DML event UPDATE query. // It returns the query string, the shared arguments array, and the unique key arguments array. -func (b *DMLUpdateQueryBuilder) BuildQuery(valueArgs, whereArgs []interface{}) (string, []interface{}, []interface{}, error) { - sharedArgs := make([]interface{}, 0, b.sharedColumns.Len()) +func (b *DMLUpdateQueryBuilder) BuildQuery(valueArgs, whereArgs []interface{}) (string, []interface{}, error) { + args := make([]interface{}, 0, b.sharedColumns.Len()+b.uniqueKeyColumns.Len()) for _, column := range b.sharedColumns.Columns() { tableOrdinal := b.tableColumns.Ordinals[column.Name] - arg := column.convertArg(valueArgs[tableOrdinal], false) - sharedArgs = append(sharedArgs, arg) + arg := column.convertArg(valueArgs[tableOrdinal]) + args = append(args, arg) } - - uniqueKeyArgs := make([]interface{}, 0, b.uniqueKeyColumns.Len()) for _, column := range b.uniqueKeyColumns.Columns() { tableOrdinal := b.tableColumns.Ordinals[column.Name] - arg := column.convertArg(whereArgs[tableOrdinal], true) - uniqueKeyArgs = append(uniqueKeyArgs, arg) + arg := column.convertArg(whereArgs[tableOrdinal]) + args = append(args, arg) } - return b.preparedStatement, sharedArgs, uniqueKeyArgs, nil + return b.preparedStatement, args, nil } diff --git a/go/sql/builder_test.go b/go/sql/builder_test.go index 06e402c89..7f80005b0 100644 --- a/go/sql/builder_test.go +++ b/go/sql/builder_test.go @@ -538,7 +538,7 @@ func TestBuildDMLInsertQuery(t *testing.T) { query, sharedArgs, err := builder.BuildQuery(args) require.NoError(t, err) expected := ` - replace /* gh-ost mydb.tbl */ + insert /* gh-ost mydb.tbl */ ignore into mydb.tbl (id, name, position, age) values @@ -554,7 +554,7 @@ func TestBuildDMLInsertQuery(t *testing.T) { query, sharedArgs, err := builder.BuildQuery(args) require.NoError(t, err) expected := ` - replace /* gh-ost mydb.tbl */ + insert /* gh-ost mydb.tbl */ ignore into mydb.tbl (position, name, age, id) values @@ -589,7 +589,7 @@ func TestBuildDMLInsertQuerySignedUnsigned(t *testing.T) { query, sharedArgs, err := builder.BuildQuery(args) require.NoError(t, err) expected := ` - replace /* gh-ost mydb.tbl */ + insert /* gh-ost mydb.tbl */ ignore into mydb.tbl (id, name, position, age) values @@ -607,7 +607,7 @@ func TestBuildDMLInsertQuerySignedUnsigned(t *testing.T) { query, sharedArgs, err := builder.BuildQuery(args) require.NoError(t, err) expected := ` - replace /* gh-ost mydb.tbl */ + insert /* gh-ost mydb.tbl */ ignore into mydb.tbl (id, name, position, age) values @@ -625,7 +625,7 @@ func TestBuildDMLInsertQuerySignedUnsigned(t *testing.T) { query, sharedArgs, err := builder.BuildQuery(args) require.NoError(t, err) expected := ` - replace /* gh-ost mydb.tbl */ + insert /* gh-ost mydb.tbl */ ignore into mydb.tbl (id, name, position, age) values @@ -647,7 +647,7 @@ func TestBuildDMLUpdateQuery(t *testing.T) { uniqueKeyColumns := NewColumnList([]string{"position"}) builder, err := NewDMLUpdateQueryBuilder(databaseName, tableName, tableColumns, sharedColumns, sharedColumns, uniqueKeyColumns) require.NoError(t, err) - query, sharedArgs, uniqueKeyArgs, err := builder.BuildQuery(valueArgs, whereArgs) + query, updateArgs, err := builder.BuildQuery(valueArgs, whereArgs) require.NoError(t, err) expected := ` update /* gh-ost mydb.tbl */ @@ -657,15 +657,14 @@ func TestBuildDMLUpdateQuery(t *testing.T) { ((position = ?)) ` require.Equal(t, normalizeQuery(expected), normalizeQuery(query)) - require.Equal(t, []interface{}{3, "testname", 17, 23}, sharedArgs) - require.Equal(t, []interface{}{17}, uniqueKeyArgs) + require.Equal(t, []interface{}{3, "testname", 17, 23, 17}, updateArgs) } { sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) uniqueKeyColumns := NewColumnList([]string{"position", "name"}) builder, err := NewDMLUpdateQueryBuilder(databaseName, tableName, tableColumns, sharedColumns, sharedColumns, uniqueKeyColumns) require.NoError(t, err) - query, sharedArgs, uniqueKeyArgs, err := builder.BuildQuery(valueArgs, whereArgs) + query, updateArgs, err := builder.BuildQuery(valueArgs, whereArgs) require.NoError(t, err) expected := ` update /* gh-ost mydb.tbl */ @@ -675,15 +674,14 @@ func TestBuildDMLUpdateQuery(t *testing.T) { ((position = ?) and (name = ?)) ` require.Equal(t, normalizeQuery(expected), normalizeQuery(query)) - require.Equal(t, []interface{}{3, "testname", 17, 23}, sharedArgs) - require.Equal(t, []interface{}{17, "testname"}, uniqueKeyArgs) + require.Equal(t, []interface{}{3, "testname", 17, 23, 17, "testname"}, updateArgs) } { sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) uniqueKeyColumns := NewColumnList([]string{"age"}) builder, err := NewDMLUpdateQueryBuilder(databaseName, tableName, tableColumns, sharedColumns, sharedColumns, uniqueKeyColumns) require.NoError(t, err) - query, sharedArgs, uniqueKeyArgs, err := builder.BuildQuery(valueArgs, whereArgs) + query, updateArgs, err := builder.BuildQuery(valueArgs, whereArgs) require.NoError(t, err) expected := ` update /* gh-ost mydb.tbl */ @@ -693,15 +691,14 @@ func TestBuildDMLUpdateQuery(t *testing.T) { ((age = ?)) ` require.Equal(t, normalizeQuery(expected), normalizeQuery(query)) - require.Equal(t, []interface{}{3, "testname", 17, 23}, sharedArgs) - require.Equal(t, []interface{}{56}, uniqueKeyArgs) + require.Equal(t, []interface{}{3, "testname", 17, 23, 56}, updateArgs) } { sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) uniqueKeyColumns := NewColumnList([]string{"age", "position", "id", "name"}) builder, err := NewDMLUpdateQueryBuilder(databaseName, tableName, tableColumns, sharedColumns, sharedColumns, uniqueKeyColumns) require.NoError(t, err) - query, sharedArgs, uniqueKeyArgs, err := builder.BuildQuery(valueArgs, whereArgs) + query, updateArgs, err := builder.BuildQuery(valueArgs, whereArgs) require.NoError(t, err) expected := ` update /* gh-ost mydb.tbl */ @@ -711,8 +708,7 @@ func TestBuildDMLUpdateQuery(t *testing.T) { ((age = ?) and (position = ?) and (id = ?) and (name = ?)) ` require.Equal(t, normalizeQuery(expected), normalizeQuery(query)) - require.Equal(t, []interface{}{3, "testname", 17, 23}, sharedArgs) - require.Equal(t, []interface{}{56, 17, 3, "testname"}, uniqueKeyArgs) + require.Equal(t, []interface{}{3, "testname", 17, 23, 56, 17, 3, "testname"}, updateArgs) } { sharedColumns := NewColumnList([]string{"id", "name", "position", "age"}) @@ -732,7 +728,7 @@ func TestBuildDMLUpdateQuery(t *testing.T) { uniqueKeyColumns := NewColumnList([]string{"id"}) builder, err := NewDMLUpdateQueryBuilder(databaseName, tableName, tableColumns, sharedColumns, mappedColumns, uniqueKeyColumns) require.NoError(t, err) - query, sharedArgs, uniqueKeyArgs, err := builder.BuildQuery(valueArgs, whereArgs) + query, updateArgs, err := builder.BuildQuery(valueArgs, whereArgs) require.NoError(t, err) expected := ` update /* gh-ost mydb.tbl */ @@ -742,8 +738,7 @@ func TestBuildDMLUpdateQuery(t *testing.T) { ((id = ?)) ` require.Equal(t, normalizeQuery(expected), normalizeQuery(query)) - require.Equal(t, []interface{}{3, "testname", 17, 23}, sharedArgs) - require.Equal(t, []interface{}{3}, uniqueKeyArgs) + require.Equal(t, []interface{}{3, "testname", 17, 23, 3}, updateArgs) } } @@ -759,7 +754,7 @@ func TestBuildDMLUpdateQuerySignedUnsigned(t *testing.T) { require.NoError(t, err) { // test signed - query, sharedArgs, uniqueKeyArgs, err := builder.BuildQuery(valueArgs, whereArgs) + query, updateArgs, err := builder.BuildQuery(valueArgs, whereArgs) require.NoError(t, err) expected := ` update /* gh-ost mydb.tbl */ @@ -769,14 +764,13 @@ func TestBuildDMLUpdateQuerySignedUnsigned(t *testing.T) { ((position = ?)) ` require.Equal(t, normalizeQuery(expected), normalizeQuery(query)) - require.Equal(t, []interface{}{3, "testname", int8(-17), int8(-2)}, sharedArgs) - require.Equal(t, []interface{}{int8(-3)}, uniqueKeyArgs) + require.Equal(t, []interface{}{3, "testname", int8(-17), int8(-2), int8(-3)}, updateArgs) } { // test unsigned sharedColumns.SetUnsigned("age") uniqueKeyColumns.SetUnsigned("position") - query, sharedArgs, uniqueKeyArgs, err := builder.BuildQuery(valueArgs, whereArgs) + query, updateArgs, err := builder.BuildQuery(valueArgs, whereArgs) require.NoError(t, err) expected := ` update /* gh-ost mydb.tbl */ @@ -786,8 +780,7 @@ func TestBuildDMLUpdateQuerySignedUnsigned(t *testing.T) { ((position = ?)) ` require.Equal(t, normalizeQuery(expected), normalizeQuery(query)) - require.Equal(t, []interface{}{3, "testname", int8(-17), uint8(254)}, sharedArgs) - require.Equal(t, []interface{}{uint8(253)}, uniqueKeyArgs) + require.Equal(t, []interface{}{3, "testname", int8(-17), uint8(254), uint8(253)}, updateArgs) } } @@ -803,12 +796,12 @@ func TestCheckpointQueryBuilder(t *testing.T) { expected := ` insert /* gh-ost */ into mydb._tbl_ghk (gh_ost_chk_timestamp, gh_ost_chk_coords, gh_ost_chk_iteration, - gh_ost_rows_copied, gh_ost_dml_applied, + gh_ost_rows_copied, gh_ost_dml_applied, gh_ost_is_cutover, name_min, position_min, my_very_long_column_that_is_64_utf8_characters_long_很长很长很长很长_min, name_max, position_max, my_very_long_column_that_is_64_utf8_characters_long_很长很长很长很长_max) values (unix_timestamp(now()), ?, ?, - ?, ?, + ?, ?, ?, ?, ?, ?, ?, ?, ?) ` diff --git a/go/sql/types.go b/go/sql/types.go index a01fb8bff..9a50bc620 100644 --- a/go/sql/types.go +++ b/go/sql/types.go @@ -57,7 +57,7 @@ type Column struct { MySQLType string } -func (this *Column) convertArg(arg interface{}, isUniqueKeyColumn bool) interface{} { +func (this *Column) convertArg(arg interface{}) interface{} { var arg2Bytes []byte if s, ok := arg.(string); ok { arg2Bytes = []byte(s) @@ -77,14 +77,14 @@ func (this *Column) convertArg(arg interface{}, isUniqueKeyColumn bool) interfac } } - if this.Type == BinaryColumnType && isUniqueKeyColumn { + if this.Type == BinaryColumnType { size := len(arg2Bytes) if uint(size) < this.BinaryOctetLength { buf := bytes.NewBuffer(arg2Bytes) for i := uint(0); i < (this.BinaryOctetLength - uint(size)); i++ { buf.Write([]byte{0}) } - arg = buf.String() + arg = buf.Bytes() } } diff --git a/go/sql/types_test.go b/go/sql/types_test.go index 7b808e64f..83c74073a 100644 --- a/go/sql/types_test.go +++ b/go/sql/types_test.go @@ -62,6 +62,56 @@ func TestConvertArgCharsetDecoding(t *testing.T) { } // Should decode []uint8 - str := col.convertArg(latin1Bytes, false) + str := col.convertArg(latin1Bytes) require.Equal(t, "Garçon !", str) } + +func TestConvertArgBinaryColumnPadding(t *testing.T) { + // Test that binary columns are padded with trailing zeros to their declared length. + // This is needed because MySQL's binlog strips trailing 0x00 bytes from binary values. + // See https://github.com/github/gh-ost/issues/909 + + // Simulates a binary(20) column where binlog delivered only 18 bytes + // (trailing zeros were stripped) + truncatedValue := []uint8{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + 0x11, 0x12, // 18 bytes, missing 2 trailing zeros + } + + col := Column{ + Name: "bin_col", + Type: BinaryColumnType, + BinaryOctetLength: 20, + } + + result := col.convertArg(truncatedValue) + resultBytes := result.([]byte) + + require.Equal(t, 20, len(resultBytes), "binary column should be padded to declared length") + // First 18 bytes should be unchanged + require.Equal(t, truncatedValue, resultBytes[:18]) + // Last 2 bytes should be zeros + require.Equal(t, []byte{0x00, 0x00}, resultBytes[18:]) +} + +func TestConvertArgBinaryColumnNoPaddingWhenFull(t *testing.T) { + // When binary value is already at full length, no padding should occur + fullValue := []uint8{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14, // 20 bytes + } + + col := Column{ + Name: "bin_col", + Type: BinaryColumnType, + BinaryOctetLength: 20, + } + + result := col.convertArg(fullValue) + resultBytes := result.([]byte) + + require.Equal(t, 20, len(resultBytes)) + require.Equal(t, fullValue, resultBytes) +} diff --git a/localtests/binary-to-varbinary/create.sql b/localtests/binary-to-varbinary/create.sql new file mode 100644 index 000000000..c2867f9ea --- /dev/null +++ b/localtests/binary-to-varbinary/create.sql @@ -0,0 +1,38 @@ +-- Test for https://github.com/github/gh-ost/issues/909 variant: +-- Binary columns with trailing zeros should preserve their values +-- when migrating from binary(N) to varbinary(M), even for rows +-- modified during migration via binlog events. + +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int NOT NULL AUTO_INCREMENT, + info varchar(255) NOT NULL, + data binary(20) NOT NULL, + PRIMARY KEY (id) +) auto_increment=1; + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + -- Insert rows where data has trailing zeros (will be stripped by binlog) + INSERT INTO gh_ost_test (info, data) VALUES ('insert-during-1', X'aabbccdd00000000000000000000000000000000'); + INSERT INTO gh_ost_test (info, data) VALUES ('insert-during-2', X'11223344556677889900000000000000000000ee'); + + -- Update existing rows to values with trailing zeros + UPDATE gh_ost_test SET data = X'ffeeddcc00000000000000000000000000000000' WHERE info = 'update-target-1'; + UPDATE gh_ost_test SET data = X'aabbccdd11111111111111111100000000000000' WHERE info = 'update-target-2'; +end ;; + +-- Pre-existing rows (copied via rowcopy, not binlog - these should work fine) +INSERT INTO gh_ost_test (info, data) VALUES + ('pre-existing-1', X'01020304050607080910111213141516171819ff'), + ('pre-existing-2', X'0102030405060708091011121314151617181900'), + ('update-target-1', X'ffffffffffffffffffffffffffffffffffffffff'), + ('update-target-2', X'eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee'); diff --git a/localtests/binary-to-varbinary/extra_args b/localtests/binary-to-varbinary/extra_args new file mode 100644 index 000000000..7ebdacf6c --- /dev/null +++ b/localtests/binary-to-varbinary/extra_args @@ -0,0 +1 @@ +--alter="MODIFY data varbinary(32)" diff --git a/localtests/copy-retries-exhausted/after.sql b/localtests/copy-retries-exhausted/after.sql new file mode 100644 index 000000000..c3d55e5b9 --- /dev/null +++ b/localtests/copy-retries-exhausted/after.sql @@ -0,0 +1 @@ +set global max_binlog_cache_size = 1073741824; -- 1GB diff --git a/localtests/copy-retries-exhausted/before.sql b/localtests/copy-retries-exhausted/before.sql new file mode 100644 index 000000000..a3570171a --- /dev/null +++ b/localtests/copy-retries-exhausted/before.sql @@ -0,0 +1 @@ +set global max_binlog_cache_size = 1024; diff --git a/localtests/copy-retries-exhausted/create.sql b/localtests/copy-retries-exhausted/create.sql new file mode 100644 index 000000000..4e37938ec --- /dev/null +++ b/localtests/copy-retries-exhausted/create.sql @@ -0,0 +1,12 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + name mediumtext not null, + primary key (id) +) auto_increment=1; + +insert into gh_ost_test (name) +select repeat('a', 1500) +from information_schema.columns +cross join information_schema.tables +limit 1000; diff --git a/localtests/copy-retries-exhausted/expect_failure b/localtests/copy-retries-exhausted/expect_failure new file mode 100644 index 000000000..cd6a516ec --- /dev/null +++ b/localtests/copy-retries-exhausted/expect_failure @@ -0,0 +1 @@ +Multi-statement transaction required more than 'max_binlog_cache_size' bytes of storage diff --git a/localtests/copy-retries-exhausted/extra_args b/localtests/copy-retries-exhausted/extra_args new file mode 100644 index 000000000..e4f8a0104 --- /dev/null +++ b/localtests/copy-retries-exhausted/extra_args @@ -0,0 +1 @@ +--alter "modify column name mediumtext" --default-retries=1 --chunk-size=1000 diff --git a/localtests/panic-on-warnings-batch-middle/create.sql b/localtests/panic-on-warnings-batch-middle/create.sql new file mode 100644 index 000000000..e4883ca66 --- /dev/null +++ b/localtests/panic-on-warnings-batch-middle/create.sql @@ -0,0 +1,11 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + email varchar(255) not null, + primary key (id) +) auto_increment=1; + +-- Insert initial data - all unique emails +insert into gh_ost_test (email) values ('alice@example.com'); +insert into gh_ost_test (email) values ('bob@example.com'); +insert into gh_ost_test (email) values ('charlie@example.com'); diff --git a/localtests/panic-on-warnings-batch-middle/expect_failure b/localtests/panic-on-warnings-batch-middle/expect_failure new file mode 100644 index 000000000..11800efb0 --- /dev/null +++ b/localtests/panic-on-warnings-batch-middle/expect_failure @@ -0,0 +1 @@ +ERROR warnings detected in statement 1 of 2 diff --git a/localtests/panic-on-warnings-batch-middle/extra_args b/localtests/panic-on-warnings-batch-middle/extra_args new file mode 100644 index 000000000..55cdcd0e8 --- /dev/null +++ b/localtests/panic-on-warnings-batch-middle/extra_args @@ -0,0 +1 @@ +--default-retries=1 --panic-on-warnings --alter "add unique index email_idx(email)" --postpone-cut-over-flag-file=/tmp/gh-ost-test.postpone-cutover diff --git a/localtests/panic-on-warnings-batch-middle/test.sh b/localtests/panic-on-warnings-batch-middle/test.sh new file mode 100755 index 000000000..cbe3829c7 --- /dev/null +++ b/localtests/panic-on-warnings-batch-middle/test.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Custom test: inject batched DML events AFTER row copy completes +# Tests that warnings in the middle of a DML batch are detected + +# Create postpone flag file (referenced in extra_args) +postpone_flag_file=/tmp/gh-ost-test.postpone-cutover +touch $postpone_flag_file + +# Set table names (required by build_ghost_command) +table_name="gh_ost_test" +ghost_table_name="_gh_ost_test_gho" + +# Build gh-ost command using framework function +build_ghost_command + +# Run in background +echo_dot +echo > $test_logfile +bash -c "$cmd" >>$test_logfile 2>&1 & +ghost_pid=$! + +# Wait for row copy to complete +echo_dot +for i in {1..30}; do + grep -q "Row copy complete" $test_logfile && break + ps -p $ghost_pid > /dev/null || { echo; echo "ERROR gh-ost exited early"; rm -f $postpone_flag_file; return 1; } + sleep 1; echo_dot +done + +# Inject batched DML events that will create warnings +# These must be in a single transaction to be batched during binlog replay +echo_dot +gh-ost-test-mysql-master test << 'EOF' +BEGIN; +-- INSERT with duplicate PRIMARY KEY - warning on migration key (filtered by gh-ost) +INSERT IGNORE INTO gh_ost_test (id, email) VALUES (1, 'duplicate_pk@example.com'); +-- INSERT with duplicate email - warning on unique index (should trigger failure) +INSERT IGNORE INTO gh_ost_test (email) VALUES ('alice@example.com'); +-- INSERT with unique data - would succeed if not for previous warning +INSERT IGNORE INTO gh_ost_test (email) VALUES ('new@example.com'); +COMMIT; +EOF + +# Wait for binlog events to replicate and be applied +sleep 10; echo_dot + +# Complete cutover by removing postpone flag +rm -f $postpone_flag_file + +# Wait for gh-ost to complete +wait $ghost_pid +execution_result=$? +rm -f $postpone_flag_file + +# Validate using framework function +validate_expected_failure +return $? diff --git a/localtests/panic-on-warnings-update-pk-with-duplicate-on-new-unique-index/create.sql b/localtests/panic-on-warnings-update-pk-with-duplicate-on-new-unique-index/create.sql new file mode 100644 index 000000000..444092cb2 --- /dev/null +++ b/localtests/panic-on-warnings-update-pk-with-duplicate-on-new-unique-index/create.sql @@ -0,0 +1,10 @@ +drop table if exists gh_ost_test; +create table gh_ost_test ( + id int auto_increment, + email varchar(100) not null, + primary key (id) +) auto_increment=1; + +insert into gh_ost_test (email) values ('alice@example.com'); +insert into gh_ost_test (email) values ('bob@example.com'); +insert into gh_ost_test (email) values ('charlie@example.com'); diff --git a/localtests/panic-on-warnings-update-pk-with-duplicate-on-new-unique-index/expect_failure b/localtests/panic-on-warnings-update-pk-with-duplicate-on-new-unique-index/expect_failure new file mode 100644 index 000000000..5a6e5411e --- /dev/null +++ b/localtests/panic-on-warnings-update-pk-with-duplicate-on-new-unique-index/expect_failure @@ -0,0 +1 @@ +ERROR warnings detected in statement diff --git a/localtests/panic-on-warnings-update-pk-with-duplicate-on-new-unique-index/extra_args b/localtests/panic-on-warnings-update-pk-with-duplicate-on-new-unique-index/extra_args new file mode 100644 index 000000000..04c41a471 --- /dev/null +++ b/localtests/panic-on-warnings-update-pk-with-duplicate-on-new-unique-index/extra_args @@ -0,0 +1 @@ +--panic-on-warnings --alter "ADD UNIQUE KEY email_unique (email)" --postpone-cut-over-flag-file=/tmp/gh-ost-test.postpone-cutover diff --git a/localtests/panic-on-warnings-update-pk-with-duplicate-on-new-unique-index/test.sh b/localtests/panic-on-warnings-update-pk-with-duplicate-on-new-unique-index/test.sh new file mode 100755 index 000000000..9f33be9a3 --- /dev/null +++ b/localtests/panic-on-warnings-update-pk-with-duplicate-on-new-unique-index/test.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Custom test: inject conflicting data AFTER row copy completes +# This tests the DML event application code path, not row copy + +# Create postpone flag file (referenced in extra_args) +postpone_flag_file=/tmp/gh-ost-test.postpone-cutover +touch $postpone_flag_file + +# Build gh-ost command using framework function +build_ghost_command + +# Run in background +echo_dot +# Clear log file before starting gh-ost +echo > $test_logfile +bash -c "$cmd" >>$test_logfile 2>&1 & +ghost_pid=$! + +# Wait for row copy to complete +echo_dot +row_copy_complete=false +for i in {1..30}; do + if grep -q "Row copy complete" $test_logfile; then + row_copy_complete=true + break + fi + ps -p $ghost_pid > /dev/null || { echo; echo "ERROR gh-ost exited early"; rm -f $postpone_flag_file; return 1; } + sleep 1; echo_dot +done + +if ! $row_copy_complete; then + echo; echo "ERROR row copy did not complete within expected time" + rm -f $postpone_flag_file + return 1 +fi + +# Inject conflicting SQL after row copy (UPDATE with PK change creates DELETE+INSERT in binlog) +echo_dot +gh-ost-test-mysql-master test -e "update gh_ost_test set id = 200, email = 'alice@example.com' where id = 2" + +# Wait for binlog event to replicate and be applied +sleep 10; echo_dot + +# Complete cutover by removing postpone flag +rm -f $postpone_flag_file + +# Wait for gh-ost to complete +wait $ghost_pid +execution_result=$? +rm -f $postpone_flag_file + +# Validate using framework function +validate_expected_failure +return $? diff --git a/localtests/test.sh b/localtests/test.sh index 6776d227a..d918d473b 100755 --- a/localtests/test.sh +++ b/localtests/test.sh @@ -30,6 +30,8 @@ replica_port= original_sql_mode= current_gtid_mode= sysbench_pid= +test_timeout=120 +test_failure_log_tail_lines=50 OPTIND=1 while getopts "b:s:dtg" OPTION; do @@ -107,12 +109,6 @@ verify_master_and_replica() { fi } -exec_cmd() { - echo "$@" - command "$@" 1>$test_logfile 2>&1 - return $? -} - echo_dot() { echo -n "." } @@ -141,6 +137,79 @@ start_replication() { done } +build_ghost_command() { + # Build gh-ost command with all standard options + # Expects: ghost_binary, replica_host, replica_port, master_host, master_port, + # table_name, storage_engine, throttle_flag_file, extra_args + cmd="GOTRACEBACK=crash $ghost_binary \ + --user=gh-ost \ + --password=gh-ost \ + --host=$replica_host \ + --port=$replica_port \ + --assume-master-host=${master_host}:${master_port} \ + --database=test \ + --table=${table_name} \ + --storage-engine=${storage_engine} \ + --alter='engine=${storage_engine}' \ + --exact-rowcount \ + --assume-rbr \ + --skip-metadata-lock-check \ + --initially-drop-old-table \ + --initially-drop-ghost-table \ + --throttle-query='select timestampdiff(second, min(last_update), now()) < 5 from _${table_name}_ghc' \ + --throttle-flag-file=$throttle_flag_file \ + --serve-socket-file=/tmp/gh-ost.test.sock \ + --initially-drop-socket-file \ + --test-on-replica \ + --default-retries=3 \ + --chunk-size=10 \ + --verbose \ + --debug \ + --stack \ + --checkpoint \ + --execute ${extra_args[@]}" +} + +print_log_excerpt() { + echo "=== Last $test_failure_log_tail_lines lines of $test_logfile ===" + tail -n $test_failure_log_tail_lines $test_logfile + echo "=== End log excerpt ===" +} + +validate_expected_failure() { + # Check if test expected to fail and validate error message + # Expects: tests_path, test_name, execution_result, test_logfile + if [ -f $tests_path/$test_name/expect_failure ]; then + if [ $execution_result -eq 0 ]; then + echo + echo "ERROR $test_name execution was expected to exit on error but did not." + print_log_excerpt + return 1 + fi + if [ -s $tests_path/$test_name/expect_failure ]; then + # 'expect_failure' file has content. We expect to find this content in the log. + expected_error_message="$(cat $tests_path/$test_name/expect_failure)" + if grep -q "$expected_error_message" $test_logfile; then + return 0 + fi + echo + echo "ERROR $test_name execution was expected to exit with error message '${expected_error_message}' but did not." + print_log_excerpt + return 1 + fi + # 'expect_failure' file has no content. We generally agree that the failure is correct + return 0 + fi + + if [ $execution_result -ne 0 ]; then + echo + echo "ERROR $test_name execution failure. cat $test_logfile:" + cat $test_logfile + return 1 + fi + return 0 +} + sysbench_prepare() { local mysql_host="$1" local mysql_port="$2" @@ -226,6 +295,11 @@ test_single() { return 1 fi + if [ -f $tests_path/$test_name/before.sql ]; then + gh-ost-test-mysql-master --default-character-set=utf8mb4 test < $tests_path/$test_name/before.sql + gh-ost-test-mysql-replica --default-character-set=utf8mb4 test < $tests_path/$test_name/before.sql + fi + extra_args="" if [ -f $tests_path/$test_name/extra_args ]; then extra_args=$(cat $tests_path/$test_name/extra_args) @@ -254,6 +328,37 @@ test_single() { table_name="gh_ost_test" ghost_table_name="_gh_ost_test_gho" + + # Check for custom test script + if [ -f $tests_path/$test_name/test.sh ]; then + # Run the custom test script in a subshell with timeout monitoring + # The subshell inherits all functions and variables from the current shell + (source $tests_path/$test_name/test.sh) & + test_pid=$! + + # Monitor the test with timeout + timeout_counter=0 + while kill -0 $test_pid 2>/dev/null; do + if [ $timeout_counter -ge $test_timeout ]; then + kill -TERM $test_pid 2>/dev/null + sleep 1 + kill -KILL $test_pid 2>/dev/null + wait $test_pid 2>/dev/null + echo + echo "ERROR $test_name execution timed out" + print_log_excerpt + return 1 + fi + sleep 1 + ((timeout_counter++)) + done + + # Get the exit code + wait $test_pid 2>/dev/null + execution_result=$? + return $execution_result + fi + # test with sysbench oltp write load if [[ "$test_name" == "sysbench" ]]; then if ! command -v sysbench &>/dev/null; then @@ -274,77 +379,49 @@ test_single() { fi trap cleanup SIGINT - # - cmd="GOTRACEBACK=crash $ghost_binary \ - --user=gh-ost \ - --password=gh-ost \ - --host=$replica_host \ - --port=$replica_port \ - --assume-master-host=${master_host}:${master_port} - --database=test \ - --table=${table_name} \ - --storage-engine=${storage_engine} \ - --alter='engine=${storage_engine}' \ - --exact-rowcount \ - --assume-rbr \ - --initially-drop-old-table \ - --initially-drop-ghost-table \ - --throttle-query='select timestampdiff(second, min(last_update), now()) < 5 from _${table_name}_ghc' \ - --throttle-flag-file=$throttle_flag_file \ - --serve-socket-file=/tmp/gh-ost.test.sock \ - --initially-drop-socket-file \ - --test-on-replica \ - --default-retries=3 \ - --chunk-size=10 \ - --verbose \ - --debug \ - --stack \ - --checkpoint \ - --execute ${extra_args[@]}" + # Build and execute gh-ost command + build_ghost_command echo_dot echo $cmd >$exec_command_file echo_dot - bash $exec_command_file 1>$test_logfile 2>&1 + timeout $test_timeout bash $exec_command_file >$test_logfile 2>&1 execution_result=$? cleanup + # Check for timeout (exit code 124) + if [ $execution_result -eq 124 ]; then + echo + echo "ERROR $test_name execution timed out" + print_log_excerpt + return 1 + fi + if [ -f $tests_path/$test_name/sql_mode ]; then gh-ost-test-mysql-master --default-character-set=utf8mb4 test -e "set @@global.sql_mode='${original_sql_mode}'" gh-ost-test-mysql-replica --default-character-set=utf8mb4 test -e "set @@global.sql_mode='${original_sql_mode}'" fi + if [ -f $tests_path/$test_name/after.sql ]; then + gh-ost-test-mysql-master --default-character-set=utf8mb4 test < $tests_path/$test_name/after.sql + gh-ost-test-mysql-replica --default-character-set=utf8mb4 test < $tests_path/$test_name/after.sql + fi + if [ -f $tests_path/$test_name/destroy.sql ]; then gh-ost-test-mysql-master --default-character-set=utf8mb4 test <$tests_path/$test_name/destroy.sql fi - if [ -f $tests_path/$test_name/expect_failure ]; then - if [ $execution_result -eq 0 ]; then - echo - echo "ERROR $test_name execution was expected to exit on error but did not. cat $test_logfile" - return 1 - fi - if [ -s $tests_path/$test_name/expect_failure ]; then - # 'expect_failure' file has content. We expect to find this content in the log. - expected_error_message="$(cat $tests_path/$test_name/expect_failure)" - if grep -q "$expected_error_message" $test_logfile; then - return 0 - fi - echo - echo "ERROR $test_name execution was expected to exit with error message '${expected_error_message}' but did not. cat $test_logfile" - return 1 - fi - # 'expect_failure' file has no content. We generally agree that the failure is correct - return 0 + # Validate expected failure or success + if ! validate_expected_failure; then + return 1 fi - if [ $execution_result -ne 0 ]; then - echo - echo "ERROR $test_name execution failure. cat $test_logfile:" - cat $test_logfile - return 1 + # If this was an expected failure test, we're done (no need to validate structure/checksums) + if [ -f $tests_path/$test_name/expect_failure ]; then + return 0 fi + # Test succeeded - now validate structure and checksums gh-ost-test-mysql-replica --default-character-set=utf8mb4 test -e "show create table ${ghost_table_name}\G" -ss >$ghost_structure_output_file if [ -f $tests_path/$test_name/expect_table_structure ]; then @@ -398,14 +475,19 @@ test_all() { test_dirs=$(find "$tests_path" -mindepth 1 -maxdepth 1 ! -path . -type d | grep "$test_pattern" | sort) while read -r test_dir; do test_name=$(basename "$test_dir") + local test_start_time=$(date +%s) if ! test_single "$test_name"; then + local test_end_time=$(date +%s) + local test_duration=$((test_end_time - test_start_time)) create_statement=$(gh-ost-test-mysql-replica test -t -e "show create table ${ghost_table_name} \G") echo "$create_statement" >>$test_logfile - echo "+ FAIL" + echo "+ FAIL (${test_duration}s)" return 1 else + local test_end_time=$(date +%s) + local test_duration=$((test_end_time - test_start_time)) echo - echo "+ pass" + echo "+ pass (${test_duration}s)" fi mysql_version="$(gh-ost-test-mysql-replica -e "select @@version")" replica_terminology="slave" diff --git a/localtests/trigger-ghost-name-conflict/create.sql b/localtests/trigger-ghost-name-conflict/create.sql new file mode 100644 index 000000000..32980daa0 --- /dev/null +++ b/localtests/trigger-ghost-name-conflict/create.sql @@ -0,0 +1,35 @@ +-- Bug #3 regression test: validateGhostTriggersDontExist must check whole schema +-- MySQL trigger names are unique per SCHEMA, not per table. +-- The validation must detect a trigger with the ghost name on ANY table in the schema. + +drop trigger if exists gh_ost_test_ai_ght; +drop trigger if exists gh_ost_test_ai; +drop table if exists gh_ost_test_other; +drop table if exists gh_ost_test; + +create table gh_ost_test ( + id int auto_increment, + i int not null, + primary key(id) +) auto_increment=1; + +-- This trigger has the _ght suffix (simulating a previous migration left it). +-- Ghost name = "gh_ost_test_ai" (suffix removed). +create trigger gh_ost_test_ai_ght + after insert on gh_ost_test for each row + set @dummy = 1; + +-- Create ANOTHER table with a trigger named "gh_ost_test_ai" (the ghost name). +-- Validation must detect this conflict even though the trigger is on a different table. +create table gh_ost_test_other ( + id int auto_increment, + primary key(id) +); + +create trigger gh_ost_test_ai + after insert on gh_ost_test_other for each row + set @dummy = 1; + +insert into gh_ost_test values (null, 11); +insert into gh_ost_test values (null, 13); +insert into gh_ost_test values (null, 17); diff --git a/localtests/trigger-ghost-name-conflict/destroy.sql b/localtests/trigger-ghost-name-conflict/destroy.sql new file mode 100644 index 000000000..5b5e1e4f7 --- /dev/null +++ b/localtests/trigger-ghost-name-conflict/destroy.sql @@ -0,0 +1,2 @@ +drop trigger if exists gh_ost_test_ai; +drop table if exists gh_ost_test_other; diff --git a/localtests/trigger-ghost-name-conflict/expect_failure b/localtests/trigger-ghost-name-conflict/expect_failure new file mode 100644 index 000000000..1a4c997f3 --- /dev/null +++ b/localtests/trigger-ghost-name-conflict/expect_failure @@ -0,0 +1 @@ +Found gh-ost triggers \ No newline at end of file diff --git a/localtests/trigger-ghost-name-conflict/extra_args b/localtests/trigger-ghost-name-conflict/extra_args new file mode 100644 index 000000000..128ce0a66 --- /dev/null +++ b/localtests/trigger-ghost-name-conflict/extra_args @@ -0,0 +1 @@ +--include-triggers --trigger-suffix=_ght --remove-trigger-suffix-if-exists \ No newline at end of file diff --git a/localtests/trigger-long-name-validation/create.sql b/localtests/trigger-long-name-validation/create.sql new file mode 100644 index 000000000..526713b74 --- /dev/null +++ b/localtests/trigger-long-name-validation/create.sql @@ -0,0 +1,38 @@ +-- Bug #1: Double-transformation in trigger length validation +-- A trigger with a 60-char name should be valid: ghost name = 60 + 4 (_ght) = 64 chars (max allowed). +-- But validateGhostTriggersLength() applies GetGhostTriggerName() twice, +-- computing 60 + 4 + 4 = 68, which falsely exceeds the 64-char limit. + +drop table if exists gh_ost_test; + +create table gh_ost_test ( + id int auto_increment, + i int not null, + primary key(id) +) auto_increment=1; + +-- Trigger name is exactly 60 characters (padded to 60). +-- Ghost name with _ght suffix = 64 chars = exactly at the MySQL limit. +-- 60 chars: trigger_long_name_padding_aaaaaaaaaaaaaaaaaaaaa_60chars_xxxx +create trigger trigger_long_name_padding_aaaaaaaaaaaaaaaaaaaaa_60chars_xxxx + after insert on gh_ost_test for each row + set @dummy = 1; + +insert into gh_ost_test values (null, 11); +insert into gh_ost_test values (null, 13); +insert into gh_ost_test values (null, 17); + +drop event if exists gh_ost_test; +delimiter ;; +create event gh_ost_test + on schedule every 1 second + starts current_timestamp + ends current_timestamp + interval 60 second + on completion not preserve + enable + do +begin + insert into gh_ost_test values (null, 23); + update gh_ost_test set i=i+1 where id=1; +end ;; +delimiter ; diff --git a/localtests/trigger-long-name-validation/extra_args b/localtests/trigger-long-name-validation/extra_args new file mode 100644 index 000000000..128ce0a66 --- /dev/null +++ b/localtests/trigger-long-name-validation/extra_args @@ -0,0 +1 @@ +--include-triggers --trigger-suffix=_ght --remove-trigger-suffix-if-exists \ No newline at end of file diff --git a/script/bootstrap b/script/bootstrap index 573313a75..96bf45675 100755 --- a/script/bootstrap +++ b/script/bootstrap @@ -11,7 +11,9 @@ set -e # up so it points back to us and go is none the wiser set -x -rm -rf .gopath -mkdir -p .gopath/src/github.com/github -ln -s "$PWD" .gopath/src/github.com/github/gh-ost +if [ ! -L .gopath/src/github.com/github/gh-ost ]; then + rm -rf .gopath + mkdir -p .gopath/src/github.com/github + ln -s "$PWD" .gopath/src/github.com/github/gh-ost +fi export GOPATH=$PWD/.gopath:$GOPATH diff --git a/script/ensure-go-installed b/script/ensure-go-installed index 1504a5b74..3ed35c058 100755 --- a/script/ensure-go-installed +++ b/script/ensure-go-installed @@ -1,7 +1,7 @@ #!/bin/bash PREFERRED_GO_VERSION=go1.23.0 -SUPPORTED_GO_VERSIONS='go1.2[01234]' +SUPPORTED_GO_VERSIONS='go1.2[012345]' GO_PKG_DARWIN=${PREFERRED_GO_VERSION}.darwin-amd64.pkg GO_PKG_DARWIN_SHA=e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855