aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/sqlite3
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2020-05-20 18:03:06 +0100
committerGitHub <noreply@github.com>2020-05-20 18:03:06 +0100
commitf2c07437fe3d7f54977f7e645cd045a04e832020 (patch)
treefe0d5c92b1bc6ac5d5acb065f555de4ca704c1ad /roomserver/storage/sqlite3
parent6091bf044ff41e9f248b1077d5b05f1a4c694412 (diff)
Use memberships to determine whether to reset latest events/state on room join (#1047)
* Track local/remote memberships, re-scope some input stuff * Check if we're in the room already before resetting latest events/state * Fix postgres, fix lint * Review comments
Diffstat (limited to 'roomserver/storage/sqlite3')
-rw-r--r--roomserver/storage/sqlite3/membership_table.go53
-rw-r--r--roomserver/storage/sqlite3/storage.go17
2 files changed, 49 insertions, 21 deletions
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
index 7ae28e4b..ca4d8fbe 100644
--- a/roomserver/storage/sqlite3/membership_table.go
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -38,6 +38,7 @@ const membershipSchema = `
sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0,
+ target_local BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid)
);
`
@@ -45,8 +46,8 @@ const membershipSchema = `
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
- "INSERT INTO roomserver_membership (room_nid, target_nid)" +
- " VALUES ($1, $2)" +
+ "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
+ " VALUES ($1, $2, $3)" +
" ON CONFLICT DO NOTHING"
const selectMembershipFromRoomAndTargetSQL = "" +
@@ -57,10 +58,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2"
+const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
+ "SELECT event_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1 AND membership_nid = $2" +
+ " AND target_local = true"
+
const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1"
+const selectLocalMembershipsFromRoomSQL = "" +
+ "SELECT event_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1" +
+ " AND target_local = true"
+
const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2"
@@ -70,12 +81,14 @@ const updateMembershipSQL = "" +
" WHERE room_nid = $4 AND target_nid = $5"
type membershipStatements struct {
- insertMembershipStmt *sql.Stmt
- selectMembershipForUpdateStmt *sql.Stmt
- selectMembershipFromRoomAndTargetStmt *sql.Stmt
- selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
- selectMembershipsFromRoomStmt *sql.Stmt
- updateMembershipStmt *sql.Stmt
+ insertMembershipStmt *sql.Stmt
+ selectMembershipForUpdateStmt *sql.Stmt
+ selectMembershipFromRoomAndTargetStmt *sql.Stmt
+ selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
+ selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
+ selectMembershipsFromRoomStmt *sql.Stmt
+ selectLocalMembershipsFromRoomStmt *sql.Stmt
+ updateMembershipStmt *sql.Stmt
}
func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@@ -89,7 +102,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
+ {&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL},
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
+ {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
}.prepare(db)
}
@@ -97,9 +112,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
func (s *membershipStatements) insertMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+ localTarget bool,
) error {
stmt := common.TxStmt(txn, s.insertMembershipStmt)
- _, err := stmt.ExecContext(ctx, roomNID, targetUserNID)
+ _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
return err
}
@@ -127,9 +143,14 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget(
func (s *membershipStatements) selectMembershipsFromRoom(
ctx context.Context, txn *sql.Tx,
- roomNID types.RoomNID,
+ roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
- selectStmt := common.TxStmt(txn, s.selectMembershipsFromRoomStmt)
+ var selectStmt *sql.Stmt
+ if localOnly {
+ selectStmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomStmt)
+ } else {
+ selectStmt = common.TxStmt(txn, s.selectMembershipsFromRoomStmt)
+ }
rows, err := selectStmt.QueryContext(ctx, roomNID)
if err != nil {
return nil, err
@@ -145,11 +166,17 @@ func (s *membershipStatements) selectMembershipsFromRoom(
}
return
}
+
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context, txn *sql.Tx,
- roomNID types.RoomNID, membership membershipState,
+ roomNID types.RoomNID, membership membershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
- stmt := common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt)
+ var stmt *sql.Stmt
+ if localOnly {
+ stmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomAndMembershipStmt)
+ } else {
+ stmt = common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt)
+ }
rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil {
return
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go
index e77fea9c..209922fa 100644
--- a/roomserver/storage/sqlite3/storage.go
+++ b/roomserver/storage/sqlite3/storage.go
@@ -569,9 +569,9 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error
return err
}
-func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) {
+func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (mu types.MembershipUpdater, err error) {
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
- mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID)
+ mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID, targetLocal)
return err
})
return
@@ -680,7 +680,7 @@ func (d *Database) StateEntriesForTuples(
// MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string,
- roomVersion gomatrixserverlib.RoomVersion,
+ targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
) (updater types.MembershipUpdater, err error) {
var txn *sql.Tx
txn, err = d.db.Begin()
@@ -716,7 +716,7 @@ func (d *Database) MembershipUpdater(
return nil, err
}
- updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
+ updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal)
if err != nil {
return nil, err
}
@@ -738,9 +738,10 @@ func (d *Database) membershipUpdaterTxn(
txn *sql.Tx,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
+ targetLocal bool,
) (types.MembershipUpdater, error) {
- if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
+ if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
return nil, err
}
@@ -896,17 +897,17 @@ func (d *Database) GetMembership(
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom(
- ctx context.Context, roomNID types.RoomNID, joinOnly bool,
+ ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
if joinOnly {
eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership(
- ctx, txn, roomNID, membershipStateJoin,
+ ctx, txn, roomNID, membershipStateJoin, localOnly,
)
return nil
}
- eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID)
+ eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
return nil
})
return