From fdaaebf1f6988ef865108f2b96e9f275edb863fa Mon Sep 17 00:00:00 2001 From: Jim Date: Thu, 9 Jul 2020 06:59:25 -0400 Subject: [PATCH] add group member functions (#173) --- internal/db/migrations/postgres.gen.go | 121 +++-- .../db/migrations/postgres/02_oplog.up.sql | 2 +- .../db/migrations/postgres/06_iam.down.sql | 3 + internal/db/migrations/postgres/06_iam.up.sql | 73 ++- .../db/migrations/postgres/08_iam.down.sql | 5 - internal/db/migrations/postgres/08_iam.up.sql | 39 -- internal/iam/group_member.go | 137 ++---- internal/iam/group_member_test.go | 369 ++++++++++++--- internal/iam/repository_group.go | 306 ++++++++++++ internal/iam/repository_group_test.go | 439 ++++++++++++++++++ internal/iam/store/group.pb.go | 25 +- internal/iam/store/group_member.pb.go | 173 ++----- internal/iam/testing.go | 13 + internal/iam/testing_test.go | 21 + internal/iam/user_grants.go | 4 +- internal/iam/user_grants_test.go | 3 +- internal/iam/user_groups.go | 2 +- internal/iam/user_groups_test.go | 3 +- .../storage/iam/store/v1/group.proto | 5 + .../storage/iam/store/v1/group_member.proto | 17 +- 20 files changed, 1344 insertions(+), 416 deletions(-) diff --git a/internal/db/migrations/postgres.gen.go b/internal/db/migrations/postgres.gen.go index 5825754071..aa6ea9be3e 100644 --- a/internal/db/migrations/postgres.gen.go +++ b/internal/db/migrations/postgres.gen.go @@ -258,7 +258,7 @@ values ('iam_scope', 1), ('iam_user', 1), ('iam_group', 1), - ('iam_group_member_user', 1), + ('iam_group_member', 1), ('iam_role', 1), ('iam_role_grant', 1), ('iam_group_role', 1), @@ -438,6 +438,7 @@ drop table iam_role cascade; drop view iam_principal_role cascade; drop table iam_group_role cascade; drop table iam_user_role cascade; +drop table iam_group_member cascade; drop table iam_role_grant cascade; drop function iam_sub_names cascade; @@ -446,6 +447,7 @@ drop function iam_sub_scopes_func cascade; drop function iam_immutable_role cascade; drop function iam_user_role_scope_check cascade; drop function iam_group_role_scope_check cascade; +drop function iam_group_member_scope_check cascade; drop function grant_scope_id_valid cascade; drop function immutable_scope_id_func cascade; @@ -780,9 +782,6 @@ create table iam_role ( disabled boolean not null default false, -- version allows optimistic locking of the role when modifying the role -- itself and when modifying dependent items like principal roles. - -- TODO (jlambert 6/2020) add before update trigger to automatically - -- increment the version when needed. This trigger can be addded when PR - -- #126 is merged and update_version_column() is available. version bigint not null default 1, -- add unique index so a composite fk can be declared. @@ -851,10 +850,19 @@ create table iam_group ( scope_id wt_scope_id not null references iam_scope(public_id) on delete cascade on update cascade, unique(name, scope_id), disabled boolean not null default false, + -- version allows optimistic locking of the group when modifying the group + -- itself and when modifying dependent items like group members. + version bigint not null default 1, + -- add unique index so a composite fk can be declared. unique(scope_id, public_id) ); +create trigger + update_version_column +after update on iam_group + for each row execute procedure update_version_column(); + create trigger update_time_column before update on iam_group @@ -1009,6 +1017,67 @@ before insert on iam_group_role for each row execute procedure default_create_time(); +-- iam_group_member is an association table that represents group with +-- associated users. +create table iam_group_member ( + create_time wt_timestamp, + group_id wt_public_id references iam_group(public_id) on delete cascade on update cascade, + member_id wt_public_id references iam_user(public_id) on delete cascade on update cascade, + primary key (group_id, member_id) +); + +-- iam_group_member_scope_check() ensures that the user is only assigned +-- groups which are within its organization, or the group is within a project +-- within its organization. +create or replace function + iam_group_member_scope_check() + returns trigger +as $$ +declare cnt int; +begin + select count(*) into cnt + from iam_user + where + public_id = new.member_id and + scope_id in( + -- check to see if they have the same org scope + select s.public_id + from iam_scope s, iam_group g + where s.public_id = g.scope_id and g.public_id = new.group_id + union + -- check to see if the role has a parent that's the same org + select s.parent_id as public_id + from iam_group g, iam_scope s + where g.scope_id = s.public_id and g.public_id = new.role_id + ); + if cnt = 0 then + raise exception 'user and group do not belong to the same organization'; + end if; + return new; +end; +$$ language plpgsql; + +-- iam_immutable_group_member() ensures that group members are immutable. +create or replace function + iam_immutable_group_member() + returns trigger +as $$ +begin + raise exception 'group members are immutable'; +end; +$$ language plpgsql; + +create trigger + iam_group_member +before +insert on iam_group_role + for each row execute procedure default_create_time(); + +create trigger iam_immutable_group_member +before +update on iam_group_member + for each row execute procedure iam_immutable_group_member(); + commit; `), @@ -1110,13 +1179,8 @@ commit; BEGIN; drop table if exists iam_auth_method cascade; -drop table if exists iam_group_member_type_enm cascade; -drop table if exists iam_group cascade cascade; -drop table if exists iam_group_member_user cascade; -drop view if exists iam_group_member; drop table if exists iam_auth_method_type_enm cascade; drop table if exists iam_action_enm cascade; -drop view if exists iam_assigned_role; COMMIT; @@ -1127,45 +1191,6 @@ COMMIT; bytes: []byte(` BEGIN; -create table iam_group_member_user ( - create_time wt_timestamp, - group_id wt_public_id references iam_group(public_id) on delete cascade on update cascade, - member_id wt_public_id references iam_user(public_id) on delete cascade on update cascade, - primary key (group_id, member_id) -); - - --- iam_group_member_user_scope_check() ensures that the user is only assigned --- groups which are within its organization, or the group is within a project --- within its organization. -create or replace function - iam_group_member_user_scope_check() - returns trigger -as $$ -declare cnt int; -begin - select count(*) into cnt - from iam_user - where - public_id = new.member_id and - scope_id in( - -- check to see if they have the same org scope - select s.public_id - from iam_scope s, iam_group g - where s.public_id = g.scope_id and g.public_id = new.group_id - union - -- check to see if the role has a parent that's the same org - select s.parent_id as public_id - from iam_group g, iam_scope s - where g.scope_id = s.public_id and g.public_id = new.role_id - ); - if cnt = 0 then - raise exception 'user and group do not belong to the same organization'; - end if; - return new; -end; -$$ language plpgsql; - CREATE TABLE iam_auth_method ( public_id wt_public_id primary key, diff --git a/internal/db/migrations/postgres/02_oplog.up.sql b/internal/db/migrations/postgres/02_oplog.up.sql index ba84a1241f..fed8b13147 100644 --- a/internal/db/migrations/postgres/02_oplog.up.sql +++ b/internal/db/migrations/postgres/02_oplog.up.sql @@ -91,7 +91,7 @@ values ('iam_scope', 1), ('iam_user', 1), ('iam_group', 1), - ('iam_group_member_user', 1), + ('iam_group_member', 1), ('iam_role', 1), ('iam_role_grant', 1), ('iam_group_role', 1), diff --git a/internal/db/migrations/postgres/06_iam.down.sql b/internal/db/migrations/postgres/06_iam.down.sql index ef12a6125d..7cdb590ee1 100644 --- a/internal/db/migrations/postgres/06_iam.down.sql +++ b/internal/db/migrations/postgres/06_iam.down.sql @@ -11,6 +11,7 @@ drop table iam_role cascade; drop view iam_principal_role cascade; drop table iam_group_role cascade; drop table iam_user_role cascade; +drop table iam_group_member cascade; drop table iam_role_grant cascade; drop function iam_sub_names cascade; @@ -19,6 +20,8 @@ drop function iam_sub_scopes_func cascade; drop function iam_immutable_role cascade; drop function iam_user_role_scope_check cascade; drop function iam_group_role_scope_check cascade; +drop function iam_group_member_scope_check cascade; +drop function iam_immutable_group_member cascade; drop function grant_scope_id_valid cascade; drop function immutable_scope_id_func cascade; diff --git a/internal/db/migrations/postgres/06_iam.up.sql b/internal/db/migrations/postgres/06_iam.up.sql index b56bcc014f..317373a2f0 100644 --- a/internal/db/migrations/postgres/06_iam.up.sql +++ b/internal/db/migrations/postgres/06_iam.up.sql @@ -322,9 +322,6 @@ create table iam_role ( disabled boolean not null default false, -- version allows optimistic locking of the role when modifying the role -- itself and when modifying dependent items like principal roles. - -- TODO (jlambert 6/2020) add before update trigger to automatically - -- increment the version when needed. This trigger can be addded when PR - -- #126 is merged and update_version_column() is available. version bigint not null default 1, -- add unique index so a composite fk can be declared. @@ -393,10 +390,19 @@ create table iam_group ( scope_id wt_scope_id not null references iam_scope(public_id) on delete cascade on update cascade, unique(name, scope_id), disabled boolean not null default false, + -- version allows optimistic locking of the group when modifying the group + -- itself and when modifying dependent items like group members. + version bigint not null default 1, + -- add unique index so a composite fk can be declared. unique(scope_id, public_id) ); +create trigger + update_version_column +after update on iam_group + for each row execute procedure update_version_column(); + create trigger update_time_column before update on iam_group @@ -551,4 +557,65 @@ before insert on iam_group_role for each row execute procedure default_create_time(); +-- iam_group_member is an association table that represents group with +-- associated users. +create table iam_group_member ( + create_time wt_timestamp, + group_id wt_public_id references iam_group(public_id) on delete cascade on update cascade, + member_id wt_public_id references iam_user(public_id) on delete cascade on update cascade, + primary key (group_id, member_id) +); + +-- iam_group_member_scope_check() ensures that the user is only assigned +-- groups which are within its organization, or the group is within a project +-- within its organization. +create or replace function + iam_group_member_scope_check() + returns trigger +as $$ +declare cnt int; +begin + select count(*) into cnt + from iam_user + where + public_id = new.member_id and + scope_id in( + -- check to see if they have the same org scope + select s.public_id + from iam_scope s, iam_group g + where s.public_id = g.scope_id and g.public_id = new.group_id + union + -- check to see if the role has a parent that's the same org + select s.parent_id as public_id + from iam_group g, iam_scope s + where g.scope_id = s.public_id and g.public_id = new.role_id + ); + if cnt = 0 then + raise exception 'user and group do not belong to the same organization'; + end if; + return new; +end; +$$ language plpgsql; + +-- iam_immutable_group_member() ensures that group members are immutable. +create or replace function + iam_immutable_group_member() + returns trigger +as $$ +begin + raise exception 'group members are immutable'; +end; +$$ language plpgsql; + +create trigger + default_create_time_column +before +insert on iam_group_member + for each row execute procedure default_create_time(); + +create trigger iam_immutable_group_member +before +update on iam_group_member + for each row execute procedure iam_immutable_group_member(); + commit; diff --git a/internal/db/migrations/postgres/08_iam.down.sql b/internal/db/migrations/postgres/08_iam.down.sql index 6e78fbbf01..179ab9083a 100644 --- a/internal/db/migrations/postgres/08_iam.down.sql +++ b/internal/db/migrations/postgres/08_iam.down.sql @@ -1,13 +1,8 @@ BEGIN; drop table if exists iam_auth_method cascade; -drop table if exists iam_group_member_type_enm cascade; -drop table if exists iam_group cascade cascade; -drop table if exists iam_group_member_user cascade; -drop view if exists iam_group_member; drop table if exists iam_auth_method_type_enm cascade; drop table if exists iam_action_enm cascade; -drop view if exists iam_assigned_role; COMMIT; \ No newline at end of file diff --git a/internal/db/migrations/postgres/08_iam.up.sql b/internal/db/migrations/postgres/08_iam.up.sql index dc048fcc28..77b198adab 100644 --- a/internal/db/migrations/postgres/08_iam.up.sql +++ b/internal/db/migrations/postgres/08_iam.up.sql @@ -1,44 +1,5 @@ BEGIN; -create table iam_group_member_user ( - create_time wt_timestamp, - group_id wt_public_id references iam_group(public_id) on delete cascade on update cascade, - member_id wt_public_id references iam_user(public_id) on delete cascade on update cascade, - primary key (group_id, member_id) -); - - --- iam_group_member_user_scope_check() ensures that the user is only assigned --- groups which are within its organization, or the group is within a project --- within its organization. -create or replace function - iam_group_member_user_scope_check() - returns trigger -as $$ -declare cnt int; -begin - select count(*) into cnt - from iam_user - where - public_id = new.member_id and - scope_id in( - -- check to see if they have the same org scope - select s.public_id - from iam_scope s, iam_group g - where s.public_id = g.scope_id and g.public_id = new.group_id - union - -- check to see if the role has a parent that's the same org - select s.parent_id as public_id - from iam_group g, iam_scope s - where g.scope_id = s.public_id and g.public_id = new.role_id - ); - if cnt = 0 then - raise exception 'user and group do not belong to the same organization'; - end if; - return new; -end; -$$ language plpgsql; - CREATE TABLE iam_auth_method ( public_id wt_public_id primary key, diff --git a/internal/iam/group_member.go b/internal/iam/group_member.go index 3c645acfa9..eb2705cbaa 100644 --- a/internal/iam/group_member.go +++ b/internal/iam/group_member.go @@ -2,137 +2,78 @@ package iam import ( "context" - "errors" + "fmt" "github.com/hashicorp/watchtower/internal/db" "github.com/hashicorp/watchtower/internal/iam/store" "google.golang.org/protobuf/proto" ) -// MemberType defines the possible types for members -type MemberType uint32 - -const ( - UnknownMemberType MemberType = 0 - UserMemberType MemberType = 1 -) - -func (m MemberType) String() string { - return [...]string{ - "unknown", - "user", - }[m] -} - -// GroupMember declares a common interface for all members assigned to a group -type GroupMember interface { - GetGroupId() string - GetMemberId() string - GetType() string -} - -// groupMemberView provides a common way to return group members regardless of their underlying type -type groupMemberView struct { - *store.GroupMemberView -} - -// TableName provides an overridden gorm table name for group members -func (v *groupMemberView) TableName() string { return "iam_group_member" } - -// AddUser will add a user to the group in memory, returning a GroupMember that -// can be written to the db -func (g *Group) AddUser(userId string, opt ...db.Option) (GroupMember, error) { - gm, err := newGroupMemberUser(g.PublicId, userId) - if err != nil { - return nil, err - } - return gm, nil -} - -// Members returns the members of the group (Users) -func (g *Group) Members(ctx context.Context, r db.Reader) ([]GroupMember, error) { - const where = "group_id = ?" - viewMembers := []*GroupMemberUser{} - if err := r.SearchWhere(ctx, &viewMembers, where, []interface{}{g.PublicId}); err != nil { - return nil, err - } - - members := []GroupMember{} - for _, m := range viewMembers { - gm := &GroupMemberUser{ - GroupMemberUser: &store.GroupMemberUser{ - CreateTime: m.CreateTime, - GroupId: m.GroupId, - MemberId: m.MemberId, - }, - } - members = append(members, gm) - } - return members, nil -} - -// GroupMemberUser is a group member that's a User -type GroupMemberUser struct { - *store.GroupMemberUser +// GroupMember is a group member that's a User +type GroupMember struct { + *store.GroupMember tableName string `gorm:"-"` } -// ensure that GroupMemberUser implements the interfaces of: Clonable, GroupMember and db.VetForWriter -var _ Clonable = (*GroupMemberUser)(nil) -var _ GroupMember = (*GroupMemberUser)(nil) -var _ db.VetForWriter = (*GroupMemberUser)(nil) +// ensure that GroupMember implements the interfaces of: Clonable, db.VetForWriter +var _ Clonable = (*GroupMember)(nil) +var _ db.VetForWriter = (*GroupMember)(nil) -// newGroupMemberUser creates a new in memory user member of the group -// options include: withDescripion, WithName -func newGroupMemberUser(groupId, userId string, opt ...Option) (*GroupMemberUser, error) { - if userId == "" { - return nil, errors.New("error the user public id is unset") - } +// NewGroupMember creates a new in memory user member of the group. Users can +// be assigned to groups which are within its organization, or the group is +// within a project within its organization. This relationship will not be +// enforced until the group member is written to the database. No options are +// currently supported. +func NewGroupMember(groupId, userId string, opt ...Option) (*GroupMember, error) { if groupId == "" { - return nil, errors.New("error the user member group is unset") + return nil, fmt.Errorf("new group member: missing group id: %w", db.ErrInvalidParameter) } - gm := &GroupMemberUser{ - GroupMemberUser: &store.GroupMemberUser{ + if userId == "" { + return nil, fmt.Errorf("new group member: missing user id: %w", db.ErrInvalidParameter) + } + return &GroupMember{ + GroupMember: &store.GroupMember{ MemberId: userId, GroupId: groupId, }, - } - return gm, nil -} - -func (m *GroupMemberUser) GetType() string { - return UserMemberType.String() + }, nil } -func allocGroupMemberUser() GroupMemberUser { - return GroupMemberUser{ - GroupMemberUser: &store.GroupMemberUser{}, +func allocGroupMember() GroupMember { + return GroupMember{ + GroupMember: &store.GroupMember{}, } } -// Clone creates a clone of the GroupMemberUser -func (m *GroupMemberUser) Clone() interface{} { - cp := proto.Clone(m.GroupMemberUser) - return &GroupMemberUser{ - GroupMemberUser: cp.(*store.GroupMemberUser), +// Clone creates a clone of the GroupMember +func (m *GroupMember) Clone() interface{} { + cp := proto.Clone(m.GroupMember) + return &GroupMember{ + GroupMember: cp.(*store.GroupMember), } } -// VetForWrite implements db.VetForWrite() interface -func (m *GroupMemberUser) VetForWrite(ctx context.Context, r db.Reader, opType db.OpType, opt ...db.Option) error { +// VetForWrite implements db.VetForWrite() interface for group members. +func (m *GroupMember) VetForWrite(ctx context.Context, r db.Reader, opType db.OpType, opt ...db.Option) error { + if m.GroupId == "" { + return fmt.Errorf("group member: missing group id: %w", db.ErrInvalidParameter) + } + if m.MemberId == "" { + return fmt.Errorf("group member: missing member id: %w", db.ErrInvalidParameter) + } return nil } // TableName returns the tablename to override the default gorm table name -func (m *GroupMemberUser) TableName() string { +func (m *GroupMember) TableName() string { if m.tableName != "" { return m.tableName } - return "iam_group_member_user" + return "iam_group_member" } // SetTableName sets the tablename and satisfies the ReplayableMessage interface -func (m *GroupMemberUser) SetTableName(n string) { +func (m *GroupMember) SetTableName(n string) { if n != "" { m.tableName = n } diff --git a/internal/iam/group_member_test.go b/internal/iam/group_member_test.go index e221269353..0b9fa12e34 100644 --- a/internal/iam/group_member_test.go +++ b/internal/iam/group_member_test.go @@ -2,84 +2,337 @@ package iam import ( "context" + "errors" "testing" "github.com/hashicorp/watchtower/internal/db" + "github.com/hashicorp/watchtower/internal/iam/store" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" ) -func TestGroup_AddUser(t *testing.T) { +func Test_NewGroupMember(t *testing.T) { t.Parallel() conn, _ := db.TestSetup(t, "postgres") - t.Run("valid", func(t *testing.T) { - assert := assert.New(t) - w := db.New(conn) - org, _ := TestScopes(t, conn) - user := TestUser(t, conn, org.PublicId) - grp := TestGroup(t, conn, org.PublicId) - gm, err := grp.AddUser(user.PublicId) - assert.NoError(err) - assert.NotNil(gm) - assert.Equal(gm.(*GroupMemberUser).GroupId, grp.PublicId) - err = w.Create(context.Background(), gm) - assert.NoError(err) - assert.Equal("user", gm.GetType()) - }) + org, proj := TestScopes(t, conn) + orgGroup := TestGroup(t, conn, org.PublicId) + projGroup := TestGroup(t, conn, proj.PublicId) + user := TestUser(t, conn, org.PublicId) + + type args struct { + groupId string + userId string + opt []Option + } + tests := []struct { + name string + args args + want *GroupMember + wantErr bool + wantIsErr error + }{ + { + name: "valid-org", + args: args{ + groupId: orgGroup.PublicId, + userId: user.PublicId, + }, + want: func() *GroupMember { + gm := allocGroupMember() + gm.GroupId = orgGroup.PublicId + gm.MemberId = user.PublicId + return &gm + }(), + }, + { + name: "valid-proj", + args: args{ + groupId: projGroup.PublicId, + userId: user.PublicId, + }, + want: func() *GroupMember { + gm := allocGroupMember() + gm.GroupId = projGroup.PublicId + gm.MemberId = user.PublicId + return &gm + }(), + }, + { + name: "missing-group", + args: args{ + userId: user.PublicId, + }, + want: nil, + wantErr: true, + wantIsErr: db.ErrInvalidParameter, + }, + { + name: "missing-user", + args: args{ + groupId: projGroup.PublicId, + }, + want: nil, + wantErr: true, + wantIsErr: db.ErrInvalidParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + got, err := NewGroupMember(tt.args.groupId, tt.args.userId, tt.args.opt...) + if tt.wantErr { + require.Error(err) + assert.True(errors.Is(err, tt.wantIsErr)) + return + } + require.NoError(err) + assert.Equal(tt.want, got) + }) + } } -func Test_NewGroupMember(t *testing.T) { +func Test_GroupMemberCreate(t *testing.T) { t.Parallel() conn, _ := db.TestSetup(t, "postgres") - t.Run("valid", func(t *testing.T) { - assert := assert.New(t) - w := db.New(conn) - s := testOrg(t, conn, "", "") - user := TestUser(t, conn, s.PublicId) - grp := TestGroup(t, conn, s.PublicId) + org, proj := TestScopes(t, conn) + type args struct { + gm *GroupMember + } + tests := []struct { + name string + args args + wantDup bool + wantErr bool + wantErrMsg string + wantIsErr error + }{ + { + name: "valid-with-org", + args: args{ + gm: func() *GroupMember { + g := TestGroup(t, conn, org.PublicId) + u := TestUser(t, conn, org.PublicId) + gm, err := NewGroupMember(g.PublicId, u.PublicId) + require.NoError(t, err) + return gm + }(), + }, + wantErr: false, + }, + { + name: "valid-with-proj", + args: args{ + gm: func() *GroupMember { + g := TestGroup(t, conn, proj.PublicId) + u := TestUser(t, conn, org.PublicId) + gm, err := NewGroupMember(g.PublicId, u.PublicId) + require.NoError(t, err) + return gm + }(), + }, + wantErr: false, + }, + { + name: "bad-group-id", + args: args{ + gm: func() *GroupMember { + id := testId(t) + u := TestUser(t, conn, org.PublicId) + gm, err := NewGroupMember(id, u.PublicId) + require.NoError(t, err) + return gm + }(), + }, + wantErr: true, + wantErrMsg: `create: failed: pq: insert or update on table "iam_group_member" violates foreign key constraint`, + }, + { + name: "bad-user-id", + args: args{ + gm: func() *GroupMember { + id := testId(t) + g := TestGroup(t, conn, proj.PublicId) + gm, err := NewGroupMember(g.PublicId, id) + require.NoError(t, err) + return gm + }(), + }, + wantErr: true, + wantErrMsg: `create: failed: pq: insert or update on table "iam_group_member" violates foreign key constraint`, + }, + { + name: "missing-group-id", + args: args{ + gm: func() *GroupMember { + u := TestUser(t, conn, org.PublicId) + return &GroupMember{ + GroupMember: &store.GroupMember{ + GroupId: "", + MemberId: u.PublicId, + }, + } + }(), + }, + wantErr: true, + wantIsErr: db.ErrInvalidParameter, + }, + { + name: "missing-user-id", + args: args{ + gm: func() *GroupMember { + g := TestGroup(t, conn, org.PublicId) + return &GroupMember{ + GroupMember: &store.GroupMember{ + GroupId: g.PublicId, + MemberId: "", + }, + } + }(), + }, + wantErr: true, + wantIsErr: db.ErrInvalidParameter, + }, + { + name: "dup-at-org", + args: args{ + gm: func() *GroupMember { + g := TestGroup(t, conn, org.PublicId) + u := TestUser(t, conn, org.PublicId) + gm, err := NewGroupMember(g.PublicId, u.PublicId) + require.NoError(t, err) + return gm + }(), + }, + wantDup: true, + wantErr: true, + wantErrMsg: `create: failed: pq: duplicate key value violates unique constraint "iam_group_member_pkey"`, + }, + } - gm, err := grp.AddUser(user.PublicId) - assert.NoError(err) - assert.NotNil(gm) - err = w.Create(context.Background(), gm) - assert.NoError(err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + w := db.New(conn) + if tt.wantDup { + gm := tt.args.gm.Clone().(*GroupMember) + err := w.Create(context.Background(), gm) + require.NoError(err) + } + gm := tt.args.gm.Clone().(*GroupMember) + err := w.Create(context.Background(), gm) + if tt.wantErr { + require.Error(err) + assert.Contains(err.Error(), tt.wantErrMsg) + if tt.wantIsErr != nil { + assert.True(errors.Is(err, tt.wantIsErr)) + } + return + } + assert.NoError(err) - members, err := grp.Members(context.Background(), w) - assert.NoError(err) - assert.Equal(1, len(members)) - assert.Equal(members[0].GetMemberId(), user.PublicId) - assert.Equal(members[0].GetGroupId(), grp.PublicId) + found := allocGroupMember() + err = w.LookupWhere(context.Background(), &found, "group_id = ? and member_id = ?", gm.GroupId, gm.MemberId) + require.NoError(err) + assert.Equal(gm, &found) + }) + } +} - rowsDeleted, err := w.Delete(context.Background(), gm) - assert.NoError(err) - assert.Equal(1, rowsDeleted) +func Test_GroupMemberUpdate(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + org, _ := TestScopes(t, conn) + rw := db.New(conn) - members, err = grp.Members(context.Background(), w) - assert.NoError(err) - assert.Equal(0, len(members)) + t.Run("updates not allowed", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + g := TestGroup(t, conn, org.PublicId) + u := TestUser(t, conn, org.PublicId) + u2 := TestUser(t, conn, org.PublicId) + gm := TestGroupMember(t, conn, g.PublicId, u.PublicId) + updateGrpMember := gm.Clone().(*GroupMember) + updateGrpMember.MemberId = u2.PublicId + updatedRows, err := rw.Update(context.Background(), updateGrpMember, []string{"MemberId"}, nil) + require.Error(err) + assert.Equal(0, updatedRows) }) - t.Run("bad-type", func(t *testing.T) { - assert := assert.New(t) - w := db.New(conn) - s := testOrg(t, conn, "", "") - role := TestRole(t, conn, s.PublicId) +} + +func Test_GroupMemberDelete(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + id := testId(t) + org, _ := TestScopes(t, conn) + u := TestUser(t, conn, org.PublicId) + g := TestGroup(t, conn, org.PublicId) - grp := TestGroup(t, conn, s.PublicId) - gm, err := grp.AddUser(role.PublicId) - assert.NoError(err) - assert.NotNil(gm) - err = w.Create(context.Background(), gm) - assert.Error(err) - assert.Equal(err.Error(), `create: failed: pq: insert or update on table "iam_group_member_user" violates foreign key constraint "iam_group_member_user_member_id_fkey"`) + tests := []struct { + name string + gm *GroupMember + wantRowsDeleted int + wantErr bool + wantErrMsg string + }{ + { + name: "valid", + gm: TestGroupMember(t, conn, g.PublicId, u.PublicId), + wantErr: false, + wantRowsDeleted: 1, + }, + { + name: "bad-id", + gm: func() *GroupMember { gm := allocGroupMember(); gm.MemberId = id; gm.GroupId = id; return &gm }(), + wantErr: false, + wantRowsDeleted: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + deleteGroup := allocGroupMember() + deleteGroup.GroupId = tt.gm.GetGroupId() + deleteGroup.MemberId = tt.gm.GetMemberId() + deletedRows, err := rw.Delete(context.Background(), &deleteGroup) + if tt.wantErr { + require.Error(err) + return + } + require.NoError(err) + if tt.wantRowsDeleted == 0 { + assert.Equal(tt.wantRowsDeleted, deletedRows) + return + } + assert.Equal(tt.wantRowsDeleted, deletedRows) + found := allocGroupMember() + err = rw.LookupWhere(context.Background(), &found, "group_id = ? and member_id = ?", tt.gm.GetGroupId(), tt.gm.GetMemberId()) + require.Error(err) + assert.True(errors.Is(db.ErrRecordNotFound, err)) + }) + } +} +func TestGroupMember_Clone(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + org, proj := TestScopes(t, conn) + user := TestUser(t, conn, org.PublicId) + t.Run("valid", func(t *testing.T) { + assert := assert.New(t) + group := TestGroup(t, conn, org.PublicId) + gm := TestGroupMember(t, conn, group.PublicId, user.PublicId) + cp := gm.Clone() + assert.True(proto.Equal(cp.(*GroupMember).GroupMember, gm.GroupMember)) }) - t.Run("nil-user", func(t *testing.T) { + t.Run("not-equal", func(t *testing.T) { assert := assert.New(t) - s := testOrg(t, conn, "", "") - grp := TestGroup(t, conn, s.PublicId) - - gm, err := grp.AddUser("") - assert.Error(err) - assert.Nil(gm) - assert.Equal(err.Error(), "error the user public id is unset") + g := TestGroup(t, conn, org.PublicId) + g2 := TestGroup(t, conn, proj.PublicId) + gm := TestGroupMember(t, conn, g.PublicId, user.PublicId) + gm2 := TestGroupMember(t, conn, g2.PublicId, user.PublicId) + cp := gm.Clone() + assert.True(!proto.Equal(cp.(*GroupMember).GroupMember, gm2.GroupMember)) }) } diff --git a/internal/iam/repository_group.go b/internal/iam/repository_group.go index 58f4480ff9..e881ef1845 100644 --- a/internal/iam/repository_group.go +++ b/internal/iam/repository_group.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/hashicorp/watchtower/internal/db" + "github.com/hashicorp/watchtower/internal/oplog" ) // CreateGroup will create a group in the repository and return the written @@ -127,3 +128,308 @@ func (r *Repository) ListGroups(ctx context.Context, withScopeId string, opt ... } return grps, nil } + +// ListGroupMembers of a group and supports WithLimit option. +func (r *Repository) ListGroupMembers(ctx context.Context, withGroupId string, opt ...Option) ([]*GroupMember, error) { + if withGroupId == "" { + return nil, fmt.Errorf("list group members: missing group id: %w", db.ErrInvalidParameter) + } + members := []*GroupMember{} + if err := r.list(ctx, &members, "group_id = ?", []interface{}{withGroupId}, opt...); err != nil { + return nil, fmt.Errorf("list group members: %w", err) + } + return members, nil +} + +// AddGroupMembers provides the ability to add members (userIds) to a group +// (groupId). The group's current db version must match the groupVersion or an +// error will be returned. The users and group must all be in the same +// organization or the user must be in the project group's parent organization. +func (r *Repository) AddGroupMembers(ctx context.Context, groupId string, groupVersion int, userIds []string, opt ...Option) ([]*GroupMember, error) { + + if groupId == "" { + return nil, fmt.Errorf("add group members: missing group id %w", db.ErrInvalidParameter) + } + if len(userIds) == 0 { + return nil, fmt.Errorf("add group members: missing user ids to add %w", db.ErrInvalidParameter) + } + + group := allocGroup() + group.PublicId = groupId + scope, err := group.GetScope(ctx, r.reader) + if err != nil { + return nil, fmt.Errorf("add group members: unable to get group %s scope: %w", groupId, err) + } + + newGroupMembers := make([]interface{}, 0, len(userIds)) + for _, id := range userIds { + gm, err := NewGroupMember(groupId, id) + if err != nil { + return nil, fmt.Errorf("add group members: unable to create in memory group member: %w", err) + } + newGroupMembers = append(newGroupMembers, gm) + } + + _, err = r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + msgs := make([]*oplog.Message, 0, 2) + groupTicket, err := w.GetTicket(&group) + if err != nil { + return fmt.Errorf("add group members: unable to get ticket: %w", err) + } + updatedGroup := allocGroup() + updatedGroup.PublicId = groupId + updatedGroup.Version = uint32(groupVersion) + 1 + var groupOplogMsg oplog.Message + rowsUpdated, err := w.Update(ctx, &updatedGroup, []string{"Version"}, nil, db.NewOplogMsg(&groupOplogMsg), db.WithVersion(groupVersion)) + if err != nil { + return fmt.Errorf("add group members: unable to update group version: %w", err) + } + if rowsUpdated != 1 { + return fmt.Errorf("add group members: updated group and %d rows updated", rowsUpdated) + } + msgs = append(msgs, &groupOplogMsg) + memberOplogMsgs := make([]*oplog.Message, 0, len(newGroupMembers)) + if err := w.CreateItems(ctx, newGroupMembers, db.NewOplogMsgs(&memberOplogMsgs)); err != nil { + return fmt.Errorf("add group members: unable to add users: %w", err) + } + msgs = append(msgs, memberOplogMsgs...) + metadata := oplog.Metadata{ + "op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()}, + "scope-id": []string{scope.PublicId}, + "scope-type": []string{scope.Type}, + "resource-public-id": []string{groupId}, + } + if err := w.WriteOplogEntryWith(ctx, r.wrapper, groupTicket, metadata, msgs); err != nil { + return fmt.Errorf("add group members: unable to write oplog: %w", err) + } + return nil + }, + ) + if err != nil { + return nil, fmt.Errorf("add group members: error adding members: %w", err) + } + members := make([]*GroupMember, 0, len(newGroupMembers)+len(newGroupMembers)) + for _, m := range newGroupMembers { + members = append(members, m.(*GroupMember)) + } + return members, nil +} + +// DeleteGroupMembers (userIds) from a group (groupId). The group's current db version +// must match the groupVersion or an error will be returned. +func (r *Repository) DeleteGroupMembers(ctx context.Context, groupId string, groupVersion int, userIds []string, opt ...Option) (int, error) { + if groupId == "" { + return db.NoRowsAffected, fmt.Errorf("delete group members: missing group id: %w", db.ErrInvalidParameter) + } + if len(userIds) == 0 { + return db.NoRowsAffected, fmt.Errorf("delete group members: missing either user or groups to delete %w", db.ErrInvalidParameter) + } + group := allocGroup() + group.PublicId = groupId + scope, err := group.GetScope(ctx, r.reader) + if err != nil { + return db.NoRowsAffected, fmt.Errorf("delete group members: unable to get group %s scope: %w", groupId, err) + } + + deleteMembers := make([]interface{}, 0, len(userIds)) + for _, id := range userIds { + member, err := NewGroupMember(groupId, id) + if err != nil { + return db.NoRowsAffected, fmt.Errorf("delete group members: unable to create in memory group member: %w", err) + } + deleteMembers = append(deleteMembers, member) + } + + var totalRowsDeleted int + _, err = r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + msgs := make([]*oplog.Message, 0, 2) + groupTicket, err := w.GetTicket(&group) + if err != nil { + return fmt.Errorf("delete group members: unable to get ticket: %w", err) + } + updatedGroup := allocGroup() + updatedGroup.PublicId = groupId + updatedGroup.Version = uint32(groupVersion) + 1 + var groupOplogMsg oplog.Message + rowsUpdated, err := w.Update(ctx, &updatedGroup, []string{"Version"}, nil, db.NewOplogMsg(&groupOplogMsg), db.WithVersion(groupVersion)) + if err != nil { + return fmt.Errorf("delete group members: unable to update group version: %w", err) + } + if rowsUpdated != 1 { + return fmt.Errorf("delete group members: updated group and %d rows updated", rowsUpdated) + } + msgs = append(msgs, &groupOplogMsg) + userOplogMsgs := make([]*oplog.Message, 0, len(deleteMembers)) + rowsDeleted, err := w.DeleteItems(ctx, deleteMembers, db.NewOplogMsgs(&userOplogMsgs)) + if err != nil { + return fmt.Errorf("delete group members: unable to delete group members: %w", err) + } + if rowsDeleted != len(deleteMembers) { + return fmt.Errorf("delete group members: group members deleted %d did not match request for %d", rowsDeleted, len(deleteMembers)) + } + totalRowsDeleted += rowsDeleted + msgs = append(msgs, userOplogMsgs...) + metadata := oplog.Metadata{ + "op-type": []string{oplog.OpType_OP_TYPE_DELETE.String()}, + "scope-id": []string{scope.PublicId}, + "scope-type": []string{scope.Type}, + "resource-public-id": []string{groupId}, + } + if err := w.WriteOplogEntryWith(ctx, r.wrapper, groupTicket, metadata, msgs); err != nil { + return fmt.Errorf("delete group members: unable to write oplog: %w", err) + } + return nil + }, + ) + if err != nil { + return db.NoRowsAffected, fmt.Errorf("delete group members: error deleting members: %w", err) + } + return totalRowsDeleted, nil +} + +// SetGroupMembers will set the group's members. If userIds is empty, the +// members will be cleared. +func (r *Repository) SetGroupMembers(ctx context.Context, groupId string, groupVersion int, userIds []string, opt ...Option) ([]*GroupMember, int, error) { + // NOTE - we are intentionally not going to check that the scopes are + // correct for the userIds, given the groupId. We are going to + // rely on the database constraints and triggers to maintain the integrity + // of these scope relationships. The users and group need to either be in + // the same organization or the group needs to be in a project of the user's + // org. There are constraints and triggers to enforce these relationships. + if groupId == "" { + return nil, db.NoRowsAffected, fmt.Errorf("set group members: missing role id: %w", db.ErrInvalidParameter) + } + group := allocGroup() + group.PublicId = groupId + scope, err := group.GetScope(ctx, r.reader) + if err != nil { + return nil, db.NoRowsAffected, fmt.Errorf("set group members: unable to get role %s scope: %w", groupId, err) + } + + // find existing members (since we're using groupVersion, we can safely do + // this here, outside the TxHandler) + members := []*GroupMember{} + if err := r.reader.SearchWhere(ctx, &members, "group_id = ?", []interface{}{groupId}); err != nil { + return nil, db.NoRowsAffected, fmt.Errorf("set group members: unable to search for existing members of group %s: %w", groupId, err) + } + found := map[string]*GroupMember{} + for _, m := range members { + found[m.GroupId+m.MemberId] = m + } + currentMembers := make([]*GroupMember, 0, len(userIds)+len(found)) + addMembers := make([]interface{}, 0, len(userIds)) + deleteMembers := make([]interface{}, 0, len(userIds)) + + for _, usrId := range userIds { + m, ok := found[groupId+usrId] + if ok { + // we have a match, so do nada since we want to keep it, but remove + // it from found. + currentMembers = append(currentMembers, m) + delete(found, groupId+usrId) + continue + } + // not found, so we add it + gm, err := NewGroupMember(groupId, usrId) + if err != nil { + return nil, db.NoRowsAffected, fmt.Errorf("add group members: unable to create in memory group member: %w", err) + } + addMembers = append(addMembers, gm) + currentMembers = append(currentMembers, gm) + } + if len(found) > 0 { + for _, gm := range found { + deleteMembers = append(deleteMembers, gm) + } + } + + // handle no change to existing group members + if len(addMembers) == 0 && len(deleteMembers) == 0 { + return currentMembers, db.NoRowsAffected, nil + } + + var totalRowsAffected int + _, err = r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + msgs := make([]*oplog.Message, 0, 2) + metadata := oplog.Metadata{ + "op-type": []string{oplog.OpType_OP_TYPE_UPDATE.String()}, + "scope-id": []string{scope.PublicId}, + "scope-type": []string{scope.Type}, + "resource-public-id": []string{groupId}, + } + // we need a group, which won't be redeemed until all the other + // writes are successful. We can't just use a single ticket because + // we need to write oplog entries for deletes and adds + groupTicket, err := w.GetTicket(&group) + if err != nil { + return fmt.Errorf("set group members: unable to get ticket for group: %w", err) + } + updatedGroup := allocGroup() + updatedGroup.PublicId = groupId + updatedGroup.Version = uint32(groupVersion) + 1 + var groupOplogMsg oplog.Message + rowsUpdated, err := w.Update(ctx, &updatedGroup, []string{"Version"}, nil, db.NewOplogMsg(&groupOplogMsg), db.WithVersion(groupVersion)) + if err != nil { + return fmt.Errorf("set group members: unable to update group version: %w", err) + } + if rowsUpdated != 1 { + return fmt.Errorf("set group members: updated group and %d rows updated", rowsUpdated) + } + if len(deleteMembers) > 0 { + userOplogMsgs := make([]*oplog.Message, 0, len(deleteMembers)) + rowsDeleted, err := w.DeleteItems(ctx, deleteMembers, db.NewOplogMsgs(&userOplogMsgs)) + if err != nil { + return fmt.Errorf("set group members: unable to delete user roles: %w", err) + } + if rowsDeleted != len(deleteMembers) { + return fmt.Errorf("set group members: members deleted %d did not match request for %d", rowsDeleted, len(deleteMembers)) + } + totalRowsAffected += rowsDeleted + msgs = append(msgs, userOplogMsgs...) + metadata["op-type"] = append(metadata["op-type"], oplog.OpType_OP_TYPE_DELETE.String()) + } + if len(addMembers) > 0 { + userOplogMsgs := make([]*oplog.Message, 0, len(addMembers)) + if err := w.CreateItems(ctx, addMembers, db.NewOplogMsgs(&userOplogMsgs)); err != nil { + return fmt.Errorf("set group members: unable to add users: %w", err) + } + totalRowsAffected += len(addMembers) + msgs = append(msgs, userOplogMsgs...) + metadata["op-type"] = append(metadata["op-type"], oplog.OpType_OP_TYPE_CREATE.String()) + + } + // we're done with all the principal writes, so let's write the + // role's update oplog message + if err := w.WriteOplogEntryWith(ctx, r.wrapper, groupTicket, metadata, msgs); err != nil { + return fmt.Errorf("set group members: unable to write oplog for additions: %w", err) + } + // we need a new repo, that's using the same reader/writer as this TxHandler + txRepo := Repository{ + reader: reader, + writer: w, + wrapper: r.wrapper, + defaultLimit: r.defaultLimit, + } + currentMembers, err = txRepo.ListGroupMembers(ctx, groupId) + if err != nil { + return fmt.Errorf("set group members: unable to retrieve current group members after sets: %w", err) + } + return nil + }) + if err != nil { + return nil, db.NoRowsAffected, fmt.Errorf("set group members: unable to set group members: %w", err) + } + return currentMembers, totalRowsAffected, nil +} diff --git a/internal/iam/repository_group_test.go b/internal/iam/repository_group_test.go index 10e76cd6f3..070c0ac6f6 100644 --- a/internal/iam/repository_group_test.go +++ b/internal/iam/repository_group_test.go @@ -3,6 +3,7 @@ package iam import ( "context" "errors" + "sort" "testing" "time" @@ -644,3 +645,441 @@ func TestRepository_ListGroups(t *testing.T) { }) } } + +func TestRepository_ListMembers(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + const testLimit = 10 + rw := db.New(conn) + wrapper := db.TestWrapper(t) + repo, err := NewRepository(rw, rw, wrapper, WithLimit(testLimit)) + require.NoError(t, err) + org, proj := TestScopes(t, conn) + pg := TestGroup(t, conn, proj.PublicId) + og := TestGroup(t, conn, org.PublicId) + + type args struct { + withGroupId string + opt []Option + } + tests := []struct { + name string + createCnt int + args args + wantCnt int + wantErr bool + }{ + { + name: "no-limit-pg-group", + createCnt: repo.defaultLimit + 1, + args: args{ + withGroupId: pg.PublicId, + opt: []Option{WithLimit(-1)}, + }, + wantCnt: repo.defaultLimit + 1, + wantErr: false, + }, + { + name: "no-limit-org-group", + createCnt: repo.defaultLimit + 1, + args: args{ + withGroupId: og.PublicId, + opt: []Option{WithLimit(-1)}, + }, + wantCnt: repo.defaultLimit + 1, + wantErr: false, + }, + { + name: "default-limit", + createCnt: repo.defaultLimit + 1, + args: args{ + withGroupId: pg.PublicId, + }, + wantCnt: repo.defaultLimit, + wantErr: false, + }, + { + name: "custom-limit", + createCnt: repo.defaultLimit + 1, + args: args{ + withGroupId: pg.PublicId, + opt: []Option{WithLimit(3)}, + }, + wantCnt: 3, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + require.NoError(conn.Where("1=1").Delete(allocGroupMember()).Error) + gm := []*GroupMember{} + for i := 0; i < tt.createCnt; i++ { + u := TestUser(t, conn, org.PublicId) + gm = append(gm, TestGroupMember(t, conn, tt.args.withGroupId, u.PublicId)) + } + assert.Equal(tt.createCnt, len(gm)) + + got, err := repo.ListGroupMembers(context.Background(), tt.args.withGroupId, tt.args.opt...) + if tt.wantErr { + require.Error(err) + return + } + require.NoError(err) + assert.Equal(tt.wantCnt, len(got)) + }) + } + t.Run("missing-id", func(t *testing.T) { + require := require.New(t) + got, err := repo.ListGroupMembers(context.Background(), "") + require.Error(err) + require.Nil(got) + require.Truef(errors.Is(err, db.ErrInvalidParameter), "unexpected error %s", err.Error()) + + }) +} + +func TestRepository_AddGroupMembers(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + repo, err := NewRepository(rw, rw, wrapper) + require.NoError(t, err) + org, proj := TestScopes(t, conn) + group := TestGroup(t, conn, proj.PublicId) + createUsersFn := func() []string { + results := []string{} + for i := 0; i < 5; i++ { + u := TestUser(t, conn, org.PublicId) + results = append(results, u.PublicId) + } + return results + } + groupVersion := 0 + type args struct { + groupId string + groupVersion int + userIds []string + opt []Option + } + tests := []struct { + name string + args args + wantErr bool + wantErrIs error + }{ + { + name: "valid-members", + args: args{ + groupId: group.PublicId, + userIds: createUsersFn(), + }, + wantErr: false, + }, + { + name: "valid-next-version", + args: args{ + groupId: group.PublicId, + userIds: createUsersFn(), + }, + wantErr: false, + }, + { + name: "bad-version", + args: args{ + groupId: group.PublicId, + groupVersion: 1000, + userIds: createUsersFn(), + }, + wantErr: true, + }, + { + name: "no-members", + args: args{ + groupId: group.PublicId, + userIds: nil, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + var version int + if tt.args.groupVersion == 0 { + groupVersion += 1 + version = groupVersion + } else { + version = tt.args.groupVersion + } + require.NoError(conn.Where("1=1").Delete(allocGroupMember()).Error) + got, err := repo.AddGroupMembers(context.Background(), tt.args.groupId, version, tt.args.userIds, tt.args.opt...) + if tt.wantErr { + require.Error(err) + if tt.wantErrIs != nil { + assert.Truef(errors.Is(err, tt.wantErrIs), "unexpected error %s", err.Error()) + } + return + } + require.NoError(err) + gotMembers := map[string]*GroupMember{} + for _, m := range got { + gotMembers[m.MemberId] = m + } + for _, id := range tt.args.userIds { + assert.NotEmpty(gotMembers[id]) + u, err := repo.LookupUser(context.Background(), id) + assert.NoError(err) + assert.Equal(id, u.PublicId) + } + err = db.TestVerifyOplog(t, rw, group.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_CREATE), db.WithCreateNotBefore(10*time.Second)) + assert.NoError(err) + + foundMembers, err := repo.ListGroupMembers(context.Background(), group.PublicId) + require.NoError(err) + for _, m := range foundMembers { + assert.NotEmpty(gotMembers[m.MemberId]) + assert.Equal(gotMembers[m.MemberId].GetGroupId(), m.GroupId) + } + + }) + } +} + +func TestRepository_DeleteGroupMembers(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + repo, err := NewRepository(rw, rw, wrapper) + require.NoError(t, err) + org, _ := TestScopes(t, conn) + + type args struct { + group *Group + groupIdOverride *string + groupVersion int + createUserCnt int + deleteUserCnt int + opt []Option + } + tests := []struct { + name string + args args + wantRowsDeleted int + wantErr bool + wantIsErr error + }{ + { + name: "valid", + args: args{ + group: TestGroup(t, conn, org.PublicId), + createUserCnt: 5, + deleteUserCnt: 5, + groupVersion: 2, + }, + wantRowsDeleted: 5, + wantErr: false, + }, + { + name: "valid-keeping-some", + args: args{ + group: TestGroup(t, conn, org.PublicId), + createUserCnt: 5, + deleteUserCnt: 2, + groupVersion: 2, + }, + wantRowsDeleted: 2, + wantErr: false, + }, + { + name: "not-found", + args: args{ + group: TestGroup(t, conn, org.PublicId), + groupVersion: 2, + groupIdOverride: func() *string { id := testId(t); return &id }(), + createUserCnt: 5, + deleteUserCnt: 5, + }, + wantRowsDeleted: 0, + wantErr: true, + }, + { + name: "missing-group-id", + args: args{ + group: TestGroup(t, conn, org.PublicId), + groupVersion: 2, + groupIdOverride: func() *string { id := ""; return &id }(), + createUserCnt: 5, + deleteUserCnt: 5, + }, + wantRowsDeleted: 0, + wantErr: true, + wantIsErr: db.ErrInvalidParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + userIds := make([]string, 0, tt.args.createUserCnt) + for i := 0; i < tt.args.createUserCnt; i++ { + u := TestUser(t, conn, org.PublicId) + userIds = append(userIds, u.PublicId) + } + members, err := repo.AddGroupMembers(context.Background(), tt.args.group.PublicId, 1, userIds, tt.args.opt...) + require.NoError(err) + assert.Equal(tt.args.createUserCnt, len(members)) + + deleteUserIds := make([]string, 0, tt.args.deleteUserCnt) + for i := 0; i < tt.args.deleteUserCnt; i++ { + deleteUserIds = append(deleteUserIds, userIds[i]) + } + var groupId string + switch { + case tt.args.groupIdOverride != nil: + groupId = *tt.args.groupIdOverride + default: + groupId = tt.args.group.PublicId + } + deletedRows, err := repo.DeleteGroupMembers(context.Background(), groupId, tt.args.groupVersion, deleteUserIds, tt.args.opt...) + if tt.wantErr { + assert.Error(err) + assert.Equal(0, deletedRows) + if tt.wantIsErr != nil { + assert.Truef(errors.Is(err, tt.wantIsErr), "unexpected error %s", err.Error()) + } + err = db.TestVerifyOplog(t, rw, tt.args.group.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_DELETE), db.WithCreateNotBefore(10*time.Second)) + assert.Error(err) + assert.True(errors.Is(db.ErrRecordNotFound, err)) + return + } + require.NoError(err) + assert.Equal(tt.wantRowsDeleted, deletedRows) + + err = db.TestVerifyOplog(t, rw, tt.args.group.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_DELETE), db.WithCreateNotBefore(10*time.Second)) + assert.NoError(err) + }) + } +} + +func TestRepository_SetGroupMembers(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + + repo, err := NewRepository(rw, rw, wrapper) + require.NoError(t, err) + + org, proj := TestScopes(t, conn) + testUser := TestUser(t, conn, org.PublicId) + + createUsersFn := func() []string { + results := []string{} + for i := 0; i < 5; i++ { + u := TestUser(t, conn, org.PublicId) + results = append(results, u.PublicId) + } + return results + } + setupFn := func(groupId string) []string { + users := createUsersFn() + _, err := repo.AddGroupMembers(context.Background(), groupId, 1, users) + require.NoError(t, err) + return users + } + type args struct { + group *Group + groupVersion int + userIds []string + addToOrigUsers bool + opt []Option + } + tests := []struct { + name string + setup func(string) []string + args args + wantAffectedRows int + wantErr bool + }{ + { + name: "clear", + setup: setupFn, + args: args{ + group: TestGroup(t, conn, proj.PublicId), + groupVersion: 2, // yep, since setupFn will increment it to 2 + userIds: []string{}, + }, + wantErr: false, + wantAffectedRows: 5, + }, + { + name: "no change", + setup: setupFn, + args: args{ + group: TestGroup(t, conn, proj.PublicId), + groupVersion: 2, // yep, since setupFn will increment it to 2 + userIds: []string{}, + addToOrigUsers: true, + }, + wantErr: false, + wantAffectedRows: 0, + }, + { + name: "add users", + setup: setupFn, + args: args{ + group: TestGroup(t, conn, proj.PublicId), + groupVersion: 2, // yep, since setupFn will increment it to 2 + userIds: []string{testUser.PublicId}, + addToOrigUsers: true, + }, + wantErr: false, + wantAffectedRows: 1, + }, + { + name: "remove existing and add users", + setup: setupFn, + args: args{ + group: TestGroup(t, conn, proj.PublicId), + groupVersion: 2, // yep, since setupFn will increment it to 2 + userIds: []string{testUser.PublicId}, + addToOrigUsers: false, + }, + wantErr: false, + wantAffectedRows: 6, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + var origUsers []string + if tt.setup != nil { + origUsers = tt.setup(tt.args.group.PublicId) + } + setUsers := tt.args.userIds + if tt.args.addToOrigUsers { + setUsers = append(setUsers, origUsers...) + } + + got, affectedRows, err := repo.SetGroupMembers(context.Background(), tt.args.group.PublicId, tt.args.groupVersion, setUsers, tt.args.opt...) + if tt.wantErr { + require.Error(err) + return + } + require.NoError(err) + assert.Equal(tt.wantAffectedRows, affectedRows) + var gotIds []string + for _, r := range got { + gotIds = append(gotIds, r.GetMemberId()) + } + var wantIds []string + wantIds = append(wantIds, tt.args.userIds...) + sort.Strings(wantIds) + sort.Strings(gotIds) + assert.Equal(wantIds, wantIds) + }) + } +} diff --git a/internal/iam/store/group.pb.go b/internal/iam/store/group.pb.go index 496c16ab80..c79ce3dfc4 100644 --- a/internal/iam/store/group.pb.go +++ b/internal/iam/store/group.pb.go @@ -53,6 +53,10 @@ type Group struct { // disabled is by default false and allows a Group to be marked disabled. // @inject_tag: `gorm:"default:null"` Disabled bool `protobuf:"varint,7,opt,name=disabled,proto3" json:"disabled,omitempty" gorm:"default:null"` + // version allows optimistic locking of the group when modifying the group + // itself and when modifying dependent items like group members. + // @inject_tag: `gorm:"default:null"` + Version uint32 `protobuf:"varint,8,opt,name=version,proto3" json:"version,omitempty" gorm:"default:null"` } func (x *Group) Reset() { @@ -136,6 +140,13 @@ func (x *Group) GetDisabled() bool { return false } +func (x *Group) GetVersion() uint32 { + if x != nil { + return x.Version + } + return 0 +} + var File_controller_storage_iam_store_v1_group_proto protoreflect.FileDescriptor var file_controller_storage_iam_store_v1_group_proto_rawDesc = []byte{ @@ -149,7 +160,7 @@ var file_controller_storage_iam_store_v1_group_proto_rawDesc = []byte{ 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x2b, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2f, 0x73, 0x74, 0x6f, 0x72, 0x61, 0x67, 0x65, 0x2f, 0x69, 0x61, 0x6d, 0x2f, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x2f, 0x76, 0x31, - 0x2f, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xab, 0x02, 0x0a, + 0x2f, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xc5, 0x02, 0x0a, 0x05, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x1b, 0x0a, 0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x49, 0x64, 0x12, 0x4b, 0x0a, 0x0b, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x5f, 0x74, 0x69, @@ -168,11 +179,13 @@ var file_controller_storage_iam_store_v1_group_proto_rawDesc = []byte{ 0x69, 0x6f, 0x6e, 0x12, 0x19, 0x0a, 0x08, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x49, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x08, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x3a, 0x5a, 0x38, 0x67, 0x69, - 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, - 0x72, 0x70, 0x2f, 0x77, 0x61, 0x74, 0x63, 0x68, 0x74, 0x6f, 0x77, 0x65, 0x72, 0x2f, 0x69, 0x6e, - 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x69, 0x61, 0x6d, 0x2f, 0x73, 0x74, 0x6f, 0x72, 0x65, - 0x3b, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x52, 0x08, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, + 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x42, 0x3a, 0x5a, 0x38, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x77, 0x61, 0x74, + 0x63, 0x68, 0x74, 0x6f, 0x77, 0x65, 0x72, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, + 0x2f, 0x69, 0x61, 0x6d, 0x2f, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x3b, 0x73, 0x74, 0x6f, 0x72, 0x65, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/internal/iam/store/group_member.pb.go b/internal/iam/store/group_member.pb.go index 71b82a4df2..40bdd1bbde 100644 --- a/internal/iam/store/group_member.pb.go +++ b/internal/iam/store/group_member.pb.go @@ -26,7 +26,7 @@ const ( // of the legacy proto package is being used. const _ = proto.ProtoPackageIsVersion4 -type GroupMemberUser struct { +type GroupMember struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields @@ -41,8 +41,8 @@ type GroupMemberUser struct { MemberId string `protobuf:"bytes,3,opt,name=member_id,json=memberId,proto3" json:"member_id,omitempty" gorm:"primary_key"` } -func (x *GroupMemberUser) Reset() { - *x = GroupMemberUser{} +func (x *GroupMember) Reset() { + *x = GroupMember{} if protoimpl.UnsafeEnabled { mi := &file_controller_storage_iam_store_v1_group_member_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -50,13 +50,13 @@ func (x *GroupMemberUser) Reset() { } } -func (x *GroupMemberUser) String() string { +func (x *GroupMember) String() string { return protoimpl.X.MessageStringOf(x) } -func (*GroupMemberUser) ProtoMessage() {} +func (*GroupMember) ProtoMessage() {} -func (x *GroupMemberUser) ProtoReflect() protoreflect.Message { +func (x *GroupMember) ProtoReflect() protoreflect.Message { mi := &file_controller_storage_iam_store_v1_group_member_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -68,108 +68,32 @@ func (x *GroupMemberUser) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use GroupMemberUser.ProtoReflect.Descriptor instead. -func (*GroupMemberUser) Descriptor() ([]byte, []int) { +// Deprecated: Use GroupMember.ProtoReflect.Descriptor instead. +func (*GroupMember) Descriptor() ([]byte, []int) { return file_controller_storage_iam_store_v1_group_member_proto_rawDescGZIP(), []int{0} } -func (x *GroupMemberUser) GetCreateTime() *timestamp.Timestamp { +func (x *GroupMember) GetCreateTime() *timestamp.Timestamp { if x != nil { return x.CreateTime } return nil } -func (x *GroupMemberUser) GetGroupId() string { +func (x *GroupMember) GetGroupId() string { if x != nil { return x.GroupId } return "" } -func (x *GroupMemberUser) GetMemberId() string { +func (x *GroupMember) GetMemberId() string { if x != nil { return x.MemberId } return "" } -type GroupMemberView struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // create_time from the RDBMS - // @inject_tag: `gorm:"default:current_timestamp"` - CreateTime *timestamp.Timestamp `protobuf:"bytes,1,opt,name=create_time,json=createTime,proto3" json:"create_time,omitempty" gorm:"default:current_timestamp"` - // group_id is the Group of this GroupMember. - GroupId string `protobuf:"bytes,2,opt,name=group_id,json=groupId,proto3" json:"group_id,omitempty"` - // member_id of the GroupMember - MemberId string `protobuf:"bytes,3,opt,name=member_id,json=memberId,proto3" json:"member_id,omitempty"` - // GroupMember type (User) - Type string `protobuf:"bytes,4,opt,name=type,proto3" json:"type,omitempty"` -} - -func (x *GroupMemberView) Reset() { - *x = GroupMemberView{} - if protoimpl.UnsafeEnabled { - mi := &file_controller_storage_iam_store_v1_group_member_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *GroupMemberView) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GroupMemberView) ProtoMessage() {} - -func (x *GroupMemberView) ProtoReflect() protoreflect.Message { - mi := &file_controller_storage_iam_store_v1_group_member_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GroupMemberView.ProtoReflect.Descriptor instead. -func (*GroupMemberView) Descriptor() ([]byte, []int) { - return file_controller_storage_iam_store_v1_group_member_proto_rawDescGZIP(), []int{1} -} - -func (x *GroupMemberView) GetCreateTime() *timestamp.Timestamp { - if x != nil { - return x.CreateTime - } - return nil -} - -func (x *GroupMemberView) GetGroupId() string { - if x != nil { - return x.GroupId - } - return "" -} - -func (x *GroupMemberView) GetMemberId() string { - if x != nil { - return x.MemberId - } - return "" -} - -func (x *GroupMemberView) GetType() string { - if x != nil { - return x.Type - } - return "" -} - var File_controller_storage_iam_store_v1_group_member_proto protoreflect.FileDescriptor var file_controller_storage_iam_store_v1_group_member_proto_rawDesc = []byte{ @@ -184,31 +108,20 @@ var file_controller_storage_iam_store_v1_group_member_proto_rawDesc = []byte{ 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x2b, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2f, 0x73, 0x74, 0x6f, 0x72, 0x61, 0x67, 0x65, 0x2f, 0x69, 0x61, 0x6d, 0x2f, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x22, 0x96, 0x01, 0x0a, 0x0f, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x4d, 0x65, 0x6d, - 0x62, 0x65, 0x72, 0x55, 0x73, 0x65, 0x72, 0x12, 0x4b, 0x0a, 0x0b, 0x63, 0x72, 0x65, 0x61, 0x74, - 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2a, 0x2e, 0x63, - 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2e, 0x73, 0x74, 0x6f, 0x72, 0x61, 0x67, - 0x65, 0x2e, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x54, - 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x54, 0x69, 0x6d, 0x65, 0x12, 0x19, 0x0a, 0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x12, - 0x1b, 0x0a, 0x09, 0x6d, 0x65, 0x6d, 0x62, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x08, 0x6d, 0x65, 0x6d, 0x62, 0x65, 0x72, 0x49, 0x64, 0x22, 0xaa, 0x01, 0x0a, - 0x0f, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x4d, 0x65, 0x6d, 0x62, 0x65, 0x72, 0x56, 0x69, 0x65, 0x77, - 0x12, 0x4b, 0x0a, 0x0b, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2a, 0x2e, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, - 0x65, 0x72, 0x2e, 0x73, 0x74, 0x6f, 0x72, 0x61, 0x67, 0x65, 0x2e, 0x74, 0x69, 0x6d, 0x65, 0x73, - 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, - 0x70, 0x52, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x69, 0x6d, 0x65, 0x12, 0x19, 0x0a, - 0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x07, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x6d, 0x65, 0x6d, 0x62, - 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x6d, 0x65, 0x6d, - 0x62, 0x65, 0x72, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x42, 0x3a, 0x5a, 0x38, 0x67, 0x69, 0x74, - 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, - 0x70, 0x2f, 0x77, 0x61, 0x74, 0x63, 0x68, 0x74, 0x6f, 0x77, 0x65, 0x72, 0x2f, 0x69, 0x6e, 0x74, - 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x69, 0x61, 0x6d, 0x2f, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x3b, - 0x73, 0x74, 0x6f, 0x72, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6f, 0x74, 0x6f, 0x22, 0x92, 0x01, 0x0a, 0x0b, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x4d, 0x65, 0x6d, + 0x62, 0x65, 0x72, 0x12, 0x4b, 0x0a, 0x0b, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x5f, 0x74, 0x69, + 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2a, 0x2e, 0x63, 0x6f, 0x6e, 0x74, 0x72, + 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2e, 0x73, 0x74, 0x6f, 0x72, 0x61, 0x67, 0x65, 0x2e, 0x74, 0x69, + 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x76, 0x31, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, + 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x54, 0x69, 0x6d, 0x65, + 0x12, 0x19, 0x0a, 0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x6d, + 0x65, 0x6d, 0x62, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x6d, 0x65, 0x6d, 0x62, 0x65, 0x72, 0x49, 0x64, 0x42, 0x3a, 0x5a, 0x38, 0x67, 0x69, 0x74, 0x68, + 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, + 0x2f, 0x77, 0x61, 0x74, 0x63, 0x68, 0x74, 0x6f, 0x77, 0x65, 0x72, 0x2f, 0x69, 0x6e, 0x74, 0x65, + 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x69, 0x61, 0x6d, 0x2f, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x3b, 0x73, + 0x74, 0x6f, 0x72, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -223,20 +136,18 @@ func file_controller_storage_iam_store_v1_group_member_proto_rawDescGZIP() []byt return file_controller_storage_iam_store_v1_group_member_proto_rawDescData } -var file_controller_storage_iam_store_v1_group_member_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_controller_storage_iam_store_v1_group_member_proto_msgTypes = make([]protoimpl.MessageInfo, 1) var file_controller_storage_iam_store_v1_group_member_proto_goTypes = []interface{}{ - (*GroupMemberUser)(nil), // 0: controller.storage.iam.store.v1.GroupMemberUser - (*GroupMemberView)(nil), // 1: controller.storage.iam.store.v1.GroupMemberView - (*timestamp.Timestamp)(nil), // 2: controller.storage.timestamp.v1.Timestamp + (*GroupMember)(nil), // 0: controller.storage.iam.store.v1.GroupMember + (*timestamp.Timestamp)(nil), // 1: controller.storage.timestamp.v1.Timestamp } var file_controller_storage_iam_store_v1_group_member_proto_depIdxs = []int32{ - 2, // 0: controller.storage.iam.store.v1.GroupMemberUser.create_time:type_name -> controller.storage.timestamp.v1.Timestamp - 2, // 1: controller.storage.iam.store.v1.GroupMemberView.create_time:type_name -> controller.storage.timestamp.v1.Timestamp - 2, // [2:2] is the sub-list for method output_type - 2, // [2:2] is the sub-list for method input_type - 2, // [2:2] is the sub-list for extension type_name - 2, // [2:2] is the sub-list for extension extendee - 0, // [0:2] is the sub-list for field type_name + 1, // 0: controller.storage.iam.store.v1.GroupMember.create_time:type_name -> controller.storage.timestamp.v1.Timestamp + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name } func init() { file_controller_storage_iam_store_v1_group_member_proto_init() } @@ -247,19 +158,7 @@ func file_controller_storage_iam_store_v1_group_member_proto_init() { file_controller_storage_iam_store_v1_scope_proto_init() if !protoimpl.UnsafeEnabled { file_controller_storage_iam_store_v1_group_member_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GroupMemberUser); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_controller_storage_iam_store_v1_group_member_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GroupMemberView); i { + switch v := v.(*GroupMember); i { case 0: return &v.state case 1: @@ -277,7 +176,7 @@ func file_controller_storage_iam_store_v1_group_member_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_controller_storage_iam_store_v1_group_member_proto_rawDesc, NumEnums: 0, - NumMessages: 2, + NumMessages: 1, NumExtensions: 0, NumServices: 0, }, diff --git a/internal/iam/testing.go b/internal/iam/testing.go index 314d80b1c3..d75a0df343 100644 --- a/internal/iam/testing.go +++ b/internal/iam/testing.go @@ -154,6 +154,19 @@ func TestGroup(t *testing.T, conn *gorm.DB, scopeId string, opt ...Option) *Grou return grp } +func TestGroupMember(t *testing.T, conn *gorm.DB, groupId, userId string, opt ...Option) *GroupMember { + t.Helper() + require := require.New(t) + rw := db.New(conn) + gm, err := NewGroupMember(groupId, userId) + require.NoError(err) + require.NotNil(gm) + err = rw.Create(context.Background(), gm) + require.NoError(err) + require.NotEmpty(gm.CreateTime) + return gm +} + func TestUserRole(t *testing.T, conn *gorm.DB, scopeId, roleId, userId string, opt ...Option) *UserRole { t.Helper() require := require.New(t) diff --git a/internal/iam/testing_test.go b/internal/iam/testing_test.go index 7d4b5ec985..c2fa977dd9 100644 --- a/internal/iam/testing_test.go +++ b/internal/iam/testing_test.go @@ -141,3 +141,24 @@ func Test_TestGroupRole(t *testing.T) { require.Equal(projRole.PublicId, groupRole.RoleId) require.Equal(projGroup.PublicId, groupRole.PrincipalId) } + +func Test_TestGroupMember(t *testing.T) { + t.Helper() + require := require.New(t) + conn, _ := db.TestSetup(t, "postgres") + org, proj := TestScopes(t, conn) + og := TestGroup(t, conn, org.PublicId) + pg := TestGroup(t, conn, proj.PublicId) + u := TestUser(t, conn, org.PublicId) + + gm := TestGroupMember(t, conn, og.PublicId, u.PublicId) + require.NotNil(gm) + require.Equal(og.PublicId, gm.GroupId) + require.Equal(u.PublicId, gm.MemberId) + + gm = TestGroupMember(t, conn, pg.PublicId, u.PublicId) + require.NotNil(gm) + require.Equal(pg.PublicId, gm.GroupId) + require.Equal(u.PublicId, gm.MemberId) + +} diff --git a/internal/iam/user_grants.go b/internal/iam/user_grants.go index 5e9b66b389..9f1e72aa19 100644 --- a/internal/iam/user_grants.go +++ b/internal/iam/user_grants.go @@ -22,12 +22,12 @@ from iam_role_grant rg, iam_principal_role ipr, iam_group grp, - iam_group_member_user gm + iam_group_member gm where rg.role_id = ipr.role_id and ipr.principal_id = grp.public_id and grp.public_id = gm.group_id and - gm.member_id = $1 and + gm.member_id = $1 and ipr."type" = 'group' union select diff --git a/internal/iam/user_grants_test.go b/internal/iam/user_grants_test.go index 2b1768768d..78e8f97c57 100644 --- a/internal/iam/user_grants_test.go +++ b/internal/iam/user_grants_test.go @@ -44,7 +44,8 @@ func Test_RawGrants(t *testing.T) { grp := TestGroup(t, conn, org.PublicId) - gm, err := grp.AddUser(user.PublicId) + gm, err := NewGroupMember(grp.PublicId, user.PublicId) + assert.NoError(err) assert.NotNil(gm) err = w.Create(context.Background(), gm) diff --git a/internal/iam/user_groups.go b/internal/iam/user_groups.go index 5c293e1df1..ee4ef1f88f 100644 --- a/internal/iam/user_groups.go +++ b/internal/iam/user_groups.go @@ -9,7 +9,7 @@ import ( // Groups will get the user's groups func (u *User) Groups(ctx context.Context, r db.Reader) ([]*Group, error) { - const where = "public_id in (select distinct group_id from iam_group_member_user where member_id = ?)" + const where = "public_id in (select distinct group_id from iam_group_member where member_id = ?)" if r == nil { return nil, errors.New("error reader is nil for getting the user's groups") diff --git a/internal/iam/user_groups_test.go b/internal/iam/user_groups_test.go index 4bfc139ae6..b97962ae30 100644 --- a/internal/iam/user_groups_test.go +++ b/internal/iam/user_groups_test.go @@ -19,7 +19,8 @@ func Test_UserGroups(t *testing.T) { grp := TestGroup(t, conn, org.PublicId) - gm, err := grp.AddUser(user.PublicId) + gm, err := NewGroupMember(grp.PublicId, user.PublicId) + assert.NoError(err) assert.NotNil(gm) err = w.Create(context.Background(), gm) diff --git a/internal/proto/local/controller/storage/iam/store/v1/group.proto b/internal/proto/local/controller/storage/iam/store/v1/group.proto index 861ca8491d..9b3ecdc68d 100644 --- a/internal/proto/local/controller/storage/iam/store/v1/group.proto +++ b/internal/proto/local/controller/storage/iam/store/v1/group.proto @@ -35,4 +35,9 @@ message Group { // disabled is by default false and allows a Group to be marked disabled. // @inject_tag: `gorm:"default:null"` bool disabled = 7; + + // version allows optimistic locking of the group when modifying the group + // itself and when modifying dependent items like group members. + // @inject_tag: `gorm:"default:null"` + uint32 version = 8; } \ No newline at end of file diff --git a/internal/proto/local/controller/storage/iam/store/v1/group_member.proto b/internal/proto/local/controller/storage/iam/store/v1/group_member.proto index c393ed5bce..d486c58077 100644 --- a/internal/proto/local/controller/storage/iam/store/v1/group_member.proto +++ b/internal/proto/local/controller/storage/iam/store/v1/group_member.proto @@ -6,7 +6,7 @@ option go_package = "github.com/hashicorp/watchtower/internal/iam/store;store"; import "controller/storage/timestamp/v1/timestamp.proto"; import "controller/storage/iam/store/v1/scope.proto"; -message GroupMemberUser { +message GroupMember { // create_time from the RDBMS // @inject_tag: `gorm:"default:current_timestamp"` timestamp.v1.Timestamp create_time = 1; @@ -18,18 +18,3 @@ message GroupMemberUser { // @inject_tag: gorm:"primary_key" string member_id = 3; } - -message GroupMemberView { - // create_time from the RDBMS - // @inject_tag: `gorm:"default:current_timestamp"` - timestamp.v1.Timestamp create_time = 1; - - // group_id is the Group of this GroupMember. - string group_id = 2; - - // member_id of the GroupMember - string member_id = 3; - - // GroupMember type (User) - string type = 4; -} \ No newline at end of file