aboutsummaryrefslogtreecommitdiff
path: root/setup
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-12-04 14:11:01 +0000
committerGitHub <noreply@github.com>2020-12-04 14:11:01 +0000
commitb507312d4cf9d35b5d4eaaa01a7f74d095b825f8 (patch)
tree2812bd7453da07a3c2850fb0b27a740c950af212 /setup
parentc052edafdd765d821f9732add4f5d33962ba5ba4 (diff)
MSC2836 threading: part 2 (#1596)
* Update GMSL * Add MSC2836EventRelationships to fedsender * Call MSC2836EventRelationships in reqCtx * auth remote servers * Extract room ID and servers from previous events; refactor a bit * initial cut of federated threading * Use the right client/fed struct in the response * Add QueryAuthChain for use with MSC2836 * Add auth chain to federated response * Fix pointers * under CI: more logging and enable mscs, nil fix * Handle direction: up * Actually send message events to the roomserver.. * Add children and children_hash to unsigned, with tests * Add logic for exploring threads and tracking children; missing storage functions * Implement storage functions for children * Add fetchUnknownEvent * Do federated hits for include_children if we have unexplored children * Use /ev_rel rather than /event as the former includes child metadata * Remove cross-room threading impl * Enable MSC2836 in the p2p demo * Namespace mscs db * Enable msc2836 for ygg Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
Diffstat (limited to 'setup')
-rw-r--r--setup/mscs/msc2836/msc2836.go598
-rw-r--r--setup/mscs/msc2836/msc2836_test.go123
-rw-r--r--setup/mscs/msc2836/storage.go159
-rw-r--r--setup/mscs/mscs.go3
4 files changed, 706 insertions, 177 deletions
diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go
index 33a65c8f..95473f97 100644
--- a/setup/mscs/msc2836/msc2836.go
+++ b/setup/mscs/msc2836/msc2836.go
@@ -18,10 +18,13 @@ package msc2836
import (
"bytes"
"context"
+ "crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
+ "sort"
+ "strings"
"time"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
@@ -36,13 +39,12 @@ import (
)
const (
- constRelType = "m.reference"
- constRoomIDKey = "relationship_room_id"
- constRoomServers = "relationship_servers"
+ constRelType = "m.reference"
)
type EventRelationshipRequest struct {
EventID string `json:"event_id"`
+ RoomID string `json:"room_id"`
MaxDepth int `json:"max_depth"`
MaxBreadth int `json:"max_breadth"`
Limit int `json:"limit"`
@@ -52,7 +54,6 @@ type EventRelationshipRequest struct {
IncludeChildren bool `json:"include_children"`
Direction string `json:"direction"`
Batch string `json:"batch"`
- AutoJoin bool `json:"auto_join"`
}
func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) {
@@ -81,8 +82,16 @@ type EventRelationshipResponse struct {
Limited bool `json:"limited"`
}
+func toClientResponse(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) *EventRelationshipResponse {
+ out := &EventRelationshipResponse{
+ Events: gomatrixserverlib.ToClientEvents(res.Events, gomatrixserverlib.FormatAll),
+ Limited: res.Limited,
+ NextBatch: res.NextBatch,
+ }
+ return out
+}
+
// Enable this MSC
-// nolint:gocyclo
func Enable(
base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI,
userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier,
@@ -96,63 +105,22 @@ func Enable(
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
hookErr := db.StoreRelation(context.Background(), he)
if hookErr != nil {
- util.GetLogger(context.Background()).WithError(hookErr).Error(
+ util.GetLogger(context.Background()).WithError(hookErr).WithField("event_id", he.EventID()).Error(
"failed to StoreRelation",
)
}
- })
- hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) {
- he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
- ctx := context.Background()
- // we only inject metadata for events our server sends
- userID := he.Sender()
- _, domain, err := gomatrixserverlib.SplitID('@', userID)
- if err != nil {
- return
- }
- if domain != base.Cfg.Global.ServerName {
- return
- }
- // if this event has an m.relationship, add on the room_id and servers to unsigned
- parent, child, relType := parentChildEventIDs(he)
- if parent == "" || child == "" || relType == "" {
- return
- }
- event, joinedToRoom := getEventIfVisible(ctx, rsAPI, parent, userID)
- if !joinedToRoom {
- return
- }
- err = he.SetUnsignedField(constRoomIDKey, event.RoomID())
- if err != nil {
- util.GetLogger(context.Background()).WithError(err).Warn("Failed to SetUnsignedField")
- return
- }
-
- var servers []gomatrixserverlib.ServerName
- if fsAPI != nil {
- var res fs.QueryJoinedHostServerNamesInRoomResponse
- err = fsAPI.QueryJoinedHostServerNamesInRoom(ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
- RoomID: event.RoomID(),
- }, &res)
- if err != nil {
- util.GetLogger(context.Background()).WithError(err).Warn("Failed to QueryJoinedHostServerNamesInRoom")
- return
- }
- servers = res.ServerNames
- } else {
- servers = []gomatrixserverlib.ServerName{
- base.Cfg.Global.ServerName,
- }
- }
- err = he.SetUnsignedField(constRoomServers, servers)
- if err != nil {
- util.GetLogger(context.Background()).WithError(err).Warn("Failed to SetUnsignedField")
- return
+ // we need to update child metadata here as well as after doing remote /event_relationships requests
+ // so we catch child metadata originating from /send transactions
+ hookErr = db.UpdateChildMetadata(context.Background(), he)
+ if hookErr != nil {
+ util.GetLogger(context.Background()).WithError(err).WithField("event_id", he.EventID()).Warn(
+ "failed to update child metadata for event",
+ )
}
})
base.PublicClientAPIMux.Handle("/unstable/event_relationships",
- httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI)),
+ httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI, fsAPI)),
).Methods(http.MethodPost, http.MethodOptions)
base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI(
@@ -163,22 +131,27 @@ func Enable(
if fedReq == nil {
return errResp
}
- return federatedEventRelationship(req.Context(), fedReq, db, rsAPI)
+ return federatedEventRelationship(req.Context(), fedReq, db, rsAPI, fsAPI)
},
)).Methods(http.MethodPost, http.MethodOptions)
return nil
}
type reqCtx struct {
- ctx context.Context
- rsAPI roomserver.RoomserverInternalAPI
- db Database
- req *EventRelationshipRequest
- userID string
+ ctx context.Context
+ rsAPI roomserver.RoomserverInternalAPI
+ db Database
+ req *EventRelationshipRequest
+ userID string
+ roomVersion gomatrixserverlib.RoomVersion
+
+ // federated request args
isFederatedRequest bool
+ serverName gomatrixserverlib.ServerName
+ fsAPI fs.FederationSenderInternalAPI
}
-func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse {
+func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse {
return func(req *http.Request, device *userapi.Device) util.JSONResponse {
relation, err := NewEventRelationshipRequest(req.Body)
if err != nil {
@@ -193,6 +166,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
req: relation,
userID: device.UserID,
rsAPI: rsAPI,
+ fsAPI: fsAPI,
isFederatedRequest: false,
db: db,
}
@@ -203,12 +177,14 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
return util.JSONResponse{
Code: 200,
- JSON: res,
+ JSON: toClientResponse(res),
}
}
}
-func federatedEventRelationship(ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, db Database, rsAPI roomserver.RoomserverInternalAPI) util.JSONResponse {
+func federatedEventRelationship(
+ ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, db Database, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI,
+) util.JSONResponse {
relation, err := NewEventRelationshipRequest(bytes.NewBuffer(fedReq.Content()))
if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to decode HTTP request as JSON")
@@ -218,17 +194,43 @@ func federatedEventRelationship(ctx context.Context, fedReq *gomatrixserverlib.F
}
}
rc := reqCtx{
- ctx: ctx,
- req: relation,
- userID: "",
- rsAPI: rsAPI,
+ ctx: ctx,
+ req: relation,
+ rsAPI: rsAPI,
+ db: db,
+ // federation args
isFederatedRequest: true,
- db: db,
+ fsAPI: fsAPI,
+ serverName: fedReq.Origin(),
}
res, resErr := rc.process()
if resErr != nil {
return *resErr
}
+ // add auth chain information
+ requiredAuthEventsSet := make(map[string]bool)
+ var requiredAuthEvents []string
+ for _, ev := range res.Events {
+ for _, a := range ev.AuthEventIDs() {
+ if requiredAuthEventsSet[a] {
+ continue
+ }
+ requiredAuthEvents = append(requiredAuthEvents, a)
+ requiredAuthEventsSet[a] = true
+ }
+ }
+ var queryRes roomserver.QueryAuthChainResponse
+ err = rsAPI.QueryAuthChain(ctx, &roomserver.QueryAuthChainRequest{
+ EventIDs: requiredAuthEvents,
+ }, &queryRes)
+ if err != nil {
+ // they may already have the auth events so don't fail this request
+ util.GetLogger(ctx).WithError(err).Error("Failed to QueryAuthChain")
+ }
+ res.AuthChain = make([]*gomatrixserverlib.Event, len(queryRes.AuthChain))
+ for i := range queryRes.AuthChain {
+ res.AuthChain[i] = queryRes.AuthChain[i].Unwrap()
+ }
return util.JSONResponse{
Code: 200,
@@ -236,18 +238,25 @@ func federatedEventRelationship(ctx context.Context, fedReq *gomatrixserverlib.F
}
}
-func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) {
- var res EventRelationshipResponse
+// nolint:gocyclo
+func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsResponse, *util.JSONResponse) {
+ var res gomatrixserverlib.MSC2836EventRelationshipsResponse
var returnEvents []*gomatrixserverlib.HeaderedEvent
// Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
- // We should have the event being referenced so don't give any claimed room ID / servers
- event := rc.getEventIfVisible(rc.req.EventID, "", nil)
+ event := rc.getLocalEvent(rc.req.EventID)
if event == nil {
+ event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID)
+ }
+ if rc.req.RoomID == "" && event != nil {
+ rc.req.RoomID = event.RoomID()
+ }
+ if event == nil || !rc.authorisedToSeeEvent(event) {
return nil, &util.JSONResponse{
Code: 403,
JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"),
}
}
+ rc.roomVersion = event.Version()
// Retrieve the event. Add it to response array.
returnEvents = append(returnEvents, event)
@@ -282,29 +291,122 @@ func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) {
)
returnEvents = append(returnEvents, events...)
}
- res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents))
+ res.Events = make([]*gomatrixserverlib.Event, len(returnEvents))
for i, ev := range returnEvents {
- res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(ev, gomatrixserverlib.FormatAll)
+ // for each event, extract the children_count | hash and add it as unsigned data.
+ rc.addChildMetadata(ev)
+ res.Events[i] = ev.Unwrap()
}
res.Limited = remaining == 0 || walkLimited
return &res, nil
}
+// fetchUnknownEvent retrieves an unknown event from the room specified. This server must
+// be joined to the room in question. This has the side effect of injecting surround threaded
+// events into the roomserver.
+func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *gomatrixserverlib.HeaderedEvent {
+ if rc.isFederatedRequest || roomID == "" {
+ // we don't do fed hits for fed requests, and we can't ask servers without a room ID!
+ return nil
+ }
+ logger := util.GetLogger(rc.ctx).WithField("room_id", roomID)
+ // if they supplied a room_id, check the room exists.
+ var queryVerRes roomserver.QueryRoomVersionForRoomResponse
+ err := rc.rsAPI.QueryRoomVersionForRoom(rc.ctx, &roomserver.QueryRoomVersionForRoomRequest{
+ RoomID: roomID,
+ }, &queryVerRes)
+ if err != nil {
+ logger.WithError(err).Warn("failed to query room version for room, does this room exist?")
+ return nil
+ }
+
+ // check the user is joined to that room
+ var queryMemRes roomserver.QueryMembershipForUserResponse
+ err = rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
+ RoomID: roomID,
+ UserID: rc.userID,
+ }, &queryMemRes)
+ if err != nil {
+ logger.WithError(err).Warn("failed to query membership for user in room")
+ return nil
+ }
+ if !queryMemRes.IsInRoom {
+ return nil
+ }
+
+ // ask one of the servers in the room for the event
+ var queryRes fs.QueryJoinedHostServerNamesInRoomResponse
+ err = rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
+ RoomID: roomID,
+ }, &queryRes)
+ if err != nil {
+ logger.WithError(err).Error("failed to QueryJoinedHostServerNamesInRoom")
+ return nil
+ }
+ // query up to 5 servers
+ serversToQuery := queryRes.ServerNames
+ if len(serversToQuery) > 5 {
+ serversToQuery = serversToQuery[:5]
+ }
+
+ // fetch the event, along with some of the surrounding thread (if it's threaded) and the auth chain.
+ // Inject the response into the roomserver to remember the event across multiple calls and to set
+ // unexplored flags correctly.
+ for _, srv := range serversToQuery {
+ res, err := rc.MSC2836EventRelationships(eventID, srv, queryVerRes.RoomVersion)
+ if err != nil {
+ continue
+ }
+ rc.injectResponseToRoomserver(res)
+ for _, ev := range res.Events {
+ if ev.EventID() == eventID {
+ return ev.Headered(ev.Version())
+ }
+ }
+ }
+ logger.WithField("servers", serversToQuery).Warn("failed to query event relationships")
+ return nil
+}
+
// If include_parent: true and there is a valid m.relationship field in the event,
// retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array.
-func (rc *reqCtx) includeParent(event *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) {
- parentID, _, _ := parentChildEventIDs(event)
+func (rc *reqCtx) includeParent(childEvent *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) {
+ parentID, _, _ := parentChildEventIDs(childEvent)
if parentID == "" {
return nil
}
- claimedRoomID, claimedServers := roomIDAndServers(event)
- return rc.getEventIfVisible(parentID, claimedRoomID, claimedServers)
+ return rc.lookForEvent(parentID)
}
// If include_children: true, lookup all events which have event_id as an m.relationship
// Apply history visibility checks to all these events and add the ones which pass into the response array,
// honouring the recent_first flag and the limit.
func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) {
+ if rc.hasUnexploredChildren(parentID) {
+ // we need to do a remote request to pull in the children as we are missing them locally.
+ serversToQuery := rc.getServersForEventID(parentID)
+ var result *gomatrixserverlib.MSC2836EventRelationshipsResponse
+ for _, srv := range serversToQuery {
+ res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{
+ EventID: parentID,
+ Direction: "down",
+ Limit: 100,
+ MaxBreadth: -1,
+ MaxDepth: 1, // we just want the children from this parent
+ RecentFirst: true,
+ }, rc.roomVersion)
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("includeChildren: failed to call MSC2836EventRelationships")
+ } else {
+ result = &res
+ break
+ }
+ }
+ if result != nil {
+ rc.injectResponseToRoomserver(result)
+ }
+ // fallthrough to pull these new events from the DB
+ }
children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent")
@@ -313,8 +415,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen
}
var childEvents []*gomatrixserverlib.HeaderedEvent
for _, child := range children {
- // in order for us to even know about the children the server must be joined to those rooms, hence pass no claimed room ID or servers.
- childEvent := rc.getEventIfVisible(child.EventID, "", nil)
+ childEvent := rc.lookForEvent(child.EventID)
if childEvent != nil {
childEvents = append(childEvents, childEvent)
}
@@ -327,14 +428,9 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen
// Begin to walk the thread DAG in the direction specified, either depth or breadth first according to the depth_first flag,
// honouring the limit, max_depth and max_breadth values according to the following rules
-// nolint: unparam
func walkThread(
ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int,
) ([]*gomatrixserverlib.HeaderedEvent, bool) {
- if rc.req.Direction != "down" {
- util.GetLogger(ctx).Error("not implemented: direction=up")
- return nil, false
- }
var result []*gomatrixserverlib.HeaderedEvent
eventWalker := walker{
ctx: ctx,
@@ -352,8 +448,11 @@ func walkThread(
}
// Process the event.
- // TODO: Include edge information: room ID and servers
- event := rc.getEventIfVisible(wi.EventID, "", nil)
+ // if event is not found, use remoteEventRelationships to explore that part of the thread remotely.
+ // This will probably be easiest if the event relationships response is directly pumped into the database
+ // so the next walk will do the right thing. This requires those events to be authed and likely injected as
+ // outliers into the roomserver DB, which will de-dupe appropriately.
+ event := rc.lookForEvent(wi.EventID)
if event != nil {
result = append(result, event)
}
@@ -368,74 +467,280 @@ func walkThread(
return result, limited
}
-func (rc *reqCtx) getEventIfVisible(eventID string, claimedRoomID string, claimedServers []string) *gomatrixserverlib.HeaderedEvent {
- event, joinedToRoom := getEventIfVisible(rc.ctx, rc.rsAPI, eventID, rc.userID)
- if event != nil && joinedToRoom {
- return event
+// MSC2836EventRelationships performs an /event_relationships request to a remote server
+func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) {
+ res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{
+ EventID: eventID,
+ DepthFirst: rc.req.DepthFirst,
+ Direction: rc.req.Direction,
+ Limit: rc.req.Limit,
+ MaxBreadth: rc.req.MaxBreadth,
+ MaxDepth: rc.req.MaxDepth,
+ RecentFirst: rc.req.RecentFirst,
+ }, ver)
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).Error("Failed to call MSC2836EventRelationships")
+ return nil, err
+ }
+ return &res, nil
+
+}
+
+// authorisedToSeeEvent checks that the user or server is allowed to see this event. Returns true if allowed to
+// see this request. This only needs to be done once per room at present as we just check for joined status.
+func (rc *reqCtx) authorisedToSeeEvent(event *gomatrixserverlib.HeaderedEvent) bool {
+ if rc.isFederatedRequest {
+ // make sure the server is in this room
+ var res fs.QueryJoinedHostServerNamesInRoomResponse
+ err := rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
+ RoomID: event.RoomID(),
+ }, &res)
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryJoinedHostServerNamesInRoom")
+ return false
+ }
+ for _, srv := range res.ServerNames {
+ if srv == rc.serverName {
+ return true
+ }
+ }
+ return false
}
- // either we don't have the event or we aren't joined to the room, regardless we should try joining if auto join is enabled
- if !rc.req.AutoJoin {
+ // make sure the user is in this room
+ // Allow events if the member is in the room
+ // TODO: This does not honour history_visibility
+ // TODO: This does not honour m.room.create content
+ var queryMembershipRes roomserver.QueryMembershipForUserResponse
+ err := rc.rsAPI.QueryMembershipForUser(rc.ctx, &roomserver.QueryMembershipForUserRequest{
+ RoomID: event.RoomID(),
+ UserID: rc.userID,
+ }, &queryMembershipRes)
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).Error("authorisedToSeeEvent: failed to QueryMembershipForUser")
+ return false
+ }
+ return queryMembershipRes.IsInRoom
+}
+
+func (rc *reqCtx) getServersForEventID(eventID string) []gomatrixserverlib.ServerName {
+ if rc.req.RoomID == "" {
+ util.GetLogger(rc.ctx).WithField("event_id", eventID).Error(
+ "getServersForEventID: event exists in unknown room",
+ )
return nil
}
- // if we're doing this on behalf of a random server don't auto-join rooms regardless of what the request says
- if rc.isFederatedRequest {
+ if rc.roomVersion == "" {
+ util.GetLogger(rc.ctx).WithField("event_id", eventID).Errorf(
+ "getServersForEventID: event exists in %s with unknown room version", rc.req.RoomID,
+ )
return nil
}
- roomID := claimedRoomID
- var servers []gomatrixserverlib.ServerName
- if event != nil {
- roomID = event.RoomID()
- }
- for _, s := range claimedServers {
- servers = append(servers, gomatrixserverlib.ServerName(s))
- }
- var joinRes roomserver.PerformJoinResponse
- rc.rsAPI.PerformJoin(rc.ctx, &roomserver.PerformJoinRequest{
- UserID: rc.userID,
- Content: map[string]interface{}{},
- RoomIDOrAlias: roomID,
- ServerNames: servers,
- }, &joinRes)
- if joinRes.Error != nil {
- util.GetLogger(rc.ctx).WithError(joinRes.Error).WithField("room_id", roomID).Error("Failed to auto-join room")
+ var queryRes fs.QueryJoinedHostServerNamesInRoomResponse
+ err := rc.fsAPI.QueryJoinedHostServerNamesInRoom(rc.ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
+ RoomID: rc.req.RoomID,
+ }, &queryRes)
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).Error("getServersForEventID: failed to QueryJoinedHostServerNamesInRoom")
return nil
}
- if event != nil {
+ // query up to 5 servers
+ serversToQuery := queryRes.ServerNames
+ if len(serversToQuery) > 5 {
+ serversToQuery = serversToQuery[:5]
+ }
+ return serversToQuery
+}
+
+func (rc *reqCtx) remoteEventRelationships(eventID string) *gomatrixserverlib.MSC2836EventRelationshipsResponse {
+ if rc.isFederatedRequest {
+ return nil // we don't query remote servers for remote requests
+ }
+ serversToQuery := rc.getServersForEventID(eventID)
+ var res *gomatrixserverlib.MSC2836EventRelationshipsResponse
+ var err error
+ for _, srv := range serversToQuery {
+ res, err = rc.MSC2836EventRelationships(eventID, srv, rc.roomVersion)
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("remoteEventRelationships: failed to call MSC2836EventRelationships")
+ } else {
+ break
+ }
+ }
+ return res
+}
+
+// lookForEvent returns the event for the event ID given, by trying to query remote servers
+// if the event ID is unknown via /event_relationships.
+func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent {
+ event := rc.getLocalEvent(eventID)
+ if event == nil {
+ queryRes := rc.remoteEventRelationships(eventID)
+ if queryRes != nil {
+ // inject all the events into the roomserver then return the event in question
+ rc.injectResponseToRoomserver(queryRes)
+ for _, ev := range queryRes.Events {
+ if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID() {
+ return ev.Headered(ev.Version())
+ }
+ }
+ }
+ } else if rc.hasUnexploredChildren(eventID) {
+ // we have the local event but we may need to do a remote hit anyway if we are exploring the thread and have unknown children.
+ // If we don't do this then we risk never fetching the children.
+ queryRes := rc.remoteEventRelationships(eventID)
+ if queryRes != nil {
+ rc.injectResponseToRoomserver(queryRes)
+ err := rc.db.MarkChildrenExplored(context.Background(), eventID)
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).Warnf("failed to mark children of %s as explored", eventID)
+ }
+ }
+ }
+ if rc.req.RoomID == event.RoomID() {
return event
}
- // TODO: hit /event_relationships on the server we joined via
- util.GetLogger(rc.ctx).Infof("joined room but need to fetch event TODO")
return nil
}
-func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, eventID, userID string) (*gomatrixserverlib.HeaderedEvent, bool) {
+func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent {
var queryEventsRes roomserver.QueryEventsByIDResponse
- err := rsAPI.QueryEventsByID(ctx, &roomserver.QueryEventsByIDRequest{
+ err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{
EventIDs: []string{eventID},
}, &queryEventsRes)
if err != nil {
- util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID")
- return nil, false
+ util.GetLogger(rc.ctx).WithError(err).Error("getLocalEvent: failed to QueryEventsByID")
+ return nil
}
if len(queryEventsRes.Events) == 0 {
- util.GetLogger(ctx).Infof("event does not exist")
- return nil, false // event does not exist
+ util.GetLogger(rc.ctx).WithField("event_id", eventID).Infof("getLocalEvent: event does not exist")
+ return nil // event does not exist
}
- event := queryEventsRes.Events[0]
+ return queryEventsRes.Events[0]
+}
- // Allow events if the member is in the room
- // TODO: This does not honour history_visibility
- // TODO: This does not honour m.room.create content
- var queryMembershipRes roomserver.QueryMembershipForUserResponse
- err = rsAPI.QueryMembershipForUser(ctx, &roomserver.QueryMembershipForUserRequest{
- RoomID: event.RoomID(),
- UserID: userID,
- }, &queryMembershipRes)
+// injectResponseToRoomserver injects the events
+// into the roomserver as KindOutlier, with auth chains.
+func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) {
+ var stateEvents []*gomatrixserverlib.Event
+ var messageEvents []*gomatrixserverlib.Event
+ for _, ev := range res.Events {
+ if ev.StateKey() != nil {
+ stateEvents = append(stateEvents, ev)
+ } else {
+ messageEvents = append(messageEvents, ev)
+ }
+ }
+ respState := gomatrixserverlib.RespState{
+ AuthEvents: res.AuthChain,
+ StateEvents: stateEvents,
+ }
+ eventsInOrder, err := respState.Events()
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse")
+ return
+ }
+ // everything gets sent as an outlier because auth chain events may be disjoint from the DAG
+ // as may the threaded events.
+ var ires []roomserver.InputRoomEvent
+ for _, outlier := range append(eventsInOrder, messageEvents...) {
+ ires = append(ires, roomserver.InputRoomEvent{
+ Kind: roomserver.KindOutlier,
+ Event: outlier.Headered(outlier.Version()),
+ AuthEventIDs: outlier.AuthEventIDs(),
+ })
+ }
+ // we've got the data by this point so use a background context
+ err = roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires)
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver")
+ }
+ // update the child count / hash columns for these nodes. We need to do this here because not all events will make it
+ // through to the KindNewEventPersisted hook because the roomserver will ignore duplicates. Duplicates have meaning though
+ // as the `unsigned` field may differ (if the number of children changes).
+ for _, ev := range ires {
+ err = rc.db.UpdateChildMetadata(context.Background(), ev.Event)
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).WithField("event_id", ev.Event.EventID()).Warn("failed to update child metadata for event")
+ }
+ }
+}
+
+func (rc *reqCtx) addChildMetadata(ev *gomatrixserverlib.HeaderedEvent) {
+ count, hash := rc.getChildMetadata(ev.EventID())
+ if count == 0 {
+ return
+ }
+ err := ev.SetUnsignedField("children_hash", gomatrixserverlib.Base64Bytes(hash))
if err != nil {
- util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser")
- return nil, false
+ util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children_hash")
+ }
+ err = ev.SetUnsignedField("children", map[string]int{
+ constRelType: count,
+ })
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children count")
}
- return event, queryMembershipRes.IsInRoom
+}
+
+func (rc *reqCtx) getChildMetadata(eventID string) (count int, hash []byte) {
+ children, err := rc.db.ChildrenForParent(rc.ctx, eventID, constRelType, false)
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).Warn("Failed to get ChildrenForParent for getting child metadata")
+ return
+ }
+ if len(children) == 0 {
+ return
+ }
+ // sort it lexiographically
+ sort.Slice(children, func(i, j int) bool {
+ return children[i].EventID < children[j].EventID
+ })
+ // hash it
+ var eventIDs strings.Builder
+ for _, c := range children {
+ _, _ = eventIDs.WriteString(c.EventID)
+ }
+ hashValBytes := sha256.Sum256([]byte(eventIDs.String()))
+
+ count = len(children)
+ hash = hashValBytes[:]
+ return
+}
+
+// hasUnexploredChildren returns true if this event has unexplored children.
+// "An event has unexplored children if the `unsigned` child count on the parent does not match
+// how many children the server believes the parent to have. In addition, if the counts match but
+// the hashes do not match, then the event is unexplored."
+func (rc *reqCtx) hasUnexploredChildren(eventID string) bool {
+ if rc.isFederatedRequest {
+ return false // we only explore children for clients, not servers.
+ }
+ // extract largest child count from event
+ eventCount, eventHash, explored, err := rc.db.ChildMetadata(rc.ctx, eventID)
+ if err != nil {
+ util.GetLogger(rc.ctx).WithError(err).WithField("event_id", eventID).Warn(
+ "failed to get ChildMetadata from db",
+ )
+ return false
+ }
+ // if there are no recorded children then we know we have >= children.
+ // if the event has already been explored (read: we hit /event_relationships successfully)
+ // then don't do it again. We'll only re-do this if we get an even bigger children count,
+ // see Database.UpdateChildMetadata
+ if eventCount == 0 || explored {
+ return false // short-circuit
+ }
+
+ // calculate child count for event
+ calcCount, calcHash := rc.getChildMetadata(eventID)
+
+ if eventCount < calcCount {
+ return false // we have more children
+ } else if eventCount > calcCount {
+ return true // the event has more children than we know about
+ }
+ // we have the same count, so a mismatched hash means some children are different
+ return !bytes.Equal(eventHash, calcHash)
}
type walkInfo struct {
@@ -453,9 +758,9 @@ type walker struct {
// WalkFrom the event ID given
func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
- children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst)
+ children, err := w.childrenForParent(eventID)
if err != nil {
- util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
+ util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() childrenForParent failed, cannot walk")
return false, err
}
var next *walkInfo
@@ -467,9 +772,9 @@ func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
return true, nil
}
// find the children's children
- children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, w.req.RecentFirst)
+ children, err = w.childrenForParent(next.EventID)
if err != nil {
- util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
+ util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() childrenForParent failed, cannot walk")
return false, err
}
toWalk = w.addChildren(toWalk, children, next.Depth+1)
@@ -528,3 +833,20 @@ func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) {
child, toWalk = toWalk[0], toWalk[1:]
return &child, toWalk
}
+
+// childrenForParent returns the children events for this event ID, honouring the direction: up|down flags
+// meaning this can actually be returning the parent for the event instead of the children.
+func (w *walker) childrenForParent(eventID string) ([]eventInfo, error) {
+ if w.req.Direction == "down" {
+ return w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst)
+ }
+ // find the event to pull out the parent
+ ei, err := w.db.ParentForChild(w.ctx, eventID, constRelType)
+ if err != nil {
+ return nil, err
+ }
+ if ei != nil {
+ return []eventInfo{*ei}, nil
+ }
+ return nil, nil
+}
diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go
index 996cc79f..4eb5708c 100644
--- a/setup/mscs/msc2836/msc2836_test.go
+++ b/setup/mscs/msc2836/msc2836_test.go
@@ -4,10 +4,14 @@ import (
"bytes"
"context"
"crypto/ed25519"
+ "crypto/sha256"
+ "encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
+ "sort"
+ "strings"
"testing"
"time"
@@ -43,9 +47,7 @@ func TestMSC2836(t *testing.T) {
alice := "@alice:localhost"
bob := "@bob:localhost"
charlie := "@charlie:localhost"
- roomIDA := "!alice:localhost"
- roomIDB := "!bob:localhost"
- roomIDC := "!charlie:localhost"
+ roomID := "!alice:localhost"
// give access tokens to all three users
nopUserAPI := &testUserAPI{
accessTokens: make(map[string]userapi.Device),
@@ -66,7 +68,7 @@ func TestMSC2836(t *testing.T) {
UserID: charlie,
}
eventA := mustCreateEvent(t, fledglingEvent{
- RoomID: roomIDA,
+ RoomID: roomID,
Sender: alice,
Type: "m.room.message",
Content: map[string]interface{}{
@@ -74,7 +76,7 @@ func TestMSC2836(t *testing.T) {
},
})
eventB := mustCreateEvent(t, fledglingEvent{
- RoomID: roomIDB,
+ RoomID: roomID,
Sender: bob,
Type: "m.room.message",
Content: map[string]interface{}{
@@ -86,7 +88,7 @@ func TestMSC2836(t *testing.T) {
},
})
eventC := mustCreateEvent(t, fledglingEvent{
- RoomID: roomIDB,
+ RoomID: roomID,
Sender: bob,
Type: "m.room.message",
Content: map[string]interface{}{
@@ -98,7 +100,7 @@ func TestMSC2836(t *testing.T) {
},
})
eventD := mustCreateEvent(t, fledglingEvent{
- RoomID: roomIDA,
+ RoomID: roomID,
Sender: alice,
Type: "m.room.message",
Content: map[string]interface{}{
@@ -110,7 +112,7 @@ func TestMSC2836(t *testing.T) {
},
})
eventE := mustCreateEvent(t, fledglingEvent{
- RoomID: roomIDB,
+ RoomID: roomID,
Sender: bob,
Type: "m.room.message",
Content: map[string]interface{}{
@@ -122,7 +124,7 @@ func TestMSC2836(t *testing.T) {
},
})
eventF := mustCreateEvent(t, fledglingEvent{
- RoomID: roomIDC,
+ RoomID: roomID,
Sender: charlie,
Type: "m.room.message",
Content: map[string]interface{}{
@@ -134,7 +136,7 @@ func TestMSC2836(t *testing.T) {
},
})
eventG := mustCreateEvent(t, fledglingEvent{
- RoomID: roomIDA,
+ RoomID: roomID,
Sender: alice,
Type: "m.room.message",
Content: map[string]interface{}{
@@ -146,7 +148,7 @@ func TestMSC2836(t *testing.T) {
},
})
eventH := mustCreateEvent(t, fledglingEvent{
- RoomID: roomIDB,
+ RoomID: roomID,
Sender: bob,
Type: "m.room.message",
Content: map[string]interface{}{
@@ -160,9 +162,9 @@ func TestMSC2836(t *testing.T) {
// make everyone joined to each other's rooms
nopRsAPI := &testRoomserverAPI{
userToJoinedRooms: map[string][]string{
- alice: []string{roomIDA, roomIDB, roomIDC},
- bob: []string{roomIDA, roomIDB, roomIDC},
- charlie: []string{roomIDA, roomIDB, roomIDC},
+ alice: []string{roomID},
+ bob: []string{roomID},
+ charlie: []string{roomID},
},
events: map[string]*gomatrixserverlib.HeaderedEvent{
eventA.EventID(): eventA,
@@ -198,21 +200,6 @@ func TestMSC2836(t *testing.T) {
"include_parent": true,
}))
})
- t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) {
- nopUserAPI.accessTokens["frank2"] = userapi.Device{
- AccessToken: "frank2",
- DisplayName: "Frank2 Not In Room",
- UserID: "@frank2:localhost",
- }
- // Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB
- nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB}
- body := postRelationships(t, 200, "frank2", newReq(t, map[string]interface{}{
- "event_id": eventB.EventID(),
- "limit": 1,
- "include_parent": true,
- }))
- assertContains(t, body, []string{eventB.EventID()})
- })
t.Run("returns the parent if include_parent is true", func(t *testing.T) {
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
@@ -349,6 +336,39 @@ func TestMSC2836(t *testing.T) {
}))
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()})
})
+ t.Run("can navigate up the graph with direction: up", func(t *testing.T) {
+ // A4
+ // |
+ // B3
+ // / \
+ // C D2
+ // /| \
+ // E F1 G
+ // |
+ // H
+ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
+ "event_id": eventF.EventID(),
+ "recent_first": false,
+ "depth_first": true,
+ "direction": "up",
+ }))
+ assertContains(t, body, []string{eventF.EventID(), eventD.EventID(), eventB.EventID(), eventA.EventID()})
+ })
+ t.Run("includes children and children_hash in unsigned", func(t *testing.T) {
+ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
+ "event_id": eventB.EventID(),
+ "recent_first": false,
+ "depth_first": false,
+ "limit": 3,
+ }))
+ // event B has C,D as children
+ // event C has no children
+ // event D has 3 children (not included in response)
+ assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID()})
+ assertUnsignedChildren(t, body.Events[0], "m.reference", 2, []string{eventC.EventID(), eventD.EventID()})
+ assertUnsignedChildren(t, body.Events[1], "", 0, nil)
+ assertUnsignedChildren(t, body.Events[2], "m.reference", 3, []string{eventE.EventID(), eventF.EventID(), eventG.EventID()})
+ })
}
// TODO: TestMSC2836TerminatesLoops (short and long)
@@ -411,8 +431,12 @@ func postRelationships(t *testing.T, expectCode int, accessToken string, req *ms
}
if res.StatusCode == 200 {
var result msc2836.EventRelationshipResponse
- if err := json.NewDecoder(res.Body).Decode(&result); err != nil {
- t.Fatalf("response 200 OK but failed to deserialise JSON : %s", err)
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("response 200 OK but failed to read response body: %s", err)
+ }
+ if err := json.Unmarshal(body, &result); err != nil {
+ t.Fatalf("response 200 OK but failed to deserialise JSON : %s\nbody: %s", err, string(body))
}
return &result
}
@@ -435,6 +459,43 @@ func assertContains(t *testing.T, result *msc2836.EventRelationshipResponse, wan
}
}
+func assertUnsignedChildren(t *testing.T, ev gomatrixserverlib.ClientEvent, relType string, wantCount int, childrenEventIDs []string) {
+ t.Helper()
+ unsigned := struct {
+ Children map[string]int `json:"children"`
+ Hash string `json:"children_hash"`
+ }{}
+ if err := json.Unmarshal(ev.Unsigned, &unsigned); err != nil {
+ if wantCount == 0 {
+ return // no children so possible there is no unsigned field at all
+ }
+ t.Fatalf("Failed to unmarshal unsigned field: %s", err)
+ }
+ // zero checks
+ if wantCount == 0 {
+ if len(unsigned.Children) != 0 || unsigned.Hash != "" {
+ t.Fatalf("want 0 children but got unsigned fields %+v", unsigned)
+ }
+ return
+ }
+ gotCount := unsigned.Children[relType]
+ if gotCount != wantCount {
+ t.Errorf("Got %d count, want %d count for rel_type %s", gotCount, wantCount, relType)
+ }
+ // work out the hash
+ sort.Strings(childrenEventIDs)
+ var b strings.Builder
+ for _, s := range childrenEventIDs {
+ b.WriteString(s)
+ }
+ t.Logf("hashing %s", b.String())
+ hashValBytes := sha256.Sum256([]byte(b.String()))
+ wantHash := base64.RawStdEncoding.EncodeToString(hashValBytes[:])
+ if wantHash != unsigned.Hash {
+ t.Errorf("Got unsigned hash %s want hash %s", unsigned.Hash, wantHash)
+ }
+}
+
type testUserAPI struct {
accessTokens map[string]userapi.Device
}
diff --git a/setup/mscs/msc2836/storage.go b/setup/mscs/msc2836/storage.go
index 72ea5195..72523916 100644
--- a/setup/mscs/msc2836/storage.go
+++ b/setup/mscs/msc2836/storage.go
@@ -1,20 +1,22 @@
package msc2836
import (
+ "bytes"
"context"
"database/sql"
+ "encoding/base64"
"encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
)
type eventInfo struct {
EventID string
OriginServerTS gomatrixserverlib.Timestamp
RoomID string
- Servers []string
}
type Database interface {
@@ -25,6 +27,21 @@ type Database interface {
// provided `relType`. The returned slice is sorted by origin_server_ts according to whether
// `recentFirst` is true or false.
ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error)
+ // ParentForChild returns the parent event for the given child `eventID`. The eventInfo should be nil if
+ // there is no parent for this child event, with no error. The parent eventInfo can be missing the
+ // timestamp if the event is not known to the server.
+ ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error)
+ // UpdateChildMetadata persists the children_count and children_hash from this event if and only if
+ // the count is greater than what was previously there. If the count is updated, the event will be
+ // updated to be unexplored.
+ UpdateChildMetadata(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error
+ // ChildMetadata returns the children_count and children_hash for the event ID in question.
+ // Also returns the `explored` flag, which is set to true when MarkChildrenExplored is called and is set
+ // back to `false` when a larger count is inserted via UpdateChildMetadata.
+ // Returns nil error if the event ID does not exist.
+ ChildMetadata(ctx context.Context, eventID string) (count int, hash []byte, explored bool, err error)
+ // MarkChildrenExplored sets the 'explored' flag on this event to `true`.
+ MarkChildrenExplored(ctx context.Context, eventID string) error
}
type DB struct {
@@ -34,6 +51,10 @@ type DB struct {
insertNodeStmt *sql.Stmt
selectChildrenForParentOldestFirstStmt *sql.Stmt
selectChildrenForParentRecentFirstStmt *sql.Stmt
+ selectParentForChildStmt *sql.Stmt
+ updateChildMetadataStmt *sql.Stmt
+ selectChildMetadataStmt *sql.Stmt
+ updateChildMetadataExploredStmt *sql.Stmt
}
// NewDatabase loads the database for msc2836
@@ -65,19 +86,26 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
CREATE TABLE IF NOT EXISTS msc2836_nodes (
event_id TEXT PRIMARY KEY NOT NULL,
origin_server_ts BIGINT NOT NULL,
- room_id TEXT NOT NULL
+ room_id TEXT NOT NULL,
+ unsigned_children_count BIGINT NOT NULL,
+ unsigned_children_hash TEXT NOT NULL,
+ explored SMALLINT NOT NULL
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
- INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING
+ INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers)
+ VALUES($1, $2, $3, $4, $5)
+ ON CONFLICT DO NOTHING
`); err != nil {
return nil, err
}
if d.insertNodeStmt, err = d.db.Prepare(`
- INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING
+ INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored)
+ VALUES($1, $2, $3, $4, $5, $6)
+ ON CONFLICT DO NOTHING
`); err != nil {
return nil, err
}
@@ -93,6 +121,27 @@ func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
return nil, err
}
+ if d.selectParentForChildStmt, err = d.db.Prepare(`
+ SELECT parent_event_id, parent_room_id FROM msc2836_edges
+ WHERE child_event_id = $1 AND rel_type = $2
+ `); err != nil {
+ return nil, err
+ }
+ if d.updateChildMetadataStmt, err = d.db.Prepare(`
+ UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4
+ `); err != nil {
+ return nil, err
+ }
+ if d.selectChildMetadataStmt, err = d.db.Prepare(`
+ SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1
+ `); err != nil {
+ return nil, err
+ }
+ if d.updateChildMetadataExploredStmt, err = d.db.Prepare(`
+ UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2
+ `); err != nil {
+ return nil, err
+ }
return &d, err
}
@@ -117,19 +166,26 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
CREATE TABLE IF NOT EXISTS msc2836_nodes (
event_id TEXT PRIMARY KEY NOT NULL,
origin_server_ts BIGINT NOT NULL,
- room_id TEXT NOT NULL
+ room_id TEXT NOT NULL,
+ unsigned_children_count BIGINT NOT NULL,
+ unsigned_children_hash TEXT NOT NULL,
+ explored SMALLINT NOT NULL
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
- INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING
+ INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers)
+ VALUES($1, $2, $3, $4, $5)
+ ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING
`); err != nil {
return nil, err
}
if d.insertNodeStmt, err = d.db.Prepare(`
- INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING
+ INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored)
+ VALUES($1, $2, $3, $4, $5, $6)
+ ON CONFLICT DO NOTHING
`); err != nil {
return nil, err
}
@@ -145,6 +201,27 @@ func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) {
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
return nil, err
}
+ if d.selectParentForChildStmt, err = d.db.Prepare(`
+ SELECT parent_event_id, parent_room_id FROM msc2836_edges
+ WHERE child_event_id = $1 AND rel_type = $2
+ `); err != nil {
+ return nil, err
+ }
+ if d.updateChildMetadataStmt, err = d.db.Prepare(`
+ UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4
+ `); err != nil {
+ return nil, err
+ }
+ if d.selectChildMetadataStmt, err = d.db.Prepare(`
+ SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1
+ `); err != nil {
+ return nil, err
+ }
+ if d.updateChildMetadataExploredStmt, err = d.db.Prepare(`
+ UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2
+ `); err != nil {
+ return nil, err
+ }
return &d, nil
}
@@ -158,16 +235,55 @@ func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEv
if err != nil {
return err
}
+ count, hash := extractChildMetadata(ev)
return p.writer.Do(p.db, nil, func(txn *sql.Tx) error {
_, err := txn.Stmt(p.insertEdgeStmt).ExecContext(ctx, parent, child, relType, relationRoomID, string(relationServersJSON))
if err != nil {
return err
}
- _, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID())
+ util.GetLogger(ctx).Infof("StoreRelation child=%s parent=%s rel_type=%s", child, parent, relType)
+ _, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID(), count, base64.RawStdEncoding.EncodeToString(hash), 0)
return err
})
}
+func (p *DB) UpdateChildMetadata(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error {
+ eventCount, eventHash := extractChildMetadata(ev)
+ if eventCount == 0 {
+ return nil // nothing to update with
+ }
+
+ // extract current children count/hash, if they are less than the current event then update the columns and set to unexplored
+ count, hash, _, err := p.ChildMetadata(ctx, ev.EventID())
+ if err != nil {
+ return err
+ }
+ if eventCount > count || (eventCount == count && !bytes.Equal(hash, eventHash)) {
+ _, err = p.updateChildMetadataStmt.ExecContext(ctx, eventCount, base64.RawStdEncoding.EncodeToString(eventHash), 0, ev.EventID())
+ return err
+ }
+ return nil
+}
+
+func (p *DB) ChildMetadata(ctx context.Context, eventID string) (count int, hash []byte, explored bool, err error) {
+ var b64hash string
+ var exploredInt int
+ if err = p.selectChildMetadataStmt.QueryRowContext(ctx, eventID).Scan(&count, &b64hash, &exploredInt); err != nil {
+ if err == sql.ErrNoRows {
+ err = nil
+ }
+ return
+ }
+ hash, err = base64.RawStdEncoding.DecodeString(b64hash)
+ explored = exploredInt > 0
+ return
+}
+
+func (p *DB) MarkChildrenExplored(ctx context.Context, eventID string) error {
+ _, err := p.updateChildMetadataExploredStmt.ExecContext(ctx, 1, eventID)
+ return err
+}
+
func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) {
var rows *sql.Rows
var err error
@@ -191,6 +307,17 @@ func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, rec
return children, nil
}
+func (p *DB) ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error) {
+ var ei eventInfo
+ err := p.selectParentForChildStmt.QueryRowContext(ctx, eventID, relType).Scan(&ei.EventID, &ei.RoomID)
+ if err == sql.ErrNoRows {
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+ return &ei, nil
+}
+
func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) {
if ev == nil {
return
@@ -224,3 +351,19 @@ func roomIDAndServers(ev *gomatrixserverlib.HeaderedEvent) (roomID string, serve
}
return body.RoomID, body.Servers
}
+
+func extractChildMetadata(ev *gomatrixserverlib.HeaderedEvent) (count int, hash []byte) {
+ unsigned := struct {
+ Counts map[string]int `json:"children"`
+ Hash gomatrixserverlib.Base64Bytes `json:"children_hash"`
+ }{}
+ if err := json.Unmarshal(ev.Unsigned(), &unsigned); err != nil {
+ // expected if there is no unsigned field at all
+ return
+ }
+ for _, c := range unsigned.Counts {
+ count += c
+ }
+ hash = unsigned.Hash
+ return
+}
diff --git a/setup/mscs/mscs.go b/setup/mscs/mscs.go
index 8b0498ce..a8e5668e 100644
--- a/setup/mscs/mscs.go
+++ b/setup/mscs/mscs.go
@@ -16,15 +16,18 @@
package mscs
import (
+ "context"
"fmt"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/mscs/msc2836"
+ "github.com/matrix-org/util"
)
// Enable MSCs - returns an error on unknown MSCs
func Enable(base *setup.BaseDendrite, monolith *setup.Monolith) error {
for _, msc := range base.Cfg.MSCs.MSCs {
+ util.GetLogger(context.Background()).WithField("msc", msc).Info("Enabling MSC")
if err := EnableMSC(base, monolith, msc); err != nil {
return err
}