diff options
author | Till <2353100+S7evinK@users.noreply.github.com> | 2023-03-01 17:06:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-01 17:06:47 +0100 |
commit | 6c20f8f742a7e03710fae81df6ef98bac31da2b1 (patch) | |
tree | 202e962951dc41c949a71c7f5c1deb6d1da78843 | |
parent | 1aa70b0f56825a4a5f92c38cabb1fe841cec6e18 (diff) |
Refactor `StoreEvent`, add `MaybeRedactEvent`, create an `EventDatabase` (#2989)
This PR changes the following:
- `StoreEvent` now only stores an event (and possibly prev event),
instead of also doing redactions
- Adds a `MaybeRedactEvent` (pulled out from `StoreEvent`), which should
be called after storing events
- a few other things
34 files changed, 486 insertions, 418 deletions
diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index ac68f4bd..528de63e 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -122,6 +122,7 @@ func (s *OutputRoomEventConsumer) onMessage( if len(output.NewRoomEvent.AddsStateEventIDs) > 0 { newEventID := output.NewRoomEvent.Event.EventID() eventsReq := &api.QueryEventsByIDRequest{ + RoomID: output.NewRoomEvent.Event.RoomID(), EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)), } eventsRes := &api.QueryEventsByIDResponse{} diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 7841b3b0..f86bbc8f 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -57,7 +57,7 @@ func SendRedaction( } } - ev := roomserverAPI.GetEvent(req.Context(), rsAPI, eventID) + ev := roomserverAPI.GetEvent(req.Context(), rsAPI, roomID, eventID) if ev == nil { return util.JSONResponse{ Code: 400, diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index e3840bbc..a9cc80cb 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -62,9 +62,10 @@ func main() { panic(err) } - stateres := state.NewStateResolution(roomserverDB, &types.RoomInfo{ + roomInfo := &types.RoomInfo{ RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion), - }) + } + stateres := state.NewStateResolution(roomserverDB, roomInfo) if *difference { if len(snapshotNIDs) != 2 { @@ -87,7 +88,7 @@ func main() { } var eventEntries []types.Event - eventEntries, err = roomserverDB.Events(ctx, 0, eventNIDs) + eventEntries, err = roomserverDB.Events(ctx, roomInfo, eventNIDs) if err != nil { panic(err) } @@ -145,7 +146,7 @@ func main() { } fmt.Println("Fetching", len(eventNIDMap), "state events") - eventEntries, err := roomserverDB.Events(ctx, 0, eventNIDs) + eventEntries, err := roomserverDB.Events(ctx, roomInfo, eventNIDs) if err != nil { panic(err) } @@ -165,7 +166,7 @@ func main() { } fmt.Println("Fetching", len(authEventIDs), "auth events") - authEventEntries, err := roomserverDB.EventsFromIDs(ctx, 0, authEventIDs) + authEventEntries, err := roomserverDB.EventsFromIDs(ctx, roomInfo, authEventIDs) if err != nil { panic(err) } diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 82a4db3f..378b96ba 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -173,6 +173,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew // Finally, work out if there are any more events missing. if len(missingEventIDs) > 0 { eventsReq := &api.QueryEventsByIDRequest{ + RoomID: ore.Event.RoomID(), EventIDs: missingEventIDs, } eventsRes := &api.QueryEventsByIDResponse{} @@ -483,7 +484,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents( // At this point the missing events are neither the event itself nor are // they present in our local database. Our only option is to fetch them // from the roomserver using the query API. - eventReq := api.QueryEventsByIDRequest{EventIDs: missing} + eventReq := api.QueryEventsByIDRequest{EventIDs: missing, RoomID: event.RoomID()} var eventResp api.QueryEventsByIDResponse if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil { return nil, err diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index 868785a9..2f1f3baf 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -36,7 +36,7 @@ func GetEventAuth( return *err } - event, resErr := fetchEvent(ctx, rsAPI, eventID) + event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID) if resErr != nil { return *resErr } diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go index 6168912b..b4129241 100644 --- a/federationapi/routing/events.go +++ b/federationapi/routing/events.go @@ -20,10 +20,11 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/roomserver/api" ) // GetEvent returns the requested event @@ -38,7 +39,9 @@ func GetEvent( if err != nil { return *err } - event, err := fetchEvent(ctx, rsAPI, eventID) + // /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string, + // which results in `QueryEventsByID` to first get the event and use that to determine the roomID. + event, err := fetchEvent(ctx, rsAPI, "", eventID) if err != nil { return *err } @@ -60,21 +63,13 @@ func allowedToSeeEvent( rsAPI api.FederationRoomserverAPI, eventID string, ) *util.JSONResponse { - var authResponse api.QueryServerAllowedToSeeEventResponse - err := rsAPI.QueryServerAllowedToSeeEvent( - ctx, - &api.QueryServerAllowedToSeeEventRequest{ - EventID: eventID, - ServerName: origin, - }, - &authResponse, - ) + allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID) if err != nil { resErr := util.ErrorResponse(err) return &resErr } - if !authResponse.AllowedToSeeEvent { + if !allowed { resErr := util.MessageResponse(http.StatusForbidden, "server not allowed to see event") return &resErr } @@ -83,11 +78,11 @@ func allowedToSeeEvent( } // fetchEvent fetches the event without auth checks. Returns an error if the event cannot be found. -func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) { +func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, roomID, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) { var eventsResponse api.QueryEventsByIDResponse err := rsAPI.QueryEventsByID( ctx, - &api.QueryEventsByIDRequest{EventIDs: []string{eventID}}, + &api.QueryEventsByIDRequest{EventIDs: []string{eventID}, RoomID: roomID}, &eventsResponse, ) if err != nil { diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index 1d08d0a8..1120cf26 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -107,7 +107,7 @@ func getState( return nil, nil, err } - event, resErr := fetchEvent(ctx, rsAPI, eventID) + event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID) if resErr != nil { return nil, nil, resErr } diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index 223282a2..d6c79e98 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -16,7 +16,9 @@ // Hooks can only be run in monolith mode. package hooks -import "sync" +import ( + "sync" +) const ( // KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 73732ae3..f6d003a4 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -54,7 +54,8 @@ type QueryBulkStateContentAPI interface { } type QueryEventsAPI interface { - // Query a list of events by event ID. + // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine + // which room to use by querying the first events roomID. QueryEventsByID( ctx context.Context, req *QueryEventsByIDRequest, @@ -71,7 +72,8 @@ type SyncRoomserverAPI interface { QueryBulkStateContentAPI // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error - // Query a list of events by event ID. + // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine + // which room to use by querying the first events roomID. QueryEventsByID( ctx context.Context, req *QueryEventsByIDRequest, @@ -108,7 +110,8 @@ type SyncRoomserverAPI interface { } type AppserviceRoomserverAPI interface { - // Query a list of events by event ID. + // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine + // which room to use by querying the first events roomID. QueryEventsByID( ctx context.Context, req *QueryEventsByIDRequest, @@ -182,6 +185,8 @@ type FederationRoomserverAPI interface { QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error + // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine + // which room to use by querying the first events roomID. QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error // Query to get state and auth chain for a (potentially hypothetical) event. // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate @@ -193,7 +198,7 @@ type FederationRoomserverAPI interface { // Query missing events for a room from roomserver QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error // Query whether a server is allowed to see an event - QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error + QueryServerAllowedToSeeEvent(ctx context.Context, serverName gomatrixserverlib.ServerName, eventID string) (allowed bool, err error) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 4ef548e1..24722db0 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -86,6 +86,9 @@ type QueryStateAfterEventsResponse struct { // QueryEventsByIDRequest is a request to QueryEventsByID type QueryEventsByIDRequest struct { + // The roomID to query events for. If this is empty, we first try to fetch the roomID from the database + // as this is needed for further processing/parsing events. + RoomID string `json:"room_id"` // The event IDs to look up. EventIDs []string `json:"event_ids"` } diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 252be557..f220560e 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -108,9 +108,10 @@ func SendInputRoomEvents( } // GetEvent returns the event or nil, even on errors. -func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, eventID string) *gomatrixserverlib.HeaderedEvent { +func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID, eventID string) *gomatrixserverlib.HeaderedEvent { var res QueryEventsByIDResponse err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{ + RoomID: roomID, EventIDs: []string{eventID}, }, &res) if err != nil { diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 27c8dd8f..9defe794 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -67,7 +67,7 @@ func CheckForSoftFail( stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomNID, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) if err != nil { return true, fmt.Errorf("loadAuthEvents: %w", err) } @@ -85,7 +85,7 @@ func CheckForSoftFail( func CheckAuthEvents( ctx context.Context, db storage.RoomDatabase, - roomNID types.RoomNID, + roomInfo *types.RoomInfo, event *gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { @@ -100,7 +100,7 @@ func CheckAuthEvents( stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, roomNID, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries) if err != nil { return nil, fmt.Errorf("loadAuthEvents: %w", err) } @@ -193,7 +193,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * func loadAuthEvents( ctx context.Context, db state.StateResolutionStorage, - roomNID types.RoomNID, + roomInfo *types.RoomInfo, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { @@ -216,7 +216,7 @@ func loadAuthEvents( eventNIDs = append(eventNIDs, eventNID) } } - if result.events, err = db.Events(ctx, roomNID, eventNIDs); err != nil { + if result.events, err = db.Events(ctx, roomInfo, eventNIDs); err != nil { return } roomID := "" diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index ee1610cf..9a70bcc9 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -85,7 +85,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam return false, err } - events, err := db.Events(ctx, info.RoomNID, eventNIDs) + events, err := db.Events(ctx, info, eventNIDs) if err != nil { return false, err } @@ -157,7 +157,7 @@ func IsInvitePending( // only keep the "m.room.member" events with a "join" membership. These events are returned. // Returns an error if there was an issue fetching the events. func GetMembershipsAtState( - ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, joinedOnly bool, + ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, joinedOnly bool, ) ([]types.Event, error) { var eventNIDs types.EventNIDs @@ -177,7 +177,7 @@ func GetMembershipsAtState( util.Unique(eventNIDs) // Get all of the events in this state - stateEvents, err := db.Events(ctx, roomNID, eventNIDs) + stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) if err != nil { return nil, err } @@ -227,9 +227,9 @@ func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types } func LoadEvents( - ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, eventNIDs []types.EventNID, + ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID, ) ([]*gomatrixserverlib.Event, error) { - stateEvents, err := db.Events(ctx, roomNID, eventNIDs) + stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) if err != nil { return nil, err } @@ -242,13 +242,13 @@ func LoadEvents( } func LoadStateEvents( - ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, + ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, ) ([]*gomatrixserverlib.Event, error) { eventNIDs := make([]types.EventNID, len(stateEntries)) for i := range stateEntries { eventNIDs[i] = stateEntries[i].EventNID } - return LoadEvents(ctx, db, roomNID, eventNIDs) + return LoadEvents(ctx, db, roomInfo, eventNIDs) } func CheckServerAllowedToSeeEvent( @@ -326,7 +326,7 @@ func slowGetHistoryVisibilityState( return nil, nil } - return LoadStateEvents(ctx, db, info.RoomNID, filteredEntries) + return LoadStateEvents(ctx, db, info, filteredEntries) } // TODO: Remove this when we have tests to assert correctness of this function @@ -366,7 +366,7 @@ BFSLoop: next = make([]string, 0) } // Retrieve the events to process from the database. - events, err = db.EventsFromIDs(ctx, info.RoomNID, front) + events, err = db.EventsFromIDs(ctx, info, front) if err != nil { return resultNIDs, redactEventIDs, err } @@ -467,7 +467,7 @@ func QueryLatestEventsAndState( return err } - stateEvents, err := LoadStateEvents(ctx, db, roomInfo.RoomNID, stateEntries) + stateEvents, err := LoadStateEvents(ctx, db, roomInfo, stateEntries) if err != nil { return err } diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go index 62730df1..c056e704 100644 --- a/roomserver/internal/helpers/helpers_test.go +++ b/roomserver/internal/helpers/helpers_test.go @@ -4,9 +4,10 @@ import ( "context" "testing" - "github.com/matrix-org/dendrite/roomserver/types" "github.com/stretchr/testify/assert" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/test" @@ -38,9 +39,9 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { var authNIDs []types.EventNID for _, x := range room.Events() { - roomNID, err := db.GetOrCreateRoomNID(context.Background(), x.Unwrap()) + roomInfo, err := db.GetOrCreateRoomInfo(context.Background(), x.Unwrap()) assert.NoError(t, err) - assert.Greater(t, roomNID, types.RoomNID(0)) + assert.NotNil(t, roomInfo) eventTypeNID, err := db.GetOrCreateEventTypeNID(context.Background(), x.Type()) assert.NoError(t, err) @@ -49,7 +50,7 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey()) assert.NoError(t, err) - evNID, _, _, _, err := db.StoreEvent(context.Background(), x.Event, roomNID, eventTypeNID, eventStateKeyNID, authNIDs, false) + evNID, _, err := db.StoreEvent(context.Background(), x.Event, roomInfo, eventTypeNID, eventStateKeyNID, authNIDs, false) assert.NoError(t, err) authNIDs = append(authNIDs, evNID) } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index fe35efb2..ede345a9 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -24,9 +24,10 @@ import ( "fmt" "time" - "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/tidwall/gjson" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/opentracing/opentracing-go" @@ -274,8 +275,10 @@ func (r *Inputer) processRoomEvent( // Check if the event is allowed by its auth events. If it isn't then // we consider the event to be "rejected" — it will still be persisted. + redactAllowed := true if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { isRejected = true + redactAllowed = false rejectionErr = err logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) } @@ -323,7 +326,7 @@ func (r *Inputer) processRoomEvent( // burning CPU time. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent { - historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo.RoomNID, input, missingPrev) + historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo, input, missingPrev) if err != nil { return fmt.Errorf("r.processStateBefore: %w", err) } @@ -332,9 +335,11 @@ func (r *Inputer) processRoomEvent( } } - roomNID, err := r.DB.GetOrCreateRoomNID(ctx, event) - if err != nil { - return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err) + if roomInfo == nil { + roomInfo, err = r.DB.GetOrCreateRoomInfo(ctx, event) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateRoomInfo: %w", err) + } } eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type()) @@ -348,15 +353,24 @@ func (r *Inputer) processRoomEvent( } // Store the event. - _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) + eventNID, stateAtEvent, err := r.DB.StoreEvent(ctx, event, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) if err != nil { return fmt.Errorf("updater.StoreEvent: %w", err) } // if storing this event results in it being redacted then do so. - if !isRejected && redactedEventID == event.EventID() { - if err = eventutil.RedactEvent(redactionEvent, event); err != nil { - return fmt.Errorf("eventutil.RedactEvent: %w", rerr) + var ( + redactedEventID string + redactionEvent *gomatrixserverlib.Event + redactedEvent *gomatrixserverlib.Event + ) + if !isRejected && !isCreateEvent { + redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, redactAllowed) + if err != nil { + return err + } + if redactedEvent != nil { + redactedEventID = redactedEvent.EventID() } } @@ -489,7 +503,7 @@ func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixse // nolint:nakedret func (r *Inputer) processStateBefore( ctx context.Context, - roomNID types.RoomNID, + roomInfo *types.RoomInfo, input *api.InputRoomEvent, missingPrev bool, ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { @@ -505,7 +519,7 @@ func (r *Inputer) processStateBefore( case input.HasState: // If we're overriding the state then we need to go and retrieve // them from the database. It's a hard error if they are missing. - stateEvents, err := r.DB.EventsFromIDs(ctx, roomNID, input.StateEventIDs) + stateEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, input.StateEventIDs) if err != nil { return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err) } @@ -604,7 +618,7 @@ func (r *Inputer) fetchAuthEvents( } for _, authEventID := range authEventIDs { - authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo.RoomNID, []string{authEventID}) + authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, []string{authEventID}) if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { unknown[authEventID] = struct{}{} continue @@ -690,9 +704,11 @@ nextAuthEvent: logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) } - roomNID, err := r.DB.GetOrCreateRoomNID(ctx, authEvent) - if err != nil { - return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err) + if roomInfo == nil { + roomInfo, err = r.DB.GetOrCreateRoomInfo(ctx, authEvent) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateRoomInfo: %w", err) + } } eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type()) @@ -706,7 +722,7 @@ nextAuthEvent: } // Finally, store the event in the database. - eventNID, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) + eventNID, _, err := r.DB.StoreEvent(ctx, authEvent, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) if err != nil { return fmt.Errorf("updater.StoreEvent: %w", err) } @@ -782,7 +798,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event return err } - memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, membershipNIDs) + memberEvents, err := r.DB.Events(ctx, roomInfo, membershipNIDs) if err != nil { return err } diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 99a01255..e1dfa6cf 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -53,7 +53,7 @@ func (r *Inputer) updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := updater.Events(ctx, 0, eventNIDs) + events, err := updater.Events(ctx, nil, eventNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index c8b7d31d..9627f15a 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -395,7 +395,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even for _, entry := range stateEntries { stateEventNIDs = append(stateEventNIDs, entry.EventNID) } - stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomNID, stateEventNIDs) + stateEvents, err := t.db.Events(ctx, t.roomInfo, stateEventNIDs) if err != nil { t.log.WithError(err).Warnf("failed to load state events locally") return nil @@ -432,7 +432,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even missingEventList = append(missingEventList, evID) } t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events") - events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList) if err != nil { return nil } @@ -702,7 +702,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo } t.haveEventsMutex.Unlock() - events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList) if err != nil { return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err) } @@ -844,7 +844,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs if localFirst { // fetch from the roomserver - events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, []string{missingEventID}) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo, []string{missingEventID}) if err != nil { t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) } else if len(events) == 1 { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 2efe2255..45089bdd 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -70,7 +70,7 @@ func (r *Admin) PerformAdminEvacuateRoom( return nil } - memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, memberNIDs) + memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 3a3a049d..411f4202 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -23,7 +23,6 @@ import ( "github.com/sirupsen/logrus" federationAPI "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/auth" "github.com/matrix-org/dendrite/roomserver/internal/helpers" @@ -86,7 +85,7 @@ func (r *Backfiller) PerformBackfill( // Retrieve events from the list that was filled previously. If we fail to get // events from the database then attempt once to get them from federation instead. var loadedEvents []*gomatrixserverlib.Event - loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs) + loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info, resultNIDs) if err != nil { if _, ok := err.(types.MissingEventError); ok { return r.backfillViaFederation(ctx, request, response) @@ -473,7 +472,7 @@ FindSuccessor: // Retrieve all "m.room.member" state events of "join" membership, which // contains the list of users in the room before the event, therefore all // the servers in it at that moment. - memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info.RoomNID, stateEntries, true) + memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info, stateEntries, true) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") return nil @@ -532,7 +531,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, roomNID = nid.RoomNID } } - eventsWithNids, err := b.db.Events(ctx, roomNID, eventNIDs) + eventsWithNids, err := b.db.Events(ctx, &b.roomInfo, eventNIDs) if err != nil { logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") return nil, err @@ -562,7 +561,7 @@ func joinEventsFromHistoryVisibility( } // Get all of the events in this state - stateEvents, err := db.Events(ctx, roomInfo.RoomNID, eventNIDs) + stateEvents, err := db.Events(ctx, roomInfo, eventNIDs) if err != nil { // even though the default should be shared, restricting the visibility to joined // feels more secure here. @@ -585,7 +584,7 @@ func joinEventsFromHistoryVisibility( if err != nil { return nil, visibility, err } - evs, err := db.Events(ctx, roomInfo.RoomNID, joinEventNIDs) + evs, err := db.Events(ctx, roomInfo, joinEventNIDs) return evs, visibility, err } @@ -606,7 +605,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs i++ } - roomNID, err = db.GetOrCreateRoomNID(ctx, ev.Unwrap()) + roomInfo, err := db.GetOrCreateRoomInfo(ctx, ev.Unwrap()) if err != nil { logrus.WithError(err).Error("failed to get or create roomNID") continue @@ -624,23 +623,22 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs continue } - var redactedEventID string - var redactionEvent *gomatrixserverlib.Event - eventNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), roomNID, eventTypeNID, eventStateKeyNID, authNids, false) + eventNID, _, err = db.StoreEvent(ctx, ev.Unwrap(), roomInfo, eventTypeNID, eventStateKeyNID, authNids, false) if err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") continue } + + _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.Unwrap(), true) + if err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") + continue + } // If storing this event results in it being redacted, then do so. // It's also possible for this event to be a redaction which results in another event being // redacted, which we don't care about since we aren't returning it in this backfill. - if redactedEventID == ev.EventID() { - eventToRedact := ev.Unwrap() - if err := eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") - continue - } - ev = eventToRedact.Headered(ev.RoomVersion) + if redactedEvent != nil && redactedEvent.EventID() == ev.EventID() { + ev = redactedEvent.Headered(ev.RoomVersion) events[j] = ev } backfilledEventMap[ev.EventID()] = types.Event{ diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go index 9ac9edc4..1fb6eb43 100644 --- a/roomserver/internal/perform/perform_inbound_peek.go +++ b/roomserver/internal/perform/perform_inbound_peek.go @@ -64,7 +64,7 @@ func (r *InboundPeeker) PerformInboundPeek( if err != nil { return err } - latestEvents, err := r.DB.EventsFromIDs(ctx, info.RoomNID, []string{latestEventRefs[0].EventID}) + latestEvents, err := r.DB.EventsFromIDs(ctx, info, []string{latestEventRefs[0].EventID}) if err != nil { return err } @@ -88,7 +88,7 @@ func (r *InboundPeeker) PerformInboundPeek( if err != nil { return err } - stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries) + stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info, stateEntries) if err != nil { return err } @@ -100,7 +100,7 @@ func (r *InboundPeeker) PerformInboundPeek( } authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe - authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) + authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs) if err != nil { return err } diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 118e1b87..13d13f7b 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -194,7 +194,7 @@ func (r *Inviter) PerformInvite( // try and see if the user is allowed to make this invite. We can't do // this for invites coming in over federation - we have to take those on // trust. - _, err = helpers.CheckAuthEvents(ctx, r.DB, info.RoomNID, event, event.AuthEventIDs()) + _, err = helpers.CheckAuthEvents(ctx, r.DB, info, event, event.AuthEventIDs()) if err != nil { logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( "processInviteEvent.checkAuthEvents failed for event", @@ -291,7 +291,7 @@ func buildInviteStrippedState( for _, stateNID := range stateEntries { stateNIDs = append(stateNIDs, stateNID.EventNID) } - stateEvents, err := db.Events(ctx, info.RoomNID, stateNIDs) + stateEvents, err := db.Events(ctx, info, stateNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index ac34e0ff..c5b74422 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -21,11 +21,12 @@ import ( "errors" "fmt" - "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/roomserver/acls" @@ -102,7 +103,7 @@ func (r *Queryer) QueryStateAfterEvents( return err } - stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries) + stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info, stateEntries) if err != nil { return err } @@ -114,7 +115,7 @@ func (r *Queryer) QueryStateAfterEvents( } authEventIDs = util.UniqueStrings(authEventIDs) - authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) + authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs) if err != nil { return fmt.Errorf("getAuthChain: %w", err) } @@ -132,24 +133,46 @@ func (r *Queryer) QueryStateAfterEvents( return nil } -// QueryEventsByID implements api.RoomserverInternalAPI +// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine +// which room to use by querying the first events roomID. func (r *Queryer) QueryEventsByID( ctx context.Context, request *api.QueryEventsByIDRequest, response *api.QueryEventsByIDResponse, ) error { - events, err := r.DB.EventsFromIDs(ctx, 0, request.EventIDs) + if len(request.EventIDs) == 0 { + return nil + } + var err error + // We didn't receive a room ID, we need to fetch it first before we can continue. + // This happens for e.g. ` /_matrix/federation/v1/event/{eventId}` + var roomInfo *types.RoomInfo + if request.RoomID == "" { + var eventNIDs map[string]types.EventMetadata + eventNIDs, err = r.DB.EventNIDs(ctx, []string{request.EventIDs[0]}) + if err != nil { + return err + } + if len(eventNIDs) == 0 { + return nil + } + roomInfo, err = r.DB.RoomInfoByNID(ctx, eventNIDs[request.EventIDs[0]].RoomNID) + } else { + roomInfo, err = r.DB.RoomInfo(ctx, request.RoomID) + } + if err != nil { + return err + } + if roomInfo == nil { + return nil + } + events, err := r.DB.EventsFromIDs(ctx, roomInfo, request.EventIDs) if err != nil { return err } for _, event := range events { - roomVersion, verr := r.roomVersion(event.RoomID()) - if verr != nil { - return verr - } - - response.Events = append(response.Events, event.Headered(roomVersion)) + response.Events = append(response.Events, event.Headered(roomInfo.RoomVersion)) } return nil @@ -186,7 +209,7 @@ func (r *Queryer) QueryMembershipForUser( response.IsInRoom = stillInRoom response.HasBeenInRoom = true - evs, err := r.DB.Events(ctx, info.RoomNID, []types.EventNID{membershipEventNID}) + evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID}) if err != nil { return err } @@ -268,10 +291,10 @@ func (r *Queryer) QueryMembershipAtEvent( // once. If we have more than one membership event, we need to get the state for each state entry. if canShortCircuit { if len(memberships) == 0 { - memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false) + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false) } } else { - memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false) + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false) } if err != nil { return fmt.Errorf("unable to get memberships at state: %w", err) @@ -318,7 +341,7 @@ func (r *Queryer) QueryMembershipsForRoom( } return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) } - events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs) + events, err = r.DB.Events(ctx, info, eventNIDs) if err != nil { return fmt.Errorf("r.DB.Events: %w", err) } @@ -357,14 +380,14 @@ func (r *Queryer) QueryMembershipsForRoom( return err } - events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs) + events, err = r.DB.Events(ctx, info, eventNIDs) } else { stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err } - events, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntries, request.JoinedOnly) + events, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntries, request.JoinedOnly) } if err != nil { @@ -412,39 +435,39 @@ func (r *Queryer) QueryServerJoinedToRoom( // QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI func (r *Queryer) QueryServerAllowedToSeeEvent( ctx context.Context, - request *api.QueryServerAllowedToSeeEventRequest, - response *api.QueryServerAllowedToSeeEventResponse, -) (err error) { - events, err := r.DB.EventsFromIDs(ctx, 0, []string{request.EventID}) + serverName gomatrixserverlib.ServerName, + eventID string, +) (allowed bool, err error) { + events, err := r.DB.EventNIDs(ctx, []string{eventID}) if err != nil { return } if len(events) == 0 { - response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see - return - } - roomID := events[0].RoomID() - - inRoomReq := &api.QueryServerJoinedToRoomRequest{ - RoomID: roomID, - ServerName: request.ServerName, + return allowed, nil } - inRoomRes := &api.QueryServerJoinedToRoomResponse{} - if err = r.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil { - return fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err) - } - - info, err := r.DB.RoomInfo(ctx, roomID) + info, err := r.DB.RoomInfoByNID(ctx, events[eventID].RoomNID) if err != nil { - return err + return allowed, err } if info == nil || info.IsStub() { - return nil + return allowed, nil } - response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent( - ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom, + var isInRoom bool + if r.IsLocalServerName(serverName) || serverName == "" { + isInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID) + if err != nil { + return allowed, fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err) + } + } else { + isInRoom, err = r.DB.GetServerInRoom(ctx, info.RoomNID, serverName) + if err != nil { + return allowed, fmt.Errorf("r.DB.GetServerInRoom: %w", err) + } + } + + return helpers.CheckServerAllowedToSeeEvent( + ctx, r.DB, info, eventID, serverName, isInRoom, ) - return } // QueryMissingEvents implements api.RoomserverInternalAPI @@ -466,19 +489,22 @@ func (r *Queryer) QueryMissingEvents( eventsToFilter[id] = true } } - events, err := r.DB.EventsFromIDs(ctx, 0, front) + if len(front) == 0 { + return nil // no events to query, give up. + } + events, err := r.DB.EventNIDs(ctx, []string{front[0]}) if err != nil { return err } if len(events) == 0 { return nil // we are missing the events being asked to search from, give up. } - info, err := r.DB.RoomInfo(ctx, events[0].RoomID()) + info, err := r.DB.RoomInfoByNID(ctx, events[front[0]].RoomNID) if err != nil { return err } if info == nil || info.IsStub() { - return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) + return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID) } resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) @@ -486,7 +512,7 @@ func (r *Queryer) QueryMissingEvents( return err } - loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs) + loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info, resultNIDs) if err != nil { return err } @@ -529,7 +555,7 @@ func (r *Queryer) QueryStateAndAuthChain( // TODO: this probably means it should be a different query operation... if request.OnlyFetchAuthChain { var authEvents []*gomatrixserverlib.Event - authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, request.AuthEventIDs) + authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs) if err != nil { return err } @@ -556,7 +582,7 @@ func (r *Queryer) QueryStateAndAuthChain( } authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe - authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) + authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs) if err != nil { return err } @@ -611,18 +637,18 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI return nil, rejected, false, err } - events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo.RoomNID, stateEntries) + events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo, stateEntries) return events, rejected, false, err } -type eventsFromIDs func(context.Context, types.RoomNID, []string) ([]types.Event, error) +type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Event, error) // GetAuthChain fetches the auth chain for the given auth events. An auth chain // is the list of all events that are referenced in the auth_events section, and // all their auth_events, recursively. The returned set of events contain the // given events. Will *not* error if we don't have all auth events. func GetAuthChain( - ctx context.Context, fn eventsFromIDs, authEventIDs []string, + ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string, ) ([]*gomatrixserverlib.Event, error) { // List of event IDs to fetch. On each pass, these events will be requested // from the database and the `eventsToFetch` will be updated with any new @@ -633,7 +659,7 @@ func GetAuthChain( for len(eventsToFetch) > 0 { // Try to retrieve the events from the database. - events, err := fn(ctx, 0, eventsToFetch) + events, err := fn(ctx, roomInfo, eventsToFetch) if err != nil { return nil, err } @@ -852,7 +878,7 @@ func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryS } func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error { - chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs) + chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, nil, req.EventIDs) if err != nil { return err } @@ -971,7 +997,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query // For each of the joined users, let's see if we can get a valid // membership event. for _, joinNID := range joinNIDs { - events, err := r.DB.Events(ctx, roomInfo.RoomNID, []types.EventNID{joinNID}) + events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID}) if err != nil || len(events) != 1 { continue } diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index 16761157..265f326d 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -80,7 +80,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error { } // EventsFromIDs implements RoomserverInternalAPIEventDB -func (db *getEventDB) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) (res []types.Event, err error) { +func (db *getEventDB) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) (res []types.Event, err error) { for _, evID := range eventIDs { res = append(res, types.Event{ EventNID: 0, @@ -106,7 +106,7 @@ func TestGetAuthChainSingle(t *testing.T) { t.Fatalf("Failed to add events to db: %v", err) } - result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e"}) + result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e"}) if err != nil { t.Fatalf("getAuthChain failed: %v", err) } @@ -139,7 +139,7 @@ func TestGetAuthChainMultiple(t *testing.T) { t.Fatalf("Failed to add events to db: %v", err) } - result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e", "f"}) + result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e", "f"}) if err != nil { t.Fatalf("getAuthChain failed: %v", err) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 304311c4..cfa27e54 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -278,6 +278,16 @@ func TestPurgeRoom(t *testing.T) { if roomInfo == nil { t.Fatalf("room does not exist") } + + // + roomInfo2, err := db.RoomInfoByNID(ctx, roomInfo.RoomNID) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(roomInfo, roomInfo2) { + t.Fatalf("expected roomInfos to be the same, but they aren't") + } + // remember the roomInfo before purging existingRoomInfo := roomInfo @@ -333,6 +343,10 @@ func TestPurgeRoom(t *testing.T) { if roomInfo != nil { t.Fatalf("room should not exist after purging: %+v", roomInfo) } + roomInfo2, err = db.RoomInfoByNID(ctx, existingRoomInfo.RoomNID) + if err == nil { + t.Fatalf("expected room to not exist, but it does: %#v", roomInfo2) + } // validation below diff --git a/roomserver/state/state.go b/roomserver/state/state.go index cec542d7..9af2f857 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -41,8 +41,8 @@ type StateResolutionStorage interface { StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) - Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) - EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) + Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) + EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) } type StateResolution struct { @@ -975,7 +975,7 @@ func (v *StateResolution) resolveConflictsV2( // Store the newly found auth events in the auth set for this event. var authEventMap map[string]types.StateEntry - authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo.RoomNID, conflictedEvent, knownAuthEvents) + authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo, conflictedEvent, knownAuthEvents) if err != nil { return err } @@ -1091,7 +1091,7 @@ func (v *StateResolution) loadStateEvents( eventNIDs = append(eventNIDs, entry.EventNID) } } - events, err := v.db.Events(ctx, v.roomInfo.RoomNID, eventNIDs) + events, err := v.db.Events(ctx, v.roomInfo, eventNIDs) if err != nil { return nil, nil, err } @@ -1120,7 +1120,7 @@ type authEventLoader struct { // loadAuthEvents loads all of the auth events for a given event recursively, // along with a map that contains state entries for all of the auth events. func (l *authEventLoader) loadAuthEvents( - ctx context.Context, roomNID types.RoomNID, event *gomatrixserverlib.Event, eventMap map[string]types.Event, + ctx context.Context, roomInfo *types.RoomInfo, event *gomatrixserverlib.Event, eventMap map[string]types.Event, ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { l.Lock() defer l.Unlock() @@ -1155,7 +1155,7 @@ func (l *authEventLoader) loadAuthEvents( // If we need to get events from the database, go and fetch // those now. if len(l.lookupFromDB) > 0 { - eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomNID, l.lookupFromDB) + eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomInfo, l.lookupFromDB) if err != nil { return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 88ec5667..a41a8a9b 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -29,6 +29,7 @@ type Database interface { SupportsConcurrentRoomInputs() bool // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) + RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) // Store the room state at an event in the database AddState( ctx context.Context, @@ -69,12 +70,12 @@ type Database interface { ) ([]types.StateEntryList, error) // Look up the Events for a list of numeric event IDs. // Returns a sorted list of events. - Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) + Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) - // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. - StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) + // Stores a matrix room event in the database. Returns the room NID, the state snapshot or an error. + StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) // Look up the state entries for a list of string event IDs // Returns an error if the there is an error talking to the database // Returns a types.MissingEventError if the event IDs aren't in the database. @@ -135,7 +136,7 @@ type Database interface { // EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was // not found. // Returns an error if the retrieval went wrong. - EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) + EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) // Publish or unpublish a room from the room directory. PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error // Returns a list of room IDs for rooms which are published. @@ -179,36 +180,53 @@ type Database interface { GetMembershipForHistoryVisibility( ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, ) (map[string]*gomatrixserverlib.HeaderedEvent, error) - GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error) + GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) + MaybeRedactEvent( + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool, + ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) } type RoomDatabase interface { + EventDatabase // RoomInfo returns room information for the given room ID, or nil if there is no room. RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) + RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) // IsEventRejected returns true if the event is known and rejected. IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error) MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error) - // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. - StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error) - StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error) GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error) StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) - SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) - StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) - Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) - EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) - EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) - EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) - GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error) + GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) +} + +type EventDatabase interface { + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error) + EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventMetadata, error) + SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) + Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) + // MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error + // (nil if there was nothing to do) + MaybeRedactEvent( + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool, + ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) + StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error) } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 87208438..d98a5cf9 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -194,23 +194,28 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room return err } d.Database = shared.Database{ - DB: db, - Cache: cache, - Writer: writer, - EventTypesTable: eventTypes, - EventStateKeysTable: eventStateKeys, - EventJSONTable: eventJSON, - EventsTable: events, - RoomsTable: rooms, - StateBlockTable: stateBlock, - StateSnapshotTable: stateSnapshot, - PrevEventsTable: prevEvents, - RoomAliasesTable: roomAliases, - InvitesTable: invites, - MembershipTable: membership, - PublishedTable: published, - RedactionsTable: redactions, - Purge: purge, + DB: db, + EventDatabase: shared.EventDatabase{ + DB: db, + Cache: cache, + Writer: writer, + EventsTable: events, + EventJSONTable: eventJSON, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + PrevEventsTable: prevEvents, + RedactionsTable: redactions, + }, + Cache: cache, + Writer: writer, + RoomsTable: rooms, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + PublishedTable: published, + Purge: purge, } return nil } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 5006c3c5..dc1db082 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -116,8 +116,8 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent }) } -func (u *RoomUpdater) Events(ctx context.Context, _ types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) { - return u.d.events(ctx, u.txn, u.roomInfo.RoomNID, eventNIDs) +func (u *RoomUpdater) Events(ctx context.Context, _ *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) { + return u.d.events(ctx, u.txn, u.roomInfo, eventNIDs) } func (u *RoomUpdater) SnapshotNIDFromEventID( @@ -195,8 +195,8 @@ func (u *RoomUpdater) StateAtEventIDs( return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) } -func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) { - return u.d.eventsFromIDs(ctx, u.txn, roomNID, eventIDs, NoFilter) +func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, u.roomInfo, eventIDs, NoFilter) } // IsReferenced implements types.RoomRecentEventsUpdater diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index aac5bc36..be3f228d 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -9,7 +9,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/caching" @@ -28,6 +27,23 @@ import ( const redactionsArePermanent = true type Database struct { + DB *sql.DB + EventDatabase + Cache caching.RoomServerCaches + Writer sqlutil.Writer + RoomsTable tables.Rooms + StateSnapshotTable tables.StateSnapshot + StateBlockTable tables.StateBlock + RoomAliasesTable tables.RoomAliases + InvitesTable tables.Invites + MembershipTable tables.Membership + PublishedTable tables.Published + Purge tables.Purge + GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) +} + +// EventDatabase contains all tables needed to work with events +type EventDatabase struct { DB *sql.DB Cache caching.RoomServerCaches Writer sqlutil.Writer @@ -35,17 +51,8 @@ type Database struct { EventJSONTable tables.EventJSON EventTypesTable tables.EventTypes EventStateKeysTable tables.EventStateKeys - RoomsTable tables.Rooms - StateSnapshotTable tables.StateSnapshot - StateBlockTable tables.StateBlock - RoomAliasesTable tables.RoomAliases PrevEventsTable tables.PreviousEvents - InvitesTable tables.Invites - MembershipTable tables.Membership - PublishedTable tables.Published RedactionsTable tables.Redactions - Purge tables.Purge - GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } func (d *Database) SupportsConcurrentRoomInputs() bool { @@ -58,13 +65,13 @@ func (d *Database) GetMembershipForHistoryVisibility( return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...) } -func (d *Database) EventTypeNIDs( +func (d *EventDatabase) EventTypeNIDs( ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { return d.eventTypeNIDs(ctx, nil, eventTypes) } -func (d *Database) eventTypeNIDs( +func (d *EventDatabase) eventTypeNIDs( ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { result := make(map[string]types.EventTypeNID) @@ -91,7 +98,7 @@ func (d *Database) eventTypeNIDs( return result, nil } -func (d *Database) EventStateKeys( +func (d *EventDatabase) EventStateKeys( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) @@ -116,13 +123,13 @@ func (d *Database) EventStateKeys( return result, nil } -func (d *Database) EventStateKeyNIDs( +func (d *EventDatabase) EventStateKeyNIDs( ctx context.Context, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { return d.eventStateKeyNIDs(ctx, nil, eventStateKeys) } -func (d *Database) eventStateKeyNIDs( +func (d *EventDatabase) eventStateKeyNIDs( ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) @@ -174,7 +181,7 @@ func (d *Database) eventStateKeyNIDs( return result, nil } -func (d *Database) StateEntriesForEventIDs( +func (d *EventDatabase) StateEntriesForEventIDs( ctx context.Context, eventIDs []string, excludeRejected bool, ) ([]types.StateEntry, error) { return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs, excludeRejected) @@ -213,6 +220,17 @@ func (d *Database) stateEntriesForTuples( return lists, nil } +func (d *Database) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) { + roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{roomNID}) + if err != nil { + return nil, err + } + if len(roomIDs) == 0 { + return nil, fmt.Errorf("room does not exist") + } + return d.roomInfo(ctx, nil, roomIDs[0]) +} + func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { return d.roomInfo(ctx, nil, roomID) } @@ -292,7 +310,7 @@ func (d *Database) addState( return } -func (d *Database) EventNIDs( +func (d *EventDatabase) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventMetadata, error) { return d.eventNIDs(ctx, nil, eventIDs, NoFilter) @@ -305,7 +323,7 @@ const ( FilterUnsentOnly UnsentFilter = true ) -func (d *Database) eventNIDs( +func (d *EventDatabase) eventNIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter, ) (map[string]types.EventMetadata, error) { switch filter { @@ -318,7 +336,7 @@ func (d *Database) eventNIDs( } } -func (d *Database) SetState( +func (d *EventDatabase) SetState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -326,19 +344,19 @@ func (d *Database) SetState( }) } -func (d *Database) StateAtEventIDs( +func (d *EventDatabase) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs) } -func (d *Database) SnapshotNIDFromEventID( +func (d *EventDatabase) SnapshotNIDFromEventID( ctx context.Context, eventID string, ) (types.StateSnapshotNID, error) { return d.snapshotNIDFromEventID(ctx, nil, eventID) } -func (d *Database) snapshotNIDFromEventID( +func (d *EventDatabase) snapshotNIDFromEventID( ctx context.Context, txn *sql.Tx, eventID string, ) (types.StateSnapshotNID, error) { _, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID) @@ -351,17 +369,17 @@ func (d *Database) snapshotNIDFromEventID( return stateNID, err } -func (d *Database) EventIDs( +func (d *EventDatabase) EventIDs( ctx context.Context, eventNIDs []types.EventNID, ) (map[types.EventNID]string, error) { return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) } -func (d *Database) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) { - return d.eventsFromIDs(ctx, nil, roomNID, eventIDs, NoFilter) +func (d *EventDatabase) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) { + return d.eventsFromIDs(ctx, nil, roomInfo, eventIDs, NoFilter) } -func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { +func (d *EventDatabase) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter) if err != nil { return nil, err @@ -370,15 +388,9 @@ func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types var nids []types.EventNID for _, nid := range nidMap { nids = append(nids, nid.EventNID) - if roomNID != 0 && roomNID != nid.RoomNID { - logrus.Errorf("expected events from room %d, but also found %d", roomNID, nid.RoomNID) - } - if roomNID == 0 { - roomNID = nid.RoomNID - } } - return d.events(ctx, txn, roomNID, nids) + return d.events(ctx, txn, roomInfo, nids) } func (d *Database) LatestEventIDs( @@ -517,19 +529,17 @@ func (d *Database) GetInvitesForUser( return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) } -func (d *Database) Events( - ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID, -) ([]types.Event, error) { - return d.events(ctx, nil, roomNID, eventNIDs) +func (d *EventDatabase) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) { + return d.events(ctx, nil, roomInfo, eventNIDs) } -func (d *Database) events( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, inputEventNIDs types.EventNIDs, +func (d *EventDatabase) events( + ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, inputEventNIDs types.EventNIDs, ) ([]types.Event, error) { - if roomNID == 0 { - // No need to go further, as we won't find any events for this room. - return nil, nil + if roomInfo == nil { // this should never happen + return nil, fmt.Errorf("unable to parse events without roomInfo") } + sort.Sort(inputEventNIDs) events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs)) eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs)) @@ -566,31 +576,9 @@ func (d *Database) events( eventIDs = map[types.EventNID]string{} } - var roomVersion gomatrixserverlib.RoomVersion - var fetchRoomVersion bool - var ok bool - var roomID string - if roomID, ok = d.Cache.GetRoomServerRoomID(roomNID); ok { - roomVersion, ok = d.Cache.GetRoomVersion(roomID) - if !ok { - fetchRoomVersion = true - } - } - - if roomVersion == "" || fetchRoomVersion { - var dbRoomVersions map[types.RoomNID]gomatrixserverlib.RoomVersion - dbRoomVersions, err = d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, []types.RoomNID{roomNID}) - if err != nil { - return nil, err - } - if roomVersion, ok = dbRoomVersions[roomNID]; !ok { - return nil, fmt.Errorf("unable to find roomversion for room %d", roomNID) - } - } - for _, eventJSON := range eventJSONs { events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID( - eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion, + eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomInfo.RoomVersion, ) if err != nil { return nil, err @@ -660,8 +648,8 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID) } -// GetOrCreateRoomNID gets or creates a new roomNID for the given event -func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (roomNID types.RoomNID, err error) { +// GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID. +func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (roomInfo *types.RoomInfo, err error) { // Get the default room version. If the client doesn't supply a room_version // then we will use our configured default to create the room. // https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom @@ -670,8 +658,9 @@ func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserver // room. var roomVersion gomatrixserverlib.RoomVersion if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { - return 0, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) + return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) } + var roomNID types.RoomNID err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion) if err != nil { @@ -679,7 +668,10 @@ func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserver } return nil }) - return roomNID, err + return &types.RoomInfo{ + RoomVersion: roomVersion, + RoomNID: roomNID, + }, err } func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) { @@ -710,25 +702,22 @@ func (d *Database) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKe return eventStateKeyNID, nil } -func (d *Database) StoreEvent( +func (d *EventDatabase) StoreEvent( ctx context.Context, event *gomatrixserverlib.Event, - roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, + roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool, -) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { +) (types.EventNID, types.StateAtEvent, error) { var ( - eventNID types.EventNID - stateNID types.StateSnapshotNID - redactionEvent *gomatrixserverlib.Event - redactedEventID string - err error + eventNID types.EventNID + stateNID types.StateSnapshotNID + err error ) - // Second writer is using the database-provided transaction, probably from the - // room updater, for easy roll-back if required. + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if eventNID, stateNID, err = d.EventsTable.InsertEvent( ctx, txn, - roomNID, + roomInfo.RoomNID, eventTypeNID, eventStateKeyNID, event.EventID(), @@ -751,16 +740,26 @@ func (d *Database) StoreEvent( if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) } - if !isRejected { // ignore rejected redaction events - redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, roomNID, eventNID, event) - if err != nil { - return fmt.Errorf("d.handleRedactions: %w", err) + + if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { + // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of + // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This + // function only does SELECTs though so the created txn (at this point) is just a read txn like + // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater + // to do writes however then this will need to go inside `Writer.Do`. + + // The following is a copy of RoomUpdater.StorePreviousEvents + for _, ref := range prevEvents { + if err = d.PrevEventsTable.InsertPreviousEvent(ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) + } } } + return nil }) if err != nil { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) + return 0, types.StateAtEvent{}, fmt.Errorf("d.Writer.Do: %w", err) } // We should attempt to update the previous events table with any @@ -768,33 +767,6 @@ func (d *Database) StoreEvent( // events updater because it somewhat works as a mutex, ensuring // that there's a row-level lock on the latest room events (well, // on Postgres at least). - if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { - // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of - // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This - // function only does SELECTs though so the created txn (at this point) is just a read txn like - // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater - // to do writes however then this will need to go inside `Writer.Do`. - succeeded := false - var roomInfo *types.RoomInfo - roomInfo, err = d.roomInfo(ctx, nil, event.RoomID()) - if err != nil { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) - } - if roomInfo == nil && len(prevEvents) > 0 { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) - } - var updater *RoomUpdater - updater, err = d.GetRoomUpdater(ctx, roomInfo) - if err != nil { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err) - } - defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) - - if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err) - } - succeeded = true - } return eventNID, types.StateAtEvent{ BeforeStateSnapshotNID: stateNID, @@ -805,7 +777,7 @@ func (d *Database) StoreEvent( }, EventNID: eventNID, }, - }, redactionEvent, redactedEventID, err + }, err } func (d *Database) PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error { @@ -893,7 +865,7 @@ func (d *Database) assignEventTypeNID( return eventTypeNID, nil } -func (d *Database) assignStateKeyNID( +func (d *EventDatabase) assignStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey) @@ -937,7 +909,7 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( return roomVersion, err } -// handleRedactions manages the redacted status of events. There's two cases to consider in order to comply with the spec: +// MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec: // "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid." // https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events // These cases are: @@ -952,95 +924,95 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( // when loading events to determine whether to apply redactions. This keeps the hot-path of reading events quick as we don't need // to cross-reference with other tables when loading. // -// Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction. -func (d *Database) handleRedactions( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event, -) (*gomatrixserverlib.Event, string, error) { - var err error - isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil - if isRedactionEvent { - // an event which redacts itself should be ignored - if event.EventID() == event.Redacts() { - return nil, "", nil - } +// Returns the redaction event and the redacted event if this call resulted in a redaction. +func (d *EventDatabase) MaybeRedactEvent( + ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool, +) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) { + var ( + redactionEvent, redactedEvent *types.Event + err error + validated bool + ignoreRedaction bool + ) - err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{ - Validated: false, - RedactionEventID: event.EventID(), - RedactsEventID: event.Redacts(), - }) - if err != nil { - return nil, "", fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err) + wErr := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil + if isRedactionEvent { + // an event which redacts itself should be ignored + if event.EventID() == event.Redacts() { + return nil + } + + err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{ + Validated: false, + RedactionEventID: event.EventID(), + RedactsEventID: event.Redacts(), + }) + if err != nil { + return fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err) + } } - } - redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, roomNID, eventNID, event) - if err != nil { - return nil, "", fmt.Errorf("d.loadRedactionPair: %w", err) - } - if validated || redactedEvent == nil || redactionEvent == nil { - // we've seen this redaction before or there is nothing to redact - return nil, "", nil - } - if redactedEvent.RoomID() != redactionEvent.RoomID() { - // redactions across rooms aren't allowed - return nil, "", nil - } + redactionEvent, redactedEvent, validated, err = d.loadRedactionPair(ctx, txn, roomInfo, eventNID, event) + switch { + case err != nil: + return fmt.Errorf("d.loadRedactionPair: %w", err) + case validated || redactedEvent == nil || redactionEvent == nil: + // we've seen this redaction before or there is nothing to redact + return nil + case redactedEvent.RoomID() != redactionEvent.RoomID(): + // redactions across rooms aren't allowed + ignoreRedaction = true + return nil + } - // Get the power level from the database, so we can verify the user is allowed to redact the event - powerLevels, err := d.GetStateEvent(ctx, event.RoomID(), gomatrixserverlib.MRoomPowerLevels, "") - if err != nil { - return nil, "", fmt.Errorf("d.GetStateEvent: %w", err) - } - if powerLevels == nil { - return nil, "", fmt.Errorf("unable to fetch m.room.power_levels event from database for room %s", event.RoomID()) - } - pl, err := powerLevels.PowerLevels() - if err != nil { - return nil, "", fmt.Errorf("unable to get powerlevels for room: %w", err) - } + // 1. The power level of the redaction event’s sender is greater than or equal to the redact level. (redactAllowed) + // 2. The domain of the redaction event’s sender matches that of the original event’s sender. + _, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender()) + _, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender()) + if !redactAllowed || sender1 != sender2 { + ignoreRedaction = true + return nil + } - redactUser := pl.UserLevel(redactionEvent.Sender()) - switch { - case redactUser >= pl.Redact: - // The power level of the redaction event’s sender is greater than or equal to the redact level. - case redactedEvent.Sender() == redactionEvent.Sender(): - // The domain of the redaction event’s sender matches that of the original event’s sender. - default: - return nil, "", nil - } + // mark the event as redacted + if redactionsArePermanent { + redactedEvent.Redact() + } - // mark the event as redacted - if redactionsArePermanent { - redactedEvent.Redact() - } + err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent) + if err != nil { + return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) + } + // NOTSPEC: sytest relies on this unspecced field existing :( + err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID()) + if err != nil { + return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) + } + // overwrite the eventJSON table + err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON()) + if err != nil { + return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) + } - err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent) - if err != nil { - return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) - } - // NOTSPEC: sytest relies on this unspecced field existing :( - err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID()) - if err != nil { - return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err) - } - // overwrite the eventJSON table - err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON()) - if err != nil { - return nil, "", fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) + err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true) + if err != nil { + return fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err) + } + return nil + }) + if wErr != nil { + return nil, nil, err } - - err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true) - if err != nil { - err = fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err) + if ignoreRedaction || redactionEvent == nil || redactedEvent == nil { + return nil, nil, nil } - - return redactionEvent.Event, redactedEvent.EventID(), err + return redactionEvent.Event, redactedEvent.Event, nil } // loadRedactionPair returns both the redaction event and the redacted event, else nil. -func (d *Database) loadRedactionPair( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event, +func (d *EventDatabase) loadRedactionPair( + ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, ) (*types.Event, *types.Event, bool, error) { var redactionEvent, redactedEvent *types.Event var info *tables.RedactionInfo @@ -1072,16 +1044,16 @@ func (d *Database) loadRedactionPair( } if isRedactionEvent { - redactedEvent = d.loadEvent(ctx, roomNID, info.RedactsEventID) + redactedEvent = d.loadEvent(ctx, roomInfo, info.RedactsEventID) } else { - redactionEvent = d.loadEvent(ctx, roomNID, info.RedactionEventID) + redactionEvent = d.loadEvent(ctx, roomInfo, info.RedactionEventID) } return redactionEvent, redactedEvent, info.Validated, nil } // applyRedactions will redact events that have an `unsigned.redacted_because` field. -func (d *Database) applyRedactions(events []types.Event) { +func (d *EventDatabase) applyRedactions(events []types.Event) { for i := range events { if result := gjson.GetBytes(events[i].Unsigned(), "redacted_because"); result.Exists() { events[i].Redact() @@ -1090,7 +1062,7 @@ func (d *Database) applyRedactions(events []types.Event) { } // loadEvent loads a single event or returns nil on any problems/missing event -func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID string) *types.Event { +func (d *EventDatabase) loadEvent(ctx context.Context, roomInfo *types.RoomInfo, eventID string) *types.Event { nids, err := d.EventNIDs(ctx, []string{eventID}) if err != nil { return nil @@ -1098,7 +1070,7 @@ func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID if len(nids) == 0 { return nil } - evs, err := d.Events(ctx, roomNID, []types.EventNID{nids[eventID].EventNID}) + evs, err := d.Events(ctx, roomInfo, []types.EventNID{nids[eventID].EventNID}) if err != nil { return nil } @@ -1144,7 +1116,7 @@ func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *type // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { - roomInfo, err := d.RoomInfo(ctx, roomID) + roomInfo, err := d.roomInfo(ctx, nil, roomID) if err != nil { return nil, err } @@ -1209,7 +1181,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s // Same as GetStateEvent but returns all matching state events with this event type. Returns no error // if there are no events with this event type. func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) { - roomInfo, err := d.RoomInfo(ctx, roomID) + roomInfo, err := d.roomInfo(ctx, nil, roomID) if err != nil { return nil, err } @@ -1340,7 +1312,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion) // TODO: This feels like this is going to be really slow... for _, roomID := range roomIDs { - roomInfo, err2 := d.RoomInfo(ctx, roomID) + roomInfo, err2 := d.roomInfo(ctx, nil, roomID) if err2 != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to load room info for room %s : %w", roomID, err2) } diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 3acb55a3..684e80b8 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -52,12 +52,14 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) + evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache} + return &shared.Database{ - DB: db, - EventStateKeysTable: stateKeyTable, - MembershipTable: membershipTable, - Writer: sqlutil.NewExclusiveWriter(), - Cache: cache, + DB: db, + EventDatabase: evDb, + MembershipTable: membershipTable, + Writer: sqlutil.NewExclusiveWriter(), + Cache: cache, }, func() { err := base.Close() assert.NoError(t, err) diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 392edd28..2adedd2d 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -203,24 +203,29 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room } d.Database = shared.Database{ - DB: db, - Cache: cache, - Writer: writer, - EventsTable: events, - EventTypesTable: eventTypes, - EventStateKeysTable: eventStateKeys, - EventJSONTable: eventJSON, - RoomsTable: rooms, - StateBlockTable: stateBlock, - StateSnapshotTable: stateSnapshot, - PrevEventsTable: prevEvents, - RoomAliasesTable: roomAliases, - InvitesTable: invites, - MembershipTable: membership, - PublishedTable: published, - RedactionsTable: redactions, - GetRoomUpdaterFn: d.GetRoomUpdater, - Purge: purge, + DB: db, + EventDatabase: shared.EventDatabase{ + DB: db, + Cache: cache, + Writer: writer, + EventsTable: events, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + EventJSONTable: eventJSON, + PrevEventsTable: prevEvents, + RedactionsTable: redactions, + }, + Cache: cache, + Writer: writer, + RoomsTable: rooms, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + PublishedTable: published, + GetRoomUpdaterFn: d.GetRoomUpdater, + Purge: purge, } return nil } diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index bc369c16..4bb6a5ee 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -253,7 +253,7 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo var res MSC2836EventRelationshipsResponse var returnEvents []*gomatrixserverlib.HeaderedEvent // Can the user see (according to history visibility) event_id? If no, reject the request, else continue. - event := rc.getLocalEvent(rc.req.EventID) + event := rc.getLocalEvent(rc.req.RoomID, rc.req.EventID) if event == nil { event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID) } @@ -592,7 +592,7 @@ func (rc *reqCtx) remoteEventRelationships(eventID string) *MSC2836EventRelation // lookForEvent returns the event for the event ID given, by trying to query remote servers // if the event ID is unknown via /event_relationships. func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent { - event := rc.getLocalEvent(eventID) + event := rc.getLocalEvent(rc.req.RoomID, eventID) if event == nil { queryRes := rc.remoteEventRelationships(eventID) if queryRes != nil { @@ -622,9 +622,10 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent return nil } -func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent { +func (rc *reqCtx) getLocalEvent(roomID, eventID string) *gomatrixserverlib.HeaderedEvent { var queryEventsRes roomserver.QueryEventsByIDResponse err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{ + RoomID: roomID, EventIDs: []string{eventID}, }, &queryEventsRes) if err != nil { diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 21838039..a8d4d2b2 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -212,6 +212,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( // Finally, work out if there are any more events missing. if len(missingEventIDs) > 0 { eventsReq := &api.QueryEventsByIDRequest{ + RoomID: ev.RoomID(), EventIDs: missingEventIDs, } eventsRes := &api.QueryEventsByIDResponse{} diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 9ffdf513..8efd77ce 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -109,7 +109,7 @@ func GetMemberships( } qryRes := &api.QueryEventsByIDResponse{} - if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs}, qryRes); err != nil { + if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs, RoomID: roomID}, qryRes); err != nil { util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed") return jsonerror.InternalServerError() } |