diff options
Diffstat (limited to 'syncapi/notifier/notifier.go')
-rw-r--r-- | syncapi/notifier/notifier.go | 58 |
1 files changed, 51 insertions, 7 deletions
diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index 6a641e6f..d2b79b63 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -43,22 +43,30 @@ type Notifier struct { userDeviceStreams map[string]map[string]*UserDeviceStream // The last time we cleaned out stale entries from the userStreams map lastCleanUpTime time.Time + // Protects roomIDToJoinedUsers and roomIDToPeekingDevices + mapLock *sync.RWMutex } // NewNotifier creates a new notifier set to the given sync position. // In order for this to be of any use, the Notifier needs to be told all rooms and // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). -func NewNotifier(currPos types.StreamingToken) *Notifier { +func NewNotifier() *Notifier { return &Notifier{ - currPos: currPos, roomIDToJoinedUsers: make(map[string]userIDSet), roomIDToPeekingDevices: make(map[string]peekingDeviceSet), userDeviceStreams: make(map[string]map[string]*UserDeviceStream), streamLock: &sync.Mutex{}, + mapLock: &sync.RWMutex{}, lastCleanUpTime: time.Now(), } } +// SetCurrentPosition sets the current streaming positions. +// This must be called directly after NewNotifier and initialising the streams. +func (n *Notifier) SetCurrentPosition(currPos types.StreamingToken) { + n.currPos = currPos +} + // OnNewEvent is called when a new event is received from the room server. Must only be // called from a single goroutine, to avoid races between updates which could set the // current sync position incorrectly. @@ -83,7 +91,7 @@ func (n *Notifier) OnNewEvent( if ev != nil { // Map this event's room_id to a list of joined users, and wake them up. - usersToNotify := n.joinedUsers(ev.RoomID()) + usersToNotify := n.JoinedUsers(ev.RoomID()) // Map this event's room_id to a list of peeking devices, and wake them up. peekingDevicesToNotify := n.PeekingDevices(ev.RoomID()) // If this is an invite, also add in the invitee to this list. @@ -114,7 +122,7 @@ func (n *Notifier) OnNewEvent( n.wakeupUsers(usersToNotify, peekingDevicesToNotify, n.currPos) } else if roomID != "" { - n.wakeupUsers(n.joinedUsers(roomID), n.PeekingDevices(roomID), n.currPos) + n.wakeupUsers(n.JoinedUsers(roomID), n.PeekingDevices(roomID), n.currPos) } else if len(userIDs) > 0 { n.wakeupUsers(userIDs, nil, n.currPos) } else { @@ -182,7 +190,7 @@ func (n *Notifier) OnNewTyping( defer n.streamLock.Unlock() n.currPos.ApplyUpdates(posUpdate) - n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos) + n.wakeupUsers(n.JoinedUsers(roomID), nil, n.currPos) } // OnNewReceipt updates the current position @@ -194,7 +202,7 @@ func (n *Notifier) OnNewReceipt( defer n.streamLock.Unlock() n.currPos.ApplyUpdates(posUpdate) - n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos) + n.wakeupUsers(n.JoinedUsers(roomID), nil, n.currPos) } func (n *Notifier) OnNewKeyChange( @@ -228,6 +236,28 @@ func (n *Notifier) OnNewNotificationData( n.wakeupUsers([]string{userID}, nil, n.currPos) } +func (n *Notifier) OnNewPresence( + posUpdate types.StreamingToken, userID string, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.currPos.ApplyUpdates(posUpdate) + sharedUsers := n.SharedUsers(userID) + sharedUsers = append(sharedUsers, userID) + + n.wakeupUsers(sharedUsers, nil, n.currPos) +} + +func (n *Notifier) SharedUsers(userID string) (sharedUsers []string) { + for roomID, users := range n.roomIDToJoinedUsers { + if _, ok := users[userID]; ok { + sharedUsers = append(sharedUsers, n.JoinedUsers(roomID)...) + } + } + return sharedUsers +} + // GetListener returns a UserStreamListener that can be used to wait for // updates for a user. Must be closed. // notify for anything before sincePos @@ -250,6 +280,8 @@ func (n *Notifier) GetListener(req types.SyncRequest) UserDeviceStreamListener { // Load the membership states required to notify users correctly. func (n *Notifier) Load(ctx context.Context, db storage.Database) error { + n.mapLock.Lock() + defer n.mapLock.Unlock() roomToUsers, err := db.AllJoinedUsersInRooms(ctx) if err != nil { return err @@ -377,6 +409,8 @@ func (n *Notifier) fetchUserStreams(userID string) []*UserDeviceStream { // Not thread-safe: must be called on the OnNewEvent goroutine only func (n *Notifier) addJoinedUser(roomID, userID string) { + n.mapLock.Lock() + defer n.mapLock.Unlock() if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { n.roomIDToJoinedUsers[roomID] = make(userIDSet) } @@ -385,6 +419,8 @@ func (n *Notifier) addJoinedUser(roomID, userID string) { // Not thread-safe: must be called on the OnNewEvent goroutine only func (n *Notifier) removeJoinedUser(roomID, userID string) { + n.mapLock.Lock() + defer n.mapLock.Unlock() if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { n.roomIDToJoinedUsers[roomID] = make(userIDSet) } @@ -392,7 +428,9 @@ func (n *Notifier) removeJoinedUser(roomID, userID string) { } // Not thread-safe: must be called on the OnNewEvent goroutine only -func (n *Notifier) joinedUsers(roomID string) (userIDs []string) { +func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) { + n.mapLock.RLock() + defer n.mapLock.RUnlock() if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { return } @@ -401,6 +439,8 @@ func (n *Notifier) joinedUsers(roomID string) (userIDs []string) { // Not thread-safe: must be called on the OnNewEvent goroutine only func (n *Notifier) addPeekingDevice(roomID, userID, deviceID string) { + n.mapLock.Lock() + defer n.mapLock.Unlock() if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet) } @@ -410,6 +450,8 @@ func (n *Notifier) addPeekingDevice(roomID, userID, deviceID string) { // Not thread-safe: must be called on the OnNewEvent goroutine only // nolint:unused func (n *Notifier) removePeekingDevice(roomID, userID, deviceID string) { + n.mapLock.Lock() + defer n.mapLock.Unlock() if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet) } @@ -419,6 +461,8 @@ func (n *Notifier) removePeekingDevice(roomID, userID, deviceID string) { // Not thread-safe: must be called on the OnNewEvent goroutine only func (n *Notifier) PeekingDevices(roomID string) (peekingDevices []types.PeekingDevice) { + n.mapLock.RLock() + defer n.mapLock.RUnlock() if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { return } |