use CTE within transaction to calc SetGroupMembers changes (#454)

pull/476/head
Jim 6 years ago committed by GitHub
parent cf3fa4522d
commit acd28e5372
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -64,4 +64,48 @@ const (
select * from final
order by action, account_id;
`
grpMemberChangesQuery = `
with
final_members (member_id) as (
-- returns the SET list
select public_id
from iam_user
where
public_id in (%s)
),
current_members (member_id) as (
-- returns the current list
select member_id
from iam_group_member
where group_id = $1
),
keep_members (member_id) as (
-- returns the KEEP list
select member_id
from current_members
where member_id in (select * from final_members)
),
delete_members (member_id) as (
-- returns the DELETE list
select member_id
from current_members
where member_id not in (select * from final_members)
),
insert_members (member_id) as (
-- returns the ADD list
select member_id
from final_members
where member_id not in (select * from keep_members)
),
final (action, member_id) as (
select 'delete', member_id
from delete_members
union
select 'add', member_id
from insert_members
)
select * from final
order by action, member_id;
`
)

@ -386,63 +386,39 @@ func (r *Repository) SetGroupMembers(ctx context.Context, groupId string, groupV
return nil, db.NoRowsAffected, fmt.Errorf("set group members: unable to get members %s scope: %w", groupId, err)
}
// TODO(mgaffney) 08/2020: Use SQL to calculate changes.
// find existing members (since we're using groupVersion, we can safely do
// this here, outside the TxHandler)
currentMembers := []*GroupMember{}
if err := r.reader.SearchWhere(ctx, &currentMembers, "group_id = ?", []interface{}{groupId}); err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("set group members: unable to search for existing members of group %s: %w", groupId, err)
}
found := map[string]*GroupMember{}
for _, m := range currentMembers {
found[m.GroupId+m.MemberId] = m
}
addMembers := make([]interface{}, 0, len(userIds))
deleteMembers := make([]interface{}, 0, len(userIds))
for _, usrId := range userIds {
_, ok := found[groupId+usrId]
if ok {
// we have a match, so do nada since we want to keep it, but remove
// it from found.
delete(found, groupId+usrId)
continue
}
// not found, so we add it
gm, err := NewGroupMemberUser(groupId, usrId)
if err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("set group members: unable to create in memory group member: %w", err)
}
addMembers = append(addMembers, gm)
}
if len(found) > 0 {
for _, fgm := range found {
// not found, so we add it
gm, err := NewGroupMemberUser(fgm.GroupId, fgm.MemberId)
if err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("set group members: unable to create in memory group member: %w", err)
}
deleteMembers = append(deleteMembers, gm)
}
}
// handle no change to existing group members
if len(addMembers) == 0 && len(deleteMembers) == 0 {
return currentMembers, db.NoRowsAffected, nil
}
oplogWrapper, err := r.kms.GetWrapper(ctx, scope.GetPublicId(), kms.KeyPurposeOplog)
if err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("add group members: unable to get oplog wrapper: %w", err)
}
var currentMembers []*GroupMember
var totalRowsAffected int
_, err = r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
// we need a new repo, that's using the same reader/writer as this TxHandler
txRepo := Repository{
reader: reader,
writer: w,
kms: r.kms,
// intentionally not setting the defaultLimit, so we'll get all
// the members without a limit
}
addMembers, deleteMembers, err := groupMemberChanges(ctx, reader, groupId, userIds)
if err != nil {
return fmt.Errorf("set associated accounts: unable to determine changes: %w", err)
}
// handle no change to existing group members
if len(addMembers) == 0 && len(deleteMembers) == 0 {
currentMembers, err = txRepo.ListGroupMembers(ctx, groupId)
if err != nil {
return fmt.Errorf("set group members: unable to retrieve current group members after sets: %w", err)
}
return nil
}
msgs := make([]*oplog.Message, 0, 2)
metadata := oplog.Metadata{
"op-type": []string{oplog.OpType_OP_TYPE_UPDATE.String()},
@ -496,14 +472,6 @@ func (r *Repository) SetGroupMembers(ctx context.Context, groupId string, groupV
if err := w.WriteOplogEntryWith(ctx, oplogWrapper, groupTicket, metadata, msgs); err != nil {
return fmt.Errorf("set group members: unable to write oplog for additions: %w", err)
}
// we need a new repo, that's using the same reader/writer as this TxHandler
txRepo := Repository{
reader: reader,
writer: w,
kms: r.kms,
// intentionally not setting the defaultLimit, so we'll get all
// the members without a limit
}
currentMembers, err = txRepo.ListGroupMembers(ctx, groupId)
if err != nil {
return fmt.Errorf("set group members: unable to retrieve current group members after sets: %w", err)
@ -515,3 +483,67 @@ func (r *Repository) SetGroupMembers(ctx context.Context, groupId string, groupV
}
return currentMembers, totalRowsAffected, nil
}
// associationChanges returns two slices: accounts to associate and disassociate
func groupMemberChanges(ctx context.Context, reader db.Reader, groupId string, userIds []string) ([]interface{}, []interface{}, error) {
var inClauseSpots []string
// starts at 2 because there is already a $1 in the query
for i := 2; i < len(userIds)+2; i++ {
inClauseSpots = append(inClauseSpots, fmt.Sprintf("$%d", i))
}
inClause := strings.Join(inClauseSpots, ",")
if inClause == "" {
inClause = "''"
}
query := fmt.Sprintf(grpMemberChangesQuery, inClause)
var params []interface{}
params = append(params, groupId)
for _, v := range userIds {
params = append(params, v)
}
// fmt.Println(query, params)
rows, err := reader.Query(query, params)
if err != nil {
return nil, nil, fmt.Errorf("changes: query failed: %w", err)
}
defer rows.Close()
type change struct {
Action string
MemberId string
}
var changes []*change
for rows.Next() {
var chg change
if err := reader.ScanRows(rows, &chg); err != nil {
return nil, nil, fmt.Errorf("changes: scan row failed: %w", err)
}
changes = append(changes, &chg)
}
addMembers := []interface{}{}
deleteMembers := []interface{}{}
for _, c := range changes {
if c.MemberId == "" {
return nil, nil, fmt.Errorf("changes: missing user id in change result")
}
switch c.Action {
case "add":
gm, err := NewGroupMemberUser(groupId, c.MemberId)
if err != nil {
return nil, nil, fmt.Errorf("set group members: unable to create in memory group member for add: %w", err)
}
addMembers = append(addMembers, gm)
case "delete":
gm, err := NewGroupMemberUser(groupId, c.MemberId)
if err != nil {
return nil, nil, fmt.Errorf("set group members: unable to create in memory group member for delete: %w", err)
}
deleteMembers = append(deleteMembers, gm)
default:
return nil, nil, fmt.Errorf("changes: unknown action %s for %s", c.Action, c.MemberId)
}
}
return addMembers, deleteMembers, nil
}

Loading…
Cancel
Save