aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-11-19 11:34:59 +0000
committerGitHub <noreply@github.com>2020-11-19 11:34:59 +0000
commit6353b0b7e42d65d92368b93c021b3a744c03214b (patch)
tree6c8268bf6e133092aadf903bc989906d430f1c66
parent1cf9f20d2f740864a48bfb3465f686f4bfe61591 (diff)
MSC2836: Threading - part one (#1589)
* Add mscs/hooks package, begin work for msc2836 * Flesh out hooks and add SQL schema * Begin implementing core msc2836 logic * Add test harness * Linting * Implement visibility checks; stub out APIs for tests * Flesh out testing * Flesh out walkThread a bit * Persist the origin_server_ts as well * Edges table instead of relationships * Add nodes table for event metadata * LEFT JOIN to extract origin_server_ts for children * Add graph walking structs * Implement walking algorithm * Add more graph walking tests * Add auto_join for local rooms * Fix create table syntax on postgres * Add relationship_room_id|servers to the unsigned section of events * Persist the parent room_id/servers in edge metadata Other events cannot assert the true room_id/servers for the parent event, only make claims to them, hence why this is edge metadata. * guts to pass through room_id/servers * Refactor msc2836 to allow handling from federation * Add JoinedVia to PerformJoin responses * Fix tests; review comments
-rw-r--r--cmd/dendrite-monolith-server/main.go7
-rw-r--r--federationsender/api/api.go2
-rw-r--r--federationsender/internal/perform.go1
-rw-r--r--federationsender/internal/query.go11
-rw-r--r--internal/config/config.go6
-rw-r--r--internal/config/config_mscs.go19
-rw-r--r--internal/hooks/hooks.go74
-rw-r--r--internal/mscs/msc2836/msc2836.go530
-rw-r--r--internal/mscs/msc2836/msc2836_test.go574
-rw-r--r--internal/mscs/msc2836/storage.go226
-rw-r--r--internal/mscs/mscs.go42
-rw-r--r--roomserver/api/perform.go3
-rw-r--r--roomserver/internal/input/input.go5
-rw-r--r--roomserver/internal/perform/perform_join.go52
14 files changed, 1517 insertions, 35 deletions
diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go
index e935805f..70b81bbc 100644
--- a/cmd/dendrite-monolith-server/main.go
+++ b/cmd/dendrite-monolith-server/main.go
@@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/federationsender"
"github.com/matrix-org/dendrite/internal/config"
+ "github.com/matrix-org/dendrite/internal/mscs"
"github.com/matrix-org/dendrite/internal/setup"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
@@ -148,6 +149,12 @@ func main() {
base.PublicMediaAPIMux,
)
+ if len(base.Cfg.MSCs.MSCs) > 0 {
+ if err := mscs.Enable(base, &monolith); err != nil {
+ logrus.WithError(err).Fatalf("Failed to enable MSCs")
+ }
+ }
+
// Expose the matrix APIs directly rather than putting them under a /api path.
go func() {
base.SetupAndServeHTTP(
diff --git a/federationsender/api/api.go b/federationsender/api/api.go
index b0522516..a4d15f1f 100644
--- a/federationsender/api/api.go
+++ b/federationsender/api/api.go
@@ -48,6 +48,7 @@ type FederationSenderInternalAPI interface {
// Query the server names of the joined hosts in a room.
// Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice
// containing only the server names (without information for membership events).
+ // The response will include this server if they are joined to the room.
QueryJoinedHostServerNamesInRoom(
ctx context.Context,
request *QueryJoinedHostServerNamesInRoomRequest,
@@ -104,6 +105,7 @@ type PerformJoinRequest struct {
}
type PerformJoinResponse struct {
+ JoinedVia gomatrixserverlib.ServerName
LastError *gomatrix.HTTPError
}
diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go
index a7484476..45f33ff7 100644
--- a/federationsender/internal/perform.go
+++ b/federationsender/internal/perform.go
@@ -105,6 +105,7 @@ func (r *FederationSenderInternalAPI) PerformJoin(
}
// We're all good.
+ response.JoinedVia = serverName
return
}
diff --git a/federationsender/internal/query.go b/federationsender/internal/query.go
index 253400a2..8ba228d1 100644
--- a/federationsender/internal/query.go
+++ b/federationsender/internal/query.go
@@ -4,7 +4,6 @@ import (
"context"
"github.com/matrix-org/dendrite/federationsender/api"
- "github.com/matrix-org/gomatrixserverlib"
)
// QueryJoinedHostServerNamesInRoom implements api.FederationSenderInternalAPI
@@ -13,17 +12,11 @@ func (f *FederationSenderInternalAPI) QueryJoinedHostServerNamesInRoom(
request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse,
) (err error) {
- joinedHosts, err := f.db.GetJoinedHosts(ctx, request.RoomID)
+ joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID})
if err != nil {
return
}
-
- response.ServerNames = make([]gomatrixserverlib.ServerName, 0, len(joinedHosts))
- for _, host := range joinedHosts {
- response.ServerNames = append(response.ServerNames, host.ServerName)
- }
-
- // TODO: remove duplicates?
+ response.ServerNames = joinedHosts
return
}
diff --git a/internal/config/config.go b/internal/config/config.go
index 9d9e2414..b8b12d0c 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -66,6 +66,8 @@ type Dendrite struct {
SyncAPI SyncAPI `yaml:"sync_api"`
UserAPI UserAPI `yaml:"user_api"`
+ MSCs MSCs `yaml:"mscs"`
+
// The config for tracing the dendrite servers.
Tracing struct {
// Set to true to enable tracer hooks. If false, no tracing is set up.
@@ -306,6 +308,7 @@ func (c *Dendrite) Defaults() {
c.SyncAPI.Defaults()
c.UserAPI.Defaults()
c.AppServiceAPI.Defaults()
+ c.MSCs.Defaults()
c.Wiring()
}
@@ -319,7 +322,7 @@ func (c *Dendrite) Verify(configErrs *ConfigErrors, isMonolith bool) {
&c.EDUServer, &c.FederationAPI, &c.FederationSender,
&c.KeyServer, &c.MediaAPI, &c.RoomServer,
&c.SigningKeyServer, &c.SyncAPI, &c.UserAPI,
- &c.AppServiceAPI,
+ &c.AppServiceAPI, &c.MSCs,
} {
c.Verify(configErrs, isMonolith)
}
@@ -337,6 +340,7 @@ func (c *Dendrite) Wiring() {
c.SyncAPI.Matrix = &c.Global
c.UserAPI.Matrix = &c.Global
c.AppServiceAPI.Matrix = &c.Global
+ c.MSCs.Matrix = &c.Global
c.ClientAPI.Derived = &c.Derived
c.AppServiceAPI.Derived = &c.Derived
diff --git a/internal/config/config_mscs.go b/internal/config/config_mscs.go
new file mode 100644
index 00000000..776d0b64
--- /dev/null
+++ b/internal/config/config_mscs.go
@@ -0,0 +1,19 @@
+package config
+
+type MSCs struct {
+ Matrix *Global `yaml:"-"`
+
+ // The MSCs to enable, currently only `msc2836` is supported.
+ MSCs []string `yaml:"mscs"`
+
+ Database DatabaseOptions `yaml:"database"`
+}
+
+func (c *MSCs) Defaults() {
+ c.Database.Defaults()
+ c.Database.ConnectionString = "file:mscs.db"
+}
+
+func (c *MSCs) Verify(configErrs *ConfigErrors, isMonolith bool) {
+ checkNotEmpty(configErrs, "mscs.database.connection_string", string(c.Database.ConnectionString))
+}
diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go
new file mode 100644
index 00000000..223282a2
--- /dev/null
+++ b/internal/hooks/hooks.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hooks exposes places in Dendrite where custom code can be executed, useful for MSCs.
+// Hooks can only be run in monolith mode.
+package hooks
+
+import "sync"
+
+const (
+ // KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent
+ // It is run when a new event is persisted in the roomserver.
+ // Usage:
+ // hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { ... })
+ KindNewEventPersisted = "new_event_persisted"
+ // KindNewEventReceived is a hook which is called with *gomatrixserverlib.HeaderedEvent
+ // It is run before a new event is processed by the roomserver. This hook can be used
+ // to modify the event before it is persisted by adding data to `unsigned`.
+ // Usage:
+ // hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) {
+ // ev := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
+ // _ = ev.SetUnsignedField("key", "val")
+ // })
+ KindNewEventReceived = "new_event_received"
+)
+
+var (
+ hookMap = make(map[string][]func(interface{}))
+ hookMu = sync.Mutex{}
+ enabled = false
+)
+
+// Enable all hooks. This may slow down the server slightly. Required for MSCs to work.
+func Enable() {
+ enabled = true
+}
+
+// Run any hooks
+func Run(kind string, data interface{}) {
+ if !enabled {
+ return
+ }
+ cbs := callbacks(kind)
+ for _, cb := range cbs {
+ cb(data)
+ }
+}
+
+// Attach a hook
+func Attach(kind string, callback func(interface{})) {
+ if !enabled {
+ return
+ }
+ hookMu.Lock()
+ defer hookMu.Unlock()
+ hookMap[kind] = append(hookMap[kind], callback)
+}
+
+func callbacks(kind string) []func(interface{}) {
+ hookMu.Lock()
+ defer hookMu.Unlock()
+ return hookMap[kind]
+}
diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go
new file mode 100644
index 00000000..865bc311
--- /dev/null
+++ b/internal/mscs/msc2836/msc2836.go
@@ -0,0 +1,530 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package msc2836 'Threading' implements https://github.com/matrix-org/matrix-doc/pull/2836
+package msc2836
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+
+ "github.com/matrix-org/dendrite/clientapi/jsonerror"
+ fs "github.com/matrix-org/dendrite/federationsender/api"
+ "github.com/matrix-org/dendrite/internal/hooks"
+ "github.com/matrix-org/dendrite/internal/httputil"
+ "github.com/matrix-org/dendrite/internal/setup"
+ roomserver "github.com/matrix-org/dendrite/roomserver/api"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+)
+
+const (
+ constRelType = "m.reference"
+ constRoomIDKey = "relationship_room_id"
+ constRoomServers = "relationship_servers"
+)
+
+type EventRelationshipRequest struct {
+ EventID string `json:"event_id"`
+ MaxDepth int `json:"max_depth"`
+ MaxBreadth int `json:"max_breadth"`
+ Limit int `json:"limit"`
+ DepthFirst bool `json:"depth_first"`
+ RecentFirst bool `json:"recent_first"`
+ IncludeParent bool `json:"include_parent"`
+ IncludeChildren bool `json:"include_children"`
+ Direction string `json:"direction"`
+ Batch string `json:"batch"`
+ AutoJoin bool `json:"auto_join"`
+}
+
+func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) {
+ var relation EventRelationshipRequest
+ relation.Defaults()
+ if err := json.NewDecoder(body).Decode(&relation); err != nil {
+ return nil, err
+ }
+ return &relation, nil
+}
+
+func (r *EventRelationshipRequest) Defaults() {
+ r.Limit = 100
+ r.MaxBreadth = 10
+ r.MaxDepth = 3
+ r.DepthFirst = false
+ r.RecentFirst = true
+ r.IncludeParent = false
+ r.IncludeChildren = false
+ r.Direction = "down"
+}
+
+type EventRelationshipResponse struct {
+ Events []gomatrixserverlib.ClientEvent `json:"events"`
+ NextBatch string `json:"next_batch"`
+ Limited bool `json:"limited"`
+}
+
+// Enable this MSC
+// nolint:gocyclo
+func Enable(
+ base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI,
+ userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
+) error {
+ db, err := NewDatabase(&base.Cfg.MSCs.Database)
+ if err != nil {
+ return fmt.Errorf("Cannot enable MSC2836: %w", err)
+ }
+ hooks.Enable()
+ hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) {
+ he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
+ hookErr := db.StoreRelation(context.Background(), he)
+ if hookErr != nil {
+ util.GetLogger(context.Background()).WithError(hookErr).Error(
+ "failed to StoreRelation",
+ )
+ }
+ })
+ hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) {
+ he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
+ ctx := context.Background()
+ // we only inject metadata for events our server sends
+ userID := he.Sender()
+ _, domain, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return
+ }
+ if domain != base.Cfg.Global.ServerName {
+ return
+ }
+ // if this event has an m.relationship, add on the room_id and servers to unsigned
+ parent, child, relType := parentChildEventIDs(he)
+ if parent == "" || child == "" || relType == "" {
+ return
+ }
+ event, joinedToRoom := getEventIfVisible(ctx, rsAPI, parent, userID)
+ if !joinedToRoom {
+ return
+ }
+ err = he.SetUnsignedField(constRoomIDKey, event.RoomID())
+ if err != nil {
+ util.GetLogger(context.Background()).WithError(err).Warn("Failed to SetUnsignedField")
+ return
+ }
+
+ var servers []gomatrixserverlib.ServerName
+ if fsAPI != nil {
+ var res fs.QueryJoinedHostServerNamesInRoomResponse
+ err = fsAPI.QueryJoinedHostServerNamesInRoom(ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
+ RoomID: event.RoomID(),
+ }, &res)
+ if err != nil {
+ util.GetLogger(context.Background()).WithError(err).Warn("Failed to QueryJoinedHostServerNamesInRoom")
+ return
+ }
+ servers = res.ServerNames
+ } else {
+ servers = []gomatrixserverlib.ServerName{
+ base.Cfg.Global.ServerName,
+ }
+ }
+ err = he.SetUnsignedField(constRoomServers, servers)
+ if err != nil {
+ util.GetLogger(context.Background()).WithError(err).Warn("Failed to SetUnsignedField")
+ return
+ }
+ })
+
+ base.PublicClientAPIMux.Handle("/unstable/event_relationships",
+ httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI)),
+ ).Methods(http.MethodPost, http.MethodOptions)
+
+ base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI(
+ "msc2836_event_relationships", func(req *http.Request) util.JSONResponse {
+ fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
+ req, time.Now(), base.Cfg.Global.ServerName, keyRing,
+ )
+ if fedReq == nil {
+ return errResp
+ }
+ return federatedEventRelationship(req.Context(), fedReq, db, rsAPI)
+ },
+ )).Methods(http.MethodPost, http.MethodOptions)
+ return nil
+}
+
+type reqCtx struct {
+ ctx context.Context
+ rsAPI roomserver.RoomserverInternalAPI
+ db Database
+ req *EventRelationshipRequest
+ userID string
+ isFederatedRequest bool
+}
+
+func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse {
+ return func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ relation, err := NewEventRelationshipRequest(req.Body)
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("failed to decode HTTP request as JSON")
+ return util.JSONResponse{
+ Code: 400,
+ JSON: jsonerror.BadJSON(fmt.Sprintf("invalid json: %s", err)),
+ }
+ }
+ rc := reqCtx{
+ ctx: req.Context(),
+ req: relation,
+ userID: device.UserID,
+ rsAPI: rsAPI,
+ isFederatedRequest: false,
+ db: db,
+ }
+ res, resErr := rc.process()
+ if resErr != nil {
+ return *resErr
+ }
+
+ return util.JSONResponse{
+ Code: 200,
+ JSON: res,
+ }
+ }
+}
+
+func federatedEventRelationship(ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, db Database, rsAPI roomserver.RoomserverInternalAPI) util.JSONResponse {
+ relation, err := NewEventRelationshipRequest(bytes.NewBuffer(fedReq.Content()))
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("failed to decode HTTP request as JSON")
+ return util.JSONResponse{
+ Code: 400,
+ JSON: jsonerror.BadJSON(fmt.Sprintf("invalid json: %s", err)),
+ }
+ }
+ rc := reqCtx{
+ ctx: ctx,
+ req: relation,
+ userID: "",
+ rsAPI: rsAPI,
+ isFederatedRequest: true,
+ db: db,
+ }
+ res, resErr := rc.process()
+ if resErr != nil {
+ return *resErr
+ }
+
+ return util.JSONResponse{
+ Code: 200,
+ JSON: res,
+ }
+}
+
+func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) {
+ var res EventRelationshipResponse
+ var returnEvents []*gomatrixserverlib.HeaderedEvent
+ // Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
+ // We should have the event being referenced so don't give any claimed room ID / servers
+ event := rc.getEventIfVisible(rc.req.EventID, "", nil)
+ if event == nil {
+ return nil, &util.JSONResponse{
+ Code: 403,
+ JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"),
+ }
+ }
+
+ // Retrieve the event. Add it to response array.
+ returnEvents = append(returnEvents, event)
+
+ if rc.req.IncludeParent {
+ if parentEvent := rc.includeParent(event); parentEvent != nil {
+ returnEvents = append(returnEvents, parentEvent)
+ }
+ }
+
+ if rc.req.IncludeChildren {
+ remaining := rc.req.Limit - len(returnEvents)
+ if remaining > 0 {
+ children, resErr := rc.includeChildren(rc.db, event.EventID(), remaining, rc.req.RecentFirst)
+ if resErr != nil {
+ return nil, resErr
+ }
+ returnEvents = append(returnEvents, children...)
+ }
+ }
+
+ remaining := rc.req.Limit - len(returnEvents)
+ var walkLimited bool
+ if remaining > 0 {
+ included := make(map[string]bool, len(returnEvents))
+ for _, ev := range returnEvents {
+ included[ev.EventID()] = true
+ }
+ var events []*gomatrixserverlib.HeaderedEvent
+ events, walkLimited = walkThread(
+ rc.ctx, rc.db, rc, included, remaining,
+ )
+ returnEvents = append(returnEvents, events...)
+ }
+ res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents))
+ for i, ev := range returnEvents {
+ res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(ev, gomatrixserverlib.FormatAll)
+ }
+ res.Limited = remaining == 0 || walkLimited
+ return &res, nil
+}
+
+// If include_parent: true and there is a valid m.relationship field in the event,
+// retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array.
+func (rc *reqCtx) includeParent(event *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) {
+ parentID, _, _ := parentChildEventIDs(event)
+ if parentID == "" {
+ return nil
+ }
+ claimedRoomID, claimedServers := roomIDAndServers(event)
+ return rc.getEventIfVisible(parentID, claimedRoomID, claimedServers)
+}
+
+// If include_children: true, lookup all events which have event_id as an m.relationship
+// Apply history visibility checks to all these events and add the ones which pass into the response array,
+// honouring the recent_first flag and the limit.
+func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) {
+ children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst)
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent")
+ resErr := jsonerror.InternalServerError()
+ return nil, &resErr
+ }
+ var childEvents []*gomatrixserverlib.HeaderedEvent
+ for _, child := range children {
+ // in order for us to even know about the children the server must be joined to those rooms, hence pass no claimed room ID or servers.
+ childEvent := rc.getEventIfVisible(child.EventID, "", nil)
+ if childEvent != nil {
+ childEvents = append(childEvents, childEvent)
+ }
+ }
+ if len(childEvents) > limit {
+ return childEvents[:limit], nil
+ }
+ return childEvents, nil
+}
+
+// Begin to walk the thread DAG in the direction specified, either depth or breadth first according to the depth_first flag,
+// honouring the limit, max_depth and max_breadth values according to the following rules
+// nolint: unparam
+func walkThread(
+ ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int,
+) ([]*gomatrixserverlib.HeaderedEvent, bool) {
+ if rc.req.Direction != "down" {
+ util.GetLogger(ctx).Error("not implemented: direction=up")
+ return nil, false
+ }
+ var result []*gomatrixserverlib.HeaderedEvent
+ eventWalker := walker{
+ ctx: ctx,
+ req: rc.req,
+ db: db,
+ fn: func(wi *walkInfo) bool {
+ // If already processed event, skip.
+ if included[wi.EventID] {
+ return false
+ }
+
+ // If the response array is >= limit, stop.
+ if len(result) >= limit {
+ return true
+ }
+
+ // Process the event.
+ // TODO: Include edge information: room ID and servers
+ event := rc.getEventIfVisible(wi.EventID, "", nil)
+ if event != nil {
+ result = append(result, event)
+ }
+ included[wi.EventID] = true
+ return false
+ },
+ }
+ limited, err := eventWalker.WalkFrom(rc.req.EventID)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", rc.req.EventID)
+ }
+ return result, limited
+}
+
+func (rc *reqCtx) getEventIfVisible(eventID string, claimedRoomID string, claimedServers []string) *gomatrixserverlib.HeaderedEvent {
+ event, joinedToRoom := getEventIfVisible(rc.ctx, rc.rsAPI, eventID, rc.userID)
+ if event != nil && joinedToRoom {
+ return event
+ }
+ // either we don't have the event or we aren't joined to the room, regardless we should try joining if auto join is enabled
+ if !rc.req.AutoJoin {
+ return nil
+ }
+ // if we're doing this on behalf of a random server don't auto-join rooms regardless of what the request says
+ if rc.isFederatedRequest {
+ return nil
+ }
+ roomID := claimedRoomID
+ var servers []gomatrixserverlib.ServerName
+ if event != nil {
+ roomID = event.RoomID()
+ }
+ for _, s := range claimedServers {
+ servers = append(servers, gomatrixserverlib.ServerName(s))
+ }
+ var joinRes roomserver.PerformJoinResponse
+ rc.rsAPI.PerformJoin(rc.ctx, &roomserver.PerformJoinRequest{
+ UserID: rc.userID,
+ Content: map[string]interface{}{},
+ RoomIDOrAlias: roomID,
+ ServerNames: servers,
+ }, &joinRes)
+ if joinRes.Error != nil {
+ util.GetLogger(rc.ctx).WithError(joinRes.Error).WithField("room_id", roomID).Error("Failed to auto-join room")
+ return nil
+ }
+ if event != nil {
+ return event
+ }
+ // TODO: hit /event_relationships on the server we joined via
+ util.GetLogger(rc.ctx).Infof("joined room but need to fetch event TODO")
+ return nil
+}
+
+func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, eventID, userID string) (*gomatrixserverlib.HeaderedEvent, bool) {
+ var queryEventsRes roomserver.QueryEventsByIDResponse
+ err := rsAPI.QueryEventsByID(ctx, &roomserver.QueryEventsByIDRequest{
+ EventIDs: []string{eventID},
+ }, &queryEventsRes)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID")
+ return nil, false
+ }
+ if len(queryEventsRes.Events) == 0 {
+ util.GetLogger(ctx).Infof("event does not exist")
+ return nil, false // event does not exist
+ }
+ event := queryEventsRes.Events[0]
+
+ // Allow events if the member is in the room
+ // TODO: This does not honour history_visibility
+ // TODO: This does not honour m.room.create content
+ var queryMembershipRes roomserver.QueryMembershipForUserResponse
+ err = rsAPI.QueryMembershipForUser(ctx, &roomserver.QueryMembershipForUserRequest{
+ RoomID: event.RoomID(),
+ UserID: userID,
+ }, &queryMembershipRes)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser")
+ return nil, false
+ }
+ return event, queryMembershipRes.IsInRoom
+}
+
+type walkInfo struct {
+ eventInfo
+ SiblingNumber int
+ Depth int
+}
+
+type walker struct {
+ ctx context.Context
+ req *EventRelationshipRequest
+ db Database
+ fn func(wi *walkInfo) bool // callback invoked for each event walked, return true to terminate the walk
+}
+
+// WalkFrom the event ID given
+func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
+ children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst)
+ if err != nil {
+ util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
+ return false, err
+ }
+ var next *walkInfo
+ toWalk := w.addChildren(nil, children, 1)
+ next, toWalk = w.nextChild(toWalk)
+ for next != nil {
+ stop := w.fn(next)
+ if stop {
+ return true, nil
+ }
+ // find the children's children
+ children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, w.req.RecentFirst)
+ if err != nil {
+ util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
+ return false, err
+ }
+ toWalk = w.addChildren(toWalk, children, next.Depth+1)
+ next, toWalk = w.nextChild(toWalk)
+ }
+
+ return false, nil
+}
+
+// addChildren adds an event's children to the to walk data structure
+func (w *walker) addChildren(toWalk []walkInfo, children []eventInfo, depthOfChildren int) []walkInfo {
+ // Check what number child this event is (ordered by recent_first) compared to its parent, does it exceed (greater than) max_breadth? If yes, skip.
+ if len(children) > w.req.MaxBreadth {
+ children = children[:w.req.MaxBreadth]
+ }
+ // Check how deep the event is compared to event_id, does it exceed (greater than) max_depth? If yes, skip.
+ if depthOfChildren > w.req.MaxDepth {
+ return toWalk
+ }
+
+ if w.req.DepthFirst {
+ // the slice is a stack so push them in reverse order so we pop them in the correct order
+ // e.g [3,2,1] => [3,2] , 1 => [3] , 2 => [] , 3
+ for i := len(children) - 1; i >= 0; i-- {
+ toWalk = append(toWalk, walkInfo{
+ eventInfo: children[i],
+ SiblingNumber: i + 1, // index from 1
+ Depth: depthOfChildren,
+ })
+ }
+ } else {
+ // the slice is a queue so push them in normal order to we dequeue them in the correct order
+ // e.g [1,2,3] => 1, [2, 3] => 2 , [3] => 3, []
+ for i := range children {
+ toWalk = append(toWalk, walkInfo{
+ eventInfo: children[i],
+ SiblingNumber: i + 1, // index from 1
+ Depth: depthOfChildren,
+ })
+ }
+ }
+ return toWalk
+}
+
+func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) {
+ if len(toWalk) == 0 {
+ return nil, nil
+ }
+ var child walkInfo
+ if w.req.DepthFirst {
+ // toWalk is a stack so pop the child off
+ child, toWalk = toWalk[len(toWalk)-1], toWalk[:len(toWalk)-1]
+ return &child, toWalk
+ }
+ // toWalk is a queue so shift the child off
+ child, toWalk = toWalk[0], toWalk[1:]
+ return &child, toWalk
+}
diff --git a/internal/mscs/msc2836/msc2836_test.go b/internal/mscs/msc2836/msc2836_test.go
new file mode 100644
index 00000000..cbf8b726
--- /dev/null
+++ b/internal/mscs/msc2836/msc2836_test.go
@@ -0,0 +1,574 @@
+package msc2836_test
+
+import (
+ "bytes"
+ "context"
+ "crypto/ed25519"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/gorilla/mux"
+ "github.com/matrix-org/dendrite/internal/config"
+ "github.com/matrix-org/dendrite/internal/hooks"
+ "github.com/matrix-org/dendrite/internal/httputil"
+ "github.com/matrix-org/dendrite/internal/mscs/msc2836"
+ "github.com/matrix-org/dendrite/internal/setup"
+ roomserver "github.com/matrix-org/dendrite/roomserver/api"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var (
+ client = &http.Client{
+ Timeout: 10 * time.Second,
+ }
+)
+
+// Basic sanity check of MSC2836 logic. Injects a thread that looks like:
+// A
+// |
+// B
+// / \
+// C D
+// /|\
+// E F G
+// |
+// H
+// And makes sure POST /event_relationships works with various parameters
+func TestMSC2836(t *testing.T) {
+ alice := "@alice:localhost"
+ bob := "@bob:localhost"
+ charlie := "@charlie:localhost"
+ roomIDA := "!alice:localhost"
+ roomIDB := "!bob:localhost"
+ roomIDC := "!charlie:localhost"
+ // give access tokens to all three users
+ nopUserAPI := &testUserAPI{
+ accessTokens: make(map[string]userapi.Device),
+ }
+ nopUserAPI.accessTokens["alice"] = userapi.Device{
+ AccessToken: "alice",
+ DisplayName: "Alice",
+ UserID: alice,
+ }
+ nopUserAPI.accessTokens["bob"] = userapi.Device{
+ AccessToken: "bob",
+ DisplayName: "Bob",
+ UserID: bob,
+ }
+ nopUserAPI.accessTokens["charlie"] = userapi.Device{
+ AccessToken: "charlie",
+ DisplayName: "Charles",
+ UserID: charlie,
+ }
+ eventA := mustCreateEvent(t, fledglingEvent{
+ RoomID: roomIDA,
+ Sender: alice,
+ Type: "m.room.message",
+ Content: map[string]interface{}{
+ "body": "[A] Do you know shelties?",
+ },
+ })
+ eventB := mustCreateEvent(t, fledglingEvent{
+ RoomID: roomIDB,
+ Sender: bob,
+ Type: "m.room.message",
+ Content: map[string]interface{}{
+ "body": "[B] I <3 shelties",
+ "m.relationship": map[string]string{
+ "rel_type": "m.reference",
+ "event_id": eventA.EventID(),
+ },
+ },
+ })
+ eventC := mustCreateEvent(t, fledglingEvent{
+ RoomID: roomIDB,
+ Sender: bob,
+ Type: "m.room.message",
+ Content: map[string]interface{}{
+ "body": "[C] like so much",
+ "m.relationship": map[string]string{
+ "rel_type": "m.reference",
+ "event_id": eventB.EventID(),
+ },
+ },
+ })
+ eventD := mustCreateEvent(t, fledglingEvent{
+ RoomID: roomIDA,
+ Sender: alice,
+ Type: "m.room.message",
+ Content: map[string]interface{}{
+ "body": "[D] but what are shelties???",
+ "m.relationship": map[string]string{
+ "rel_type": "m.reference",
+ "event_id": eventB.EventID(),
+ },
+ },
+ })
+ eventE := mustCreateEvent(t, fledglingEvent{
+ RoomID: roomIDB,
+ Sender: bob,
+ Type: "m.room.message",
+ Content: map[string]interface{}{
+ "body": "[E] seriously???",
+ "m.relationship": map[string]string{
+ "rel_type": "m.reference",
+ "event_id": eventD.EventID(),
+ },
+ },
+ })
+ eventF := mustCreateEvent(t, fledglingEvent{
+ RoomID: roomIDC,
+ Sender: charlie,
+ Type: "m.room.message",
+ Content: map[string]interface{}{
+ "body": "[F] omg how do you not know what shelties are",
+ "m.relationship": map[string]string{
+ "rel_type": "m.reference",
+ "event_id": eventD.EventID(),
+ },
+ },
+ })
+ eventG := mustCreateEvent(t, fledglingEvent{
+ RoomID: roomIDA,
+ Sender: alice,
+ Type: "m.room.message",
+ Content: map[string]interface{}{
+ "body": "[G] looked it up, it's a sheltered person?",
+ "m.relationship": map[string]string{
+ "rel_type": "m.reference",
+ "event_id": eventD.EventID(),
+ },
+ },
+ })
+ eventH := mustCreateEvent(t, fledglingEvent{
+ RoomID: roomIDB,
+ Sender: bob,
+ Type: "m.room.message",
+ Content: map[string]interface{}{
+ "body": "[H] it's a dog!!!!!",
+ "m.relationship": map[string]string{
+ "rel_type": "m.reference",
+ "event_id": eventE.EventID(),
+ },
+ },
+ })
+ // make everyone joined to each other's rooms
+ nopRsAPI := &testRoomserverAPI{
+ userToJoinedRooms: map[string][]string{
+ alice: []string{roomIDA, roomIDB, roomIDC},
+ bob: []string{roomIDA, roomIDB, roomIDC},
+ charlie: []string{roomIDA, roomIDB, roomIDC},
+ },
+ events: map[string]*gomatrixserverlib.HeaderedEvent{
+ eventA.EventID(): eventA,
+ eventB.EventID(): eventB,
+ eventC.EventID(): eventC,
+ eventD.EventID(): eventD,
+ eventE.EventID(): eventE,
+ eventF.EventID(): eventF,
+ eventG.EventID(): eventG,
+ eventH.EventID(): eventH,
+ },
+ }
+ router := injectEvents(t, nopUserAPI, nopRsAPI, []*gomatrixserverlib.HeaderedEvent{
+ eventA, eventB, eventC, eventD, eventE, eventF, eventG, eventH,
+ })
+ cancel := runServer(t, router)
+ defer cancel()
+
+ t.Run("returns 403 on invalid event IDs", func(t *testing.T) {
+ _ = postRelationships(t, 403, "alice", newReq(t, map[string]interface{}{
+ "event_id": "$invalid",
+ }))
+ })
+ t.Run("returns 403 if not joined to the room of specified event in request", func(t *testing.T) {
+ nopUserAPI.accessTokens["frank"] = userapi.Device{
+ AccessToken: "frank",
+ DisplayName: "Frank Not In Room",
+ UserID: "@frank:localhost",
+ }
+ _ = postRelationships(t, 403, "frank", newReq(t, map[string]interface{}{
+ "event_id": eventB.EventID(),
+ "limit": 1,
+ "include_parent": true,
+ }))
+ })
+ t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) {
+ nopUserAPI.accessTokens["frank2"] = userapi.Device{
+ AccessToken: "frank2",
+ DisplayName: "Frank2 Not In Room",
+ UserID: "@frank2:localhost",
+ }
+ // Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB
+ nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB}
+ body := postRelationships(t, 200, "frank2", newReq(t, map[string]interface{}{
+ "event_id": eventB.EventID(),
+ "limit": 1,
+ "include_parent": true,
+ }))
+ assertContains(t, body, []string{eventB.EventID()})
+ })
+ t.Run("returns the parent if include_parent is true", func(t *testing.T) {
+ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
+ "event_id": eventB.EventID(),
+ "include_parent": true,
+ "limit": 2,
+ }))
+ assertContains(t, body, []string{eventB.EventID(), eventA.EventID()})
+ })
+ t.Run("returns the children in the right order if include_children is true", func(t *testing.T) {
+ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
+ "event_id": eventD.EventID(),
+ "include_children": true,
+ "recent_first": true,
+ "limit": 4,
+ }))
+ assertContains(t, body, []string{eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()})
+ body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
+ "event_id": eventD.EventID(),
+ "include_children": true,
+ "recent_first": false,
+ "limit": 4,
+ }))
+ assertContains(t, body, []string{eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
+ })
+ t.Run("walks the graph depth first", func(t *testing.T) {
+ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
+ "event_id": eventB.EventID(),
+ "recent_first": false,
+ "depth_first": true,
+ "limit": 6,
+ }))
+ // Oldest first so:
+ // A
+ // |
+ // B1
+ // / \
+ // C2 D3
+ // /| \
+ // 4E 6F G
+ // |
+ // 5H
+ assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventH.EventID(), eventF.EventID()})
+ body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
+ "event_id": eventB.EventID(),
+ "recent_first": true,
+ "depth_first": true,
+ "limit": 6,
+ }))
+ // Recent first so:
+ // A
+ // |
+ // B1
+ // / \
+ // C D2
+ // /| \
+ // E5 F4 G3
+ // |
+ // H6
+ assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID(), eventH.EventID()})
+ })
+ t.Run("walks the graph breadth first", func(t *testing.T) {
+ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
+ "event_id": eventB.EventID(),
+ "recent_first": false,
+ "depth_first": false,
+ "limit": 6,
+ }))
+ // Oldest first so:
+ // A
+ // |
+ // B1
+ // / \
+ // C2 D3
+ // /| \
+ // E4 F5 G6
+ // |
+ // H
+ assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
+ body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
+ "event_id": eventB.EventID(),
+ "recent_first": true,
+ "depth_first": false,
+ "limit": 6,
+ }))
+ // Recent first so:
+ // A
+ // |
+ // B1
+ // / \
+ // C3 D2
+ // /| \
+ // E6 F5 G4
+ // |
+ // H
+ assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventC.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()})
+ })
+ t.Run("caps via max_breadth", func(t *testing.T) {
+ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
+ "event_id": eventB.EventID(),
+ "recent_first": false,
+ "depth_first": false,
+ "max_breadth": 2,
+ "limit": 10,
+ }))
+ // Event G gets omitted because of max_breadth
+ assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventH.EventID()})
+ })
+ t.Run("caps via max_depth", func(t *testing.T) {
+ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
+ "event_id": eventB.EventID(),
+ "recent_first": false,
+ "depth_first": false,
+ "max_depth": 2,
+ "limit": 10,
+ }))
+ // Event H gets omitted because of max_depth
+ assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
+ })
+ t.Run("terminates when reaching the limit", func(t *testing.T) {
+ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
+ "event_id": eventB.EventID(),
+ "recent_first": false,
+ "depth_first": false,
+ "limit": 4,
+ }))
+ assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID()})
+ })
+ t.Run("returns all events with a high enough limit", func(t *testing.T) {
+ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
+ "event_id": eventB.EventID(),
+ "recent_first": false,
+ "depth_first": false,
+ "limit": 400,
+ }))
+ assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()})
+ })
+}
+
+// TODO: TestMSC2836TerminatesLoops (short and long)
+// TODO: TestMSC2836UnknownEventsSkipped
+// TODO: TestMSC2836SkipEventIfNotInRoom
+
+func newReq(t *testing.T, jsonBody map[string]interface{}) *msc2836.EventRelationshipRequest {
+ t.Helper()
+ b, err := json.Marshal(jsonBody)
+ if err != nil {
+ t.Fatalf("Failed to marshal request: %s", err)
+ }
+ r, err := msc2836.NewEventRelationshipRequest(bytes.NewBuffer(b))
+ if err != nil {
+ t.Fatalf("Failed to NewEventRelationshipRequest: %s", err)
+ }
+ return r
+}
+
+func runServer(t *testing.T, router *mux.Router) func() {
+ t.Helper()
+ externalServ := &http.Server{
+ Addr: string(":8009"),
+ WriteTimeout: 60 * time.Second,
+ Handler: router,
+ }
+ go func() {
+ externalServ.ListenAndServe()
+ }()
+ // wait to listen on the port
+ time.Sleep(500 * time.Millisecond)
+ return func() {
+ externalServ.Shutdown(context.TODO())
+ }
+}
+
+func postRelationships(t *testing.T, expectCode int, accessToken string, req *msc2836.EventRelationshipRequest) *msc2836.EventRelationshipResponse {
+ t.Helper()
+ var r msc2836.EventRelationshipRequest
+ r.Defaults()
+ data, err := json.Marshal(req)
+ if err != nil {
+ t.Fatalf("failed to marshal request: %s", err)
+ }
+ httpReq, err := http.NewRequest(
+ "POST", "http://localhost:8009/_matrix/client/unstable/event_relationships",
+ bytes.NewBuffer(data),
+ )
+ httpReq.Header.Set("Authorization", "Bearer "+accessToken)
+ if err != nil {
+ t.Fatalf("failed to prepare request: %s", err)
+ }
+ res, err := client.Do(httpReq)
+ if err != nil {
+ t.Fatalf("failed to do request: %s", err)
+ }
+ if res.StatusCode != expectCode {
+ body, _ := ioutil.ReadAll(res.Body)
+ t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body))
+ }
+ if res.StatusCode == 200 {
+ var result msc2836.EventRelationshipResponse
+ if err := json.NewDecoder(res.Body).Decode(&result); err != nil {
+ t.Fatalf("response 200 OK but failed to deserialise JSON : %s", err)
+ }
+ return &result
+ }
+ return nil
+}
+
+func assertContains(t *testing.T, result *msc2836.EventRelationshipResponse, wantEventIDs []string) {
+ t.Helper()
+ gotEventIDs := make([]string, len(result.Events))
+ for i, ev := range result.Events {
+ gotEventIDs[i] = ev.EventID
+ }
+ if len(gotEventIDs) != len(wantEventIDs) {
+ t.Fatalf("length mismatch: got %v want %v", gotEventIDs, wantEventIDs)
+ }
+ for i := range gotEventIDs {
+ if gotEventIDs[i] != wantEventIDs[i] {
+ t.Errorf("wrong item in position %d - got %s want %s", i, gotEventIDs[i], wantEventIDs[i])
+ }
+ }
+}
+
+type testUserAPI struct {
+ accessTokens map[string]userapi.Device
+}
+
+func (u *testUserAPI) InputAccountData(ctx context.Context, req *userapi.InputAccountDataRequest, res *userapi.InputAccountDataResponse) error {
+ return nil
+}
+func (u *testUserAPI) PerformAccountCreation(ctx context.Context, req *userapi.PerformAccountCreationRequest, res *userapi.PerformAccountCreationResponse) error {
+ return nil
+}
+func (u *testUserAPI) PerformPasswordUpdate(ctx context.Context, req *userapi.PerformPasswordUpdateRequest, res *userapi.PerformPasswordUpdateResponse) error {
+ return nil
+}
+func (u *testUserAPI) PerformDeviceCreation(ctx context.Context, req *userapi.PerformDeviceCreationRequest, res *userapi.PerformDeviceCreationResponse) error {
+ return nil
+}
+func (u *testUserAPI) PerformDeviceDeletion(ctx context.Context, req *userapi.PerformDeviceDeletionRequest, res *userapi.PerformDeviceDeletionResponse) error {
+ return nil
+}
+func (u *testUserAPI) PerformDeviceUpdate(ctx context.Context, req *userapi.PerformDeviceUpdateRequest, res *userapi.PerformDeviceUpdateResponse) error {
+ return nil
+}
+func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error {
+ return nil
+}
+func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error {
+ return nil
+}
+func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error {
+ dev, ok := u.accessTokens[req.AccessToken]
+ if !ok {
+ res.Err = fmt.Errorf("unknown token")
+ return nil
+ }
+ res.Device = &dev
+ return nil
+}
+func (u *testUserAPI) QueryDevices(ctx context.Context, req *userapi.QueryDevicesRequest, res *userapi.QueryDevicesResponse) error {
+ return nil
+}
+func (u *testUserAPI) QueryAccountData(ctx context.Context, req *userapi.QueryAccountDataRequest, res *userapi.QueryAccountDataResponse) error {
+ return nil
+}
+func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDeviceInfosRequest, res *userapi.QueryDeviceInfosResponse) error {
+ return nil
+}
+func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error {
+ return nil
+}
+
+type testRoomserverAPI struct {
+ // use a trace API as it implements method stubs so we don't need to have them here.
+ // We'll override the functions we care about.
+ roomserver.RoomserverInternalAPITrace
+ userToJoinedRooms map[string][]string
+ events map[string]*gomatrixserverlib.HeaderedEvent
+}
+
+func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error {
+ for _, eventID := range req.EventIDs {
+ ev := r.events[eventID]
+ if ev != nil {
+ res.Events = append(res.Events, ev)
+ }
+ }
+ return nil
+}
+
+func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error {
+ rooms := r.userToJoinedRooms[req.UserID]
+ for _, roomID := range rooms {
+ if roomID == req.RoomID {
+ res.IsInRoom = true
+ res.HasBeenInRoom = true
+ res.Membership = "join"
+ break
+ }
+ }
+ return nil
+}
+
+func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router {
+ t.Helper()
+ cfg := &config.Dendrite{}
+ cfg.Defaults()
+ cfg.Global.ServerName = "localhost"
+ cfg.MSCs.Database.ConnectionString = "file:msc2836_test.db"
+ cfg.MSCs.MSCs = []string{"msc2836"}
+ base := &setup.BaseDendrite{
+ Cfg: cfg,
+ PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(),
+ PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(),
+ }
+
+ err := msc2836.Enable(base, rsAPI, nil, userAPI, nil)
+ if err != nil {
+ t.Fatalf("failed to enable MSC2836: %s", err)
+ }
+ for _, ev := range events {
+ hooks.Run(hooks.KindNewEventPersisted, ev)
+ }
+ return base.PublicClientAPIMux
+}
+
+type fledglingEvent struct {
+ Type string
+ StateKey *string
+ Content interface{}
+ Sender string
+ RoomID string
+}
+
+func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) {
+ t.Helper()
+ roomVer := gomatrixserverlib.RoomVersionV6
+ seed := make([]byte, ed25519.SeedSize) // zero seed
+ key := ed25519.NewKeyFromSeed(seed)
+ eb := gomatrixserverlib.EventBuilder{
+ Sender: ev.Sender,
+ Depth: 999,
+ Type: ev.Type,
+ StateKey: ev.StateKey,
+ RoomID: ev.RoomID,
+ }
+ err := eb.SetContent(ev.Content)
+ if err != nil {
+ t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content)
+ }
+ // make sure the origin_server_ts changes so we can test recency
+ time.Sleep(1 * time.Millisecond)
+ signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer)
+ if err != nil {
+ t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
+ }
+ h := signedEvent.Headered(roomVer)
+ return h
+}
diff --git a/internal/mscs/msc2836/storage.go b/internal/mscs/msc2836/storage.go
new file mode 100644
index 00000000..f524165f
--- /dev/null
+++ b/internal/mscs/msc2836/storage.go
@@ -0,0 +1,226 @@
+package msc2836
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal/config"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+type eventInfo struct {
+ EventID string
+ OriginServerTS gomatrixserverlib.Timestamp
+ RoomID string
+ Servers []string
+}
+
+type Database interface {
+ // StoreRelation stores the parent->child and child->parent relationship for later querying.
+ // Also stores the event metadata e.g timestamp
+ StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error
+ // ChildrenForParent returns the events who have the given `eventID` as an m.relationship with the
+ // provided `relType`. The returned slice is sorted by origin_server_ts according to whether
+ // `recentFirst` is true or false.
+ ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error)
+}
+
+type DB struct {
+ db *sql.DB
+ writer sqlutil.Writer
+ insertEdgeStmt *sql.Stmt
+ insertNodeStmt *sql.Stmt
+ selectChildrenForParentOldestFirstStmt *sql.Stmt
+ selectChildrenForParentRecentFirstStmt *sql.Stmt
+}
+
+// NewDatabase loads the database for msc2836
+func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
+ if dbOpts.ConnectionString.IsPostgres() {
+ return newPostgresDatabase(dbOpts)
+ }
+ return newSQLiteDatabase(dbOpts)
+}
+
+func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
+ d := DB{
+ writer: sqlutil.NewDummyWriter(),
+ }
+ var err error
+ if d.db, err = sqlutil.Open(dbOpts); err != nil {
+ return nil, err
+ }
+ _, err = d.db.Exec(`
+ CREATE TABLE IF NOT EXISTS msc2836_edges (
+ parent_event_id TEXT NOT NULL,
+ child_event_id TEXT NOT NULL,
+ rel_type TEXT NOT NULL,
+ parent_room_id TEXT NOT NULL,
+ parent_servers TEXT NOT NULL,
+ CONSTRAINT msc2836_edges_uniq UNIQUE (parent_event_id, child_event_id, rel_type)
+ );
+
+ CREATE TABLE IF NOT EXISTS msc2836_nodes (
+ event_id TEXT PRIMARY KEY NOT NULL,
+ origin_server_ts BIGINT NOT NULL,
+ room_id TEXT NOT NULL
+ );
+ `)
+ if err != nil {
+ return nil, err
+ }
+ if d.insertEdgeStmt, err = d.db.Prepare(`
+ INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING
+ `); err != nil {
+ return nil, err
+ }
+ if d.insertNodeStmt, err = d.db.Prepare(`
+ INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING
+ `); err != nil {
+ return nil, err
+ }
+ selectChildrenQuery := `
+ SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges
+ LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id
+ WHERE parent_event_id = $1 AND rel_type = $2
+ ORDER BY origin_server_ts
+ `
+ if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil {
+ return nil, err
+ }
+ if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
+ return nil, err
+ }
+ return &d, err
+}
+
+func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
+ d := DB{
+ writer: sqlutil.NewExclusiveWriter(),
+ }
+ var err error
+ if d.db, err = sqlutil.Open(dbOpts); err != nil {
+ return nil, err
+ }
+ _, err = d.db.Exec(`
+ CREATE TABLE IF NOT EXISTS msc2836_edges (
+ parent_event_id TEXT NOT NULL,
+ child_event_id TEXT NOT NULL,
+ rel_type TEXT NOT NULL,
+ parent_room_id TEXT NOT NULL,
+ parent_servers TEXT NOT NULL,
+ UNIQUE (parent_event_id, child_event_id, rel_type)
+ );
+
+ CREATE TABLE IF NOT EXISTS msc2836_nodes (
+ event_id TEXT PRIMARY KEY NOT NULL,
+ origin_server_ts BIGINT NOT NULL,
+ room_id TEXT NOT NULL
+ );
+ `)
+ if err != nil {
+ return nil, err
+ }
+ if d.insertEdgeStmt, err = d.db.Prepare(`
+ INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING
+ `); err != nil {
+ return nil, err
+ }
+ if d.insertNodeStmt, err = d.db.Prepare(`
+ INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING
+ `); err != nil {
+ return nil, err
+ }
+ selectChildrenQuery := `
+ SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges
+ LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id
+ WHERE parent_event_id = $1 AND rel_type = $2
+ ORDER BY origin_server_ts
+ `
+ if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil {
+ return nil, err
+ }
+ if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
+ return nil, err
+ }
+ return &d, nil
+}
+
+func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error {
+ parent, child, relType := parentChildEventIDs(ev)
+ if parent == "" || child == "" {
+ return nil
+ }
+ relationRoomID, relationServers := roomIDAndServers(ev)
+ relationServersJSON, err := json.Marshal(relationServers)
+ if err != nil {
+ return err
+ }
+ return p.writer.Do(p.db, nil, func(txn *sql.Tx) error {
+ _, err := txn.Stmt(p.insertEdgeStmt).ExecContext(ctx, parent, child, relType, relationRoomID, string(relationServersJSON))
+ if err != nil {
+ return err
+ }
+ _, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID())
+ return err
+ })
+}
+
+func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) {
+ var rows *sql.Rows
+ var err error
+ if recentFirst {
+ rows, err = p.selectChildrenForParentRecentFirstStmt.QueryContext(ctx, eventID, relType)
+ } else {
+ rows, err = p.selectChildrenForParentOldestFirstStmt.QueryContext(ctx, eventID, relType)
+ }
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ var children []eventInfo
+ for rows.Next() {
+ var evInfo eventInfo
+ if err := rows.Scan(&evInfo.EventID, &evInfo.OriginServerTS, &evInfo.RoomID); err != nil {
+ return nil, err
+ }
+ children = append(children, evInfo)
+ }
+ return children, nil
+}
+
+func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) {
+ if ev == nil {
+ return
+ }
+ body := struct {
+ Relationship struct {
+ RelType string `json:"rel_type"`
+ EventID string `json:"event_id"`
+ } `json:"m.relationship"`
+ }{}
+ if err := json.Unmarshal(ev.Content(), &body); err != nil {
+ return
+ }
+ if body.Relationship.EventID == "" || body.Relationship.RelType == "" {
+ return
+ }
+ return body.Relationship.EventID, ev.EventID(), body.Relationship.RelType
+}
+
+func roomIDAndServers(ev *gomatrixserverlib.HeaderedEvent) (roomID string, servers []string) {
+ servers = []string{}
+ if ev == nil {
+ return
+ }
+ body := struct {
+ RoomID string `json:"relationship_room_id"`
+ Servers []string `json:"relationship_servers"`
+ }{}
+ if err := json.Unmarshal(ev.Unsigned(), &body); err != nil {
+ return
+ }
+ return body.RoomID, body.Servers
+}
diff --git a/internal/mscs/mscs.go b/internal/mscs/mscs.go
new file mode 100644
index 00000000..0a896ab0
--- /dev/null
+++ b/internal/mscs/mscs.go
@@ -0,0 +1,42 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package mscs implements Matrix Spec Changes from https://github.com/matrix-org/matrix-doc
+package mscs
+
+import (
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal/mscs/msc2836"
+ "github.com/matrix-org/dendrite/internal/setup"
+)
+
+// Enable MSCs - returns an error on unknown MSCs
+func Enable(base *setup.BaseDendrite, monolith *setup.Monolith) error {
+ for _, msc := range base.Cfg.MSCs.MSCs {
+ if err := EnableMSC(base, monolith, msc); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func EnableMSC(base *setup.BaseDendrite, monolith *setup.Monolith, msc string) error {
+ switch msc {
+ case "msc2836":
+ return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationSenderAPI, monolith.UserAPI, monolith.KeyRing)
+ default:
+ return fmt.Errorf("EnableMSC: unknown msc '%s'", msc)
+ }
+}
diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go
index 29dbd25c..ec561f11 100644
--- a/roomserver/api/perform.go
+++ b/roomserver/api/perform.go
@@ -83,7 +83,8 @@ type PerformJoinRequest struct {
type PerformJoinResponse struct {
// The room ID, populated on success.
- RoomID string `json:"room_id"`
+ RoomID string `json:"room_id"`
+ JoinedVia gomatrixserverlib.ServerName
// If non-nil, the join request failed. Contains more information why it failed.
Error *PerformError
}
diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go
index d7257539..79dc2fe1 100644
--- a/roomserver/internal/input/input.go
+++ b/roomserver/internal/input/input.go
@@ -22,6 +22,7 @@ import (
"time"
"github.com/Shopify/sarama"
+ "github.com/matrix-org/dendrite/internal/hooks"
"github.com/matrix-org/dendrite/roomserver/acls"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage"
@@ -61,7 +62,11 @@ func (w *inputWorker) start() {
for {
select {
case task := <-w.input:
+ hooks.Run(hooks.KindNewEventReceived, &task.event.Event)
_, task.err = w.r.processRoomEvent(task.ctx, task.event)
+ if task.err == nil {
+ hooks.Run(hooks.KindNewEventPersisted, &task.event.Event)
+ }
task.wg.Done()
case <-time.After(time.Second * 5):
return
diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go
index 56ae6d0b..f3745a7f 100644
--- a/roomserver/internal/perform/perform_join.go
+++ b/roomserver/internal/perform/perform_join.go
@@ -47,7 +47,7 @@ func (r *Joiner) PerformJoin(
req *api.PerformJoinRequest,
res *api.PerformJoinResponse,
) {
- roomID, err := r.performJoin(ctx, req)
+ roomID, joinedVia, err := r.performJoin(ctx, req)
if err != nil {
perr, ok := err.(*api.PerformError)
if ok {
@@ -59,21 +59,22 @@ func (r *Joiner) PerformJoin(
}
}
res.RoomID = roomID
+ res.JoinedVia = joinedVia
}
func (r *Joiner) performJoin(
ctx context.Context,
req *api.PerformJoinRequest,
-) (string, error) {
+) (string, gomatrixserverlib.ServerName, error) {
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
- return "", &api.PerformError{
+ return "", "", &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
}
}
if domain != r.Cfg.Matrix.ServerName {
- return "", &api.PerformError{
+ return "", "", &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
}
@@ -84,7 +85,7 @@ func (r *Joiner) performJoin(
if strings.HasPrefix(req.RoomIDOrAlias, "#") {
return r.performJoinRoomByAlias(ctx, req)
}
- return "", &api.PerformError{
+ return "", "", &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias),
}
@@ -93,11 +94,11 @@ func (r *Joiner) performJoin(
func (r *Joiner) performJoinRoomByAlias(
ctx context.Context,
req *api.PerformJoinRequest,
-) (string, error) {
+) (string, gomatrixserverlib.ServerName, error) {
// Get the domain part of the room alias.
_, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias)
if err != nil {
- return "", fmt.Errorf("Alias %q is not in the correct format", req.RoomIDOrAlias)
+ return "", "", fmt.Errorf("Alias %q is not in the correct format", req.RoomIDOrAlias)
}
req.ServerNames = append(req.ServerNames, domain)
@@ -115,7 +116,7 @@ func (r *Joiner) performJoinRoomByAlias(
err = r.FSAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes)
if err != nil {
logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias)
- return "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err)
+ return "", "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err)
}
roomID = dirRes.RoomID
req.ServerNames = append(req.ServerNames, dirRes.ServerNames...)
@@ -123,13 +124,13 @@ func (r *Joiner) performJoinRoomByAlias(
// Otherwise, look up if we know this room alias locally.
roomID, err = r.DB.GetRoomIDForAlias(ctx, req.RoomIDOrAlias)
if err != nil {
- return "", fmt.Errorf("Lookup room alias %q failed: %w", req.RoomIDOrAlias, err)
+ return "", "", fmt.Errorf("Lookup room alias %q failed: %w", req.RoomIDOrAlias, err)
}
}
// If the room ID is empty then we failed to look up the alias.
if roomID == "" {
- return "", fmt.Errorf("Alias %q not found", req.RoomIDOrAlias)
+ return "", "", fmt.Errorf("Alias %q not found", req.RoomIDOrAlias)
}
// If we do, then pluck out the room ID and continue the join.
@@ -142,11 +143,11 @@ func (r *Joiner) performJoinRoomByAlias(
func (r *Joiner) performJoinRoomByID(
ctx context.Context,
req *api.PerformJoinRequest,
-) (string, error) {
+) (string, gomatrixserverlib.ServerName, error) {
// Get the domain part of the room ID.
_, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias)
if err != nil {
- return "", &api.PerformError{
+ return "", "", &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomIDOrAlias, err),
}
@@ -169,7 +170,7 @@ func (r *Joiner) performJoinRoomByID(
Redacts: "",
}
if err = eb.SetUnsigned(struct{}{}); err != nil {
- return "", fmt.Errorf("eb.SetUnsigned: %w", err)
+ return "", "", fmt.Errorf("eb.SetUnsigned: %w", err)
}
// It is possible for the request to include some "content" for the
@@ -180,7 +181,7 @@ func (r *Joiner) performJoinRoomByID(
}
req.Content["membership"] = gomatrixserverlib.Join
if err = eb.SetContent(req.Content); err != nil {
- return "", fmt.Errorf("eb.SetContent: %w", err)
+ return "", "", fmt.Errorf("eb.SetContent: %w", err)
}
// Force a federated join if we aren't in the room and we've been
@@ -194,7 +195,7 @@ func (r *Joiner) performJoinRoomByID(
if err == nil && isInvitePending {
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
if ierr != nil {
- return "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
+ return "", "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
// If we were invited by someone from another server then we can
@@ -206,8 +207,10 @@ func (r *Joiner) performJoinRoomByID(
}
// If we should do a forced federated join then do that.
+ var joinedVia gomatrixserverlib.ServerName
if forceFederatedJoin {
- return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req)
+ joinedVia, err = r.performFederatedJoinRoomByID(ctx, req)
+ return req.RoomIDOrAlias, joinedVia, err
}
// Try to construct an actual join event from the template.
@@ -249,7 +252,7 @@ func (r *Joiner) performJoinRoomByID(
inputRes := api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
if err = inputRes.Err(); err != nil {
- return "", &api.PerformError{
+ return "", "", &api.PerformError{
Code: api.PerformErrorNotAllowed,
Msg: fmt.Sprintf("InputRoomEvents auth failed: %s", err),
}
@@ -265,7 +268,7 @@ func (r *Joiner) performJoinRoomByID(
// Otherwise we'll try a federated join as normal, since it's quite
// possible that the room still exists on other servers.
if len(req.ServerNames) == 0 {
- return "", &api.PerformError{
+ return "", "", &api.PerformError{
Code: api.PerformErrorNoRoom,
Msg: fmt.Sprintf("Room ID %q does not exist", req.RoomIDOrAlias),
}
@@ -273,24 +276,25 @@ func (r *Joiner) performJoinRoomByID(
}
// Perform a federated room join.
- return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req)
+ joinedVia, err = r.performFederatedJoinRoomByID(ctx, req)
+ return req.RoomIDOrAlias, joinedVia, err
default:
// Something else went wrong.
- return "", fmt.Errorf("Error joining local room: %q", err)
+ return "", "", fmt.Errorf("Error joining local room: %q", err)
}
// By this point, if req.RoomIDOrAlias contained an alias, then
// it will have been overwritten with a room ID by performJoinRoomByAlias.
// We should now include this in the response so that the CS API can
// return the right room ID.
- return req.RoomIDOrAlias, nil
+ return req.RoomIDOrAlias, r.Cfg.Matrix.ServerName, nil
}
func (r *Joiner) performFederatedJoinRoomByID(
ctx context.Context,
req *api.PerformJoinRequest,
-) error {
+) (gomatrixserverlib.ServerName, error) {
// Try joining by all of the supplied server names.
fedReq := fsAPI.PerformJoinRequest{
RoomID: req.RoomIDOrAlias, // the room ID to try and join
@@ -301,13 +305,13 @@ func (r *Joiner) performFederatedJoinRoomByID(
fedRes := fsAPI.PerformJoinResponse{}
r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes)
if fedRes.LastError != nil {
- return &api.PerformError{
+ return "", &api.PerformError{
Code: api.PerformErrRemote,
Msg: fedRes.LastError.Message,
RemoteCode: fedRes.LastError.Code,
}
}
- return nil
+ return fedRes.JoinedVia, nil
}
func buildEvent(