aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/postgres/storage.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/storage/postgres/storage.go')
-rw-r--r--roomserver/storage/postgres/storage.go17
1 files changed, 9 insertions, 8 deletions
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go
index 1d825ecc..d451d665 100644
--- a/roomserver/storage/postgres/storage.go
+++ b/roomserver/storage/postgres/storage.go
@@ -459,8 +459,8 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error
return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID)
}
-func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) {
- return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID)
+func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) {
+ return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal)
}
// RoomNID implements query.RoomserverQueryAPIDB
@@ -558,7 +558,7 @@ func (d *Database) StateEntriesForTuples(
// MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string,
- roomVersion gomatrixserverlib.RoomVersion,
+ targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
) (types.MembershipUpdater, error) {
txn, err := d.db.Begin()
if err != nil {
@@ -581,7 +581,7 @@ func (d *Database) MembershipUpdater(
return nil, err
}
- updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
+ updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal)
if err != nil {
return nil, err
}
@@ -603,9 +603,10 @@ func (d *Database) membershipUpdaterTxn(
txn *sql.Tx,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
+ targetLocal bool,
) (types.MembershipUpdater, error) {
- if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
+ if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
return nil, err
}
@@ -748,15 +749,15 @@ func (d *Database) GetMembership(
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom(
- ctx context.Context, roomNID types.RoomNID, joinOnly bool,
+ ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
if joinOnly {
return d.statements.selectMembershipsFromRoomAndMembership(
- ctx, roomNID, membershipStateJoin,
+ ctx, roomNID, membershipStateJoin, localOnly,
)
}
- return d.statements.selectMembershipsFromRoom(ctx, roomNID)
+ return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly)
}
// EventsFromIDs implements query.RoomserverQueryAPIEventDB