diff --git a/go.mod b/go.mod index 0ac8763466..2864f17342 100644 --- a/go.mod +++ b/go.mod @@ -100,7 +100,6 @@ require ( github.com/kelseyhightower/envconfig v1.4.0 github.com/miekg/dns v1.1.58 github.com/mikesmitty/edkey v0.0.0-20170222072505-3356ea4e686a - github.com/mitchellh/go-homedir v1.1.0 github.com/sevlyar/go-daemon v0.1.6 golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 golang.org/x/net v0.25.0 @@ -127,6 +126,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/sys/user v0.1.0 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect diff --git a/internal/clientcache/cmd/cache/start.go b/internal/clientcache/cmd/cache/start.go index 93a914c5f7..c94d507e09 100644 --- a/internal/clientcache/cmd/cache/start.go +++ b/internal/clientcache/cmd/cache/start.go @@ -19,7 +19,6 @@ import ( "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/errors" "github.com/mitchellh/cli" - "github.com/mitchellh/go-homedir" "github.com/posener/complete" "gopkg.in/natefinch/lumberjack.v2" ) @@ -243,7 +242,7 @@ func (c *StartCommand) Run(args []string) int { // DefaultDotDirectory returns the default path to the boundary dot directory. func DefaultDotDirectory(ctx context.Context) (string, error) { const op = "cache.DefaultDotDirectory" - homeDir, err := homedir.Dir() + homeDir, err := os.UserHomeDir() if err != nil { return "", errors.Wrap(ctx, err, op) } diff --git a/internal/clientcache/internal/daemon/options.go b/internal/clientcache/internal/daemon/options.go index ba20725409..daed203df3 100644 --- a/internal/clientcache/internal/daemon/options.go +++ b/internal/clientcache/internal/daemon/options.go @@ -20,8 +20,9 @@ type options struct { WithReadyToServeNotificationCh chan struct{} withBoundaryTokenReaderFunc cache.BoundaryTokenReaderFn - withUrl string - withLogger hclog.Logger + withUrl string + withLogger hclog.Logger + withHomeDir string } // Option - how options are passed as args @@ -42,6 +43,14 @@ func getOpts(opt ...Option) (options, error) { return opts, nil } +// WithHomeDir provides an optional home directory to use. +func WithHomeDir(_ context.Context, dir string) Option { + return func(o *options) error { + o.withHomeDir = dir + return nil + } +} + // withRefreshInterval provides an optional refresh interval. func withRefreshInterval(_ context.Context, d time.Duration) Option { return func(o *options) error { diff --git a/internal/clientcache/internal/daemon/options_test.go b/internal/clientcache/internal/daemon/options_test.go index 1025aa0381..0afcbab1b5 100644 --- a/internal/clientcache/internal/daemon/options_test.go +++ b/internal/clientcache/internal/daemon/options_test.go @@ -101,6 +101,13 @@ func Test_GetOpts(t *testing.T) { testOpts := getDefaultOptions() assert.Equal(t, opts, testOpts) }) + t.Run("WithHomeDir", func(t *testing.T) { + opts, err := getOpts(WithHomeDir(ctx, "/tmp")) + require.NoError(t, err) + testOpts := getDefaultOptions() + testOpts.withHomeDir = "/tmp" + assert.Equal(t, opts, testOpts) + }) t.Run("WithReadyToServeNotificationCh", func(t *testing.T) { ch := make(chan struct{}) opts, err := getOpts(WithReadyToServeNotificationCh(ctx, ch)) @@ -109,6 +116,5 @@ func Test_GetOpts(t *testing.T) { testOpts := getDefaultOptions() assert.Nil(t, testOpts.WithReadyToServeNotificationCh) testOpts.WithReadyToServeNotificationCh = ch - assert.Equal(t, opts, testOpts) }) } diff --git a/internal/clientcache/internal/daemon/server.go b/internal/clientcache/internal/daemon/server.go index 79a80e67a8..cd0f13e77c 100644 --- a/internal/clientcache/internal/daemon/server.go +++ b/internal/clientcache/internal/daemon/server.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "os" + "path/filepath" "sort" "strings" "sync" @@ -499,6 +500,10 @@ func setupEventing(ctx context.Context, logger hclog.Logger, serializationLock * return nil } +// openStore will open the underlying store for the db. If no options are +// provided, it will default to an on disk store using the user's home dir + +// ".boundary/cache.db". If a url is provided, it will use that as the store. +// Supported options: WithUrl, WithLogger, WithHomeDir func openStore(ctx context.Context, opt ...Option) (*db.DB, error) { const op = "daemon.openStore" opts, err := getOpts(opt...) @@ -514,6 +519,12 @@ func openStore(ctx context.Context, opt ...Option) (*db.DB, error) { return nil, errors.Wrap(ctx, err, op) } dbOpts = append(dbOpts, cachedb.WithUrl(url)) + default: + url, err := defaultDbUrl(ctx, opt...) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + dbOpts = append(dbOpts, cachedb.WithUrl(url)) } if !util.IsNil(opts.withLogger) { dbOpts = append(dbOpts, cachedb.WithGormFormatter(opts.withLogger)) @@ -524,3 +535,31 @@ func openStore(ctx context.Context, opt ...Option) (*db.DB, error) { } return store, nil } + +// defaultDbUrl returns the default db name including the path. It will ensure +// the directory exists by creating it if it doesn't. +func defaultDbUrl(ctx context.Context, opt ...Option) (string, error) { + const op = "daemon.DefaultDotDirectory" + opts, err := getOpts(opt...) + if err != nil { + return "", errors.Wrap(ctx, err, op) + } + if opts.withHomeDir == "" { + opts.withHomeDir, err = os.UserHomeDir() + if err != nil { + return "", errors.Wrap(ctx, err, op) + } + } + dotDir := filepath.Join(opts.withHomeDir, dotDirname) + if err := os.MkdirAll(dotDir, 0o700); err != nil { + return "", errors.Wrap(ctx, err, op) + } + fileName := filepath.Join(dotDir, dbFileName) + return fmt.Sprintf("%s%s", fileName, fkPragma), nil +} + +const ( + dotDirname = ".boundary" + dbFileName = "cache.db" + fkPragma = "?_pragma=foreign_keys(1)" +) diff --git a/internal/clientcache/internal/daemon/server_test.go b/internal/clientcache/internal/daemon/server_test.go index 1afafedcaf..2f7e947c90 100644 --- a/internal/clientcache/internal/daemon/server_test.go +++ b/internal/clientcache/internal/daemon/server_test.go @@ -15,6 +15,24 @@ import ( "github.com/stretchr/testify/require" ) +func Test_openStore(t *testing.T) { + ctx := context.Background() + t.Run("success", func(t *testing.T) { + tmpDir := t.TempDir() + db, err := openStore(ctx, WithUrl(ctx, tmpDir+"/test.db"+fkPragma)) + require.NoError(t, err) + require.NotNil(t, db) + assert.FileExists(t, tmpDir+"/test.db") + }) + t.Run("homedir", func(t *testing.T) { + tmpDir := t.TempDir() + db, err := openStore(ctx, WithHomeDir(ctx, tmpDir)) + require.NoError(t, err) + require.NotNil(t, db) + assert.FileExists(t, tmpDir+"/"+dotDirname+"/"+dbFileName) + }) +} + // Note: the name of this test must remain short because the temp dir created // includes the name of the test and there is a 108 character limit in allowed // unix socket path names. diff --git a/internal/clientcache/internal/db/db.go b/internal/clientcache/internal/db/db.go index 31907a3b0f..974faa7fbf 100644 --- a/internal/clientcache/internal/db/db.go +++ b/internal/clientcache/internal/db/db.go @@ -7,6 +7,8 @@ import ( "context" _ "embed" "fmt" + "strings" + "time" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" @@ -17,11 +19,16 @@ import ( //go:embed schema.sql var cacheSchema string +//go:embed schema_reset.sql +var cacheSchemaReset string + // DefaultStoreUrl uses a temp in-memory sqlite database see: https://www.sqlite.org/inmemorydb.html const DefaultStoreUrl = "file::memory:?_pragma=foreign_keys(1)" // Open creates a database connection. WithUrl is supported, but by default it // uses an in memory sqlite table. Sqlite is the only supported dbtype. +// Supported options: WithUrl, WithGormFormatter, WithDebug, +// WithTestValidSchemaVersion (for testing purposes) func Open(ctx context.Context, opt ...Option) (*db.DB, error) { const op = "db.Open" opts, err := getOpts(opt...) @@ -50,16 +57,38 @@ func Open(ctx context.Context, opt ...Option) (*db.DB, error) { conn.Debug(opts.withDebug) switch { - case opts.withDbType == dbw.Sqlite: + case opts.withDbType == dbw.Sqlite && url == DefaultStoreUrl: if err := createTables(ctx, conn); err != nil { return nil, errors.Wrap(ctx, err, op) } + case opts.withDbType == dbw.Sqlite && url != DefaultStoreUrl: + ok, err := validSchema(ctx, conn, opt...) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + if !ok { + if err := resetSchema(ctx, conn); err != nil { + return nil, errors.Wrap(ctx, err, op) + } + if err := createTables(ctx, conn); err != nil { + return nil, errors.Wrap(ctx, err, op) + } + } default: return nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("%q is not a supported cache store type", opts.withDbType)) } return conn, nil } +func resetSchema(ctx context.Context, conn *db.DB) error { + const op = "db.resetSchema" + rw := db.New(conn) + if _, err := rw.Exec(ctx, cacheSchemaReset, nil); err != nil { + return errors.Wrap(ctx, err, op) + } + return nil +} + func createTables(ctx context.Context, conn *db.DB) error { const op = "db.createTables" rw := db.New(conn) @@ -68,3 +97,56 @@ func createTables(ctx context.Context, conn *db.DB) error { } return nil } + +// validSchema checks of the schema is valid based on its version. Options +// supported: withTestValidSchemaVersion (for testing purposes) +func validSchema(ctx context.Context, conn *db.DB, opt ...Option) (bool, error) { + const op = "validateSchema" + switch { + case conn == nil: + return false, errors.New(ctx, errors.InvalidParameter, op, "conn is missing") + } + opts, err := getOpts(opt...) + if err != nil { + return false, errors.Wrap(ctx, err, op) + } + if opts.withSchemaVersion == "" { + opts.withSchemaVersion = schemaCurrentVersion + } + + rw := db.New(conn) + s := schema{} + err = rw.LookupWhere(ctx, &s, "1=1", nil) + switch { + case err != nil && strings.Contains(err.Error(), "no such table: schema_version"): + return false, nil + case err != nil: + // not sure if we should return the error or just return false so the + // schema is recreated... for now return the error. + return false, fmt.Errorf("%s: unable to get version: %w", op, err) + case s.Version != opts.withSchemaVersion: + return false, nil + default: + return true, nil + } +} + +// schema represents the current schema in the database +type schema struct { + // Version of the schema + Version string + // UpdateTime is the last update of the version + UpdateTime time.Time + // CreateTime is the create time of the initial version + CreateTime time.Time +} + +const ( + schemaTableName = "schema_version" + schemaCurrentVersion = "v0.0.1" +) + +// TableName returns the table name +func (s *schema) TableName() string { + return schemaTableName +} diff --git a/internal/clientcache/internal/db/db_test.go b/internal/clientcache/internal/db/db_test.go new file mode 100644 index 0000000000..062f55caa1 --- /dev/null +++ b/internal/clientcache/internal/db/db_test.go @@ -0,0 +1,66 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package db + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOpen(t *testing.T) { + ctx := context.Background() + t.Run("success-file-url-with-reopening", func(t *testing.T) { + tmpDir := t.TempDir() + db, err := Open(ctx, WithUrl(tmpDir+"/test.db"+fkPragma)) + require.NoError(t, err) + require.NotNil(t, db) + assert.FileExists(t, tmpDir+"/test.db") + + info, err := os.Stat(tmpDir + "/test.db") + require.NoError(t, err) + origCreatedAt := info.ModTime() + + // Reopen the db and make sure the file is not recreated + db, err = Open(ctx, WithUrl(tmpDir+"/test.db"+fkPragma)) + require.NoError(t, err) + require.NotNil(t, db) + info, err = os.Stat(tmpDir + "/test.db") + require.NoError(t, err) + assert.Equal(t, origCreatedAt, info.ModTime()) + }) + t.Run("success-mem-default-url", func(t *testing.T) { + db, err := Open(ctx) + require.NoError(t, err) + require.NotNil(t, db) + }) + t.Run("recreate-on-version-mismatch", func(t *testing.T) { + tmpDir := t.TempDir() + db, err := Open(ctx, WithUrl(tmpDir+"/test.db"+fkPragma)) + require.NoError(t, err) + require.NotNil(t, db) + assert.FileExists(t, tmpDir+"/test.db") + info, err := os.Stat(tmpDir + "/test.db") + require.NoError(t, err) + origCreatedAt := info.ModTime() + + // Reopen the db with a different schema version: forcing the db to be recreated + db, err = Open(ctx, WithUrl(tmpDir+"/test.db"+fkPragma), withTestValidSchemaVersion("2")) + require.NoError(t, err) + require.NotNil(t, db) + info, err = os.Stat(tmpDir + "/test.db") + require.NoError(t, err) + // The file should have been recreated with a new timestamp + assert.NotEqual(t, origCreatedAt, info.ModTime()) + }) +} + +const ( + dotDirname = ".boundary" + dbFileName = "cache.db" + fkPragma = "?_pragma=foreign_keys(1)" +) diff --git a/internal/clientcache/internal/db/options.go b/internal/clientcache/internal/db/options.go index ba8c3382bd..30186c90b1 100644 --- a/internal/clientcache/internal/db/options.go +++ b/internal/clientcache/internal/db/options.go @@ -9,6 +9,7 @@ import ( ) type options struct { + withSchemaVersion string withDebug bool withUrl string withDbType dbw.DbType @@ -42,6 +43,15 @@ func WithGormFormatter(logger hclog.Logger) Option { } } +// withTestValidSchemaVersion provides optional valid schema version for testing +// purposes. This is used to simulate a schema version that is valid/invalid. +func withTestValidSchemaVersion(useVersion string) Option { + return func(o *options) error { + o.withSchemaVersion = useVersion + return nil + } +} + // WithUrls provides optional url func WithUrl(url string) Option { return func(o *options) error { diff --git a/internal/clientcache/internal/db/options_test.go b/internal/clientcache/internal/db/options_test.go index f57591ef60..b67096b081 100644 --- a/internal/clientcache/internal/db/options_test.go +++ b/internal/clientcache/internal/db/options_test.go @@ -37,4 +37,12 @@ func Test_GetOpts(t *testing.T) { testOpts.withDebug = true assert.Equal(t, opts, testOpts) }) + t.Run("withTestValidSchemaVersion", func(t *testing.T) { + version := "v1" + opts, err := getOpts(withTestValidSchemaVersion(version)) + require.NoError(t, err) + testOpts := getDefaultOptions() + testOpts.withSchemaVersion = version + assert.Equal(t, opts, testOpts) + }) } diff --git a/internal/clientcache/internal/db/schema.sql b/internal/clientcache/internal/db/schema.sql index bf5023b196..ddd635b2d1 100644 --- a/internal/clientcache/internal/db/schema.sql +++ b/internal/clientcache/internal/db/schema.sql @@ -2,6 +2,40 @@ -- SPDX-License-Identifier: BUSL-1.1 begin; + +-- schema_version is a one row table to keep the version +create table if not exists schema_version ( + version text not null, + create_time timestamp not null default current_timestamp, + update_time timestamp not null default current_timestamp +); + +-- ensure that it's only ever one row +create unique index schema_version_one_row +ON schema_version((version is not null)); + +create trigger immutable_columns_schema_version +before update on schema_version +for each row + when + new.create_time <> old.create_time + begin + select raise(abort, 'immutable column'); + end; + + +create trigger update_time_column_schema_version +before update on schema_version +for each row +when + new.version <> old.version + begin + update schema_version set update_time = datetime('now','localtime') where rowid == new.rowid; + end; + + +insert into schema_version(version) values('v0.0.1'); + -- user contains the boundary user information for the boundary user that owns -- the information in the cache. create table if not exists user ( diff --git a/internal/clientcache/internal/db/schema_reset.sql b/internal/clientcache/internal/db/schema_reset.sql new file mode 100644 index 0000000000..8c58d56d4b --- /dev/null +++ b/internal/clientcache/internal/db/schema_reset.sql @@ -0,0 +1,11 @@ +-- Copyright (c) HashiCorp, Inc. +-- SPDX-License-Identifier: BUSL-1.1 + +-- cannot vacuum from within a transaction, so we're not using a transaction +-- when running these statements +PRAGMA writable_schema = 1; +DELETE FROM sqlite_master; +PRAGMA writable_schema = 0; +VACUUM; +PRAGMA integrity_check; +