diff --git a/internal/auth/oidc/repository_account.go b/internal/auth/oidc/repository_account.go index 70ea38cce0..ec9fc011b7 100644 --- a/internal/auth/oidc/repository_account.go +++ b/internal/auth/oidc/repository_account.go @@ -24,6 +24,8 @@ import ( // // Both a.Name and a.Description are optional. If a.Name is set, it must be // unique within a.AuthMethodId. +// +// WithPublicId is currently the only valid option. func (r *Repository) CreateAccount(ctx context.Context, scopeId string, a *Account, opt ...Option) (*Account, error) { const op = "oidc.(Repository).CreateAccount" if a == nil { @@ -67,11 +69,20 @@ func (r *Repository) CreateAccount(ctx context.Context, scopeId string, a *Accou if a.Issuer == "" { return nil, errors.New(errors.InvalidParameter, op, "no issuer provided or defined in auth method") } - id, err := newAccountId(a.AuthMethodId, a.Issuer, a.Subject) - if err != nil { - return nil, errors.Wrap(err, op) + + opts := getOpts(opt...) + if opts.withPublicId != "" { + if !strings.HasPrefix(opts.withPublicId, AccountPrefix+"_") { + return nil, errors.New(errors.InvalidParameter, op, "chosen account id does not have a valid prefix") + } + a.PublicId = opts.withPublicId + } else { + id, err := newAccountId(a.AuthMethodId, a.Issuer, a.Subject) + if err != nil { + return nil, errors.Wrap(err, op) + } + a.PublicId = id } - a.PublicId = id oplogWrapper, err := r.kms.GetWrapper(ctx, scopeId, kms.KeyPurposeOplog) if err != nil { diff --git a/internal/auth/password/repository_account.go b/internal/auth/password/repository_account.go index 4c1a70936a..627c239305 100644 --- a/internal/auth/password/repository_account.go +++ b/internal/auth/password/repository_account.go @@ -14,14 +14,15 @@ import ( ) // CreateAccount inserts a into the repository and returns a new Account -// containing the account's PublicId. a is not changed. a must contain a -// valid AuthMethodId. a must not contain a PublicId. The PublicId is -// generated and assigned by this method. +// containing the account's PublicId. a is not changed. a must contain a valid +// AuthMethodId. a must not contain a PublicId. The PublicId is generated and +// assigned by this method. // // a must contain a valid LoginName. a.LoginName must be unique within // a.AuthMethodId. // -// WithPassword is the only valid option. All other options are ignored. +// WithPassword and WithPublicId are the only valid options. All other options +// are ignored. // // Both a.Name and a.Description are optional. If a.Name is set, it must be // unique within a.AuthMethodId. @@ -55,21 +56,28 @@ func (r *Repository) CreateAccount(ctx context.Context, scopeId string, a *Accou return nil, errors.New(errors.TooShort, op, fmt.Sprintf("username: %s, must be longer than %d", a.LoginName, cc.MinLoginNameLength)) } + opts := getOpts(opt...) + a = a.clone() - id, err := newAccountId() - if err != nil { - return nil, errors.Wrap(err, op) + if opts.withPublicId != "" { + if !strings.HasPrefix(opts.withPublicId, AccountPrefix+"_") { + return nil, errors.New(errors.InvalidParameter, op, "chosen account id does not have a valid prefix") + } + a.PublicId = opts.withPublicId + } else { + id, err := newAccountId() + if err != nil { + return nil, errors.Wrap(err, op) + } + a.PublicId = id } - a.PublicId = id - - opts := getOpts(opt...) var cred *Argon2Credential if opts.withPassword { if cc.MinPasswordLength > len(opts.password) { return nil, errors.New(errors.PasswordTooShort, op, fmt.Sprintf("must be longer than %v", cc.MinPasswordLength)) } - if cred, err = newArgon2Credential(id, opts.password, cc.argon2()); err != nil { + if cred, err = newArgon2Credential(a.PublicId, opts.password, cc.argon2()); err != nil { return nil, errors.Wrap(err, op) } } diff --git a/internal/cmd/base/dev.go b/internal/cmd/base/dev.go index 5d5a9dc0cd..4059eb934a 100644 --- a/internal/cmd/base/dev.go +++ b/internal/cmd/base/dev.go @@ -193,7 +193,8 @@ func (b *Server) CreateDevOidcAuthMethod(ctx context.Context) error { switch { case b.DevUnprivilegedLoginName == "", b.DevUnprivilegedPassword == "", - b.DevUnprivilegedUserId == "": + b.DevUnprivilegedUserId == "", + b.DevUnprivilegedOidcAccountId == "": default: b.DevOidcSetup.createUnpriv = true @@ -390,7 +391,7 @@ func (b *Server) createInitialOidcAuthMethod(ctx context.Context) (*oidc.AuthMet // Create accounts { - createAndLinkAccount := func(loginName, userId, typ string) error { + createAndLinkAccount := func(loginName, userId, accountId, typ string) error { acct, err := oidc.NewAccount( b.DevOidcSetup.authMethod.GetPublicId(), loginName, @@ -403,6 +404,7 @@ func (b *Server) createInitialOidcAuthMethod(ctx context.Context) (*oidc.AuthMet cancelCtx, b.DevOidcSetup.authMethod.GetScopeId(), acct, + oidc.WithPublicId(accountId), ) if err != nil { return fmt.Errorf("error creating %s oidc account: %w", typ, err) @@ -425,11 +427,11 @@ func (b *Server) createInitialOidcAuthMethod(ctx context.Context) (*oidc.AuthMet return nil } - if err := createAndLinkAccount(b.DevLoginName, b.DevUserId, "admin"); err != nil { + if err := createAndLinkAccount(b.DevLoginName, b.DevUserId, b.DevOidcAccountId, "admin"); err != nil { return nil, err } if b.DevOidcSetup.createUnpriv { - if err := createAndLinkAccount(b.DevUnprivilegedLoginName, b.DevUnprivilegedUserId, "unprivileged"); err != nil { + if err := createAndLinkAccount(b.DevUnprivilegedLoginName, b.DevUnprivilegedUserId, b.DevUnprivilegedOidcAccountId, "unprivileged"); err != nil { return nil, err } } diff --git a/internal/cmd/base/initial_resources.go b/internal/cmd/base/initial_resources.go index dfe5cfa3a7..a7f1f6c848 100644 --- a/internal/cmd/base/initial_resources.go +++ b/internal/cmd/base/initial_resources.go @@ -145,7 +145,7 @@ func (b *Server) CreateInitialPasswordAuthMethod(ctx context.Context) (*password return nil, nil, fmt.Errorf("unable to set primary auth method for global scope: %w", err) } - createUser := func(loginName, loginPassword, userId string, admin bool) (*iam.User, error) { + createUser := func(loginName, loginPassword, userId, accountId string, admin bool) (*iam.User, error) { // Create the dev admin user if loginName == "" { return nil, fmt.Errorf("empty login name") @@ -167,7 +167,13 @@ func (b *Server) CreateInitialPasswordAuthMethod(ctx context.Context) (*password if err != nil { return nil, fmt.Errorf("error creating new in memory password auth account: %w", err) } - acct, err = pwRepo.CreateAccount(cancelCtx, scope.Global.String(), acct, password.WithPassword(loginPassword)) + acct, err = pwRepo.CreateAccount( + cancelCtx, + scope.Global.String(), + acct, + password.WithPassword(loginPassword), + password.WithPublicId(accountId), + ) if err != nil { return nil, fmt.Errorf("error saving auth account to the db: %w", err) } @@ -233,7 +239,7 @@ func (b *Server) CreateInitialPasswordAuthMethod(ctx context.Context) (*password b.DevUnprivilegedPassword == "", b.DevUnprivilegedUserId == "": default: - _, err := createUser(b.DevUnprivilegedLoginName, b.DevUnprivilegedPassword, b.DevUnprivilegedUserId, false) + _, err := createUser(b.DevUnprivilegedLoginName, b.DevUnprivilegedPassword, b.DevUnprivilegedUserId, b.DevUnprivilegedPasswordAccountId, false) if err != nil { return nil, nil, err } @@ -257,7 +263,7 @@ func (b *Server) CreateInitialPasswordAuthMethod(ctx context.Context) (*password return nil, nil, fmt.Errorf("error generating initial user id: %w", err) } } - u, err := createUser(b.DevLoginName, b.DevPassword, b.DevUserId, true) + u, err := createUser(b.DevLoginName, b.DevPassword, b.DevUserId, b.DevPasswordAccountId, true) if err != nil { return nil, nil, err } diff --git a/internal/cmd/base/servers.go b/internal/cmd/base/servers.go index b6d1837219..9dcd195f58 100644 --- a/internal/cmd/base/servers.go +++ b/internal/cmd/base/servers.go @@ -67,24 +67,28 @@ type Server struct { Listeners []*ServerListener - DevPasswordAuthMethodId string - DevOidcAuthMethodId string - DevLoginName string - DevPassword string - DevUserId string - DevUnprivilegedLoginName string - DevUnprivilegedPassword string - DevUnprivilegedUserId string - DevOrgId string - DevProjectId string - DevHostCatalogId string - DevHostSetId string - DevHostId string - DevTargetId string - DevHostAddress string - DevTargetDefaultPort int - DevTargetSessionMaxSeconds int - DevTargetSessionConnectionLimit int + DevPasswordAuthMethodId string + DevOidcAuthMethodId string + DevLoginName string + DevPassword string + DevUserId string + DevPasswordAccountId string + DevOidcAccountId string + DevUnprivilegedLoginName string + DevUnprivilegedPassword string + DevUnprivilegedUserId string + DevUnprivilegedPasswordAccountId string + DevUnprivilegedOidcAccountId string + DevOrgId string + DevProjectId string + DevHostCatalogId string + DevHostSetId string + DevHostId string + DevTargetId string + DevHostAddress string + DevTargetDefaultPort int + DevTargetSessionMaxSeconds int + DevTargetSessionConnectionLimit int DevOidcSetup oidcSetup diff --git a/internal/cmd/commands/dev/dev.go b/internal/cmd/commands/dev/dev.go index 4ccfe45e73..704b603991 100644 --- a/internal/cmd/commands/dev/dev.go +++ b/internal/cmd/commands/dev/dev.go @@ -313,7 +313,11 @@ func (c *Command) Run(args []string) int { c.DevPasswordAuthMethodId = fmt.Sprintf("%s_%s", password.AuthMethodPrefix, c.flagIdSuffix) c.DevOidcAuthMethodId = fmt.Sprintf("%s_%s", oidc.AuthMethodPrefix, c.flagIdSuffix) c.DevUserId = fmt.Sprintf("%s_%s", iam.UserPrefix, c.flagIdSuffix) + c.DevPasswordAccountId = fmt.Sprintf("%s_%s", password.AccountPrefix, c.flagIdSuffix) + c.DevOidcAccountId = fmt.Sprintf("%s_%s", oidc.AccountPrefix, c.flagIdSuffix) c.DevUnprivilegedUserId = "u_" + strutil.Reverse(strings.TrimPrefix(c.DevUserId, "u_")) + c.DevUnprivilegedPasswordAccountId = fmt.Sprintf("%s_", password.AccountPrefix) + strutil.Reverse(strings.TrimPrefix(c.DevPasswordAccountId, fmt.Sprintf("%s_", password.AccountPrefix))) + c.DevUnprivilegedOidcAccountId = fmt.Sprintf("%s_", oidc.AccountPrefix) + strutil.Reverse(strings.TrimPrefix(c.DevOidcAccountId, fmt.Sprintf("%s_", oidc.AccountPrefix))) c.DevOrgId = fmt.Sprintf("%s_%s", scope.Org.Prefix(), c.flagIdSuffix) c.DevProjectId = fmt.Sprintf("%s_%s", scope.Project.Prefix(), c.flagIdSuffix) c.DevHostCatalogId = fmt.Sprintf("%s_%s", static.HostCatalogPrefix, c.flagIdSuffix) diff --git a/internal/servers/controller/testing.go b/internal/servers/controller/testing.go index 9e82686dcf..9c63a4a23b 100644 --- a/internal/servers/controller/testing.go +++ b/internal/servers/controller/testing.go @@ -27,12 +27,16 @@ import ( ) const ( - DefaultTestPasswordAuthMethodId = "ampw_1234567890" - DefaultTestOidcAuthMethodId = "amoidc_1234567890" - DefaultTestLoginName = "admin" - DefaultTestUnprivilegedLoginName = "user" - DefaultTestPassword = "passpass" - DefaultTestUserId = "u_1234567890" + DefaultTestPasswordAuthMethodId = "ampw_1234567890" + DefaultTestOidcAuthMethodId = "amoidc_1234567890" + DefaultTestLoginName = "admin" + DefaultTestUnprivilegedLoginName = "user" + DefaultTestPassword = "passpass" + DefaultTestUserId = "u_1234567890" + DefaultTestPasswordAccountId = "apw_1234567890" + DefaultTestOidcAccountId = "acctoidc_1234567890" + DefaultTestUnprivilegedPasswordAccountId = "apw_0987654321" + DefaultTestUnprivilegedOidcAccountId = "acctoidc_0987654321" ) // TestController wraps a base.Server and Controller to provide a @@ -429,6 +433,10 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { tc.b.DevPassword = DefaultTestPassword tc.b.DevUnprivilegedPassword = DefaultTestPassword } + tc.b.DevPasswordAccountId = DefaultTestPasswordAccountId + tc.b.DevOidcAccountId = DefaultTestOidcAccountId + tc.b.DevUnprivilegedPasswordAccountId = DefaultTestUnprivilegedPasswordAccountId + tc.b.DevUnprivilegedOidcAccountId = DefaultTestUnprivilegedOidcAccountId // Start a logger tc.b.Logger = opts.Logger