diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2020-05-20 18:03:06 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-20 18:03:06 +0100 |
commit | f2c07437fe3d7f54977f7e645cd045a04e832020 (patch) | |
tree | fe0d5c92b1bc6ac5d5acb065f555de4ca704c1ad /roomserver/storage/sqlite3 | |
parent | 6091bf044ff41e9f248b1077d5b05f1a4c694412 (diff) |
Use memberships to determine whether to reset latest events/state on room join (#1047)
* Track local/remote memberships, re-scope some input stuff
* Check if we're in the room already before resetting latest events/state
* Fix postgres, fix lint
* Review comments
Diffstat (limited to 'roomserver/storage/sqlite3')
-rw-r--r-- | roomserver/storage/sqlite3/membership_table.go | 53 | ||||
-rw-r--r-- | roomserver/storage/sqlite3/storage.go | 17 |
2 files changed, 49 insertions, 21 deletions
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 7ae28e4b..ca4d8fbe 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -38,6 +38,7 @@ const membershipSchema = ` sender_nid INTEGER NOT NULL DEFAULT 0, membership_nid INTEGER NOT NULL DEFAULT 1, event_nid INTEGER NOT NULL DEFAULT 0, + target_local BOOLEAN NOT NULL DEFAULT false, UNIQUE (room_nid, target_nid) ); ` @@ -45,8 +46,8 @@ const membershipSchema = ` // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + - "INSERT INTO roomserver_membership (room_nid, target_nid)" + - " VALUES ($1, $2)" + + "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" + + " VALUES ($1, $2, $3)" + " ON CONFLICT DO NOTHING" const selectMembershipFromRoomAndTargetSQL = "" + @@ -57,10 +58,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND membership_nid = $2" +const selectLocalMembershipsFromRoomAndMembershipSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND membership_nid = $2" + + " AND target_local = true" + const selectMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1" +const selectLocalMembershipsFromRoomSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1" + + " AND target_local = true" + const selectMembershipForUpdateSQL = "" + "SELECT membership_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2" @@ -70,12 +81,14 @@ const updateMembershipSQL = "" + " WHERE room_nid = $4 AND target_nid = $5" type membershipStatements struct { - insertMembershipStmt *sql.Stmt - selectMembershipForUpdateStmt *sql.Stmt - selectMembershipFromRoomAndTargetStmt *sql.Stmt - selectMembershipsFromRoomAndMembershipStmt *sql.Stmt - selectMembershipsFromRoomStmt *sql.Stmt - updateMembershipStmt *sql.Stmt + insertMembershipStmt *sql.Stmt + selectMembershipForUpdateStmt *sql.Stmt + selectMembershipFromRoomAndTargetStmt *sql.Stmt + selectMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectMembershipsFromRoomStmt *sql.Stmt + selectLocalMembershipsFromRoomStmt *sql.Stmt + updateMembershipStmt *sql.Stmt } func (s *membershipStatements) prepare(db *sql.DB) (err error) { @@ -89,7 +102,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, + {&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL}, {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, + {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, }.prepare(db) } @@ -97,9 +112,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) insertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + localTarget bool, ) error { stmt := common.TxStmt(txn, s.insertMembershipStmt) - _, err := stmt.ExecContext(ctx, roomNID, targetUserNID) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) return err } @@ -127,9 +143,14 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget( func (s *membershipStatements) selectMembershipsFromRoom( ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, + roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - selectStmt := common.TxStmt(txn, s.selectMembershipsFromRoomStmt) + var selectStmt *sql.Stmt + if localOnly { + selectStmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomStmt) + } else { + selectStmt = common.TxStmt(txn, s.selectMembershipsFromRoomStmt) + } rows, err := selectStmt.QueryContext(ctx, roomNID) if err != nil { return nil, err @@ -145,11 +166,17 @@ func (s *membershipStatements) selectMembershipsFromRoom( } return } + func (s *membershipStatements) selectMembershipsFromRoomAndMembership( ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, membership membershipState, + roomNID types.RoomNID, membership membershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { - stmt := common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) + var stmt *sql.Stmt + if localOnly { + stmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomAndMembershipStmt) + } else { + stmt = common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) + } rows, err := stmt.QueryContext(ctx, roomNID, membership) if err != nil { return diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index e77fea9c..209922fa 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -569,9 +569,9 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error return err } -func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) { +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (mu types.MembershipUpdater, err error) { err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { - mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID) + mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID, targetLocal) return err }) return @@ -680,7 +680,7 @@ func (d *Database) StateEntriesForTuples( // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, - roomVersion gomatrixserverlib.RoomVersion, + targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, ) (updater types.MembershipUpdater, err error) { var txn *sql.Tx txn, err = d.db.Begin() @@ -716,7 +716,7 @@ func (d *Database) MembershipUpdater( return nil, err } - updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) + updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) if err != nil { return nil, err } @@ -738,9 +738,10 @@ func (d *Database) membershipUpdaterTxn( txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + targetLocal bool, ) (types.MembershipUpdater, error) { - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { + if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { return nil, err } @@ -896,17 +897,17 @@ func (d *Database) GetMembership( // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, + ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, ) (eventNIDs []types.EventNID, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { if joinOnly { eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership( - ctx, txn, roomNID, membershipStateJoin, + ctx, txn, roomNID, membershipStateJoin, localOnly, ) return nil } - eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID) + eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly) return nil }) return |