aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2022-12-08 08:25:03 +0100
committerGitHub <noreply@github.com>2022-12-08 08:25:03 +0100
commitc136a450d5196cf22a91419f493bb73c29481122 (patch)
tree8a3dc463cb652c0f0beada3eba0553f336d329f0
parent0351618ff4e7d569e14a165be59a1a7e9e979684 (diff)
Fix newly joined users presence (#2854)
Fixes #2803 Also refactors the presence stream to not hit the database for every user, instead queries all users at once now.
-rw-r--r--syncapi/consumers/presence.go12
-rw-r--r--syncapi/storage/interface.go4
-rw-r--r--syncapi/storage/postgres/presence_table.go38
-rw-r--r--syncapi/storage/shared/storage_consumer.go4
-rw-r--r--syncapi/storage/shared/storage_sync.go4
-rw-r--r--syncapi/storage/sqlite3/presence_table.go48
-rw-r--r--syncapi/storage/tables/interface.go2
-rw-r--r--syncapi/storage/tables/presence_table_test.go136
-rw-r--r--syncapi/streams/stream_presence.go80
-rw-r--r--syncapi/sync/requestpool.go6
-rw-r--r--syncapi/sync/requestpool_test.go4
11 files changed, 263 insertions, 75 deletions
diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go
index 145059c2..6e3150c2 100644
--- a/syncapi/consumers/presence.go
+++ b/syncapi/consumers/presence.go
@@ -78,7 +78,7 @@ func (s *PresenceConsumer) Start() error {
// Normal NATS subscription, used by Request/Reply
_, err := s.nats.Subscribe(s.requestTopic, func(msg *nats.Msg) {
userID := msg.Header.Get(jetstream.UserID)
- presence, err := s.db.GetPresence(context.Background(), userID)
+ presences, err := s.db.GetPresences(context.Background(), []string{userID})
m := &nats.Msg{
Header: nats.Header{},
}
@@ -89,10 +89,12 @@ func (s *PresenceConsumer) Start() error {
}
return
}
- if presence == nil {
- presence = &types.PresenceInternal{
- UserID: userID,
- }
+
+ presence := &types.PresenceInternal{
+ UserID: userID,
+ }
+ if len(presences) > 0 {
+ presence = presences[0]
}
deviceRes := api.QueryDevicesResponse{}
diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go
index 97c2ced4..75afbce1 100644
--- a/syncapi/storage/interface.go
+++ b/syncapi/storage/interface.go
@@ -106,7 +106,7 @@ type DatabaseTransaction interface {
SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
- GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
+ GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error)
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error)
}
@@ -186,7 +186,7 @@ type Database interface {
}
type Presence interface {
- GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
+ GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error)
UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
}
diff --git a/syncapi/storage/postgres/presence_table.go b/syncapi/storage/postgres/presence_table.go
index 7194afea..a3f7c521 100644
--- a/syncapi/storage/postgres/presence_table.go
+++ b/syncapi/storage/postgres/presence_table.go
@@ -19,10 +19,12 @@ import (
"database/sql"
"time"
+ "github.com/lib/pq"
+ "github.com/matrix-org/gomatrixserverlib"
+
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/types"
- "github.com/matrix-org/gomatrixserverlib"
)
const presenceSchema = `
@@ -63,9 +65,9 @@ const upsertPresenceFromSyncSQL = "" +
" RETURNING id"
const selectPresenceForUserSQL = "" +
- "SELECT presence, status_msg, last_active_ts" +
+ "SELECT user_id, presence, status_msg, last_active_ts" +
" FROM syncapi_presence" +
- " WHERE user_id = $1 LIMIT 1"
+ " WHERE user_id = ANY($1)"
const selectMaxPresenceSQL = "" +
"SELECT COALESCE(MAX(id), 0) FROM syncapi_presence"
@@ -119,20 +121,28 @@ func (p *presenceStatements) UpsertPresence(
return
}
-// GetPresenceForUser returns the current presence of a user.
-func (p *presenceStatements) GetPresenceForUser(
+// GetPresenceForUsers returns the current presence for a list of users.
+// If the user doesn't have a presence status yet, it is omitted from the response.
+func (p *presenceStatements) GetPresenceForUsers(
ctx context.Context, txn *sql.Tx,
- userID string,
-) (*types.PresenceInternal, error) {
- result := &types.PresenceInternal{
- UserID: userID,
- }
+ userIDs []string,
+) ([]*types.PresenceInternal, error) {
+ result := make([]*types.PresenceInternal, 0, len(userIDs))
stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
- err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
- if err == sql.ErrNoRows {
- return nil, nil
+ rows, err := stmt.QueryContext(ctx, pq.Array(userIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed")
+
+ for rows.Next() {
+ presence := &types.PresenceInternal{}
+ if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil {
+ return nil, err
+ }
+ presence.ClientFields.Presence = presence.Presence.String()
+ result = append(result, presence)
}
- result.ClientFields.Presence = result.Presence.String()
return result, err
}
diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go
index f2064fb8..df2338cf 100644
--- a/syncapi/storage/shared/storage_consumer.go
+++ b/syncapi/storage/shared/storage_consumer.go
@@ -564,8 +564,8 @@ func (d *Database) UpdatePresence(ctx context.Context, userID string, presence t
return pos, err
}
-func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
- return d.Presence.GetPresenceForUser(ctx, nil, userID)
+func (d *Database) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) {
+ return d.Presence.GetPresenceForUsers(ctx, nil, userIDs)
}
func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {
diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go
index c3763521..77afa029 100644
--- a/syncapi/storage/shared/storage_sync.go
+++ b/syncapi/storage/shared/storage_sync.go
@@ -596,8 +596,8 @@ func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx contex
return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs)
}
-func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
- return d.Presence.GetPresenceForUser(ctx, d.txn, userID)
+func (d *DatabaseTransaction) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) {
+ return d.Presence.GetPresenceForUsers(ctx, d.txn, userIDs)
}
func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go
index b61a825d..7641de92 100644
--- a/syncapi/storage/sqlite3/presence_table.go
+++ b/syncapi/storage/sqlite3/presence_table.go
@@ -17,12 +17,14 @@ package sqlite3
import (
"context"
"database/sql"
+ "strings"
"time"
+ "github.com/matrix-org/gomatrixserverlib"
+
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/types"
- "github.com/matrix-org/gomatrixserverlib"
)
const presenceSchema = `
@@ -62,9 +64,9 @@ const upsertPresenceFromSyncSQL = "" +
" RETURNING id"
const selectPresenceForUserSQL = "" +
- "SELECT presence, status_msg, last_active_ts" +
+ "SELECT user_id, presence, status_msg, last_active_ts" +
" FROM syncapi_presence" +
- " WHERE user_id = $1 LIMIT 1"
+ " WHERE user_id IN ($1)"
const selectMaxPresenceSQL = "" +
"SELECT COALESCE(MAX(id), 0) FROM syncapi_presence"
@@ -134,20 +136,38 @@ func (p *presenceStatements) UpsertPresence(
return
}
-// GetPresenceForUser returns the current presence of a user.
-func (p *presenceStatements) GetPresenceForUser(
+// GetPresenceForUsers returns the current presence for a list of users.
+// If the user doesn't have a presence status yet, it is omitted from the response.
+func (p *presenceStatements) GetPresenceForUsers(
ctx context.Context, txn *sql.Tx,
- userID string,
-) (*types.PresenceInternal, error) {
- result := &types.PresenceInternal{
- UserID: userID,
+ userIDs []string,
+) ([]*types.PresenceInternal, error) {
+ qry := strings.Replace(selectPresenceForUserSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
+ prepStmt, err := p.db.Prepare(qry)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, prepStmt, "GetPresenceForUsers: stmt.close() failed")
+
+ params := make([]interface{}, len(userIDs))
+ for i := range userIDs {
+ params[i] = userIDs[i]
}
- stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt)
- err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS)
- if err == sql.ErrNoRows {
- return nil, nil
+
+ rows, err := sqlutil.TxStmt(txn, prepStmt).QueryContext(ctx, params...)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed")
+ result := make([]*types.PresenceInternal, 0, len(userIDs))
+ for rows.Next() {
+ presence := &types.PresenceInternal{}
+ if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil {
+ return nil, err
+ }
+ presence.ClientFields.Presence = presence.Presence.String()
+ result = append(result, presence)
}
- result.ClientFields.Presence = result.Presence.String()
return result, err
}
diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go
index 2c4f04ec..a0574b25 100644
--- a/syncapi/storage/tables/interface.go
+++ b/syncapi/storage/tables/interface.go
@@ -207,7 +207,7 @@ type Ignores interface {
type Presence interface {
UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error)
- GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error)
+ GetPresenceForUsers(ctx context.Context, txn *sql.Tx, userIDs []string) (presence []*types.PresenceInternal, err error)
GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error)
GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error)
}
diff --git a/syncapi/storage/tables/presence_table_test.go b/syncapi/storage/tables/presence_table_test.go
new file mode 100644
index 00000000..dce0c695
--- /dev/null
+++ b/syncapi/storage/tables/presence_table_test.go
@@ -0,0 +1,136 @@
+package tables_test
+
+import (
+ "context"
+ "database/sql"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/matrix-org/gomatrixserverlib"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/syncapi/storage/postgres"
+ "github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
+ "github.com/matrix-org/dendrite/syncapi/storage/tables"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/dendrite/test"
+)
+
+func mustPresenceTable(t *testing.T, dbType test.DBType) (tables.Presence, func()) {
+ t.Helper()
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := sqlutil.Open(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, sqlutil.NewExclusiveWriter())
+ if err != nil {
+ t.Fatalf("failed to open db: %s", err)
+ }
+
+ var tab tables.Presence
+ switch dbType {
+ case test.DBTypePostgres:
+ tab, err = postgres.NewPostgresPresenceTable(db)
+ case test.DBTypeSQLite:
+ var stream sqlite3.StreamIDStatements
+ if err = stream.Prepare(db); err != nil {
+ t.Fatalf("failed to prepare stream stmts: %s", err)
+ }
+ tab, err = sqlite3.NewSqlitePresenceTable(db, &stream)
+ }
+ if err != nil {
+ t.Fatalf("failed to make new table: %s", err)
+ }
+ return tab, close
+}
+
+func TestPresence(t *testing.T) {
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+ ctx := context.Background()
+
+ statusMsg := "Hello World!"
+ timestamp := gomatrixserverlib.AsTimestamp(time.Now())
+
+ var txn *sql.Tx
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ tab, closeDB := mustPresenceTable(t, dbType)
+ defer closeDB()
+
+ // Insert some presences
+ pos, err := tab.UpsertPresence(ctx, txn, alice.ID, &statusMsg, types.PresenceOnline, timestamp, false)
+ if err != nil {
+ t.Error(err)
+ }
+ wantPos := types.StreamPosition(1)
+ if pos != wantPos {
+ t.Errorf("expected pos to be %d, got %d", wantPos, pos)
+ }
+ pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, false)
+ if err != nil {
+ t.Error(err)
+ }
+ wantPos = 2
+ if pos != wantPos {
+ t.Errorf("expected pos to be %d, got %d", wantPos, pos)
+ }
+
+ // verify the expected max presence ID
+ maxPos, err := tab.GetMaxPresenceID(ctx, txn)
+ if err != nil {
+ t.Error(err)
+ }
+ if maxPos != wantPos {
+ t.Errorf("expected max pos to be %d, got %d", wantPos, maxPos)
+ }
+
+ // This should increment the position
+ pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, true)
+ if err != nil {
+ t.Error(err)
+ }
+ wantPos = pos
+ if wantPos <= maxPos {
+ t.Errorf("expected pos to be %d incremented, got %d", wantPos, pos)
+ }
+
+ // This should return only Bobs status
+ presences, err := tab.GetPresenceAfter(ctx, txn, maxPos, gomatrixserverlib.EventFilter{Limit: 10})
+ if err != nil {
+ t.Error(err)
+ }
+
+ if c := len(presences); c > 1 {
+ t.Errorf("expected only one presence, got %d", c)
+ }
+
+ // Validate the response
+ wantPresence := &types.PresenceInternal{
+ UserID: bob.ID,
+ Presence: types.PresenceOnline,
+ StreamPos: wantPos,
+ LastActiveTS: timestamp,
+ ClientFields: types.PresenceClientResponse{
+ LastActiveAgo: 0,
+ Presence: types.PresenceOnline.String(),
+ StatusMsg: &statusMsg,
+ },
+ }
+ if !reflect.DeepEqual(wantPresence, presences[bob.ID]) {
+ t.Errorf("unexpected presence result:\n%+v, want\n%+v", presences[bob.ID], wantPresence)
+ }
+
+ // Try getting presences for existing and non-existing users
+ getUsers := []string{alice.ID, bob.ID, "@doesntexist:test"}
+ presencesForUsers, err := tab.GetPresenceForUsers(ctx, nil, getUsers)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if len(presencesForUsers) >= len(getUsers) {
+ t.Errorf("expected less presences, but they are the same/more as requested: %d >= %d", len(presencesForUsers), len(getUsers))
+ }
+ })
+
+}
diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go
index 030b7c5d..445e46b3 100644
--- a/syncapi/streams/stream_presence.go
+++ b/syncapi/streams/stream_presence.go
@@ -17,6 +17,7 @@ package streams
import (
"context"
"encoding/json"
+ "fmt"
"sync"
"github.com/matrix-org/gomatrixserverlib"
@@ -70,39 +71,25 @@ func (p *PresenceStreamProvider) IncrementalSync(
return from
}
- if len(presences) == 0 {
+ getPresenceForUsers, err := p.getNeededUsersFromRequest(ctx, req, presences)
+ if err != nil {
+ req.Log.WithError(err).Error("getNeededUsersFromRequest failed")
+ return from
+ }
+
+ // Got no presence between range and no presence to get from the database
+ if len(getPresenceForUsers) == 0 && len(presences) == 0 {
return to
}
- // add newly joined rooms user presences
- newlyJoined := joinedRooms(req.Response, req.Device.UserID)
- if len(newlyJoined) > 0 {
- // TODO: Check if this is working better than before.
- if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil {
- req.Log.WithError(err).Error("unable to refresh notifier lists")
- return from
- }
- NewlyJoinedLoop:
- for _, roomID := range newlyJoined {
- roomUsers := p.notifier.JoinedUsers(roomID)
- for i := range roomUsers {
- // we already got a presence from this user
- if _, ok := presences[roomUsers[i]]; ok {
- continue
- }
- // Bear in mind that this might return nil, but at least populating
- // a nil means that there's a map entry so we won't repeat this call.
- presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i])
- if err != nil {
- req.Log.WithError(err).Error("unable to query presence for user")
- _ = snapshot.Rollback()
- return from
- }
- if len(presences) > req.Filter.Presence.Limit {
- break NewlyJoinedLoop
- }
- }
- }
+ dbPresences, err := snapshot.GetPresences(ctx, getPresenceForUsers)
+ if err != nil {
+ req.Log.WithError(err).Error("unable to query presence for user")
+ _ = snapshot.Rollback()
+ return from
+ }
+ for _, presence := range dbPresences {
+ presences[presence.UserID] = presence
}
lastPos := from
@@ -164,6 +151,39 @@ func (p *PresenceStreamProvider) IncrementalSync(
return lastPos
}
+func (p *PresenceStreamProvider) getNeededUsersFromRequest(ctx context.Context, req *types.SyncRequest, presences map[string]*types.PresenceInternal) ([]string, error) {
+ getPresenceForUsers := []string{}
+ // Add presence for users which newly joined a room
+ for userID := range req.MembershipChanges {
+ if _, ok := presences[userID]; ok {
+ continue
+ }
+ getPresenceForUsers = append(getPresenceForUsers, userID)
+ }
+
+ // add newly joined rooms user presences
+ newlyJoined := joinedRooms(req.Response, req.Device.UserID)
+ if len(newlyJoined) == 0 {
+ return getPresenceForUsers, nil
+ }
+
+ // TODO: Check if this is working better than before.
+ if err := p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil {
+ return getPresenceForUsers, fmt.Errorf("unable to refresh notifier lists: %w", err)
+ }
+ for _, roomID := range newlyJoined {
+ roomUsers := p.notifier.JoinedUsers(roomID)
+ for i := range roomUsers {
+ // we already got a presence from this user
+ if _, ok := presences[roomUsers[i]]; ok {
+ continue
+ }
+ getPresenceForUsers = append(getPresenceForUsers, roomUsers[i])
+ }
+ }
+ return getPresenceForUsers, nil
+}
+
func joinedRooms(res *types.Response, userID string) []string {
var roomIDs []string
for roomID, join := range res.Rooms.Join {
diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go
index 29d92b29..b086567b 100644
--- a/syncapi/sync/requestpool.go
+++ b/syncapi/sync/requestpool.go
@@ -145,12 +145,12 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user
}
// ensure we also send the current status_msg to federated servers and not nil
- dbPresence, err := db.GetPresence(context.Background(), userID)
+ dbPresence, err := db.GetPresences(context.Background(), []string{userID})
if err != nil && err != sql.ErrNoRows {
return
}
- if dbPresence != nil {
- newPresence.ClientFields = dbPresence.ClientFields
+ if len(dbPresence) > 0 && dbPresence[0] != nil {
+ newPresence.ClientFields = dbPresence[0].ClientFields
}
newPresence.ClientFields.Presence = presenceID.String()
diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go
index 3e5769d8..faa0b49c 100644
--- a/syncapi/sync/requestpool_test.go
+++ b/syncapi/sync/requestpool_test.go
@@ -29,8 +29,8 @@ func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence typ
return 0, nil
}
-func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
- return &types.PresenceInternal{}, nil
+func (d dummyDB) GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) {
+ return []*types.PresenceInternal{}, nil
}
func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {