aboutsummaryrefslogtreecommitdiff
path: root/roomserver
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-09-02 13:47:31 +0100
committerGitHub <noreply@github.com>2020-09-02 13:47:31 +0100
commite473320e733484b1cc6da0588fd2ccf4affb3d24 (patch)
tree51385110bdfc89b82a8d005d77c9951f8db15e4e /roomserver
parent02a73f29f861c637f30df4a2bb1fce400e481a3c (diff)
Refactor roomserver/internal - split perform stuff out (#1380)
- New package `perform` which contains all `Perform` functions - New package `helpers` which contains helper functions used by both perform and query/input functions. - Perform invite/leave have no idea how to `WriteOutputEvents` and this is now returned from `PerformInvite` or `PerformLeave` respectively. Still to do: - RSAPI is fed into the inviter/joiner/leaver - this introduces circular logic so will need to be removed. - Put query operations in a `query` package. - Put input operations (and output) in an `input` package. - Factor out helper functions as much as possible, possibly rejigging the storage layer in the process.
Diffstat (limited to 'roomserver')
-rw-r--r--roomserver/internal/api.go120
-rw-r--r--roomserver/internal/helpers/auth.go (renamed from roomserver/internal/input_authevents.go)16
-rw-r--r--roomserver/internal/helpers/auth_test.go (renamed from roomserver/internal/input_authevents_test.go)6
-rw-r--r--roomserver/internal/helpers/helpers.go326
-rw-r--r--roomserver/internal/input.go9
-rw-r--r--roomserver/internal/input_events.go3
-rw-r--r--roomserver/internal/input_membership.go37
-rw-r--r--roomserver/internal/perform/perform_backfill.go (renamed from roomserver/internal/perform_backfill.go)237
-rw-r--r--roomserver/internal/perform/perform_invite.go (renamed from roomserver/internal/perform_invite.go)59
-rw-r--r--roomserver/internal/perform/perform_join.go (renamed from roomserver/internal/perform_join.go)37
-rw-r--r--roomserver/internal/perform/perform_leave.go (renamed from roomserver/internal/perform_leave.go)126
-rw-r--r--roomserver/internal/perform/perform_publish.go (renamed from roomserver/internal/perform_publish.go)9
-rw-r--r--roomserver/internal/query.go466
-rw-r--r--roomserver/roomserver.go14
14 files changed, 819 insertions, 646 deletions
diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go
index f94c72f0..1897f7a5 100644
--- a/roomserver/internal/api.go
+++ b/roomserver/internal/api.go
@@ -1,12 +1,15 @@
package internal
import (
+ "context"
"sync"
"github.com/Shopify/sarama"
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/config"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/internal/perform"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -20,7 +23,122 @@ type RoomserverInternalAPI struct {
ServerName gomatrixserverlib.ServerName
KeyRing gomatrixserverlib.JSONVerifier
FedClient *gomatrixserverlib.FederationClient
- OutputRoomEventTopic string // Kafka topic for new output room events
+ OutputRoomEventTopic string // Kafka topic for new output room events
+ Inviter *perform.Inviter
+ Joiner *perform.Joiner
+ Leaver *perform.Leaver
+ Publisher *perform.Publisher
+ Backfiller *perform.Backfiller
mutexes sync.Map // room ID -> *sync.Mutex, protects calls to processRoomEvent
fsAPI fsAPI.FederationSenderInternalAPI
}
+
+func NewRoomserverAPI(
+ cfg *config.RoomServer, roomserverDB storage.Database, producer sarama.SyncProducer,
+ outputRoomEventTopic string, caches caching.RoomServerCaches, fedClient *gomatrixserverlib.FederationClient,
+ keyRing gomatrixserverlib.JSONVerifier,
+) *RoomserverInternalAPI {
+ a := &RoomserverInternalAPI{
+ DB: roomserverDB,
+ Cfg: cfg,
+ Producer: producer,
+ Cache: caches,
+ ServerName: cfg.Matrix.ServerName,
+ KeyRing: keyRing,
+ FedClient: fedClient,
+ OutputRoomEventTopic: outputRoomEventTopic,
+ // perform-er structs get initialised when we have a federation sender to use
+ }
+ return a
+}
+
+// SetFederationSenderInputAPI passes in a federation sender input API reference
+// so that we can avoid the chicken-and-egg problem of both the roomserver input API
+// and the federation sender input API being interdependent.
+func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {
+ r.fsAPI = fsAPI
+
+ r.Inviter = &perform.Inviter{
+ DB: r.DB,
+ Cfg: r.Cfg,
+ FSAPI: r.fsAPI,
+ RSAPI: r,
+ }
+ r.Joiner = &perform.Joiner{
+ ServerName: r.Cfg.Matrix.ServerName,
+ Cfg: r.Cfg,
+ DB: r.DB,
+ FSAPI: r.fsAPI,
+ RSAPI: r,
+ }
+ r.Leaver = &perform.Leaver{
+ Cfg: r.Cfg,
+ DB: r.DB,
+ FSAPI: r.fsAPI,
+ RSAPI: r,
+ }
+ r.Publisher = &perform.Publisher{
+ DB: r.DB,
+ }
+ r.Backfiller = &perform.Backfiller{
+ ServerName: r.ServerName,
+ DB: r.DB,
+ FedClient: r.FedClient,
+ KeyRing: r.KeyRing,
+ }
+}
+
+func (r *RoomserverInternalAPI) PerformInvite(
+ ctx context.Context,
+ req *api.PerformInviteRequest,
+ res *api.PerformInviteResponse,
+) error {
+ outputEvents, err := r.Inviter.PerformInvite(ctx, req, res)
+ if err != nil {
+ return err
+ }
+ if len(outputEvents) == 0 {
+ return nil
+ }
+ return r.WriteOutputEvents(req.Event.RoomID(), outputEvents)
+}
+
+func (r *RoomserverInternalAPI) PerformJoin(
+ ctx context.Context,
+ req *api.PerformJoinRequest,
+ res *api.PerformJoinResponse,
+) {
+ r.Joiner.PerformJoin(ctx, req, res)
+}
+
+func (r *RoomserverInternalAPI) PerformLeave(
+ ctx context.Context,
+ req *api.PerformLeaveRequest,
+ res *api.PerformLeaveResponse,
+) error {
+ outputEvents, err := r.Leaver.PerformLeave(ctx, req, res)
+ if err != nil {
+ return err
+ }
+ if len(outputEvents) == 0 {
+ return nil
+ }
+ return r.WriteOutputEvents(req.RoomID, outputEvents)
+}
+
+func (r *RoomserverInternalAPI) PerformPublish(
+ ctx context.Context,
+ req *api.PerformPublishRequest,
+ res *api.PerformPublishResponse,
+) {
+ r.Publisher.PerformPublish(ctx, req, res)
+}
+
+// Query a given amount (or less) of events prior to a given set of events.
+func (r *RoomserverInternalAPI) PerformBackfill(
+ ctx context.Context,
+ request *api.PerformBackfillRequest,
+ response *api.PerformBackfillResponse,
+) error {
+ return r.Backfiller.PerformBackfill(ctx, request, response)
+}
diff --git a/roomserver/internal/input_authevents.go b/roomserver/internal/helpers/auth.go
index e3828f56..060f0a0e 100644
--- a/roomserver/internal/input_authevents.go
+++ b/roomserver/internal/helpers/auth.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package internal
+package helpers
import (
"context"
@@ -23,9 +23,9 @@ import (
"github.com/matrix-org/gomatrixserverlib"
)
-// checkAuthEvents checks that the event passes authentication checks
+// CheckAuthEvents checks that the event passes authentication checks
// Returns the numeric IDs for the auth events.
-func checkAuthEvents(
+func CheckAuthEvents(
ctx context.Context,
db storage.Database,
event gomatrixserverlib.HeaderedEvent,
@@ -63,7 +63,7 @@ func checkAuthEvents(
type authEvents struct {
stateKeyNIDMap map[string]types.EventStateKeyNID
state stateEntryMap
- events eventMap
+ events EventMap
}
// Create implements gomatrixserverlib.AuthEventProvider
@@ -99,7 +99,7 @@ func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *
if !ok {
return nil
}
- event, ok := ae.events.lookup(eventNID)
+ event, ok := ae.events.Lookup(eventNID)
if !ok {
return nil
}
@@ -118,7 +118,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
if !ok {
return nil
}
- event, ok := ae.events.lookup(eventNID)
+ event, ok := ae.events.Lookup(eventNID)
if !ok {
return nil
}
@@ -224,10 +224,10 @@ func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.Even
// Map from numeric event ID to event.
// Implemented using binary search on a sorted array.
-type eventMap []types.Event
+type EventMap []types.Event
// lookup an entry in the event map.
-func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
+func (m EventMap) Lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
// Since the list is sorted we can implement this using binary search.
// This is faster than using a hash map.
// We don't have to worry about pathological cases because the keys are fixed
diff --git a/roomserver/internal/input_authevents_test.go b/roomserver/internal/helpers/auth_test.go
index 6b981571..2a1c3ea4 100644
--- a/roomserver/internal/input_authevents_test.go
+++ b/roomserver/internal/helpers/auth_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package internal
+package helpers
import (
"testing"
@@ -95,7 +95,7 @@ func TestStateEntryMap(t *testing.T) {
}
func TestEventMap(t *testing.T) {
- events := eventMap([]types.Event{
+ events := EventMap([]types.Event{
{EventNID: 1},
{EventNID: 2},
{EventNID: 3},
@@ -123,7 +123,7 @@ func TestEventMap(t *testing.T) {
}
for _, testCase := range testCases {
- gotEvent, gotOK := events.lookup(testCase.inputEventNID)
+ gotEvent, gotOK := events.Lookup(testCase.inputEventNID)
if testCase.wantOK != gotOK {
t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK)
}
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go
new file mode 100644
index 00000000..d7bb40af
--- /dev/null
+++ b/roomserver/internal/helpers/helpers.go
@@ -0,0 +1,326 @@
+package helpers
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/auth"
+ "github.com/matrix-org/dendrite/roomserver/state"
+ "github.com/matrix-org/dendrite/roomserver/storage"
+ "github.com/matrix-org/dendrite/roomserver/storage/shared"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+)
+
+// TODO: temporary package which has helper functions used by both internal/perform packages.
+// Move these to a more sensible place.
+
+func UpdateToInviteMembership(
+ mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
+ roomVersion gomatrixserverlib.RoomVersion,
+) ([]api.OutputEvent, error) {
+ // We may have already sent the invite to the user, either because we are
+ // reprocessing this event, or because the we received this invite from a
+ // remote server via the federation invite API. In those cases we don't need
+ // to send the event.
+ needsSending, err := mu.SetToInvite(*add)
+ if err != nil {
+ return nil, err
+ }
+ if needsSending {
+ // We notify the consumers using a special event even though we will
+ // notify them about the change in current state as part of the normal
+ // room event stream. This ensures that the consumers only have to
+ // consider a single stream of events when determining whether a user
+ // is invited, rather than having to combine multiple streams themselves.
+ onie := api.OutputNewInviteEvent{
+ Event: add.Headered(roomVersion),
+ RoomVersion: roomVersion,
+ }
+ updates = append(updates, api.OutputEvent{
+ Type: api.OutputTypeNewInviteEvent,
+ NewInviteEvent: &onie,
+ })
+ }
+ return updates, nil
+}
+
+func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) {
+ info, err := db.RoomInfo(ctx, roomID)
+ if err != nil {
+ return false, err
+ }
+ if info == nil {
+ return false, fmt.Errorf("unknown room %s", roomID)
+ }
+
+ eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
+ if err != nil {
+ return false, err
+ }
+
+ events, err := db.Events(ctx, eventNIDs)
+ if err != nil {
+ return false, err
+ }
+ gmslEvents := make([]gomatrixserverlib.Event, len(events))
+ for i := range events {
+ gmslEvents[i] = events[i].Event
+ }
+ return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
+}
+
+func IsInvitePending(
+ ctx context.Context, db storage.Database,
+ roomID, userID string,
+) (bool, string, string, error) {
+ // Look up the room NID for the supplied room ID.
+ info, err := db.RoomInfo(ctx, roomID)
+ if err != nil {
+ return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err)
+ }
+ if info == nil {
+ return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID)
+ }
+
+ // Look up the state key NID for the supplied user ID.
+ targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID})
+ if err != nil {
+ return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
+ }
+ targetUserNID, targetUserFound := targetUserNIDs[userID]
+ if !targetUserFound {
+ return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
+ }
+
+ // Let's see if we have an event active for the user in the room. If
+ // we do then it will contain a server name that we can direct the
+ // send_leave to.
+ senderUserNIDs, eventIDs, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID)
+ if err != nil {
+ return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
+ }
+ if len(senderUserNIDs) == 0 {
+ return false, "", "", nil
+ }
+ userNIDToEventID := make(map[types.EventStateKeyNID]string)
+ for i, nid := range senderUserNIDs {
+ userNIDToEventID[nid] = eventIDs[i]
+ }
+
+ // Look up the user ID from the NID.
+ senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs)
+ if err != nil {
+ return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err)
+ }
+ if len(senderUsers) == 0 {
+ return false, "", "", fmt.Errorf("no senderUsers")
+ }
+
+ senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
+ if !senderUserFound {
+ return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
+ }
+
+ return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil
+}
+
+// GetMembershipsAtState filters the state events to
+// only keep the "m.room.member" events with a "join" membership. These events are returned.
+// Returns an error if there was an issue fetching the events.
+func GetMembershipsAtState(
+ ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
+) ([]types.Event, error) {
+
+ var eventNIDs []types.EventNID
+ for _, entry := range stateEntries {
+ // Filter the events to retrieve to only keep the membership events
+ if entry.EventTypeNID == types.MRoomMemberNID {
+ eventNIDs = append(eventNIDs, entry.EventNID)
+ }
+ }
+
+ // Get all of the events in this state
+ stateEvents, err := db.Events(ctx, eventNIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ if !joinedOnly {
+ return stateEvents, nil
+ }
+
+ // Filter the events to only keep the "join" membership events
+ var events []types.Event
+ for _, event := range stateEvents {
+ membership, err := event.Membership()
+ if err != nil {
+ return nil, err
+ }
+
+ if membership == gomatrixserverlib.Join {
+ events = append(events, event)
+ }
+ }
+
+ return events, nil
+}
+
+func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
+ roomState := state.NewStateResolution(db, info)
+ // Lookup the event NID
+ eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
+ if err != nil {
+ return nil, err
+ }
+ eventIDs := []string{eIDs[eventNID]}
+
+ prevState, err := db.StateAtEventIDs(ctx, eventIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ // Fetch the state as it was when this event was fired
+ return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
+}
+
+func LoadEvents(
+ ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
+) ([]gomatrixserverlib.Event, error) {
+ stateEvents, err := db.Events(ctx, eventNIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ result := make([]gomatrixserverlib.Event, len(stateEvents))
+ for i := range stateEvents {
+ result[i] = stateEvents[i].Event
+ }
+ return result, nil
+}
+
+func LoadStateEvents(
+ ctx context.Context, db storage.Database, stateEntries []types.StateEntry,
+) ([]gomatrixserverlib.Event, error) {
+ eventNIDs := make([]types.EventNID, len(stateEntries))
+ for i := range stateEntries {
+ eventNIDs[i] = stateEntries[i].EventNID
+ }
+ return LoadEvents(ctx, db, eventNIDs)
+}
+
+func CheckServerAllowedToSeeEvent(
+ ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
+) (bool, error) {
+ roomState := state.NewStateResolution(db, info)
+ stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
+ if err != nil {
+ return false, err
+ }
+
+ // TODO: We probably want to make it so that we don't have to pull
+ // out all the state if possible.
+ stateAtEvent, err := LoadStateEvents(ctx, db, stateEntries)
+ if err != nil {
+ return false, err
+ }
+
+ return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
+}
+
+// TODO: Remove this when we have tests to assert correctness of this function
+// nolint:gocyclo
+func ScanEventTree(
+ ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int,
+ serverName gomatrixserverlib.ServerName,
+) ([]types.EventNID, error) {
+ var resultNIDs []types.EventNID
+ var err error
+ var allowed bool
+ var events []types.Event
+ var next []string
+ var pre string
+
+ // TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be)
+ // Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing
+ // so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in
+ // duplicate events being sent in response to /backfill requests.
+ initialIgnoreList := make(map[string]bool, len(visited))
+ for k, v := range visited {
+ initialIgnoreList[k] = v
+ }
+
+ resultNIDs = make([]types.EventNID, 0, limit)
+
+ var checkedServerInRoom bool
+ var isServerInRoom bool
+
+ // Loop through the event IDs to retrieve the requested events and go
+ // through the whole tree (up to the provided limit) using the events'
+ // "prev_event" key.
+BFSLoop:
+ for len(front) > 0 {
+ // Prevent unnecessary allocations: reset the slice only when not empty.
+ if len(next) > 0 {
+ next = make([]string, 0)
+ }
+ // Retrieve the events to process from the database.
+ events, err = db.EventsFromIDs(ctx, front)
+ if err != nil {
+ return resultNIDs, err
+ }
+
+ if !checkedServerInRoom && len(events) > 0 {
+ // It's nasty that we have to extract the room ID from an event, but many federation requests
+ // only talk in event IDs, no room IDs at all (!!!)
+ ev := events[0]
+ isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID())
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
+ }
+ checkedServerInRoom = true
+ }
+
+ for _, ev := range events {
+ // Break out of the loop if the provided limit is reached.
+ if len(resultNIDs) == limit {
+ break BFSLoop
+ }
+
+ if !initialIgnoreList[ev.EventID()] {
+ // Update the list of events to retrieve.
+ resultNIDs = append(resultNIDs, ev.EventNID)
+ }
+ // Loop through the event's parents.
+ for _, pre = range ev.PrevEventIDs() {
+ // Only add an event to the list of next events to process if it
+ // hasn't been seen before.
+ if !visited[pre] {
+ visited[pre] = true
+ allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom)
+ if err != nil {
+ util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
+ "Error checking if allowed to see event",
+ )
+ return resultNIDs, err
+ }
+
+ // If the event hasn't been seen before and the HS
+ // requesting to retrieve it is allowed to do so, add it to
+ // the list of events to retrieve.
+ if allowed {
+ next = append(next, pre)
+ } else {
+ util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
+ }
+ }
+ }
+ }
+ // Repeat the same process with the parent events we just processed.
+ front = next
+ }
+
+ return resultNIDs, err
+}
diff --git a/roomserver/internal/input.go b/roomserver/internal/input.go
index e85e9830..dbf67b79 100644
--- a/roomserver/internal/input.go
+++ b/roomserver/internal/input.go
@@ -23,17 +23,8 @@ import (
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/roomserver/api"
log "github.com/sirupsen/logrus"
-
- fsAPI "github.com/matrix-org/dendrite/federationsender/api"
)
-// SetFederationSenderInputAPI passes in a federation sender input API reference
-// so that we can avoid the chicken-and-egg problem of both the roomserver input API
-// and the federation sender input API being interdependent.
-func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {
- r.fsAPI = fsAPI
-}
-
// WriteOutputEvents implements OutputRoomEventWriter
func (r *RoomserverInternalAPI) WriteOutputEvents(roomID string, updates []api.OutputEvent) error {
messages := make([]*sarama.ProducerMessage, len(updates))
diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go
index 287db1af..edc8b416 100644
--- a/roomserver/internal/input_events.go
+++ b/roomserver/internal/input_events.go
@@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -45,7 +46,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
// Check that the event passes authentication checks and work out
// the numeric IDs for the auth events.
- authEventNIDs, err := checkAuthEvents(ctx, r.DB, headered, input.AuthEventIDs)
+ authEventNIDs, err := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs)
if err != nil {
logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event")
return
diff --git a/roomserver/internal/input_membership.go b/roomserver/internal/input_membership.go
index bcecfca0..57a94596 100644
--- a/roomserver/internal/input_membership.go
+++ b/roomserver/internal/input_membership.go
@@ -19,6 +19,7 @@ import (
"fmt"
"github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -59,13 +60,13 @@ func (r *RoomserverInternalAPI) updateMemberships(
var re *gomatrixserverlib.Event
targetUserNID := change.EventStateKeyNID
if change.removedEventNID != 0 {
- ev, _ := eventMap(events).lookup(change.removedEventNID)
+ ev, _ := helpers.EventMap(events).Lookup(change.removedEventNID)
if ev != nil {
re = &ev.Event
}
}
if change.addedEventNID != 0 {
- ev, _ := eventMap(events).lookup(change.addedEventNID)
+ ev, _ := helpers.EventMap(events).Lookup(change.addedEventNID)
if ev != nil {
ae = &ev.Event
}
@@ -120,7 +121,7 @@ func (r *RoomserverInternalAPI) updateMembership(
switch newMembership {
case gomatrixserverlib.Invite:
- return updateToInviteMembership(mu, add, updates, updater.RoomVersion())
+ return helpers.UpdateToInviteMembership(mu, add, updates, updater.RoomVersion())
case gomatrixserverlib.Join:
return updateToJoinMembership(mu, add, updates)
case gomatrixserverlib.Leave, gomatrixserverlib.Ban:
@@ -141,36 +142,6 @@ func (r *RoomserverInternalAPI) isLocalTarget(event *gomatrixserverlib.Event) bo
return isTargetLocalUser
}
-func updateToInviteMembership(
- mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
- roomVersion gomatrixserverlib.RoomVersion,
-) ([]api.OutputEvent, error) {
- // We may have already sent the invite to the user, either because we are
- // reprocessing this event, or because the we received this invite from a
- // remote server via the federation invite API. In those cases we don't need
- // to send the event.
- needsSending, err := mu.SetToInvite(*add)
- if err != nil {
- return nil, err
- }
- if needsSending {
- // We notify the consumers using a special event even though we will
- // notify them about the change in current state as part of the normal
- // room event stream. This ensures that the consumers only have to
- // consider a single stream of events when determining whether a user
- // is invited, rather than having to combine multiple streams themselves.
- onie := api.OutputNewInviteEvent{
- Event: add.Headered(roomVersion),
- RoomVersion: roomVersion,
- }
- updates = append(updates, api.OutputEvent{
- Type: api.OutputTypeNewInviteEvent,
- NewInviteEvent: &onie,
- })
- }
- return updates, nil
-}
-
func updateToJoinMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
diff --git a/roomserver/internal/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go
index 721f6610..ebb66ef4 100644
--- a/roomserver/internal/perform_backfill.go
+++ b/roomserver/internal/perform/perform_backfill.go
@@ -1,9 +1,13 @@
-package internal
+package perform
import (
"context"
+ "fmt"
+ "github.com/matrix-org/dendrite/internal/eventutil"
+ "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/auth"
+ "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -11,6 +15,189 @@ import (
"github.com/sirupsen/logrus"
)
+type Backfiller struct {
+ ServerName gomatrixserverlib.ServerName
+ DB storage.Database
+ FedClient *gomatrixserverlib.FederationClient
+ KeyRing gomatrixserverlib.JSONVerifier
+}
+
+// PerformBackfill implements api.RoomServerQueryAPI
+func (r *Backfiller) PerformBackfill(
+ ctx context.Context,
+ request *api.PerformBackfillRequest,
+ response *api.PerformBackfillResponse,
+) error {
+ // if we are requesting the backfill then we need to do a federation hit
+ // TODO: we could be more sensible and fetch as many events we already have then request the rest
+ // which is what the syncapi does already.
+ if request.ServerName == r.ServerName {
+ return r.backfillViaFederation(ctx, request, response)
+ }
+ // someone else is requesting the backfill, try to service their request.
+ var err error
+ var front []string
+
+ // The limit defines the maximum number of events to retrieve, so it also
+ // defines the highest number of elements in the map below.
+ visited := make(map[string]bool, request.Limit)
+
+ // this will include these events which is what we want
+ front = request.PrevEventIDs()
+
+ info, err := r.DB.RoomInfo(ctx, request.RoomID)
+ if err != nil {
+ return err
+ }
+ if info == nil || info.IsStub {
+ return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID)
+ }
+
+ // Scan the event tree for events to send back.
+ resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName)
+ if err != nil {
+ return err
+ }
+
+ // Retrieve events from the list that was filled previously.
+ var loadedEvents []gomatrixserverlib.Event
+ loadedEvents, err = helpers.LoadEvents(ctx, r.DB, resultNIDs)
+ if err != nil {
+ return err
+ }
+
+ for _, event := range loadedEvents {
+ response.Events = append(response.Events, event.Headered(info.RoomVersion))
+ }
+
+ return err
+}
+
+func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error {
+ info, err := r.DB.RoomInfo(ctx, req.RoomID)
+ if err != nil {
+ return err
+ }
+ if info == nil || info.IsStub {
+ return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID)
+ }
+ requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName, req.BackwardsExtremities)
+ // Request 100 items regardless of what the query asks for.
+ // We don't want to go much higher than this.
+ // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
+ // (so we don't need to hit /state_ids which the test has no listener for)
+ // Specifically the test "Outbound federation can backfill events"
+ events, err := gomatrixserverlib.RequestBackfill(
+ ctx, requester,
+ r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100)
+ if err != nil {
+ return err
+ }
+ logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
+
+ // persist these new events - auth checks have already been done
+ roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
+ if err != nil {
+ return err
+ }
+
+ for _, ev := range backfilledEventMap {
+ // now add state for these events
+ stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()]
+ if !ok {
+ // this should be impossible as all events returned must have pass Step 5 of the PDU checks
+ // which requires a list of state IDs.
+ logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks")
+ continue
+ }
+ var entries []types.StateEntry
+ if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil {
+ // attempt to fetch the missing events
+ r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs)
+ // try again
+ entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs)
+ if err != nil {
+ logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event")
+ return err
+ }
+ }
+
+ var beforeStateSnapshotNID types.StateSnapshotNID
+ if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
+ logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid")
+ return err
+ }
+ if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil {
+ logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid")
+ }
+ }
+
+ // TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point.
+
+ res.Events = events
+ return nil
+}
+
+// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just
+// best effort.
+func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
+ backfillRequester *backfillRequester, stateIDs []string) {
+
+ servers := backfillRequester.servers
+
+ // work out which are missing
+ nidMap, err := r.DB.EventNIDs(ctx, stateIDs)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Warn("cannot query missing events")
+ return
+ }
+ missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event
+ for _, id := range stateIDs {
+ if _, ok := nidMap[id]; !ok {
+ missingMap[id] = nil
+ }
+ }
+ util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers))
+
+ // fetch the events from federation. Loop the servers first so if we find one that works we stick with them
+ for _, srv := range servers {
+ for id, ev := range missingMap {
+ if ev != nil {
+ continue // already found
+ }
+ logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id)
+ res, err := r.FedClient.GetEvent(ctx, srv, id)
+ if err != nil {
+ logger.WithError(err).Warn("failed to get event from server")
+ continue
+ }
+ loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
+ result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents)
+ if err != nil {
+ logger.WithError(err).Warn("failed to load and verify event")
+ continue
+ }
+ logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result)
+ for _, res := range result {
+ if res.Error != nil {
+ logger.WithError(err).Warn("event failed PDU checks")
+ continue
+ }
+ missingMap[id] = res.Event
+ }
+ }
+ }
+
+ var newEvents []gomatrixserverlib.HeaderedEvent
+ for _, ev := range missingMap {
+ if ev != nil {
+ newEvents = append(newEvents, *ev)
+ }
+ }
+ util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
+ persistEvents(ctx, r.DB, newEvents)
+}
+
// backfillRequester implements gomatrixserverlib.BackfillRequester
type backfillRequester struct {
db storage.Database
@@ -200,7 +387,7 @@ FindSuccessor:
return nil
}
- stateEntries, err := stateBeforeEvent(ctx, b.db, *info, NIDs[eventID])
+ stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, *info, NIDs[eventID])
if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
return nil
@@ -217,7 +404,7 @@ FindSuccessor:
// Retrieve all "m.room.member" state events of "join" membership, which
// contains the list of users in the room before the event, therefore all
// the servers in it at that moment.
- memberEvents, err := getMembershipsAtState(ctx, b.db, stateEntries, true)
+ memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, stateEntries, true)
if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
return nil
@@ -314,3 +501,47 @@ func joinEventsFromHistoryVisibility(
}
return db.Events(ctx, joinEventNIDs)
}
+
+func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) {
+ var roomNID types.RoomNID
+ backfilledEventMap := make(map[string]types.Event)
+ for j, ev := range events {
+ nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs())
+ if err != nil { // this shouldn't happen as RequestBackfill already found them
+ logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events")
+ continue
+ }
+ authNids := make([]types.EventNID, len(nidMap))
+ i := 0
+ for _, nid := range nidMap {
+ authNids[i] = nid
+ i++
+ }
+ var stateAtEvent types.StateAtEvent
+ var redactedEventID string
+ var redactionEvent *gomatrixserverlib.Event
+ roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids)
+ if err != nil {
+ logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
+ continue
+ }
+ // If storing this event results in it being redacted, then do so.
+ // It's also possible for this event to be a redaction which results in another event being
+ // redacted, which we don't care about since we aren't returning it in this backfill.
+ if redactedEventID == ev.EventID() {
+ eventToRedact := ev.Unwrap()
+ redactedEvent, err := eventutil.RedactEvent(redactionEvent, &eventToRedact)
+ if err != nil {
+ logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
+ continue
+ }
+ ev = redactedEvent.Headered(ev.RoomVersion)
+ events[j] = ev
+ }
+ backfilledEventMap[ev.EventID()] = types.Event{
+ EventNID: stateAtEvent.StateEntry.EventNID,
+ Event: ev.Unwrap(),
+ }
+ }
+ return roomNID, backfilledEventMap
+}
diff --git a/roomserver/internal/perform_invite.go b/roomserver/internal/perform/perform_invite.go
index 6690de05..7320388e 100644
--- a/roomserver/internal/perform_invite.go
+++ b/roomserver/internal/perform/perform_invite.go
@@ -1,11 +1,13 @@
-package internal
+package perform
import (
"context"
"fmt"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
+ "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
@@ -13,22 +15,31 @@ import (
log "github.com/sirupsen/logrus"
)
+type Inviter struct {
+ DB storage.Database
+ Cfg *config.RoomServer
+ FSAPI federationSenderAPI.FederationSenderInternalAPI
+
+ // TODO FIXME: Remove this
+ RSAPI api.RoomserverInternalAPI
+}
+
// nolint:gocyclo
-func (r *RoomserverInternalAPI) PerformInvite(
+func (r *Inviter) PerformInvite(
ctx context.Context,
req *api.PerformInviteRequest,
res *api.PerformInviteResponse,
-) error {
+) ([]api.OutputEvent, error) {
event := req.Event
if event.StateKey() == nil {
- return fmt.Errorf("invite must be a state event")
+ return nil, fmt.Errorf("invite must be a state event")
}
roomID := event.RoomID()
targetUserID := *event.StateKey()
info, err := r.DB.RoomInfo(ctx, roomID)
if err != nil {
- return fmt.Errorf("Failed to load RoomInfo: %w", err)
+ return nil, fmt.Errorf("Failed to load RoomInfo: %w", err)
}
log.WithFields(log.Fields{
@@ -52,11 +63,11 @@ func (r *RoomserverInternalAPI) PerformInvite(
}
if len(inviteState) == 0 {
if err = event.SetUnsignedField("invite_room_state", struct{}{}); err != nil {
- return fmt.Errorf("event.SetUnsignedField: %w", err)
+ return nil, fmt.Errorf("event.SetUnsignedField: %w", err)
}
} else {
if err = event.SetUnsignedField("invite_room_state", inviteState); err != nil {
- return fmt.Errorf("event.SetUnsignedField: %w", err)
+ return nil, fmt.Errorf("event.SetUnsignedField: %w", err)
}
}
@@ -64,7 +75,7 @@ func (r *RoomserverInternalAPI) PerformInvite(
if info != nil {
_, isAlreadyJoined, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey())
if err != nil {
- return fmt.Errorf("r.DB.GetMembership: %w", err)
+ return nil, fmt.Errorf("r.DB.GetMembership: %w", err)
}
}
if isAlreadyJoined {
@@ -99,7 +110,7 @@ func (r *RoomserverInternalAPI) PerformInvite(
Code: api.PerformErrorNotAllowed,
Msg: "User is already joined to room",
}
- return nil
+ return nil, nil
}
if isOriginLocal {
@@ -107,7 +118,7 @@ func (r *RoomserverInternalAPI) PerformInvite(
// try and see if the user is allowed to make this invite. We can't do
// this for invites coming in over federation - we have to take those on
// trust.
- _, err = checkAuthEvents(ctx, r.DB, event, event.AuthEventIDs())
+ _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs())
if err != nil {
log.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
"processInviteEvent.checkAuthEvents failed for event",
@@ -117,9 +128,9 @@ func (r *RoomserverInternalAPI) PerformInvite(
Msg: err.Error(),
Code: api.PerformErrorNotAllowed,
}
- return nil
+ return nil, nil
}
- return fmt.Errorf("checkAuthEvents: %w", err)
+ return nil, fmt.Errorf("checkAuthEvents: %w", err)
}
// If the invite originated from us and the target isn't local then we
@@ -133,13 +144,13 @@ func (r *RoomserverInternalAPI) PerformInvite(
InviteRoomState: inviteState,
}
fsRes := &federationSenderAPI.PerformInviteResponse{}
- if err = r.fsAPI.PerformInvite(ctx, fsReq, fsRes); err != nil {
+ if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil {
res.Error = &api.PerformError{
Msg: err.Error(),
Code: api.PerformErrorNoOperation,
}
- log.WithError(err).WithField("event_id", event.EventID()).Error("r.fsAPI.PerformInvite failed")
- return nil
+ log.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed")
+ return nil, nil
}
event = fsRes.Event
}
@@ -159,8 +170,8 @@ func (r *RoomserverInternalAPI) PerformInvite(
},
}
inputRes := &api.InputRoomEventsResponse{}
- if err = r.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil {
- return fmt.Errorf("r.InputRoomEvents: %w", err)
+ if err = r.RSAPI.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil {
+ return nil, fmt.Errorf("r.InputRoomEvents: %w", err)
}
} else {
// The invite originated over federation. Process the membership
@@ -168,25 +179,23 @@ func (r *RoomserverInternalAPI) PerformInvite(
// invite.
updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion)
if err != nil {
- return fmt.Errorf("r.DB.MembershipUpdater: %w", err)
+ return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
}
unwrapped := event.Unwrap()
- outputUpdates, err := updateToInviteMembership(updater, &unwrapped, nil, req.Event.RoomVersion)
+ outputUpdates, err := helpers.UpdateToInviteMembership(updater, &unwrapped, nil, req.Event.RoomVersion)
if err != nil {
- return fmt.Errorf("updateToInviteMembership: %w", err)
+ return nil, fmt.Errorf("updateToInviteMembership: %w", err)
}
if err = updater.Commit(); err != nil {
- return fmt.Errorf("updater.Commit: %w", err)
+ return nil, fmt.Errorf("updater.Commit: %w", err)
}
- if err = r.WriteOutputEvents(roomID, outputUpdates); err != nil {
- return fmt.Errorf("r.WriteOutputEvents: %w", err)
- }
+ return outputUpdates, nil
}
- return nil
+ return nil, nil
}
func buildInviteStrippedState(
diff --git a/roomserver/internal/perform_join.go b/roomserver/internal/perform/perform_join.go
index 3b9b1b3c..c8e6e8e6 100644
--- a/roomserver/internal/perform_join.go
+++ b/roomserver/internal/perform/perform_join.go
@@ -1,4 +1,4 @@
-package internal
+package perform
import (
"context"
@@ -8,14 +8,27 @@ import (
"time"
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
+ "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/internal/helpers"
+ "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
+type Joiner struct {
+ ServerName gomatrixserverlib.ServerName
+ Cfg *config.RoomServer
+ FSAPI fsAPI.FederationSenderInternalAPI
+ DB storage.Database
+
+ // TODO FIXME: Remove this
+ RSAPI api.RoomserverInternalAPI
+}
+
// PerformJoin handles joining matrix rooms, including over federation by talking to the federationsender.
-func (r *RoomserverInternalAPI) PerformJoin(
+func (r *Joiner) PerformJoin(
ctx context.Context,
req *api.PerformJoinRequest,
res *api.PerformJoinResponse,
@@ -34,7 +47,7 @@ func (r *RoomserverInternalAPI) PerformJoin(
res.RoomID = roomID
}
-func (r *RoomserverInternalAPI) performJoin(
+func (r *Joiner) performJoin(
ctx context.Context,
req *api.PerformJoinRequest,
) (string, error) {
@@ -63,7 +76,7 @@ func (r *RoomserverInternalAPI) performJoin(
}
}
-func (r *RoomserverInternalAPI) performJoinRoomByAlias(
+func (r *Joiner) performJoinRoomByAlias(
ctx context.Context,
req *api.PerformJoinRequest,
) (string, error) {
@@ -85,7 +98,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByAlias(
ServerName: domain, // the server to ask
}
dirRes := fsAPI.PerformDirectoryLookupResponse{}
- err = r.fsAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes)
+ err = r.FSAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes)
if err != nil {
logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias)
return "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err)
@@ -112,7 +125,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByAlias(
// TODO: Break this function up a bit
// nolint:gocyclo
-func (r *RoomserverInternalAPI) performJoinRoomByID(
+func (r *Joiner) performJoinRoomByID(
ctx context.Context,
req *api.PerformJoinRequest,
) (string, error) {
@@ -161,8 +174,8 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
// where we might think we know about a room in the following
// section but don't know the latest state as all of our users
// have left.
- serverInRoom, _ := r.isServerCurrentlyInRoom(ctx, r.ServerName, req.RoomIDOrAlias)
- isInvitePending, inviteSender, _, err := r.isInvitePending(ctx, req.RoomIDOrAlias, req.UserID)
+ serverInRoom, _ := helpers.IsServerCurrentlyInRoom(ctx, r.DB, r.ServerName, req.RoomIDOrAlias)
+ isInvitePending, inviteSender, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID)
if err == nil && isInvitePending && !serverInRoom {
// Check if there's an invite pending.
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
@@ -194,7 +207,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
&eb, // the template join event
r.Cfg.Matrix, // the server configuration
time.Now(), // the event timestamp to use
- r, // the roomserver API to use
+ r.RSAPI, // the roomserver API to use
&buildRes, // the query response
)
@@ -228,7 +241,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
},
}
inputRes := api.InputRoomEventsResponse{}
- if err = r.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
+ if err = r.RSAPI.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
var notAllowed *gomatrixserverlib.NotAllowed
if errors.As(err, &notAllowed) {
return "", &api.PerformError{
@@ -271,7 +284,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
return req.RoomIDOrAlias, nil
}
-func (r *RoomserverInternalAPI) performFederatedJoinRoomByID(
+func (r *Joiner) performFederatedJoinRoomByID(
ctx context.Context,
req *api.PerformJoinRequest,
) error {
@@ -283,7 +296,7 @@ func (r *RoomserverInternalAPI) performFederatedJoinRoomByID(
Content: req.Content, // the membership event content
}
fedRes := fsAPI.PerformJoinResponse{}
- r.fsAPI.PerformJoin(ctx, &fedReq, &fedRes)
+ r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes)
if fedRes.LastError != nil {
return &api.PerformError{
Code: api.PerformErrRemote,
diff --git a/roomserver/internal/perform_leave.go b/roomserver/internal/perform/perform_leave.go
index b8603147..b4053eed 100644
--- a/roomserver/internal/perform_leave.go
+++ b/roomserver/internal/perform/perform_leave.go
@@ -1,4 +1,4 @@
-package internal
+package perform
import (
"context"
@@ -7,39 +7,50 @@ import (
"time"
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
+ "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
- "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/dendrite/roomserver/internal/helpers"
+ "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/gomatrixserverlib"
)
+type Leaver struct {
+ Cfg *config.RoomServer
+ DB storage.Database
+ FSAPI fsAPI.FederationSenderInternalAPI
+
+ // TODO FIXME: Remove this
+ RSAPI api.RoomserverInternalAPI
+}
+
// WriteOutputEvents implements OutputRoomEventWriter
-func (r *RoomserverInternalAPI) PerformLeave(
+func (r *Leaver) PerformLeave(
ctx context.Context,
req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse,
-) error {
+) ([]api.OutputEvent, error) {
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
- return fmt.Errorf("Supplied user ID %q in incorrect format", req.UserID)
+ return nil, fmt.Errorf("Supplied user ID %q in incorrect format", req.UserID)
}
if domain != r.Cfg.Matrix.ServerName {
- return fmt.Errorf("User %q does not belong to this homeserver", req.UserID)
+ return nil, fmt.Errorf("User %q does not belong to this homeserver", req.UserID)
}
if strings.HasPrefix(req.RoomID, "!") {
return r.performLeaveRoomByID(ctx, req, res)
}
- return fmt.Errorf("Room ID %q is invalid", req.RoomID)
+ return nil, fmt.Errorf("Room ID %q is invalid", req.RoomID)
}
-func (r *RoomserverInternalAPI) performLeaveRoomByID(
+func (r *Leaver) performLeaveRoomByID(
ctx context.Context,
req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam
-) error {
+) ([]api.OutputEvent, error) {
// If there's an invite outstanding for the room then respond to
// that.
- isInvitePending, senderUser, eventID, err := r.isInvitePending(ctx, req.RoomID, req.UserID)
+ isInvitePending, senderUser, eventID, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID)
if err == nil && isInvitePending {
return r.performRejectInvite(ctx, req, res, senderUser, eventID)
}
@@ -56,25 +67,25 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID(
},
}
latestRes := api.QueryLatestEventsAndStateResponse{}
- if err = r.QueryLatestEventsAndState(ctx, &latestReq, &latestRes); err != nil {
- return err
+ if err = r.RSAPI.QueryLatestEventsAndState(ctx, &latestReq, &latestRes); err != nil {
+ return nil, err
}
if !latestRes.RoomExists {
- return fmt.Errorf("Room %q does not exist", req.RoomID)
+ return nil, fmt.Errorf("Room %q does not exist", req.RoomID)
}
// Now let's see if the user is in the room.
if len(latestRes.StateEvents) == 0 {
- return fmt.Errorf("User %q is not a member of room %q", req.UserID, req.RoomID)
+ return nil, fmt.Errorf("User %q is not a member of room %q", req.UserID, req.RoomID)
}
membership, err := latestRes.StateEvents[0].Membership()
if err != nil {
- return fmt.Errorf("Error getting membership: %w", err)
+ return nil, fmt.Errorf("Error getting membership: %w", err)
}
if membership != gomatrixserverlib.Join {
// TODO: should be able to handle "invite" in this case too, if
// it's a case of kicking or banning or such
- return fmt.Errorf("User %q is not joined to the room (membership is %q)", req.UserID, membership)
+ return nil, fmt.Errorf("User %q is not joined to the room (membership is %q)", req.UserID, membership)
}
// Prepare the template for the leave event.
@@ -87,10 +98,10 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID(
Redacts: "",
}
if err = eb.SetContent(map[string]interface{}{"membership": "leave"}); err != nil {
- return fmt.Errorf("eb.SetContent: %w", err)
+ return nil, fmt.Errorf("eb.SetContent: %w", err)
}
if err = eb.SetUnsigned(struct{}{}); err != nil {
- return fmt.Errorf("eb.SetUnsigned: %w", err)
+ return nil, fmt.Errorf("eb.SetUnsigned: %w", err)
}
// We know that the user is in the room at this point so let's build
@@ -103,11 +114,11 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID(
&eb, // the template leave event
r.Cfg.Matrix, // the server configuration
time.Now(), // the event timestamp to use
- r, // the roomserver API to use
+ r.RSAPI, // the roomserver API to use
&buildRes, // the query response
)
if err != nil {
- return fmt.Errorf("eventutil.BuildEvent: %w", err)
+ return nil, fmt.Errorf("eventutil.BuildEvent: %w", err)
}
// Give our leave event to the roomserver input stream. The
@@ -124,22 +135,22 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID(
},
}
inputRes := api.InputRoomEventsResponse{}
- if err = r.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
- return fmt.Errorf("r.InputRoomEvents: %w", err)
+ if err = r.RSAPI.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
+ return nil, fmt.Errorf("r.InputRoomEvents: %w", err)
}
- return nil
+ return nil, nil
}
-func (r *RoomserverInternalAPI) performRejectInvite(
+func (r *Leaver) performRejectInvite(
ctx context.Context,
req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam
senderUser, eventID string,
-) error {
+) ([]api.OutputEvent, error) {
_, domain, err := gomatrixserverlib.SplitID('@', senderUser)
if err != nil {
- return fmt.Errorf("User ID %q invalid: %w", senderUser, err)
+ return nil, fmt.Errorf("User ID %q invalid: %w", senderUser, err)
}
// Ask the federation sender to perform a federated leave for us.
@@ -149,13 +160,13 @@ func (r *RoomserverInternalAPI) performRejectInvite(
ServerNames: []gomatrixserverlib.ServerName{domain},
}
leaveRes := fsAPI.PerformLeaveResponse{}
- if err := r.fsAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil {
- return err
+ if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil {
+ return nil, err
}
// Withdraw the invite, so that the sync API etc are
// notified that we rejected it.
- return r.WriteOutputEvents(req.RoomID, []api.OutputEvent{
+ return []api.OutputEvent{
{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{
@@ -164,60 +175,5 @@ func (r *RoomserverInternalAPI) performRejectInvite(
TargetUserID: req.UserID,
},
},
- })
-}
-
-func (r *RoomserverInternalAPI) isInvitePending(
- ctx context.Context,
- roomID, userID string,
-) (bool, string, string, error) {
- // Look up the room NID for the supplied room ID.
- info, err := r.DB.RoomInfo(ctx, roomID)
- if err != nil {
- return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err)
- }
- if info == nil {
- return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID)
- }
-
- // Look up the state key NID for the supplied user ID.
- targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{userID})
- if err != nil {
- return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
- }
- targetUserNID, targetUserFound := targetUserNIDs[userID]
- if !targetUserFound {
- return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
- }
-
- // Let's see if we have an event active for the user in the room. If
- // we do then it will contain a server name that we can direct the
- // send_leave to.
- senderUserNIDs, eventIDs, err := r.DB.GetInvitesForUser(ctx, info.RoomNID, targetUserNID)
- if err != nil {
- return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
- }
- if len(senderUserNIDs) == 0 {
- return false, "", "", nil
- }
- userNIDToEventID := make(map[types.EventStateKeyNID]string)
- for i, nid := range senderUserNIDs {
- userNIDToEventID[nid] = eventIDs[i]
- }
-
- // Look up the user ID from the NID.
- senderUsers, err := r.DB.EventStateKeys(ctx, senderUserNIDs)
- if err != nil {
- return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err)
- }
- if len(senderUsers) == 0 {
- return false, "", "", fmt.Errorf("no senderUsers")
- }
-
- senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
- if !senderUserFound {
- return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
- }
-
- return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil
+ }, nil
}
diff --git a/roomserver/internal/perform_publish.go b/roomserver/internal/perform/perform_publish.go
index d7863620..aab282f3 100644
--- a/roomserver/internal/perform_publish.go
+++ b/roomserver/internal/perform/perform_publish.go
@@ -1,12 +1,17 @@
-package internal
+package perform
import (
"context"
"github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/storage"
)
-func (r *RoomserverInternalAPI) PerformPublish(
+type Publisher struct {
+ DB storage.Database
+}
+
+func (r *Publisher) PerformPublish(
ctx context.Context,
req *api.PerformPublishRequest,
res *api.PerformPublishResponse,
diff --git a/roomserver/internal/query.go b/roomserver/internal/query.go
index f8e8ba04..26b22c74 100644
--- a/roomserver/internal/query.go
+++ b/roomserver/internal/query.go
@@ -20,11 +20,9 @@ import (
"context"
"fmt"
- "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
- "github.com/matrix-org/dendrite/roomserver/auth"
+ "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/state"
- "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrixserverlib"
@@ -74,7 +72,7 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState(
return err
}
- stateEvents, err := r.loadStateEvents(ctx, stateEntries)
+ stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries)
if err != nil {
return err
}
@@ -123,7 +121,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents(
return err
}
- stateEvents, err := r.loadStateEvents(ctx, stateEntries)
+ stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries)
if err != nil {
return err
}
@@ -151,7 +149,7 @@ func (r *RoomserverInternalAPI) QueryEventsByID(
eventNIDs = append(eventNIDs, nid)
}
- events, err := r.loadEvents(ctx, eventNIDs)
+ events, err := helpers.LoadEvents(ctx, r.DB, eventNIDs)
if err != nil {
return err
}
@@ -168,31 +166,6 @@ func (r *RoomserverInternalAPI) QueryEventsByID(
return nil
}
-func (r *RoomserverInternalAPI) loadStateEvents(
- ctx context.Context, stateEntries []types.StateEntry,
-) ([]gomatrixserverlib.Event, error) {
- eventNIDs := make([]types.EventNID, len(stateEntries))
- for i := range stateEntries {
- eventNIDs[i] = stateEntries[i].EventNID
- }
- return r.loadEvents(ctx, eventNIDs)
-}
-
-func (r *RoomserverInternalAPI) loadEvents(
- ctx context.Context, eventNIDs []types.EventNID,
-) ([]gomatrixserverlib.Event, error) {
- stateEvents, err := r.DB.Events(ctx, eventNIDs)
- if err != nil {
- return nil, err
- }
-
- result := make([]gomatrixserverlib.Event, len(stateEvents))
- for i := range stateEvents {
- result[i] = stateEvents[i].Event
- }
- return result, nil
-}
-
// QueryMembershipForUser implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryMembershipForUser(
ctx context.Context,
@@ -266,12 +239,12 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
events, err = r.DB.Events(ctx, eventNIDs)
} else {
- stateEntries, err = stateBeforeEvent(ctx, r.DB, *info, membershipEventNID)
+ stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, *info, membershipEventNID)
if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err
}
- events, err = getMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly)
+ events, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly)
}
if err != nil {
@@ -286,65 +259,6 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
return nil
}
-func stateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
- roomState := state.NewStateResolution(db, info)
- // Lookup the event NID
- eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
- if err != nil {
- return nil, err
- }
- eventIDs := []string{eIDs[eventNID]}
-
- prevState, err := db.StateAtEventIDs(ctx, eventIDs)
- if err != nil {
- return nil, err
- }
-
- // Fetch the state as it was when this event was fired
- return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
-}
-
-// getMembershipsAtState filters the state events to
-// only keep the "m.room.member" events with a "join" membership. These events are returned.
-// Returns an error if there was an issue fetching the events.
-func getMembershipsAtState(
- ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
-) ([]types.Event, error) {
-
- var eventNIDs []types.EventNID
- for _, entry := range stateEntries {
- // Filter the events to retrieve to only keep the membership events
- if entry.EventTypeNID == types.MRoomMemberNID {
- eventNIDs = append(eventNIDs, entry.EventNID)
- }
- }
-
- // Get all of the events in this state
- stateEvents, err := db.Events(ctx, eventNIDs)
- if err != nil {
- return nil, err
- }
-
- if !joinedOnly {
- return stateEvents, nil
- }
-
- // Filter the events to only keep the "join" membership events
- var events []types.Event
- for _, event := range stateEvents {
- membership, err := event.Membership()
- if err != nil {
- return nil, err
- }
-
- if membership == gomatrixserverlib.Join {
- events = append(events, event)
- }
- }
-
- return events, nil
-}
-
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent(
ctx context.Context,
@@ -360,7 +274,7 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent(
return
}
roomID := events[0].RoomID()
- isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, roomID)
+ isServerInRoom, err := helpers.IsServerCurrentlyInRoom(ctx, r.DB, request.ServerName, roomID)
if err != nil {
return
}
@@ -371,31 +285,12 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent(
if info == nil {
return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID)
}
- response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent(
- ctx, *info, request.EventID, request.ServerName, isServerInRoom,
+ response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
+ ctx, r.DB, *info, request.EventID, request.ServerName, isServerInRoom,
)
return
}
-func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent(
- ctx context.Context, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
-) (bool, error) {
- roomState := state.NewStateResolution(r.DB, info)
- stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
- if err != nil {
- return false, err
- }
-
- // TODO: We probably want to make it so that we don't have to pull
- // out all the state if possible.
- stateAtEvent, err := r.loadStateEvents(ctx, stateEntries)
- if err != nil {
- return false, err
- }
-
- return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
-}
-
// QueryMissingEvents implements api.RoomserverInternalAPI
// nolint:gocyclo
func (r *RoomserverInternalAPI) QueryMissingEvents(
@@ -431,12 +326,12 @@ func (r *RoomserverInternalAPI) QueryMissingEvents(
return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
}
- resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName)
+ resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName)
if err != nil {
return err
}
- loadedEvents, err := r.loadEvents(ctx, resultNIDs)
+ loadedEvents, err := helpers.LoadEvents(ctx, r.DB, resultNIDs)
if err != nil {
return err
}
@@ -456,299 +351,6 @@ func (r *RoomserverInternalAPI) QueryMissingEvents(
return err
}
-// PerformBackfill implements api.RoomServerQueryAPI
-func (r *RoomserverInternalAPI) PerformBackfill(
- ctx context.Context,
- request *api.PerformBackfillRequest,
- response *api.PerformBackfillResponse,
-) error {
- // if we are requesting the backfill then we need to do a federation hit
- // TODO: we could be more sensible and fetch as many events we already have then request the rest
- // which is what the syncapi does already.
- if request.ServerName == r.ServerName {
- return r.backfillViaFederation(ctx, request, response)
- }
- // someone else is requesting the backfill, try to service their request.
- var err error
- var front []string
-
- // The limit defines the maximum number of events to retrieve, so it also
- // defines the highest number of elements in the map below.
- visited := make(map[string]bool, request.Limit)
-
- // this will include these events which is what we want
- front = request.PrevEventIDs()
-
- info, err := r.DB.RoomInfo(ctx, request.RoomID)
- if err != nil {
- return err
- }
- if info == nil || info.IsStub {
- return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID)
- }
-
- // Scan the event tree for events to send back.
- resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName)
- if err != nil {
- return err
- }
-
- // Retrieve events from the list that was filled previously.
- var loadedEvents []gomatrixserverlib.Event
- loadedEvents, err = r.loadEvents(ctx, resultNIDs)
- if err != nil {
- return err
- }
-
- for _, event := range loadedEvents {
- response.Events = append(response.Events, event.Headered(info.RoomVersion))
- }
-
- return err
-}
-
-func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error {
- roomVer, err := r.roomVersion(req.RoomID)
- if err != nil {
- return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err)
- }
- requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName, req.BackwardsExtremities)
- // Request 100 items regardless of what the query asks for.
- // We don't want to go much higher than this.
- // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
- // (so we don't need to hit /state_ids which the test has no listener for)
- // Specifically the test "Outbound federation can backfill events"
- events, err := gomatrixserverlib.RequestBackfill(
- ctx, requester,
- r.KeyRing, req.RoomID, roomVer, req.PrevEventIDs(), 100)
- if err != nil {
- return err
- }
- logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
-
- // persist these new events - auth checks have already been done
- roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
- if err != nil {
- return err
- }
-
- for _, ev := range backfilledEventMap {
- // now add state for these events
- stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()]
- if !ok {
- // this should be impossible as all events returned must have pass Step 5 of the PDU checks
- // which requires a list of state IDs.
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks")
- continue
- }
- var entries []types.StateEntry
- if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil {
- // attempt to fetch the missing events
- r.fetchAndStoreMissingEvents(ctx, roomVer, requester, stateIDs)
- // try again
- entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs)
- if err != nil {
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event")
- return err
- }
- }
-
- var beforeStateSnapshotNID types.StateSnapshotNID
- if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid")
- return err
- }
- if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil {
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid")
- }
- }
-
- // TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point.
-
- res.Events = events
- return nil
-}
-
-func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) {
- info, err := r.DB.RoomInfo(ctx, roomID)
- if err != nil {
- return false, err
- }
- if info == nil {
- return false, fmt.Errorf("unknown room %s", roomID)
- }
-
- eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
- if err != nil {
- return false, err
- }
-
- events, err := r.DB.Events(ctx, eventNIDs)
- if err != nil {
- return false, err
- }
- gmslEvents := make([]gomatrixserverlib.Event, len(events))
- for i := range events {
- gmslEvents[i] = events[i].Event
- }
- return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
-}
-
-// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just
-// best effort.
-func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
- backfillRequester *backfillRequester, stateIDs []string) {
-
- servers := backfillRequester.servers
-
- // work out which are missing
- nidMap, err := r.DB.EventNIDs(ctx, stateIDs)
- if err != nil {
- util.GetLogger(ctx).WithError(err).Warn("cannot query missing events")
- return
- }
- missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event
- for _, id := range stateIDs {
- if _, ok := nidMap[id]; !ok {
- missingMap[id] = nil
- }
- }
- util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers))
-
- // fetch the events from federation. Loop the servers first so if we find one that works we stick with them
- for _, srv := range servers {
- for id, ev := range missingMap {
- if ev != nil {
- continue // already found
- }
- logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id)
- res, err := r.FedClient.GetEvent(ctx, srv, id)
- if err != nil {
- logger.WithError(err).Warn("failed to get event from server")
- continue
- }
- loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
- result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents)
- if err != nil {
- logger.WithError(err).Warn("failed to load and verify event")
- continue
- }
- logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result)
- for _, res := range result {
- if res.Error != nil {
- logger.WithError(err).Warn("event failed PDU checks")
- continue
- }
- missingMap[id] = res.Event
- }
- }
- }
-
- var newEvents []gomatrixserverlib.HeaderedEvent
- for _, ev := range missingMap {
- if ev != nil {
- newEvents = append(newEvents, *ev)
- }
- }
- util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
- persistEvents(ctx, r.DB, newEvents)
-}
-
-// TODO: Remove this when we have tests to assert correctness of this function
-// nolint:gocyclo
-func (r *RoomserverInternalAPI) scanEventTree(
- ctx context.Context, info types.RoomInfo, front []string, visited map[string]bool, limit int,
- serverName gomatrixserverlib.ServerName,
-) ([]types.EventNID, error) {
- var resultNIDs []types.EventNID
- var err error
- var allowed bool
- var events []types.Event
- var next []string
- var pre string
-
- // TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be)
- // Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing
- // so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in
- // duplicate events being sent in response to /backfill requests.
- initialIgnoreList := make(map[string]bool, len(visited))
- for k, v := range visited {
- initialIgnoreList[k] = v
- }
-
- resultNIDs = make([]types.EventNID, 0, limit)
-
- var checkedServerInRoom bool
- var isServerInRoom bool
-
- // Loop through the event IDs to retrieve the requested events and go
- // through the whole tree (up to the provided limit) using the events'
- // "prev_event" key.
-BFSLoop:
- for len(front) > 0 {
- // Prevent unnecessary allocations: reset the slice only when not empty.
- if len(next) > 0 {
- next = make([]string, 0)
- }
- // Retrieve the events to process from the database.
- events, err = r.DB.EventsFromIDs(ctx, front)
- if err != nil {
- return resultNIDs, err
- }
-
- if !checkedServerInRoom && len(events) > 0 {
- // It's nasty that we have to extract the room ID from an event, but many federation requests
- // only talk in event IDs, no room IDs at all (!!!)
- ev := events[0]
- isServerInRoom, err = r.isServerCurrentlyInRoom(ctx, serverName, ev.RoomID())
- if err != nil {
- util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
- }
- checkedServerInRoom = true
- }
-
- for _, ev := range events {
- // Break out of the loop if the provided limit is reached.
- if len(resultNIDs) == limit {
- break BFSLoop
- }
-
- if !initialIgnoreList[ev.EventID()] {
- // Update the list of events to retrieve.
- resultNIDs = append(resultNIDs, ev.EventNID)
- }
- // Loop through the event's parents.
- for _, pre = range ev.PrevEventIDs() {
- // Only add an event to the list of next events to process if it
- // hasn't been seen before.
- if !visited[pre] {
- visited[pre] = true
- allowed, err = r.checkServerAllowedToSeeEvent(ctx, info, pre, serverName, isServerInRoom)
- if err != nil {
- util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
- "Error checking if allowed to see event",
- )
- return resultNIDs, err
- }
-
- // If the event hasn't been seen before and the HS
- // requesting to retrieve it is allowed to do so, add it to
- // the list of events to retrieve.
- if allowed {
- next = append(next, pre)
- } else {
- util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
- }
- }
- }
- }
- // Repeat the same process with the parent events we just processed.
- front = next
- }
-
- return resultNIDs, err
-}
-
// QueryStateAndAuthChain implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryStateAndAuthChain(
ctx context.Context,
@@ -823,7 +425,7 @@ func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, roomInf
return nil, err
}
- return r.loadStateEvents(ctx, stateEntries)
+ return helpers.LoadStateEvents(ctx, r.DB, stateEntries)
}
type eventsFromIDs func(context.Context, []string) ([]types.Event, error)
@@ -879,50 +481,6 @@ func getAuthChain(
return authEvents, nil
}
-func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) {
- var roomNID types.RoomNID
- backfilledEventMap := make(map[string]types.Event)
- for j, ev := range events {
- nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs())
- if err != nil { // this shouldn't happen as RequestBackfill already found them
- logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events")
- continue
- }
- authNids := make([]types.EventNID, len(nidMap))
- i := 0
- for _, nid := range nidMap {
- authNids[i] = nid
- i++
- }
- var stateAtEvent types.StateAtEvent
- var redactedEventID string
- var redactionEvent *gomatrixserverlib.Event
- roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids)
- if err != nil {
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
- continue
- }
- // If storing this event results in it being redacted, then do so.
- // It's also possible for this event to be a redaction which results in another event being
- // redacted, which we don't care about since we aren't returning it in this backfill.
- if redactedEventID == ev.EventID() {
- eventToRedact := ev.Unwrap()
- redactedEvent, err := eventutil.RedactEvent(redactionEvent, &eventToRedact)
- if err != nil {
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
- continue
- }
- ev = redactedEvent.Headered(ev.RoomVersion)
- events[j] = ev
- }
- backfilledEventMap[ev.EventID()] = types.Event{
- EventNID: stateAtEvent.StateEntry.EventNID,
- Event: ev.Unwrap(),
- }
- }
- return roomNID, backfilledEventMap
-}
-
// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryRoomVersionCapabilities(
ctx context.Context,
diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go
index 21af5f32..a428ad57 100644
--- a/roomserver/roomserver.go
+++ b/roomserver/roomserver.go
@@ -47,14 +47,8 @@ func NewInternalAPI(
logrus.WithError(err).Panicf("failed to connect to room server db")
}
- return &internal.RoomserverInternalAPI{
- DB: roomserverDB,
- Cfg: cfg,
- Producer: base.KafkaProducer,
- OutputRoomEventTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent)),
- Cache: base.Caches,
- ServerName: cfg.Matrix.ServerName,
- FedClient: fedClient,
- KeyRing: keyRing,
- }
+ return internal.NewRoomserverAPI(
+ cfg, roomserverDB, base.KafkaProducer, string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent)),
+ base.Caches, fedClient, keyRing,
+ )
}