From 072b950540a6adeb030d6bba71591eed9a4f6ded Mon Sep 17 00:00:00 2001 From: Todd Knight Date: Tue, 13 Apr 2021 16:42:25 -0700 Subject: [PATCH] Filter inactive and private oidc auth methods from unauthenticated list requests (#1110) --- internal/auth/auth.go | 10 +- internal/auth/oidc/service_callback_test.go | 46 ++++---- internal/auth/option.go | 4 +- internal/auth/options_test.go | 4 +- internal/cmd/base/initial_resources.go | 3 +- .../authmethods/authmethod_service.go | 6 +- .../authmethods/authmethod_service_test.go | 12 +- .../handlers/authmethods/oidc_test.go | 106 ++++++++++++++++++ 8 files changed, 155 insertions(+), 36 deletions(-) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 7b213b9662..0799a0d2da 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -42,6 +42,10 @@ const ( AuthTokenTypeRecoveryKms ) +const ( + AnonymousUserId = "u_anon" +) + type key int var verifierKey key @@ -238,7 +242,7 @@ func Verify(ctx context.Context, opt ...Option) (ret VerifyResults) { // If the anon user was used (either no token, or invalid (perhaps // expired) token), return a 401. That way if it's an authn'd user // that is not authz'd we'll return 403 to be explicit. - if ret.UserId == "u_anon" { + if ret.UserId == AnonymousUserId { ret.Error = handlers.UnauthenticatedError() } return @@ -381,7 +385,7 @@ func (v verifier) performAuthCheck() (aclResults perms.ACLResults, userId string // Make the linter happy _ = retErr scopeInfo = new(scopes.ScopeInfo) - userId = "u_anon" + userId = AnonymousUserId var accountId string // Validate the token and fetch the corresponding user ID @@ -416,7 +420,7 @@ func (v verifier) performAuthCheck() (aclResults perms.ACLResults, userId string userId = at.GetIamUserId() if userId == "" { v.logger.Warn("perform auth check: valid token did not map to a user, likely because no account is associated with the user any longer; continuing as u_anon", "token_id", at.GetPublicId()) - userId = "u_anon" + userId = AnonymousUserId accountId = "" } } diff --git a/internal/auth/oidc/service_callback_test.go b/internal/auth/oidc/service_callback_test.go index 9114dc409e..687098a30b 100644 --- a/internal/auth/oidc/service_callback_test.go +++ b/internal/auth/oidc/service_callback_test.go @@ -114,20 +114,20 @@ func Test_Callback(t *testing.T) { require.NoError(t, err) tests := []struct { - name string - setup func() // provide a simple way to do some prework before the test. - oidcRepoFn OidcRepoFactory // returns a new oidc repo - iamRepoFn IamRepoFactory // returns a new iam repo - atRepoFn AuthTokenRepoFactory // returns a new auth token repo - am *AuthMethod // the authmethod for the test - state string // state parameter for test provider and Callback(...) - code string // code parameter for test provider and Callback(...) - wantSubject string // sub claim from id token - wantInfoName string // name claim from userinfo - wantInfoEmail string // email claim from userinfo - wantFinalRedirect string // final redirect from Callback(...) - wantErrMatch *errors.Template // error template to match - wantErrContains string // error string should contain + name string + setup func() // provide a simple way to do some prework before the test. + oidcRepoFn OidcRepoFactory // returns a new oidc repo + iamRepoFn IamRepoFactory // returns a new iam repo + atRepoFn AuthTokenRepoFactory // returns a new auth token repo + am *AuthMethod // the authmethod for the test + state string // state parameter for test provider and Callback(...) + code string // code parameter for test provider and Callback(...) + wantSubject string // sub claim from id token + wantInfoName string // name claim from userinfo + wantInfoEmail string // email claim from userinfo + wantFinalRedirect string // final redirect from Callback(...) + wantErrMatch *errors.Template // error template to match + wantErrContains string // error string should contain }{ { name: "simple", // must remain the first test @@ -234,15 +234,15 @@ func Test_Callback(t *testing.T) { wantErrContains: "missing auth method", }, { - name: "mismatch-auth-method", - oidcRepoFn: repoFn, - iamRepoFn: iamRepoFn, - atRepoFn: atRepoFn, - am: testAuthMethod, - state: testState(t, testAuthMethod2, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce), - code: "simple", - wantErrMatch: errors.T(errors.InvalidParameter), - wantErrContains: "auth method id does not match request wrapper auth method id", + name: "mismatch-auth-method", + oidcRepoFn: repoFn, + iamRepoFn: iamRepoFn, + atRepoFn: atRepoFn, + am: testAuthMethod, + state: testState(t, testAuthMethod2, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce), + code: "simple", + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "auth method id does not match request wrapper auth method id", }, { name: "bad-state-encoding", diff --git a/internal/auth/option.go b/internal/auth/option.go index 939a2f0d01..9b55b75767 100644 --- a/internal/auth/option.go +++ b/internal/auth/option.go @@ -33,7 +33,9 @@ type options struct { } func getDefaultOptions() options { - return options{} + return options{ + withUserId: AnonymousUserId, + } } func WithScopeId(id string) Option { diff --git a/internal/auth/options_test.go b/internal/auth/options_test.go index e68a38b46d..fe1b3a7d4e 100644 --- a/internal/auth/options_test.go +++ b/internal/auth/options_test.go @@ -14,7 +14,9 @@ import ( func Test_GetOpts(t *testing.T) { t.Parallel() opts := getOpts() - assert.Equal(t, options{}, opts) + assert.Equal(t, options{ + withUserId: AnonymousUserId, + }, opts) withKms := new(kms.Kms) res := new(perms.Resource) diff --git a/internal/cmd/base/initial_resources.go b/internal/cmd/base/initial_resources.go index 8c55d9e581..b2f2adc1fc 100644 --- a/internal/cmd/base/initial_resources.go +++ b/internal/cmd/base/initial_resources.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" + "github.com/hashicorp/boundary/internal/auth" "github.com/hashicorp/boundary/internal/auth/password" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/host/static" @@ -64,7 +65,7 @@ func (b *Server) CreateInitialLoginRole(ctx context.Context) (*iam.Role, error) }); err != nil { return nil, fmt.Errorf("error creating grant for default generated grants: %w", err) } - if _, err := iamRepo.AddPrincipalRoles(cancelCtx, role.PublicId, role.Version+1, []string{"u_anon"}, nil); err != nil { + if _, err := iamRepo.AddPrincipalRoles(cancelCtx, role.PublicId, role.Version+1, []string{auth.AnonymousUserId}, nil); err != nil { return nil, fmt.Errorf("error adding principal to role for default generated grants: %w", err) } diff --git a/internal/servers/controller/handlers/authmethods/authmethod_service.go b/internal/servers/controller/handlers/authmethods/authmethod_service.go index 1c7e458f50..7b79cd2bed 100644 --- a/internal/servers/controller/handlers/authmethods/authmethod_service.go +++ b/internal/servers/controller/handlers/authmethods/authmethod_service.go @@ -157,7 +157,7 @@ func (s Service) ListAuthMethods(ctx context.Context, req *pbs.ListAuthMethodsRe return &pbs.ListAuthMethodsResponse{}, nil } - ul, err := s.listFromRepo(ctx, scopeIds) + ul, err := s.listFromRepo(ctx, scopeIds, authResults.UserId == auth.AnonymousUserId) if err != nil { return nil, err } @@ -376,12 +376,12 @@ func (s Service) getFromRepo(ctx context.Context, id string) (*pb.AuthMethod, er return toAuthMethodProto(am) } -func (s Service) listFromRepo(ctx context.Context, scopeIds []string) ([]*pb.AuthMethod, error) { +func (s Service) listFromRepo(ctx context.Context, scopeIds []string, unauthn bool) ([]*pb.AuthMethod, error) { oidcRepo, err := s.oidcRepoFn() if err != nil { return nil, err } - ol, err := oidcRepo.ListAuthMethods(ctx, scopeIds) + ol, err := oidcRepo.ListAuthMethods(ctx, scopeIds, oidc.WithUnauthenticatedUser(unauthn)) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/authmethods/authmethod_service_test.go b/internal/servers/controller/handlers/authmethods/authmethod_service_test.go index 06098a1cc0..0d67b2b41d 100644 --- a/internal/servers/controller/handlers/authmethods/authmethod_service_test.go +++ b/internal/servers/controller/handlers/authmethods/authmethod_service_test.go @@ -226,8 +226,8 @@ func TestList(t *testing.T) { var wantSomeAuthMethods []*pb.AuthMethod databaseWrapper, err := kmsCache.GetWrapper(context.Background(), oWithAuthMethods.GetPublicId(), kms.KeyPurposeDatabase) require.NoError(t, err) - oidcam := oidc.TestAuthMethod(t, conn, databaseWrapper, oWithAuthMethods.GetPublicId(), oidc.InactiveState, "alice_rp", "secret", - oidc.WithIssuer(oidc.TestConvertToUrls(t, "https://alice.com")[0]), oidc.WithApiUrl(oidc.TestConvertToUrls(t, "https://api.com")[0])) + oidcam := oidc.TestAuthMethod(t, conn, databaseWrapper, oWithAuthMethods.GetPublicId(), oidc.ActivePublicState, "alice_rp", "secret", + oidc.WithIssuer(oidc.TestConvertToUrls(t, "https://alice.com")[0]), oidc.WithApiUrl(oidc.TestConvertToUrls(t, "https://api.com")[0]), oidc.WithSigningAlgs(oidc.EdDSA)) iam.TestSetPrimaryAuthMethod(t, iamRepo, oWithAuthMethods, oidcam.GetPublicId()) wantSomeAuthMethods = append(wantSomeAuthMethods, &pb.AuthMethod{ @@ -236,15 +236,19 @@ func TestList(t *testing.T) { CreatedTime: oidcam.GetCreateTime().GetTimestamp(), UpdatedTime: oidcam.GetUpdateTime().GetTimestamp(), Scope: &scopepb.ScopeInfo{Id: oWithAuthMethods.GetPublicId(), Type: scope.Org.String(), ParentScopeId: scope.Global.String()}, - Version: 1, + Version: 2, Type: auth.OidcSubtype.String(), Attributes: &structpb.Struct{Fields: map[string]*structpb.Value{ "issuer": structpb.NewStringValue("https://alice.com"), "client_id": structpb.NewStringValue("alice_rp"), "client_secret_hmac": structpb.NewStringValue(""), - "state": structpb.NewStringValue(string(oidc.InactiveState)), + "state": structpb.NewStringValue(string(oidc.ActivePublicState)), "api_url_prefix": structpb.NewStringValue("https://api.com"), "callback_url": structpb.NewStringValue(fmt.Sprintf(oidc.CallbackEndpoint, "https://api.com", oidcam.GetPublicId())), + "signing_algorithms": func() *structpb.Value { + lv, _ := structpb.NewList([]interface{}{string(oidc.EdDSA)}) + return structpb.NewListValue(lv) + }(), }}, IsPrimary: true, AuthorizedActions: oidcAuthorizedActions, diff --git a/internal/servers/controller/handlers/authmethods/oidc_test.go b/internal/servers/controller/handlers/authmethods/oidc_test.go index 5b85a1dad2..30d2591b91 100644 --- a/internal/servers/controller/handlers/authmethods/oidc_test.go +++ b/internal/servers/controller/handlers/authmethods/oidc_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "regexp" + "sort" "testing" "time" @@ -22,6 +23,7 @@ import ( pbs "github.com/hashicorp/boundary/internal/gen/controller/api/services" "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/servers" "github.com/hashicorp/boundary/internal/servers/controller/common" "github.com/hashicorp/boundary/internal/servers/controller/handlers" "github.com/hashicorp/boundary/internal/servers/controller/handlers/authmethods" @@ -127,6 +129,110 @@ func getSetup(t *testing.T) setup { return ret } +func TestList_FilterNonPublic(t *testing.T) { + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + kmsCache := kms.TestKms(t, conn, wrapper) + iamRepoFn := func() (*iam.Repository, error) { + return iam.TestRepo(t, conn, wrapper), nil + } + oidcRepoFn := func() (*oidc.Repository, error) { + return oidc.NewRepository(rw, rw, kmsCache) + } + pwRepoFn := func() (*password.Repository, error) { + return password.NewRepository(rw, rw, kmsCache) + } + atRepoFn := func() (*authtoken.Repository, error) { + return authtoken.NewRepository(rw, rw, kmsCache) + } + authTokenRepoFn := func() (*authtoken.Repository, error) { + return authtoken.NewRepository(rw, rw, kmsCache) + } + serversRepoFn := func() (*servers.Repository, error) { + return servers.NewRepository(rw, rw, kmsCache) + } + iamRepo := iam.TestRepo(t, conn, wrapper) + + o, _ := iam.TestScopes(t, iamRepo) + + databaseWrapper, err := kmsCache.GetWrapper(context.Background(), o.GetPublicId(), kms.KeyPurposeDatabase) + require.NoError(t, err) + + // 1 Public + i := 0 + oidcam := oidc.TestAuthMethod(t, conn, databaseWrapper, o.GetPublicId(), oidc.ActivePublicState, "alice_rp", "secret", oidc.WithDescription(fmt.Sprintf("%d", i)), + oidc.WithIssuer(oidc.TestConvertToUrls(t, fmt.Sprintf("https://alice%d.com", i))[0]), oidc.WithApiUrl(oidc.TestConvertToUrls(t, "https://api.com")[0]), oidc.WithSigningAlgs(oidc.EdDSA)) + i++ + iam.TestSetPrimaryAuthMethod(t, iamRepo, o, oidcam.GetPublicId()) + + // 4 private + for ; i < 4; i++ { + _ = oidc.TestAuthMethod(t, conn, databaseWrapper, o.GetPublicId(), oidc.ActivePrivateState, "alice_rp", "secret", oidc.WithDescription(fmt.Sprintf("%d", i)), + oidc.WithIssuer(oidc.TestConvertToUrls(t, fmt.Sprintf("https://alice%d.com", i))[0]), oidc.WithApiUrl(oidc.TestConvertToUrls(t, "https://api.com")[0]), oidc.WithSigningAlgs(oidc.EdDSA)) + } + + // 5 inactive + for ; i < 10; i++ { + _ = oidc.TestAuthMethod(t, conn, databaseWrapper, o.GetPublicId(), oidc.InactiveState, "alice_rp", "secret", oidc.WithDescription(fmt.Sprintf("%d", i)), + oidc.WithIssuer(oidc.TestConvertToUrls(t, fmt.Sprintf("https://alice%d.com", i))[0]), oidc.WithApiUrl(oidc.TestConvertToUrls(t, "https://api.com")[0])) + } + + s, err := authmethods.NewService(kmsCache, pwRepoFn, oidcRepoFn, iamRepoFn, atRepoFn) + require.NoError(t, err, "Couldn't create new auth_method service.") + + req := &pbs.ListAuthMethodsRequest{ + ScopeId: o.GetPublicId(), + Filter: `"/item/type"=="oidc"`, // We are concerned about OIDC auth methods being filtered by authn state + } + + cases := []struct { + name string + reqCtx context.Context + respCount int + }{ + { + name: "unauthenticated", + reqCtx: auth.DisabledAuthTestContext(iamRepoFn, o.GetPublicId()), + respCount: 1, + }, + { + name: "authenticated", + reqCtx: func() context.Context { + at := authtoken.TestAuthToken(t, conn, kmsCache, o.GetPublicId()) + return auth.NewVerifierContext(context.Background(), + nil, + iamRepoFn, + authTokenRepoFn, + serversRepoFn, + kmsCache, + auth.RequestInfo{ + Token: at.GetToken(), + TokenFormat: auth.AuthTokenTypeBearer, + PublicId: at.GetPublicId(), + }) + }(), + respCount: 10, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, err := s.ListAuthMethods(tc.reqCtx, req) + assert.NoError(t, err) + require.NotNil(t, got) + require.Len(t, got.GetItems(), tc.respCount) + + gotSorted := got.GetItems() + sort.Slice(gotSorted, func(i, j int) bool { + return gotSorted[i].GetDescription().GetValue() < gotSorted[j].GetDescription().GetValue() + }) + for i := 0; i < tc.respCount; i++ { + assert.Equal(t, fmt.Sprintf("%d", i), gotSorted[i].GetDescription().GetValue(), "Auth method with description '%d' missing", i) + } + }) + } +} + func TestUpdate_OIDC(t *testing.T) { conn, _ := db.TestSetup(t, "postgres") rw := db.New(conn)