aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/shared/storage.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/storage/shared/storage.go')
-rw-r--r--roomserver/storage/shared/storage.go39
1 files changed, 30 insertions, 9 deletions
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index ba937ba3..692af1f6 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -439,8 +439,18 @@ func (d *Database) Events(
}
func (d *Database) events(
- ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
+ ctx context.Context, txn *sql.Tx, inputEventNIDs types.EventNIDs,
) ([]types.Event, error) {
+ sort.Sort(inputEventNIDs)
+ events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs))
+ eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs))
+ for _, nid := range inputEventNIDs {
+ if event, ok := d.Cache.GetRoomServerEvent(nid); ok && event != nil {
+ events[nid] = event
+ } else {
+ eventNIDs = append(eventNIDs, nid)
+ }
+ }
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs)
if err != nil {
return nil, err
@@ -476,18 +486,29 @@ func (d *Database) events(
for n, v := range dbRoomVersions {
roomVersions[n] = v
}
- results := make([]types.Event, len(eventJSONs))
- for i, eventJSON := range eventJSONs {
- result := &results[i]
- result.EventNID = eventJSON.EventNID
- roomNID := roomNIDs[result.EventNID]
+ for _, eventJSON := range eventJSONs {
+ roomNID := roomNIDs[eventJSON.EventNID]
roomVersion := roomVersions[roomNID]
- result.Event, err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID(
+ events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID(
eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion,
)
if err != nil {
return nil, err
}
+ if event := events[eventJSON.EventNID]; event != nil {
+ d.Cache.StoreRoomServerEvent(eventJSON.EventNID, event)
+ }
+ }
+ results := make([]types.Event, 0, len(inputEventNIDs))
+ for _, nid := range inputEventNIDs {
+ event, ok := events[nid]
+ if !ok || event == nil {
+ return nil, fmt.Errorf("event %d missing", nid)
+ }
+ results = append(results, types.Event{
+ EventNID: nid,
+ Event: event,
+ })
}
if !redactionsArePermanent {
d.applyRedactions(results)
@@ -854,7 +875,7 @@ func (d *Database) handleRedactions(
// mark the event as redacted
if redactionsArePermanent {
- redactedEvent.Event = redactedEvent.Redact()
+ redactedEvent.Redact()
}
err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent)
@@ -926,7 +947,7 @@ func (d *Database) loadRedactionPair(
func (d *Database) applyRedactions(events []types.Event) {
for i := range events {
if result := gjson.GetBytes(events[i].Unsigned(), "redacted_because"); result.Exists() {
- events[i].Event = events[i].Redact()
+ events[i].Redact()
}
}
}