diff --git a/internal/cmd/base/dev.go b/internal/cmd/base/dev.go new file mode 100644 index 0000000000..5195c7045b --- /dev/null +++ b/internal/cmd/base/dev.go @@ -0,0 +1,437 @@ +package base + +import ( + "context" + "crypto/ed25519" + "fmt" + "net" + "net/url" + "strings" + + "github.com/hashicorp/boundary/internal/auth/oidc" + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/schema" + "github.com/hashicorp/boundary/internal/docker" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/types/scope" + capoidc "github.com/hashicorp/cap/oidc" + "github.com/hashicorp/go-multierror" +) + +func (b *Server) CreateDevDatabase(ctx context.Context, opt ...Option) error { + var container, url, dialect string + var err error + var c func() error + + opts := getOpts(opt...) + + // We should only get back postgres for now, but laying the foundation for non-postgres + switch opts.withDialect { + case "": + b.Logger.Error("unsupported dialect. wanted: postgres, got: %v", opts.withDialect) + default: + dialect = opts.withDialect + } + + switch b.DatabaseUrl { + case "": + c, url, container, err = docker.StartDbInDocker(dialect, docker.WithContainerImage(opts.withContainerImage)) + // In case of an error, run the cleanup function. If we pass all errors, c should be set to a noop + // function before returning from this method + defer func() { + if !opts.withSkipDatabaseDestruction { + if c != nil { + if err := c(); err != nil { + b.Logger.Error("error cleaning up docker container", "error", err) + } + } + } + }() + if err == docker.ErrDockerUnsupported { + return err + } + if err != nil { + return fmt.Errorf("unable to start dev database with dialect %s: %w", dialect, err) + } + + // Let migrate store manage the dirty bit since dev DBs should be ephemeral anyways. + _, err := schema.MigrateStore(ctx, dialect, url) + if err != nil { + err = fmt.Errorf("unable to initialize dev database with dialect %s: %w", dialect, err) + if c != nil { + err = multierror.Append(err, c()) + } + return err + } + + b.DevDatabaseCleanupFunc = c + b.DatabaseUrl = url + default: + // Let migrate store manage the dirty bit since dev DBs should be ephemeral anyways. + if _, err := schema.MigrateStore(ctx, dialect, b.DatabaseUrl); err != nil { + err = fmt.Errorf("error initializing store: %w", err) + if c != nil { + err = multierror.Append(err, c()) + } + return err + } + } + + b.InfoKeys = append(b.InfoKeys, "dev database url") + b.Info["dev database url"] = b.DatabaseUrl + if container != "" { + b.InfoKeys = append(b.InfoKeys, "dev database container") + b.Info["dev database container"] = strings.TrimPrefix(container, "/") + } + + if err := b.ConnectToDatabase(dialect); err != nil { + if c != nil { + err = multierror.Append(err, c()) + } + return err + } + + b.Database.LogMode(true) + + if err := b.CreateGlobalKmsKeys(ctx); err != nil { + if c != nil { + err = multierror.Append(err, c()) + } + return err + } + + if _, err := b.CreateInitialLoginRole(ctx); err != nil { + if c != nil { + err = multierror.Append(err, c()) + } + return err + } + + if opts.withSkipAuthMethodCreation { + // now that we have passed all the error cases, reset c to be a noop so the + // defer doesn't do anything. + c = func() error { return nil } + return nil + } + + if _, _, err := b.CreateInitialPasswordAuthMethod(ctx); err != nil { + return err + } + + if err := b.CreateDevOidcAuthMethod(ctx); err != nil { + return err + } + + if opts.withSkipScopesCreation { + // now that we have passed all the error cases, reset c to be a noop so the + // defer doesn't do anything. + c = func() error { return nil } + return nil + } + + if _, _, err := b.CreateInitialScopes(ctx); err != nil { + return err + } + + if opts.withSkipHostResourcesCreation { + // now that we have passed all the error cases, reset c to be a noop so the + // defer doesn't do anything. + c = func() error { return nil } + return nil + } + + if _, _, _, err := b.CreateInitialHostResources(context.Background()); err != nil { + return err + } + + if opts.withSkipTargetCreation { + // now that we have passed all the error cases, reset c to be a noop so the + // defer doesn't do anything. + c = func() error { return nil } + return nil + } + + if _, err := b.CreateInitialTarget(ctx); err != nil { + return err + } + + // now that we have passed all the error cases, reset c to be a noop so the + // defer doesn't do anything. + c = func() error { return nil } + return nil +} + +type oidcSetup struct { + clientId string + clientSecret oidc.ClientSecret + oidcPort int + callbackPort string + hostAddr string + authMethod *oidc.AuthMethod + pubKey []byte + privKey []byte + testProvider *capoidc.TestProvider + createUnpriv bool + callbackUrl *url.URL +} + +func (b *Server) CreateDevOidcAuthMethod(ctx context.Context) error { + var err error + + if b.DevOidcAuthMethodId == "" { + b.DevOidcAuthMethodId, err = db.NewPublicId(oidc.AuthMethodPrefix) + if err != nil { + return fmt.Errorf("error generating initial oidc auth method id: %w", err) + } + } + b.InfoKeys = append(b.InfoKeys, "generated oidc auth method id") + b.Info["generated oidc auth method id"] = b.DevOidcAuthMethodId + + switch { + case b.DevUnprivilegedLoginName == "", + b.DevUnprivilegedPassword == "", + b.DevUnprivilegedUserId == "": + + default: + b.DevOidcSetup.createUnpriv = true + } + + // Trawl through the listeners and find the api listener so we can use the + // same host name/IP + { + for _, ln := range b.Listeners { + purpose := strings.ToLower(ln.Config.Purpose[0]) + if purpose != "api" { + continue + } + b.DevOidcSetup.hostAddr, b.DevOidcSetup.callbackPort, err = net.SplitHostPort(ln.Config.Address) + if err != nil { + if strings.Contains(err.Error(), "missing port") { + b.DevOidcSetup.hostAddr = ln.Config.Address + // Use the default API port in the callback + b.DevOidcSetup.callbackPort = "9200" + } else { + return fmt.Errorf("error splitting host/port: %w", err) + } + } + } + if b.DevOidcSetup.hostAddr == "" { + return fmt.Errorf("could not determine address to use for built-in oidc dev listener") + } + } + + // Find an available port -- allocate one, then close the listener, and + // re-use it. This is a sort of hacky way to get around the chicken and egg + // of the auth method needing to know the discovery URL and the test + // provider needing to know the callback URL. + l, err := net.Listen("tcp", fmt.Sprintf("%s:0", b.DevOidcSetup.hostAddr)) + if err != nil { + return fmt.Errorf("error finding port for oidc test provider: %w", err) + } + b.DevOidcSetup.oidcPort = l.(*net.TCPListener).Addr().(*net.TCPAddr).Port + if err := l.Close(); err != nil { + return fmt.Errorf("error closing initial test port: %w", err) + } + b.DevOidcSetup.callbackUrl, err = url.Parse(fmt.Sprintf("http://%s:%s", b.DevOidcSetup.hostAddr, b.DevOidcSetup.callbackPort)) + if err != nil { + return fmt.Errorf("error parsing oidc test provider callback url: %w", err) + } + + // Generate initial IDs/keys + { + b.DevOidcSetup.clientId, err = capoidc.NewID() + if err != nil { + return fmt.Errorf("unable to generate client id: %w", err) + } + clientSecret, err := capoidc.NewID() + if err != nil { + return fmt.Errorf("unable to generate client secret: %w", err) + } + b.DevOidcSetup.clientSecret = oidc.ClientSecret(clientSecret) + b.DevOidcSetup.pubKey, b.DevOidcSetup.privKey, err = ed25519.GenerateKey(nil) + if err != nil { + return fmt.Errorf("unable to generate signing key: %w", err) + } + } + + // Create the subject information and testing provider + { + logger, err := capoidc.NewTestingLogger(b.Logger.Named("dev-oidc")) + if err != nil { + return fmt.Errorf("unable to create logger: %w", err) + } + + subInfo := map[string]*capoidc.TestSubject{ + b.DevLoginName: { + Password: b.DevPassword, + UserInfo: map[string]interface{}{ + "email": "admin@localhost", + "name": "Admin User", + }, + }, + } + if b.DevOidcSetup.createUnpriv { + subInfo[b.DevUnprivilegedLoginName] = &capoidc.TestSubject{ + Password: b.DevUnprivilegedPassword, + UserInfo: map[string]interface{}{ + "email": "user@localhost", + "name": "Unprivileged User", + }, + } + } + + clientSecret := string(b.DevOidcSetup.clientSecret) + + b.DevOidcSetup.testProvider = capoidc.StartTestProvider( + logger, + capoidc.WithNoTLS(), + capoidc.WithTestHost(b.DevOidcSetup.hostAddr), + capoidc.WithTestPort(b.DevOidcSetup.oidcPort), + capoidc.WithTestDefaults(&capoidc.TestProviderDefaults{ + CustomClaims: map[string]interface{}{ + "mode": "dev", + }, + SubjectInfo: subInfo, + SigningKey: &capoidc.TestSigningKey{ + PrivKey: ed25519.PrivateKey(b.DevOidcSetup.privKey), + PubKey: ed25519.PublicKey(b.DevOidcSetup.pubKey), + Alg: capoidc.EdDSA, + }, + AllowedRedirectURIs: []string{fmt.Sprintf("%s/v1/auth-methods/oidc:authenticate:callback", b.DevOidcSetup.callbackUrl.String())}, + ClientID: &b.DevOidcSetup.clientId, + ClientSecret: &clientSecret, + })) + + b.ShutdownFuncs = append(b.ShutdownFuncs, func() error { + b.DevOidcSetup.testProvider.Stop() + return nil + }) + } + + // Create auth method and link accounts + { + b.DevOidcSetup.authMethod, err = b.createInitialOidcAuthMethod(ctx) + if err != nil { + return fmt.Errorf("error creating initial oidc auth method: %w", err) + } + } + + return nil +} + +func (b *Server) createInitialOidcAuthMethod(ctx context.Context) (*oidc.AuthMethod, error) { + rw := db.New(b.Database) + + kmsRepo, err := kms.NewRepository(rw, rw) + if err != nil { + return nil, fmt.Errorf("error creating kms repository: %w", err) + } + kmsCache, err := kms.NewKms(kmsRepo, kms.WithLogger(b.Logger.Named("kms"))) + if err != nil { + return nil, fmt.Errorf("error creating kms cache: %w", err) + } + if err := kmsCache.AddExternalWrappers( + kms.WithRootWrapper(b.RootKms), + ); err != nil { + return nil, fmt.Errorf("error adding config keys to kms: %w", err) + } + + discoveryUrl, err := url.Parse(fmt.Sprintf("http://%s:%d", b.DevOidcSetup.hostAddr, b.DevOidcSetup.oidcPort)) + if err != nil { + return nil, fmt.Errorf("error parsing oidc test provider address: %w", err) + } + + // Create the auth method + oidcRepo, err := oidc.NewRepository(rw, rw, kmsCache) + if err != nil { + return nil, fmt.Errorf("error creating oidc repo: %w", err) + } + + authMethod, err := oidc.NewAuthMethod( + scope.Global.String(), + b.DevOidcSetup.clientId, + b.DevOidcSetup.clientSecret, + oidc.WithName("Generated global scope initial oidc auth method"), + oidc.WithDescription("Provides initial administrative and unprivileged authentication into Boundary"), + oidc.WithIssuer(discoveryUrl), + oidc.WithApiUrl(b.DevOidcSetup.callbackUrl), + oidc.WithSigningAlgs(oidc.EdDSA), + oidc.WithOperationalState(oidc.ActivePublicState)) + if err != nil { + return nil, fmt.Errorf("error creating new in memory oidc auth method: %w", err) + } + if b.DevOidcAuthMethodId == "" { + b.DevOidcAuthMethodId, err = db.NewPublicId(oidc.AuthMethodPrefix) + if err != nil { + return nil, fmt.Errorf("error generating initial oidc auth method id: %w", err) + } + } + + cancelCtx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + select { + case <-b.ShutdownCh: + cancel() + case <-cancelCtx.Done(): + } + }() + + b.DevOidcSetup.authMethod, err = oidcRepo.CreateAuthMethod( + cancelCtx, + authMethod, + oidc.WithPublicId(b.DevOidcAuthMethodId)) + if err != nil { + return nil, fmt.Errorf("error saving oidc auth method to the db: %w", err) + } + + // Create accounts + { + createAndLinkAccount := func(loginName, userId, typ string) error { + acct, err := oidc.NewAccount( + b.DevOidcSetup.authMethod.GetPublicId(), + loginName, + oidc.WithDescription(fmt.Sprintf("Initial %s OIDC account", typ)), + ) + if err != nil { + return fmt.Errorf("error generating %s oidc account: %w", typ, err) + } + acct, err = oidcRepo.CreateAccount( + cancelCtx, + b.DevOidcSetup.authMethod.GetScopeId(), + acct, + ) + if err != nil { + return fmt.Errorf("error creating %s oidc account: %w", typ, err) + } + + // Link accounts to existing user + iamRepo, err := iam.NewRepository(rw, rw, kmsCache) + if err != nil { + return fmt.Errorf("unable to create iam repo: %w", err) + } + + u, _, err := iamRepo.LookupUser(cancelCtx, userId) + if err != nil { + return fmt.Errorf("error looking up %s user: %w", typ, err) + } + if _, err = iamRepo.AddUserAccounts(cancelCtx, u.GetPublicId(), u.GetVersion(), []string{acct.GetPublicId()}); err != nil { + return fmt.Errorf("error associating initial %s user with account: %w", typ, err) + } + + return nil + } + + if err := createAndLinkAccount(b.DevLoginName, b.DevUserId, "admin"); err != nil { + return nil, err + } + if b.DevOidcSetup.createUnpriv { + if err := createAndLinkAccount(b.DevUnprivilegedLoginName, b.DevUnprivilegedUserId, "unprivileged"); err != nil { + return nil, err + } + } + } + + return nil, nil +} diff --git a/internal/cmd/base/servers.go b/internal/cmd/base/servers.go index 2c846412cf..b6d1837219 100644 --- a/internal/cmd/base/servers.go +++ b/internal/cmd/base/servers.go @@ -20,8 +20,6 @@ import ( "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/cmd/config" "github.com/hashicorp/boundary/internal/db" - "github.com/hashicorp/boundary/internal/db/schema" - "github.com/hashicorp/boundary/internal/docker" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/types/scope" "github.com/hashicorp/boundary/sdk/strutil" @@ -88,6 +86,8 @@ type Server struct { DevTargetSessionMaxSeconds int DevTargetSessionConnectionLimit int + DevOidcSetup oidcSetup + DatabaseUrl string DatabaseMaxOpenConnections int DevDatabaseCleanupFunc func() error @@ -467,145 +467,6 @@ func (b *Server) ConnectToDatabase(dialect string) error { return nil } -func (b *Server) CreateDevDatabase(ctx context.Context, opt ...Option) error { - var container, url, dialect string - var err error - var c func() error - - opts := getOpts(opt...) - - // We should only get back postgres for now, but laying the foundation for non-postgres - switch opts.withDialect { - case "": - b.Logger.Error("unsupported dialect. wanted: postgres, got: %v", opts.withDialect) - default: - dialect = opts.withDialect - } - - switch b.DatabaseUrl { - case "": - c, url, container, err = docker.StartDbInDocker(dialect, docker.WithContainerImage(opts.withContainerImage)) - // In case of an error, run the cleanup function. If we pass all errors, c should be set to a noop - // function before returning from this method - defer func() { - if !opts.withSkipDatabaseDestruction { - if c != nil { - if err := c(); err != nil { - b.Logger.Error("error cleaning up docker container", "error", err) - } - } - } - }() - if err == docker.ErrDockerUnsupported { - return err - } - if err != nil { - return fmt.Errorf("unable to start dev database with dialect %s: %w", dialect, err) - } - - // Let migrate store manage the dirty bit since dev DBs should be ephemeral anyways. - _, err := schema.MigrateStore(ctx, dialect, url) - if err != nil { - err = fmt.Errorf("unable to initialize dev database with dialect %s: %w", dialect, err) - if c != nil { - err = multierror.Append(err, c()) - } - return err - } - - b.DevDatabaseCleanupFunc = c - b.DatabaseUrl = url - default: - // Let migrate store manage the dirty bit since dev DBs should be ephemeral anyways. - if _, err := schema.MigrateStore(ctx, dialect, b.DatabaseUrl); err != nil { - err = fmt.Errorf("error initializing store: %w", err) - if c != nil { - err = multierror.Append(err, c()) - } - return err - } - } - - b.InfoKeys = append(b.InfoKeys, "dev database url") - b.Info["dev database url"] = b.DatabaseUrl - if container != "" { - b.InfoKeys = append(b.InfoKeys, "dev database container") - b.Info["dev database container"] = strings.TrimPrefix(container, "/") - } - - if err := b.ConnectToDatabase(dialect); err != nil { - if c != nil { - err = multierror.Append(err, c()) - } - return err - } - - b.Database.LogMode(true) - - if err := b.CreateGlobalKmsKeys(ctx); err != nil { - if c != nil { - err = multierror.Append(err, c()) - } - return err - } - - if _, err := b.CreateInitialLoginRole(ctx); err != nil { - if c != nil { - err = multierror.Append(err, c()) - } - return err - } - - if opts.withSkipAuthMethodCreation { - // now that we have passed all the error cases, reset c to be a noop so the - // defer doesn't do anything. - c = func() error { return nil } - return nil - } - - if _, _, err := b.CreateInitialPasswordAuthMethod(ctx); err != nil { - return err - } - - if opts.withSkipScopesCreation { - // now that we have passed all the error cases, reset c to be a noop so the - // defer doesn't do anything. - c = func() error { return nil } - return nil - } - - if _, _, err := b.CreateInitialScopes(ctx); err != nil { - return err - } - - if opts.withSkipHostResourcesCreation { - // now that we have passed all the error cases, reset c to be a noop so the - // defer doesn't do anything. - c = func() error { return nil } - return nil - } - - if _, _, _, err := b.CreateInitialHostResources(context.Background()); err != nil { - return err - } - - if opts.withSkipTargetCreation { - // now that we have passed all the error cases, reset c to be a noop so the - // defer doesn't do anything. - c = func() error { return nil } - return nil - } - - if _, err := b.CreateInitialTarget(ctx); err != nil { - return err - } - - // now that we have passed all the error cases, reset c to be a noop so the - // defer doesn't do anything. - c = func() error { return nil } - return nil -} - func (b *Server) CreateGlobalKmsKeys(ctx context.Context) error { rw := db.New(b.Database) diff --git a/internal/cmd/commands/dev/dev.go b/internal/cmd/commands/dev/dev.go index 6f4f14021e..4ccfe45e73 100644 --- a/internal/cmd/commands/dev/dev.go +++ b/internal/cmd/commands/dev/dev.go @@ -1,11 +1,8 @@ package dev import ( - "context" - "crypto/ed25519" "fmt" "net" - "net/url" "runtime" "strings" @@ -13,17 +10,14 @@ import ( "github.com/hashicorp/boundary/internal/auth/password" "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/cmd/config" - "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/host/static" "github.com/hashicorp/boundary/internal/iam" - "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/servers/controller" "github.com/hashicorp/boundary/internal/servers/controller/handlers" "github.com/hashicorp/boundary/internal/servers/worker" "github.com/hashicorp/boundary/internal/target" "github.com/hashicorp/boundary/internal/types/scope" "github.com/hashicorp/boundary/sdk/strutil" - capoidc "github.com/hashicorp/cap/oidc" "github.com/mitchellh/cli" "github.com/posener/complete" ) @@ -45,8 +39,6 @@ type Command struct { controller *controller.Controller worker *worker.Worker - oidcSetup oidcSetup - flagLogLevel string flagLogFormat string flagCombineLogs bool @@ -497,11 +489,6 @@ func (c *Command) Run(args []string) int { } } - if err := c.startDevOidcAuthMethod(); err != nil { - c.UI.Error(fmt.Errorf("Error starting dev OIDC auth method: %w", err).Error()) - return base.CommandCliError - } - c.PrintInfo(c.UI) c.ReleaseLogGate() @@ -585,277 +572,3 @@ func (c *Command) Run(args []string) int { return base.CommandSuccess } - -type oidcSetup struct { - clientId string - clientSecret oidc.ClientSecret - oidcPort int - callbackPort string - hostAddr string - authMethod *oidc.AuthMethod - pubKey []byte - privKey []byte - testProvider *capoidc.TestProvider - createUnpriv bool - callbackUrl *url.URL -} - -func (c *Command) startDevOidcAuthMethod() error { - var err error - - if c.DevOidcAuthMethodId == "" { - c.DevOidcAuthMethodId, err = db.NewPublicId(oidc.AuthMethodPrefix) - if err != nil { - return fmt.Errorf("error generating initial oidc auth method id: %w", err) - } - } - c.InfoKeys = append(c.InfoKeys, "generated oidc auth method id") - c.Info["generated oidc auth method id"] = c.DevOidcAuthMethodId - - switch { - case c.DevUnprivilegedLoginName == "", - c.DevUnprivilegedPassword == "", - c.DevUnprivilegedUserId == "": - - default: - c.oidcSetup.createUnpriv = true - } - - // Trawl through the listeners and find the api listener so we can use the - // same host name/IP - { - for _, lnConfig := range c.Config.Listeners { - purpose := strings.ToLower(lnConfig.Purpose[0]) - if purpose != "api" { - continue - } - c.oidcSetup.hostAddr, c.oidcSetup.callbackPort, err = net.SplitHostPort(lnConfig.Address) - if err != nil { - if strings.Contains(err.Error(), "missing port") { - c.oidcSetup.hostAddr = lnConfig.Address - // Use the default API port in the callback - c.oidcSetup.callbackPort = "9200" - } else { - return fmt.Errorf("error splitting host/port: %w", err) - } - } - } - if c.oidcSetup.hostAddr == "" { - return fmt.Errorf("could not determine address to use for built-in oidc dev listener") - } - } - - // Find an available port -- allocate one, then close the listener, and - // re-use it. This is a sort of hacky way to get around the chicken and egg - // of the auth method needing to know the discovery URL and the test - // provider needing to know the callback URL. - l, err := net.Listen("tcp", fmt.Sprintf("%s:0", c.oidcSetup.hostAddr)) - if err != nil { - return fmt.Errorf("error finding port for oidc test provider: %w", err) - } - c.oidcSetup.oidcPort = l.(*net.TCPListener).Addr().(*net.TCPAddr).Port - if err := l.Close(); err != nil { - return fmt.Errorf("error closing initial test port: %w", err) - } - c.oidcSetup.callbackUrl, err = url.Parse(fmt.Sprintf("http://%s:%s", c.oidcSetup.hostAddr, c.oidcSetup.callbackPort)) - if err != nil { - return fmt.Errorf("error parsing oidc test provider callback url: %w", err) - } - - // Generate initial IDs/keys - { - c.oidcSetup.clientId, err = capoidc.NewID() - if err != nil { - return fmt.Errorf("unable to generate client id: %w", err) - } - clientSecret, err := capoidc.NewID() - if err != nil { - return fmt.Errorf("unable to generate client secret: %w", err) - } - c.oidcSetup.clientSecret = oidc.ClientSecret(clientSecret) - c.oidcSetup.pubKey, c.oidcSetup.privKey, err = ed25519.GenerateKey(nil) - if err != nil { - return fmt.Errorf("unable to generate signing key: %w", err) - } - } - - // Create the subject information and testing provider - { - logger, err := capoidc.NewTestingLogger(c.Logger.Named("dev-oidc")) - if err != nil { - return fmt.Errorf("unable to create logger: %w", err) - } - - subInfo := map[string]*capoidc.TestSubject{ - c.DevLoginName: { - Password: c.DevPassword, - UserInfo: map[string]interface{}{ - "email": "admin@localhost", - "name": "Admin User", - }, - }, - } - if c.oidcSetup.createUnpriv { - subInfo[c.DevUnprivilegedLoginName] = &capoidc.TestSubject{ - Password: c.DevUnprivilegedPassword, - UserInfo: map[string]interface{}{ - "email": "user@localhost", - "name": "Unprivileged User", - }, - } - } - - clientSecret := string(c.oidcSetup.clientSecret) - - c.oidcSetup.testProvider = capoidc.StartTestProvider( - logger, - capoidc.WithNoTLS(), - capoidc.WithTestHost(c.oidcSetup.hostAddr), - capoidc.WithTestPort(c.oidcSetup.oidcPort), - capoidc.WithTestDefaults(&capoidc.TestProviderDefaults{ - CustomClaims: map[string]interface{}{ - "mode": "dev", - }, - SubjectInfo: subInfo, - SigningKey: &capoidc.TestSigningKey{ - PrivKey: ed25519.PrivateKey(c.oidcSetup.privKey), - PubKey: ed25519.PublicKey(c.oidcSetup.pubKey), - Alg: capoidc.EdDSA, - }, - AllowedRedirectURIs: []string{fmt.Sprintf("%s/v1/auth-methods/oidc:authenticate:callback", c.oidcSetup.callbackUrl.String())}, - ClientID: &c.oidcSetup.clientId, - ClientSecret: &clientSecret, - })) - - c.ShutdownFuncs = append(c.ShutdownFuncs, func() error { - c.oidcSetup.testProvider.Stop() - return nil - }) - } - - // Create auth method and link accounts - { - c.oidcSetup.authMethod, err = c.createInitialOidcAuthMethod() - if err != nil { - return fmt.Errorf("error creating initial oidc auth method: %w", err) - } - } - - return nil -} - -func (c *Command) createInitialOidcAuthMethod() (*oidc.AuthMethod, error) { - rw := db.New(c.Database) - - kmsRepo, err := kms.NewRepository(rw, rw) - if err != nil { - return nil, fmt.Errorf("error creating kms repository: %w", err) - } - kmsCache, err := kms.NewKms(kmsRepo, kms.WithLogger(c.Logger.Named("kms"))) - if err != nil { - return nil, fmt.Errorf("error creating kms cache: %w", err) - } - if err := kmsCache.AddExternalWrappers( - kms.WithRootWrapper(c.RootKms), - ); err != nil { - return nil, fmt.Errorf("error adding config keys to kms: %w", err) - } - - discoveryUrl, err := url.Parse(fmt.Sprintf("http://%s:%d", c.oidcSetup.hostAddr, c.oidcSetup.oidcPort)) - if err != nil { - return nil, fmt.Errorf("error parsing oidc test provider address: %w", err) - } - - // Create the auth method - oidcRepo, err := oidc.NewRepository(rw, rw, kmsCache) - if err != nil { - return nil, fmt.Errorf("error creating oidc repo: %w", err) - } - - authMethod, err := oidc.NewAuthMethod( - scope.Global.String(), - c.oidcSetup.clientId, - c.oidcSetup.clientSecret, - oidc.WithName("Generated global scope initial oidc auth method"), - oidc.WithDescription("Provides initial administrative and unprivileged authentication into Boundary"), - oidc.WithIssuer(discoveryUrl), - oidc.WithApiUrl(c.oidcSetup.callbackUrl), - oidc.WithSigningAlgs(oidc.EdDSA), - oidc.WithOperationalState(oidc.ActivePublicState)) - if err != nil { - return nil, fmt.Errorf("error creating new in memory oidc auth method: %w", err) - } - if c.DevOidcAuthMethodId == "" { - c.DevOidcAuthMethodId, err = db.NewPublicId(oidc.AuthMethodPrefix) - if err != nil { - return nil, fmt.Errorf("error generating initial oidc auth method id: %w", err) - } - } - - cancelCtx, cancel := context.WithCancel(c.Context) - defer cancel() - go func() { - select { - case <-c.ShutdownCh: - cancel() - case <-cancelCtx.Done(): - } - }() - - c.oidcSetup.authMethod, err = oidcRepo.CreateAuthMethod( - cancelCtx, - authMethod, - oidc.WithPublicId(c.DevOidcAuthMethodId)) - if err != nil { - return nil, fmt.Errorf("error saving oidc auth method to the db: %w", err) - } - - // Create accounts - { - createAndLinkAccount := func(loginName, userId, typ string) error { - acct, err := oidc.NewAccount( - c.oidcSetup.authMethod.GetPublicId(), - loginName, - oidc.WithDescription(fmt.Sprintf("Initial %s OIDC account", typ)), - ) - if err != nil { - return fmt.Errorf("error generating %s oidc account: %w", typ, err) - } - acct, err = oidcRepo.CreateAccount( - cancelCtx, - c.oidcSetup.authMethod.GetScopeId(), - acct, - ) - if err != nil { - return fmt.Errorf("error creating %s oidc account: %w", typ, err) - } - - // Link accounts to existing user - iamRepo, err := iam.NewRepository(rw, rw, kmsCache) - if err != nil { - return fmt.Errorf("unable to create iam repo: %w", err) - } - - u, _, err := iamRepo.LookupUser(cancelCtx, userId) - if err != nil { - return fmt.Errorf("error looking up %s user: %w", typ, err) - } - if _, err = iamRepo.AddUserAccounts(cancelCtx, u.GetPublicId(), u.GetVersion(), []string{acct.GetPublicId()}); err != nil { - return fmt.Errorf("error associating initial %s user with account: %w", typ, err) - } - - return nil - } - - if err := createAndLinkAccount(c.DevLoginName, c.DevUserId, "admin"); err != nil { - return nil, err - } - if c.oidcSetup.createUnpriv { - if err := createAndLinkAccount(c.DevUnprivilegedLoginName, c.DevUnprivilegedUserId, "unprivileged"); err != nil { - return nil, err - } - } - } - - return nil, nil -} diff --git a/internal/db/schema/migrations/postgres/9/01_managed_groups.up.sql b/internal/db/schema/migrations/postgres/9/01_managed_groups.up.sql new file mode 100644 index 0000000000..de1eb13197 --- /dev/null +++ b/internal/db/schema/migrations/postgres/9/01_managed_groups.up.sql @@ -0,0 +1,60 @@ +begin; + +-- The base abstract table +create table auth_managed_group ( + public_id wt_public_id + primary key, + auth_method_id wt_public_id + not null, + -- Ensure that if the auth method is deleted (which will also happen if the + -- scope is deleted) this is deleted too + constraint auth_method_fkey + foreign key (auth_method_id) -- fk1 + references auth_method(public_id) + on delete cascade + on update cascade, + constraint auth_managed_group_auth_method_id_public_id_uq + unique(auth_method_id, public_id) +); +comment on table auth_managed_group is +'auth_managed_group is the abstract base table for managed groups.'; + +-- Define the immutable fields of auth_managed_group +create trigger + immutable_columns +before +update on auth_managed_group + for each row execute procedure immutable_columns('public_id', 'auth_method_id'); + +-- Function to insert into the base table when values are inserted into a +-- concrete type table. This happens before inserts so the foreign keys in the +-- concrete type will be valid. +create or replace function + insert_managed_group_subtype() + returns trigger +as $$ +begin + + insert into auth_managed_group + (public_id, auth_method_id) + values + (new.public_id, new.auth_method_id); + + return new; + +end; +$$ language plpgsql; + +-- delete_managed_group_subtype() is an after delete trigger +-- function for subtypes of managed_group +create or replace function delete_managed_group_subtype() + returns trigger +as $$ +begin + delete from auth_managed_group + where public_id = old.public_id; + return null; -- result is ignored since this is an after trigger +end; +$$ language plpgsql; + +commit; diff --git a/internal/db/schema/migrations/postgres/9/02_oidc_managed_group.up.sql b/internal/db/schema/migrations/postgres/9/02_oidc_managed_group.up.sql new file mode 100644 index 0000000000..49b579d795 --- /dev/null +++ b/internal/db/schema/migrations/postgres/9/02_oidc_managed_group.up.sql @@ -0,0 +1,83 @@ +begin; + +create table auth_oidc_managed_group ( + public_id wt_public_id + primary key, + auth_method_id wt_public_id + not null, + name wt_name, + description wt_description, + create_time wt_timestamp, + update_time wt_timestamp, + version wt_version, + filter wt_bexprfilter + not null, + -- Ensure that this managed group relates to an oidc auth method, as opposed + -- to other types + constraint auth_oidc_method_fkey + foreign key (auth_method_id) -- fk1 + references auth_oidc_method (public_id) + on delete cascade + on update cascade, + -- Ensure it relates to an abstract managed group + constraint auth_managed_group_fkey + foreign key (auth_method_id, public_id) -- fk2 + references auth_managed_group (auth_method_id, public_id) + on delete cascade + on update cascade, + constraint auth_oidc_managed_group_auth_method_id_name_uq + unique(auth_method_id, name) +); +comment on table auth_oidc_managed_group is +'auth_oidc_managed_group entries are subtypes of auth_managed_group and represent an oidc managed group.'; + +-- Define the immutable fields of auth_oidc_managed_group +create trigger + immutable_columns +before +update on auth_oidc_managed_group + for each row execute procedure immutable_columns('public_id', 'auth_method_id', 'create_time'); + +-- Populate create time on insert +create trigger + default_create_time_column +before +insert on auth_oidc_managed_group + for each row execute procedure default_create_time(); + +-- Generate update time on update +create trigger + update_time_column +before +update on auth_oidc_managed_group + for each row execute procedure update_time_column(); + +-- Update version when something changes +create trigger + update_version_column +after +update on auth_oidc_managed_group + for each row execute procedure update_version_column(); + +-- Add into the base table when inserting into the concrete table +create trigger + insert_managed_group_subtype +before insert on auth_oidc_managed_group + for each row execute procedure insert_managed_group_subtype(); + +-- Ensure that deletions in the oidc subtype result in deletions to the base +-- table. +create trigger + delete_managed_group_subtype +after +delete on auth_oidc_managed_group + for each row execute procedure delete_managed_group_subtype(); + +-- The tickets for oplog are the subtypes not the base types because no updates +-- are done to any values in the base types. +insert into oplog_ticket + (name, version) +values + ('auth_oidc_managed_group', 1); + +commit; diff --git a/internal/db/schema/migrations/postgres/9/03_oidc_managed_group_member.up.sql b/internal/db/schema/migrations/postgres/9/03_oidc_managed_group_member.up.sql new file mode 100644 index 0000000000..a173dc14ad --- /dev/null +++ b/internal/db/schema/migrations/postgres/9/03_oidc_managed_group_member.up.sql @@ -0,0 +1,52 @@ +begin; + +-- Mappings of account to oidc managed groups. This is a non-abstract table with +-- a view (below) so that it is a natural aggregate for the oplog (also below). +create table auth_oidc_managed_group_member_account ( + create_time wt_timestamp, + managed_group_id wt_public_id + references auth_oidc_managed_group(public_id) + on delete cascade + on update cascade, + member_id wt_public_id + references auth_oidc_account(public_id) + on delete cascade + on update cascade, + primary key (managed_group_id, member_id) +); +comment on table auth_oidc_managed_group_member_account is +'auth_oidc_managed_group_member_account is the join table for managed oidc groups and accounts.'; + +-- auth_immutable_managed_oidc_group_member_account() ensures that group members are immutable. +create or replace function + auth_immutable_managed_oidc_group_member_account() + returns trigger +as $$ +begin + raise exception 'managed oidc group members are immutable'; +end; +$$ language plpgsql; + +create trigger + default_create_time_column +before +insert on auth_oidc_managed_group_member_account + for each row execute procedure default_create_time(); + +create trigger + auth_immutable_managed_oidc_group_member_account +before +update on auth_oidc_managed_group_member_account + for each row execute procedure auth_immutable_managed_oidc_group_member_account(); + +-- Initially create the view with just oidc; eventually we can replace this view +-- to union with other subtype tables. +create view auth_managed_group_member_account as +select + oidc.create_time, + oidc.managed_group_id, + oidc.member_id +from + auth_oidc_managed_group_member_account oidc; + +commit; diff --git a/internal/db/schema/migrations/postgres/9/managed_group_test.go b/internal/db/schema/migrations/postgres/9/managed_group_test.go new file mode 100644 index 0000000000..b6060519a8 --- /dev/null +++ b/internal/db/schema/migrations/postgres/9/managed_group_test.go @@ -0,0 +1,376 @@ +package migration + +import ( + "fmt" + "testing" + "time" + + "github.com/hashicorp/boundary/internal/servers/controller" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ManagedGroupTable(t *testing.T) { + t.Parallel() + tc := controller.NewTestController(t, nil) + defer tc.Shutdown() + + db := tc.DbConn().DB() + var err error + + managedGroupId := "a_bcdefghijk" + defaultPasswordAuthMethodId := "ampw_1234567890" + defaultOidcAuthMethodId := "amoidc_1234567890" + + insertTests := []struct { + testName string + publicId string + authMethodId string + wantErr bool + }{ + { + testName: "invalid auth method", + publicId: managedGroupId, + authMethodId: "amoid_1234567890", + wantErr: true, + }, + { + testName: "valid", + publicId: managedGroupId, + authMethodId: defaultOidcAuthMethodId, + wantErr: false, + }, + } + for _, tt := range insertTests { + t.Run("insert: "+tt.testName, func(t *testing.T) { + require := require.New(t) + _, err = db.Exec("insert into auth_managed_group values ($1, $2)", + tt.publicId, + tt.authMethodId) + require.True(tt.wantErr == (err != nil)) + }) + } + + updateTests := []struct { + testName string + column string + value string + publicId string + wantErr bool + }{ + { + testName: "immutable public id", + column: "public_id", + value: "z_yxwvutsrqp", + publicId: managedGroupId, + wantErr: true, + }, + { + testName: "immutable auth method", + column: "auth_method_id", + value: defaultPasswordAuthMethodId, + publicId: managedGroupId, + wantErr: true, + }, + } + for _, tt := range updateTests { + t.Run("update: "+tt.testName, func(t *testing.T) { + assert := assert.New(t) + _, err = db.Exec(fmt.Sprintf("update auth_managed_group set %s = $1 where public_id = $2", tt.column), tt.value, tt.publicId) + assert.True(tt.wantErr == (err != nil)) + }) + } +} + +func Test_OidcManagedGroupTable(t *testing.T) { + t.Parallel() + tc := controller.NewTestController(t, nil) + defer tc.Shutdown() + + db := tc.DbConn().DB() + var err error + + managedGroupId := "a_bcdefghijk" + defaultPasswordAuthMethodId := "ampw_1234567890" + defaultOidcAuthMethodId := "amoidc_1234567890" + name := "this is the name" + filter := "this is a filter" + + // The first set of tests is for initial insertion + { + insertTests := []struct { + testName string + publicId string + authMethodId string + name string + filter string + wantErr bool + }{ + { + testName: "null filter", + publicId: managedGroupId, + authMethodId: "amoid_1234567890", + name: name, + filter: "", + wantErr: true, + }, + { + testName: "invalid auth method", + publicId: managedGroupId, + authMethodId: defaultPasswordAuthMethodId, + name: name, + filter: filter, + wantErr: true, + }, + { + testName: "valid", + publicId: managedGroupId, + authMethodId: defaultOidcAuthMethodId, + name: name, + filter: filter, + wantErr: false, + }, + { + testName: "duplicate public id", + publicId: managedGroupId, + authMethodId: defaultOidcAuthMethodId, + name: name, + filter: filter, + wantErr: true, + }, + { + testName: "duplicate name", + publicId: "z_yxwvutsrqp", + authMethodId: defaultOidcAuthMethodId, + name: name, + filter: filter, + wantErr: true, + }, + } + for _, tt := range insertTests { + t.Run("insert: "+tt.testName, func(t *testing.T) { + require := require.New(t) + _, err = db.Exec("insert into auth_oidc_managed_group (public_id, auth_method_id, name, filter) values ($1, $2, $3, $4)", + tt.publicId, + tt.authMethodId, + tt.name, + tt.filter) + require.True(tt.wantErr == (err != nil)) + }) + } + } + + // Read some values to validate that things were set automatically + rows, err := db.Query("select create_time, update_time, version from auth_oidc_managed_group") + require.NoError(t, err) + require.True(t, rows.Next()) + var create_time, update_time time.Time + var version int + require.NoError(t, rows.Scan(&create_time, &update_time, &version)) + assert.False(t, create_time.IsZero()) + assert.Equal(t, update_time, create_time) + assert.Equal(t, 1, version) + + // These update tests check immutability + { + updateTests := []struct { + testName string + column string + value interface{} + wantErr bool + }{ + { + testName: "immutable public id", + column: "public_id", + value: "z_yxwvutsrqp", + wantErr: true, + }, + { + testName: "immutable auth method", + column: "auth_method_id", + value: defaultPasswordAuthMethodId, + wantErr: true, + }, + { + testName: "immutable creation time", + column: "create_time", + value: time.Now(), + wantErr: true, + }, + { + testName: "valid", + column: "description", + value: "this is the description", + wantErr: false, + }, + } + for _, tt := range updateTests { + t.Run("update: "+tt.testName, func(t *testing.T) { + require := require.New(t) + _, err = db.Exec(fmt.Sprintf("update auth_oidc_managed_group set %s = $1 where public_id = $2", tt.column), tt.value, managedGroupId) + require.True(tt.wantErr == (err != nil)) + }) + } + } + + // Read values again to validate that things were updated automatically + rows, err = db.Query("select create_time, update_time, version from auth_oidc_managed_group") + require.NoError(t, err) + require.True(t, rows.Next()) + var updated_create_time, updated_update_time time.Time + require.NoError(t, rows.Scan(&updated_create_time, &updated_update_time, &version)) + assert.Equal(t, create_time, updated_create_time) + assert.NotEqual(t, update_time, updated_update_time) + assert.Equal(t, 2, version) + + // Read values from auth_managed_group to ensure it was populated automatically + rows, err = db.Query("select public_id, auth_method_id from auth_managed_group") + require.NoError(t, err) + require.True(t, rows.Next()) + var public_id, auth_method_id string + require.NoError(t, rows.Scan(&public_id, &auth_method_id)) + assert.Equal(t, managedGroupId, public_id) + assert.Equal(t, defaultOidcAuthMethodId, auth_method_id) + + // Delete the value from the subtype table + res, err := db.Exec("delete from auth_oidc_managed_group where public_id = $1", managedGroupId) + require.NoError(t, err) + affected, err := res.RowsAffected() + require.NoError(t, err) + require.EqualValues(t, 1, affected) + + // It should no longer be in the base table + rows, err = db.Query("select public_id, auth_method_id from auth_managed_group") + require.NoError(t, err) + require.False(t, rows.Next()) +} + +func Test_AuthManagedOidcGroupMemberAccountTable(t *testing.T) { + t.Parallel() + tc := controller.NewTestController(t, nil) + defer tc.Shutdown() + + db := tc.DbConn().DB() + var err error + + managedGroupId := "a_bcdefghijk" + defaultOidcAuthMethodId := "amoidc_1234567890" + name := "this is the name" + filter := "this is a filter" + + // Insert valid data in auth_oidc_managed_group to use for the following tests + _, err = db.Exec("insert into auth_oidc_managed_group (public_id, auth_method_id, name, filter) values ($1, $2, $3, $4)", + managedGroupId, + defaultOidcAuthMethodId, + name, + filter) + require.NoError(t, err) + + // Fetch a valid (oidc) account ID to use in insertion + rows, err := db.Query("select public_id from auth_oidc_account limit 1") + require.NoError(t, err) + require.True(t, rows.Next()) + var accountId string + require.NoError(t, rows.Scan(&accountId)) + require.NotEmpty(t, accountId) + + // The first set of tests is for initial insertion + { + insertTests := []struct { + testName string + managedGroupId string + memberId string + wantErr bool + }{ + { + testName: "invalid managed group id", + managedGroupId: "z_yxwvutsrqp", + memberId: accountId, + wantErr: true, + }, + { + testName: "invalid member id", + managedGroupId: managedGroupId, + memberId: "acct_1234567890", + wantErr: true, + }, + { + testName: "valid", + managedGroupId: managedGroupId, + memberId: accountId, + wantErr: false, + }, + { + testName: "duplicate values", + managedGroupId: managedGroupId, + memberId: accountId, + wantErr: true, + }, + } + for _, tt := range insertTests { + t.Run("insert: "+tt.testName, func(t *testing.T) { + assert := assert.New(t) + _, err = db.Exec("insert into auth_oidc_managed_group_member_account (managed_group_id, member_id) values ($1, $2)", + tt.managedGroupId, + tt.memberId) + assert.True(tt.wantErr == (err != nil)) + }) + } + } + + // Read some values to validate that things were set automatically + rows, err = db.Query("select create_time, managed_group_id, member_id from auth_oidc_managed_group_member_account") + require.NoError(t, err) + require.True(t, rows.Next()) + var create_time time.Time + var managed_group_id, member_id string + require.NoError(t, rows.Scan(&create_time, &managed_group_id, &member_id)) + assert.False(t, create_time.IsZero()) + + // These update tests check immutability + { + updateTests := []struct { + testName string + column string + value interface{} + wantErr bool + }{ + { + testName: "immutable managed group id", + column: "managed_group_id", + value: "z_yxwvutsrqp", + wantErr: true, + }, + { + testName: "immutable member_id", + column: "member_id", + value: "acct_1234567890", + wantErr: true, + }, + { + testName: "immutable creation time", + column: "create_time", + value: time.Now(), + wantErr: true, + }, + } + for _, tt := range updateTests { + t.Run("update: "+tt.testName, func(t *testing.T) { + assert := assert.New(t) + _, err = db.Exec(fmt.Sprintf("update auth_managed_group_member_account set %s = $1 where managed_group_id = $2 and member_id = $3", tt.column), managedGroupId, accountId) + assert.True(tt.wantErr == (err != nil)) + }) + } + } + + // Read from the view to ensure we see it there + rows, err = db.Query("select create_time, managed_group_id, member_id from auth_managed_group_member_account") + require.NoError(t, err) + require.True(t, rows.Next()) + var view_create_time time.Time + var view_managed_group_id, view_member_id string + require.NoError(t, rows.Scan(&view_create_time, &view_managed_group_id, &view_member_id)) + assert.Equal(t, create_time, view_create_time) + assert.Equal(t, managed_group_id, view_managed_group_id) + assert.Equal(t, member_id, view_member_id) +} diff --git a/internal/db/schema/postgres_migration.gen.go b/internal/db/schema/postgres_migration.gen.go index 4d5db009b1..d7623c6d76 100644 --- a/internal/db/schema/postgres_migration.gen.go +++ b/internal/db/schema/postgres_migration.gen.go @@ -4,7 +4,7 @@ package schema func init() { migrationStates["postgres"] = migrationState{ - binarySchemaVersion: 8001, + binarySchemaVersion: 9003, upMigrations: map[int][]byte{ 1: []byte(` create domain wt_public_id as text @@ -6373,6 +6373,195 @@ from session s where sc.session_id = s.public_id; +`), + 9001: []byte(` +-- The base abstract table +create table auth_managed_group ( + public_id wt_public_id + primary key, + auth_method_id wt_public_id + not null, + -- Ensure that if the auth method is deleted (which will also happen if the + -- scope is deleted) this is deleted too + constraint auth_method_fkey + foreign key (auth_method_id) -- fk1 + references auth_method(public_id) + on delete cascade + on update cascade, + constraint auth_managed_group_auth_method_id_public_id_uq + unique(auth_method_id, public_id) +); +comment on table auth_managed_group is +'auth_managed_group is the abstract base table for managed groups.'; + +-- Define the immutable fields of auth_managed_group +create trigger + immutable_columns +before +update on auth_managed_group + for each row execute procedure immutable_columns('public_id', 'auth_method_id'); + +-- Function to insert into the base table when values are inserted into a +-- concrete type table. This happens before inserts so the foreign keys in the +-- concrete type will be valid. +create or replace function + insert_managed_group_subtype() + returns trigger +as $$ +begin + + insert into auth_managed_group + (public_id, auth_method_id) + values + (new.public_id, new.auth_method_id); + + return new; + +end; +$$ language plpgsql; + +-- delete_managed_group_subtype() is an after delete trigger +-- function for subtypes of managed_group +create or replace function delete_managed_group_subtype() + returns trigger +as $$ +begin + delete from auth_managed_group + where public_id = old.public_id; + return null; -- result is ignored since this is an after trigger +end; +$$ language plpgsql; +`), + 9002: []byte(` +create table auth_oidc_managed_group ( + public_id wt_public_id + primary key, + auth_method_id wt_public_id + not null, + name wt_name, + description wt_description, + create_time wt_timestamp, + update_time wt_timestamp, + version wt_version, + filter wt_bexprfilter + not null, + -- Ensure that this managed group relates to an oidc auth method, as opposed + -- to other types + constraint auth_oidc_method_fkey + foreign key (auth_method_id) -- fk1 + references auth_oidc_method (public_id) + on delete cascade + on update cascade, + -- Ensure it relates to an abstract managed group + constraint auth_managed_group_fkey + foreign key (auth_method_id, public_id) -- fk2 + references auth_managed_group (auth_method_id, public_id) + on delete cascade + on update cascade, + constraint auth_oidc_managed_group_auth_method_id_name_uq + unique(auth_method_id, name) +); +comment on table auth_oidc_managed_group is +'auth_oidc_managed_group entries are subtypes of auth_managed_group and represent an oidc managed group.'; + +-- Define the immutable fields of auth_oidc_managed_group +create trigger + immutable_columns +before +update on auth_oidc_managed_group + for each row execute procedure immutable_columns('public_id', 'auth_method_id', 'create_time'); + +-- Populate create time on insert +create trigger + default_create_time_column +before +insert on auth_oidc_managed_group + for each row execute procedure default_create_time(); + +-- Generate update time on update +create trigger + update_time_column +before +update on auth_oidc_managed_group + for each row execute procedure update_time_column(); + +-- Update version when something changes +create trigger + update_version_column +after +update on auth_oidc_managed_group + for each row execute procedure update_version_column(); + +-- Add into the base table when inserting into the concrete table +create trigger + insert_managed_group_subtype +before insert on auth_oidc_managed_group + for each row execute procedure insert_managed_group_subtype(); + +-- Ensure that deletions in the oidc subtype result in deletions to the base +-- table. +create trigger + delete_managed_group_subtype +after +delete on auth_oidc_managed_group + for each row execute procedure delete_managed_group_subtype(); + +-- The tickets for oplog are the subtypes not the base types because no updates +-- are done to any values in the base types. +insert into oplog_ticket + (name, version) +values + ('auth_oidc_managed_group', 1); +`), + 9003: []byte(` +-- Mappings of account to oidc managed groups. This is a non-abstract table with +-- a view (below) so that it is a natural aggregate for the oplog (also below). +create table auth_oidc_managed_group_member_account ( + create_time wt_timestamp, + managed_group_id wt_public_id + references auth_oidc_managed_group(public_id) + on delete cascade + on update cascade, + member_id wt_public_id + references auth_oidc_account(public_id) + on delete cascade + on update cascade, + primary key (managed_group_id, member_id) +); +comment on table auth_oidc_managed_group_member_account is +'auth_oidc_managed_group_member_account is the join table for managed oidc groups and accounts.'; + +-- auth_immutable_managed_oidc_group_member_account() ensures that group members are immutable. +create or replace function + auth_immutable_managed_oidc_group_member_account() + returns trigger +as $$ +begin + raise exception 'managed oidc group members are immutable'; +end; +$$ language plpgsql; + +create trigger + default_create_time_column +before +insert on auth_oidc_managed_group_member_account + for each row execute procedure default_create_time(); + +create trigger + auth_immutable_managed_oidc_group_member_account +before +update on auth_oidc_managed_group_member_account + for each row execute procedure auth_immutable_managed_oidc_group_member_account(); + +-- Initially create the view with just oidc; eventually we can replace this view +-- to union with other subtype tables. +create view auth_managed_group_member_account as +select + oidc.create_time, + oidc.managed_group_id, + oidc.member_id +from + auth_oidc_managed_group_member_account oidc; `), }, } diff --git a/internal/servers/controller/handler_test.go b/internal/servers/controller/handler_test.go index 868db76171..8ed729817c 100644 --- a/internal/servers/controller/handler_test.go +++ b/internal/servers/controller/handler_test.go @@ -20,7 +20,7 @@ import ( func TestAuthenticationHandler(t *testing.T) { c := NewTestController(t, &TestControllerOpts{ DisableAuthorizationFailures: true, - DefaultAuthMethodId: "ampw_1234567890", + DefaultPasswordAuthMethodId: "ampw_1234567890", DefaultLoginName: "admin", DefaultPassword: "password123", }) diff --git a/internal/servers/controller/testing.go b/internal/servers/controller/testing.go index ebbcd10f32..10326bb4fb 100644 --- a/internal/servers/controller/testing.go +++ b/internal/servers/controller/testing.go @@ -27,7 +27,8 @@ import ( ) const ( - DefaultTestAuthMethodId = "ampw_1234567890" + DefaultTestPasswordAuthMethodId = "ampw_1234567890" + DefaultTestOidcAuthMethodId = "amoidc_1234567890" DefaultTestLoginName = "admin" DefaultTestUnprivilegedLoginName = "user" DefaultTestPassword = "passpass" @@ -280,8 +281,11 @@ type TestControllerOpts struct { // set. Config *config.Config - // DefaultAuthMethodId is the default auth method ID to use, if set. - DefaultAuthMethodId string + // DefaultPasswordAuthMethodId is the default password method ID to use, if set. + DefaultPasswordAuthMethodId string + + // DefaultOidcAuthMethodId is the default OIDC method ID to use, if set. + DefaultOidcAuthMethodId string // DefaultLoginName is the login name used when creating the default admin account. DefaultLoginName string @@ -394,10 +398,15 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { opts.Config.Controller.Name = opts.Name } - if opts.DefaultAuthMethodId != "" { - tc.b.DevPasswordAuthMethodId = opts.DefaultAuthMethodId + if opts.DefaultPasswordAuthMethodId != "" { + tc.b.DevPasswordAuthMethodId = opts.DefaultPasswordAuthMethodId + } else { + tc.b.DevPasswordAuthMethodId = DefaultTestPasswordAuthMethodId + } + if opts.DefaultOidcAuthMethodId != "" { + tc.b.DevOidcAuthMethodId = opts.DefaultOidcAuthMethodId } else { - tc.b.DevPasswordAuthMethodId = DefaultTestAuthMethodId + tc.b.DevOidcAuthMethodId = DefaultTestOidcAuthMethodId } if opts.DefaultLoginName != "" { tc.b.DevLoginName = opts.DefaultLoginName @@ -440,6 +449,7 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { if opts.InitialResourcesSuffix != "" { suffix := opts.InitialResourcesSuffix tc.b.DevPasswordAuthMethodId = "ampw_" + suffix + tc.b.DevOidcAuthMethodId = "amoidc_" + suffix tc.b.DevHostCatalogId = "hcst_" + suffix tc.b.DevHostId = "hst_" + suffix tc.b.DevHostSetId = "hsst_" + suffix @@ -500,6 +510,9 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { if _, _, err := tc.b.CreateInitialPasswordAuthMethod(ctx); err != nil { t.Fatal(err) } + if err := tc.b.CreateDevOidcAuthMethod(ctx); err != nil { + t.Fatal(err) + } if !opts.DisableScopesCreation { if _, _, err := tc.b.CreateInitialScopes(ctx); err != nil { t.Fatal(err) @@ -557,17 +570,18 @@ func (tc *TestController) AddClusterControllerMember(t *testing.T, opts *TestCon opts = new(TestControllerOpts) } nextOpts := &TestControllerOpts{ - DatabaseUrl: tc.c.conf.DatabaseUrl, - DefaultAuthMethodId: tc.c.conf.DevPasswordAuthMethodId, - RootKms: tc.c.conf.RootKms, - WorkerAuthKms: tc.c.conf.WorkerAuthKms, - RecoveryKms: tc.c.conf.RecoveryKms, - Name: opts.Name, - Logger: tc.c.conf.Logger, - DefaultLoginName: tc.b.DevLoginName, - DefaultPassword: tc.b.DevPassword, - DisableKmsKeyCreation: true, - DisableAuthMethodCreation: true, + DatabaseUrl: tc.c.conf.DatabaseUrl, + DefaultPasswordAuthMethodId: tc.c.conf.DevPasswordAuthMethodId, + DefaultOidcAuthMethodId: tc.c.conf.DevOidcAuthMethodId, + RootKms: tc.c.conf.RootKms, + WorkerAuthKms: tc.c.conf.WorkerAuthKms, + RecoveryKms: tc.c.conf.RecoveryKms, + Name: opts.Name, + Logger: tc.c.conf.Logger, + DefaultLoginName: tc.b.DevLoginName, + DefaultPassword: tc.b.DevPassword, + DisableKmsKeyCreation: true, + DisableAuthMethodCreation: true, } if opts.Logger != nil { nextOpts.Logger = opts.Logger diff --git a/testing/controller/controller.go b/testing/controller/controller.go index 9a4198de23..cdd432d638 100644 --- a/testing/controller/controller.go +++ b/testing/controller/controller.go @@ -22,7 +22,7 @@ func getOpts(opt ...Option) (*controller.TestControllerOpts, error) { return nil, fmt.Errorf("Cannot provide both WithConfigFile and WithConfigText") } var setDbParams bool - if opts.setDefaultAuthMethodId || opts.setDefaultLoginName || opts.setDefaultPassword { + if opts.setDefaultPasswordAuthMethodId || opts.setDefaultOidcAuthMethodId || opts.setDefaultLoginName || opts.setDefaultPassword { setDbParams = true } if opts.setDisableAuthMethodCreation { @@ -39,19 +39,20 @@ func getOpts(opt ...Option) (*controller.TestControllerOpts, error) { } type option struct { - tcOptions *controller.TestControllerOpts - setWithConfigFile bool - setWithConfigText bool - setDisableAuthMethodCreation bool - setDisableDatabaseCreation bool - setDisableDatabaseDestruction bool - setDefaultAuthMethodId bool - setDefaultLoginName bool - setDefaultPassword bool - setRootKms bool - setWorkerAuthKms bool - setRecoveryKms bool - setDatabaseUrl bool + tcOptions *controller.TestControllerOpts + setWithConfigFile bool + setWithConfigText bool + setDisableAuthMethodCreation bool + setDisableDatabaseCreation bool + setDisableDatabaseDestruction bool + setDefaultPasswordAuthMethodId bool + setDefaultOidcAuthMethodId bool + setDefaultLoginName bool + setDefaultPassword bool + setRootKms bool + setWorkerAuthKms bool + setRecoveryKms bool + setDatabaseUrl bool } type Option func(*option) error @@ -119,10 +120,18 @@ func DisableDatabaseDestruction() Option { } } -func WithDefaultAuthMethodId(id string) Option { +func WithDefaultPasswordAuthMethodId(id string) Option { return func(c *option) error { - c.setDefaultAuthMethodId = true - c.tcOptions.DefaultAuthMethodId = id + c.setDefaultPasswordAuthMethodId = true + c.tcOptions.DefaultPasswordAuthMethodId = id + return nil + } +} + +func WithDefaultOidcAuthMethodId(id string) Option { + return func(c *option) error { + c.setDefaultOidcAuthMethodId = true + c.tcOptions.DefaultOidcAuthMethodId = id return nil } }