fix (API): check attributes missing appropriately. (#2219)

* fix (API): check attributes missing appropriately.
pull/2220/head
Jim 4 years ago committed by GitHub
parent 08c2e71272
commit aaf669a044
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -963,31 +963,41 @@ func validateCreateRequest(req *pbs.CreateAccountRequest) error {
badFields[typeField] = "Doesn't match the parent resource's type."
}
attrs := req.GetItem().GetPasswordAccountAttributes()
if attrs.GetLoginName() == "" {
badFields[loginNameKey] = "This is a required field for this type."
switch {
case attrs == nil:
badFields["attributes"] = "This is a required field."
default:
if attrs.GetLoginName() == "" {
badFields[loginNameKey] = "This is a required field for this type."
}
}
case oidc.Subtype:
if req.GetItem().GetType() != "" && req.GetItem().GetType() != oidc.Subtype.String() {
badFields[typeField] = "Doesn't match the parent resource's type."
}
attrs := req.GetItem().GetOidcAccountAttributes()
if attrs.GetSubject() == "" {
badFields[subjectField] = "This is a required field for this type."
}
if attrs.GetIssuer() != "" {
du, err := url.Parse(attrs.GetIssuer())
if err != nil {
badFields[issuerField] = fmt.Sprintf("Cannot be parsed as a url. %v", err)
switch {
case attrs == nil:
badFields["attributes"] = "This is a required field."
default:
if attrs.GetSubject() == "" {
badFields[subjectField] = "This is a required field for this type."
}
if trimmed := strings.TrimSuffix(strings.TrimSuffix(du.RawPath, "/"), "/.well-known/openid-configuration"); trimmed != "" {
badFields[issuerField] = "The path segment of the url should be empty."
if attrs.GetIssuer() != "" {
du, err := url.Parse(attrs.GetIssuer())
if err != nil {
badFields[issuerField] = fmt.Sprintf("Cannot be parsed as a url. %v", err)
}
if trimmed := strings.TrimSuffix(strings.TrimSuffix(du.RawPath, "/"), "/.well-known/openid-configuration"); trimmed != "" {
badFields[issuerField] = "The path segment of the url should be empty."
}
}
if attrs.GetFullName() != "" {
badFields[nameClaimField] = "This is a read only field."
}
if attrs.GetEmail() != "" {
badFields[emailClaimField] = "This is a read only field."
}
}
if attrs.GetFullName() != "" {
badFields[nameClaimField] = "This is a read only field."
}
if attrs.GetEmail() != "" {
badFields[emailClaimField] = "This is a read only field."
}
default:
badFields[authMethodIdField] = "Unknown auth method type from ID."

@ -49,11 +49,22 @@ func TestValidateCreateRequest(t *testing.T) {
},
errContains: fieldError(typeField, "Doesn't match the parent resource's type."),
},
{
name: "missing oidc attributes",
item: &pb.Account{
Type: oidc.Subtype.String(),
AuthMethodId: oidc.AuthMethodPrefix + "_1234567890",
},
errContains: fieldError(attributesField, "This is a required field."),
},
{
name: "missing oidc subject",
item: &pb.Account{
Type: oidc.Subtype.String(),
AuthMethodId: oidc.AuthMethodPrefix + "_1234567890",
Attrs: &pb.Account_OidcAccountAttributes{
OidcAccountAttributes: &pb.OidcAccountAttributes{},
},
},
errContains: fieldError(subjectField, "This is a required field for this type."),
},
@ -79,11 +90,22 @@ func TestValidateCreateRequest(t *testing.T) {
},
errContains: fieldError(emailClaimField, "This is a read only field."),
},
{
name: "missing password attributes",
item: &pb.Account{
Type: password.Subtype.String(),
AuthMethodId: password.AuthMethodPrefix + "_1234567890",
},
errContains: fieldError(attributesField, "This is a required field."),
},
{
name: "missing login name for password type",
item: &pb.Account{
Type: password.Subtype.String(),
AuthMethodId: password.AuthMethodPrefix + "_1234567890",
Attrs: &pb.Account_PasswordAccountAttributes{
PasswordAccountAttributes: &pb.PasswordAccountAttributes{},
},
},
errContains: fieldError(loginNameKey, "This is a required field for this type."),
},

@ -333,33 +333,42 @@ func validateAuthenticateOidcRequest(req *pbs.AuthenticateRequest) error {
case startCommand:
if req.GetOidcStartAttributes() != nil {
attrs := req.GetOidcStartAttributes()
// Ensure we pay no attention to cache information provided by the client
attrs.CachedRoundtripPayload = ""
payload := attrs.GetRoundtripPayload()
if payload == nil {
break
}
m, err := json.Marshal(payload.AsMap())
if err != nil {
// We don't know what's in this payload so we swallow the
// error, as it could be something sensitive.
badFields[roundtripPayloadAttributesField] = "Unable to marshal given value as JSON."
} else {
// Cache for later
attrs.CachedRoundtripPayload = string(m)
switch {
case attrs == nil:
badFields["attributes"] = "Attributes field not supplied request"
default:
// Ensure we pay no attention to cache information provided by the client
attrs.CachedRoundtripPayload = ""
payload := attrs.GetRoundtripPayload()
if payload == nil {
break
}
m, err := json.Marshal(payload.AsMap())
if err != nil {
// We don't know what's in this payload so we swallow the
// error, as it could be something sensitive.
badFields[roundtripPayloadAttributesField] = "Unable to marshal given value as JSON."
} else {
// Cache for later
attrs.CachedRoundtripPayload = string(m)
}
}
}
case callbackCommand:
attrs := req.GetOidcAuthMethodAuthenticateCallbackRequest()
switch {
case attrs == nil:
badFields["attributes"] = "Attributes field not supplied request"
return handlers.InvalidArgumentErrorf("This is a required field.", badFields)
default:
if attrs.GetCode() == "" && attrs.GetError() == "" {
badFields[codeField] = "Code field not supplied in callback request."
}
if attrs.GetCode() == "" && attrs.GetError() == "" {
badFields[codeField] = "Code field not supplied in callback request."
}
if attrs.GetState() == "" {
badFields[stateField] = "State field not supplied in callback request."
if attrs.GetState() == "" {
badFields[stateField] = "State field not supplied in callback request."
}
}
case tokenCommand:

@ -137,27 +137,32 @@ func validateAuthenticatePasswordRequest(req *pbs.AuthenticateRequest) error {
badFields := make(map[string]string)
attrs := req.GetPasswordLoginAttributes()
if attrs.LoginName == "" {
badFields["attributes.login_name"] = "This is a required field."
}
if attrs.Password == "" {
badFields["attributes.password"] = "This is a required field."
}
if req.GetCommand() == "" {
// TODO: Eventually, require a command. For now, fall back to "login" for backwards compat.
req.Command = loginCommand
}
if req.Command != loginCommand {
badFields[commandField] = "Invalid command for this auth method type."
}
tokenType := req.GetType()
if tokenType == "" {
// Fall back to deprecated field if type is not set
tokenType = req.GetTokenType()
}
tType := strings.ToLower(strings.TrimSpace(tokenType))
if tType != "" && tType != "token" && tType != "cookie" {
badFields[tokenTypeField] = `The only accepted types are "token" and "cookie".`
switch {
case attrs == nil:
badFields["attributes"] = "This is a required field."
default:
if attrs.LoginName == "" {
badFields["attributes.login_name"] = "This is a required field."
}
if attrs.Password == "" {
badFields["attributes.password"] = "This is a required field."
}
if req.GetCommand() == "" {
// TODO: Eventually, require a command. For now, fall back to "login" for backwards compat.
req.Command = loginCommand
}
if req.Command != loginCommand {
badFields[commandField] = "Invalid command for this auth method type."
}
tokenType := req.GetType()
if tokenType == "" {
// Fall back to deprecated field if type is not set
tokenType = req.GetTokenType()
}
tType := strings.ToLower(strings.TrimSpace(tokenType))
if tType != "" && tType != "token" && tType != "cookie" {
badFields[tokenTypeField] = `The only accepted types are "token" and "cookie".`
}
}
if len(badFields) > 0 {

@ -498,10 +498,11 @@ func TestAuthenticate_Password(t *testing.T) {
require.NotNil(t, acct)
cases := []struct {
name string
request *pbs.AuthenticateRequest
wantType string
wantErr error
name string
request *pbs.AuthenticateRequest
wantType string
wantErr error
wantErrContains string
}{
{
name: "basic",
@ -597,6 +598,16 @@ func TestAuthenticate_Password(t *testing.T) {
},
wantErr: handlers.ApiErrorWithCode(codes.Unauthenticated),
},
{
name: "no-attributes",
request: &pbs.AuthenticateRequest{
AuthMethodId: am.GetPublicId(),
TokenType: "token",
Attrs: &pbs.AuthenticateRequest_PasswordLoginAttributes{},
},
wantErr: handlers.ApiErrorWithCode(codes.InvalidArgument),
wantErrContains: `Details: {{name: "attributes", desc: "This is a required field."}}`,
},
}
for _, tc := range cases {
@ -609,6 +620,9 @@ func TestAuthenticate_Password(t *testing.T) {
if tc.wantErr != nil {
assert.Error(err)
assert.Truef(errors.Is(err, tc.wantErr), "Got %#v, wanted %#v", err, tc.wantErr)
if tc.wantErrContains != "" {
assert.Contains(err.Error(), tc.wantErrContains)
}
return
}
require.NoError(err)

@ -635,19 +635,24 @@ func validateCreateRequest(req *pbs.CreateHostRequest) error {
badFields[globals.DnsNamesField] = "This field is not supported for this host type."
}
attrs := req.GetItem().GetStaticHostAttributes()
if attrs.GetAddress() == nil ||
len(attrs.GetAddress().GetValue()) < static.MinHostAddressLength ||
len(attrs.GetAddress().GetValue()) > static.MaxHostAddressLength {
badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength)
}
_, _, err := net.SplitHostPort(attrs.GetAddress().GetValue())
switch {
case err == nil:
badFields["attributes.address"] = "Address for static hosts does not support a port."
case strings.Contains(err.Error(), "missing port in address"):
// Bare hostname, which we want
case attrs == nil:
badFields["attributes"] = "This is a required field."
default:
badFields["attributes.address"] = fmt.Sprintf("Error parsing address: %v.", err)
if attrs.GetAddress() == nil ||
len(attrs.GetAddress().GetValue()) < static.MinHostAddressLength ||
len(attrs.GetAddress().GetValue()) > static.MaxHostAddressLength {
badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength)
}
_, _, err := net.SplitHostPort(attrs.GetAddress().GetValue())
switch {
case err == nil:
badFields["attributes.address"] = "Address for static hosts does not support a port."
case strings.Contains(err.Error(), "missing port in address"):
// Bare hostname, which we want
default:
badFields["attributes.address"] = fmt.Sprintf("Error parsing address: %v.", err)
}
}
case plugin.Subtype:
badFields[globals.HostCatalogIdField] = "Cannot manually create hosts for this type of catalog."
@ -666,10 +671,15 @@ func validateUpdateRequest(req *pbs.UpdateHostRequest) error {
attrs := req.GetItem().GetStaticHostAttributes()
if handlers.MaskContains(req.GetUpdateMask().GetPaths(), "attributes.address") {
if attrs.GetAddress() == nil ||
len(strings.TrimSpace(attrs.GetAddress().GetValue())) < static.MinHostAddressLength ||
len(strings.TrimSpace(attrs.GetAddress().GetValue())) > static.MaxHostAddressLength {
badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength)
switch {
case attrs == nil:
badFields["attributes"] = "Attributes field not supplied request"
default:
if attrs.GetAddress() == nil ||
len(strings.TrimSpace(attrs.GetAddress().GetValue())) < static.MinHostAddressLength ||
len(strings.TrimSpace(attrs.GetAddress().GetValue())) > static.MaxHostAddressLength {
badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength)
}
}
}
}

@ -624,10 +624,11 @@ func TestCreate(t *testing.T) {
defaultHcCreated := hc.GetCreateTime().GetTimestamp().AsTime()
cases := []struct {
name string
req *pbs.CreateHostRequest
res *pbs.CreateHostResponse
err error
name string
req *pbs.CreateHostRequest
res *pbs.CreateHostResponse
err error
wantErrContains string
}{
{
name: "Create a valid Host",
@ -659,6 +660,17 @@ func TestCreate(t *testing.T) {
},
},
},
{
name: "no-attributes",
req: &pbs.CreateHostRequest{Item: &pb.Host{
HostCatalogId: hc.GetPublicId(),
Name: &wrappers.StringValue{Value: "name"},
Description: &wrappers.StringValue{Value: "desc"},
Type: "static",
}},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
wantErrContains: `Details: {{name: "attributes", desc: "This is a required field."}}`,
},
{
name: "Create a plugin Host",
req: &pbs.CreateHostRequest{Item: &pb.Host{
@ -773,6 +785,9 @@ func TestCreate(t *testing.T) {
if tc.err != nil {
require.Error(gErr)
assert.True(errors.Is(gErr, tc.err), "CreateHost(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)
if tc.wantErrContains != "" {
assert.Contains(gErr.Error(), tc.wantErrContains)
}
}
if got != nil {
assert.Contains(got.GetUri(), tc.res.GetUri())

@ -630,11 +630,16 @@ func validateUpdateRequest(req *pbs.UpdateManagedGroupRequest) error {
}
attrs := req.GetItem().GetOidcManagedGroupAttributes()
if handlers.MaskContains(req.GetUpdateMask().GetPaths(), attrFilterField) {
if attrs.Filter == "" {
badFields[attrFilterField] = "Field cannot be empty."
} else {
if _, err := bexpr.CreateEvaluator(attrs.Filter); err != nil {
badFields[attrFilterField] = fmt.Sprintf("Error evaluating submitted filter expression: %v.", err)
switch {
case attrs == nil:
badFields["attributes"] = "Attributes field not supplied request"
default:
if attrs.Filter == "" {
badFields[attrFilterField] = "Field cannot be empty."
} else {
if _, err := bexpr.CreateEvaluator(attrs.Filter); err != nil {
badFields[attrFilterField] = fmt.Sprintf("Error evaluating submitted filter expression: %v.", err)
}
}
}
}

Loading…
Cancel
Save