diff --git a/internal/iam/query.go b/internal/iam/query.go index 3e9eca56db..56c0f92b4a 100644 --- a/internal/iam/query.go +++ b/internal/iam/query.go @@ -64,4 +64,48 @@ const ( select * from final order by action, account_id; ` + + grpMemberChangesQuery = ` + with + final_members (member_id) as ( + -- returns the SET list + select public_id + from iam_user + where + public_id in (%s) + ), + current_members (member_id) as ( + -- returns the current list + select member_id + from iam_group_member + where group_id = $1 + ), + keep_members (member_id) as ( + -- returns the KEEP list + select member_id + from current_members + where member_id in (select * from final_members) + ), + delete_members (member_id) as ( + -- returns the DELETE list + select member_id + from current_members + where member_id not in (select * from final_members) + ), + insert_members (member_id) as ( + -- returns the ADD list + select member_id + from final_members + where member_id not in (select * from keep_members) + ), + final (action, member_id) as ( + select 'delete', member_id + from delete_members + union + select 'add', member_id + from insert_members + ) + select * from final + order by action, member_id; + ` ) diff --git a/internal/iam/repository_group.go b/internal/iam/repository_group.go index 2da07706c6..d8acfee56d 100644 --- a/internal/iam/repository_group.go +++ b/internal/iam/repository_group.go @@ -386,63 +386,39 @@ func (r *Repository) SetGroupMembers(ctx context.Context, groupId string, groupV return nil, db.NoRowsAffected, fmt.Errorf("set group members: unable to get members %s scope: %w", groupId, err) } - // TODO(mgaffney) 08/2020: Use SQL to calculate changes. - - // find existing members (since we're using groupVersion, we can safely do - // this here, outside the TxHandler) - currentMembers := []*GroupMember{} - if err := r.reader.SearchWhere(ctx, ¤tMembers, "group_id = ?", []interface{}{groupId}); err != nil { - return nil, db.NoRowsAffected, fmt.Errorf("set group members: unable to search for existing members of group %s: %w", groupId, err) - } - found := map[string]*GroupMember{} - for _, m := range currentMembers { - found[m.GroupId+m.MemberId] = m - } - addMembers := make([]interface{}, 0, len(userIds)) - deleteMembers := make([]interface{}, 0, len(userIds)) - - for _, usrId := range userIds { - _, ok := found[groupId+usrId] - if ok { - // we have a match, so do nada since we want to keep it, but remove - // it from found. - delete(found, groupId+usrId) - continue - } - // not found, so we add it - gm, err := NewGroupMemberUser(groupId, usrId) - if err != nil { - return nil, db.NoRowsAffected, fmt.Errorf("set group members: unable to create in memory group member: %w", err) - } - addMembers = append(addMembers, gm) - } - if len(found) > 0 { - for _, fgm := range found { - // not found, so we add it - gm, err := NewGroupMemberUser(fgm.GroupId, fgm.MemberId) - if err != nil { - return nil, db.NoRowsAffected, fmt.Errorf("set group members: unable to create in memory group member: %w", err) - } - deleteMembers = append(deleteMembers, gm) - } - } - - // handle no change to existing group members - if len(addMembers) == 0 && len(deleteMembers) == 0 { - return currentMembers, db.NoRowsAffected, nil - } - oplogWrapper, err := r.kms.GetWrapper(ctx, scope.GetPublicId(), kms.KeyPurposeOplog) if err != nil { return nil, db.NoRowsAffected, fmt.Errorf("add group members: unable to get oplog wrapper: %w", err) } + var currentMembers []*GroupMember var totalRowsAffected int _, err = r.writer.DoTx( ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { + // we need a new repo, that's using the same reader/writer as this TxHandler + txRepo := Repository{ + reader: reader, + writer: w, + kms: r.kms, + // intentionally not setting the defaultLimit, so we'll get all + // the members without a limit + } + addMembers, deleteMembers, err := groupMemberChanges(ctx, reader, groupId, userIds) + if err != nil { + return fmt.Errorf("set associated accounts: unable to determine changes: %w", err) + } + // handle no change to existing group members + if len(addMembers) == 0 && len(deleteMembers) == 0 { + currentMembers, err = txRepo.ListGroupMembers(ctx, groupId) + if err != nil { + return fmt.Errorf("set group members: unable to retrieve current group members after sets: %w", err) + } + return nil + } + msgs := make([]*oplog.Message, 0, 2) metadata := oplog.Metadata{ "op-type": []string{oplog.OpType_OP_TYPE_UPDATE.String()}, @@ -496,14 +472,6 @@ func (r *Repository) SetGroupMembers(ctx context.Context, groupId string, groupV if err := w.WriteOplogEntryWith(ctx, oplogWrapper, groupTicket, metadata, msgs); err != nil { return fmt.Errorf("set group members: unable to write oplog for additions: %w", err) } - // we need a new repo, that's using the same reader/writer as this TxHandler - txRepo := Repository{ - reader: reader, - writer: w, - kms: r.kms, - // intentionally not setting the defaultLimit, so we'll get all - // the members without a limit - } currentMembers, err = txRepo.ListGroupMembers(ctx, groupId) if err != nil { return fmt.Errorf("set group members: unable to retrieve current group members after sets: %w", err) @@ -515,3 +483,67 @@ func (r *Repository) SetGroupMembers(ctx context.Context, groupId string, groupV } return currentMembers, totalRowsAffected, nil } + +// associationChanges returns two slices: accounts to associate and disassociate +func groupMemberChanges(ctx context.Context, reader db.Reader, groupId string, userIds []string) ([]interface{}, []interface{}, error) { + var inClauseSpots []string + // starts at 2 because there is already a $1 in the query + for i := 2; i < len(userIds)+2; i++ { + inClauseSpots = append(inClauseSpots, fmt.Sprintf("$%d", i)) + } + inClause := strings.Join(inClauseSpots, ",") + if inClause == "" { + inClause = "''" + } + query := fmt.Sprintf(grpMemberChangesQuery, inClause) + + var params []interface{} + params = append(params, groupId) + for _, v := range userIds { + params = append(params, v) + } + // fmt.Println(query, params) + rows, err := reader.Query(query, params) + if err != nil { + return nil, nil, fmt.Errorf("changes: query failed: %w", err) + } + defer rows.Close() + + type change struct { + Action string + MemberId string + } + var changes []*change + for rows.Next() { + var chg change + if err := reader.ScanRows(rows, &chg); err != nil { + return nil, nil, fmt.Errorf("changes: scan row failed: %w", err) + } + changes = append(changes, &chg) + } + addMembers := []interface{}{} + deleteMembers := []interface{}{} + for _, c := range changes { + if c.MemberId == "" { + return nil, nil, fmt.Errorf("changes: missing user id in change result") + } + switch c.Action { + case "add": + gm, err := NewGroupMemberUser(groupId, c.MemberId) + if err != nil { + return nil, nil, fmt.Errorf("set group members: unable to create in memory group member for add: %w", err) + } + addMembers = append(addMembers, gm) + case "delete": + gm, err := NewGroupMemberUser(groupId, c.MemberId) + if err != nil { + return nil, nil, fmt.Errorf("set group members: unable to create in memory group member for delete: %w", err) + } + deleteMembers = append(deleteMembers, gm) + default: + return nil, nil, fmt.Errorf("changes: unknown action %s for %s", c.Action, c.MemberId) + } + + } + return addMembers, deleteMembers, nil +}