aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/shared/storage.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/storage/shared/storage.go')
-rw-r--r--roomserver/storage/shared/storage.go137
1 files changed, 74 insertions, 63 deletions
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 8c7854e8..a9cb5782 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"encoding/json"
- "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
@@ -345,13 +344,15 @@ func (d *Database) GetLatestEventsForUpdate(
func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event,
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
-) (types.RoomNID, types.StateAtEvent, error) {
+) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
var (
roomNID types.RoomNID
eventTypeNID types.EventTypeNID
eventStateKeyNID types.EventStateKeyNID
eventNID types.EventNID
stateNID types.StateSnapshotNID
+ redactionEvent *gomatrixserverlib.Event
+ redactedEventID string
err error
)
@@ -419,11 +420,11 @@ func (d *Database) StoreEvent(
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
return err
}
-
- return d.handleRedactions(ctx, txn, eventNID, event)
+ redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event)
+ return err
})
if err != nil {
- return 0, types.StateAtEvent{}, err
+ return 0, types.StateAtEvent{}, nil, "", err
}
return roomNID, types.StateAtEvent{
@@ -435,7 +436,7 @@ func (d *Database) StoreEvent(
},
EventNID: eventNID,
},
- }, nil
+ }, redactionEvent, redactedEventID, nil
}
func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error {
@@ -531,20 +532,42 @@ func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) (
// When an event is redacted, the redacted event JSON is modified to add an `unsigned.redacted_because` field. We use this field
// when loading events to determine whether to apply redactions. This keeps the hot-path of reading events quick as we don't need
// to cross-reference with other tables when loading.
-func (d *Database) handleRedactions(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event gomatrixserverlib.Event) error {
+//
+// Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction.
+func (d *Database) handleRedactions(
+ ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event gomatrixserverlib.Event,
+) (*gomatrixserverlib.Event, string, error) {
+ var err error
+ isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil
+ if isRedactionEvent {
+ // an event which redacts itself should be ignored
+ if event.EventID() == event.Redacts() {
+ return nil, "", nil
+ }
+
+ err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{
+ Validated: false,
+ RedactionEventID: event.EventID(),
+ RedactsEventID: event.Redacts(),
+ })
+ if err != nil {
+ return nil, "", err
+ }
+ }
+
redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, eventNID, event)
if err != nil {
- return err
+ return nil, "", err
}
if validated || redactedEvent == nil || redactionEvent == nil {
// we've seen this redaction before or there is nothing to redact
- return nil
+ return nil, "", nil
}
// mark the event as redacted
err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent)
if err != nil {
- return err
+ return nil, "", err
}
if redactionsArePermanent {
redactedEvent.Event = redactedEvent.Redact()
@@ -552,82 +575,51 @@ func (d *Database) handleRedactions(ctx context.Context, txn *sql.Tx, eventNID t
// overwrite the eventJSON table
err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON())
if err != nil {
- return err
+ return nil, "", err
}
- return d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true)
+ return &redactionEvent.Event, redactedEvent.EventID(), d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true)
}
// loadRedactionPair returns both the redaction event and the redacted event, else nil.
-// nolint:gocyclo
func (d *Database) loadRedactionPair(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event gomatrixserverlib.Event,
) (*types.Event, *types.Event, bool, error) {
var redactionEvent, redactedEvent *types.Event
var info *tables.RedactionInfo
- var nids map[string]types.EventNID
- var evs []types.Event
var err error
isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil
+
+ var eventBeingRedacted string
if isRedactionEvent {
+ eventBeingRedacted = event.Redacts()
redactionEvent = &types.Event{
EventNID: eventNID,
Event: event,
}
- // find the redacted event if one exists
- info, err = d.RedactionsTable.SelectRedactedEvent(ctx, txn, event.EventID())
- if err != nil {
- return nil, nil, false, err
- }
- if info == nil {
- // we don't have the redacted event yet
- return nil, nil, false, nil
- }
- nids, err = d.EventNIDs(ctx, []string{info.RedactsEventID})
- if err != nil {
- return nil, nil, false, err
- }
- if len(nids) == 0 {
- return nil, nil, false, fmt.Errorf("redaction: missing event NID being redacted: %+v", info)
- }
- evs, err = d.Events(ctx, []types.EventNID{nids[info.RedactsEventID]})
- if err != nil {
- return nil, nil, false, err
- }
- if len(evs) != 1 {
- return nil, nil, false, fmt.Errorf("redaction: missing event being redacted: %+v", info)
- }
- redactedEvent = &evs[0]
} else {
+ eventBeingRedacted = event.EventID() // maybe, we'll see if we have info
redactedEvent = &types.Event{
EventNID: eventNID,
Event: event,
}
- // find the redaction event if one exists
- info, err = d.RedactionsTable.SelectRedactionEvent(ctx, txn, event.EventID())
- if err != nil {
- return nil, nil, false, err
- }
- if info == nil {
- // this event is not redacted
- return nil, nil, false, nil
- }
- nids, err = d.EventNIDs(ctx, []string{info.RedactionEventID})
- if err != nil {
- return nil, nil, false, err
- }
- if len(nids) == 0 {
- return nil, nil, false, fmt.Errorf("redaction: missing redaction event NID: %+v", info)
- }
- evs, err = d.Events(ctx, []types.EventNID{nids[info.RedactionEventID]})
- if err != nil {
- return nil, nil, false, err
- }
- if len(evs) != 1 {
- return nil, nil, false, fmt.Errorf("redaction: missing redaction event: %+v", info)
- }
- redactionEvent = &evs[0]
}
+
+ info, err = d.RedactionsTable.SelectRedactionInfoByEventBeingRedacted(ctx, txn, eventBeingRedacted)
+ if err != nil {
+ return nil, nil, false, err
+ }
+ if info == nil {
+ // this event hasn't been redacted or we don't have the redaction for it yet
+ return nil, nil, false, nil
+ }
+
+ if isRedactionEvent {
+ redactedEvent = d.loadEvent(ctx, info.RedactsEventID)
+ } else {
+ redactionEvent = d.loadEvent(ctx, info.RedactionEventID)
+ }
+
return redactionEvent, redactedEvent, info.Validated, nil
}
@@ -639,3 +631,22 @@ func (d *Database) applyRedactions(events []types.Event) {
}
}
}
+
+// loadEvent loads a single event or returns nil on any problems/missing event
+func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
+ nids, err := d.EventNIDs(ctx, []string{eventID})
+ if err != nil {
+ return nil
+ }
+ if len(nids) == 0 {
+ return nil
+ }
+ evs, err := d.Events(ctx, []types.EventNID{nids[eventID]})
+ if err != nil {
+ return nil
+ }
+ if len(evs) != 1 {
+ return nil
+ }
+ return &evs[0]
+}