diff options
Diffstat (limited to 'roomserver/storage/shared/storage.go')
-rw-r--r-- | roomserver/storage/shared/storage.go | 137 |
1 files changed, 74 insertions, 63 deletions
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 8c7854e8..a9cb5782 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "encoding/json" - "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" @@ -345,13 +344,15 @@ func (d *Database) GetLatestEventsForUpdate( func (d *Database) StoreEvent( ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, -) (types.RoomNID, types.StateAtEvent, error) { +) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { var ( roomNID types.RoomNID eventTypeNID types.EventTypeNID eventStateKeyNID types.EventStateKeyNID eventNID types.EventNID stateNID types.StateSnapshotNID + redactionEvent *gomatrixserverlib.Event + redactedEventID string err error ) @@ -419,11 +420,11 @@ func (d *Database) StoreEvent( if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { return err } - - return d.handleRedactions(ctx, txn, eventNID, event) + redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) + return err }) if err != nil { - return 0, types.StateAtEvent{}, err + return 0, types.StateAtEvent{}, nil, "", err } return roomNID, types.StateAtEvent{ @@ -435,7 +436,7 @@ func (d *Database) StoreEvent( }, EventNID: eventNID, }, - }, nil + }, redactionEvent, redactedEventID, nil } func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error { @@ -531,20 +532,42 @@ func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) ( // When an event is redacted, the redacted event JSON is modified to add an `unsigned.redacted_because` field. We use this field // when loading events to determine whether to apply redactions. This keeps the hot-path of reading events quick as we don't need // to cross-reference with other tables when loading. -func (d *Database) handleRedactions(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event gomatrixserverlib.Event) error { +// +// Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction. +func (d *Database) handleRedactions( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event gomatrixserverlib.Event, +) (*gomatrixserverlib.Event, string, error) { + var err error + isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil + if isRedactionEvent { + // an event which redacts itself should be ignored + if event.EventID() == event.Redacts() { + return nil, "", nil + } + + err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{ + Validated: false, + RedactionEventID: event.EventID(), + RedactsEventID: event.Redacts(), + }) + if err != nil { + return nil, "", err + } + } + redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, eventNID, event) if err != nil { - return err + return nil, "", err } if validated || redactedEvent == nil || redactionEvent == nil { // we've seen this redaction before or there is nothing to redact - return nil + return nil, "", nil } // mark the event as redacted err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent) if err != nil { - return err + return nil, "", err } if redactionsArePermanent { redactedEvent.Event = redactedEvent.Redact() @@ -552,82 +575,51 @@ func (d *Database) handleRedactions(ctx context.Context, txn *sql.Tx, eventNID t // overwrite the eventJSON table err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON()) if err != nil { - return err + return nil, "", err } - return d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true) + return &redactionEvent.Event, redactedEvent.EventID(), d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true) } // loadRedactionPair returns both the redaction event and the redacted event, else nil. -// nolint:gocyclo func (d *Database) loadRedactionPair( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event gomatrixserverlib.Event, ) (*types.Event, *types.Event, bool, error) { var redactionEvent, redactedEvent *types.Event var info *tables.RedactionInfo - var nids map[string]types.EventNID - var evs []types.Event var err error isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil + + var eventBeingRedacted string if isRedactionEvent { + eventBeingRedacted = event.Redacts() redactionEvent = &types.Event{ EventNID: eventNID, Event: event, } - // find the redacted event if one exists - info, err = d.RedactionsTable.SelectRedactedEvent(ctx, txn, event.EventID()) - if err != nil { - return nil, nil, false, err - } - if info == nil { - // we don't have the redacted event yet - return nil, nil, false, nil - } - nids, err = d.EventNIDs(ctx, []string{info.RedactsEventID}) - if err != nil { - return nil, nil, false, err - } - if len(nids) == 0 { - return nil, nil, false, fmt.Errorf("redaction: missing event NID being redacted: %+v", info) - } - evs, err = d.Events(ctx, []types.EventNID{nids[info.RedactsEventID]}) - if err != nil { - return nil, nil, false, err - } - if len(evs) != 1 { - return nil, nil, false, fmt.Errorf("redaction: missing event being redacted: %+v", info) - } - redactedEvent = &evs[0] } else { + eventBeingRedacted = event.EventID() // maybe, we'll see if we have info redactedEvent = &types.Event{ EventNID: eventNID, Event: event, } - // find the redaction event if one exists - info, err = d.RedactionsTable.SelectRedactionEvent(ctx, txn, event.EventID()) - if err != nil { - return nil, nil, false, err - } - if info == nil { - // this event is not redacted - return nil, nil, false, nil - } - nids, err = d.EventNIDs(ctx, []string{info.RedactionEventID}) - if err != nil { - return nil, nil, false, err - } - if len(nids) == 0 { - return nil, nil, false, fmt.Errorf("redaction: missing redaction event NID: %+v", info) - } - evs, err = d.Events(ctx, []types.EventNID{nids[info.RedactionEventID]}) - if err != nil { - return nil, nil, false, err - } - if len(evs) != 1 { - return nil, nil, false, fmt.Errorf("redaction: missing redaction event: %+v", info) - } - redactionEvent = &evs[0] } + + info, err = d.RedactionsTable.SelectRedactionInfoByEventBeingRedacted(ctx, txn, eventBeingRedacted) + if err != nil { + return nil, nil, false, err + } + if info == nil { + // this event hasn't been redacted or we don't have the redaction for it yet + return nil, nil, false, nil + } + + if isRedactionEvent { + redactedEvent = d.loadEvent(ctx, info.RedactsEventID) + } else { + redactionEvent = d.loadEvent(ctx, info.RedactionEventID) + } + return redactionEvent, redactedEvent, info.Validated, nil } @@ -639,3 +631,22 @@ func (d *Database) applyRedactions(events []types.Event) { } } } + +// loadEvent loads a single event or returns nil on any problems/missing event +func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event { + nids, err := d.EventNIDs(ctx, []string{eventID}) + if err != nil { + return nil + } + if len(nids) == 0 { + return nil + } + evs, err := d.Events(ctx, []types.EventNID{nids[eventID]}) + if err != nil { + return nil + } + if len(evs) != 1 { + return nil + } + return &evs[0] +} |