diff --git a/internal/daemon/controller/tickers.go b/internal/daemon/controller/tickers.go index d15b8c2bd1..ba86bedcc9 100644 --- a/internal/daemon/controller/tickers.go +++ b/internal/daemon/controller/tickers.go @@ -36,7 +36,7 @@ func (c *Controller) startStatusTicking(cancelCtx context.Context) { return case <-timer.C: - if err := c.upsertController(cancelCtx); err != nil { + if err := c.updateController(cancelCtx); err != nil { event.WriteError(cancelCtx, op, err, event.WithInfoMsg("error fetching repository for status update")) } timer.Reset(statusInterval) @@ -57,17 +57,39 @@ func (c *Controller) upsertController(ctx context.Context) error { controller := server.NewController(c.conf.RawConfig.Controller.Name, opts...) repo, err := c.ServersRepoFn() if err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg("error fetching repository for status update")) + return errors.Wrap(ctx, err, op, errors.WithMsg("error fetching repository for status upsert")) } _, err = repo.UpsertController(ctx, controller) if err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg("error performing status update")) + return errors.Wrap(ctx, err, op, errors.WithMsg("error performing status upsert")) } return nil } +func (c *Controller) updateController(ctx context.Context) error { + const op = "controller.(Controller).updateController" + var opts []server.Option + if c.conf.RawConfig.Controller.Description != "" { + opts = append(opts, server.WithDescription(c.conf.RawConfig.Controller.Description)) + } + if c.conf.RawConfig.Controller.PublicClusterAddr != "" { + opts = append(opts, server.WithAddress(c.conf.RawConfig.Controller.PublicClusterAddr)) + } + controller := server.NewController(c.conf.RawConfig.Controller.Name, opts...) + repo, err := c.ServersRepoFn() + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("error fetching repository for status update")) + } + + _, err = repo.UpdateControllerStatus(ctx, controller) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("error performing status update")) + } + return nil +} + func (c *Controller) startNonceCleanupTicking(cancelCtx context.Context) { const op = "controller.(Controller).startNonceCleanupTicking" timer := time.NewTimer(0) @@ -76,7 +98,6 @@ func (c *Controller) startNonceCleanupTicking(cancelCtx context.Context) { case <-cancelCtx.Done(): event.WriteSysEvent(cancelCtx, op, "recovery nonce ticking shutting down") return - case <-timer.C: repo, err := c.ServersRepoFn() if err != nil { @@ -113,7 +134,6 @@ func (c *Controller) startTerminateCompletedSessionsTicking(cancelCtx context.Co case <-cancelCtx.Done(): event.WriteSysEvent(cancelCtx, op, "terminating completed sessions ticking shutting down") return - case <-timer.C: repo, err := c.SessionRepoFn() if err != nil { @@ -148,7 +168,6 @@ func (c *Controller) startCloseExpiredPendingTokens(cancelCtx context.Context) { case <-cancelCtx.Done(): event.WriteSysEvent(cancelCtx, op, "closing expired pending tokens ticking shutting down") return - case <-timer.C: repo, err := c.AuthTokenRepoFn() if err != nil { @@ -192,7 +211,6 @@ func (c *Controller) startWorkerConnectionMaintenanceTicking(cancelCtx context.C case <-cancelCtx.Done(): event.WriteSysEvent(cancelCtx, op, "context done, shutting down") return - case <-timer.C: connectionState := m.Connected() if len(connectionState.WorkerIds()) > 0 { @@ -208,7 +226,6 @@ func (c *Controller) startWorkerConnectionMaintenanceTicking(cancelCtx context.C } connectionState.DisconnectMissingWorkers(knownWorkers) } - if len(connectionState.UnmappedKeyIds()) > 0 { repo, err := c.WorkerAuthRepoStorageFn() if err != nil { @@ -223,7 +240,6 @@ func (c *Controller) startWorkerConnectionMaintenanceTicking(cancelCtx context.C connectionState.DisconnectMissingUnmappedKeyIds(authorized) } } - timer.Reset(getRandomInterval()) } }() diff --git a/internal/server/query.go b/internal/server/query.go index 732a1e7b41..7364a063f2 100644 --- a/internal/server/query.go +++ b/internal/server/query.go @@ -223,4 +223,12 @@ const ( where last_status_time > now() - interval '%d seconds' and operational_state = 'active'; ` + + updateController = ` + update server_controller + set address = @controller_address, + description = @controller_description, + update_time = now() + where private_id = @controller_private_id; +` ) diff --git a/internal/server/repository_controller.go b/internal/server/repository_controller.go index 238352c289..9912adf217 100644 --- a/internal/server/repository_controller.go +++ b/internal/server/repository_controller.go @@ -1,10 +1,10 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 - package server import ( "context" + "database/sql" "fmt" "github.com/hashicorp/boundary/internal/db" @@ -24,12 +24,10 @@ func (r *Repository) listControllersWithReader(ctx context.Context, reader db.Re if liveness == 0 { liveness = DefaultLiveness } - var where string if liveness > 0 { where = fmt.Sprintf("update_time > now() - interval '%d seconds'", uint32(liveness.Seconds())) } - var controllers []*Controller if err := reader.SearchWhere( ctx, @@ -40,17 +38,14 @@ func (r *Repository) listControllersWithReader(ctx context.Context, reader db.Re ); err != nil { return nil, errors.Wrap(ctx, err, "workers.listControllersWithReader") } - return controllers, nil } func (r *Repository) UpsertController(ctx context.Context, controller *Controller) (int, error) { const op = "server.(Repository).UpsertController" - if controller == nil { return db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "controller is nil") } - var rowsUpdated int64 _, err := r.writer.DoTx( ctx, @@ -66,7 +61,6 @@ func (r *Repository) UpsertController(ctx context.Context, controller *Controlle if err != nil { return errors.Wrap(ctx, err, op+":Upsert") } - return nil }, ) @@ -76,3 +70,52 @@ func (r *Repository) UpsertController(ctx context.Context, controller *Controlle return int(rowsUpdated), nil } + +// UpdateControllerStatus updates the controller's status in the repository. +// This includes updating the address or description of the controller as well +// as updating the update_time attribute, which is required for liveness checks +// as part of a controller's status ticking. +func (r *Repository) UpdateControllerStatus(ctx context.Context, controller *Controller) (int, error) { + const op = "server.(Repository).UpdateControllerStatus" + + if controller == nil { + return db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "controller is nil") + } + if controller.PrivateId == "" { + return db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "controller private_id is empty") + } + if controller.Address == "" { + return db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "controller address is empty") + } + + var rowsUpdated int + _, err := r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + var err error + rowsUpdated, err = w.Exec(ctx, updateController, + []any{ + sql.Named("controller_address", controller.Address), + sql.Named("controller_description", controller.Description), + sql.Named("controller_private_id", controller.PrivateId), + }) + switch { + case err != nil: + return errors.Wrap(ctx, err, op+":Update") + case rowsUpdated > 1: + return errors.New(ctx, errors.MultipleRecords, op, "more than 1 resource would have been updated") + case rowsUpdated == 0: + return errors.New(ctx, errors.RecordNotFound, op, "no resources would have been updated") + default: + return nil + } + }, + ) + if err != nil { + return db.NoRowsAffected, err + } + + return rowsUpdated, nil +} diff --git a/internal/server/repository_controller_test.go b/internal/server/repository_controller_test.go index cae413292e..d239d538b4 100644 --- a/internal/server/repository_controller_test.go +++ b/internal/server/repository_controller_test.go @@ -1,6 +1,5 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 - package server import ( @@ -14,6 +13,10 @@ import ( "github.com/stretchr/testify/require" ) +const ( + removeControllerSql = `delete from server_controller where private_id = $1` +) + func TestRepository_UpsertController(t *testing.T) { ctx := context.Background() conn, _ := db.TestSetup(t, "postgres") @@ -22,10 +25,8 @@ func TestRepository_UpsertController(t *testing.T) { testKms := kms.TestKms(t, conn, wrapper) testRepo, err := NewRepository(ctx, rw, rw, testKms) require.NoError(t, err) - iamRepo := iam.TestRepo(t, conn, wrapper) iam.TestScopes(t, iamRepo) - tests := []struct { name string controller *Controller @@ -81,3 +82,147 @@ func TestRepository_UpsertController(t *testing.T) { }) } } + +func TestRepository_UpdateControllerStatus(t *testing.T) { + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + testKms := kms.TestKms(t, conn, wrapper) + testRepo, err := NewRepository(ctx, rw, rw, testKms) + require.NoError(t, err) + + iamRepo := iam.TestRepo(t, conn, wrapper) + iam.TestScopes(t, iamRepo) + + tests := map[string]struct { + originalController *Controller + updatedController *Controller + wantCount int + wantErr bool + cleanUpFunc func(t *testing.T, rw *db.Db, privateId string) + }{ + "nil-controller": { + wantErr: true, + }, + "empty-id": { + updatedController: NewController("", WithAddress("127.0.0.1")), + wantErr: true, + }, + "empty-address": { + updatedController: NewController("test-controller"), + wantErr: true, + }, + "controller-not-found": { + updatedController: NewController("test-controller", WithAddress("127.0.0.1"), WithDescription("new ipv4 description")), + wantErr: true, + }, + "valid-ipv4-controller": { + originalController: NewController("ipv4-controller", WithAddress("127.0.0.1"), WithDescription("ipv4 description")), + updatedController: NewController("ipv4-controller", WithAddress("127.0.0.2"), WithDescription("new ipv4 description")), + wantCount: 1, + cleanUpFunc: func(t *testing.T, rw *db.Db, privateId string) { + t.Helper() + c, err := rw.Exec(ctx, removeControllerSql, []any{privateId}) + require.NoError(t, err) + require.Equal(t, 1, c) + }, + }, + "valid-ipv6-controller": { + originalController: NewController("test-ipv6-controller", WithAddress("[2001:4860:4860:0:0:0:0:8888]"), WithDescription("ipv6 description")), + updatedController: NewController("test-ipv6-controller", WithAddress("[2001:4860:4860:0:0:0:0:9999]"), WithDescription("new ipv6 description")), + wantCount: 1, + cleanUpFunc: func(t *testing.T, rw *db.Db, privateId string) { + t.Helper() + c, err := rw.Exec(ctx, removeControllerSql, []any{privateId}) + require.NoError(t, err) + require.Equal(t, 1, c) + }, + }, + "valid-abbreviated-ipv6-controller": { + originalController: NewController("test-abbreviated-ipv6-controller", WithAddress("[2001:4860:4860::8888]"), WithDescription("abbreviated ipv6 description")), + updatedController: NewController("test-abbreviated-ipv6-controller", WithAddress("[2001:4860:4860::9999]"), WithDescription("new abbreviated ipv6 description")), + wantCount: 1, + cleanUpFunc: func(t *testing.T, rw *db.Db, privateId string) { + t.Helper() + c, err := rw.Exec(ctx, removeControllerSql, []any{privateId}) + require.NoError(t, err) + require.Equal(t, 1, c) + }, + }, + "valid-controller-short-name": { + originalController: NewController("test", WithAddress("127.0.0.1"), WithDescription("short name description")), + updatedController: NewController("test", WithAddress("127.0.0.2"), WithDescription("new short name description")), + wantCount: 1, + cleanUpFunc: func(t *testing.T, rw *db.Db, privateId string) { + t.Helper() + c, err := rw.Exec(ctx, removeControllerSql, []any{privateId}) + require.NoError(t, err) + require.Equal(t, 1, c) + }, + }, + // Test case for updating a controller with the same attributes validating update_time is updated + "duplicate-ipv4-controller-update": { + originalController: NewController("ipv4-controller", WithAddress("127.0.0.1"), WithDescription("new ipv4 description")), + updatedController: NewController("ipv4-controller", WithAddress("127.0.0.1"), WithDescription("new ipv4 description")), + wantCount: 1, + cleanUpFunc: func(t *testing.T, rw *db.Db, privateId string) { + t.Helper() + c, err := rw.Exec(ctx, removeControllerSql, []any{privateId}) + require.NoError(t, err) + require.Equal(t, 1, c) + }, + }, + "duplicate-ipv6-controller-update": { + originalController: NewController("test-ipv6-controller", WithAddress("[2001:4860:4860:0:0:0:0:8888]"), WithDescription("ipv6 description")), + updatedController: NewController("test-ipv6-controller", WithAddress("[2001:4860:4860:0:0:0:0:8888]"), WithDescription("ipv6 description")), + wantCount: 1, + cleanUpFunc: func(t *testing.T, rw *db.Db, privateId string) { + t.Helper() + c, err := rw.Exec(ctx, removeControllerSql, []any{privateId}) + require.NoError(t, err) + require.Equal(t, 1, c) + }, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + var originalControllerEntry *Controller + if tt.originalController != nil { + _, err := testRepo.UpsertController(ctx, tt.originalController) + require.NoError(err) + + controllerList, err := testRepo.ListControllers(ctx, []Option{}...) + require.NoError(err) + originalControllerEntry = controllerList[0] + } + + got, err := testRepo.UpdateControllerStatus(ctx, tt.updatedController) + if tt.wantErr { + require.Error(err) + assert.Equal(0, got) + if tt.cleanUpFunc != nil { + tt.cleanUpFunc(t, rw, tt.updatedController.PrivateId) + } + return + } + require.NoError(err) + assert.Equal(tt.wantCount, got) + + controllerList, err := testRepo.ListControllers(ctx, []Option{}...) + require.NoError(err) + require.Len(controllerList, 1) + + updatedControllerEntry := controllerList[0] + + assert.Equal(tt.updatedController.PrivateId, updatedControllerEntry.PrivateId) + assert.Equal(tt.updatedController.Address, updatedControllerEntry.Address) + assert.Equal(tt.updatedController.Description, updatedControllerEntry.Description) + assert.True(updatedControllerEntry.UpdateTime.AsTime().After(originalControllerEntry.UpdateTime.AsTime())) + tt.cleanUpFunc(t, rw, tt.updatedController.PrivateId) + }) + } +}