From aaf669a0449d73137bef055d98258f70cc6ede71 Mon Sep 17 00:00:00 2001 From: Jim Date: Sun, 19 Jun 2022 10:55:37 -0400 Subject: [PATCH] fix (API): check attributes missing appropriately. (#2219) * fix (API): check attributes missing appropriately. --- .../handlers/accounts/account_service.go | 44 +++++++++------ .../handlers/accounts/validate_test.go | 22 ++++++++ .../controller/handlers/authmethods/oidc.go | 53 +++++++++++-------- .../handlers/authmethods/password.go | 47 ++++++++-------- .../handlers/authmethods/password_test.go | 22 ++++++-- .../controller/handlers/hosts/host_service.go | 40 ++++++++------ .../handlers/hosts/host_service_test.go | 23 ++++++-- .../managed_groups/managed_group_service.go | 15 ++++-- 8 files changed, 178 insertions(+), 88 deletions(-) diff --git a/internal/daemon/controller/handlers/accounts/account_service.go b/internal/daemon/controller/handlers/accounts/account_service.go index fb3bcf8564..a5d3fd5aa9 100644 --- a/internal/daemon/controller/handlers/accounts/account_service.go +++ b/internal/daemon/controller/handlers/accounts/account_service.go @@ -963,31 +963,41 @@ func validateCreateRequest(req *pbs.CreateAccountRequest) error { badFields[typeField] = "Doesn't match the parent resource's type." } attrs := req.GetItem().GetPasswordAccountAttributes() - if attrs.GetLoginName() == "" { - badFields[loginNameKey] = "This is a required field for this type." + switch { + case attrs == nil: + badFields["attributes"] = "This is a required field." + default: + if attrs.GetLoginName() == "" { + badFields[loginNameKey] = "This is a required field for this type." + } } case oidc.Subtype: if req.GetItem().GetType() != "" && req.GetItem().GetType() != oidc.Subtype.String() { badFields[typeField] = "Doesn't match the parent resource's type." } attrs := req.GetItem().GetOidcAccountAttributes() - if attrs.GetSubject() == "" { - badFields[subjectField] = "This is a required field for this type." - } - if attrs.GetIssuer() != "" { - du, err := url.Parse(attrs.GetIssuer()) - if err != nil { - badFields[issuerField] = fmt.Sprintf("Cannot be parsed as a url. %v", err) + switch { + case attrs == nil: + badFields["attributes"] = "This is a required field." + default: + if attrs.GetSubject() == "" { + badFields[subjectField] = "This is a required field for this type." } - if trimmed := strings.TrimSuffix(strings.TrimSuffix(du.RawPath, "/"), "/.well-known/openid-configuration"); trimmed != "" { - badFields[issuerField] = "The path segment of the url should be empty." + if attrs.GetIssuer() != "" { + du, err := url.Parse(attrs.GetIssuer()) + if err != nil { + badFields[issuerField] = fmt.Sprintf("Cannot be parsed as a url. %v", err) + } + if trimmed := strings.TrimSuffix(strings.TrimSuffix(du.RawPath, "/"), "/.well-known/openid-configuration"); trimmed != "" { + badFields[issuerField] = "The path segment of the url should be empty." + } + } + if attrs.GetFullName() != "" { + badFields[nameClaimField] = "This is a read only field." + } + if attrs.GetEmail() != "" { + badFields[emailClaimField] = "This is a read only field." } - } - if attrs.GetFullName() != "" { - badFields[nameClaimField] = "This is a read only field." - } - if attrs.GetEmail() != "" { - badFields[emailClaimField] = "This is a read only field." } default: badFields[authMethodIdField] = "Unknown auth method type from ID." diff --git a/internal/daemon/controller/handlers/accounts/validate_test.go b/internal/daemon/controller/handlers/accounts/validate_test.go index e4190d6fb8..8c55b495a5 100644 --- a/internal/daemon/controller/handlers/accounts/validate_test.go +++ b/internal/daemon/controller/handlers/accounts/validate_test.go @@ -49,11 +49,22 @@ func TestValidateCreateRequest(t *testing.T) { }, errContains: fieldError(typeField, "Doesn't match the parent resource's type."), }, + { + name: "missing oidc attributes", + item: &pb.Account{ + Type: oidc.Subtype.String(), + AuthMethodId: oidc.AuthMethodPrefix + "_1234567890", + }, + errContains: fieldError(attributesField, "This is a required field."), + }, { name: "missing oidc subject", item: &pb.Account{ Type: oidc.Subtype.String(), AuthMethodId: oidc.AuthMethodPrefix + "_1234567890", + Attrs: &pb.Account_OidcAccountAttributes{ + OidcAccountAttributes: &pb.OidcAccountAttributes{}, + }, }, errContains: fieldError(subjectField, "This is a required field for this type."), }, @@ -79,11 +90,22 @@ func TestValidateCreateRequest(t *testing.T) { }, errContains: fieldError(emailClaimField, "This is a read only field."), }, + { + name: "missing password attributes", + item: &pb.Account{ + Type: password.Subtype.String(), + AuthMethodId: password.AuthMethodPrefix + "_1234567890", + }, + errContains: fieldError(attributesField, "This is a required field."), + }, { name: "missing login name for password type", item: &pb.Account{ Type: password.Subtype.String(), AuthMethodId: password.AuthMethodPrefix + "_1234567890", + Attrs: &pb.Account_PasswordAccountAttributes{ + PasswordAccountAttributes: &pb.PasswordAccountAttributes{}, + }, }, errContains: fieldError(loginNameKey, "This is a required field for this type."), }, diff --git a/internal/daemon/controller/handlers/authmethods/oidc.go b/internal/daemon/controller/handlers/authmethods/oidc.go index e52da632d3..87df9bfe3d 100644 --- a/internal/daemon/controller/handlers/authmethods/oidc.go +++ b/internal/daemon/controller/handlers/authmethods/oidc.go @@ -333,33 +333,42 @@ func validateAuthenticateOidcRequest(req *pbs.AuthenticateRequest) error { case startCommand: if req.GetOidcStartAttributes() != nil { attrs := req.GetOidcStartAttributes() - - // Ensure we pay no attention to cache information provided by the client - attrs.CachedRoundtripPayload = "" - - payload := attrs.GetRoundtripPayload() - if payload == nil { - break - } - m, err := json.Marshal(payload.AsMap()) - if err != nil { - // We don't know what's in this payload so we swallow the - // error, as it could be something sensitive. - badFields[roundtripPayloadAttributesField] = "Unable to marshal given value as JSON." - } else { - // Cache for later - attrs.CachedRoundtripPayload = string(m) + switch { + case attrs == nil: + badFields["attributes"] = "Attributes field not supplied request" + default: + // Ensure we pay no attention to cache information provided by the client + attrs.CachedRoundtripPayload = "" + + payload := attrs.GetRoundtripPayload() + if payload == nil { + break + } + m, err := json.Marshal(payload.AsMap()) + if err != nil { + // We don't know what's in this payload so we swallow the + // error, as it could be something sensitive. + badFields[roundtripPayloadAttributesField] = "Unable to marshal given value as JSON." + } else { + // Cache for later + attrs.CachedRoundtripPayload = string(m) + } } } case callbackCommand: attrs := req.GetOidcAuthMethodAuthenticateCallbackRequest() + switch { + case attrs == nil: + badFields["attributes"] = "Attributes field not supplied request" + return handlers.InvalidArgumentErrorf("This is a required field.", badFields) + default: + if attrs.GetCode() == "" && attrs.GetError() == "" { + badFields[codeField] = "Code field not supplied in callback request." + } - if attrs.GetCode() == "" && attrs.GetError() == "" { - badFields[codeField] = "Code field not supplied in callback request." - } - - if attrs.GetState() == "" { - badFields[stateField] = "State field not supplied in callback request." + if attrs.GetState() == "" { + badFields[stateField] = "State field not supplied in callback request." + } } case tokenCommand: diff --git a/internal/daemon/controller/handlers/authmethods/password.go b/internal/daemon/controller/handlers/authmethods/password.go index a5badec255..f7904aa79a 100644 --- a/internal/daemon/controller/handlers/authmethods/password.go +++ b/internal/daemon/controller/handlers/authmethods/password.go @@ -137,27 +137,32 @@ func validateAuthenticatePasswordRequest(req *pbs.AuthenticateRequest) error { badFields := make(map[string]string) attrs := req.GetPasswordLoginAttributes() - if attrs.LoginName == "" { - badFields["attributes.login_name"] = "This is a required field." - } - if attrs.Password == "" { - badFields["attributes.password"] = "This is a required field." - } - if req.GetCommand() == "" { - // TODO: Eventually, require a command. For now, fall back to "login" for backwards compat. - req.Command = loginCommand - } - if req.Command != loginCommand { - badFields[commandField] = "Invalid command for this auth method type." - } - tokenType := req.GetType() - if tokenType == "" { - // Fall back to deprecated field if type is not set - tokenType = req.GetTokenType() - } - tType := strings.ToLower(strings.TrimSpace(tokenType)) - if tType != "" && tType != "token" && tType != "cookie" { - badFields[tokenTypeField] = `The only accepted types are "token" and "cookie".` + switch { + case attrs == nil: + badFields["attributes"] = "This is a required field." + default: + if attrs.LoginName == "" { + badFields["attributes.login_name"] = "This is a required field." + } + if attrs.Password == "" { + badFields["attributes.password"] = "This is a required field." + } + if req.GetCommand() == "" { + // TODO: Eventually, require a command. For now, fall back to "login" for backwards compat. + req.Command = loginCommand + } + if req.Command != loginCommand { + badFields[commandField] = "Invalid command for this auth method type." + } + tokenType := req.GetType() + if tokenType == "" { + // Fall back to deprecated field if type is not set + tokenType = req.GetTokenType() + } + tType := strings.ToLower(strings.TrimSpace(tokenType)) + if tType != "" && tType != "token" && tType != "cookie" { + badFields[tokenTypeField] = `The only accepted types are "token" and "cookie".` + } } if len(badFields) > 0 { diff --git a/internal/daemon/controller/handlers/authmethods/password_test.go b/internal/daemon/controller/handlers/authmethods/password_test.go index 0acff20589..3f1688fb4d 100644 --- a/internal/daemon/controller/handlers/authmethods/password_test.go +++ b/internal/daemon/controller/handlers/authmethods/password_test.go @@ -498,10 +498,11 @@ func TestAuthenticate_Password(t *testing.T) { require.NotNil(t, acct) cases := []struct { - name string - request *pbs.AuthenticateRequest - wantType string - wantErr error + name string + request *pbs.AuthenticateRequest + wantType string + wantErr error + wantErrContains string }{ { name: "basic", @@ -597,6 +598,16 @@ func TestAuthenticate_Password(t *testing.T) { }, wantErr: handlers.ApiErrorWithCode(codes.Unauthenticated), }, + { + name: "no-attributes", + request: &pbs.AuthenticateRequest{ + AuthMethodId: am.GetPublicId(), + TokenType: "token", + Attrs: &pbs.AuthenticateRequest_PasswordLoginAttributes{}, + }, + wantErr: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: `Details: {{name: "attributes", desc: "This is a required field."}}`, + }, } for _, tc := range cases { @@ -609,6 +620,9 @@ func TestAuthenticate_Password(t *testing.T) { if tc.wantErr != nil { assert.Error(err) assert.Truef(errors.Is(err, tc.wantErr), "Got %#v, wanted %#v", err, tc.wantErr) + if tc.wantErrContains != "" { + assert.Contains(err.Error(), tc.wantErrContains) + } return } require.NoError(err) diff --git a/internal/daemon/controller/handlers/hosts/host_service.go b/internal/daemon/controller/handlers/hosts/host_service.go index 59ee484e22..24fd32e201 100644 --- a/internal/daemon/controller/handlers/hosts/host_service.go +++ b/internal/daemon/controller/handlers/hosts/host_service.go @@ -635,19 +635,24 @@ func validateCreateRequest(req *pbs.CreateHostRequest) error { badFields[globals.DnsNamesField] = "This field is not supported for this host type." } attrs := req.GetItem().GetStaticHostAttributes() - if attrs.GetAddress() == nil || - len(attrs.GetAddress().GetValue()) < static.MinHostAddressLength || - len(attrs.GetAddress().GetValue()) > static.MaxHostAddressLength { - badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) - } - _, _, err := net.SplitHostPort(attrs.GetAddress().GetValue()) switch { - case err == nil: - badFields["attributes.address"] = "Address for static hosts does not support a port." - case strings.Contains(err.Error(), "missing port in address"): - // Bare hostname, which we want + case attrs == nil: + badFields["attributes"] = "This is a required field." default: - badFields["attributes.address"] = fmt.Sprintf("Error parsing address: %v.", err) + if attrs.GetAddress() == nil || + len(attrs.GetAddress().GetValue()) < static.MinHostAddressLength || + len(attrs.GetAddress().GetValue()) > static.MaxHostAddressLength { + badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) + } + _, _, err := net.SplitHostPort(attrs.GetAddress().GetValue()) + switch { + case err == nil: + badFields["attributes.address"] = "Address for static hosts does not support a port." + case strings.Contains(err.Error(), "missing port in address"): + // Bare hostname, which we want + default: + badFields["attributes.address"] = fmt.Sprintf("Error parsing address: %v.", err) + } } case plugin.Subtype: badFields[globals.HostCatalogIdField] = "Cannot manually create hosts for this type of catalog." @@ -666,10 +671,15 @@ func validateUpdateRequest(req *pbs.UpdateHostRequest) error { attrs := req.GetItem().GetStaticHostAttributes() if handlers.MaskContains(req.GetUpdateMask().GetPaths(), "attributes.address") { - if attrs.GetAddress() == nil || - len(strings.TrimSpace(attrs.GetAddress().GetValue())) < static.MinHostAddressLength || - len(strings.TrimSpace(attrs.GetAddress().GetValue())) > static.MaxHostAddressLength { - badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) + switch { + case attrs == nil: + badFields["attributes"] = "Attributes field not supplied request" + default: + if attrs.GetAddress() == nil || + len(strings.TrimSpace(attrs.GetAddress().GetValue())) < static.MinHostAddressLength || + len(strings.TrimSpace(attrs.GetAddress().GetValue())) > static.MaxHostAddressLength { + badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) + } } } } diff --git a/internal/daemon/controller/handlers/hosts/host_service_test.go b/internal/daemon/controller/handlers/hosts/host_service_test.go index 50f12d4925..7e41152365 100644 --- a/internal/daemon/controller/handlers/hosts/host_service_test.go +++ b/internal/daemon/controller/handlers/hosts/host_service_test.go @@ -624,10 +624,11 @@ func TestCreate(t *testing.T) { defaultHcCreated := hc.GetCreateTime().GetTimestamp().AsTime() cases := []struct { - name string - req *pbs.CreateHostRequest - res *pbs.CreateHostResponse - err error + name string + req *pbs.CreateHostRequest + res *pbs.CreateHostResponse + err error + wantErrContains string }{ { name: "Create a valid Host", @@ -659,6 +660,17 @@ func TestCreate(t *testing.T) { }, }, }, + { + name: "no-attributes", + req: &pbs.CreateHostRequest{Item: &pb.Host{ + HostCatalogId: hc.GetPublicId(), + Name: &wrappers.StringValue{Value: "name"}, + Description: &wrappers.StringValue{Value: "desc"}, + Type: "static", + }}, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: `Details: {{name: "attributes", desc: "This is a required field."}}`, + }, { name: "Create a plugin Host", req: &pbs.CreateHostRequest{Item: &pb.Host{ @@ -773,6 +785,9 @@ func TestCreate(t *testing.T) { if tc.err != nil { require.Error(gErr) assert.True(errors.Is(gErr, tc.err), "CreateHost(%+v) got error %v, wanted %v", tc.req, gErr, tc.err) + if tc.wantErrContains != "" { + assert.Contains(gErr.Error(), tc.wantErrContains) + } } if got != nil { assert.Contains(got.GetUri(), tc.res.GetUri()) diff --git a/internal/daemon/controller/handlers/managed_groups/managed_group_service.go b/internal/daemon/controller/handlers/managed_groups/managed_group_service.go index 042b2e07d0..5b5307abf0 100644 --- a/internal/daemon/controller/handlers/managed_groups/managed_group_service.go +++ b/internal/daemon/controller/handlers/managed_groups/managed_group_service.go @@ -630,11 +630,16 @@ func validateUpdateRequest(req *pbs.UpdateManagedGroupRequest) error { } attrs := req.GetItem().GetOidcManagedGroupAttributes() if handlers.MaskContains(req.GetUpdateMask().GetPaths(), attrFilterField) { - if attrs.Filter == "" { - badFields[attrFilterField] = "Field cannot be empty." - } else { - if _, err := bexpr.CreateEvaluator(attrs.Filter); err != nil { - badFields[attrFilterField] = fmt.Sprintf("Error evaluating submitted filter expression: %v.", err) + switch { + case attrs == nil: + badFields["attributes"] = "Attributes field not supplied request" + default: + if attrs.Filter == "" { + badFields[attrFilterField] = "Field cannot be empty." + } else { + if _, err := bexpr.CreateEvaluator(attrs.Filter); err != nil { + badFields[attrFilterField] = fmt.Sprintf("Error evaluating submitted filter expression: %v.", err) + } } } }