aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--roomserver/internal/input/input_latest_events.go2
-rw-r--r--roomserver/storage/postgres/events_table.go27
-rw-r--r--roomserver/storage/shared/membership_updater.go6
-rw-r--r--roomserver/storage/shared/room_updater.go14
-rw-r--r--roomserver/storage/shared/storage.go26
-rw-r--r--roomserver/storage/sqlite3/events_table.go33
-rw-r--r--roomserver/storage/tables/interface.go1
7 files changed, 90 insertions, 19 deletions
diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go
index 5173d3ab..ae28ebef 100644
--- a/roomserver/internal/input/input_latest_events.go
+++ b/roomserver/internal/input/input_latest_events.go
@@ -405,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro
if len(extraEventIDs) == 0 {
return nil, nil
}
- extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs)
+ extraEvents, err := u.updater.UnsentEventsFromIDs(u.ctx, extraEventIDs)
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go
index c136f039..8012174a 100644
--- a/roomserver/storage/postgres/events_table.go
+++ b/roomserver/storage/postgres/events_table.go
@@ -127,6 +127,9 @@ const bulkSelectEventIDSQL = "" +
const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)"
+const bulkSelectUnsentEventNIDSQL = "" +
+ "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE"
+
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)"
@@ -147,6 +150,7 @@ type eventStatements struct {
bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt
+ bulkSelectUnsentEventNIDStmt *sql.Stmt
selectMaxEventDepthStmt *sql.Stmt
selectRoomNIDsForEventNIDsStmt *sql.Stmt
}
@@ -173,6 +177,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
+ {&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL},
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
{&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL},
}.Prepare(db)
@@ -458,10 +463,28 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
return results, nil
}
-// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
+// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
- stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
+ return s.bulkSelectEventNID(ctx, txn, eventIDs, false)
+}
+
+// BulkSelectEventNIDs returns a map from string event ID to numeric event ID
+// only for events that haven't already been sent to the roomserver output.
+// If an event ID is not in the database then it is omitted from the map.
+func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
+ return s.bulkSelectEventNID(ctx, txn, eventIDs, true)
+}
+
+// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
+// If an event ID is not in the database then it is omitted from the map.
+func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) {
+ var stmt *sql.Stmt
+ if onlyUnsent {
+ stmt = sqlutil.TxStmt(txn, s.bulkSelectUnsentEventNIDStmt)
+ } else {
+ stmt = sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
+ }
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, err
diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go
index 66ac2f5b..8f3f3d63 100644
--- a/roomserver/storage/shared/membership_updater.go
+++ b/roomserver/storage/shared/membership_updater.go
@@ -136,7 +136,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
}
// Look up the NID of the new join event
- nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID})
+ nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
@@ -170,7 +170,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
}
// Look up the NID of the new leave event
- nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID})
+ nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
@@ -196,7 +196,7 @@ func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, er
}
if u.membership != tables.MembershipStateKnock {
// Look up the NID of the new knock event
- nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()})
+ nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go
index 89b878b9..810a18ef 100644
--- a/roomserver/storage/shared/room_updater.go
+++ b/roomserver/storage/shared/room_updater.go
@@ -215,7 +215,13 @@ func (u *RoomUpdater) EventIDs(
func (u *RoomUpdater) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
- return u.d.eventNIDs(ctx, u.txn, eventIDs)
+ return u.d.eventNIDs(ctx, u.txn, eventIDs, NoFilter)
+}
+
+func (u *RoomUpdater) UnsentEventNIDs(
+ ctx context.Context, eventIDs []string,
+) (map[string]types.EventNID, error) {
+ return u.d.eventNIDs(ctx, u.txn, eventIDs, FilterUnsentOnly)
}
func (u *RoomUpdater) StateAtEventIDs(
@@ -231,7 +237,11 @@ func (u *RoomUpdater) StateEntriesForEventIDs(
}
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
- return u.d.eventsFromIDs(ctx, u.txn, eventIDs)
+ return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false)
+}
+
+func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
+ return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true)
}
func (u *RoomUpdater) GetMembershipEventNIDsForRoom(
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index e96c77af..9f3b8b1d 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -238,13 +238,27 @@ func (d *Database) addState(
func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
- return d.eventNIDs(ctx, nil, eventIDs)
+ return d.eventNIDs(ctx, nil, eventIDs, NoFilter)
}
+type UnsentFilter bool
+
+const (
+ NoFilter UnsentFilter = false
+ FilterUnsentOnly UnsentFilter = true
+)
+
func (d *Database) eventNIDs(
- ctx context.Context, txn *sql.Tx, eventIDs []string,
+ ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter,
) (map[string]types.EventNID, error) {
- return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
+ switch filter {
+ case FilterUnsentOnly:
+ return d.EventsTable.BulkSelectUnsentEventNID(ctx, txn, eventIDs)
+ case NoFilter:
+ return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
+ default:
+ panic("impossible case")
+ }
}
func (d *Database) SetState(
@@ -281,11 +295,11 @@ func (d *Database) EventIDs(
}
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
- return d.eventsFromIDs(ctx, nil, eventIDs)
+ return d.eventsFromIDs(ctx, nil, eventIDs, NoFilter)
}
-func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) {
- nidMap, err := d.eventNIDs(ctx, txn, eventIDs)
+func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter) ([]types.Event, error) {
+ nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter)
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go
index cef09fe6..969a10ce 100644
--- a/roomserver/storage/sqlite3/events_table.go
+++ b/roomserver/storage/sqlite3/events_table.go
@@ -99,6 +99,9 @@ const bulkSelectEventIDSQL = "" +
const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)"
+const bulkSelectUnsentEventNIDSQL = "" +
+ "SELECT event_id, event_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)"
+
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
@@ -118,8 +121,9 @@ type eventStatements struct {
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt
- bulkSelectEventNIDStmt *sql.Stmt
- //selectRoomNIDsForEventNIDsStmt *sql.Stmt
+ //bulkSelectEventNIDStmt *sql.Stmt
+ //bulkSelectUnsentEventNIDStmt *sql.Stmt
+ //selectRoomNIDsForEventNIDsStmt *sql.Stmt
}
func createEventsTable(db *sql.DB) error {
@@ -144,7 +148,8 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) {
{&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
- {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
+ //{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
+ //{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL},
//{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
}.Prepare(db)
}
@@ -494,15 +499,33 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
return results, nil
}
-// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
+// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
+ return s.bulkSelectEventNID(ctx, txn, eventIDs, false)
+}
+
+// BulkSelectEventNIDs returns a map from string event ID to numeric event ID
+// only for events that haven't already been sent to the roomserver output.
+// If an event ID is not in the database then it is omitted from the map.
+func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
+ return s.bulkSelectEventNID(ctx, txn, eventIDs, true)
+}
+
+// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
+// If an event ID is not in the database then it is omitted from the map.
+func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) {
///////////////
iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs {
iEventIDs[k] = v
}
- selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
+ var selectOrig string
+ if onlyUnsent {
+ selectOrig = strings.Replace(bulkSelectUnsentEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
+ } else {
+ selectOrig = strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1)
+ }
selectStmt, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go
index fed39b94..e3fed700 100644
--- a/roomserver/storage/tables/interface.go
+++ b/roomserver/storage/tables/interface.go
@@ -59,6 +59,7 @@ type Events interface {
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
+ BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
}