aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--roomserver/storage/shared/membership_updater.go144
1 files changed, 79 insertions, 65 deletions
diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go
index 329813bf..834af606 100644
--- a/roomserver/storage/shared/membership_updater.go
+++ b/roomserver/storage/shared/membership_updater.go
@@ -79,89 +79,103 @@ func (u *MembershipUpdater) IsLeave() bool {
// SetToInvite implements types.MembershipUpdater
func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
- senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
- if err != nil {
- return false, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
- }
- inserted, err := u.d.InvitesTable.InsertInviteEvent(
- u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
- )
- if err != nil {
- return false, fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err)
- }
- if u.membership != tables.MembershipStateInvite {
- if err = u.d.MembershipTable.UpdateMembership(
- u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
- ); err != nil {
- return false, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
+ var inserted bool
+ err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
+ if err != nil {
+ return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
- }
- return inserted, nil
+ inserted, err = u.d.InvitesTable.InsertInviteEvent(
+ u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
+ )
+ if err != nil {
+ return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err)
+ }
+ if u.membership != tables.MembershipStateInvite {
+ if err = u.d.MembershipTable.UpdateMembership(
+ u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
+ ); err != nil {
+ return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
+ }
+ }
+ return nil
+ })
+ return inserted, err
}
// SetToJoin implements types.MembershipUpdater
func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
var inviteEventIDs []string
- senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
- if err != nil {
- return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
- }
-
- // If this is a join event update, there is no invite to update
- if !isUpdate {
- inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired(
- u.ctx, u.txn, u.roomNID, u.targetUserNID,
- )
+ err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
if err != nil {
- return nil, fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err)
+ return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
- }
- // Look up the NID of the new join event
- nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
- if err != nil {
- return nil, fmt.Errorf("u.d.EventNIDs: %w", err)
- }
+ // If this is a join event update, there is no invite to update
+ if !isUpdate {
+ inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired(
+ u.ctx, u.txn, u.roomNID, u.targetUserNID,
+ )
+ if err != nil {
+ return fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err)
+ }
+ }
- if u.membership != tables.MembershipStateJoin || isUpdate {
- if err = u.d.MembershipTable.UpdateMembership(
- u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
- tables.MembershipStateJoin, nIDs[eventID],
- ); err != nil {
- return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
+ // Look up the NID of the new join event
+ nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
+ if err != nil {
+ return fmt.Errorf("u.d.EventNIDs: %w", err)
}
- }
- return inviteEventIDs, nil
+ if u.membership != tables.MembershipStateJoin || isUpdate {
+ if err = u.d.MembershipTable.UpdateMembership(
+ u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
+ tables.MembershipStateJoin, nIDs[eventID],
+ ); err != nil {
+ return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
+ }
+ }
+
+ return nil
+ })
+
+ return inviteEventIDs, err
}
// SetToLeave implements types.MembershipUpdater
func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
- senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
- if err != nil {
- return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
- }
- inviteEventIDs, err := u.d.InvitesTable.UpdateInviteRetired(
- u.ctx, u.txn, u.roomNID, u.targetUserNID,
- )
- if err != nil {
- return nil, fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err)
- }
+ var inviteEventIDs []string
- // Look up the NID of the new leave event
- nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
- if err != nil {
- return nil, fmt.Errorf("u.d.EventNIDs: %w", err)
- }
+ err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
+ if err != nil {
+ return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
+ }
+ inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired(
+ u.ctx, u.txn, u.roomNID, u.targetUserNID,
+ )
+ if err != nil {
+ return fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err)
+ }
+
+ // Look up the NID of the new leave event
+ nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
+ if err != nil {
+ return fmt.Errorf("u.d.EventNIDs: %w", err)
+ }
- if u.membership != tables.MembershipStateLeaveOrBan {
- if err = u.d.MembershipTable.UpdateMembership(
- u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
- tables.MembershipStateLeaveOrBan, nIDs[eventID],
- ); err != nil {
- return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
+ if u.membership != tables.MembershipStateLeaveOrBan {
+ if err = u.d.MembershipTable.UpdateMembership(
+ u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
+ tables.MembershipStateLeaveOrBan, nIDs[eventID],
+ ); err != nil {
+ return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
+ }
}
- }
- return inviteEventIDs, nil
+
+ return nil
+ })
+ return inviteEventIDs, err
}