diff --git a/internal/db/schema/internal/edition/edition.go b/internal/db/schema/internal/edition/edition.go index 6df4f5ce19..cd61a762bb 100644 --- a/internal/db/schema/internal/edition/edition.go +++ b/internal/db/schema/internal/edition/edition.go @@ -78,6 +78,9 @@ func (e Editions) Sort() { // 01_add_new_table.up.sql // 02_refactor_views.up.sql func New(name string, dialect Dialect, m embed.FS, priority int, opt ...Option) (Edition, error) { + const beginStmt = "begin;" + const commitStmt = "commit;" + var largestSchemaVersion int migrations := make(migration.Migrations) @@ -121,11 +124,11 @@ func New(name string, dialect Dialect, m embed.FS, priority int, opt ...Option) } contents := strings.TrimSpace(string(cbts)) - if strings.ToLower(contents[:len("begin;")]) == "begin;" { - contents = contents[len("begin;"):] + if beginIdx := strings.Index(contents, beginStmt); beginIdx != -1 { + contents = contents[beginIdx+len(beginStmt):] } - if strings.ToLower(contents[len(contents)-len("commit;"):]) == "commit;" { - contents = contents[:len(contents)-len("commit;")] + if strings.ToLower(contents[len(contents)-len(commitStmt):]) == commitStmt { + contents = contents[:len(contents)-len(commitStmt)] } contents = strings.TrimSpace(contents) diff --git a/internal/db/schema/internal/edition/edition_test.go b/internal/db/schema/internal/edition/edition_test.go index bbf98dc318..4073c1977b 100644 --- a/internal/db/schema/internal/edition/edition_test.go +++ b/internal/db/schema/internal/edition/edition_test.go @@ -62,6 +62,10 @@ func TestNew(t *testing.T) { assert.Equal(t, e.LatestVersion, tt.expectedVersion, "Version") assert.Equal(t, e.Priority, tt.priority, "Priority") assert.Equal(t, len(e.Migrations), tt.expectedMigrationCount, "Number of migrations") + for _, m := range e.Migrations { + assert.NotContains(t, string(m.Statements), "begin;") + assert.NotContains(t, string(m.Statements), "commit;") + } }) } }