From e455e31f9e397c26ad454c49178566d2ff1e6323 Mon Sep 17 00:00:00 2001 From: Hugo Date: Thu, 2 Feb 2023 18:51:39 +0000 Subject: [PATCH] fix(target): Incorrectly allowing whitespace on Target's address field (#2862) --- .../handlers/targets/target_service.go | 4 +- internal/target/address.go | 2 + internal/target/repository.go | 2 + .../target/tcp/repository_tcp_target_test.go | 44 ++++++++++++++++++- internal/tests/api/targets/target_test.go | 40 +++++++++++++++++ 5 files changed, 88 insertions(+), 4 deletions(-) diff --git a/internal/daemon/controller/handlers/targets/target_service.go b/internal/daemon/controller/handlers/targets/target_service.go index a1735b1795..e6062712f9 100644 --- a/internal/daemon/controller/handlers/targets/target_service.go +++ b/internal/daemon/controller/handlers/targets/target_service.go @@ -1069,7 +1069,7 @@ func (s Service) createInRepo(ctx context.Context, item *pb.Target) (target.Targ opts = append(opts, target.WithIngressWorkerFilter(item.GetIngressWorkerFilter().GetValue())) } if item.GetAddress() != nil { - opts = append(opts, target.WithAddress(item.GetAddress().GetValue())) + opts = append(opts, target.WithAddress(strings.TrimSpace(item.GetAddress().GetValue()))) } attr, err := subtypeRegistry.newAttribute(target.SubtypeFromType(item.GetType()), item.GetAttrs()) @@ -1124,7 +1124,7 @@ func (s Service) updateInRepo(ctx context.Context, scopeId, id string, mask []st } if item.GetAddress() != nil { dbMask = append(dbMask, "Address") - opts = append(opts, target.WithAddress(item.GetAddress().GetValue())) + opts = append(opts, target.WithAddress(strings.TrimSpace(item.GetAddress().GetValue()))) } subtype := target.SubtypeFromId(id) diff --git a/internal/target/address.go b/internal/target/address.go index 72d05584a5..304099b477 100644 --- a/internal/target/address.go +++ b/internal/target/address.go @@ -2,6 +2,7 @@ package target import ( "context" + "strings" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" @@ -37,6 +38,7 @@ func NewAddress(targetId, address string, _ ...Option) (*Address, error) { if address == "" { return nil, errors.NewDeprecated(errors.InvalidParameter, op, "missing address") } + address = strings.TrimSpace(address) t := &Address{ TargetAddress: &store.TargetAddress{ TargetId: targetId, diff --git a/internal/target/repository.go b/internal/target/repository.go index cad16b4b0f..c1b42d970f 100644 --- a/internal/target/repository.go +++ b/internal/target/repository.go @@ -411,6 +411,7 @@ func (r *Repository) CreateTarget(ctx context.Context, target Target, opt ...Opt var address *Address var err error if t.GetAddress() != "" { + t.SetAddress(strings.TrimSpace(t.GetAddress())) address, err = NewAddress(t.GetPublicId(), t.GetAddress()) if err != nil { return nil, nil, nil, errors.Wrap(ctx, err, op) @@ -504,6 +505,7 @@ func (r *Repository) UpdateTarget(ctx context.Context, target Target, version ui case strings.EqualFold("egressworkerfilter", f): case strings.EqualFold("ingressworkerfilter", f): case strings.EqualFold("address", f): + target.SetAddress(strings.TrimSpace(target.GetAddress())) addressEndpoint = target.GetAddress() default: return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidFieldMask, op, fmt.Sprintf("invalid field mask: %s", f)) diff --git a/internal/target/tcp/repository_tcp_target_test.go b/internal/target/tcp/repository_tcp_target_test.go index 919ae895fb..9116dd5ad1 100644 --- a/internal/target/tcp/repository_tcp_target_test.go +++ b/internal/target/tcp/repository_tcp_target_test.go @@ -48,6 +48,7 @@ func TestRepository_CreateTarget(t *testing.T) { name string args args wantErr bool + wantAddress string wantIsError errors.Code }{ { @@ -79,6 +80,22 @@ func TestRepository_CreateTarget(t *testing.T) { }, wantErr: false, }, + { + name: "with-address-whitespace", + args: args{ + target: func() target.Target { + target, err := target.New(ctx, tcp.Subtype, proj.PublicId, + target.WithName("with-address-whitespace"), + target.WithDescription("with-address-whitespace"), + target.WithDefaultPort(80), + target.WithAddress(" 8.8.8.8 ")) + require.NoError(t, err) + return target + }(), + }, + wantErr: false, + wantAddress: "8.8.8.8", + }, { name: "nil-target", args: args{ @@ -202,7 +219,11 @@ func TestRepository_CreateTarget(t *testing.T) { foundTarget, foundHostSources, foundCredLibs, err := repo.LookupTarget(context.Background(), tar.GetPublicId()) assert.NoError(err) - assert.Equal(tt.args.target.GetAddress(), tar.GetAddress()) + if len(tt.wantAddress) != 0 { + assert.Equal(tt.wantAddress, tar.GetAddress()) + } else { + assert.Equal(tt.args.target.GetAddress(), tar.GetAddress()) + } assert.True(proto.Equal(tar.(*tcp.Target), foundTarget.(*tcp.Target))) assert.Equal(hostSources, foundHostSources) assert.Equal(credSources, foundCredLibs) @@ -259,6 +280,7 @@ func TestRepository_UpdateTcpTarget(t *testing.T) { wantIsError errors.Code wantDup bool wantHostSources bool + wantAddress string }{ { name: "valid", @@ -299,6 +321,20 @@ func TestRepository_UpdateTcpTarget(t *testing.T) { wantRowsUpdate: 1, wantHostSources: false, }, + { + name: "address-with-whitespace", + args: args{ + fieldMaskPaths: []string{"Address"}, + ProjectId: proj.PublicId, + address: " 127.0.0.1 ", + }, + newProjectId: proj.PublicId, + newTargetOpts: []target.Option{target.WithAddress("10.0.0.1")}, + wantErr: false, + wantRowsUpdate: 1, + wantHostSources: false, + wantAddress: "127.0.0.1", + }, { name: "delete-address", args: args{ @@ -558,7 +594,11 @@ func TestRepository_UpdateTcpTarget(t *testing.T) { afterUpdateIds = append(afterUpdateIds, hs.Id()) } assert.Equal(testHostSetIds, afterUpdateIds) - assert.Equal(tt.args.address, targetAfterUpdate.GetAddress()) + if len(tt.wantAddress) != 0 { + assert.Equal(tt.wantAddress, targetAfterUpdate.GetAddress()) + } else { + assert.Equal(tt.args.address, targetAfterUpdate.GetAddress()) + } afterUpdateIds = make([]string, 0, len(credSources)) for _, cl := range credSources { diff --git a/internal/tests/api/targets/target_test.go b/internal/tests/api/targets/target_test.go index dd4ee9a26a..ba84da87f7 100644 --- a/internal/tests/api/targets/target_test.go +++ b/internal/tests/api/targets/target_test.go @@ -693,3 +693,43 @@ func TestSet_Errors(t *testing.T) { assert.NotNil(apiErr) assert.EqualValues(http.StatusBadRequest, apiErr.Response().StatusCode()) } + +func TestCreateTarget_WhitespaceInAddress(t *testing.T) { + require := require.New(t) + tc := controller.NewTestController(t, nil) + defer tc.Shutdown() + + client := tc.Client() + token := tc.Token() + client.SetToken(token.Token) + _, proj := iam.TestScopes(t, tc.IamRepo(), iam.WithUserId(token.UserId)) + + tarClient := targets.NewClient(client) + + tar, err := tarClient.Create(tc.Context(), "tcp", proj.GetPublicId(), targets.WithName("foo"), targets.WithTcpTargetDefaultPort(2), targets.WithAddress(" 127.0.0.1 ")) + require.NoError(err) + require.NotNil(tar) + require.Equal("127.0.0.1", tar.GetItem().Address) +} + +func TestUpdateTarget_WhitespaceInAddress(t *testing.T) { + require := require.New(t) + tc := controller.NewTestController(t, nil) + defer tc.Shutdown() + + client := tc.Client() + token := tc.Token() + client.SetToken(token.Token) + _, proj := iam.TestScopes(t, tc.IamRepo(), iam.WithUserId(token.UserId)) + + tarClient := targets.NewClient(client) + + tar, err := tarClient.Create(tc.Context(), "tcp", proj.GetPublicId(), targets.WithName("foo"), targets.WithTcpTargetDefaultPort(2), targets.WithAddress("127.0.0.1")) + require.NoError(err) + require.NotNil(tar) + + updateResult, err := tarClient.Update(tc.Context(), tar.Item.Id, tar.Item.Version, targets.WithAddress(" 10.0.0.1 ")) + require.NoError(err) + require.NotNil(updateResult) + require.Equal("10.0.0.1", updateResult.Item.Address) +}