aboutsummaryrefslogtreecommitdiff
path: root/roomserver
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2020-05-20 18:03:06 +0100
committerGitHub <noreply@github.com>2020-05-20 18:03:06 +0100
commitf2c07437fe3d7f54977f7e645cd045a04e832020 (patch)
treefe0d5c92b1bc6ac5d5acb065f555de4ca704c1ad /roomserver
parent6091bf044ff41e9f248b1077d5b05f1a4c694412 (diff)
Use memberships to determine whether to reset latest events/state on room join (#1047)
* Track local/remote memberships, re-scope some input stuff * Check if we're in the room already before resetting latest events/state * Fix postgres, fix lint * Review comments
Diffstat (limited to 'roomserver')
-rw-r--r--roomserver/internal/input.go4
-rw-r--r--roomserver/internal/input_events.go53
-rw-r--r--roomserver/internal/input_latest_events.go23
-rw-r--r--roomserver/internal/input_membership.go24
-rw-r--r--roomserver/internal/query.go4
-rw-r--r--roomserver/internal/query_backfill.go2
-rw-r--r--roomserver/storage/interface.go4
-rw-r--r--roomserver/storage/postgres/membership_table.go59
-rw-r--r--roomserver/storage/postgres/storage.go17
-rw-r--r--roomserver/storage/sqlite3/membership_table.go53
-rw-r--r--roomserver/storage/sqlite3/storage.go17
-rw-r--r--roomserver/types/types.go2
12 files changed, 162 insertions, 100 deletions
diff --git a/roomserver/internal/input.go b/roomserver/internal/input.go
index ab3d7516..932b4df4 100644
--- a/roomserver/internal/input.go
+++ b/roomserver/internal/input.go
@@ -60,7 +60,7 @@ func (r *RoomserverInternalAPI) InputRoomEvents(
defer r.mutex.Unlock()
for i := range request.InputInviteEvents {
var loopback *api.InputRoomEvent
- if loopback, err = processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil {
+ if loopback, err = r.processInviteEvent(ctx, r, request.InputInviteEvents[i]); err != nil {
return err
}
// The processInviteEvent function can optionally return a
@@ -71,7 +71,7 @@ func (r *RoomserverInternalAPI) InputRoomEvents(
}
}
for i := range request.InputRoomEvents {
- if response.EventID, err = processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil {
+ if response.EventID, err = r.processRoomEvent(ctx, request.InputRoomEvents[i]); err != nil {
return err
}
}
diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go
index f5c678ca..a4167714 100644
--- a/roomserver/internal/input_events.go
+++ b/roomserver/internal/input_events.go
@@ -31,21 +31,13 @@ import (
log "github.com/sirupsen/logrus"
)
-// OutputRoomEventWriter has the APIs needed to write an event to the output logs.
-type OutputRoomEventWriter interface {
- // Write a list of events for a room
- WriteOutputEvents(roomID string, updates []api.OutputEvent) error
-}
-
// processRoomEvent can only be called once at a time
//
// TODO(#375): This should be rewritten to allow concurrent calls. The
// difficulty is in ensuring that we correctly annotate events with the correct
// state deltas when sending to kafka streams
-func processRoomEvent(
+func (r *RoomserverInternalAPI) processRoomEvent(
ctx context.Context,
- db storage.Database,
- ow OutputRoomEventWriter,
input api.InputRoomEvent,
) (eventID string, err error) {
// Parse and validate the event JSON
@@ -54,7 +46,7 @@ func processRoomEvent(
// Check that the event passes authentication checks and work out
// the numeric IDs for the auth events.
- authEventNIDs, err := checkAuthEvents(ctx, db, headered, input.AuthEventIDs)
+ authEventNIDs, err := checkAuthEvents(ctx, r.DB, headered, input.AuthEventIDs)
if err != nil {
logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event")
return
@@ -63,7 +55,7 @@ func processRoomEvent(
// If we don't have a transaction ID then get one.
if input.TransactionID != nil {
tdID := input.TransactionID
- eventID, err = db.GetTransactionEventID(
+ eventID, err = r.DB.GetTransactionEventID(
ctx, tdID.TransactionID, tdID.SessionID, event.Sender(),
)
// On error OR event with the transaction already processed/processesing
@@ -73,7 +65,7 @@ func processRoomEvent(
}
// Store the event.
- roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
+ roomNID, stateAtEvent, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
if err != nil {
return
}
@@ -93,16 +85,14 @@ func processRoomEvent(
if stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet.
// Lets calculate one.
- err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event)
+ err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event)
if err != nil {
return
}
}
- if err = updateLatestEvents(
+ if err = r.updateLatestEvents(
ctx, // context
- db, // roomserver database
- ow, // output event writer
roomNID, // room NID to update
stateAtEvent, // state at event (below)
event, // event
@@ -116,29 +106,36 @@ func processRoomEvent(
return event.EventID(), nil
}
-func calculateAndSetState(
+func (r *RoomserverInternalAPI) calculateAndSetState(
ctx context.Context,
- db storage.Database,
input api.InputRoomEvent,
roomNID types.RoomNID,
stateAtEvent *types.StateAtEvent,
event gomatrixserverlib.Event,
) error {
var err error
- roomState := state.NewStateResolution(db)
+ roomState := state.NewStateResolution(r.DB)
if input.HasState {
- // TODO: Check here if we think we're in the room already.
+ // Check here if we think we're in the room already.
stateAtEvent.Overwrite = true
+ var joinEventNIDs []types.EventNID
+ // Request join memberships only for local users only.
+ if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true); err == nil {
+ // If we have no local users that are joined to the room then any state about
+ // the room that we have is quite possibly out of date. Therefore in that case
+ // we should overwrite it rather than merge it.
+ stateAtEvent.Overwrite = len(joinEventNIDs) == 0
+ }
// We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state.
var entries []types.StateEntry
- if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
+ if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
return err
}
- if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil {
+ if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
return err
}
} else {
@@ -149,12 +146,11 @@ func calculateAndSetState(
return err
}
}
- return db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
+ return r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
}
-func processInviteEvent(
+func (r *RoomserverInternalAPI) processInviteEvent(
ctx context.Context,
- db storage.Database,
ow *RoomserverInternalAPI,
input api.InputInviteEvent,
) (*api.InputRoomEvent, error) {
@@ -172,7 +168,10 @@ func processInviteEvent(
"target_user_id": targetUserID,
}).Info("processing invite event")
- updater, err := db.MembershipUpdater(ctx, roomID, targetUserID, input.RoomVersion)
+ _, domain, _ := gomatrixserverlib.SplitID('@', targetUserID)
+ isTargetLocalUser := domain == r.Cfg.Matrix.ServerName
+
+ updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocalUser, input.RoomVersion)
if err != nil {
return nil, err
}
@@ -239,7 +238,7 @@ func processInviteEvent(
// up from local data (which is most likely to be if the event came
// from the CS API). If we know about the room then we can insert
// the invite room state, if we don't then we just fail quietly.
- if irs, ierr := buildInviteStrippedState(ctx, db, input); ierr == nil {
+ if irs, ierr := buildInviteStrippedState(ctx, r.DB, input); ierr == nil {
if err = event.SetUnsignedField("invite_room_state", irs); err != nil {
return nil, err
}
diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go
index 6eeeedab..d7c9a5cb 100644
--- a/roomserver/internal/input_latest_events.go
+++ b/roomserver/internal/input_latest_events.go
@@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state"
- "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@@ -46,17 +45,15 @@ import (
// 7 <----- latest
//
// Can only be called once at a time
-func updateLatestEvents(
+func (r *RoomserverInternalAPI) updateLatestEvents(
ctx context.Context,
- db storage.Database,
- ow OutputRoomEventWriter,
roomNID types.RoomNID,
stateAtEvent types.StateAtEvent,
event gomatrixserverlib.Event,
sendAsServer string,
transactionID *api.TransactionID,
) (err error) {
- updater, err := db.GetLatestEventsForUpdate(ctx, roomNID)
+ updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID)
if err != nil {
return
}
@@ -70,9 +67,8 @@ func updateLatestEvents(
u := latestEventsUpdater{
ctx: ctx,
- db: db,
+ api: r,
updater: updater,
- ow: ow,
roomNID: roomNID,
stateAtEvent: stateAtEvent,
event: event,
@@ -94,9 +90,8 @@ func updateLatestEvents(
// when there are so many variables to pass around.
type latestEventsUpdater struct {
ctx context.Context
- db storage.Database
+ api *RoomserverInternalAPI
updater types.RoomRecentEventsUpdater
- ow OutputRoomEventWriter
roomNID types.RoomNID
stateAtEvent types.StateAtEvent
event gomatrixserverlib.Event
@@ -181,7 +176,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// If we need to generate any output events then here's where we do it.
// TODO: Move this!
- updates, err := updateMemberships(u.ctx, u.db, u.updater, u.removed, u.added)
+ updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added)
if err != nil {
return err
}
@@ -200,7 +195,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// send the event asynchronously but we would need to ensure that 1) the events are written to the log in
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
// necessary bookkeeping we'll keep the event sending synchronous for now.
- if err = u.ow.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
+ if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
return err
}
@@ -213,7 +208,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
func (u *latestEventsUpdater) latestState() error {
var err error
- roomState := state.NewStateResolution(u.db)
+ roomState := state.NewStateResolution(u.api.DB)
// Get a list of the current latest events.
latestStateAtEvents := make([]types.StateAtEvent, len(u.latest))
@@ -303,7 +298,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
latestEventIDs[i] = u.latest[i].EventID
}
- roomVersion, err := u.db.GetRoomVersionForRoom(u.ctx, u.event.RoomID())
+ roomVersion, err := u.api.DB.GetRoomVersionForRoom(u.ctx, u.event.RoomID())
if err != nil {
return nil, err
}
@@ -329,7 +324,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
- eventIDMap, err := u.db.EventIDs(u.ctx, stateEventNIDs)
+ eventIDMap, err := u.api.DB.EventIDs(u.ctx, stateEventNIDs)
if err != nil {
return nil, err
}
diff --git a/roomserver/internal/input_membership.go b/roomserver/internal/input_membership.go
index 666e7ebc..af0c7f8b 100644
--- a/roomserver/internal/input_membership.go
+++ b/roomserver/internal/input_membership.go
@@ -19,7 +19,6 @@ import (
"fmt"
"github.com/matrix-org/dendrite/roomserver/api"
- "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -28,9 +27,8 @@ import (
// user affected by a change in the current state of the room.
// Returns a list of output events to write to the kafka log to inform the
// consumers about the invites added or retired by the change in current state.
-func updateMemberships(
+func (r *RoomserverInternalAPI) updateMemberships(
ctx context.Context,
- db storage.Database,
updater types.RoomRecentEventsUpdater,
removed, added []types.StateEntry,
) ([]api.OutputEvent, error) {
@@ -48,7 +46,7 @@ func updateMemberships(
// Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON?
- events, err := db.Events(ctx, eventNIDs)
+ events, err := r.DB.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
@@ -71,15 +69,16 @@ func updateMemberships(
ae = &ev.Event
}
}
- if updates, err = updateMembership(updater, targetUserNID, re, ae, updates); err != nil {
+ if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil {
return nil, err
}
}
return updates, nil
}
-func updateMembership(
- updater types.RoomRecentEventsUpdater, targetUserNID types.EventStateKeyNID,
+func (r *RoomserverInternalAPI) updateMembership(
+ updater types.RoomRecentEventsUpdater,
+ targetUserNID types.EventStateKeyNID,
remove, add *gomatrixserverlib.Event,
updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
@@ -113,7 +112,7 @@ func updateMembership(
return updates, nil
}
- mu, err := updater.MembershipUpdater(targetUserNID)
+ mu, err := updater.MembershipUpdater(targetUserNID, r.isLocalTarget(add))
if err != nil {
return nil, err
}
@@ -132,6 +131,15 @@ func updateMembership(
}
}
+func (r *RoomserverInternalAPI) isLocalTarget(event *gomatrixserverlib.Event) bool {
+ isTargetLocalUser := false
+ if statekey := event.StateKey(); statekey != nil {
+ _, domain, _ := gomatrixserverlib.SplitID('@', *statekey)
+ isTargetLocalUser = domain == r.Cfg.Matrix.ServerName
+ }
+ return isTargetLocalUser
+}
+
func updateToInviteMembership(
mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
roomVersion gomatrixserverlib.RoomVersion,
diff --git a/roomserver/internal/query.go b/roomserver/internal/query.go
index 2d1c21c5..fce2ae90 100644
--- a/roomserver/internal/query.go
+++ b/roomserver/internal/query.go
@@ -267,7 +267,7 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
var stateEntries []types.StateEntry
if stillInRoom {
var eventNIDs []types.EventNID
- eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly)
+ eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly, false)
if err != nil {
return err
}
@@ -591,7 +591,7 @@ func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, ser
return false, err
}
- eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true)
+ eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false)
if err != nil {
return false, err
}
diff --git a/roomserver/internal/query_backfill.go b/roomserver/internal/query_backfill.go
index 49e0af34..23ae9455 100644
--- a/roomserver/internal/query_backfill.go
+++ b/roomserver/internal/query_backfill.go
@@ -297,7 +297,7 @@ func joinEventsFromHistoryVisibility(
if err != nil {
return nil, err
}
- joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true)
+ joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false)
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go
index fb39eca6..1e0232d2 100644
--- a/roomserver/storage/interface.go
+++ b/roomserver/storage/interface.go
@@ -83,9 +83,9 @@ type Database interface {
GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error)
GetCreatorIDForAlias(ctx context.Context, alias string) (string, error)
RemoveRoomAlias(ctx context.Context, alias string) error
- MembershipUpdater(ctx context.Context, roomID, targetUserID string, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error)
+ MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error)
GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error)
- GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error)
+ GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error)
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error)
GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
}
diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go
index 9c8a4c25..820ef4e7 100644
--- a/roomserver/storage/postgres/membership_table.go
+++ b/roomserver/storage/postgres/membership_table.go
@@ -59,6 +59,10 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
-- This NID is updated if the join event gets updated (e.g. profile update),
-- or if the user leaves/joins the room.
event_nid BIGINT NOT NULL DEFAULT 0,
+ -- Local target is true if the target_nid refers to a local user rather than
+ -- a federated one. This is an optimisation for resetting state on federated
+ -- room joins.
+ target_local BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid)
);
`
@@ -66,8 +70,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
- "INSERT INTO roomserver_membership (room_nid, target_nid)" +
- " VALUES ($1, $2)" +
+ "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
+ " VALUES ($1, $2, $3)" +
" ON CONFLICT DO NOTHING"
const selectMembershipFromRoomAndTargetSQL = "" +
@@ -78,10 +82,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2"
+const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
+ "SELECT event_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1 AND membership_nid = $2" +
+ " AND target_local = true"
+
const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1"
+const selectLocalMembershipsFromRoomSQL = "" +
+ "SELECT event_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1" +
+ " AND target_local = true"
+
const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE"
@@ -91,12 +105,14 @@ const updateMembershipSQL = "" +
" WHERE room_nid = $1 AND target_nid = $2"
type membershipStatements struct {
- insertMembershipStmt *sql.Stmt
- selectMembershipForUpdateStmt *sql.Stmt
- selectMembershipFromRoomAndTargetStmt *sql.Stmt
- selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
- selectMembershipsFromRoomStmt *sql.Stmt
- updateMembershipStmt *sql.Stmt
+ insertMembershipStmt *sql.Stmt
+ selectMembershipForUpdateStmt *sql.Stmt
+ selectMembershipFromRoomAndTargetStmt *sql.Stmt
+ selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
+ selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
+ selectMembershipsFromRoomStmt *sql.Stmt
+ selectLocalMembershipsFromRoomStmt *sql.Stmt
+ updateMembershipStmt *sql.Stmt
}
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@@ -110,7 +126,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
+ {&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL},
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
+ {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
}.prepare(db)
}
@@ -118,9 +136,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
func (s *membershipStatements) insertMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+ localTarget bool,
) error {
stmt := common.TxStmt(txn, s.insertMembershipStmt)
- _, err := stmt.ExecContext(ctx, roomNID, targetUserNID)
+ _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
return err
}
@@ -145,9 +164,15 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget(
}
func (s *membershipStatements) selectMembershipsFromRoom(
- ctx context.Context, roomNID types.RoomNID,
+ ctx context.Context, roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
- rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID)
+ var stmt *sql.Stmt
+ if localOnly {
+ stmt = s.selectLocalMembershipsFromRoomStmt
+ } else {
+ stmt = s.selectMembershipsFromRoomStmt
+ }
+ rows, err := stmt.QueryContext(ctx, roomNID)
if err != nil {
return
}
@@ -165,10 +190,16 @@ func (s *membershipStatements) selectMembershipsFromRoom(
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context,
- roomNID types.RoomNID, membership membershipState,
+ roomNID types.RoomNID, membership membershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
- stmt := s.selectMembershipsFromRoomAndMembershipStmt
- rows, err := stmt.QueryContext(ctx, roomNID, membership)
+ var rows *sql.Rows
+ var stmt *sql.Stmt
+ if localOnly {
+ stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt
+ } else {
+ stmt = s.selectMembershipsFromRoomAndMembershipStmt
+ }
+ rows, err = stmt.QueryContext(ctx, roomNID, membership)
if err != nil {
return
}
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go
index 1d825ecc..d451d665 100644
--- a/roomserver/storage/postgres/storage.go
+++ b/roomserver/storage/postgres/storage.go
@@ -459,8 +459,8 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error
return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID)
}
-func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) {
- return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID)
+func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) {
+ return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal)
}
// RoomNID implements query.RoomserverQueryAPIDB
@@ -558,7 +558,7 @@ func (d *Database) StateEntriesForTuples(
// MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string,
- roomVersion gomatrixserverlib.RoomVersion,
+ targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
) (types.MembershipUpdater, error) {
txn, err := d.db.Begin()
if err != nil {
@@ -581,7 +581,7 @@ func (d *Database) MembershipUpdater(
return nil, err
}
- updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
+ updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal)
if err != nil {
return nil, err
}
@@ -603,9 +603,10 @@ func (d *Database) membershipUpdaterTxn(
txn *sql.Tx,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
+ targetLocal bool,
) (types.MembershipUpdater, error) {
- if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
+ if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
return nil, err
}
@@ -748,15 +749,15 @@ func (d *Database) GetMembership(
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom(
- ctx context.Context, roomNID types.RoomNID, joinOnly bool,
+ ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
if joinOnly {
return d.statements.selectMembershipsFromRoomAndMembership(
- ctx, roomNID, membershipStateJoin,
+ ctx, roomNID, membershipStateJoin, localOnly,
)
}
- return d.statements.selectMembershipsFromRoom(ctx, roomNID)
+ return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly)
}
// EventsFromIDs implements query.RoomserverQueryAPIEventDB
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
index 7ae28e4b..ca4d8fbe 100644
--- a/roomserver/storage/sqlite3/membership_table.go
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -38,6 +38,7 @@ const membershipSchema = `
sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0,
+ target_local BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid)
);
`
@@ -45,8 +46,8 @@ const membershipSchema = `
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
- "INSERT INTO roomserver_membership (room_nid, target_nid)" +
- " VALUES ($1, $2)" +
+ "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
+ " VALUES ($1, $2, $3)" +
" ON CONFLICT DO NOTHING"
const selectMembershipFromRoomAndTargetSQL = "" +
@@ -57,10 +58,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2"
+const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
+ "SELECT event_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1 AND membership_nid = $2" +
+ " AND target_local = true"
+
const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1"
+const selectLocalMembershipsFromRoomSQL = "" +
+ "SELECT event_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1" +
+ " AND target_local = true"
+
const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2"
@@ -70,12 +81,14 @@ const updateMembershipSQL = "" +
" WHERE room_nid = $4 AND target_nid = $5"
type membershipStatements struct {
- insertMembershipStmt *sql.Stmt
- selectMembershipForUpdateStmt *sql.Stmt
- selectMembershipFromRoomAndTargetStmt *sql.Stmt
- selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
- selectMembershipsFromRoomStmt *sql.Stmt
- updateMembershipStmt *sql.Stmt
+ insertMembershipStmt *sql.Stmt
+ selectMembershipForUpdateStmt *sql.Stmt
+ selectMembershipFromRoomAndTargetStmt *sql.Stmt
+ selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
+ selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
+ selectMembershipsFromRoomStmt *sql.Stmt
+ selectLocalMembershipsFromRoomStmt *sql.Stmt
+ updateMembershipStmt *sql.Stmt
}
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@@ -89,7 +102,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
+ {&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL},
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
+ {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
}.prepare(db)
}
@@ -97,9 +112,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
func (s *membershipStatements) insertMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+ localTarget bool,
) error {
stmt := common.TxStmt(txn, s.insertMembershipStmt)
- _, err := stmt.ExecContext(ctx, roomNID, targetUserNID)
+ _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
return err
}
@@ -127,9 +143,14 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget(
func (s *membershipStatements) selectMembershipsFromRoom(
ctx context.Context, txn *sql.Tx,
- roomNID types.RoomNID,
+ roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
- selectStmt := common.TxStmt(txn, s.selectMembershipsFromRoomStmt)
+ var selectStmt *sql.Stmt
+ if localOnly {
+ selectStmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomStmt)
+ } else {
+ selectStmt = common.TxStmt(txn, s.selectMembershipsFromRoomStmt)
+ }
rows, err := selectStmt.QueryContext(ctx, roomNID)
if err != nil {
return nil, err
@@ -145,11 +166,17 @@ func (s *membershipStatements) selectMembershipsFromRoom(
}
return
}
+
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context, txn *sql.Tx,
- roomNID types.RoomNID, membership membershipState,
+ roomNID types.RoomNID, membership membershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
- stmt := common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt)
+ var stmt *sql.Stmt
+ if localOnly {
+ stmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomAndMembershipStmt)
+ } else {
+ stmt = common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt)
+ }
rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil {
return
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go
index e77fea9c..209922fa 100644
--- a/roomserver/storage/sqlite3/storage.go
+++ b/roomserver/storage/sqlite3/storage.go
@@ -569,9 +569,9 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error
return err
}
-func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) {
+func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (mu types.MembershipUpdater, err error) {
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
- mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID)
+ mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID, targetLocal)
return err
})
return
@@ -680,7 +680,7 @@ func (d *Database) StateEntriesForTuples(
// MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string,
- roomVersion gomatrixserverlib.RoomVersion,
+ targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
) (updater types.MembershipUpdater, err error) {
var txn *sql.Tx
txn, err = d.db.Begin()
@@ -716,7 +716,7 @@ func (d *Database) MembershipUpdater(
return nil, err
}
- updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
+ updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal)
if err != nil {
return nil, err
}
@@ -738,9 +738,10 @@ func (d *Database) membershipUpdaterTxn(
txn *sql.Tx,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
+ targetLocal bool,
) (types.MembershipUpdater, error) {
- if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
+ if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
return nil, err
}
@@ -896,17 +897,17 @@ func (d *Database) GetMembership(
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom(
- ctx context.Context, roomNID types.RoomNID, joinOnly bool,
+ ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
if joinOnly {
eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership(
- ctx, txn, roomNID, membershipStateJoin,
+ ctx, txn, roomNID, membershipStateJoin, localOnly,
)
return nil
}
- eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID)
+ eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
return nil
})
return
diff --git a/roomserver/types/types.go b/roomserver/types/types.go
index da83f614..74e6b078 100644
--- a/roomserver/types/types.go
+++ b/roomserver/types/types.go
@@ -172,7 +172,7 @@ type RoomRecentEventsUpdater interface {
MarkEventAsSent(eventNID EventNID) error
// Build a membership updater for the target user in this room.
// It will share the same transaction as this updater.
- MembershipUpdater(targetUserNID EventStateKeyNID) (MembershipUpdater, error)
+ MembershipUpdater(targetUserNID EventStateKeyNID, isTargetLocalUser bool) (MembershipUpdater, error)
// Implements Transaction so it can be committed or rolledback
common.Transaction
}