Add initial SQL for managed groups (#1253)

pull/1254/head
Jeff Mitchell 5 years ago committed by GitHub
parent ed5e34082c
commit 15bd1a5245
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,437 @@
package base
import (
"context"
"crypto/ed25519"
"fmt"
"net"
"net/url"
"strings"
"github.com/hashicorp/boundary/internal/auth/oidc"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/types/scope"
capoidc "github.com/hashicorp/cap/oidc"
"github.com/hashicorp/go-multierror"
)
func (b *Server) CreateDevDatabase(ctx context.Context, opt ...Option) error {
var container, url, dialect string
var err error
var c func() error
opts := getOpts(opt...)
// We should only get back postgres for now, but laying the foundation for non-postgres
switch opts.withDialect {
case "":
b.Logger.Error("unsupported dialect. wanted: postgres, got: %v", opts.withDialect)
default:
dialect = opts.withDialect
}
switch b.DatabaseUrl {
case "":
c, url, container, err = docker.StartDbInDocker(dialect, docker.WithContainerImage(opts.withContainerImage))
// In case of an error, run the cleanup function. If we pass all errors, c should be set to a noop
// function before returning from this method
defer func() {
if !opts.withSkipDatabaseDestruction {
if c != nil {
if err := c(); err != nil {
b.Logger.Error("error cleaning up docker container", "error", err)
}
}
}
}()
if err == docker.ErrDockerUnsupported {
return err
}
if err != nil {
return fmt.Errorf("unable to start dev database with dialect %s: %w", dialect, err)
}
// Let migrate store manage the dirty bit since dev DBs should be ephemeral anyways.
_, err := schema.MigrateStore(ctx, dialect, url)
if err != nil {
err = fmt.Errorf("unable to initialize dev database with dialect %s: %w", dialect, err)
if c != nil {
err = multierror.Append(err, c())
}
return err
}
b.DevDatabaseCleanupFunc = c
b.DatabaseUrl = url
default:
// Let migrate store manage the dirty bit since dev DBs should be ephemeral anyways.
if _, err := schema.MigrateStore(ctx, dialect, b.DatabaseUrl); err != nil {
err = fmt.Errorf("error initializing store: %w", err)
if c != nil {
err = multierror.Append(err, c())
}
return err
}
}
b.InfoKeys = append(b.InfoKeys, "dev database url")
b.Info["dev database url"] = b.DatabaseUrl
if container != "" {
b.InfoKeys = append(b.InfoKeys, "dev database container")
b.Info["dev database container"] = strings.TrimPrefix(container, "/")
}
if err := b.ConnectToDatabase(dialect); err != nil {
if c != nil {
err = multierror.Append(err, c())
}
return err
}
b.Database.LogMode(true)
if err := b.CreateGlobalKmsKeys(ctx); err != nil {
if c != nil {
err = multierror.Append(err, c())
}
return err
}
if _, err := b.CreateInitialLoginRole(ctx); err != nil {
if c != nil {
err = multierror.Append(err, c())
}
return err
}
if opts.withSkipAuthMethodCreation {
// now that we have passed all the error cases, reset c to be a noop so the
// defer doesn't do anything.
c = func() error { return nil }
return nil
}
if _, _, err := b.CreateInitialPasswordAuthMethod(ctx); err != nil {
return err
}
if err := b.CreateDevOidcAuthMethod(ctx); err != nil {
return err
}
if opts.withSkipScopesCreation {
// now that we have passed all the error cases, reset c to be a noop so the
// defer doesn't do anything.
c = func() error { return nil }
return nil
}
if _, _, err := b.CreateInitialScopes(ctx); err != nil {
return err
}
if opts.withSkipHostResourcesCreation {
// now that we have passed all the error cases, reset c to be a noop so the
// defer doesn't do anything.
c = func() error { return nil }
return nil
}
if _, _, _, err := b.CreateInitialHostResources(context.Background()); err != nil {
return err
}
if opts.withSkipTargetCreation {
// now that we have passed all the error cases, reset c to be a noop so the
// defer doesn't do anything.
c = func() error { return nil }
return nil
}
if _, err := b.CreateInitialTarget(ctx); err != nil {
return err
}
// now that we have passed all the error cases, reset c to be a noop so the
// defer doesn't do anything.
c = func() error { return nil }
return nil
}
type oidcSetup struct {
clientId string
clientSecret oidc.ClientSecret
oidcPort int
callbackPort string
hostAddr string
authMethod *oidc.AuthMethod
pubKey []byte
privKey []byte
testProvider *capoidc.TestProvider
createUnpriv bool
callbackUrl *url.URL
}
func (b *Server) CreateDevOidcAuthMethod(ctx context.Context) error {
var err error
if b.DevOidcAuthMethodId == "" {
b.DevOidcAuthMethodId, err = db.NewPublicId(oidc.AuthMethodPrefix)
if err != nil {
return fmt.Errorf("error generating initial oidc auth method id: %w", err)
}
}
b.InfoKeys = append(b.InfoKeys, "generated oidc auth method id")
b.Info["generated oidc auth method id"] = b.DevOidcAuthMethodId
switch {
case b.DevUnprivilegedLoginName == "",
b.DevUnprivilegedPassword == "",
b.DevUnprivilegedUserId == "":
default:
b.DevOidcSetup.createUnpriv = true
}
// Trawl through the listeners and find the api listener so we can use the
// same host name/IP
{
for _, ln := range b.Listeners {
purpose := strings.ToLower(ln.Config.Purpose[0])
if purpose != "api" {
continue
}
b.DevOidcSetup.hostAddr, b.DevOidcSetup.callbackPort, err = net.SplitHostPort(ln.Config.Address)
if err != nil {
if strings.Contains(err.Error(), "missing port") {
b.DevOidcSetup.hostAddr = ln.Config.Address
// Use the default API port in the callback
b.DevOidcSetup.callbackPort = "9200"
} else {
return fmt.Errorf("error splitting host/port: %w", err)
}
}
}
if b.DevOidcSetup.hostAddr == "" {
return fmt.Errorf("could not determine address to use for built-in oidc dev listener")
}
}
// Find an available port -- allocate one, then close the listener, and
// re-use it. This is a sort of hacky way to get around the chicken and egg
// of the auth method needing to know the discovery URL and the test
// provider needing to know the callback URL.
l, err := net.Listen("tcp", fmt.Sprintf("%s:0", b.DevOidcSetup.hostAddr))
if err != nil {
return fmt.Errorf("error finding port for oidc test provider: %w", err)
}
b.DevOidcSetup.oidcPort = l.(*net.TCPListener).Addr().(*net.TCPAddr).Port
if err := l.Close(); err != nil {
return fmt.Errorf("error closing initial test port: %w", err)
}
b.DevOidcSetup.callbackUrl, err = url.Parse(fmt.Sprintf("http://%s:%s", b.DevOidcSetup.hostAddr, b.DevOidcSetup.callbackPort))
if err != nil {
return fmt.Errorf("error parsing oidc test provider callback url: %w", err)
}
// Generate initial IDs/keys
{
b.DevOidcSetup.clientId, err = capoidc.NewID()
if err != nil {
return fmt.Errorf("unable to generate client id: %w", err)
}
clientSecret, err := capoidc.NewID()
if err != nil {
return fmt.Errorf("unable to generate client secret: %w", err)
}
b.DevOidcSetup.clientSecret = oidc.ClientSecret(clientSecret)
b.DevOidcSetup.pubKey, b.DevOidcSetup.privKey, err = ed25519.GenerateKey(nil)
if err != nil {
return fmt.Errorf("unable to generate signing key: %w", err)
}
}
// Create the subject information and testing provider
{
logger, err := capoidc.NewTestingLogger(b.Logger.Named("dev-oidc"))
if err != nil {
return fmt.Errorf("unable to create logger: %w", err)
}
subInfo := map[string]*capoidc.TestSubject{
b.DevLoginName: {
Password: b.DevPassword,
UserInfo: map[string]interface{}{
"email": "admin@localhost",
"name": "Admin User",
},
},
}
if b.DevOidcSetup.createUnpriv {
subInfo[b.DevUnprivilegedLoginName] = &capoidc.TestSubject{
Password: b.DevUnprivilegedPassword,
UserInfo: map[string]interface{}{
"email": "user@localhost",
"name": "Unprivileged User",
},
}
}
clientSecret := string(b.DevOidcSetup.clientSecret)
b.DevOidcSetup.testProvider = capoidc.StartTestProvider(
logger,
capoidc.WithNoTLS(),
capoidc.WithTestHost(b.DevOidcSetup.hostAddr),
capoidc.WithTestPort(b.DevOidcSetup.oidcPort),
capoidc.WithTestDefaults(&capoidc.TestProviderDefaults{
CustomClaims: map[string]interface{}{
"mode": "dev",
},
SubjectInfo: subInfo,
SigningKey: &capoidc.TestSigningKey{
PrivKey: ed25519.PrivateKey(b.DevOidcSetup.privKey),
PubKey: ed25519.PublicKey(b.DevOidcSetup.pubKey),
Alg: capoidc.EdDSA,
},
AllowedRedirectURIs: []string{fmt.Sprintf("%s/v1/auth-methods/oidc:authenticate:callback", b.DevOidcSetup.callbackUrl.String())},
ClientID: &b.DevOidcSetup.clientId,
ClientSecret: &clientSecret,
}))
b.ShutdownFuncs = append(b.ShutdownFuncs, func() error {
b.DevOidcSetup.testProvider.Stop()
return nil
})
}
// Create auth method and link accounts
{
b.DevOidcSetup.authMethod, err = b.createInitialOidcAuthMethod(ctx)
if err != nil {
return fmt.Errorf("error creating initial oidc auth method: %w", err)
}
}
return nil
}
func (b *Server) createInitialOidcAuthMethod(ctx context.Context) (*oidc.AuthMethod, error) {
rw := db.New(b.Database)
kmsRepo, err := kms.NewRepository(rw, rw)
if err != nil {
return nil, fmt.Errorf("error creating kms repository: %w", err)
}
kmsCache, err := kms.NewKms(kmsRepo, kms.WithLogger(b.Logger.Named("kms")))
if err != nil {
return nil, fmt.Errorf("error creating kms cache: %w", err)
}
if err := kmsCache.AddExternalWrappers(
kms.WithRootWrapper(b.RootKms),
); err != nil {
return nil, fmt.Errorf("error adding config keys to kms: %w", err)
}
discoveryUrl, err := url.Parse(fmt.Sprintf("http://%s:%d", b.DevOidcSetup.hostAddr, b.DevOidcSetup.oidcPort))
if err != nil {
return nil, fmt.Errorf("error parsing oidc test provider address: %w", err)
}
// Create the auth method
oidcRepo, err := oidc.NewRepository(rw, rw, kmsCache)
if err != nil {
return nil, fmt.Errorf("error creating oidc repo: %w", err)
}
authMethod, err := oidc.NewAuthMethod(
scope.Global.String(),
b.DevOidcSetup.clientId,
b.DevOidcSetup.clientSecret,
oidc.WithName("Generated global scope initial oidc auth method"),
oidc.WithDescription("Provides initial administrative and unprivileged authentication into Boundary"),
oidc.WithIssuer(discoveryUrl),
oidc.WithApiUrl(b.DevOidcSetup.callbackUrl),
oidc.WithSigningAlgs(oidc.EdDSA),
oidc.WithOperationalState(oidc.ActivePublicState))
if err != nil {
return nil, fmt.Errorf("error creating new in memory oidc auth method: %w", err)
}
if b.DevOidcAuthMethodId == "" {
b.DevOidcAuthMethodId, err = db.NewPublicId(oidc.AuthMethodPrefix)
if err != nil {
return nil, fmt.Errorf("error generating initial oidc auth method id: %w", err)
}
}
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-b.ShutdownCh:
cancel()
case <-cancelCtx.Done():
}
}()
b.DevOidcSetup.authMethod, err = oidcRepo.CreateAuthMethod(
cancelCtx,
authMethod,
oidc.WithPublicId(b.DevOidcAuthMethodId))
if err != nil {
return nil, fmt.Errorf("error saving oidc auth method to the db: %w", err)
}
// Create accounts
{
createAndLinkAccount := func(loginName, userId, typ string) error {
acct, err := oidc.NewAccount(
b.DevOidcSetup.authMethod.GetPublicId(),
loginName,
oidc.WithDescription(fmt.Sprintf("Initial %s OIDC account", typ)),
)
if err != nil {
return fmt.Errorf("error generating %s oidc account: %w", typ, err)
}
acct, err = oidcRepo.CreateAccount(
cancelCtx,
b.DevOidcSetup.authMethod.GetScopeId(),
acct,
)
if err != nil {
return fmt.Errorf("error creating %s oidc account: %w", typ, err)
}
// Link accounts to existing user
iamRepo, err := iam.NewRepository(rw, rw, kmsCache)
if err != nil {
return fmt.Errorf("unable to create iam repo: %w", err)
}
u, _, err := iamRepo.LookupUser(cancelCtx, userId)
if err != nil {
return fmt.Errorf("error looking up %s user: %w", typ, err)
}
if _, err = iamRepo.AddUserAccounts(cancelCtx, u.GetPublicId(), u.GetVersion(), []string{acct.GetPublicId()}); err != nil {
return fmt.Errorf("error associating initial %s user with account: %w", typ, err)
}
return nil
}
if err := createAndLinkAccount(b.DevLoginName, b.DevUserId, "admin"); err != nil {
return nil, err
}
if b.DevOidcSetup.createUnpriv {
if err := createAndLinkAccount(b.DevUnprivilegedLoginName, b.DevUnprivilegedUserId, "unprivileged"); err != nil {
return nil, err
}
}
}
return nil, nil
}

@ -20,8 +20,6 @@ import (
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/boundary/sdk/strutil"
@ -88,6 +86,8 @@ type Server struct {
DevTargetSessionMaxSeconds int
DevTargetSessionConnectionLimit int
DevOidcSetup oidcSetup
DatabaseUrl string
DatabaseMaxOpenConnections int
DevDatabaseCleanupFunc func() error
@ -467,145 +467,6 @@ func (b *Server) ConnectToDatabase(dialect string) error {
return nil
}
func (b *Server) CreateDevDatabase(ctx context.Context, opt ...Option) error {
var container, url, dialect string
var err error
var c func() error
opts := getOpts(opt...)
// We should only get back postgres for now, but laying the foundation for non-postgres
switch opts.withDialect {
case "":
b.Logger.Error("unsupported dialect. wanted: postgres, got: %v", opts.withDialect)
default:
dialect = opts.withDialect
}
switch b.DatabaseUrl {
case "":
c, url, container, err = docker.StartDbInDocker(dialect, docker.WithContainerImage(opts.withContainerImage))
// In case of an error, run the cleanup function. If we pass all errors, c should be set to a noop
// function before returning from this method
defer func() {
if !opts.withSkipDatabaseDestruction {
if c != nil {
if err := c(); err != nil {
b.Logger.Error("error cleaning up docker container", "error", err)
}
}
}
}()
if err == docker.ErrDockerUnsupported {
return err
}
if err != nil {
return fmt.Errorf("unable to start dev database with dialect %s: %w", dialect, err)
}
// Let migrate store manage the dirty bit since dev DBs should be ephemeral anyways.
_, err := schema.MigrateStore(ctx, dialect, url)
if err != nil {
err = fmt.Errorf("unable to initialize dev database with dialect %s: %w", dialect, err)
if c != nil {
err = multierror.Append(err, c())
}
return err
}
b.DevDatabaseCleanupFunc = c
b.DatabaseUrl = url
default:
// Let migrate store manage the dirty bit since dev DBs should be ephemeral anyways.
if _, err := schema.MigrateStore(ctx, dialect, b.DatabaseUrl); err != nil {
err = fmt.Errorf("error initializing store: %w", err)
if c != nil {
err = multierror.Append(err, c())
}
return err
}
}
b.InfoKeys = append(b.InfoKeys, "dev database url")
b.Info["dev database url"] = b.DatabaseUrl
if container != "" {
b.InfoKeys = append(b.InfoKeys, "dev database container")
b.Info["dev database container"] = strings.TrimPrefix(container, "/")
}
if err := b.ConnectToDatabase(dialect); err != nil {
if c != nil {
err = multierror.Append(err, c())
}
return err
}
b.Database.LogMode(true)
if err := b.CreateGlobalKmsKeys(ctx); err != nil {
if c != nil {
err = multierror.Append(err, c())
}
return err
}
if _, err := b.CreateInitialLoginRole(ctx); err != nil {
if c != nil {
err = multierror.Append(err, c())
}
return err
}
if opts.withSkipAuthMethodCreation {
// now that we have passed all the error cases, reset c to be a noop so the
// defer doesn't do anything.
c = func() error { return nil }
return nil
}
if _, _, err := b.CreateInitialPasswordAuthMethod(ctx); err != nil {
return err
}
if opts.withSkipScopesCreation {
// now that we have passed all the error cases, reset c to be a noop so the
// defer doesn't do anything.
c = func() error { return nil }
return nil
}
if _, _, err := b.CreateInitialScopes(ctx); err != nil {
return err
}
if opts.withSkipHostResourcesCreation {
// now that we have passed all the error cases, reset c to be a noop so the
// defer doesn't do anything.
c = func() error { return nil }
return nil
}
if _, _, _, err := b.CreateInitialHostResources(context.Background()); err != nil {
return err
}
if opts.withSkipTargetCreation {
// now that we have passed all the error cases, reset c to be a noop so the
// defer doesn't do anything.
c = func() error { return nil }
return nil
}
if _, err := b.CreateInitialTarget(ctx); err != nil {
return err
}
// now that we have passed all the error cases, reset c to be a noop so the
// defer doesn't do anything.
c = func() error { return nil }
return nil
}
func (b *Server) CreateGlobalKmsKeys(ctx context.Context) error {
rw := db.New(b.Database)

@ -1,11 +1,8 @@
package dev
import (
"context"
"crypto/ed25519"
"fmt"
"net"
"net/url"
"runtime"
"strings"
@ -13,17 +10,14 @@ import (
"github.com/hashicorp/boundary/internal/auth/password"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/host/static"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/servers/controller"
"github.com/hashicorp/boundary/internal/servers/controller/handlers"
"github.com/hashicorp/boundary/internal/servers/worker"
"github.com/hashicorp/boundary/internal/target"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/boundary/sdk/strutil"
capoidc "github.com/hashicorp/cap/oidc"
"github.com/mitchellh/cli"
"github.com/posener/complete"
)
@ -45,8 +39,6 @@ type Command struct {
controller *controller.Controller
worker *worker.Worker
oidcSetup oidcSetup
flagLogLevel string
flagLogFormat string
flagCombineLogs bool
@ -497,11 +489,6 @@ func (c *Command) Run(args []string) int {
}
}
if err := c.startDevOidcAuthMethod(); err != nil {
c.UI.Error(fmt.Errorf("Error starting dev OIDC auth method: %w", err).Error())
return base.CommandCliError
}
c.PrintInfo(c.UI)
c.ReleaseLogGate()
@ -585,277 +572,3 @@ func (c *Command) Run(args []string) int {
return base.CommandSuccess
}
type oidcSetup struct {
clientId string
clientSecret oidc.ClientSecret
oidcPort int
callbackPort string
hostAddr string
authMethod *oidc.AuthMethod
pubKey []byte
privKey []byte
testProvider *capoidc.TestProvider
createUnpriv bool
callbackUrl *url.URL
}
func (c *Command) startDevOidcAuthMethod() error {
var err error
if c.DevOidcAuthMethodId == "" {
c.DevOidcAuthMethodId, err = db.NewPublicId(oidc.AuthMethodPrefix)
if err != nil {
return fmt.Errorf("error generating initial oidc auth method id: %w", err)
}
}
c.InfoKeys = append(c.InfoKeys, "generated oidc auth method id")
c.Info["generated oidc auth method id"] = c.DevOidcAuthMethodId
switch {
case c.DevUnprivilegedLoginName == "",
c.DevUnprivilegedPassword == "",
c.DevUnprivilegedUserId == "":
default:
c.oidcSetup.createUnpriv = true
}
// Trawl through the listeners and find the api listener so we can use the
// same host name/IP
{
for _, lnConfig := range c.Config.Listeners {
purpose := strings.ToLower(lnConfig.Purpose[0])
if purpose != "api" {
continue
}
c.oidcSetup.hostAddr, c.oidcSetup.callbackPort, err = net.SplitHostPort(lnConfig.Address)
if err != nil {
if strings.Contains(err.Error(), "missing port") {
c.oidcSetup.hostAddr = lnConfig.Address
// Use the default API port in the callback
c.oidcSetup.callbackPort = "9200"
} else {
return fmt.Errorf("error splitting host/port: %w", err)
}
}
}
if c.oidcSetup.hostAddr == "" {
return fmt.Errorf("could not determine address to use for built-in oidc dev listener")
}
}
// Find an available port -- allocate one, then close the listener, and
// re-use it. This is a sort of hacky way to get around the chicken and egg
// of the auth method needing to know the discovery URL and the test
// provider needing to know the callback URL.
l, err := net.Listen("tcp", fmt.Sprintf("%s:0", c.oidcSetup.hostAddr))
if err != nil {
return fmt.Errorf("error finding port for oidc test provider: %w", err)
}
c.oidcSetup.oidcPort = l.(*net.TCPListener).Addr().(*net.TCPAddr).Port
if err := l.Close(); err != nil {
return fmt.Errorf("error closing initial test port: %w", err)
}
c.oidcSetup.callbackUrl, err = url.Parse(fmt.Sprintf("http://%s:%s", c.oidcSetup.hostAddr, c.oidcSetup.callbackPort))
if err != nil {
return fmt.Errorf("error parsing oidc test provider callback url: %w", err)
}
// Generate initial IDs/keys
{
c.oidcSetup.clientId, err = capoidc.NewID()
if err != nil {
return fmt.Errorf("unable to generate client id: %w", err)
}
clientSecret, err := capoidc.NewID()
if err != nil {
return fmt.Errorf("unable to generate client secret: %w", err)
}
c.oidcSetup.clientSecret = oidc.ClientSecret(clientSecret)
c.oidcSetup.pubKey, c.oidcSetup.privKey, err = ed25519.GenerateKey(nil)
if err != nil {
return fmt.Errorf("unable to generate signing key: %w", err)
}
}
// Create the subject information and testing provider
{
logger, err := capoidc.NewTestingLogger(c.Logger.Named("dev-oidc"))
if err != nil {
return fmt.Errorf("unable to create logger: %w", err)
}
subInfo := map[string]*capoidc.TestSubject{
c.DevLoginName: {
Password: c.DevPassword,
UserInfo: map[string]interface{}{
"email": "admin@localhost",
"name": "Admin User",
},
},
}
if c.oidcSetup.createUnpriv {
subInfo[c.DevUnprivilegedLoginName] = &capoidc.TestSubject{
Password: c.DevUnprivilegedPassword,
UserInfo: map[string]interface{}{
"email": "user@localhost",
"name": "Unprivileged User",
},
}
}
clientSecret := string(c.oidcSetup.clientSecret)
c.oidcSetup.testProvider = capoidc.StartTestProvider(
logger,
capoidc.WithNoTLS(),
capoidc.WithTestHost(c.oidcSetup.hostAddr),
capoidc.WithTestPort(c.oidcSetup.oidcPort),
capoidc.WithTestDefaults(&capoidc.TestProviderDefaults{
CustomClaims: map[string]interface{}{
"mode": "dev",
},
SubjectInfo: subInfo,
SigningKey: &capoidc.TestSigningKey{
PrivKey: ed25519.PrivateKey(c.oidcSetup.privKey),
PubKey: ed25519.PublicKey(c.oidcSetup.pubKey),
Alg: capoidc.EdDSA,
},
AllowedRedirectURIs: []string{fmt.Sprintf("%s/v1/auth-methods/oidc:authenticate:callback", c.oidcSetup.callbackUrl.String())},
ClientID: &c.oidcSetup.clientId,
ClientSecret: &clientSecret,
}))
c.ShutdownFuncs = append(c.ShutdownFuncs, func() error {
c.oidcSetup.testProvider.Stop()
return nil
})
}
// Create auth method and link accounts
{
c.oidcSetup.authMethod, err = c.createInitialOidcAuthMethod()
if err != nil {
return fmt.Errorf("error creating initial oidc auth method: %w", err)
}
}
return nil
}
func (c *Command) createInitialOidcAuthMethod() (*oidc.AuthMethod, error) {
rw := db.New(c.Database)
kmsRepo, err := kms.NewRepository(rw, rw)
if err != nil {
return nil, fmt.Errorf("error creating kms repository: %w", err)
}
kmsCache, err := kms.NewKms(kmsRepo, kms.WithLogger(c.Logger.Named("kms")))
if err != nil {
return nil, fmt.Errorf("error creating kms cache: %w", err)
}
if err := kmsCache.AddExternalWrappers(
kms.WithRootWrapper(c.RootKms),
); err != nil {
return nil, fmt.Errorf("error adding config keys to kms: %w", err)
}
discoveryUrl, err := url.Parse(fmt.Sprintf("http://%s:%d", c.oidcSetup.hostAddr, c.oidcSetup.oidcPort))
if err != nil {
return nil, fmt.Errorf("error parsing oidc test provider address: %w", err)
}
// Create the auth method
oidcRepo, err := oidc.NewRepository(rw, rw, kmsCache)
if err != nil {
return nil, fmt.Errorf("error creating oidc repo: %w", err)
}
authMethod, err := oidc.NewAuthMethod(
scope.Global.String(),
c.oidcSetup.clientId,
c.oidcSetup.clientSecret,
oidc.WithName("Generated global scope initial oidc auth method"),
oidc.WithDescription("Provides initial administrative and unprivileged authentication into Boundary"),
oidc.WithIssuer(discoveryUrl),
oidc.WithApiUrl(c.oidcSetup.callbackUrl),
oidc.WithSigningAlgs(oidc.EdDSA),
oidc.WithOperationalState(oidc.ActivePublicState))
if err != nil {
return nil, fmt.Errorf("error creating new in memory oidc auth method: %w", err)
}
if c.DevOidcAuthMethodId == "" {
c.DevOidcAuthMethodId, err = db.NewPublicId(oidc.AuthMethodPrefix)
if err != nil {
return nil, fmt.Errorf("error generating initial oidc auth method id: %w", err)
}
}
cancelCtx, cancel := context.WithCancel(c.Context)
defer cancel()
go func() {
select {
case <-c.ShutdownCh:
cancel()
case <-cancelCtx.Done():
}
}()
c.oidcSetup.authMethod, err = oidcRepo.CreateAuthMethod(
cancelCtx,
authMethod,
oidc.WithPublicId(c.DevOidcAuthMethodId))
if err != nil {
return nil, fmt.Errorf("error saving oidc auth method to the db: %w", err)
}
// Create accounts
{
createAndLinkAccount := func(loginName, userId, typ string) error {
acct, err := oidc.NewAccount(
c.oidcSetup.authMethod.GetPublicId(),
loginName,
oidc.WithDescription(fmt.Sprintf("Initial %s OIDC account", typ)),
)
if err != nil {
return fmt.Errorf("error generating %s oidc account: %w", typ, err)
}
acct, err = oidcRepo.CreateAccount(
cancelCtx,
c.oidcSetup.authMethod.GetScopeId(),
acct,
)
if err != nil {
return fmt.Errorf("error creating %s oidc account: %w", typ, err)
}
// Link accounts to existing user
iamRepo, err := iam.NewRepository(rw, rw, kmsCache)
if err != nil {
return fmt.Errorf("unable to create iam repo: %w", err)
}
u, _, err := iamRepo.LookupUser(cancelCtx, userId)
if err != nil {
return fmt.Errorf("error looking up %s user: %w", typ, err)
}
if _, err = iamRepo.AddUserAccounts(cancelCtx, u.GetPublicId(), u.GetVersion(), []string{acct.GetPublicId()}); err != nil {
return fmt.Errorf("error associating initial %s user with account: %w", typ, err)
}
return nil
}
if err := createAndLinkAccount(c.DevLoginName, c.DevUserId, "admin"); err != nil {
return nil, err
}
if c.oidcSetup.createUnpriv {
if err := createAndLinkAccount(c.DevUnprivilegedLoginName, c.DevUnprivilegedUserId, "unprivileged"); err != nil {
return nil, err
}
}
}
return nil, nil
}

@ -0,0 +1,60 @@
begin;
-- The base abstract table
create table auth_managed_group (
public_id wt_public_id
primary key,
auth_method_id wt_public_id
not null,
-- Ensure that if the auth method is deleted (which will also happen if the
-- scope is deleted) this is deleted too
constraint auth_method_fkey
foreign key (auth_method_id) -- fk1
references auth_method(public_id)
on delete cascade
on update cascade,
constraint auth_managed_group_auth_method_id_public_id_uq
unique(auth_method_id, public_id)
);
comment on table auth_managed_group is
'auth_managed_group is the abstract base table for managed groups.';
-- Define the immutable fields of auth_managed_group
create trigger
immutable_columns
before
update on auth_managed_group
for each row execute procedure immutable_columns('public_id', 'auth_method_id');
-- Function to insert into the base table when values are inserted into a
-- concrete type table. This happens before inserts so the foreign keys in the
-- concrete type will be valid.
create or replace function
insert_managed_group_subtype()
returns trigger
as $$
begin
insert into auth_managed_group
(public_id, auth_method_id)
values
(new.public_id, new.auth_method_id);
return new;
end;
$$ language plpgsql;
-- delete_managed_group_subtype() is an after delete trigger
-- function for subtypes of managed_group
create or replace function delete_managed_group_subtype()
returns trigger
as $$
begin
delete from auth_managed_group
where public_id = old.public_id;
return null; -- result is ignored since this is an after trigger
end;
$$ language plpgsql;
commit;

@ -0,0 +1,83 @@
begin;
create table auth_oidc_managed_group (
public_id wt_public_id
primary key,
auth_method_id wt_public_id
not null,
name wt_name,
description wt_description,
create_time wt_timestamp,
update_time wt_timestamp,
version wt_version,
filter wt_bexprfilter
not null,
-- Ensure that this managed group relates to an oidc auth method, as opposed
-- to other types
constraint auth_oidc_method_fkey
foreign key (auth_method_id) -- fk1
references auth_oidc_method (public_id)
on delete cascade
on update cascade,
-- Ensure it relates to an abstract managed group
constraint auth_managed_group_fkey
foreign key (auth_method_id, public_id) -- fk2
references auth_managed_group (auth_method_id, public_id)
on delete cascade
on update cascade,
constraint auth_oidc_managed_group_auth_method_id_name_uq
unique(auth_method_id, name)
);
comment on table auth_oidc_managed_group is
'auth_oidc_managed_group entries are subtypes of auth_managed_group and represent an oidc managed group.';
-- Define the immutable fields of auth_oidc_managed_group
create trigger
immutable_columns
before
update on auth_oidc_managed_group
for each row execute procedure immutable_columns('public_id', 'auth_method_id', 'create_time');
-- Populate create time on insert
create trigger
default_create_time_column
before
insert on auth_oidc_managed_group
for each row execute procedure default_create_time();
-- Generate update time on update
create trigger
update_time_column
before
update on auth_oidc_managed_group
for each row execute procedure update_time_column();
-- Update version when something changes
create trigger
update_version_column
after
update on auth_oidc_managed_group
for each row execute procedure update_version_column();
-- Add into the base table when inserting into the concrete table
create trigger
insert_managed_group_subtype
before insert on auth_oidc_managed_group
for each row execute procedure insert_managed_group_subtype();
-- Ensure that deletions in the oidc subtype result in deletions to the base
-- table.
create trigger
delete_managed_group_subtype
after
delete on auth_oidc_managed_group
for each row execute procedure delete_managed_group_subtype();
-- The tickets for oplog are the subtypes not the base types because no updates
-- are done to any values in the base types.
insert into oplog_ticket
(name, version)
values
('auth_oidc_managed_group', 1);
commit;

@ -0,0 +1,52 @@
begin;
-- Mappings of account to oidc managed groups. This is a non-abstract table with
-- a view (below) so that it is a natural aggregate for the oplog (also below).
create table auth_oidc_managed_group_member_account (
create_time wt_timestamp,
managed_group_id wt_public_id
references auth_oidc_managed_group(public_id)
on delete cascade
on update cascade,
member_id wt_public_id
references auth_oidc_account(public_id)
on delete cascade
on update cascade,
primary key (managed_group_id, member_id)
);
comment on table auth_oidc_managed_group_member_account is
'auth_oidc_managed_group_member_account is the join table for managed oidc groups and accounts.';
-- auth_immutable_managed_oidc_group_member_account() ensures that group members are immutable.
create or replace function
auth_immutable_managed_oidc_group_member_account()
returns trigger
as $$
begin
raise exception 'managed oidc group members are immutable';
end;
$$ language plpgsql;
create trigger
default_create_time_column
before
insert on auth_oidc_managed_group_member_account
for each row execute procedure default_create_time();
create trigger
auth_immutable_managed_oidc_group_member_account
before
update on auth_oidc_managed_group_member_account
for each row execute procedure auth_immutable_managed_oidc_group_member_account();
-- Initially create the view with just oidc; eventually we can replace this view
-- to union with other subtype tables.
create view auth_managed_group_member_account as
select
oidc.create_time,
oidc.managed_group_id,
oidc.member_id
from
auth_oidc_managed_group_member_account oidc;
commit;

@ -0,0 +1,376 @@
package migration
import (
"fmt"
"testing"
"time"
"github.com/hashicorp/boundary/internal/servers/controller"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_ManagedGroupTable(t *testing.T) {
t.Parallel()
tc := controller.NewTestController(t, nil)
defer tc.Shutdown()
db := tc.DbConn().DB()
var err error
managedGroupId := "a_bcdefghijk"
defaultPasswordAuthMethodId := "ampw_1234567890"
defaultOidcAuthMethodId := "amoidc_1234567890"
insertTests := []struct {
testName string
publicId string
authMethodId string
wantErr bool
}{
{
testName: "invalid auth method",
publicId: managedGroupId,
authMethodId: "amoid_1234567890",
wantErr: true,
},
{
testName: "valid",
publicId: managedGroupId,
authMethodId: defaultOidcAuthMethodId,
wantErr: false,
},
}
for _, tt := range insertTests {
t.Run("insert: "+tt.testName, func(t *testing.T) {
require := require.New(t)
_, err = db.Exec("insert into auth_managed_group values ($1, $2)",
tt.publicId,
tt.authMethodId)
require.True(tt.wantErr == (err != nil))
})
}
updateTests := []struct {
testName string
column string
value string
publicId string
wantErr bool
}{
{
testName: "immutable public id",
column: "public_id",
value: "z_yxwvutsrqp",
publicId: managedGroupId,
wantErr: true,
},
{
testName: "immutable auth method",
column: "auth_method_id",
value: defaultPasswordAuthMethodId,
publicId: managedGroupId,
wantErr: true,
},
}
for _, tt := range updateTests {
t.Run("update: "+tt.testName, func(t *testing.T) {
assert := assert.New(t)
_, err = db.Exec(fmt.Sprintf("update auth_managed_group set %s = $1 where public_id = $2", tt.column), tt.value, tt.publicId)
assert.True(tt.wantErr == (err != nil))
})
}
}
func Test_OidcManagedGroupTable(t *testing.T) {
t.Parallel()
tc := controller.NewTestController(t, nil)
defer tc.Shutdown()
db := tc.DbConn().DB()
var err error
managedGroupId := "a_bcdefghijk"
defaultPasswordAuthMethodId := "ampw_1234567890"
defaultOidcAuthMethodId := "amoidc_1234567890"
name := "this is the name"
filter := "this is a filter"
// The first set of tests is for initial insertion
{
insertTests := []struct {
testName string
publicId string
authMethodId string
name string
filter string
wantErr bool
}{
{
testName: "null filter",
publicId: managedGroupId,
authMethodId: "amoid_1234567890",
name: name,
filter: "",
wantErr: true,
},
{
testName: "invalid auth method",
publicId: managedGroupId,
authMethodId: defaultPasswordAuthMethodId,
name: name,
filter: filter,
wantErr: true,
},
{
testName: "valid",
publicId: managedGroupId,
authMethodId: defaultOidcAuthMethodId,
name: name,
filter: filter,
wantErr: false,
},
{
testName: "duplicate public id",
publicId: managedGroupId,
authMethodId: defaultOidcAuthMethodId,
name: name,
filter: filter,
wantErr: true,
},
{
testName: "duplicate name",
publicId: "z_yxwvutsrqp",
authMethodId: defaultOidcAuthMethodId,
name: name,
filter: filter,
wantErr: true,
},
}
for _, tt := range insertTests {
t.Run("insert: "+tt.testName, func(t *testing.T) {
require := require.New(t)
_, err = db.Exec("insert into auth_oidc_managed_group (public_id, auth_method_id, name, filter) values ($1, $2, $3, $4)",
tt.publicId,
tt.authMethodId,
tt.name,
tt.filter)
require.True(tt.wantErr == (err != nil))
})
}
}
// Read some values to validate that things were set automatically
rows, err := db.Query("select create_time, update_time, version from auth_oidc_managed_group")
require.NoError(t, err)
require.True(t, rows.Next())
var create_time, update_time time.Time
var version int
require.NoError(t, rows.Scan(&create_time, &update_time, &version))
assert.False(t, create_time.IsZero())
assert.Equal(t, update_time, create_time)
assert.Equal(t, 1, version)
// These update tests check immutability
{
updateTests := []struct {
testName string
column string
value interface{}
wantErr bool
}{
{
testName: "immutable public id",
column: "public_id",
value: "z_yxwvutsrqp",
wantErr: true,
},
{
testName: "immutable auth method",
column: "auth_method_id",
value: defaultPasswordAuthMethodId,
wantErr: true,
},
{
testName: "immutable creation time",
column: "create_time",
value: time.Now(),
wantErr: true,
},
{
testName: "valid",
column: "description",
value: "this is the description",
wantErr: false,
},
}
for _, tt := range updateTests {
t.Run("update: "+tt.testName, func(t *testing.T) {
require := require.New(t)
_, err = db.Exec(fmt.Sprintf("update auth_oidc_managed_group set %s = $1 where public_id = $2", tt.column), tt.value, managedGroupId)
require.True(tt.wantErr == (err != nil))
})
}
}
// Read values again to validate that things were updated automatically
rows, err = db.Query("select create_time, update_time, version from auth_oidc_managed_group")
require.NoError(t, err)
require.True(t, rows.Next())
var updated_create_time, updated_update_time time.Time
require.NoError(t, rows.Scan(&updated_create_time, &updated_update_time, &version))
assert.Equal(t, create_time, updated_create_time)
assert.NotEqual(t, update_time, updated_update_time)
assert.Equal(t, 2, version)
// Read values from auth_managed_group to ensure it was populated automatically
rows, err = db.Query("select public_id, auth_method_id from auth_managed_group")
require.NoError(t, err)
require.True(t, rows.Next())
var public_id, auth_method_id string
require.NoError(t, rows.Scan(&public_id, &auth_method_id))
assert.Equal(t, managedGroupId, public_id)
assert.Equal(t, defaultOidcAuthMethodId, auth_method_id)
// Delete the value from the subtype table
res, err := db.Exec("delete from auth_oidc_managed_group where public_id = $1", managedGroupId)
require.NoError(t, err)
affected, err := res.RowsAffected()
require.NoError(t, err)
require.EqualValues(t, 1, affected)
// It should no longer be in the base table
rows, err = db.Query("select public_id, auth_method_id from auth_managed_group")
require.NoError(t, err)
require.False(t, rows.Next())
}
func Test_AuthManagedOidcGroupMemberAccountTable(t *testing.T) {
t.Parallel()
tc := controller.NewTestController(t, nil)
defer tc.Shutdown()
db := tc.DbConn().DB()
var err error
managedGroupId := "a_bcdefghijk"
defaultOidcAuthMethodId := "amoidc_1234567890"
name := "this is the name"
filter := "this is a filter"
// Insert valid data in auth_oidc_managed_group to use for the following tests
_, err = db.Exec("insert into auth_oidc_managed_group (public_id, auth_method_id, name, filter) values ($1, $2, $3, $4)",
managedGroupId,
defaultOidcAuthMethodId,
name,
filter)
require.NoError(t, err)
// Fetch a valid (oidc) account ID to use in insertion
rows, err := db.Query("select public_id from auth_oidc_account limit 1")
require.NoError(t, err)
require.True(t, rows.Next())
var accountId string
require.NoError(t, rows.Scan(&accountId))
require.NotEmpty(t, accountId)
// The first set of tests is for initial insertion
{
insertTests := []struct {
testName string
managedGroupId string
memberId string
wantErr bool
}{
{
testName: "invalid managed group id",
managedGroupId: "z_yxwvutsrqp",
memberId: accountId,
wantErr: true,
},
{
testName: "invalid member id",
managedGroupId: managedGroupId,
memberId: "acct_1234567890",
wantErr: true,
},
{
testName: "valid",
managedGroupId: managedGroupId,
memberId: accountId,
wantErr: false,
},
{
testName: "duplicate values",
managedGroupId: managedGroupId,
memberId: accountId,
wantErr: true,
},
}
for _, tt := range insertTests {
t.Run("insert: "+tt.testName, func(t *testing.T) {
assert := assert.New(t)
_, err = db.Exec("insert into auth_oidc_managed_group_member_account (managed_group_id, member_id) values ($1, $2)",
tt.managedGroupId,
tt.memberId)
assert.True(tt.wantErr == (err != nil))
})
}
}
// Read some values to validate that things were set automatically
rows, err = db.Query("select create_time, managed_group_id, member_id from auth_oidc_managed_group_member_account")
require.NoError(t, err)
require.True(t, rows.Next())
var create_time time.Time
var managed_group_id, member_id string
require.NoError(t, rows.Scan(&create_time, &managed_group_id, &member_id))
assert.False(t, create_time.IsZero())
// These update tests check immutability
{
updateTests := []struct {
testName string
column string
value interface{}
wantErr bool
}{
{
testName: "immutable managed group id",
column: "managed_group_id",
value: "z_yxwvutsrqp",
wantErr: true,
},
{
testName: "immutable member_id",
column: "member_id",
value: "acct_1234567890",
wantErr: true,
},
{
testName: "immutable creation time",
column: "create_time",
value: time.Now(),
wantErr: true,
},
}
for _, tt := range updateTests {
t.Run("update: "+tt.testName, func(t *testing.T) {
assert := assert.New(t)
_, err = db.Exec(fmt.Sprintf("update auth_managed_group_member_account set %s = $1 where managed_group_id = $2 and member_id = $3", tt.column), managedGroupId, accountId)
assert.True(tt.wantErr == (err != nil))
})
}
}
// Read from the view to ensure we see it there
rows, err = db.Query("select create_time, managed_group_id, member_id from auth_managed_group_member_account")
require.NoError(t, err)
require.True(t, rows.Next())
var view_create_time time.Time
var view_managed_group_id, view_member_id string
require.NoError(t, rows.Scan(&view_create_time, &view_managed_group_id, &view_member_id))
assert.Equal(t, create_time, view_create_time)
assert.Equal(t, managed_group_id, view_managed_group_id)
assert.Equal(t, member_id, view_member_id)
}

@ -4,7 +4,7 @@ package schema
func init() {
migrationStates["postgres"] = migrationState{
binarySchemaVersion: 8001,
binarySchemaVersion: 9003,
upMigrations: map[int][]byte{
1: []byte(`
create domain wt_public_id as text
@ -6373,6 +6373,195 @@ from
session s
where
sc.session_id = s.public_id;
`),
9001: []byte(`
-- The base abstract table
create table auth_managed_group (
public_id wt_public_id
primary key,
auth_method_id wt_public_id
not null,
-- Ensure that if the auth method is deleted (which will also happen if the
-- scope is deleted) this is deleted too
constraint auth_method_fkey
foreign key (auth_method_id) -- fk1
references auth_method(public_id)
on delete cascade
on update cascade,
constraint auth_managed_group_auth_method_id_public_id_uq
unique(auth_method_id, public_id)
);
comment on table auth_managed_group is
'auth_managed_group is the abstract base table for managed groups.';
-- Define the immutable fields of auth_managed_group
create trigger
immutable_columns
before
update on auth_managed_group
for each row execute procedure immutable_columns('public_id', 'auth_method_id');
-- Function to insert into the base table when values are inserted into a
-- concrete type table. This happens before inserts so the foreign keys in the
-- concrete type will be valid.
create or replace function
insert_managed_group_subtype()
returns trigger
as $$
begin
insert into auth_managed_group
(public_id, auth_method_id)
values
(new.public_id, new.auth_method_id);
return new;
end;
$$ language plpgsql;
-- delete_managed_group_subtype() is an after delete trigger
-- function for subtypes of managed_group
create or replace function delete_managed_group_subtype()
returns trigger
as $$
begin
delete from auth_managed_group
where public_id = old.public_id;
return null; -- result is ignored since this is an after trigger
end;
$$ language plpgsql;
`),
9002: []byte(`
create table auth_oidc_managed_group (
public_id wt_public_id
primary key,
auth_method_id wt_public_id
not null,
name wt_name,
description wt_description,
create_time wt_timestamp,
update_time wt_timestamp,
version wt_version,
filter wt_bexprfilter
not null,
-- Ensure that this managed group relates to an oidc auth method, as opposed
-- to other types
constraint auth_oidc_method_fkey
foreign key (auth_method_id) -- fk1
references auth_oidc_method (public_id)
on delete cascade
on update cascade,
-- Ensure it relates to an abstract managed group
constraint auth_managed_group_fkey
foreign key (auth_method_id, public_id) -- fk2
references auth_managed_group (auth_method_id, public_id)
on delete cascade
on update cascade,
constraint auth_oidc_managed_group_auth_method_id_name_uq
unique(auth_method_id, name)
);
comment on table auth_oidc_managed_group is
'auth_oidc_managed_group entries are subtypes of auth_managed_group and represent an oidc managed group.';
-- Define the immutable fields of auth_oidc_managed_group
create trigger
immutable_columns
before
update on auth_oidc_managed_group
for each row execute procedure immutable_columns('public_id', 'auth_method_id', 'create_time');
-- Populate create time on insert
create trigger
default_create_time_column
before
insert on auth_oidc_managed_group
for each row execute procedure default_create_time();
-- Generate update time on update
create trigger
update_time_column
before
update on auth_oidc_managed_group
for each row execute procedure update_time_column();
-- Update version when something changes
create trigger
update_version_column
after
update on auth_oidc_managed_group
for each row execute procedure update_version_column();
-- Add into the base table when inserting into the concrete table
create trigger
insert_managed_group_subtype
before insert on auth_oidc_managed_group
for each row execute procedure insert_managed_group_subtype();
-- Ensure that deletions in the oidc subtype result in deletions to the base
-- table.
create trigger
delete_managed_group_subtype
after
delete on auth_oidc_managed_group
for each row execute procedure delete_managed_group_subtype();
-- The tickets for oplog are the subtypes not the base types because no updates
-- are done to any values in the base types.
insert into oplog_ticket
(name, version)
values
('auth_oidc_managed_group', 1);
`),
9003: []byte(`
-- Mappings of account to oidc managed groups. This is a non-abstract table with
-- a view (below) so that it is a natural aggregate for the oplog (also below).
create table auth_oidc_managed_group_member_account (
create_time wt_timestamp,
managed_group_id wt_public_id
references auth_oidc_managed_group(public_id)
on delete cascade
on update cascade,
member_id wt_public_id
references auth_oidc_account(public_id)
on delete cascade
on update cascade,
primary key (managed_group_id, member_id)
);
comment on table auth_oidc_managed_group_member_account is
'auth_oidc_managed_group_member_account is the join table for managed oidc groups and accounts.';
-- auth_immutable_managed_oidc_group_member_account() ensures that group members are immutable.
create or replace function
auth_immutable_managed_oidc_group_member_account()
returns trigger
as $$
begin
raise exception 'managed oidc group members are immutable';
end;
$$ language plpgsql;
create trigger
default_create_time_column
before
insert on auth_oidc_managed_group_member_account
for each row execute procedure default_create_time();
create trigger
auth_immutable_managed_oidc_group_member_account
before
update on auth_oidc_managed_group_member_account
for each row execute procedure auth_immutable_managed_oidc_group_member_account();
-- Initially create the view with just oidc; eventually we can replace this view
-- to union with other subtype tables.
create view auth_managed_group_member_account as
select
oidc.create_time,
oidc.managed_group_id,
oidc.member_id
from
auth_oidc_managed_group_member_account oidc;
`),
},
}

@ -20,7 +20,7 @@ import (
func TestAuthenticationHandler(t *testing.T) {
c := NewTestController(t, &TestControllerOpts{
DisableAuthorizationFailures: true,
DefaultAuthMethodId: "ampw_1234567890",
DefaultPasswordAuthMethodId: "ampw_1234567890",
DefaultLoginName: "admin",
DefaultPassword: "password123",
})

@ -27,7 +27,8 @@ import (
)
const (
DefaultTestAuthMethodId = "ampw_1234567890"
DefaultTestPasswordAuthMethodId = "ampw_1234567890"
DefaultTestOidcAuthMethodId = "amoidc_1234567890"
DefaultTestLoginName = "admin"
DefaultTestUnprivilegedLoginName = "user"
DefaultTestPassword = "passpass"
@ -280,8 +281,11 @@ type TestControllerOpts struct {
// set.
Config *config.Config
// DefaultAuthMethodId is the default auth method ID to use, if set.
DefaultAuthMethodId string
// DefaultPasswordAuthMethodId is the default password method ID to use, if set.
DefaultPasswordAuthMethodId string
// DefaultOidcAuthMethodId is the default OIDC method ID to use, if set.
DefaultOidcAuthMethodId string
// DefaultLoginName is the login name used when creating the default admin account.
DefaultLoginName string
@ -394,10 +398,15 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
opts.Config.Controller.Name = opts.Name
}
if opts.DefaultAuthMethodId != "" {
tc.b.DevPasswordAuthMethodId = opts.DefaultAuthMethodId
if opts.DefaultPasswordAuthMethodId != "" {
tc.b.DevPasswordAuthMethodId = opts.DefaultPasswordAuthMethodId
} else {
tc.b.DevPasswordAuthMethodId = DefaultTestPasswordAuthMethodId
}
if opts.DefaultOidcAuthMethodId != "" {
tc.b.DevOidcAuthMethodId = opts.DefaultOidcAuthMethodId
} else {
tc.b.DevPasswordAuthMethodId = DefaultTestAuthMethodId
tc.b.DevOidcAuthMethodId = DefaultTestOidcAuthMethodId
}
if opts.DefaultLoginName != "" {
tc.b.DevLoginName = opts.DefaultLoginName
@ -440,6 +449,7 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
if opts.InitialResourcesSuffix != "" {
suffix := opts.InitialResourcesSuffix
tc.b.DevPasswordAuthMethodId = "ampw_" + suffix
tc.b.DevOidcAuthMethodId = "amoidc_" + suffix
tc.b.DevHostCatalogId = "hcst_" + suffix
tc.b.DevHostId = "hst_" + suffix
tc.b.DevHostSetId = "hsst_" + suffix
@ -500,6 +510,9 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
if _, _, err := tc.b.CreateInitialPasswordAuthMethod(ctx); err != nil {
t.Fatal(err)
}
if err := tc.b.CreateDevOidcAuthMethod(ctx); err != nil {
t.Fatal(err)
}
if !opts.DisableScopesCreation {
if _, _, err := tc.b.CreateInitialScopes(ctx); err != nil {
t.Fatal(err)
@ -557,17 +570,18 @@ func (tc *TestController) AddClusterControllerMember(t *testing.T, opts *TestCon
opts = new(TestControllerOpts)
}
nextOpts := &TestControllerOpts{
DatabaseUrl: tc.c.conf.DatabaseUrl,
DefaultAuthMethodId: tc.c.conf.DevPasswordAuthMethodId,
RootKms: tc.c.conf.RootKms,
WorkerAuthKms: tc.c.conf.WorkerAuthKms,
RecoveryKms: tc.c.conf.RecoveryKms,
Name: opts.Name,
Logger: tc.c.conf.Logger,
DefaultLoginName: tc.b.DevLoginName,
DefaultPassword: tc.b.DevPassword,
DisableKmsKeyCreation: true,
DisableAuthMethodCreation: true,
DatabaseUrl: tc.c.conf.DatabaseUrl,
DefaultPasswordAuthMethodId: tc.c.conf.DevPasswordAuthMethodId,
DefaultOidcAuthMethodId: tc.c.conf.DevOidcAuthMethodId,
RootKms: tc.c.conf.RootKms,
WorkerAuthKms: tc.c.conf.WorkerAuthKms,
RecoveryKms: tc.c.conf.RecoveryKms,
Name: opts.Name,
Logger: tc.c.conf.Logger,
DefaultLoginName: tc.b.DevLoginName,
DefaultPassword: tc.b.DevPassword,
DisableKmsKeyCreation: true,
DisableAuthMethodCreation: true,
}
if opts.Logger != nil {
nextOpts.Logger = opts.Logger

@ -22,7 +22,7 @@ func getOpts(opt ...Option) (*controller.TestControllerOpts, error) {
return nil, fmt.Errorf("Cannot provide both WithConfigFile and WithConfigText")
}
var setDbParams bool
if opts.setDefaultAuthMethodId || opts.setDefaultLoginName || opts.setDefaultPassword {
if opts.setDefaultPasswordAuthMethodId || opts.setDefaultOidcAuthMethodId || opts.setDefaultLoginName || opts.setDefaultPassword {
setDbParams = true
}
if opts.setDisableAuthMethodCreation {
@ -39,19 +39,20 @@ func getOpts(opt ...Option) (*controller.TestControllerOpts, error) {
}
type option struct {
tcOptions *controller.TestControllerOpts
setWithConfigFile bool
setWithConfigText bool
setDisableAuthMethodCreation bool
setDisableDatabaseCreation bool
setDisableDatabaseDestruction bool
setDefaultAuthMethodId bool
setDefaultLoginName bool
setDefaultPassword bool
setRootKms bool
setWorkerAuthKms bool
setRecoveryKms bool
setDatabaseUrl bool
tcOptions *controller.TestControllerOpts
setWithConfigFile bool
setWithConfigText bool
setDisableAuthMethodCreation bool
setDisableDatabaseCreation bool
setDisableDatabaseDestruction bool
setDefaultPasswordAuthMethodId bool
setDefaultOidcAuthMethodId bool
setDefaultLoginName bool
setDefaultPassword bool
setRootKms bool
setWorkerAuthKms bool
setRecoveryKms bool
setDatabaseUrl bool
}
type Option func(*option) error
@ -119,10 +120,18 @@ func DisableDatabaseDestruction() Option {
}
}
func WithDefaultAuthMethodId(id string) Option {
func WithDefaultPasswordAuthMethodId(id string) Option {
return func(c *option) error {
c.setDefaultAuthMethodId = true
c.tcOptions.DefaultAuthMethodId = id
c.setDefaultPasswordAuthMethodId = true
c.tcOptions.DefaultPasswordAuthMethodId = id
return nil
}
}
func WithDefaultOidcAuthMethodId(id string) Option {
return func(c *option) error {
c.setDefaultOidcAuthMethodId = true
c.tcOptions.DefaultOidcAuthMethodId = id
return nil
}
}

Loading…
Cancel
Save