mirror of https://github.com/hashicorp/boundary
Add initial SQL for managed groups (#1253)
parent
ed5e34082c
commit
15bd1a5245
@ -0,0 +1,437 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/auth/oidc"
|
||||
"github.com/hashicorp/boundary/internal/db"
|
||||
"github.com/hashicorp/boundary/internal/db/schema"
|
||||
"github.com/hashicorp/boundary/internal/docker"
|
||||
"github.com/hashicorp/boundary/internal/iam"
|
||||
"github.com/hashicorp/boundary/internal/kms"
|
||||
"github.com/hashicorp/boundary/internal/types/scope"
|
||||
capoidc "github.com/hashicorp/cap/oidc"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
func (b *Server) CreateDevDatabase(ctx context.Context, opt ...Option) error {
|
||||
var container, url, dialect string
|
||||
var err error
|
||||
var c func() error
|
||||
|
||||
opts := getOpts(opt...)
|
||||
|
||||
// We should only get back postgres for now, but laying the foundation for non-postgres
|
||||
switch opts.withDialect {
|
||||
case "":
|
||||
b.Logger.Error("unsupported dialect. wanted: postgres, got: %v", opts.withDialect)
|
||||
default:
|
||||
dialect = opts.withDialect
|
||||
}
|
||||
|
||||
switch b.DatabaseUrl {
|
||||
case "":
|
||||
c, url, container, err = docker.StartDbInDocker(dialect, docker.WithContainerImage(opts.withContainerImage))
|
||||
// In case of an error, run the cleanup function. If we pass all errors, c should be set to a noop
|
||||
// function before returning from this method
|
||||
defer func() {
|
||||
if !opts.withSkipDatabaseDestruction {
|
||||
if c != nil {
|
||||
if err := c(); err != nil {
|
||||
b.Logger.Error("error cleaning up docker container", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
if err == docker.ErrDockerUnsupported {
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to start dev database with dialect %s: %w", dialect, err)
|
||||
}
|
||||
|
||||
// Let migrate store manage the dirty bit since dev DBs should be ephemeral anyways.
|
||||
_, err := schema.MigrateStore(ctx, dialect, url)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to initialize dev database with dialect %s: %w", dialect, err)
|
||||
if c != nil {
|
||||
err = multierror.Append(err, c())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
b.DevDatabaseCleanupFunc = c
|
||||
b.DatabaseUrl = url
|
||||
default:
|
||||
// Let migrate store manage the dirty bit since dev DBs should be ephemeral anyways.
|
||||
if _, err := schema.MigrateStore(ctx, dialect, b.DatabaseUrl); err != nil {
|
||||
err = fmt.Errorf("error initializing store: %w", err)
|
||||
if c != nil {
|
||||
err = multierror.Append(err, c())
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
b.InfoKeys = append(b.InfoKeys, "dev database url")
|
||||
b.Info["dev database url"] = b.DatabaseUrl
|
||||
if container != "" {
|
||||
b.InfoKeys = append(b.InfoKeys, "dev database container")
|
||||
b.Info["dev database container"] = strings.TrimPrefix(container, "/")
|
||||
}
|
||||
|
||||
if err := b.ConnectToDatabase(dialect); err != nil {
|
||||
if c != nil {
|
||||
err = multierror.Append(err, c())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
b.Database.LogMode(true)
|
||||
|
||||
if err := b.CreateGlobalKmsKeys(ctx); err != nil {
|
||||
if c != nil {
|
||||
err = multierror.Append(err, c())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := b.CreateInitialLoginRole(ctx); err != nil {
|
||||
if c != nil {
|
||||
err = multierror.Append(err, c())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if opts.withSkipAuthMethodCreation {
|
||||
// now that we have passed all the error cases, reset c to be a noop so the
|
||||
// defer doesn't do anything.
|
||||
c = func() error { return nil }
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, _, err := b.CreateInitialPasswordAuthMethod(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := b.CreateDevOidcAuthMethod(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if opts.withSkipScopesCreation {
|
||||
// now that we have passed all the error cases, reset c to be a noop so the
|
||||
// defer doesn't do anything.
|
||||
c = func() error { return nil }
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, _, err := b.CreateInitialScopes(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if opts.withSkipHostResourcesCreation {
|
||||
// now that we have passed all the error cases, reset c to be a noop so the
|
||||
// defer doesn't do anything.
|
||||
c = func() error { return nil }
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, _, _, err := b.CreateInitialHostResources(context.Background()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if opts.withSkipTargetCreation {
|
||||
// now that we have passed all the error cases, reset c to be a noop so the
|
||||
// defer doesn't do anything.
|
||||
c = func() error { return nil }
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := b.CreateInitialTarget(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// now that we have passed all the error cases, reset c to be a noop so the
|
||||
// defer doesn't do anything.
|
||||
c = func() error { return nil }
|
||||
return nil
|
||||
}
|
||||
|
||||
type oidcSetup struct {
|
||||
clientId string
|
||||
clientSecret oidc.ClientSecret
|
||||
oidcPort int
|
||||
callbackPort string
|
||||
hostAddr string
|
||||
authMethod *oidc.AuthMethod
|
||||
pubKey []byte
|
||||
privKey []byte
|
||||
testProvider *capoidc.TestProvider
|
||||
createUnpriv bool
|
||||
callbackUrl *url.URL
|
||||
}
|
||||
|
||||
func (b *Server) CreateDevOidcAuthMethod(ctx context.Context) error {
|
||||
var err error
|
||||
|
||||
if b.DevOidcAuthMethodId == "" {
|
||||
b.DevOidcAuthMethodId, err = db.NewPublicId(oidc.AuthMethodPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating initial oidc auth method id: %w", err)
|
||||
}
|
||||
}
|
||||
b.InfoKeys = append(b.InfoKeys, "generated oidc auth method id")
|
||||
b.Info["generated oidc auth method id"] = b.DevOidcAuthMethodId
|
||||
|
||||
switch {
|
||||
case b.DevUnprivilegedLoginName == "",
|
||||
b.DevUnprivilegedPassword == "",
|
||||
b.DevUnprivilegedUserId == "":
|
||||
|
||||
default:
|
||||
b.DevOidcSetup.createUnpriv = true
|
||||
}
|
||||
|
||||
// Trawl through the listeners and find the api listener so we can use the
|
||||
// same host name/IP
|
||||
{
|
||||
for _, ln := range b.Listeners {
|
||||
purpose := strings.ToLower(ln.Config.Purpose[0])
|
||||
if purpose != "api" {
|
||||
continue
|
||||
}
|
||||
b.DevOidcSetup.hostAddr, b.DevOidcSetup.callbackPort, err = net.SplitHostPort(ln.Config.Address)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "missing port") {
|
||||
b.DevOidcSetup.hostAddr = ln.Config.Address
|
||||
// Use the default API port in the callback
|
||||
b.DevOidcSetup.callbackPort = "9200"
|
||||
} else {
|
||||
return fmt.Errorf("error splitting host/port: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if b.DevOidcSetup.hostAddr == "" {
|
||||
return fmt.Errorf("could not determine address to use for built-in oidc dev listener")
|
||||
}
|
||||
}
|
||||
|
||||
// Find an available port -- allocate one, then close the listener, and
|
||||
// re-use it. This is a sort of hacky way to get around the chicken and egg
|
||||
// of the auth method needing to know the discovery URL and the test
|
||||
// provider needing to know the callback URL.
|
||||
l, err := net.Listen("tcp", fmt.Sprintf("%s:0", b.DevOidcSetup.hostAddr))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error finding port for oidc test provider: %w", err)
|
||||
}
|
||||
b.DevOidcSetup.oidcPort = l.(*net.TCPListener).Addr().(*net.TCPAddr).Port
|
||||
if err := l.Close(); err != nil {
|
||||
return fmt.Errorf("error closing initial test port: %w", err)
|
||||
}
|
||||
b.DevOidcSetup.callbackUrl, err = url.Parse(fmt.Sprintf("http://%s:%s", b.DevOidcSetup.hostAddr, b.DevOidcSetup.callbackPort))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing oidc test provider callback url: %w", err)
|
||||
}
|
||||
|
||||
// Generate initial IDs/keys
|
||||
{
|
||||
b.DevOidcSetup.clientId, err = capoidc.NewID()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to generate client id: %w", err)
|
||||
}
|
||||
clientSecret, err := capoidc.NewID()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to generate client secret: %w", err)
|
||||
}
|
||||
b.DevOidcSetup.clientSecret = oidc.ClientSecret(clientSecret)
|
||||
b.DevOidcSetup.pubKey, b.DevOidcSetup.privKey, err = ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to generate signing key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create the subject information and testing provider
|
||||
{
|
||||
logger, err := capoidc.NewTestingLogger(b.Logger.Named("dev-oidc"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create logger: %w", err)
|
||||
}
|
||||
|
||||
subInfo := map[string]*capoidc.TestSubject{
|
||||
b.DevLoginName: {
|
||||
Password: b.DevPassword,
|
||||
UserInfo: map[string]interface{}{
|
||||
"email": "admin@localhost",
|
||||
"name": "Admin User",
|
||||
},
|
||||
},
|
||||
}
|
||||
if b.DevOidcSetup.createUnpriv {
|
||||
subInfo[b.DevUnprivilegedLoginName] = &capoidc.TestSubject{
|
||||
Password: b.DevUnprivilegedPassword,
|
||||
UserInfo: map[string]interface{}{
|
||||
"email": "user@localhost",
|
||||
"name": "Unprivileged User",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
clientSecret := string(b.DevOidcSetup.clientSecret)
|
||||
|
||||
b.DevOidcSetup.testProvider = capoidc.StartTestProvider(
|
||||
logger,
|
||||
capoidc.WithNoTLS(),
|
||||
capoidc.WithTestHost(b.DevOidcSetup.hostAddr),
|
||||
capoidc.WithTestPort(b.DevOidcSetup.oidcPort),
|
||||
capoidc.WithTestDefaults(&capoidc.TestProviderDefaults{
|
||||
CustomClaims: map[string]interface{}{
|
||||
"mode": "dev",
|
||||
},
|
||||
SubjectInfo: subInfo,
|
||||
SigningKey: &capoidc.TestSigningKey{
|
||||
PrivKey: ed25519.PrivateKey(b.DevOidcSetup.privKey),
|
||||
PubKey: ed25519.PublicKey(b.DevOidcSetup.pubKey),
|
||||
Alg: capoidc.EdDSA,
|
||||
},
|
||||
AllowedRedirectURIs: []string{fmt.Sprintf("%s/v1/auth-methods/oidc:authenticate:callback", b.DevOidcSetup.callbackUrl.String())},
|
||||
ClientID: &b.DevOidcSetup.clientId,
|
||||
ClientSecret: &clientSecret,
|
||||
}))
|
||||
|
||||
b.ShutdownFuncs = append(b.ShutdownFuncs, func() error {
|
||||
b.DevOidcSetup.testProvider.Stop()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Create auth method and link accounts
|
||||
{
|
||||
b.DevOidcSetup.authMethod, err = b.createInitialOidcAuthMethod(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating initial oidc auth method: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Server) createInitialOidcAuthMethod(ctx context.Context) (*oidc.AuthMethod, error) {
|
||||
rw := db.New(b.Database)
|
||||
|
||||
kmsRepo, err := kms.NewRepository(rw, rw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating kms repository: %w", err)
|
||||
}
|
||||
kmsCache, err := kms.NewKms(kmsRepo, kms.WithLogger(b.Logger.Named("kms")))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating kms cache: %w", err)
|
||||
}
|
||||
if err := kmsCache.AddExternalWrappers(
|
||||
kms.WithRootWrapper(b.RootKms),
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("error adding config keys to kms: %w", err)
|
||||
}
|
||||
|
||||
discoveryUrl, err := url.Parse(fmt.Sprintf("http://%s:%d", b.DevOidcSetup.hostAddr, b.DevOidcSetup.oidcPort))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing oidc test provider address: %w", err)
|
||||
}
|
||||
|
||||
// Create the auth method
|
||||
oidcRepo, err := oidc.NewRepository(rw, rw, kmsCache)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating oidc repo: %w", err)
|
||||
}
|
||||
|
||||
authMethod, err := oidc.NewAuthMethod(
|
||||
scope.Global.String(),
|
||||
b.DevOidcSetup.clientId,
|
||||
b.DevOidcSetup.clientSecret,
|
||||
oidc.WithName("Generated global scope initial oidc auth method"),
|
||||
oidc.WithDescription("Provides initial administrative and unprivileged authentication into Boundary"),
|
||||
oidc.WithIssuer(discoveryUrl),
|
||||
oidc.WithApiUrl(b.DevOidcSetup.callbackUrl),
|
||||
oidc.WithSigningAlgs(oidc.EdDSA),
|
||||
oidc.WithOperationalState(oidc.ActivePublicState))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new in memory oidc auth method: %w", err)
|
||||
}
|
||||
if b.DevOidcAuthMethodId == "" {
|
||||
b.DevOidcAuthMethodId, err = db.NewPublicId(oidc.AuthMethodPrefix)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating initial oidc auth method id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
go func() {
|
||||
select {
|
||||
case <-b.ShutdownCh:
|
||||
cancel()
|
||||
case <-cancelCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
b.DevOidcSetup.authMethod, err = oidcRepo.CreateAuthMethod(
|
||||
cancelCtx,
|
||||
authMethod,
|
||||
oidc.WithPublicId(b.DevOidcAuthMethodId))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error saving oidc auth method to the db: %w", err)
|
||||
}
|
||||
|
||||
// Create accounts
|
||||
{
|
||||
createAndLinkAccount := func(loginName, userId, typ string) error {
|
||||
acct, err := oidc.NewAccount(
|
||||
b.DevOidcSetup.authMethod.GetPublicId(),
|
||||
loginName,
|
||||
oidc.WithDescription(fmt.Sprintf("Initial %s OIDC account", typ)),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating %s oidc account: %w", typ, err)
|
||||
}
|
||||
acct, err = oidcRepo.CreateAccount(
|
||||
cancelCtx,
|
||||
b.DevOidcSetup.authMethod.GetScopeId(),
|
||||
acct,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating %s oidc account: %w", typ, err)
|
||||
}
|
||||
|
||||
// Link accounts to existing user
|
||||
iamRepo, err := iam.NewRepository(rw, rw, kmsCache)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create iam repo: %w", err)
|
||||
}
|
||||
|
||||
u, _, err := iamRepo.LookupUser(cancelCtx, userId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error looking up %s user: %w", typ, err)
|
||||
}
|
||||
if _, err = iamRepo.AddUserAccounts(cancelCtx, u.GetPublicId(), u.GetVersion(), []string{acct.GetPublicId()}); err != nil {
|
||||
return fmt.Errorf("error associating initial %s user with account: %w", typ, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := createAndLinkAccount(b.DevLoginName, b.DevUserId, "admin"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if b.DevOidcSetup.createUnpriv {
|
||||
if err := createAndLinkAccount(b.DevUnprivilegedLoginName, b.DevUnprivilegedUserId, "unprivileged"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
@ -0,0 +1,60 @@
|
||||
begin;
|
||||
|
||||
-- The base abstract table
|
||||
create table auth_managed_group (
|
||||
public_id wt_public_id
|
||||
primary key,
|
||||
auth_method_id wt_public_id
|
||||
not null,
|
||||
-- Ensure that if the auth method is deleted (which will also happen if the
|
||||
-- scope is deleted) this is deleted too
|
||||
constraint auth_method_fkey
|
||||
foreign key (auth_method_id) -- fk1
|
||||
references auth_method(public_id)
|
||||
on delete cascade
|
||||
on update cascade,
|
||||
constraint auth_managed_group_auth_method_id_public_id_uq
|
||||
unique(auth_method_id, public_id)
|
||||
);
|
||||
comment on table auth_managed_group is
|
||||
'auth_managed_group is the abstract base table for managed groups.';
|
||||
|
||||
-- Define the immutable fields of auth_managed_group
|
||||
create trigger
|
||||
immutable_columns
|
||||
before
|
||||
update on auth_managed_group
|
||||
for each row execute procedure immutable_columns('public_id', 'auth_method_id');
|
||||
|
||||
-- Function to insert into the base table when values are inserted into a
|
||||
-- concrete type table. This happens before inserts so the foreign keys in the
|
||||
-- concrete type will be valid.
|
||||
create or replace function
|
||||
insert_managed_group_subtype()
|
||||
returns trigger
|
||||
as $$
|
||||
begin
|
||||
|
||||
insert into auth_managed_group
|
||||
(public_id, auth_method_id)
|
||||
values
|
||||
(new.public_id, new.auth_method_id);
|
||||
|
||||
return new;
|
||||
|
||||
end;
|
||||
$$ language plpgsql;
|
||||
|
||||
-- delete_managed_group_subtype() is an after delete trigger
|
||||
-- function for subtypes of managed_group
|
||||
create or replace function delete_managed_group_subtype()
|
||||
returns trigger
|
||||
as $$
|
||||
begin
|
||||
delete from auth_managed_group
|
||||
where public_id = old.public_id;
|
||||
return null; -- result is ignored since this is an after trigger
|
||||
end;
|
||||
$$ language plpgsql;
|
||||
|
||||
commit;
|
||||
@ -0,0 +1,83 @@
|
||||
begin;
|
||||
|
||||
create table auth_oidc_managed_group (
|
||||
public_id wt_public_id
|
||||
primary key,
|
||||
auth_method_id wt_public_id
|
||||
not null,
|
||||
name wt_name,
|
||||
description wt_description,
|
||||
create_time wt_timestamp,
|
||||
update_time wt_timestamp,
|
||||
version wt_version,
|
||||
filter wt_bexprfilter
|
||||
not null,
|
||||
-- Ensure that this managed group relates to an oidc auth method, as opposed
|
||||
-- to other types
|
||||
constraint auth_oidc_method_fkey
|
||||
foreign key (auth_method_id) -- fk1
|
||||
references auth_oidc_method (public_id)
|
||||
on delete cascade
|
||||
on update cascade,
|
||||
-- Ensure it relates to an abstract managed group
|
||||
constraint auth_managed_group_fkey
|
||||
foreign key (auth_method_id, public_id) -- fk2
|
||||
references auth_managed_group (auth_method_id, public_id)
|
||||
on delete cascade
|
||||
on update cascade,
|
||||
constraint auth_oidc_managed_group_auth_method_id_name_uq
|
||||
unique(auth_method_id, name)
|
||||
);
|
||||
comment on table auth_oidc_managed_group is
|
||||
'auth_oidc_managed_group entries are subtypes of auth_managed_group and represent an oidc managed group.';
|
||||
|
||||
-- Define the immutable fields of auth_oidc_managed_group
|
||||
create trigger
|
||||
immutable_columns
|
||||
before
|
||||
update on auth_oidc_managed_group
|
||||
for each row execute procedure immutable_columns('public_id', 'auth_method_id', 'create_time');
|
||||
|
||||
-- Populate create time on insert
|
||||
create trigger
|
||||
default_create_time_column
|
||||
before
|
||||
insert on auth_oidc_managed_group
|
||||
for each row execute procedure default_create_time();
|
||||
|
||||
-- Generate update time on update
|
||||
create trigger
|
||||
update_time_column
|
||||
before
|
||||
update on auth_oidc_managed_group
|
||||
for each row execute procedure update_time_column();
|
||||
|
||||
-- Update version when something changes
|
||||
create trigger
|
||||
update_version_column
|
||||
after
|
||||
update on auth_oidc_managed_group
|
||||
for each row execute procedure update_version_column();
|
||||
|
||||
-- Add into the base table when inserting into the concrete table
|
||||
create trigger
|
||||
insert_managed_group_subtype
|
||||
before insert on auth_oidc_managed_group
|
||||
for each row execute procedure insert_managed_group_subtype();
|
||||
|
||||
-- Ensure that deletions in the oidc subtype result in deletions to the base
|
||||
-- table.
|
||||
create trigger
|
||||
delete_managed_group_subtype
|
||||
after
|
||||
delete on auth_oidc_managed_group
|
||||
for each row execute procedure delete_managed_group_subtype();
|
||||
|
||||
-- The tickets for oplog are the subtypes not the base types because no updates
|
||||
-- are done to any values in the base types.
|
||||
insert into oplog_ticket
|
||||
(name, version)
|
||||
values
|
||||
('auth_oidc_managed_group', 1);
|
||||
|
||||
commit;
|
||||
@ -0,0 +1,52 @@
|
||||
begin;
|
||||
|
||||
-- Mappings of account to oidc managed groups. This is a non-abstract table with
|
||||
-- a view (below) so that it is a natural aggregate for the oplog (also below).
|
||||
create table auth_oidc_managed_group_member_account (
|
||||
create_time wt_timestamp,
|
||||
managed_group_id wt_public_id
|
||||
references auth_oidc_managed_group(public_id)
|
||||
on delete cascade
|
||||
on update cascade,
|
||||
member_id wt_public_id
|
||||
references auth_oidc_account(public_id)
|
||||
on delete cascade
|
||||
on update cascade,
|
||||
primary key (managed_group_id, member_id)
|
||||
);
|
||||
comment on table auth_oidc_managed_group_member_account is
|
||||
'auth_oidc_managed_group_member_account is the join table for managed oidc groups and accounts.';
|
||||
|
||||
-- auth_immutable_managed_oidc_group_member_account() ensures that group members are immutable.
|
||||
create or replace function
|
||||
auth_immutable_managed_oidc_group_member_account()
|
||||
returns trigger
|
||||
as $$
|
||||
begin
|
||||
raise exception 'managed oidc group members are immutable';
|
||||
end;
|
||||
$$ language plpgsql;
|
||||
|
||||
create trigger
|
||||
default_create_time_column
|
||||
before
|
||||
insert on auth_oidc_managed_group_member_account
|
||||
for each row execute procedure default_create_time();
|
||||
|
||||
create trigger
|
||||
auth_immutable_managed_oidc_group_member_account
|
||||
before
|
||||
update on auth_oidc_managed_group_member_account
|
||||
for each row execute procedure auth_immutable_managed_oidc_group_member_account();
|
||||
|
||||
-- Initially create the view with just oidc; eventually we can replace this view
|
||||
-- to union with other subtype tables.
|
||||
create view auth_managed_group_member_account as
|
||||
select
|
||||
oidc.create_time,
|
||||
oidc.managed_group_id,
|
||||
oidc.member_id
|
||||
from
|
||||
auth_oidc_managed_group_member_account oidc;
|
||||
|
||||
commit;
|
||||
@ -0,0 +1,376 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/servers/controller"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ManagedGroupTable(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := controller.NewTestController(t, nil)
|
||||
defer tc.Shutdown()
|
||||
|
||||
db := tc.DbConn().DB()
|
||||
var err error
|
||||
|
||||
managedGroupId := "a_bcdefghijk"
|
||||
defaultPasswordAuthMethodId := "ampw_1234567890"
|
||||
defaultOidcAuthMethodId := "amoidc_1234567890"
|
||||
|
||||
insertTests := []struct {
|
||||
testName string
|
||||
publicId string
|
||||
authMethodId string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
testName: "invalid auth method",
|
||||
publicId: managedGroupId,
|
||||
authMethodId: "amoid_1234567890",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
testName: "valid",
|
||||
publicId: managedGroupId,
|
||||
authMethodId: defaultOidcAuthMethodId,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range insertTests {
|
||||
t.Run("insert: "+tt.testName, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
_, err = db.Exec("insert into auth_managed_group values ($1, $2)",
|
||||
tt.publicId,
|
||||
tt.authMethodId)
|
||||
require.True(tt.wantErr == (err != nil))
|
||||
})
|
||||
}
|
||||
|
||||
updateTests := []struct {
|
||||
testName string
|
||||
column string
|
||||
value string
|
||||
publicId string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
testName: "immutable public id",
|
||||
column: "public_id",
|
||||
value: "z_yxwvutsrqp",
|
||||
publicId: managedGroupId,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
testName: "immutable auth method",
|
||||
column: "auth_method_id",
|
||||
value: defaultPasswordAuthMethodId,
|
||||
publicId: managedGroupId,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range updateTests {
|
||||
t.Run("update: "+tt.testName, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
_, err = db.Exec(fmt.Sprintf("update auth_managed_group set %s = $1 where public_id = $2", tt.column), tt.value, tt.publicId)
|
||||
assert.True(tt.wantErr == (err != nil))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_OidcManagedGroupTable(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := controller.NewTestController(t, nil)
|
||||
defer tc.Shutdown()
|
||||
|
||||
db := tc.DbConn().DB()
|
||||
var err error
|
||||
|
||||
managedGroupId := "a_bcdefghijk"
|
||||
defaultPasswordAuthMethodId := "ampw_1234567890"
|
||||
defaultOidcAuthMethodId := "amoidc_1234567890"
|
||||
name := "this is the name"
|
||||
filter := "this is a filter"
|
||||
|
||||
// The first set of tests is for initial insertion
|
||||
{
|
||||
insertTests := []struct {
|
||||
testName string
|
||||
publicId string
|
||||
authMethodId string
|
||||
name string
|
||||
filter string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
testName: "null filter",
|
||||
publicId: managedGroupId,
|
||||
authMethodId: "amoid_1234567890",
|
||||
name: name,
|
||||
filter: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
testName: "invalid auth method",
|
||||
publicId: managedGroupId,
|
||||
authMethodId: defaultPasswordAuthMethodId,
|
||||
name: name,
|
||||
filter: filter,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
testName: "valid",
|
||||
publicId: managedGroupId,
|
||||
authMethodId: defaultOidcAuthMethodId,
|
||||
name: name,
|
||||
filter: filter,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
testName: "duplicate public id",
|
||||
publicId: managedGroupId,
|
||||
authMethodId: defaultOidcAuthMethodId,
|
||||
name: name,
|
||||
filter: filter,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
testName: "duplicate name",
|
||||
publicId: "z_yxwvutsrqp",
|
||||
authMethodId: defaultOidcAuthMethodId,
|
||||
name: name,
|
||||
filter: filter,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range insertTests {
|
||||
t.Run("insert: "+tt.testName, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
_, err = db.Exec("insert into auth_oidc_managed_group (public_id, auth_method_id, name, filter) values ($1, $2, $3, $4)",
|
||||
tt.publicId,
|
||||
tt.authMethodId,
|
||||
tt.name,
|
||||
tt.filter)
|
||||
require.True(tt.wantErr == (err != nil))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Read some values to validate that things were set automatically
|
||||
rows, err := db.Query("select create_time, update_time, version from auth_oidc_managed_group")
|
||||
require.NoError(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var create_time, update_time time.Time
|
||||
var version int
|
||||
require.NoError(t, rows.Scan(&create_time, &update_time, &version))
|
||||
assert.False(t, create_time.IsZero())
|
||||
assert.Equal(t, update_time, create_time)
|
||||
assert.Equal(t, 1, version)
|
||||
|
||||
// These update tests check immutability
|
||||
{
|
||||
updateTests := []struct {
|
||||
testName string
|
||||
column string
|
||||
value interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
testName: "immutable public id",
|
||||
column: "public_id",
|
||||
value: "z_yxwvutsrqp",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
testName: "immutable auth method",
|
||||
column: "auth_method_id",
|
||||
value: defaultPasswordAuthMethodId,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
testName: "immutable creation time",
|
||||
column: "create_time",
|
||||
value: time.Now(),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
testName: "valid",
|
||||
column: "description",
|
||||
value: "this is the description",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range updateTests {
|
||||
t.Run("update: "+tt.testName, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
_, err = db.Exec(fmt.Sprintf("update auth_oidc_managed_group set %s = $1 where public_id = $2", tt.column), tt.value, managedGroupId)
|
||||
require.True(tt.wantErr == (err != nil))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Read values again to validate that things were updated automatically
|
||||
rows, err = db.Query("select create_time, update_time, version from auth_oidc_managed_group")
|
||||
require.NoError(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var updated_create_time, updated_update_time time.Time
|
||||
require.NoError(t, rows.Scan(&updated_create_time, &updated_update_time, &version))
|
||||
assert.Equal(t, create_time, updated_create_time)
|
||||
assert.NotEqual(t, update_time, updated_update_time)
|
||||
assert.Equal(t, 2, version)
|
||||
|
||||
// Read values from auth_managed_group to ensure it was populated automatically
|
||||
rows, err = db.Query("select public_id, auth_method_id from auth_managed_group")
|
||||
require.NoError(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var public_id, auth_method_id string
|
||||
require.NoError(t, rows.Scan(&public_id, &auth_method_id))
|
||||
assert.Equal(t, managedGroupId, public_id)
|
||||
assert.Equal(t, defaultOidcAuthMethodId, auth_method_id)
|
||||
|
||||
// Delete the value from the subtype table
|
||||
res, err := db.Exec("delete from auth_oidc_managed_group where public_id = $1", managedGroupId)
|
||||
require.NoError(t, err)
|
||||
affected, err := res.RowsAffected()
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, affected)
|
||||
|
||||
// It should no longer be in the base table
|
||||
rows, err = db.Query("select public_id, auth_method_id from auth_managed_group")
|
||||
require.NoError(t, err)
|
||||
require.False(t, rows.Next())
|
||||
}
|
||||
|
||||
func Test_AuthManagedOidcGroupMemberAccountTable(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc := controller.NewTestController(t, nil)
|
||||
defer tc.Shutdown()
|
||||
|
||||
db := tc.DbConn().DB()
|
||||
var err error
|
||||
|
||||
managedGroupId := "a_bcdefghijk"
|
||||
defaultOidcAuthMethodId := "amoidc_1234567890"
|
||||
name := "this is the name"
|
||||
filter := "this is a filter"
|
||||
|
||||
// Insert valid data in auth_oidc_managed_group to use for the following tests
|
||||
_, err = db.Exec("insert into auth_oidc_managed_group (public_id, auth_method_id, name, filter) values ($1, $2, $3, $4)",
|
||||
managedGroupId,
|
||||
defaultOidcAuthMethodId,
|
||||
name,
|
||||
filter)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch a valid (oidc) account ID to use in insertion
|
||||
rows, err := db.Query("select public_id from auth_oidc_account limit 1")
|
||||
require.NoError(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var accountId string
|
||||
require.NoError(t, rows.Scan(&accountId))
|
||||
require.NotEmpty(t, accountId)
|
||||
|
||||
// The first set of tests is for initial insertion
|
||||
{
|
||||
insertTests := []struct {
|
||||
testName string
|
||||
managedGroupId string
|
||||
memberId string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
testName: "invalid managed group id",
|
||||
managedGroupId: "z_yxwvutsrqp",
|
||||
memberId: accountId,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
testName: "invalid member id",
|
||||
managedGroupId: managedGroupId,
|
||||
memberId: "acct_1234567890",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
testName: "valid",
|
||||
managedGroupId: managedGroupId,
|
||||
memberId: accountId,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
testName: "duplicate values",
|
||||
managedGroupId: managedGroupId,
|
||||
memberId: accountId,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range insertTests {
|
||||
t.Run("insert: "+tt.testName, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
_, err = db.Exec("insert into auth_oidc_managed_group_member_account (managed_group_id, member_id) values ($1, $2)",
|
||||
tt.managedGroupId,
|
||||
tt.memberId)
|
||||
assert.True(tt.wantErr == (err != nil))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Read some values to validate that things were set automatically
|
||||
rows, err = db.Query("select create_time, managed_group_id, member_id from auth_oidc_managed_group_member_account")
|
||||
require.NoError(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var create_time time.Time
|
||||
var managed_group_id, member_id string
|
||||
require.NoError(t, rows.Scan(&create_time, &managed_group_id, &member_id))
|
||||
assert.False(t, create_time.IsZero())
|
||||
|
||||
// These update tests check immutability
|
||||
{
|
||||
updateTests := []struct {
|
||||
testName string
|
||||
column string
|
||||
value interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
testName: "immutable managed group id",
|
||||
column: "managed_group_id",
|
||||
value: "z_yxwvutsrqp",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
testName: "immutable member_id",
|
||||
column: "member_id",
|
||||
value: "acct_1234567890",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
testName: "immutable creation time",
|
||||
column: "create_time",
|
||||
value: time.Now(),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range updateTests {
|
||||
t.Run("update: "+tt.testName, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
_, err = db.Exec(fmt.Sprintf("update auth_managed_group_member_account set %s = $1 where managed_group_id = $2 and member_id = $3", tt.column), managedGroupId, accountId)
|
||||
assert.True(tt.wantErr == (err != nil))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Read from the view to ensure we see it there
|
||||
rows, err = db.Query("select create_time, managed_group_id, member_id from auth_managed_group_member_account")
|
||||
require.NoError(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var view_create_time time.Time
|
||||
var view_managed_group_id, view_member_id string
|
||||
require.NoError(t, rows.Scan(&view_create_time, &view_managed_group_id, &view_member_id))
|
||||
assert.Equal(t, create_time, view_create_time)
|
||||
assert.Equal(t, managed_group_id, view_managed_group_id)
|
||||
assert.Equal(t, member_id, view_member_id)
|
||||
}
|
||||
Loading…
Reference in new issue