aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/helpers/helpers.go
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-09-02 13:47:31 +0100
committerGitHub <noreply@github.com>2020-09-02 13:47:31 +0100
commite473320e733484b1cc6da0588fd2ccf4affb3d24 (patch)
tree51385110bdfc89b82a8d005d77c9951f8db15e4e /roomserver/internal/helpers/helpers.go
parent02a73f29f861c637f30df4a2bb1fce400e481a3c (diff)
Refactor roomserver/internal - split perform stuff out (#1380)
- New package `perform` which contains all `Perform` functions - New package `helpers` which contains helper functions used by both perform and query/input functions. - Perform invite/leave have no idea how to `WriteOutputEvents` and this is now returned from `PerformInvite` or `PerformLeave` respectively. Still to do: - RSAPI is fed into the inviter/joiner/leaver - this introduces circular logic so will need to be removed. - Put query operations in a `query` package. - Put input operations (and output) in an `input` package. - Factor out helper functions as much as possible, possibly rejigging the storage layer in the process.
Diffstat (limited to 'roomserver/internal/helpers/helpers.go')
-rw-r--r--roomserver/internal/helpers/helpers.go326
1 files changed, 326 insertions, 0 deletions
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go
new file mode 100644
index 00000000..d7bb40af
--- /dev/null
+++ b/roomserver/internal/helpers/helpers.go
@@ -0,0 +1,326 @@
+package helpers
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/auth"
+ "github.com/matrix-org/dendrite/roomserver/state"
+ "github.com/matrix-org/dendrite/roomserver/storage"
+ "github.com/matrix-org/dendrite/roomserver/storage/shared"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+)
+
+// TODO: temporary package which has helper functions used by both internal/perform packages.
+// Move these to a more sensible place.
+
+func UpdateToInviteMembership(
+ mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
+ roomVersion gomatrixserverlib.RoomVersion,
+) ([]api.OutputEvent, error) {
+ // We may have already sent the invite to the user, either because we are
+ // reprocessing this event, or because the we received this invite from a
+ // remote server via the federation invite API. In those cases we don't need
+ // to send the event.
+ needsSending, err := mu.SetToInvite(*add)
+ if err != nil {
+ return nil, err
+ }
+ if needsSending {
+ // We notify the consumers using a special event even though we will
+ // notify them about the change in current state as part of the normal
+ // room event stream. This ensures that the consumers only have to
+ // consider a single stream of events when determining whether a user
+ // is invited, rather than having to combine multiple streams themselves.
+ onie := api.OutputNewInviteEvent{
+ Event: add.Headered(roomVersion),
+ RoomVersion: roomVersion,
+ }
+ updates = append(updates, api.OutputEvent{
+ Type: api.OutputTypeNewInviteEvent,
+ NewInviteEvent: &onie,
+ })
+ }
+ return updates, nil
+}
+
+func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) {
+ info, err := db.RoomInfo(ctx, roomID)
+ if err != nil {
+ return false, err
+ }
+ if info == nil {
+ return false, fmt.Errorf("unknown room %s", roomID)
+ }
+
+ eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
+ if err != nil {
+ return false, err
+ }
+
+ events, err := db.Events(ctx, eventNIDs)
+ if err != nil {
+ return false, err
+ }
+ gmslEvents := make([]gomatrixserverlib.Event, len(events))
+ for i := range events {
+ gmslEvents[i] = events[i].Event
+ }
+ return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
+}
+
+func IsInvitePending(
+ ctx context.Context, db storage.Database,
+ roomID, userID string,
+) (bool, string, string, error) {
+ // Look up the room NID for the supplied room ID.
+ info, err := db.RoomInfo(ctx, roomID)
+ if err != nil {
+ return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err)
+ }
+ if info == nil {
+ return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID)
+ }
+
+ // Look up the state key NID for the supplied user ID.
+ targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID})
+ if err != nil {
+ return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
+ }
+ targetUserNID, targetUserFound := targetUserNIDs[userID]
+ if !targetUserFound {
+ return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
+ }
+
+ // Let's see if we have an event active for the user in the room. If
+ // we do then it will contain a server name that we can direct the
+ // send_leave to.
+ senderUserNIDs, eventIDs, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID)
+ if err != nil {
+ return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
+ }
+ if len(senderUserNIDs) == 0 {
+ return false, "", "", nil
+ }
+ userNIDToEventID := make(map[types.EventStateKeyNID]string)
+ for i, nid := range senderUserNIDs {
+ userNIDToEventID[nid] = eventIDs[i]
+ }
+
+ // Look up the user ID from the NID.
+ senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs)
+ if err != nil {
+ return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err)
+ }
+ if len(senderUsers) == 0 {
+ return false, "", "", fmt.Errorf("no senderUsers")
+ }
+
+ senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
+ if !senderUserFound {
+ return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
+ }
+
+ return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil
+}
+
+// GetMembershipsAtState filters the state events to
+// only keep the "m.room.member" events with a "join" membership. These events are returned.
+// Returns an error if there was an issue fetching the events.
+func GetMembershipsAtState(
+ ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
+) ([]types.Event, error) {
+
+ var eventNIDs []types.EventNID
+ for _, entry := range stateEntries {
+ // Filter the events to retrieve to only keep the membership events
+ if entry.EventTypeNID == types.MRoomMemberNID {
+ eventNIDs = append(eventNIDs, entry.EventNID)
+ }
+ }
+
+ // Get all of the events in this state
+ stateEvents, err := db.Events(ctx, eventNIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ if !joinedOnly {
+ return stateEvents, nil
+ }
+
+ // Filter the events to only keep the "join" membership events
+ var events []types.Event
+ for _, event := range stateEvents {
+ membership, err := event.Membership()
+ if err != nil {
+ return nil, err
+ }
+
+ if membership == gomatrixserverlib.Join {
+ events = append(events, event)
+ }
+ }
+
+ return events, nil
+}
+
+func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
+ roomState := state.NewStateResolution(db, info)
+ // Lookup the event NID
+ eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
+ if err != nil {
+ return nil, err
+ }
+ eventIDs := []string{eIDs[eventNID]}
+
+ prevState, err := db.StateAtEventIDs(ctx, eventIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ // Fetch the state as it was when this event was fired
+ return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
+}
+
+func LoadEvents(
+ ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
+) ([]gomatrixserverlib.Event, error) {
+ stateEvents, err := db.Events(ctx, eventNIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ result := make([]gomatrixserverlib.Event, len(stateEvents))
+ for i := range stateEvents {
+ result[i] = stateEvents[i].Event
+ }
+ return result, nil
+}
+
+func LoadStateEvents(
+ ctx context.Context, db storage.Database, stateEntries []types.StateEntry,
+) ([]gomatrixserverlib.Event, error) {
+ eventNIDs := make([]types.EventNID, len(stateEntries))
+ for i := range stateEntries {
+ eventNIDs[i] = stateEntries[i].EventNID
+ }
+ return LoadEvents(ctx, db, eventNIDs)
+}
+
+func CheckServerAllowedToSeeEvent(
+ ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
+) (bool, error) {
+ roomState := state.NewStateResolution(db, info)
+ stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
+ if err != nil {
+ return false, err
+ }
+
+ // TODO: We probably want to make it so that we don't have to pull
+ // out all the state if possible.
+ stateAtEvent, err := LoadStateEvents(ctx, db, stateEntries)
+ if err != nil {
+ return false, err
+ }
+
+ return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
+}
+
+// TODO: Remove this when we have tests to assert correctness of this function
+// nolint:gocyclo
+func ScanEventTree(
+ ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int,
+ serverName gomatrixserverlib.ServerName,
+) ([]types.EventNID, error) {
+ var resultNIDs []types.EventNID
+ var err error
+ var allowed bool
+ var events []types.Event
+ var next []string
+ var pre string
+
+ // TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be)
+ // Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing
+ // so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in
+ // duplicate events being sent in response to /backfill requests.
+ initialIgnoreList := make(map[string]bool, len(visited))
+ for k, v := range visited {
+ initialIgnoreList[k] = v
+ }
+
+ resultNIDs = make([]types.EventNID, 0, limit)
+
+ var checkedServerInRoom bool
+ var isServerInRoom bool
+
+ // Loop through the event IDs to retrieve the requested events and go
+ // through the whole tree (up to the provided limit) using the events'
+ // "prev_event" key.
+BFSLoop:
+ for len(front) > 0 {
+ // Prevent unnecessary allocations: reset the slice only when not empty.
+ if len(next) > 0 {
+ next = make([]string, 0)
+ }
+ // Retrieve the events to process from the database.
+ events, err = db.EventsFromIDs(ctx, front)
+ if err != nil {
+ return resultNIDs, err
+ }
+
+ if !checkedServerInRoom && len(events) > 0 {
+ // It's nasty that we have to extract the room ID from an event, but many federation requests
+ // only talk in event IDs, no room IDs at all (!!!)
+ ev := events[0]
+ isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID())
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
+ }
+ checkedServerInRoom = true
+ }
+
+ for _, ev := range events {
+ // Break out of the loop if the provided limit is reached.
+ if len(resultNIDs) == limit {
+ break BFSLoop
+ }
+
+ if !initialIgnoreList[ev.EventID()] {
+ // Update the list of events to retrieve.
+ resultNIDs = append(resultNIDs, ev.EventNID)
+ }
+ // Loop through the event's parents.
+ for _, pre = range ev.PrevEventIDs() {
+ // Only add an event to the list of next events to process if it
+ // hasn't been seen before.
+ if !visited[pre] {
+ visited[pre] = true
+ allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom)
+ if err != nil {
+ util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
+ "Error checking if allowed to see event",
+ )
+ return resultNIDs, err
+ }
+
+ // If the event hasn't been seen before and the HS
+ // requesting to retrieve it is allowed to do so, add it to
+ // the list of events to retrieve.
+ if allowed {
+ next = append(next, pre)
+ } else {
+ util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
+ }
+ }
+ }
+ }
+ // Repeat the same process with the parent events we just processed.
+ front = next
+ }
+
+ return resultNIDs, err
+}