aboutsummaryrefslogtreecommitdiff
path: root/syncapi/storage
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-01-12 10:06:03 +0100
committerGitHub <noreply@github.com>2023-01-12 10:06:03 +0100
commit0491a8e3436bc17535a4c57d26376af83685a97c (patch)
tree91c8ba9810d413b59d0a4b4b9bf935058fb6b660 /syncapi/storage
parent25dfbc6ec3991ba04f317cbae4a4dd51bab6013e (diff)
Fix room summary returning wrong heroes (#2930)
This should fix #2910. Probably makes Sytest/Complement a bit upset, since this not using `sort.Strings` anymore.
Diffstat (limited to 'syncapi/storage')
-rw-r--r--syncapi/storage/interface.go2
-rw-r--r--syncapi/storage/postgres/current_room_state_table.go92
-rw-r--r--syncapi/storage/postgres/memberships_table.go27
-rw-r--r--syncapi/storage/shared/storage_sync.go58
-rw-r--r--syncapi/storage/sqlite3/current_room_state_table.go97
-rw-r--r--syncapi/storage/sqlite3/memberships_table.go39
-rw-r--r--syncapi/storage/storage_test.go179
-rw-r--r--syncapi/storage/tables/interface.go4
-rw-r--r--syncapi/storage/tables/memberships_test.go36
9 files changed, 369 insertions, 165 deletions
diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go
index 75afbce1..4e22f8a6 100644
--- a/syncapi/storage/interface.go
+++ b/syncapi/storage/interface.go
@@ -45,7 +45,7 @@ type DatabaseTransaction interface {
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
- GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error)
+ GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error)
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error)
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go
index 48ed2002..3caafa14 100644
--- a/syncapi/storage/postgres/current_room_state_table.go
+++ b/syncapi/storage/postgres/current_room_state_table.go
@@ -19,6 +19,7 @@ import (
"context"
"database/sql"
"encoding/json"
+ "errors"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
@@ -110,6 +111,15 @@ const selectSharedUsersSQL = "" +
" SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
") AND type = 'm.room.member' AND state_key = ANY($2) AND membership IN ('join', 'invite');"
+const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2`
+
+const selectRoomHeroes = `
+SELECT state_key FROM syncapi_current_room_state
+WHERE type = 'm.room.member' AND room_id = $1 AND membership = ANY($2) AND state_key != $3
+ORDER BY added_at, state_key
+LIMIT 5
+`
+
type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
@@ -122,6 +132,8 @@ type currentRoomStateStatements struct {
selectEventsWithEventIDsStmt *sql.Stmt
selectStateEventStmt *sql.Stmt
selectSharedUsersStmt *sql.Stmt
+ selectMembershipCountStmt *sql.Stmt
+ selectRoomHeroesStmt *sql.Stmt
}
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@@ -141,40 +153,21 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
return nil, err
}
- if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
- return nil, err
- }
- if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
- return nil, err
- }
- if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil {
- return nil, err
- }
- if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
- return nil, err
- }
- if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil {
- return nil, err
- }
- if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
- return nil, err
- }
- if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
- return nil, err
- }
- if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
- return nil, err
- }
- if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
- return nil, err
- }
- if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
- return nil, err
- }
- if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil {
- return nil, err
- }
- return s, nil
+ return s, sqlutil.StatementList{
+ {&s.upsertRoomStateStmt, upsertRoomStateSQL},
+ {&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL},
+ {&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL},
+ {&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL},
+ {&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL},
+ {&s.selectCurrentStateStmt, selectCurrentStateSQL},
+ {&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
+ {&s.selectJoinedUsersInRoomStmt, selectJoinedUsersInRoomSQL},
+ {&s.selectEventsWithEventIDsStmt, selectEventsWithEventIDsSQL},
+ {&s.selectStateEventStmt, selectStateEventSQL},
+ {&s.selectSharedUsersStmt, selectSharedUsersSQL},
+ {&s.selectMembershipCountStmt, selectMembershipCount},
+ {&s.selectRoomHeroesStmt, selectRoomHeroes},
+ }.Prepare(db)
}
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
@@ -447,3 +440,34 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
}
return result, rows.Err()
}
+
+func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) {
+ stmt := sqlutil.TxStmt(txn, s.selectRoomHeroesStmt)
+ rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(memberships), excludeUserID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroesStmt: rows.close() failed")
+
+ var stateKey string
+ result := make([]string, 0, 5)
+ for rows.Next() {
+ if err = rows.Scan(&stateKey); err != nil {
+ return nil, err
+ }
+ result = append(result, stateKey)
+ }
+ return result, rows.Err()
+}
+
+func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) {
+ stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt)
+ err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return 0, nil
+ }
+ return 0, err
+ }
+ return count, nil
+}
diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go
index b555e845..ac44b235 100644
--- a/syncapi/storage/postgres/memberships_table.go
+++ b/syncapi/storage/postgres/memberships_table.go
@@ -19,10 +19,8 @@ import (
"database/sql"
"fmt"
- "github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
@@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" +
" SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" +
") t WHERE t.membership = $3"
-const selectHeroesSQL = "" +
- "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5"
-
const selectMembershipBeforeSQL = "" +
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
@@ -81,7 +76,6 @@ WHERE ($3::text IS NULL OR t.membership = $3)
type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
- selectHeroesStmt *sql.Stmt
selectMembershipForUserStmt *sql.Stmt
selectMembersStmt *sql.Stmt
}
@@ -95,7 +89,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
return s, sqlutil.StatementList{
{&s.upsertMembershipStmt, upsertMembershipSQL},
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
- {&s.selectHeroesStmt, selectHeroesSQL},
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
{&s.selectMembersStmt, selectMembersSQL},
}.Prepare(db)
@@ -129,26 +122,6 @@ func (s *membershipsStatements) SelectMembershipCount(
return
}
-func (s *membershipsStatements) SelectHeroes(
- ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string,
-) (heroes []string, err error) {
- stmt := sqlutil.TxStmt(txn, s.selectHeroesStmt)
- var rows *sql.Rows
- rows, err = stmt.QueryContext(ctx, roomID, userID, pq.StringArray(memberships))
- if err != nil {
- return
- }
- defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed")
- var hero string
- for rows.Next() {
- if err = rows.Scan(&hero); err != nil {
- return
- }
- heroes = append(heroes, hero)
- }
- return heroes, rows.Err()
-}
-
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go
index 77afa029..c6933486 100644
--- a/syncapi/storage/shared/storage_sync.go
+++ b/syncapi/storage/shared/storage_sync.go
@@ -6,6 +6,7 @@ import (
"fmt"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/syncapi/types"
@@ -92,8 +93,61 @@ func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membe
return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos)
}
-func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) {
- return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships)
+func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID string) (*types.Summary, error) {
+ summary := &types.Summary{Heroes: []string{}}
+
+ joinCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Join)
+ if err != nil {
+ return summary, err
+ }
+ inviteCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Invite)
+ if err != nil {
+ return summary, err
+ }
+ summary.InvitedMemberCount = &inviteCount
+ summary.JoinedMemberCount = &joinCount
+
+ // Get the room name and canonical alias, if any
+ filter := gomatrixserverlib.DefaultStateFilter()
+ filterTypes := []string{gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias}
+ filterRooms := []string{roomID}
+
+ filter.Types = &filterTypes
+ filter.Rooms = &filterRooms
+ evs, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, &filter, nil)
+ if err != nil {
+ return summary, err
+ }
+
+ for _, ev := range evs {
+ switch ev.Type() {
+ case gomatrixserverlib.MRoomName:
+ if gjson.GetBytes(ev.Content(), "name").Str != "" {
+ return summary, nil
+ }
+ case gomatrixserverlib.MRoomCanonicalAlias:
+ if gjson.GetBytes(ev.Content(), "alias").Str != "" {
+ return summary, nil
+ }
+ }
+ }
+
+ // If there's no room name or canonical alias, get the room heroes, excluding the user
+ heroes, err := d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Join, gomatrixserverlib.Invite})
+ if err != nil {
+ return summary, err
+ }
+
+ // "When no joined or invited members are available, this should consist of the banned and left users"
+ if len(heroes) == 0 {
+ heroes, err = d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Leave, gomatrixserverlib.Ban})
+ if err != nil {
+ return summary, err
+ }
+ }
+ summary.Heroes = heroes
+
+ return summary, nil
}
func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
index 7a381f68..6bc1b267 100644
--- a/syncapi/storage/sqlite3/current_room_state_table.go
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -19,6 +19,7 @@ import (
"context"
"database/sql"
"encoding/json"
+ "errors"
"fmt"
"strings"
@@ -95,6 +96,15 @@ const selectSharedUsersSQL = "" +
" SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
") AND type = 'm.room.member' AND state_key IN ($2) AND membership IN ('join', 'invite');"
+const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2`
+
+const selectRoomHeroes = `
+SELECT state_key FROM syncapi_current_room_state
+WHERE type = 'm.room.member' AND room_id = $1 AND state_key != $2 AND membership IN ($3)
+ORDER BY added_at, state_key
+LIMIT 5
+`
+
type currentRoomStateStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements
@@ -107,6 +117,8 @@ type currentRoomStateStatements struct {
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
selectStateEventStmt *sql.Stmt
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
+ selectMembershipCountStmt *sql.Stmt
+ //selectRoomHeroes *sql.Stmt - prepared at runtime due to variadic
}
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
@@ -129,31 +141,16 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
return nil, err
}
- if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
- return nil, err
- }
- if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
- return nil, err
- }
- if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil {
- return nil, err
- }
- if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
- return nil, err
- }
- if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil {
- return nil, err
- }
- if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
- return nil, err
- }
- //if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
- // return nil, err
- //}
- if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
- return nil, err
- }
- return s, nil
+ return s, sqlutil.StatementList{
+ {&s.upsertRoomStateStmt, upsertRoomStateSQL},
+ {&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL},
+ {&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL},
+ {&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL},
+ {&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL},
+ {&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
+ {&s.selectStateEventStmt, selectStateEventSQL},
+ {&s.selectMembershipCountStmt, selectMembershipCount},
+ }.Prepare(db)
}
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
@@ -485,3 +482,53 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
return result, err
}
+
+func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) {
+ params := make([]interface{}, len(memberships)+2)
+ params[0] = roomID
+ params[1] = excludeUserID
+ for k, v := range memberships {
+ params[k+2] = v
+ }
+
+ query := strings.Replace(selectRoomHeroes, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
+ var stmt *sql.Stmt
+ var err error
+ if txn != nil {
+ stmt, err = txn.Prepare(query)
+ } else {
+ stmt, err = s.db.Prepare(query)
+ }
+ if err != nil {
+ return []string{}, err
+ }
+ defer internal.CloseAndLogIfError(ctx, stmt, "selectRoomHeroes: stmt.close() failed")
+
+ rows, err := stmt.QueryContext(ctx, params...)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroes: rows.close() failed")
+
+ var stateKey string
+ result := make([]string, 0, 5)
+ for rows.Next() {
+ if err = rows.Scan(&stateKey); err != nil {
+ return nil, err
+ }
+ result = append(result, stateKey)
+ }
+ return result, rows.Err()
+}
+
+func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) {
+ stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt)
+ err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return 0, nil
+ }
+ return 0, err
+ }
+ return count, nil
+}
diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go
index 7e54fac1..905a1e1a 100644
--- a/syncapi/storage/sqlite3/memberships_table.go
+++ b/syncapi/storage/sqlite3/memberships_table.go
@@ -18,11 +18,9 @@ import (
"context"
"database/sql"
"fmt"
- "strings"
"github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
@@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" +
" SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" +
") t WHERE t.membership = $3"
-const selectHeroesSQL = "" +
- "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5"
-
const selectMembershipBeforeSQL = "" +
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
@@ -99,7 +94,6 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
{&s.selectMembersStmt, selectMembersSQL},
- // {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic
}.Prepare(db)
}
@@ -131,39 +125,6 @@ func (s *membershipsStatements) SelectMembershipCount(
return
}
-func (s *membershipsStatements) SelectHeroes(
- ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string,
-) (heroes []string, err error) {
- stmtSQL := strings.Replace(selectHeroesSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1)
- stmt, err := s.db.PrepareContext(ctx, stmtSQL)
- if err != nil {
- return
- }
- defer internal.CloseAndLogIfError(ctx, stmt, "SelectHeroes: stmt.close() failed")
- params := []interface{}{
- roomID, userID,
- }
- for _, membership := range memberships {
- params = append(params, membership)
- }
-
- stmt = sqlutil.TxStmt(txn, stmt)
- var rows *sql.Rows
- rows, err = stmt.QueryContext(ctx, params...)
- if err != nil {
- return
- }
- defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed")
- var hero string
- for rows.Next() {
- if err = rows.Scan(&hero); err != nil {
- return
- }
- heroes = append(heroes, hero)
- }
- return heroes, rows.Err()
-}
-
// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
// string as the membership.
diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go
index 5ff185a3..166ddd23 100644
--- a/syncapi/storage/storage_test.go
+++ b/syncapi/storage/storage_test.go
@@ -14,6 +14,7 @@ import (
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/stretchr/testify/assert"
)
var ctx = context.Background()
@@ -664,3 +665,181 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
return &tok
}
*/
+
+func pointer[t any](s t) *t {
+ return &s
+}
+
+func TestRoomSummary(t *testing.T) {
+
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+ charlie := test.NewUser(t)
+
+ // Create some dummy users
+ moreUsers := []*test.User{}
+ moreUserIDs := []string{}
+ for i := 0; i < 10; i++ {
+ u := test.NewUser(t)
+ moreUsers = append(moreUsers, u)
+ moreUserIDs = append(moreUserIDs, u.ID)
+ }
+
+ testCases := []struct {
+ name string
+ wantSummary *types.Summary
+ additionalEvents func(t *testing.T, room *test.Room)
+ }{
+ {
+ name: "after initial creation",
+ wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{}},
+ },
+ {
+ name: "invited user",
+ wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{bob.ID}},
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "invite",
+ }, test.WithStateKey(bob.ID))
+ },
+ },
+ {
+ name: "invited user, but declined",
+ wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}},
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "invite",
+ }, test.WithStateKey(bob.ID))
+ room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "leave",
+ }, test.WithStateKey(bob.ID))
+ },
+ },
+ {
+ name: "joined user after invitation",
+ wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}},
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "invite",
+ }, test.WithStateKey(bob.ID))
+ room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(bob.ID))
+ },
+ },
+ {
+ name: "multiple joined user",
+ wantSummary: &types.Summary{JoinedMemberCount: pointer(3), InvitedMemberCount: pointer(0), Heroes: []string{charlie.ID, bob.ID}},
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ room.CreateAndInsert(t, charlie, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(charlie.ID))
+ room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(bob.ID))
+ },
+ },
+ {
+ name: "multiple joined/invited user",
+ wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID, bob.ID}},
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "invite",
+ }, test.WithStateKey(charlie.ID))
+ room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(bob.ID))
+ },
+ },
+ {
+ name: "multiple joined/invited/left user",
+ wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID}},
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "invite",
+ }, test.WithStateKey(charlie.ID))
+ room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(bob.ID))
+ room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "leave",
+ }, test.WithStateKey(bob.ID))
+ },
+ },
+ {
+ name: "leaving user after joining",
+ wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}},
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(bob.ID))
+ room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "leave",
+ }, test.WithStateKey(bob.ID))
+ },
+ },
+ {
+ name: "many users", // heroes ordered by stream id
+ wantSummary: &types.Summary{JoinedMemberCount: pointer(len(moreUserIDs) + 1), InvitedMemberCount: pointer(0), Heroes: moreUserIDs[:5]},
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ for _, x := range moreUsers {
+ room.CreateAndInsert(t, x, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(x.ID))
+ }
+ },
+ },
+ {
+ name: "canonical alias set",
+ wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}},
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(bob.ID))
+ room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomCanonicalAlias, map[string]interface{}{
+ "alias": "myalias",
+ }, test.WithStateKey(""))
+ },
+ },
+ {
+ name: "room name set",
+ wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}},
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(bob.ID))
+ room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomName, map[string]interface{}{
+ "name": "my room name",
+ }, test.WithStateKey(""))
+ },
+ },
+ }
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close, closeBase := MustCreateDatabase(t, dbType)
+ defer close()
+ defer closeBase()
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+
+ r := test.NewRoom(t, alice)
+
+ if tc.additionalEvents != nil {
+ tc.additionalEvents(t, r)
+ }
+
+ // write the room before creating a transaction
+ MustWriteEvents(t, db, r.Events())
+
+ transaction, err := db.NewDatabaseTransaction(ctx)
+ assert.NoError(t, err)
+ defer transaction.Rollback()
+
+ summary, err := transaction.GetRoomSummary(ctx, r.ID, alice.ID)
+ assert.NoError(t, err)
+ assert.Equal(t, tc.wantSummary, summary)
+ })
+ }
+ })
+}
diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go
index a0574b25..c02e4ecc 100644
--- a/syncapi/storage/tables/interface.go
+++ b/syncapi/storage/tables/interface.go
@@ -115,6 +115,9 @@ type CurrentRoomState interface {
SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error)
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
+
+ SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error)
+ SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (int, error)
}
// BackwardsExtremities keeps track of backwards extremities for a room.
@@ -185,7 +188,6 @@ type Receipts interface {
type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
- SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error)
SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
SelectMemberships(
ctx context.Context, txn *sql.Tx,
diff --git a/syncapi/storage/tables/memberships_test.go b/syncapi/storage/tables/memberships_test.go
index 0cee7f5a..df593ae7 100644
--- a/syncapi/storage/tables/memberships_test.go
+++ b/syncapi/storage/tables/memberships_test.go
@@ -3,8 +3,6 @@ package tables_test
import (
"context"
"database/sql"
- "reflect"
- "sort"
"testing"
"time"
@@ -88,43 +86,9 @@ func TestMembershipsTable(t *testing.T) {
testUpsert(t, ctx, table, userEvents[0], alice, room)
testMembershipCount(t, ctx, table, room)
- testHeroes(t, ctx, table, alice, room, users)
})
}
-func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) {
-
- // Re-slice and sort the expected users
- users = users[1:]
- sort.Strings(users)
- type testCase struct {
- name string
- memberships []string
- wantHeroes []string
- }
-
- testCases := []testCase{
- {name: "no memberships queried", memberships: []string{}},
- {name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]},
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships)
- if err != nil {
- t.Fatalf("unable to select heroes: %s", err)
- }
- if gotLen := len(got); gotLen != len(tc.wantHeroes) {
- t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen)
- }
-
- if !reflect.DeepEqual(got, tc.wantHeroes) {
- t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got)
- }
- })
- }
-}
-
func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) {
t.Run("membership counts are correct", func(t *testing.T) {
// After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users)