fix a minor defect in fakeFile Readdir() and unit tests (#53)

pull/56/head
Jim 6 years ago committed by GitHub
parent 477a47897e
commit 4a88d10ecc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,18 +21,7 @@ type migrationDriver struct {
// Open returns the given "file"
func (m *migrationDriver) Open(name string) (http.File, error) {
var ff *fakeFile
switch m.dialect {
case "postgres":
ff = postgresMigrations[name]
}
if ff == nil {
return nil, os.ErrNotExist
}
ff.name = strings.TrimPrefix(name, "migrations/")
ff.reader = bytes.NewReader(ff.bytes)
ff.dialect = m.dialect
return ff, nil
return newFakeFile(m.dialect, name)
}
// NewMigrationSource creates a source.Driver using httpfs with the given dialect
@ -53,6 +42,20 @@ type fakeFile struct {
dialect string
}
func newFakeFile(dialect string, name string) (*fakeFile, error) {
var ff *fakeFile
switch dialect {
case "postgres":
ff = postgresMigrations[name]
}
if ff == nil {
return nil, os.ErrNotExist
}
ff.name = strings.TrimPrefix(name, "migrations/")
ff.reader = bytes.NewReader(ff.bytes)
ff.dialect = dialect
return ff, nil
}
func (f *fakeFile) Read(p []byte) (n int, err error) {
return f.reader.Read(p)
}
@ -84,7 +87,7 @@ func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) {
// Create the slice of fileinfo objects to return
ret := make([]os.FileInfo, 0, len(migrationsMap))
for _, v := range keys {
for i, v := range keys {
// We need "migrations" in the map for the initial Open call but we
// should not return it as part of the "directory"'s "files".
if v == "migrations" {
@ -95,6 +98,9 @@ func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) {
return nil, err
}
ret = append(ret, stat)
if count > 0 && count == i {
break
}
}
return ret, nil
}

@ -0,0 +1,177 @@
package migrations
import (
"os"
"reflect"
"testing"
"github.com/golang-migrate/migrate/v4/source"
"github.com/golang-migrate/migrate/v4/source/httpfs"
"github.com/stretchr/testify/assert"
)
func TestNewMigrationSource(t *testing.T) {
type args struct {
dialect string
}
tests := []struct {
name string
args args
want source.Driver
wantErr bool
}{
{
name: "postgres",
args: args{dialect: "postgres"},
want: func() source.Driver {
d, err := httpfs.New(&migrationDriver{"postgres"}, "migrations")
if err != nil {
t.Errorf("NewMigrationSource() error creating httpfs = %w", err)
}
return d
}(),
wantErr: false,
},
{
name: "no-dialect",
args: args{dialect: ""},
want: nil,
wantErr: true,
},
{
name: "bad-dialect",
args: args{dialect: "rainbows-and-unicorns-db"},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewMigrationSource(tt.args.dialect)
if (err != nil) != tt.wantErr {
t.Errorf("NewMigrationSource() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewMigrationSource() = %v, want %v", got, tt.want)
}
})
}
}
func Test_migrationDriver_Open(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
dialect string
args args
wantErr bool
}{
{
name: "valid-file",
dialect: "postgres",
args: args{name: "migrations/01_domain_types.up.sql"},
wantErr: false,
},
{
name: "bad-file",
dialect: "postgres",
args: args{name: "migrations/unicorns-and-rainbows.up.sql"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &migrationDriver{
dialect: tt.dialect,
}
_, err := m.Open(tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("migrationDriver.Open() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_fakeFile_Read(t *testing.T) {
assert := assert.New(t)
t.Run("valid", func(t *testing.T) {
ff, err := newFakeFile("postgres", "migrations/01_domain_types.up.sql")
assert.Nil(err)
buf := make([]byte, len(ff.bytes))
n, err := ff.Read(buf)
assert.Nil(err)
assert.Equal(len(buf), n)
})
}
func Test_fakeFile_Seek(t *testing.T) {
assert := assert.New(t)
t.Run("valid", func(t *testing.T) {
ff, err := newFakeFile("postgres", "migrations/01_domain_types.up.sql")
assert.Nil(err)
buf := make([]byte, len(ff.bytes))
n, err := ff.Seek(10, 0)
assert.Nil(err)
assert.Equal(int64(10), n)
n2, err := ff.Read(buf)
assert.Nil(err)
assert.Equal(len(ff.bytes)-10, n2)
})
}
func Test_fakeFile_Close(t *testing.T) {
assert := assert.New(t)
t.Run("valid", func(t *testing.T) {
m := &migrationDriver{
dialect: "postgres",
}
f, err := m.Open("migrations/01_domain_types.up.sql")
assert.Nil(err)
err = f.Close()
assert.Nil(err)
})
}
func Test_fakeFile_Stat(t *testing.T) {
assert := assert.New(t)
t.Run("valid", func(t *testing.T) {
name := "migrations/01_domain_types.up.sql"
ff, err := newFakeFile("postgres", name)
assert.Nil(err)
info, err := ff.Stat()
assert.Nil(err)
assert.Equal(ff.name, info.Name())
assert.Equal(int64(len(ff.bytes)), info.Size())
assert.Equal(os.ModePerm, info.Mode())
assert.Equal(false, info.IsDir())
assert.Equal(nil, info.Sys())
})
}
func Test_fakeFile_Readdir(t *testing.T) {
assert := assert.New(t)
t.Run("valid", func(t *testing.T) {
name := "migrations/01_domain_types.up.sql"
ff, err := newFakeFile("postgres", name)
assert.Nil(err)
info, err := ff.Readdir(0)
assert.Nil(err)
assert.True(info != nil)
info, err = ff.Readdir(1)
assert.Nil(err)
assert.True(info != nil)
assert.Equal(1, len(info))
info, err = ff.Readdir(0)
assert.Nil(err)
assert.True(info != nil)
// we don't want to count "migrations", so we're len - 1
assert.Equal(len(postgresMigrations)-1, len(info))
})
}
Loading…
Cancel
Save