diff options
-rw-r--r-- | roomserver/storage/shared/membership_updater.go | 144 |
1 files changed, 79 insertions, 65 deletions
diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 329813bf..834af606 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -79,89 +79,103 @@ func (u *MembershipUpdater) IsLeave() bool { // SetToInvite implements types.MembershipUpdater func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) - if err != nil { - return false, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) - } - inserted, err := u.d.InvitesTable.InsertInviteEvent( - u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), - ) - if err != nil { - return false, fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) - } - if u.membership != tables.MembershipStateInvite { - if err = u.d.MembershipTable.UpdateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, - ); err != nil { - return false, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + var inserted bool + err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + if err != nil { + return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } - } - return inserted, nil + inserted, err = u.d.InvitesTable.InsertInviteEvent( + u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), + ) + if err != nil { + return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) + } + if u.membership != tables.MembershipStateInvite { + if err = u.d.MembershipTable.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, + ); err != nil { + return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + } + } + return nil + }) + return inserted, err } // SetToJoin implements types.MembershipUpdater func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { var inviteEventIDs []string - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) - if err != nil { - return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) - } - - // If this is a join event update, there is no invite to update - if !isUpdate { - inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired( - u.ctx, u.txn, u.roomNID, u.targetUserNID, - ) + err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) if err != nil { - return nil, fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err) + return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } - } - // Look up the NID of the new join event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return nil, fmt.Errorf("u.d.EventNIDs: %w", err) - } + // If this is a join event update, there is no invite to update + if !isUpdate { + inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err) + } + } - if u.membership != tables.MembershipStateJoin || isUpdate { - if err = u.d.MembershipTable.UpdateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - tables.MembershipStateJoin, nIDs[eventID], - ); err != nil { - return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + // Look up the NID of the new join event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return fmt.Errorf("u.d.EventNIDs: %w", err) } - } - return inviteEventIDs, nil + if u.membership != tables.MembershipStateJoin || isUpdate { + if err = u.d.MembershipTable.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + tables.MembershipStateJoin, nIDs[eventID], + ); err != nil { + return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + } + } + + return nil + }) + + return inviteEventIDs, err } // SetToLeave implements types.MembershipUpdater func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) - if err != nil { - return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) - } - inviteEventIDs, err := u.d.InvitesTable.UpdateInviteRetired( - u.ctx, u.txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return nil, fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err) - } + var inviteEventIDs []string - // Look up the NID of the new leave event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return nil, fmt.Errorf("u.d.EventNIDs: %w", err) - } + err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) + if err != nil { + return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) + } + inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err) + } + + // Look up the NID of the new leave event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return fmt.Errorf("u.d.EventNIDs: %w", err) + } - if u.membership != tables.MembershipStateLeaveOrBan { - if err = u.d.MembershipTable.UpdateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - tables.MembershipStateLeaveOrBan, nIDs[eventID], - ); err != nil { - return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + if u.membership != tables.MembershipStateLeaveOrBan { + if err = u.d.MembershipTable.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + tables.MembershipStateLeaveOrBan, nIDs[eventID], + ); err != nil { + return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + } } - } - return inviteEventIDs, nil + + return nil + }) + return inviteEventIDs, err } |