aboutsummaryrefslogtreecommitdiff
path: root/syncapi/storage/postgres
diff options
context:
space:
mode:
Diffstat (limited to 'syncapi/storage/postgres')
-rw-r--r--syncapi/storage/postgres/memberships_table.go35
1 files changed, 34 insertions, 1 deletions
diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go
index 939d6b3f..b555e845 100644
--- a/syncapi/storage/postgres/memberships_table.go
+++ b/syncapi/storage/postgres/memberships_table.go
@@ -20,11 +20,12 @@ import (
"fmt"
"github.com/lib/pq"
+ "github.com/matrix-org/gomatrixserverlib"
+
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
- "github.com/matrix-org/gomatrixserverlib"
)
// The memberships table is designed to track the last time that
@@ -69,11 +70,20 @@ const selectHeroesSQL = "" +
const selectMembershipBeforeSQL = "" +
"SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
+const selectMembersSQL = `
+SELECT event_id FROM (
+ SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC
+) t
+WHERE ($3::text IS NULL OR t.membership = $3)
+ AND ($4::text IS NULL OR t.membership <> $4)
+`
+
type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
selectHeroesStmt *sql.Stmt
selectMembershipForUserStmt *sql.Stmt
+ selectMembersStmt *sql.Stmt
}
func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@@ -87,6 +97,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
{&s.selectHeroesStmt, selectHeroesSQL},
{&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
+ {&s.selectMembersStmt, selectMembersSQL},
}.Prepare(db)
}
@@ -154,3 +165,25 @@ func (s *membershipsStatements) SelectMembershipForUser(
}
return membership, topologyPos, nil
}
+
+func (s *membershipsStatements) SelectMemberships(
+ ctx context.Context, txn *sql.Tx,
+ roomID string, pos types.TopologyToken,
+ membership, notMembership *string,
+) (eventIDs []string, err error) {
+ stmt := sqlutil.TxStmt(txn, s.selectMembersStmt)
+ rows, err := stmt.QueryContext(ctx, roomID, pos.Depth, membership, notMembership)
+ if err != nil {
+ return
+ }
+ var (
+ eventID string
+ )
+ for rows.Next() {
+ if err = rows.Scan(&eventID); err != nil {
+ return
+ }
+ eventIDs = append(eventIDs, eventID)
+ }
+ return eventIDs, rows.Err()
+}