aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/shared
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2021-07-22 12:26:58 +0100
committerGitHub <noreply@github.com>2021-07-22 12:26:58 +0100
commit39e8d1cc6f798ac842e0d8bc0ba2491e3ad3876e (patch)
tree8f14562a07a3996c16a75fc80e81e326cd7e74bb /roomserver/storage/shared
parent43ac66e0b487f96101f96298e303d02d6ca5654e (diff)
Track knocking in membership updater (#1935)
* Topologically sort outliers in SendEventWithState * Knock in membership updater * Update gomatrixserverlib * Update gomatrixserverlib * Get the NID of the knock event properly for the membership updater
Diffstat (limited to 'roomserver/storage/shared')
-rw-r--r--roomserver/storage/shared/membership_updater.go29
1 files changed, 29 insertions, 0 deletions
diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go
index 57f3a520..f1f589a3 100644
--- a/roomserver/storage/shared/membership_updater.go
+++ b/roomserver/storage/shared/membership_updater.go
@@ -86,6 +86,11 @@ func (u *MembershipUpdater) IsLeave() bool {
return u.membership == tables.MembershipStateLeaveOrBan
}
+// IsKnock implements types.MembershipUpdater
+func (u *MembershipUpdater) IsKnock() bool {
+ return u.membership == tables.MembershipStateKnock
+}
+
// SetToInvite implements types.MembershipUpdater
func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
var inserted bool
@@ -180,3 +185,27 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
})
return inviteEventIDs, err
}
+
+// SetToKnock implements types.MembershipUpdater
+func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, error) {
+ var inserted bool
+ err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
+ if err != nil {
+ return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
+ }
+ if u.membership != tables.MembershipStateKnock {
+ // Look up the NID of the new knock event
+ nIDs, err := u.d.EventNIDs(u.ctx, []string{event.EventID()})
+ if err != nil {
+ return fmt.Errorf("u.d.EventNIDs: %w", err)
+ }
+
+ if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, nIDs[event.EventID()], false); err != nil {
+ return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
+ }
+ }
+ return nil
+ })
+ return inserted, err
+}