aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/input
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-09-02 17:13:15 +0100
committerGitHub <noreply@github.com>2020-09-02 17:13:15 +0100
commit9d9e854fe042cd2c83cf694d6b3e4c8e7046cde1 (patch)
tree75f9247df1d00b140c4249ef14b711b2839806bd /roomserver/internal/input
parentf06637435b2124c89dfdd96cd723f54cc7055602 (diff)
Add Queryer and Inputer and factor out more RSAPI stuff (#1382)
* Add Queryer and use embedded structs * Add Inputer and factor out more RS API stuff This neatly splits up the RS API based on the functionality it provides, whilst providing a useful place for code sharing via the `helpers` package.
Diffstat (limited to 'roomserver/internal/input')
-rw-r--r--roomserver/internal/input/input.go91
-rw-r--r--roomserver/internal/input/input_events.go185
-rw-r--r--roomserver/internal/input/input_latest_events.go390
-rw-r--r--roomserver/internal/input/input_membership.go267
4 files changed, 933 insertions, 0 deletions
diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go
new file mode 100644
index 00000000..87bdc5db
--- /dev/null
+++ b/roomserver/internal/input/input.go
@@ -0,0 +1,91 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package input contains the code processes new room events
+package input
+
+import (
+ "context"
+ "encoding/json"
+ "sync"
+
+ "github.com/Shopify/sarama"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/storage"
+ "github.com/matrix-org/gomatrixserverlib"
+ log "github.com/sirupsen/logrus"
+)
+
+type Inputer struct {
+ DB storage.Database
+ Producer sarama.SyncProducer
+ ServerName gomatrixserverlib.ServerName
+ OutputRoomEventTopic string
+
+ mutexes sync.Map // room ID -> *sync.Mutex, protects calls to processRoomEvent
+}
+
+// WriteOutputEvents implements OutputRoomEventWriter
+func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) error {
+ messages := make([]*sarama.ProducerMessage, len(updates))
+ for i := range updates {
+ value, err := json.Marshal(updates[i])
+ if err != nil {
+ return err
+ }
+ logger := log.WithFields(log.Fields{
+ "room_id": roomID,
+ "type": updates[i].Type,
+ })
+ if updates[i].NewRoomEvent != nil {
+ logger = logger.WithFields(log.Fields{
+ "event_type": updates[i].NewRoomEvent.Event.Type(),
+ "event_id": updates[i].NewRoomEvent.Event.EventID(),
+ "adds_state": len(updates[i].NewRoomEvent.AddsStateEventIDs),
+ "removes_state": len(updates[i].NewRoomEvent.RemovesStateEventIDs),
+ "send_as_server": updates[i].NewRoomEvent.SendAsServer,
+ "sender": updates[i].NewRoomEvent.Event.Sender(),
+ })
+ }
+ logger.Infof("Producing to topic '%s'", r.OutputRoomEventTopic)
+ messages[i] = &sarama.ProducerMessage{
+ Topic: r.OutputRoomEventTopic,
+ Key: sarama.StringEncoder(roomID),
+ Value: sarama.ByteEncoder(value),
+ }
+ }
+ return r.Producer.SendMessages(messages)
+}
+
+// InputRoomEvents implements api.RoomserverInternalAPI
+func (r *Inputer) InputRoomEvents(
+ ctx context.Context,
+ request *api.InputRoomEventsRequest,
+ response *api.InputRoomEventsResponse,
+) (err error) {
+ for i, e := range request.InputRoomEvents {
+ roomID := "global"
+ if r.DB.SupportsConcurrentRoomInputs() {
+ roomID = e.Event.RoomID()
+ }
+ mutex, _ := r.mutexes.LoadOrStore(roomID, &sync.Mutex{})
+ mutex.(*sync.Mutex).Lock()
+ if response.EventID, err = r.processRoomEvent(ctx, request.InputRoomEvents[i]); err != nil {
+ mutex.(*sync.Mutex).Unlock()
+ return err
+ }
+ mutex.(*sync.Mutex).Unlock()
+ }
+ return nil
+}
diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go
new file mode 100644
index 00000000..69f51f4b
--- /dev/null
+++ b/roomserver/internal/input/input_events.go
@@ -0,0 +1,185 @@
+// Copyright 2017 Vector Creations Ltd
+// Copyright 2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package input
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal/eventutil"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/internal/helpers"
+ "github.com/matrix-org/dendrite/roomserver/state"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+)
+
+// processRoomEvent can only be called once at a time
+//
+// TODO(#375): This should be rewritten to allow concurrent calls. The
+// difficulty is in ensuring that we correctly annotate events with the correct
+// state deltas when sending to kafka streams
+// TODO: Break up function - we should probably do transaction ID checks before calling this.
+// nolint:gocyclo
+func (r *Inputer) processRoomEvent(
+ ctx context.Context,
+ input api.InputRoomEvent,
+) (eventID string, err error) {
+ // Parse and validate the event JSON
+ headered := input.Event
+ event := headered.Unwrap()
+
+ // Check that the event passes authentication checks and work out
+ // the numeric IDs for the auth events.
+ authEventNIDs, err := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs)
+ if err != nil {
+ logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event")
+ return
+ }
+
+ // If we don't have a transaction ID then get one.
+ if input.TransactionID != nil {
+ tdID := input.TransactionID
+ eventID, err = r.DB.GetTransactionEventID(
+ ctx, tdID.TransactionID, tdID.SessionID, event.Sender(),
+ )
+ // On error OR event with the transaction already processed/processesing
+ if err != nil || eventID != "" {
+ return
+ }
+ }
+
+ // Store the event.
+ _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
+ if err != nil {
+ return "", fmt.Errorf("r.DB.StoreEvent: %w", err)
+ }
+ // if storing this event results in it being redacted then do so.
+ if redactedEventID == event.EventID() {
+ r, rerr := eventutil.RedactEvent(redactionEvent, &event)
+ if rerr != nil {
+ return "", fmt.Errorf("eventutil.RedactEvent: %w", rerr)
+ }
+ event = *r
+ }
+
+ // For outliers we can stop after we've stored the event itself as it
+ // doesn't have any associated state to store and we don't need to
+ // notify anyone about it.
+ if input.Kind == api.KindOutlier {
+ logrus.WithFields(logrus.Fields{
+ "event_id": event.EventID(),
+ "type": event.Type(),
+ "room": event.RoomID(),
+ }).Info("Stored outlier")
+ return event.EventID(), nil
+ }
+
+ roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID())
+ if err != nil {
+ return "", fmt.Errorf("r.DB.RoomInfo: %w", err)
+ }
+ if roomInfo == nil {
+ return "", fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID())
+ }
+
+ if stateAtEvent.BeforeStateSnapshotNID == 0 {
+ // We haven't calculated a state for this event yet.
+ // Lets calculate one.
+ err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event)
+ if err != nil {
+ return "", fmt.Errorf("r.calculateAndSetState: %w", err)
+ }
+ }
+
+ if err = r.updateLatestEvents(
+ ctx, // context
+ roomInfo, // room info for the room being updated
+ stateAtEvent, // state at event (below)
+ event, // event
+ input.SendAsServer, // send as server
+ input.TransactionID, // transaction ID
+ ); err != nil {
+ return "", fmt.Errorf("r.updateLatestEvents: %w", err)
+ }
+
+ // processing this event resulted in an event (which may not be the one we're processing)
+ // being redacted. We are guaranteed to have both sides (the redaction/redacted event),
+ // so notify downstream components to redact this event - they should have it if they've
+ // been tracking our output log.
+ if redactedEventID != "" {
+ err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{
+ {
+ Type: api.OutputTypeRedactedEvent,
+ RedactedEvent: &api.OutputRedactedEvent{
+ RedactedEventID: redactedEventID,
+ RedactedBecause: redactionEvent.Headered(headered.RoomVersion),
+ },
+ },
+ })
+ if err != nil {
+ return "", fmt.Errorf("r.WriteOutputEvents: %w", err)
+ }
+ }
+
+ // Update the extremities of the event graph for the room
+ return event.EventID(), nil
+}
+
+func (r *Inputer) calculateAndSetState(
+ ctx context.Context,
+ input api.InputRoomEvent,
+ roomInfo types.RoomInfo,
+ stateAtEvent *types.StateAtEvent,
+ event gomatrixserverlib.Event,
+) error {
+ var err error
+ roomState := state.NewStateResolution(r.DB, roomInfo)
+
+ if input.HasState {
+ // Check here if we think we're in the room already.
+ stateAtEvent.Overwrite = true
+ var joinEventNIDs []types.EventNID
+ // Request join memberships only for local users only.
+ if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil {
+ // If we have no local users that are joined to the room then any state about
+ // the room that we have is quite possibly out of date. Therefore in that case
+ // we should overwrite it rather than merge it.
+ stateAtEvent.Overwrite = len(joinEventNIDs) == 0
+ }
+
+ // We've been told what the state at the event is so we don't need to calculate it.
+ // Check that those state events are in the database and store the state.
+ var entries []types.StateEntry
+ if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
+ return err
+ }
+
+ if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
+ return err
+ }
+ } else {
+ stateAtEvent.Overwrite = false
+
+ // We haven't been told what the state at the event is so we need to calculate it from the prev_events
+ if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil {
+ return err
+ }
+ }
+ return r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
+}
diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go
new file mode 100644
index 00000000..67a7d8a4
--- /dev/null
+++ b/roomserver/internal/input/input_latest_events.go
@@ -0,0 +1,390 @@
+// Copyright 2017 Vector Creations Ltd
+// Copyright 2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package input
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/state"
+ "github.com/matrix-org/dendrite/roomserver/storage/shared"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+)
+
+// updateLatestEvents updates the list of latest events for this room in the database and writes the
+// event to the output log.
+// The latest events are the events that aren't referenced by another event in the database:
+//
+// Time goes down the page. 1 is the m.room.create event (root).
+//
+// 1 After storing 1 the latest events are {1}
+// | After storing 2 the latest events are {2}
+// 2 After storing 3 the latest events are {3}
+// / \ After storing 4 the latest events are {3,4}
+// 3 4 After storing 5 the latest events are {5,4}
+// | | After storing 6 the latest events are {5,6}
+// 5 6 <--- latest After storing 7 the latest events are {6,7}
+// |
+// 7 <----- latest
+//
+// Can only be called once at a time
+func (r *Inputer) updateLatestEvents(
+ ctx context.Context,
+ roomInfo *types.RoomInfo,
+ stateAtEvent types.StateAtEvent,
+ event gomatrixserverlib.Event,
+ sendAsServer string,
+ transactionID *api.TransactionID,
+) (err error) {
+ updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo)
+ if err != nil {
+ return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
+ }
+ succeeded := false
+ defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
+
+ u := latestEventsUpdater{
+ ctx: ctx,
+ api: r,
+ updater: updater,
+ roomInfo: roomInfo,
+ stateAtEvent: stateAtEvent,
+ event: event,
+ sendAsServer: sendAsServer,
+ transactionID: transactionID,
+ }
+
+ if err = u.doUpdateLatestEvents(); err != nil {
+ return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
+ }
+
+ succeeded = true
+ return
+}
+
+// latestEventsUpdater tracks the state used to update the latest events in the
+// room. It mostly just ferries state between the various function calls.
+// The state could be passed using function arguments, but it becomes impractical
+// when there are so many variables to pass around.
+type latestEventsUpdater struct {
+ ctx context.Context
+ api *Inputer
+ updater *shared.LatestEventsUpdater
+ roomInfo *types.RoomInfo
+ stateAtEvent types.StateAtEvent
+ event gomatrixserverlib.Event
+ transactionID *api.TransactionID
+ // Which server to send this event as.
+ sendAsServer string
+ // The eventID of the event that was processed before this one.
+ lastEventIDSent string
+ // The latest events in the room after processing this event.
+ latest []types.StateAtEventAndReference
+ // The state entries removed from and added to the current state of the
+ // room as a result of processing this event. They are sorted lists.
+ removed []types.StateEntry
+ added []types.StateEntry
+ // The state entries that are removed and added to recover the state before
+ // the event being processed. They are sorted lists.
+ stateBeforeEventRemoves []types.StateEntry
+ stateBeforeEventAdds []types.StateEntry
+ // The snapshots of current state before and after processing this event
+ oldStateNID types.StateSnapshotNID
+ newStateNID types.StateSnapshotNID
+}
+
+func (u *latestEventsUpdater) doUpdateLatestEvents() error {
+ prevEvents := u.event.PrevEvents()
+ u.lastEventIDSent = u.updater.LastEventIDSent()
+ u.oldStateNID = u.updater.CurrentStateSnapshotNID()
+
+ // If we are doing a regular event update then we will get the
+ // previous latest events to use as a part of the calculation. If
+ // we are overwriting the latest events because we have a complete
+ // state snapshot from somewhere else, e.g. a federated room join,
+ // then start with an empty set - none of the forward extremities
+ // that we knew about before matter anymore.
+ oldLatest := []types.StateAtEventAndReference{}
+ if !u.stateAtEvent.Overwrite {
+ oldLatest = u.updater.LatestEvents()
+ }
+
+ // If the event has already been written to the output log then we
+ // don't need to do anything, as we've handled it already.
+ hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID)
+ if err != nil {
+ return fmt.Errorf("u.updater.HasEventBeenSent: %w", err)
+ } else if hasBeenSent {
+ return nil
+ }
+
+ // Update the roomserver_previous_events table with references. This
+ // is effectively tracking the structure of the DAG.
+ if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil {
+ return fmt.Errorf("u.updater.StorePreviousEvents: %w", err)
+ }
+
+ // Get the event reference for our new event. This will be used when
+ // determining if the event is referenced by an existing event.
+ eventReference := u.event.EventReference()
+
+ // Check if our new event is already referenced by an existing event
+ // in the room. If it is then it isn't a latest event.
+ alreadyReferenced, err := u.updater.IsReferenced(eventReference)
+ if err != nil {
+ return fmt.Errorf("u.updater.IsReferenced: %w", err)
+ }
+
+ // Work out what the latest events are.
+ u.latest = calculateLatest(
+ oldLatest,
+ alreadyReferenced,
+ prevEvents,
+ types.StateAtEventAndReference{
+ EventReference: eventReference,
+ StateAtEvent: u.stateAtEvent,
+ },
+ )
+
+ // Now that we know what the latest events are, it's time to get the
+ // latest state.
+ if err = u.latestState(); err != nil {
+ return fmt.Errorf("u.latestState: %w", err)
+ }
+
+ // If we need to generate any output events then here's where we do it.
+ // TODO: Move this!
+ updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added)
+ if err != nil {
+ return fmt.Errorf("u.api.updateMemberships: %w", err)
+ }
+
+ update, err := u.makeOutputNewRoomEvent()
+ if err != nil {
+ return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
+ }
+ updates = append(updates, *update)
+
+ // Send the event to the output logs.
+ // We do this inside the database transaction to ensure that we only mark an event as sent if we sent it.
+ // (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but
+ // the write to the output log succeeds)
+ // TODO: This assumes that writing the event to the output log is synchronous. It should be possible to
+ // send the event asynchronously but we would need to ensure that 1) the events are written to the log in
+ // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
+ // necessary bookkeeping we'll keep the event sending synchronous for now.
+ if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
+ return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
+ }
+
+ if err = u.updater.SetLatestEvents(u.roomInfo.RoomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
+ return fmt.Errorf("u.updater.SetLatestEvents: %w", err)
+ }
+
+ if err = u.updater.MarkEventAsSent(u.stateAtEvent.EventNID); err != nil {
+ return fmt.Errorf("u.updater.MarkEventAsSent: %w", err)
+ }
+
+ return nil
+}
+
+func (u *latestEventsUpdater) latestState() error {
+ var err error
+ roomState := state.NewStateResolution(u.api.DB, *u.roomInfo)
+
+ // Get a list of the current latest events.
+ latestStateAtEvents := make([]types.StateAtEvent, len(u.latest))
+ for i := range u.latest {
+ latestStateAtEvents[i] = u.latest[i].StateAtEvent
+ }
+
+ // Takes the NIDs of the latest events and creates a state snapshot
+ // of the state after the events. The snapshot state will be resolved
+ // using the correct state resolution algorithm for the room.
+ u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents(
+ u.ctx, latestStateAtEvents,
+ )
+ if err != nil {
+ return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err)
+ }
+
+ // If we are overwriting the state then we should make sure that we
+ // don't send anything out over federation again, it will very likely
+ // be a repeat.
+ if u.stateAtEvent.Overwrite {
+ u.sendAsServer = ""
+ }
+
+ // Now that we have a new state snapshot based on the latest events,
+ // we can compare that new snapshot to the previous one and see what
+ // has changed. This gives us one list of removed state events and
+ // another list of added ones. Replacing a value for a state-key tuple
+ // will result one removed (the old event) and one added (the new event).
+ u.removed, u.added, err = roomState.DifferenceBetweeenStateSnapshots(
+ u.ctx, u.oldStateNID, u.newStateNID,
+ )
+ if err != nil {
+ return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err)
+ }
+
+ // Also work out the state before the event removes and the event
+ // adds.
+ u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots(
+ u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
+ )
+ if err != nil {
+ return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err)
+ }
+
+ return nil
+}
+
+func calculateLatest(
+ oldLatest []types.StateAtEventAndReference,
+ alreadyReferenced bool,
+ prevEvents []gomatrixserverlib.EventReference,
+ newEvent types.StateAtEventAndReference,
+) []types.StateAtEventAndReference {
+ var alreadyInLatest bool
+ var newLatest []types.StateAtEventAndReference
+ for _, l := range oldLatest {
+ keep := true
+ for _, prevEvent := range prevEvents {
+ if l.EventID == prevEvent.EventID && bytes.Equal(l.EventSHA256, prevEvent.EventSHA256) {
+ // This event can be removed from the latest events cause we've found an event that references it.
+ // (If an event is referenced by another event then it can't be one of the latest events in the room
+ // because we have an event that comes after it)
+ keep = false
+ break
+ }
+ }
+ if l.EventNID == newEvent.EventNID {
+ alreadyInLatest = true
+ }
+ if keep {
+ // Keep the event in the latest events.
+ newLatest = append(newLatest, l)
+ }
+ }
+
+ if !alreadyReferenced && !alreadyInLatest {
+ // This event is not referenced by any of the events in the room
+ // and the event is not already in the latest events.
+ // Add it to the latest events
+ newLatest = append(newLatest, newEvent)
+ }
+
+ return newLatest
+}
+
+func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) {
+
+ latestEventIDs := make([]string, len(u.latest))
+ for i := range u.latest {
+ latestEventIDs[i] = u.latest[i].EventID
+ }
+
+ ore := api.OutputNewRoomEvent{
+ Event: u.event.Headered(u.roomInfo.RoomVersion),
+ LastSentEventID: u.lastEventIDSent,
+ LatestEventIDs: latestEventIDs,
+ TransactionID: u.transactionID,
+ }
+
+ eventIDMap, err := u.stateEventMap()
+ if err != nil {
+ return nil, err
+ }
+
+ for _, entry := range u.added {
+ ore.AddsStateEventIDs = append(ore.AddsStateEventIDs, eventIDMap[entry.EventNID])
+ }
+ for _, entry := range u.removed {
+ ore.RemovesStateEventIDs = append(ore.RemovesStateEventIDs, eventIDMap[entry.EventNID])
+ }
+ for _, entry := range u.stateBeforeEventRemoves {
+ ore.StateBeforeRemovesEventIDs = append(ore.StateBeforeRemovesEventIDs, eventIDMap[entry.EventNID])
+ }
+ for _, entry := range u.stateBeforeEventAdds {
+ ore.StateBeforeAddsEventIDs = append(ore.StateBeforeAddsEventIDs, eventIDMap[entry.EventNID])
+ }
+ ore.SendAsServer = u.sendAsServer
+
+ // include extra state events if they were added as nearly every downstream component will care about it
+ // and we'd rather not have them all hit QueryEventsByID at the same time!
+ if len(ore.AddsStateEventIDs) > 0 {
+ ore.AddStateEvents, err = u.extraEventsForIDs(u.roomInfo.RoomVersion, ore.AddsStateEventIDs)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load add_state_events from db: %w", err)
+ }
+ }
+
+ return &api.OutputEvent{
+ Type: api.OutputTypeNewRoomEvent,
+ NewRoomEvent: &ore,
+ }, nil
+}
+
+// extraEventsForIDs returns the full events for the event IDs given, but does not include the current event being
+// updated.
+func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) {
+ var extraEventIDs []string
+ for _, e := range eventIDs {
+ if e == u.event.EventID() {
+ continue
+ }
+ extraEventIDs = append(extraEventIDs, e)
+ }
+ if len(extraEventIDs) == 0 {
+ return nil, nil
+ }
+ extraEvents, err := u.api.DB.EventsFromIDs(u.ctx, extraEventIDs)
+ if err != nil {
+ return nil, err
+ }
+ var h []gomatrixserverlib.HeaderedEvent
+ for _, e := range extraEvents {
+ h = append(h, e.Headered(roomVersion))
+ }
+ return h, nil
+}
+
+// retrieve an event nid -> event ID map for all events that need updating
+func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) {
+ var stateEventNIDs []types.EventNID
+ var allStateEntries []types.StateEntry
+ allStateEntries = append(allStateEntries, u.added...)
+ allStateEntries = append(allStateEntries, u.removed...)
+ allStateEntries = append(allStateEntries, u.stateBeforeEventRemoves...)
+ allStateEntries = append(allStateEntries, u.stateBeforeEventAdds...)
+ for _, entry := range allStateEntries {
+ stateEventNIDs = append(stateEventNIDs, entry.EventNID)
+ }
+ stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
+ return u.api.DB.EventIDs(u.ctx, stateEventNIDs)
+}
+
+type eventNIDSorter []types.EventNID
+
+func (s eventNIDSorter) Len() int { return len(s) }
+func (s eventNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
+func (s eventNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go
new file mode 100644
index 00000000..8befcd64
--- /dev/null
+++ b/roomserver/internal/input/input_membership.go
@@ -0,0 +1,267 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package input
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/internal/helpers"
+ "github.com/matrix-org/dendrite/roomserver/storage/shared"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// updateMembership updates the current membership and the invites for each
+// user affected by a change in the current state of the room.
+// Returns a list of output events to write to the kafka log to inform the
+// consumers about the invites added or retired by the change in current state.
+func (r *Inputer) updateMemberships(
+ ctx context.Context,
+ updater *shared.LatestEventsUpdater,
+ removed, added []types.StateEntry,
+) ([]api.OutputEvent, error) {
+ changes := membershipChanges(removed, added)
+ var eventNIDs []types.EventNID
+ for _, change := range changes {
+ if change.addedEventNID != 0 {
+ eventNIDs = append(eventNIDs, change.addedEventNID)
+ }
+ if change.removedEventNID != 0 {
+ eventNIDs = append(eventNIDs, change.removedEventNID)
+ }
+ }
+
+ // Load the event JSON so we can look up the "membership" key.
+ // TODO: Maybe add a membership key to the events table so we can load that
+ // key without having to load the entire event JSON?
+ events, err := r.DB.Events(ctx, eventNIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ var updates []api.OutputEvent
+
+ for _, change := range changes {
+ var ae *gomatrixserverlib.Event
+ var re *gomatrixserverlib.Event
+ targetUserNID := change.EventStateKeyNID
+ if change.removedEventNID != 0 {
+ ev, _ := helpers.EventMap(events).Lookup(change.removedEventNID)
+ if ev != nil {
+ re = &ev.Event
+ }
+ }
+ if change.addedEventNID != 0 {
+ ev, _ := helpers.EventMap(events).Lookup(change.addedEventNID)
+ if ev != nil {
+ ae = &ev.Event
+ }
+ }
+ if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil {
+ return nil, err
+ }
+ }
+ return updates, nil
+}
+
+func (r *Inputer) updateMembership(
+ updater *shared.LatestEventsUpdater,
+ targetUserNID types.EventStateKeyNID,
+ remove, add *gomatrixserverlib.Event,
+ updates []api.OutputEvent,
+) ([]api.OutputEvent, error) {
+ var err error
+ // Default the membership to Leave if no event was added or removed.
+ oldMembership := gomatrixserverlib.Leave
+ newMembership := gomatrixserverlib.Leave
+
+ if remove != nil {
+ oldMembership, err = remove.Membership()
+ if err != nil {
+ return nil, err
+ }
+ }
+ if add != nil {
+ newMembership, err = add.Membership()
+ if err != nil {
+ return nil, err
+ }
+ }
+ if oldMembership == newMembership && newMembership != gomatrixserverlib.Join {
+ // If the membership is the same then nothing changed and we can return
+ // immediately, unless it's a Join update (e.g. profile update).
+ return updates, nil
+ }
+
+ if add == nil {
+ // This can happen when we have rejoined a room and suddenly we have a
+ // divergence between the former state and the new one. We don't want to
+ // act on removals and apparently there are no adds, so stop here.
+ return updates, nil
+ }
+
+ mu, err := updater.MembershipUpdater(targetUserNID, r.isLocalTarget(add))
+ if err != nil {
+ return nil, err
+ }
+
+ switch newMembership {
+ case gomatrixserverlib.Invite:
+ return helpers.UpdateToInviteMembership(mu, add, updates, updater.RoomVersion())
+ case gomatrixserverlib.Join:
+ return updateToJoinMembership(mu, add, updates)
+ case gomatrixserverlib.Leave, gomatrixserverlib.Ban:
+ return updateToLeaveMembership(mu, add, newMembership, updates)
+ default:
+ panic(fmt.Errorf(
+ "input: membership %q is not one of the allowed values", newMembership,
+ ))
+ }
+}
+
+func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool {
+ isTargetLocalUser := false
+ if statekey := event.StateKey(); statekey != nil {
+ _, domain, _ := gomatrixserverlib.SplitID('@', *statekey)
+ isTargetLocalUser = domain == r.ServerName
+ }
+ return isTargetLocalUser
+}
+
+func updateToJoinMembership(
+ mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
+) ([]api.OutputEvent, error) {
+ // If the user is already marked as being joined, we call SetToJoin to update
+ // the event ID then we can return immediately. Retired is ignored as there
+ // is no invite event to retire.
+ if mu.IsJoin() {
+ _, err := mu.SetToJoin(add.Sender(), add.EventID(), true)
+ if err != nil {
+ return nil, err
+ }
+ return updates, nil
+ }
+ // When we mark a user as being joined we will invalidate any invites that
+ // are active for that user. We notify the consumers that the invites have
+ // been retired using a special event, even though they could infer this
+ // by studying the state changes in the room event stream.
+ retired, err := mu.SetToJoin(add.Sender(), add.EventID(), false)
+ if err != nil {
+ return nil, err
+ }
+ for _, eventID := range retired {
+ orie := api.OutputRetireInviteEvent{
+ EventID: eventID,
+ Membership: gomatrixserverlib.Join,
+ RetiredByEventID: add.EventID(),
+ TargetUserID: *add.StateKey(),
+ }
+ updates = append(updates, api.OutputEvent{
+ Type: api.OutputTypeRetireInviteEvent,
+ RetireInviteEvent: &orie,
+ })
+ }
+ return updates, nil
+}
+
+func updateToLeaveMembership(
+ mu *shared.MembershipUpdater, add *gomatrixserverlib.Event,
+ newMembership string, updates []api.OutputEvent,
+) ([]api.OutputEvent, error) {
+ // If the user is already neither joined, nor invited to the room then we
+ // can return immediately.
+ if mu.IsLeave() {
+ return updates, nil
+ }
+ // When we mark a user as having left we will invalidate any invites that
+ // are active for that user. We notify the consumers that the invites have
+ // been retired using a special event, even though they could infer this
+ // by studying the state changes in the room event stream.
+ retired, err := mu.SetToLeave(add.Sender(), add.EventID())
+ if err != nil {
+ return nil, err
+ }
+ for _, eventID := range retired {
+ orie := api.OutputRetireInviteEvent{
+ EventID: eventID,
+ Membership: newMembership,
+ RetiredByEventID: add.EventID(),
+ TargetUserID: *add.StateKey(),
+ }
+ updates = append(updates, api.OutputEvent{
+ Type: api.OutputTypeRetireInviteEvent,
+ RetireInviteEvent: &orie,
+ })
+ }
+ return updates, nil
+}
+
+// membershipChanges pairs up the membership state changes.
+func membershipChanges(removed, added []types.StateEntry) []stateChange {
+ changes := pairUpChanges(removed, added)
+ var result []stateChange
+ for _, c := range changes {
+ if c.EventTypeNID == types.MRoomMemberNID {
+ result = append(result, c)
+ }
+ }
+ return result
+}
+
+type stateChange struct {
+ types.StateKeyTuple
+ removedEventNID types.EventNID
+ addedEventNID types.EventNID
+}
+
+// pairUpChanges pairs up the state events added and removed for each type,
+// state key tuple.
+func pairUpChanges(removed, added []types.StateEntry) []stateChange {
+ tuples := make(map[types.StateKeyTuple]stateChange)
+ changes := []stateChange{}
+
+ // First, go through the newly added state entries.
+ for _, add := range added {
+ if change, ok := tuples[add.StateKeyTuple]; ok {
+ // If we already have an entry, update it.
+ change.addedEventNID = add.EventNID
+ tuples[add.StateKeyTuple] = change
+ } else {
+ // Otherwise, create a new entry.
+ tuples[add.StateKeyTuple] = stateChange{add.StateKeyTuple, 0, add.EventNID}
+ }
+ }
+
+ // Now go through the removed state entries.
+ for _, remove := range removed {
+ if change, ok := tuples[remove.StateKeyTuple]; ok {
+ // If we already have an entry, update it.
+ change.removedEventNID = remove.EventNID
+ tuples[remove.StateKeyTuple] = change
+ } else {
+ // Otherwise, create a new entry.
+ tuples[remove.StateKeyTuple] = stateChange{remove.StateKeyTuple, remove.EventNID, 0}
+ }
+ }
+
+ // Now return the changes as an array.
+ for _, change := range tuples {
+ changes = append(changes, change)
+ }
+
+ return changes
+}