diff options
Diffstat (limited to 'roomserver/storage/postgres/storage.go')
-rw-r--r-- | roomserver/storage/postgres/storage.go | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 1d825ecc..d451d665 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -459,8 +459,8 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) } -func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID) +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) { + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) } // RoomNID implements query.RoomserverQueryAPIDB @@ -558,7 +558,7 @@ func (d *Database) StateEntriesForTuples( // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, - roomVersion gomatrixserverlib.RoomVersion, + targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, ) (types.MembershipUpdater, error) { txn, err := d.db.Begin() if err != nil { @@ -581,7 +581,7 @@ func (d *Database) MembershipUpdater( return nil, err } - updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) + updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) if err != nil { return nil, err } @@ -603,9 +603,10 @@ func (d *Database) membershipUpdaterTxn( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + targetLocal bool, ) (types.MembershipUpdater, error) { - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { + if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { return nil, err } @@ -748,15 +749,15 @@ func (d *Database) GetMembership( // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, + ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, ) ([]types.EventNID, error) { if joinOnly { return d.statements.selectMembershipsFromRoomAndMembership( - ctx, roomNID, membershipStateJoin, + ctx, roomNID, membershipStateJoin, localOnly, ) } - return d.statements.selectMembershipsFromRoom(ctx, roomNID) + return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly) } // EventsFromIDs implements query.RoomserverQueryAPIEventDB |