diff options
author | Kegsay <kegan@matrix.org> | 2020-09-02 17:13:15 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-02 17:13:15 +0100 |
commit | 9d9e854fe042cd2c83cf694d6b3e4c8e7046cde1 (patch) | |
tree | 75f9247df1d00b140c4249ef14b711b2839806bd /roomserver/internal/input | |
parent | f06637435b2124c89dfdd96cd723f54cc7055602 (diff) |
Add Queryer and Inputer and factor out more RSAPI stuff (#1382)
* Add Queryer and use embedded structs
* Add Inputer and factor out more RS API stuff
This neatly splits up the RS API based on the functionality it provides,
whilst providing a useful place for code sharing via the `helpers` package.
Diffstat (limited to 'roomserver/internal/input')
-rw-r--r-- | roomserver/internal/input/input.go | 91 | ||||
-rw-r--r-- | roomserver/internal/input/input_events.go | 185 | ||||
-rw-r--r-- | roomserver/internal/input/input_latest_events.go | 390 | ||||
-rw-r--r-- | roomserver/internal/input/input_membership.go | 267 |
4 files changed, 933 insertions, 0 deletions
diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go new file mode 100644 index 00000000..87bdc5db --- /dev/null +++ b/roomserver/internal/input/input.go @@ -0,0 +1,91 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package input contains the code processes new room events +package input + +import ( + "context" + "encoding/json" + "sync" + + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +type Inputer struct { + DB storage.Database + Producer sarama.SyncProducer + ServerName gomatrixserverlib.ServerName + OutputRoomEventTopic string + + mutexes sync.Map // room ID -> *sync.Mutex, protects calls to processRoomEvent +} + +// WriteOutputEvents implements OutputRoomEventWriter +func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) error { + messages := make([]*sarama.ProducerMessage, len(updates)) + for i := range updates { + value, err := json.Marshal(updates[i]) + if err != nil { + return err + } + logger := log.WithFields(log.Fields{ + "room_id": roomID, + "type": updates[i].Type, + }) + if updates[i].NewRoomEvent != nil { + logger = logger.WithFields(log.Fields{ + "event_type": updates[i].NewRoomEvent.Event.Type(), + "event_id": updates[i].NewRoomEvent.Event.EventID(), + "adds_state": len(updates[i].NewRoomEvent.AddsStateEventIDs), + "removes_state": len(updates[i].NewRoomEvent.RemovesStateEventIDs), + "send_as_server": updates[i].NewRoomEvent.SendAsServer, + "sender": updates[i].NewRoomEvent.Event.Sender(), + }) + } + logger.Infof("Producing to topic '%s'", r.OutputRoomEventTopic) + messages[i] = &sarama.ProducerMessage{ + Topic: r.OutputRoomEventTopic, + Key: sarama.StringEncoder(roomID), + Value: sarama.ByteEncoder(value), + } + } + return r.Producer.SendMessages(messages) +} + +// InputRoomEvents implements api.RoomserverInternalAPI +func (r *Inputer) InputRoomEvents( + ctx context.Context, + request *api.InputRoomEventsRequest, + response *api.InputRoomEventsResponse, +) (err error) { + for i, e := range request.InputRoomEvents { + roomID := "global" + if r.DB.SupportsConcurrentRoomInputs() { + roomID = e.Event.RoomID() + } + mutex, _ := r.mutexes.LoadOrStore(roomID, &sync.Mutex{}) + mutex.(*sync.Mutex).Lock() + if response.EventID, err = r.processRoomEvent(ctx, request.InputRoomEvents[i]); err != nil { + mutex.(*sync.Mutex).Unlock() + return err + } + mutex.(*sync.Mutex).Unlock() + } + return nil +} diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go new file mode 100644 index 00000000..69f51f4b --- /dev/null +++ b/roomserver/internal/input/input_events.go @@ -0,0 +1,185 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package input + +import ( + "context" + "fmt" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// processRoomEvent can only be called once at a time +// +// TODO(#375): This should be rewritten to allow concurrent calls. The +// difficulty is in ensuring that we correctly annotate events with the correct +// state deltas when sending to kafka streams +// TODO: Break up function - we should probably do transaction ID checks before calling this. +// nolint:gocyclo +func (r *Inputer) processRoomEvent( + ctx context.Context, + input api.InputRoomEvent, +) (eventID string, err error) { + // Parse and validate the event JSON + headered := input.Event + event := headered.Unwrap() + + // Check that the event passes authentication checks and work out + // the numeric IDs for the auth events. + authEventNIDs, err := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) + if err != nil { + logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event") + return + } + + // If we don't have a transaction ID then get one. + if input.TransactionID != nil { + tdID := input.TransactionID + eventID, err = r.DB.GetTransactionEventID( + ctx, tdID.TransactionID, tdID.SessionID, event.Sender(), + ) + // On error OR event with the transaction already processed/processesing + if err != nil || eventID != "" { + return + } + } + + // Store the event. + _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) + if err != nil { + return "", fmt.Errorf("r.DB.StoreEvent: %w", err) + } + // if storing this event results in it being redacted then do so. + if redactedEventID == event.EventID() { + r, rerr := eventutil.RedactEvent(redactionEvent, &event) + if rerr != nil { + return "", fmt.Errorf("eventutil.RedactEvent: %w", rerr) + } + event = *r + } + + // For outliers we can stop after we've stored the event itself as it + // doesn't have any associated state to store and we don't need to + // notify anyone about it. + if input.Kind == api.KindOutlier { + logrus.WithFields(logrus.Fields{ + "event_id": event.EventID(), + "type": event.Type(), + "room": event.RoomID(), + }).Info("Stored outlier") + return event.EventID(), nil + } + + roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID()) + if err != nil { + return "", fmt.Errorf("r.DB.RoomInfo: %w", err) + } + if roomInfo == nil { + return "", fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) + } + + if stateAtEvent.BeforeStateSnapshotNID == 0 { + // We haven't calculated a state for this event yet. + // Lets calculate one. + err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event) + if err != nil { + return "", fmt.Errorf("r.calculateAndSetState: %w", err) + } + } + + if err = r.updateLatestEvents( + ctx, // context + roomInfo, // room info for the room being updated + stateAtEvent, // state at event (below) + event, // event + input.SendAsServer, // send as server + input.TransactionID, // transaction ID + ); err != nil { + return "", fmt.Errorf("r.updateLatestEvents: %w", err) + } + + // processing this event resulted in an event (which may not be the one we're processing) + // being redacted. We are guaranteed to have both sides (the redaction/redacted event), + // so notify downstream components to redact this event - they should have it if they've + // been tracking our output log. + if redactedEventID != "" { + err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{ + { + Type: api.OutputTypeRedactedEvent, + RedactedEvent: &api.OutputRedactedEvent{ + RedactedEventID: redactedEventID, + RedactedBecause: redactionEvent.Headered(headered.RoomVersion), + }, + }, + }) + if err != nil { + return "", fmt.Errorf("r.WriteOutputEvents: %w", err) + } + } + + // Update the extremities of the event graph for the room + return event.EventID(), nil +} + +func (r *Inputer) calculateAndSetState( + ctx context.Context, + input api.InputRoomEvent, + roomInfo types.RoomInfo, + stateAtEvent *types.StateAtEvent, + event gomatrixserverlib.Event, +) error { + var err error + roomState := state.NewStateResolution(r.DB, roomInfo) + + if input.HasState { + // Check here if we think we're in the room already. + stateAtEvent.Overwrite = true + var joinEventNIDs []types.EventNID + // Request join memberships only for local users only. + if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { + // If we have no local users that are joined to the room then any state about + // the room that we have is quite possibly out of date. Therefore in that case + // we should overwrite it rather than merge it. + stateAtEvent.Overwrite = len(joinEventNIDs) == 0 + } + + // We've been told what the state at the event is so we don't need to calculate it. + // Check that those state events are in the database and store the state. + var entries []types.StateEntry + if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + return err + } + + if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { + return err + } + } else { + stateAtEvent.Overwrite = false + + // We haven't been told what the state at the event is so we need to calculate it from the prev_events + if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil { + return err + } + } + return r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) +} diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go new file mode 100644 index 00000000..67a7d8a4 --- /dev/null +++ b/roomserver/internal/input/input_latest_events.go @@ -0,0 +1,390 @@ +// Copyright 2017 Vector Creations Ltd +// Copyright 2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package input + +import ( + "bytes" + "context" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// updateLatestEvents updates the list of latest events for this room in the database and writes the +// event to the output log. +// The latest events are the events that aren't referenced by another event in the database: +// +// Time goes down the page. 1 is the m.room.create event (root). +// +// 1 After storing 1 the latest events are {1} +// | After storing 2 the latest events are {2} +// 2 After storing 3 the latest events are {3} +// / \ After storing 4 the latest events are {3,4} +// 3 4 After storing 5 the latest events are {5,4} +// | | After storing 6 the latest events are {5,6} +// 5 6 <--- latest After storing 7 the latest events are {6,7} +// | +// 7 <----- latest +// +// Can only be called once at a time +func (r *Inputer) updateLatestEvents( + ctx context.Context, + roomInfo *types.RoomInfo, + stateAtEvent types.StateAtEvent, + event gomatrixserverlib.Event, + sendAsServer string, + transactionID *api.TransactionID, +) (err error) { + updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo) + if err != nil { + return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) + } + succeeded := false + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) + + u := latestEventsUpdater{ + ctx: ctx, + api: r, + updater: updater, + roomInfo: roomInfo, + stateAtEvent: stateAtEvent, + event: event, + sendAsServer: sendAsServer, + transactionID: transactionID, + } + + if err = u.doUpdateLatestEvents(); err != nil { + return fmt.Errorf("u.doUpdateLatestEvents: %w", err) + } + + succeeded = true + return +} + +// latestEventsUpdater tracks the state used to update the latest events in the +// room. It mostly just ferries state between the various function calls. +// The state could be passed using function arguments, but it becomes impractical +// when there are so many variables to pass around. +type latestEventsUpdater struct { + ctx context.Context + api *Inputer + updater *shared.LatestEventsUpdater + roomInfo *types.RoomInfo + stateAtEvent types.StateAtEvent + event gomatrixserverlib.Event + transactionID *api.TransactionID + // Which server to send this event as. + sendAsServer string + // The eventID of the event that was processed before this one. + lastEventIDSent string + // The latest events in the room after processing this event. + latest []types.StateAtEventAndReference + // The state entries removed from and added to the current state of the + // room as a result of processing this event. They are sorted lists. + removed []types.StateEntry + added []types.StateEntry + // The state entries that are removed and added to recover the state before + // the event being processed. They are sorted lists. + stateBeforeEventRemoves []types.StateEntry + stateBeforeEventAdds []types.StateEntry + // The snapshots of current state before and after processing this event + oldStateNID types.StateSnapshotNID + newStateNID types.StateSnapshotNID +} + +func (u *latestEventsUpdater) doUpdateLatestEvents() error { + prevEvents := u.event.PrevEvents() + u.lastEventIDSent = u.updater.LastEventIDSent() + u.oldStateNID = u.updater.CurrentStateSnapshotNID() + + // If we are doing a regular event update then we will get the + // previous latest events to use as a part of the calculation. If + // we are overwriting the latest events because we have a complete + // state snapshot from somewhere else, e.g. a federated room join, + // then start with an empty set - none of the forward extremities + // that we knew about before matter anymore. + oldLatest := []types.StateAtEventAndReference{} + if !u.stateAtEvent.Overwrite { + oldLatest = u.updater.LatestEvents() + } + + // If the event has already been written to the output log then we + // don't need to do anything, as we've handled it already. + hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID) + if err != nil { + return fmt.Errorf("u.updater.HasEventBeenSent: %w", err) + } else if hasBeenSent { + return nil + } + + // Update the roomserver_previous_events table with references. This + // is effectively tracking the structure of the DAG. + if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil { + return fmt.Errorf("u.updater.StorePreviousEvents: %w", err) + } + + // Get the event reference for our new event. This will be used when + // determining if the event is referenced by an existing event. + eventReference := u.event.EventReference() + + // Check if our new event is already referenced by an existing event + // in the room. If it is then it isn't a latest event. + alreadyReferenced, err := u.updater.IsReferenced(eventReference) + if err != nil { + return fmt.Errorf("u.updater.IsReferenced: %w", err) + } + + // Work out what the latest events are. + u.latest = calculateLatest( + oldLatest, + alreadyReferenced, + prevEvents, + types.StateAtEventAndReference{ + EventReference: eventReference, + StateAtEvent: u.stateAtEvent, + }, + ) + + // Now that we know what the latest events are, it's time to get the + // latest state. + if err = u.latestState(); err != nil { + return fmt.Errorf("u.latestState: %w", err) + } + + // If we need to generate any output events then here's where we do it. + // TODO: Move this! + updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added) + if err != nil { + return fmt.Errorf("u.api.updateMemberships: %w", err) + } + + update, err := u.makeOutputNewRoomEvent() + if err != nil { + return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err) + } + updates = append(updates, *update) + + // Send the event to the output logs. + // We do this inside the database transaction to ensure that we only mark an event as sent if we sent it. + // (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but + // the write to the output log succeeds) + // TODO: This assumes that writing the event to the output log is synchronous. It should be possible to + // send the event asynchronously but we would need to ensure that 1) the events are written to the log in + // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the + // necessary bookkeeping we'll keep the event sending synchronous for now. + if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil { + return fmt.Errorf("u.api.WriteOutputEvents: %w", err) + } + + if err = u.updater.SetLatestEvents(u.roomInfo.RoomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil { + return fmt.Errorf("u.updater.SetLatestEvents: %w", err) + } + + if err = u.updater.MarkEventAsSent(u.stateAtEvent.EventNID); err != nil { + return fmt.Errorf("u.updater.MarkEventAsSent: %w", err) + } + + return nil +} + +func (u *latestEventsUpdater) latestState() error { + var err error + roomState := state.NewStateResolution(u.api.DB, *u.roomInfo) + + // Get a list of the current latest events. + latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) + for i := range u.latest { + latestStateAtEvents[i] = u.latest[i].StateAtEvent + } + + // Takes the NIDs of the latest events and creates a state snapshot + // of the state after the events. The snapshot state will be resolved + // using the correct state resolution algorithm for the room. + u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents( + u.ctx, latestStateAtEvents, + ) + if err != nil { + return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err) + } + + // If we are overwriting the state then we should make sure that we + // don't send anything out over federation again, it will very likely + // be a repeat. + if u.stateAtEvent.Overwrite { + u.sendAsServer = "" + } + + // Now that we have a new state snapshot based on the latest events, + // we can compare that new snapshot to the previous one and see what + // has changed. This gives us one list of removed state events and + // another list of added ones. Replacing a value for a state-key tuple + // will result one removed (the old event) and one added (the new event). + u.removed, u.added, err = roomState.DifferenceBetweeenStateSnapshots( + u.ctx, u.oldStateNID, u.newStateNID, + ) + if err != nil { + return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err) + } + + // Also work out the state before the event removes and the event + // adds. + u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots( + u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, + ) + if err != nil { + return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err) + } + + return nil +} + +func calculateLatest( + oldLatest []types.StateAtEventAndReference, + alreadyReferenced bool, + prevEvents []gomatrixserverlib.EventReference, + newEvent types.StateAtEventAndReference, +) []types.StateAtEventAndReference { + var alreadyInLatest bool + var newLatest []types.StateAtEventAndReference + for _, l := range oldLatest { + keep := true + for _, prevEvent := range prevEvents { + if l.EventID == prevEvent.EventID && bytes.Equal(l.EventSHA256, prevEvent.EventSHA256) { + // This event can be removed from the latest events cause we've found an event that references it. + // (If an event is referenced by another event then it can't be one of the latest events in the room + // because we have an event that comes after it) + keep = false + break + } + } + if l.EventNID == newEvent.EventNID { + alreadyInLatest = true + } + if keep { + // Keep the event in the latest events. + newLatest = append(newLatest, l) + } + } + + if !alreadyReferenced && !alreadyInLatest { + // This event is not referenced by any of the events in the room + // and the event is not already in the latest events. + // Add it to the latest events + newLatest = append(newLatest, newEvent) + } + + return newLatest +} + +func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) { + + latestEventIDs := make([]string, len(u.latest)) + for i := range u.latest { + latestEventIDs[i] = u.latest[i].EventID + } + + ore := api.OutputNewRoomEvent{ + Event: u.event.Headered(u.roomInfo.RoomVersion), + LastSentEventID: u.lastEventIDSent, + LatestEventIDs: latestEventIDs, + TransactionID: u.transactionID, + } + + eventIDMap, err := u.stateEventMap() + if err != nil { + return nil, err + } + + for _, entry := range u.added { + ore.AddsStateEventIDs = append(ore.AddsStateEventIDs, eventIDMap[entry.EventNID]) + } + for _, entry := range u.removed { + ore.RemovesStateEventIDs = append(ore.RemovesStateEventIDs, eventIDMap[entry.EventNID]) + } + for _, entry := range u.stateBeforeEventRemoves { + ore.StateBeforeRemovesEventIDs = append(ore.StateBeforeRemovesEventIDs, eventIDMap[entry.EventNID]) + } + for _, entry := range u.stateBeforeEventAdds { + ore.StateBeforeAddsEventIDs = append(ore.StateBeforeAddsEventIDs, eventIDMap[entry.EventNID]) + } + ore.SendAsServer = u.sendAsServer + + // include extra state events if they were added as nearly every downstream component will care about it + // and we'd rather not have them all hit QueryEventsByID at the same time! + if len(ore.AddsStateEventIDs) > 0 { + ore.AddStateEvents, err = u.extraEventsForIDs(u.roomInfo.RoomVersion, ore.AddsStateEventIDs) + if err != nil { + return nil, fmt.Errorf("failed to load add_state_events from db: %w", err) + } + } + + return &api.OutputEvent{ + Type: api.OutputTypeNewRoomEvent, + NewRoomEvent: &ore, + }, nil +} + +// extraEventsForIDs returns the full events for the event IDs given, but does not include the current event being +// updated. +func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) { + var extraEventIDs []string + for _, e := range eventIDs { + if e == u.event.EventID() { + continue + } + extraEventIDs = append(extraEventIDs, e) + } + if len(extraEventIDs) == 0 { + return nil, nil + } + extraEvents, err := u.api.DB.EventsFromIDs(u.ctx, extraEventIDs) + if err != nil { + return nil, err + } + var h []gomatrixserverlib.HeaderedEvent + for _, e := range extraEvents { + h = append(h, e.Headered(roomVersion)) + } + return h, nil +} + +// retrieve an event nid -> event ID map for all events that need updating +func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) { + var stateEventNIDs []types.EventNID + var allStateEntries []types.StateEntry + allStateEntries = append(allStateEntries, u.added...) + allStateEntries = append(allStateEntries, u.removed...) + allStateEntries = append(allStateEntries, u.stateBeforeEventRemoves...) + allStateEntries = append(allStateEntries, u.stateBeforeEventAdds...) + for _, entry := range allStateEntries { + stateEventNIDs = append(stateEventNIDs, entry.EventNID) + } + stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] + return u.api.DB.EventIDs(u.ctx, stateEventNIDs) +} + +type eventNIDSorter []types.EventNID + +func (s eventNIDSorter) Len() int { return len(s) } +func (s eventNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s eventNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go new file mode 100644 index 00000000..8befcd64 --- /dev/null +++ b/roomserver/internal/input/input_membership.go @@ -0,0 +1,267 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package input + +import ( + "context" + "fmt" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// updateMembership updates the current membership and the invites for each +// user affected by a change in the current state of the room. +// Returns a list of output events to write to the kafka log to inform the +// consumers about the invites added or retired by the change in current state. +func (r *Inputer) updateMemberships( + ctx context.Context, + updater *shared.LatestEventsUpdater, + removed, added []types.StateEntry, +) ([]api.OutputEvent, error) { + changes := membershipChanges(removed, added) + var eventNIDs []types.EventNID + for _, change := range changes { + if change.addedEventNID != 0 { + eventNIDs = append(eventNIDs, change.addedEventNID) + } + if change.removedEventNID != 0 { + eventNIDs = append(eventNIDs, change.removedEventNID) + } + } + + // Load the event JSON so we can look up the "membership" key. + // TODO: Maybe add a membership key to the events table so we can load that + // key without having to load the entire event JSON? + events, err := r.DB.Events(ctx, eventNIDs) + if err != nil { + return nil, err + } + + var updates []api.OutputEvent + + for _, change := range changes { + var ae *gomatrixserverlib.Event + var re *gomatrixserverlib.Event + targetUserNID := change.EventStateKeyNID + if change.removedEventNID != 0 { + ev, _ := helpers.EventMap(events).Lookup(change.removedEventNID) + if ev != nil { + re = &ev.Event + } + } + if change.addedEventNID != 0 { + ev, _ := helpers.EventMap(events).Lookup(change.addedEventNID) + if ev != nil { + ae = &ev.Event + } + } + if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil { + return nil, err + } + } + return updates, nil +} + +func (r *Inputer) updateMembership( + updater *shared.LatestEventsUpdater, + targetUserNID types.EventStateKeyNID, + remove, add *gomatrixserverlib.Event, + updates []api.OutputEvent, +) ([]api.OutputEvent, error) { + var err error + // Default the membership to Leave if no event was added or removed. + oldMembership := gomatrixserverlib.Leave + newMembership := gomatrixserverlib.Leave + + if remove != nil { + oldMembership, err = remove.Membership() + if err != nil { + return nil, err + } + } + if add != nil { + newMembership, err = add.Membership() + if err != nil { + return nil, err + } + } + if oldMembership == newMembership && newMembership != gomatrixserverlib.Join { + // If the membership is the same then nothing changed and we can return + // immediately, unless it's a Join update (e.g. profile update). + return updates, nil + } + + if add == nil { + // This can happen when we have rejoined a room and suddenly we have a + // divergence between the former state and the new one. We don't want to + // act on removals and apparently there are no adds, so stop here. + return updates, nil + } + + mu, err := updater.MembershipUpdater(targetUserNID, r.isLocalTarget(add)) + if err != nil { + return nil, err + } + + switch newMembership { + case gomatrixserverlib.Invite: + return helpers.UpdateToInviteMembership(mu, add, updates, updater.RoomVersion()) + case gomatrixserverlib.Join: + return updateToJoinMembership(mu, add, updates) + case gomatrixserverlib.Leave, gomatrixserverlib.Ban: + return updateToLeaveMembership(mu, add, newMembership, updates) + default: + panic(fmt.Errorf( + "input: membership %q is not one of the allowed values", newMembership, + )) + } +} + +func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool { + isTargetLocalUser := false + if statekey := event.StateKey(); statekey != nil { + _, domain, _ := gomatrixserverlib.SplitID('@', *statekey) + isTargetLocalUser = domain == r.ServerName + } + return isTargetLocalUser +} + +func updateToJoinMembership( + mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, +) ([]api.OutputEvent, error) { + // If the user is already marked as being joined, we call SetToJoin to update + // the event ID then we can return immediately. Retired is ignored as there + // is no invite event to retire. + if mu.IsJoin() { + _, err := mu.SetToJoin(add.Sender(), add.EventID(), true) + if err != nil { + return nil, err + } + return updates, nil + } + // When we mark a user as being joined we will invalidate any invites that + // are active for that user. We notify the consumers that the invites have + // been retired using a special event, even though they could infer this + // by studying the state changes in the room event stream. + retired, err := mu.SetToJoin(add.Sender(), add.EventID(), false) + if err != nil { + return nil, err + } + for _, eventID := range retired { + orie := api.OutputRetireInviteEvent{ + EventID: eventID, + Membership: gomatrixserverlib.Join, + RetiredByEventID: add.EventID(), + TargetUserID: *add.StateKey(), + } + updates = append(updates, api.OutputEvent{ + Type: api.OutputTypeRetireInviteEvent, + RetireInviteEvent: &orie, + }) + } + return updates, nil +} + +func updateToLeaveMembership( + mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, + newMembership string, updates []api.OutputEvent, +) ([]api.OutputEvent, error) { + // If the user is already neither joined, nor invited to the room then we + // can return immediately. + if mu.IsLeave() { + return updates, nil + } + // When we mark a user as having left we will invalidate any invites that + // are active for that user. We notify the consumers that the invites have + // been retired using a special event, even though they could infer this + // by studying the state changes in the room event stream. + retired, err := mu.SetToLeave(add.Sender(), add.EventID()) + if err != nil { + return nil, err + } + for _, eventID := range retired { + orie := api.OutputRetireInviteEvent{ + EventID: eventID, + Membership: newMembership, + RetiredByEventID: add.EventID(), + TargetUserID: *add.StateKey(), + } + updates = append(updates, api.OutputEvent{ + Type: api.OutputTypeRetireInviteEvent, + RetireInviteEvent: &orie, + }) + } + return updates, nil +} + +// membershipChanges pairs up the membership state changes. +func membershipChanges(removed, added []types.StateEntry) []stateChange { + changes := pairUpChanges(removed, added) + var result []stateChange + for _, c := range changes { + if c.EventTypeNID == types.MRoomMemberNID { + result = append(result, c) + } + } + return result +} + +type stateChange struct { + types.StateKeyTuple + removedEventNID types.EventNID + addedEventNID types.EventNID +} + +// pairUpChanges pairs up the state events added and removed for each type, +// state key tuple. +func pairUpChanges(removed, added []types.StateEntry) []stateChange { + tuples := make(map[types.StateKeyTuple]stateChange) + changes := []stateChange{} + + // First, go through the newly added state entries. + for _, add := range added { + if change, ok := tuples[add.StateKeyTuple]; ok { + // If we already have an entry, update it. + change.addedEventNID = add.EventNID + tuples[add.StateKeyTuple] = change + } else { + // Otherwise, create a new entry. + tuples[add.StateKeyTuple] = stateChange{add.StateKeyTuple, 0, add.EventNID} + } + } + + // Now go through the removed state entries. + for _, remove := range removed { + if change, ok := tuples[remove.StateKeyTuple]; ok { + // If we already have an entry, update it. + change.removedEventNID = remove.EventNID + tuples[remove.StateKeyTuple] = change + } else { + // Otherwise, create a new entry. + tuples[remove.StateKeyTuple] = stateChange{remove.StateKeyTuple, remove.EventNID, 0} + } + } + + // Now return the changes as an array. + for _, change := range tuples { + changes = append(changes, change) + } + + return changes +} |