aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--appservice/consumers/roomserver.go12
-rw-r--r--clientapi/routing/directory.go30
-rw-r--r--clientapi/routing/redaction.go2
-rw-r--r--clientapi/routing/sendevent.go4
-rw-r--r--clientapi/routing/state.go21
-rw-r--r--cmd/resolve-state/main.go5
-rw-r--r--federationapi/federationapi_test.go4
-rw-r--r--federationapi/internal/perform.go35
-rw-r--r--federationapi/routing/invite.go6
-rw-r--r--federationapi/routing/join.go24
-rw-r--r--federationapi/routing/leave.go13
-rw-r--r--go.mod4
-rw-r--r--go.sum8
-rw-r--r--internal/pushrules/evaluate.go16
-rw-r--r--internal/pushrules/evaluate_test.go17
-rw-r--r--internal/transactionrequest.go4
-rw-r--r--internal/transactionrequest_test.go8
-rw-r--r--roomserver/api/alias.go2
-rw-r--r--roomserver/api/api.go12
-rw-r--r--roomserver/api/query.go4
-rw-r--r--roomserver/internal/alias.go21
-rw-r--r--roomserver/internal/helpers/auth.go4
-rw-r--r--roomserver/internal/input/input_events.go32
-rw-r--r--roomserver/internal/input/input_events_test.go2
-rw-r--r--roomserver/internal/input/input_missing.go24
-rw-r--r--roomserver/internal/perform/perform_admin.go8
-rw-r--r--roomserver/internal/perform/perform_backfill.go12
-rw-r--r--roomserver/internal/perform/perform_create_room.go4
-rw-r--r--roomserver/internal/perform/perform_invite.go12
-rw-r--r--roomserver/internal/perform/perform_upgrade.go10
-rw-r--r--roomserver/internal/query/query.go30
-rw-r--r--roomserver/producers/roomevent.go2
-rw-r--r--roomserver/state/state.go9
-rw-r--r--roomserver/storage/interface.go5
-rw-r--r--roomserver/storage/shared/membership_updater.go2
-rw-r--r--roomserver/storage/shared/room_updater.go5
-rw-r--r--roomserver/storage/shared/storage.go28
-rw-r--r--setup/mscs/msc2836/msc2836.go8
-rw-r--r--setup/mscs/msc2836/msc2836_test.go4
-rw-r--r--setup/mscs/msc2946/msc2946.go2
-rw-r--r--syncapi/consumers/roomserver.go2
-rw-r--r--syncapi/routing/context.go23
-rw-r--r--syncapi/routing/getevent.go7
-rw-r--r--syncapi/routing/memberships.go6
-rw-r--r--syncapi/routing/messages.go12
-rw-r--r--syncapi/routing/relations.go7
-rw-r--r--syncapi/routing/routing.go2
-rw-r--r--syncapi/routing/search.go46
-rw-r--r--syncapi/routing/search_test.go10
-rw-r--r--syncapi/storage/postgres/current_room_state_table.go2
-rw-r--r--syncapi/storage/postgres/output_room_events_table.go2
-rw-r--r--syncapi/storage/shared/storage_consumer.go21
-rw-r--r--syncapi/storage/sqlite3/current_room_state_table.go2
-rw-r--r--syncapi/storage/sqlite3/output_room_events_table.go2
-rw-r--r--syncapi/streams/stream_invite.go12
-rw-r--r--syncapi/streams/stream_pdu.go38
-rw-r--r--syncapi/streams/streams.go1
-rw-r--r--syncapi/syncapi_test.go4
-rw-r--r--syncapi/synctypes/clientevent.go13
-rw-r--r--syncapi/synctypes/clientevent_test.go17
-rw-r--r--syncapi/types/types.go4
-rw-r--r--syncapi/types/types_test.go12
-rw-r--r--test/room.go6
-rw-r--r--userapi/consumers/roomserver.go43
-rw-r--r--userapi/consumers/roomserver_test.go11
-rw-r--r--userapi/util/notify_test.go9
66 files changed, 580 insertions, 189 deletions
diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go
index c02d9040..06625ad7 100644
--- a/appservice/consumers/roomserver.go
+++ b/appservice/consumers/roomserver.go
@@ -181,7 +181,9 @@ func (s *OutputRoomEventConsumer) sendEvents(
// Create the transaction body.
transaction, err := json.Marshal(
ApplicationServiceTransaction{
- Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll),
+ Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
+ return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }),
},
)
if err != nil {
@@ -233,10 +235,16 @@ func (s *appserviceState) backoffAndPause(err error) error {
//
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool {
+ user := ""
+ userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if err == nil {
+ user = userID.String()
+ }
+
switch {
case appservice.URL == "":
return false
- case appservice.IsInterestedInUserID(event.Sender()):
+ case appservice.IsInterestedInUserID(user):
return true
case appservice.IsInterestedInRoomID(event.RoomID()):
return true
diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go
index c786f8cc..0c842e6a 100644
--- a/clientapi/routing/directory.go
+++ b/clientapi/routing/directory.go
@@ -215,9 +215,35 @@ func RemoveLocalAlias(
alias string,
rsAPI roomserverAPI.ClientRoomserverAPI,
) util.JSONResponse {
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{Err: "UserID for device is invalid"},
+ }
+ }
+
+ roomIDReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: alias}
+ roomIDRes := roomserverAPI.GetRoomIDForAliasResponse{}
+ err = rsAPI.GetRoomIDForAlias(req.Context(), &roomIDReq, &roomIDRes)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusNotFound,
+ JSON: spec.NotFound("The alias does not exist."),
+ }
+ }
+
+ deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomIDRes.RoomID, *userID)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{Err: "Could not find SenderID for this device"},
+ }
+ }
+
queryReq := roomserverAPI.RemoveRoomAliasRequest{
- Alias: alias,
- UserID: device.UserID,
+ Alias: alias,
+ SenderID: deviceSenderID,
}
var queryRes roomserverAPI.RemoveRoomAliasResponse
if err := rsAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil {
diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go
index 88312642..e94c7748 100644
--- a/clientapi/routing/redaction.go
+++ b/clientapi/routing/redaction.go
@@ -76,7 +76,7 @@ func SendRedaction(
// "Users may redact their own events, and any user with a power level greater than or equal
// to the redact power level of the room may redact events there"
// https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid
- allowedToRedact := ev.Sender() == device.UserID
+ allowedToRedact := ev.SenderID() == device.UserID // TODO: Should replace device.UserID with device...PerRoomKey
if !allowedToRedact {
plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomPowerLevels,
diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go
index 1a2e25c9..8b09f399 100644
--- a/clientapi/routing/sendevent.go
+++ b/clientapi/routing/sendevent.go
@@ -331,7 +331,9 @@ func generateSendEvent(
stateEvents[i] = queryRes.StateEvents[i].PDU
}
provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
- if err = gomatrixserverlib.Allowed(e.PDU, &provider); err != nil {
+ if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
return nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client?
diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go
index 319f4eba..13f30899 100644
--- a/clientapi/routing/state.go
+++ b/clientapi/routing/state.go
@@ -140,9 +140,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
// use the result of the previous QueryLatestEventsAndState response
// to find the state event, if provided.
for _, ev := range stateRes.StateEvents {
+ sender := spec.UserID{}
+ userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
stateEvents = append(
stateEvents,
- synctypes.ToClientEvent(ev, synctypes.FormatAll),
+ synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
)
}
} else {
@@ -162,9 +167,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
}
}
for _, ev := range stateAfterRes.StateEvents {
+ sender := spec.UserID{}
+ userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
stateEvents = append(
stateEvents,
- synctypes.ToClientEvent(ev, synctypes.FormatAll),
+ synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
)
}
}
@@ -334,8 +344,13 @@ func OnIncomingStateTypeRequest(
}
}
+ sender := spec.UserID{}
+ userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
stateEvent := stateEventInStateResp{
- ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll),
+ ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
}
var res interface{}
diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go
index 3a4255ba..36040309 100644
--- a/cmd/resolve-state/main.go
+++ b/cmd/resolve-state/main.go
@@ -18,6 +18,7 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/gomatrixserverlib/spec"
)
// This is a utility for inspecting state snapshots and running state resolution
@@ -182,7 +183,9 @@ func main() {
fmt.Println("Resolving state")
var resolved Events
resolved, err = gomatrixserverlib.ResolveConflicts(
- gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents,
+ gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return roomserverDB.GetUserIDForSender(ctx, roomID, senderID)
+ },
)
if err != nil {
panic(err)
diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go
index beb648a4..a97bcdea 100644
--- a/federationapi/federationapi_test.go
+++ b/federationapi/federationapi_test.go
@@ -36,6 +36,10 @@ type fedRoomserverAPI struct {
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
}
+func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
+ return spec.NewUserID(senderID, true)
+}
+
// PerformJoin will call this function
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if f.inputRoomEvents == nil {
diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go
index ed800d03..2d59d0f9 100644
--- a/federationapi/internal/perform.go
+++ b/federationapi/internal/perform.go
@@ -156,15 +156,20 @@ func (r *FederationInternalAPI) performJoinUsingServer(
}
joinInput := gomatrixserverlib.PerformJoinInput{
- UserID: user,
- RoomID: room,
- ServerName: serverName,
- Content: content,
- Unsigned: unsigned,
- PrivateKey: r.cfg.Matrix.PrivateKey,
- KeyID: r.cfg.Matrix.KeyID,
- KeyRing: r.keyRing,
- EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName),
+ UserID: user,
+ RoomID: room,
+ ServerName: serverName,
+ Content: content,
+ Unsigned: unsigned,
+ PrivateKey: r.cfg.Matrix.PrivateKey,
+ KeyID: r.cfg.Matrix.KeyID,
+ KeyRing: r.keyRing,
+ EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }),
+ UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) {
+ return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ },
}
response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)
@@ -358,8 +363,11 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
// authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse()
+ userIDProvider := func(roomID, senderID string) (*spec.UserID, error) {
+ return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
- ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName),
+ ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider), userIDProvider,
)
if err != nil {
return fmt.Errorf("error checking state returned from peeking: %w", err)
@@ -509,7 +517,7 @@ func (r *FederationInternalAPI) SendInvite(
event gomatrixserverlib.PDU,
strippedState []gomatrixserverlib.InviteStrippedState,
) (gomatrixserverlib.PDU, error) {
- _, origin, err := r.cfg.Matrix.SplitLocalID('@', event.Sender())
+ inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil {
return nil, err
}
@@ -542,7 +550,7 @@ func (r *FederationInternalAPI) SendInvite(
return nil, fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err)
}
- inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq)
+ inviteRes, err := r.federation.SendInviteV2(ctx, inviter.Domain(), destination, inviteReq)
if err != nil {
return nil, fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err)
}
@@ -635,6 +643,7 @@ func checkEventsContainCreateEvent(events []gomatrixserverlib.PDU) error {
func federatedEventProvider(
ctx context.Context, federation fclient.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName,
+ userIDForSender spec.UserIDForSender,
) gomatrixserverlib.EventProvider {
// A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join.
@@ -684,7 +693,7 @@ func federatedEventProvider(
}
// Check the signatures of the event.
- if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing); err != nil {
+ if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender); err != nil {
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
}
diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go
index 78a09d94..d792335b 100644
--- a/federationapi/routing/invite.go
+++ b/federationapi/routing/invite.go
@@ -95,6 +95,9 @@ func InviteV2(
StateQuerier: rsAPI.StateQuerier(),
InviteEvent: inviteReq.Event(),
StrippedState: inviteReq.InviteRoomState(),
+ UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
+ },
}
event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI)
if jsonErr != nil {
@@ -185,6 +188,9 @@ func InviteV1(
StateQuerier: rsAPI.StateQuerier(),
InviteEvent: event,
StrippedState: strippedState,
+ UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
+ },
}
event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI)
if jsonErr != nil {
diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go
index 2980c2af..9da05918 100644
--- a/federationapi/routing/join.go
+++ b/federationapi/routing/join.go
@@ -99,15 +99,18 @@ func MakeJoin(
}
input := gomatrixserverlib.HandleMakeJoinInput{
- Context: httpReq.Context(),
- UserID: userID,
- RoomID: roomID,
- RoomVersion: roomVersion,
- RemoteVersions: remoteVersions,
- RequestOrigin: request.Origin(),
- LocalServerName: cfg.Matrix.ServerName,
- LocalServerInRoom: res.RoomExists && res.IsInRoom,
- RoomQuerier: &roomQuerier,
+ Context: httpReq.Context(),
+ UserID: userID,
+ RoomID: roomID,
+ RoomVersion: roomVersion,
+ RemoteVersions: remoteVersions,
+ RequestOrigin: request.Origin(),
+ LocalServerName: cfg.Matrix.ServerName,
+ LocalServerInRoom: res.RoomExists && res.IsInRoom,
+ RoomQuerier: &roomQuerier,
+ UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
+ },
BuildEventTemplate: createJoinTemplate,
}
response, internalErr := gomatrixserverlib.HandleMakeJoin(input)
@@ -202,6 +205,9 @@ func SendJoin(
PrivateKey: cfg.Matrix.PrivateKey,
Verifier: keys,
MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI},
+ UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
+ },
}
response, joinErr := gomatrixserverlib.HandleSendJoin(input)
switch e := joinErr.(type) {
diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go
index d7d5b599..30e99c4f 100644
--- a/federationapi/routing/leave.go
+++ b/federationapi/routing/leave.go
@@ -95,6 +95,9 @@ func MakeLeave(
LocalServerName: cfg.Matrix.ServerName,
LocalServerInRoom: res.RoomExists && res.IsInRoom,
BuildEventTemplate: createLeaveTemplate,
+ UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
+ },
}
response, internalErr := gomatrixserverlib.HandleMakeLeave(input)
@@ -213,7 +216,7 @@ func SendLeave(
JSON: spec.BadJSON("No state key was provided in the leave event."),
}
}
- if !event.StateKeyEquals(event.Sender()) {
+ if !event.StateKeyEquals(event.SenderID()) {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Event state key must match the event sender."),
@@ -223,13 +226,13 @@ func SendLeave(
// Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both.
- var serverName spec.ServerName
- if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil {
+ sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID())
+ if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("The sender of the join is invalid"),
}
- } else if serverName != request.Origin() {
+ } else if sender.Domain() != request.Origin() {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("The sender does not match the server that originated the request"),
@@ -291,7 +294,7 @@ func SendLeave(
}
}
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
- ServerName: serverName,
+ ServerName: sender.Domain(),
Message: redacted,
AtTS: event.OriginServerTS(),
ValidityCheckingFunc: gomatrixserverlib.StrictValiditySignatureCheck,
diff --git a/go.mod b/go.mod
index a49dfa0c..10551f70 100644
--- a/go.mod
+++ b/go.mod
@@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
- github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e
+ github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16
@@ -34,7 +34,7 @@ require (
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.13.0
- github.com/sirupsen/logrus v1.9.2
+ github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.8.2
github.com/tidwall/gjson v1.14.4
github.com/tidwall/sjson v1.2.5
diff --git a/go.sum b/go.sum
index 79154624..3ec1c115 100644
--- a/go.sum
+++ b/go.sum
@@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
-github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e h1:I3Sfr8gZvVtLHOeI8lgc62kgLuzpMhBZ6EQOMyexXEA=
-github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
+github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66 h1:6SixhMmB5Ir10xUJ6zh3A4NBxSaZCSz2s5U63Wg0eEU=
+github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
@@ -444,8 +444,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
-github.com/sirupsen/logrus v1.9.2 h1:oxx1eChJGI6Uks2ZC4W1zpLlVgqB8ner4EuQwV4Ik1Y=
-github.com/sirupsen/logrus v1.9.2/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
+github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
+github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go
index 7c98efd3..da33d386 100644
--- a/internal/pushrules/evaluate.go
+++ b/internal/pushrules/evaluate.go
@@ -6,6 +6,7 @@ import (
"strings"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/gomatrixserverlib/spec"
)
// A RuleSetEvaluator encapsulates context to evaluate an event
@@ -53,7 +54,7 @@ func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluat
// MatchEvent returns the first matching rule. Returns nil if there
// was no match rule.
-func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, error) {
+func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) (*Rule, error) {
// TODO: server-default rules have lower priority than user rules,
// but they are stored together with the user rules. It's a bit
// unclear what the specification (11.14.1.4 Predefined rules)
@@ -68,7 +69,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err
if rule.Default != defRules {
continue
}
- ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec)
+ ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec, userIDForSender)
if err != nil {
return nil, err
}
@@ -83,7 +84,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err
return nil, nil
}
-func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext) (bool, error) {
+func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext, userIDForSender spec.UserIDForSender) (bool, error) {
if !rule.Enabled {
return false, nil
}
@@ -113,7 +114,12 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati
return rule.RuleID == event.RoomID(), nil
case SenderKind:
- return rule.RuleID == event.Sender(), nil
+ userID := ""
+ sender, err := userIDForSender(event.RoomID(), event.SenderID())
+ if err == nil {
+ userID = sender.String()
+ }
+ return rule.RuleID == userID, nil
default:
return false, nil
@@ -143,7 +149,7 @@ func conditionMatches(cond *Condition, event gomatrixserverlib.PDU, ec Evaluatio
return cmp(n), nil
case SenderNotificationPermissionCondition:
- return ec.HasPowerLevel(event.Sender(), cond.Key)
+ return ec.HasPowerLevel(event.SenderID(), cond.Key)
default:
return false, nil
diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go
index 5045a864..34c1436f 100644
--- a/internal/pushrules/evaluate_test.go
+++ b/internal/pushrules/evaluate_test.go
@@ -5,8 +5,13 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/gomatrixserverlib/spec"
)
+func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) {
+ return spec.NewUserID(senderID, true)
+}
+
func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
ev := mustEventFromJSON(t, `{}`)
defaultEnabled := &Rule{
@@ -45,7 +50,7 @@ func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
rse := NewRuleSetEvaluator(fakeEvaluationContext{3}, &tst.RuleSet)
- got, err := rse.MatchEvent(tst.Event)
+ got, err := rse.MatchEvent(tst.Event, UserIDForSender)
if err != nil {
t.Fatalf("MatchEvent failed: %v", err)
}
@@ -82,15 +87,15 @@ func TestRuleMatches(t *testing.T) {
{"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true},
{"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false},
- {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true},
- {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false},
+ {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true},
+ {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false},
- {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true},
- {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false},
+ {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com"}`, true},
+ {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com"}`, false},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
- got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil)
+ got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil, UserIDForSender)
if err != nil {
t.Fatalf("ruleMatches failed: %v", err)
}
diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go
index c9d321f2..0bbe0720 100644
--- a/internal/transactionrequest.go
+++ b/internal/transactionrequest.go
@@ -167,7 +167,9 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
}
continue
}
- if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil {
+ if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) {
+ return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
results[event.EventID()] = fclient.PDUResult{
Error: err.Error(),
diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go
index fb30d410..6f3ce0b3 100644
--- a/internal/transactionrequest_test.go
+++ b/internal/transactionrequest_test.go
@@ -70,6 +70,10 @@ type FakeRsAPI struct {
bannedFromRoom bool
}
+func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
+ return spec.NewUserID(senderID, true)
+}
+
func (r *FakeRsAPI) QueryRoomVersionForRoom(
ctx context.Context,
roomID string,
@@ -638,6 +642,10 @@ type testRoomserverAPI struct {
queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse
}
+func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
+ return spec.NewUserID(senderID, true)
+}
+
func (t *testRoomserverAPI) InputRoomEvents(
ctx context.Context,
request *rsAPI.InputRoomEventsRequest,
diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go
index 37892a44..1b947540 100644
--- a/roomserver/api/alias.go
+++ b/roomserver/api/alias.go
@@ -62,7 +62,7 @@ type GetAliasesForRoomIDResponse struct {
// RemoveRoomAliasRequest is a request to RemoveRoomAlias
type RemoveRoomAliasRequest struct {
// ID of the user removing the alias
- UserID string `json:"user_id"`
+ SenderID string `json:"user_id"`
// The room alias to remove
Alias string `json:"alias"`
}
diff --git a/roomserver/api/api.go b/roomserver/api/api.go
index a37ade3a..d61a0553 100644
--- a/roomserver/api/api.go
+++ b/roomserver/api/api.go
@@ -49,6 +49,7 @@ type RoomserverInternalAPI interface {
ClientRoomserverAPI
UserRoomserverAPI
FederationRoomserverAPI
+ QuerySenderIDAPI
// needed to avoid chicken and egg scenario when setting up the
// interdependencies between the roomserver and other input APIs
@@ -75,6 +76,11 @@ type InputRoomEventsAPI interface {
)
}
+type QuerySenderIDAPI interface {
+ QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error)
+ QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error)
+}
+
// Query the latest events and state for a room from the room server.
type QueryLatestEventsAndStateAPI interface {
QueryLatestEventsAndState(ctx context.Context, req *QueryLatestEventsAndStateRequest, res *QueryLatestEventsAndStateResponse) error
@@ -102,6 +108,7 @@ type QueryEventsAPI interface {
type SyncRoomserverAPI interface {
QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI
+ QuerySenderIDAPI
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
@@ -142,6 +149,7 @@ type SyncRoomserverAPI interface {
}
type AppserviceRoomserverAPI interface {
+ QuerySenderIDAPI
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID.
QueryEventsByID(
@@ -168,6 +176,7 @@ type ClientRoomserverAPI interface {
QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI
QueryEventsAPI
+ QuerySenderIDAPI
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
@@ -200,6 +209,7 @@ type ClientRoomserverAPI interface {
}
type UserRoomserverAPI interface {
+ QuerySenderIDAPI
QueryLatestEventsAndStateAPI
KeyserverRoomserverAPI
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
@@ -213,6 +223,8 @@ type FederationRoomserverAPI interface {
InputRoomEventsAPI
QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI
+ QuerySenderIDAPI
+
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
diff --git a/roomserver/api/query.go b/roomserver/api/query.go
index e741c140..d79dcebb 100644
--- a/roomserver/api/query.go
+++ b/roomserver/api/query.go
@@ -491,10 +491,10 @@ type MembershipQuerier struct {
Roomserver FederationRoomserverAPI
}
-func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) {
+func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
req := QueryMembershipForUserRequest{
RoomID: roomID.String(),
- UserID: userID.String(),
+ UserID: string(senderID),
}
res := QueryMembershipForUserResponse{}
err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res)
diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go
index 52b90cf4..dcfb26b8 100644
--- a/roomserver/internal/alias.go
+++ b/roomserver/internal/alias.go
@@ -119,11 +119,6 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
request *api.RemoveRoomAliasRequest,
response *api.RemoveRoomAliasResponse,
) error {
- _, virtualHost, err := r.Cfg.Global.SplitLocalID('@', request.UserID)
- if err != nil {
- return err
- }
-
roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
if err != nil {
return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err)
@@ -134,13 +129,19 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return nil
}
+ sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID)
+ if err != nil {
+ return fmt.Errorf("r.QueryUserIDForSender: %w", err)
+ }
+ virtualHost := sender.Domain()
+
response.Found = true
creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias)
if err != nil {
return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err)
}
- if creatorID != request.UserID {
+ if creatorID != request.SenderID {
var plEvent *types.HeaderedEvent
var pls *gomatrixserverlib.PowerLevelContent
@@ -154,7 +155,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return fmt.Errorf("plEvent.PowerLevels: %w", err)
}
- if pls.UserLevel(request.UserID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) {
+ if pls.UserLevel(request.SenderID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) {
response.Removed = false
return nil
}
@@ -172,9 +173,9 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return err
}
- sender := request.UserID
- if request.UserID != ev.Sender() {
- sender = ev.Sender()
+ sender := request.SenderID
+ if request.SenderID != ev.SenderID() {
+ sender = ev.SenderID()
}
_, senderDomain, err := r.Cfg.Global.SplitLocalID('@', sender)
diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go
index 7ec0892e..932ce615 100644
--- a/roomserver/internal/helpers/auth.go
+++ b/roomserver/internal/helpers/auth.go
@@ -76,7 +76,9 @@ func CheckForSoftFail(
}
// Check if the event is allowed.
- if err = gomatrixserverlib.Allowed(event.PDU, &authEvents); err != nil {
+ if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return db.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
// return true, nil
return true, err
}
diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go
index 386083f6..764bdfe2 100644
--- a/roomserver/internal/input/input_events.go
+++ b/roomserver/internal/input/input_events.go
@@ -128,9 +128,13 @@ func (r *Inputer) processRoomEvent(
if roomInfo == nil && !isCreateEvent {
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
}
- _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender())
+ sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil {
- return fmt.Errorf("event has invalid sender %q", input.Event.Sender())
+ return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err)
+ }
+ senderDomain := spec.ServerName("")
+ if sender != nil {
+ senderDomain = sender.Domain()
}
// If we already know about this outlier and it hasn't been rejected
@@ -193,7 +197,9 @@ func (r *Inputer) processRoomEvent(
serverRes.ServerNames = append(serverRes.ServerNames, input.Origin)
delete(servers, input.Origin)
}
- if senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName {
+ // Only perform this check if the sender mxid_mapping can be resolved.
+ // Don't fail processing the event if we have no mxid_maping.
+ if sender != nil && senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName {
serverRes.ServerNames = append(serverRes.ServerNames, senderDomain)
delete(servers, senderDomain)
}
@@ -276,7 +282,9 @@ func (r *Inputer) processRoomEvent(
// Check if the event is allowed by its auth events. If it isn't then
// we consider the event to be "rejected" — it will still be persisted.
- if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
+ if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
isRejected = true
rejectionErr = err
logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID())
@@ -493,7 +501,7 @@ func (r *Inputer) processRoomEvent(
func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error {
oldRoomID := event.RoomID()
newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str
- return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender())
+ return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.SenderID())
}
// processStateBefore works out what the state is before the event and
@@ -579,7 +587,9 @@ func (r *Inputer) processStateBefore(
stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
gomatrixserverlib.ToPDUs(stateBeforeEvent),
)
- if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth); rejectionErr != nil {
+ if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ }); rejectionErr != nil {
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
return
}
@@ -690,7 +700,9 @@ nextAuthEvent:
// Check the signatures of the event. If this fails then we'll simply
// skip it, because gomatrixserverlib.Allowed() will notice a problem
// if a critical event is missing anyway.
- if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing()); err != nil {
+ if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
continue nextAuthEvent
}
@@ -706,7 +718,9 @@ nextAuthEvent:
}
// Check if the auth event should be rejected.
- err := gomatrixserverlib.Allowed(authEvent, auth)
+ err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ })
if isRejected = err != nil; isRejected {
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
}
@@ -828,11 +842,13 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
continue
}
+ // TODO: pseudoIDs: get userID for room using state key (which is now senderID)
localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
if err != nil {
continue
}
+ // TODO: pseudoIDs: query account by state key (which is now senderID)
accountRes := &userAPI.QueryAccountByLocalpartResponse{}
if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
Localpart: localpart,
diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go
index 56803813..0ba7d19f 100644
--- a/roomserver/internal/input/input_events_test.go
+++ b/roomserver/internal/input/input_events_test.go
@@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) {
}
// Finally check that the event is NOT allowed
- if err := gomatrixserverlib.Allowed(ev.PDU, &allower); err == nil {
+ if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil {
t.Fatalf("event should not be allowed, but it was")
}
}
diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go
index 10486138..ac0670fc 100644
--- a/roomserver/internal/input/input_missing.go
+++ b/roomserver/internal/input/input_missing.go
@@ -473,14 +473,18 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
stateEventList = append(stateEventList, state.StateEvents...)
}
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
- roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList),
+ roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID, senderID string) (*spec.UserID, error) {
+ return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ },
)
if err != nil {
return nil, err
}
// apply the current event
retryAllowedState:
- if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil {
+ if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
switch missing := err.(type) {
case gomatrixserverlib.MissingAuthEventError:
h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true)
@@ -565,7 +569,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
// will be added and duplicates will be removed.
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
- if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys); err != nil {
+ if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID, senderID string) (*spec.UserID, error) {
+ return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
continue
}
missingEvents = append(missingEvents, t.cacheAndReturn(ev))
@@ -654,7 +660,9 @@ func (t *missingStateReq) lookupMissingStateViaState(
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
StateEvents: state.GetStateEvents(),
AuthEvents: state.GetAuthEvents(),
- }, roomVersion, t.keys, nil)
+ }, roomVersion, t.keys, nil, func(roomID, senderID string) (*spec.UserID, error) {
+ return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ })
if err != nil {
return nil, err
}
@@ -889,14 +897,16 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
}
- if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil {
+ if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) {
+ return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
return nil, verifySigError{event.EventID(), err}
}
return t.cacheAndReturn(event), nil
}
-func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU) error {
+func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error {
authUsingState := gomatrixserverlib.NewAuthEvents(nil)
for i := range stateEvents {
err := authUsingState.AddEvent(stateEvents[i])
@@ -904,7 +914,7 @@ func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverli
return err
}
}
- return gomatrixserverlib.Allowed(e, &authUsingState)
+ return gomatrixserverlib.Allowed(e, &authUsingState, userIDForSender)
}
func (t *missingStateReq) hadEvent(eventID string) {
diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go
index 575525e2..ca736cb6 100644
--- a/roomserver/internal/perform/perform_admin.go
+++ b/roomserver/internal/perform/perform_admin.go
@@ -262,13 +262,17 @@ func (r *Admin) PerformAdminDownloadState(
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
}
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
- if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil {
+ if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
continue
}
authEventMap[authEvent.EventID()] = authEvent
}
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
- if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing); err != nil {
+ if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
continue
}
stateEventMap[stateEvent.EventID()] = stateEvent
diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go
index fb579f03..0f743f4e 100644
--- a/roomserver/internal/perform/perform_backfill.go
+++ b/roomserver/internal/perform/perform_backfill.go
@@ -121,7 +121,9 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill(
ctx, req.VirtualHost, requester,
- r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100,
+ r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ },
)
// Only return an error if we really couldn't get any events.
if err != nil && len(events) == 0 {
@@ -210,7 +212,9 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
continue
}
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
- result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents)
+ result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ })
if err != nil {
logger.WithError(err).Warn("failed to load and verify event")
continue
@@ -484,8 +488,8 @@ FindSuccessor:
// Store the server names in a temporary map to avoid duplicates.
serverSet := make(map[spec.ServerName]bool)
for _, event := range memberEvents {
- if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()); err == nil {
- serverSet[senderDomain] = true
+ if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil {
+ serverSet[sender.Domain()] = true
}
}
var servers []spec.ServerName
diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go
index 41194832..897bd3a0 100644
--- a/roomserver/internal/perform/perform_create_room.go
+++ b/roomserver/internal/perform/perform_create_room.go
@@ -308,7 +308,9 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
}
}
- if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil {
+ if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return c.DB.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go
index 1930b5ac..e8e20ede 100644
--- a/roomserver/internal/perform/perform_invite.go
+++ b/roomserver/internal/perform/perform_invite.go
@@ -97,11 +97,12 @@ func (r *Inviter) ProcessInviteMembership(
) ([]api.OutputEvent, error) {
var outputUpdates []api.OutputEvent
var updater *shared.MembershipUpdater
- _, domain, err := gomatrixserverlib.SplitID('@', *inviteEvent.StateKey())
+
+ userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey())
if err != nil {
return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())}
}
- isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain)
+ isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain())
if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil {
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
}
@@ -125,9 +126,9 @@ func (r *Inviter) PerformInvite(
) error {
event := req.Event
- sender, err := spec.NewUserID(event.Sender(), true)
+ sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil {
- return spec.InvalidParam("The user ID is invalid")
+ return spec.InvalidParam("The sender user ID is invalid")
}
if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) {
return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
@@ -155,6 +156,9 @@ func (r *Inviter) PerformInvite(
StrippedState: req.InviteRoomState,
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
StateQuerier: &QueryState{r.DB},
+ UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ },
}
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
if err != nil {
diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go
index ff4a6a1d..8c0df1c4 100644
--- a/roomserver/internal/perform/perform_upgrade.go
+++ b/roomserver/internal/perform/perform_upgrade.go
@@ -176,7 +176,7 @@ func moveLocalAliases(ctx context.Context,
}
for _, alias := range aliasRes.Aliases {
- removeAliasReq := api.RemoveRoomAliasRequest{UserID: userID, Alias: alias}
+ removeAliasReq := api.RemoveRoomAliasRequest{SenderID: userID, Alias: alias}
removeAliasRes := api.RemoveRoomAliasResponse{}
if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil {
return fmt.Errorf("Failed to remove old room alias: %w", err)
@@ -484,7 +484,9 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
}
- if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
+ if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
}
@@ -567,7 +569,9 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
stateEvents[i] = queryRes.StateEvents[i].PDU
}
provider := gomatrixserverlib.NewAuthEvents(stateEvents)
- if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider); err != nil {
+ if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?
}
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index 6d898e8a..707e95b2 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -159,7 +159,9 @@ func (r *Queryer) QueryStateAfterEvents(
}
stateEvents, err = gomatrixserverlib.ResolveConflicts(
- info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
+ info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ },
)
if err != nil {
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
@@ -386,7 +388,12 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err)
}
for _, event := range events {
- clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll)
+ sender := spec.UserID{}
+ userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if queryErr == nil && userID != nil {
+ sender = *userID
+ }
+ clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
return nil
@@ -435,7 +442,12 @@ func (r *Queryer) QueryMembershipsForRoom(
}
for _, event := range events {
- clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll)
+ sender := spec.UserID{}
+ userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
+ clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
@@ -625,7 +637,9 @@ func (r *Queryer) QueryStateAndAuthChain(
if request.ResolveState {
stateEvents, err = gomatrixserverlib.ResolveConflicts(
- info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
+ info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ },
)
if err != nil {
return err
@@ -960,3 +974,11 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID)
}
+
+func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) {
+ return r.DB.GetSenderIDForUser(ctx, roomID, userID)
+}
+
+func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+}
diff --git a/roomserver/producers/roomevent.go b/roomserver/producers/roomevent.go
index febe8ddf..165304d4 100644
--- a/roomserver/producers/roomevent.go
+++ b/roomserver/producers/roomevent.go
@@ -60,7 +60,7 @@ func (r *RoomEventProducer) ProduceRoomEvents(roomID string, updates []api.Outpu
"adds_state": len(update.NewRoomEvent.AddsStateEventIDs),
"removes_state": len(update.NewRoomEvent.RemovesStateEventIDs),
"send_as_server": update.NewRoomEvent.SendAsServer,
- "sender": update.NewRoomEvent.Event.Sender(),
+ "sender": update.NewRoomEvent.Event.SenderID(),
})
if update.NewRoomEvent.Event.StateKey() != nil {
logger = logger.WithField("state_key", *update.NewRoomEvent.Event.StateKey())
diff --git a/roomserver/state/state.go b/roomserver/state/state.go
index f38d8f96..3131cbff 100644
--- a/roomserver/state/state.go
+++ b/roomserver/state/state.go
@@ -24,6 +24,7 @@ import (
"time"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
@@ -43,6 +44,7 @@ type StateResolutionStorage interface {
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
+ GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error)
}
type StateResolution struct {
@@ -945,7 +947,9 @@ func (v *StateResolution) resolveConflictsV1(
}
// Resolve the conflicts.
- resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents)
+ resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return v.db.GetUserIDForSender(ctx, roomID, senderID)
+ })
// Map from the full events back to numeric state entries.
for _, resolvedEvent := range resolvedEvents {
@@ -1057,6 +1061,9 @@ func (v *StateResolution) resolveConflictsV2(
conflictedEvents,
nonConflictedEvents,
authEvents,
+ func(roomID, senderID string) (*spec.UserID, error) {
+ return v.db.GetUserIDForSender(ctx, roomID, senderID)
+ },
)
}()
diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go
index 7d22df00..2d007bed 100644
--- a/roomserver/storage/interface.go
+++ b/roomserver/storage/interface.go
@@ -166,6 +166,10 @@ type Database interface {
GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error)
// GetKnownUsers searches all users that userID knows about.
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
+ // GetKnownUsers tries to obtain the current mxid for a given user.
+ GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error)
+ // GetKnownUsers tries to obtain the current senderID for a given user.
+ GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error)
// GetKnownRooms returns a list of all rooms we know about.
GetKnownRooms(ctx context.Context) ([]string, error)
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
@@ -211,6 +215,7 @@ type RoomDatabase interface {
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
+ GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error)
}
type EventDatabase interface {
diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go
index f9c889cb..105e61df 100644
--- a/roomserver/storage/shared/membership_updater.go
+++ b/roomserver/storage/shared/membership_updater.go
@@ -101,7 +101,7 @@ func (u *MembershipUpdater) Update(newMembership tables.MembershipState, event *
var inserted bool // Did the query result in a membership change?
var retired []string // Did we retire any updates in the process?
return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
- senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.SenderID())
if err != nil {
return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go
index 70672a33..73500138 100644
--- a/roomserver/storage/shared/room_updater.go
+++ b/roomserver/storage/shared/room_updater.go
@@ -6,6 +6,7 @@ import (
"fmt"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/roomserver/types"
)
@@ -250,3 +251,7 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
}
+
+func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
+ return u.d.GetUserIDForSender(ctx, roomID, senderID)
+}
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index cefa58a3..406d7cf1 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -988,8 +988,18 @@ func (d *EventDatabase) MaybeRedactEvent(
return nil
}
- _, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender())
- _, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender())
+ // TODO: Don't hack senderID into userID here (pseudoIDs)
+ sender1Domain := ""
+ sender1, err1 := spec.NewUserID(redactedEvent.SenderID(), true)
+ if err1 == nil {
+ sender1Domain = string(sender1.Domain())
+ }
+ // TODO: Don't hack senderID into userID here (pseudoIDs)
+ sender2Domain := ""
+ sender2, err2 := spec.NewUserID(redactionEvent.SenderID(), true)
+ if err2 == nil {
+ sender2Domain = string(sender2.Domain())
+ }
var powerlevels *gomatrixserverlib.PowerLevelContent
powerlevels, err = plResolver.Resolve(ctx, redactionEvent.EventID())
if err != nil {
@@ -997,9 +1007,9 @@ func (d *EventDatabase) MaybeRedactEvent(
}
switch {
- case powerlevels.UserLevel(redactionEvent.Sender()) >= powerlevels.Redact:
+ case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact:
// 1. The power level of the redaction event’s sender is greater than or equal to the redact level.
- case sender1 == sender2:
+ case sender1Domain == sender2Domain:
// 2. The domain of the redaction event’s sender matches that of the original event’s sender.
default:
ignoreRedaction = true
@@ -1514,6 +1524,16 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
}
+func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
+ // TODO: Use real logic once DB for pseudoIDs is in place
+ return spec.NewUserID(senderID, true)
+}
+
+func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) {
+ // TODO: Use real logic once DB for pseudoIDs is in place
+ return userID.String(), nil
+}
+
// GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go
index f468b048..5ce3b430 100644
--- a/setup/mscs/msc2836/msc2836.go
+++ b/setup/mscs/msc2836/msc2836.go
@@ -92,9 +92,11 @@ type MSC2836EventRelationshipsResponse struct {
ParsedAuthChain []gomatrixserverlib.PDU
}
-func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationshipResponse {
+func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse {
out := &EventRelationshipResponse{
- Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll),
+ Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }),
Limited: res.Limited,
NextBatch: res.NextBatch,
}
@@ -187,7 +189,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
return util.JSONResponse{
Code: 200,
- JSON: toClientResponse(res),
+ JSON: toClientResponse(req.Context(), res, rsAPI),
}
}
}
diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go
index 2c6f63d4..c463fd72 100644
--- a/setup/mscs/msc2836/msc2836_test.go
+++ b/setup/mscs/msc2836/msc2836_test.go
@@ -525,6 +525,10 @@ type testRoomserverAPI struct {
events map[string]*types.HeaderedEvent
}
+func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
+ return spec.NewUserID(senderID, true)
+}
+
func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error {
for _, eventID := range req.EventIDs {
ev := r.events[eventID]
diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go
index 291e0f3b..f380d3d4 100644
--- a/setup/mscs/msc2946/msc2946.go
+++ b/setup/mscs/msc2946/msc2946.go
@@ -730,7 +730,7 @@ func stripped(ev gomatrixserverlib.PDU) *fclient.MSC2946StrippedEvent {
Type: ev.Type(),
StateKey: *ev.StateKey(),
Content: ev.Content(),
- Sender: ev.Sender(),
+ Sender: ev.SenderID(),
OriginServerTS: ev.OriginServerTS(),
}
}
diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go
index 56285dbf..c0836465 100644
--- a/syncapi/consumers/roomserver.go
+++ b/syncapi/consumers/roomserver.go
@@ -523,7 +523,7 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent)
prev := types.PrevEventRef{
PrevContent: prevEvent.Content(),
ReplacesState: prevEvent.EventID(),
- PrevSender: prevEvent.Sender(),
+ PrevSender: prevEvent.SenderID(),
}
event.PDU, err = event.SetUnsigned(prev)
diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go
index ac17d39d..27e99a35 100644
--- a/syncapi/routing/context.go
+++ b/syncapi/routing/context.go
@@ -193,14 +193,20 @@ func Context(
}
}
- eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll)
- eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll)
+ eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
+ eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
newState := state
if filter.LazyLoadMembers {
allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...)
allEvents = append(allEvents, &requestedEvent)
- evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll)
+ evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache)
if err != nil {
logrus.WithError(err).Error("unable to load membership events")
@@ -211,12 +217,19 @@ func Context(
}
}
- ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll)
+ sender := spec.UserID{}
+ userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
+ ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender)
response := ContextRespsonse{
Event: &ev,
EventsAfter: eventsAfterClient,
EventsBefore: eventsBeforeClient,
- State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll),
+ State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }),
}
if len(response.State) > filter.Limit {
diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go
index 0d3d412f..63df7e83 100644
--- a/syncapi/routing/getevent.go
+++ b/syncapi/routing/getevent.go
@@ -101,8 +101,13 @@ func GetEvent(
}
}
+ sender := spec.UserID{}
+ senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID())
+ if err == nil && senderUserID != nil {
+ sender = *senderUserID
+ }
return util.JSONResponse{
Code: http.StatusOK,
- JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll),
+ JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender),
}
}
diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go
index 7d2e137d..9c2319dd 100644
--- a/syncapi/routing/memberships.go
+++ b/syncapi/routing/memberships.go
@@ -144,7 +144,7 @@ func GetMemberships(
JSON: spec.InternalServerError{},
}
}
- res.Joined[ev.Sender()] = joinedMember(content)
+ res.Joined[ev.SenderID()] = joinedMember(content)
}
return util.JSONResponse{
Code: http.StatusOK,
@@ -153,6 +153,8 @@ func GetMemberships(
}
return util.JSONResponse{
Code: http.StatusOK,
- JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll)},
+ JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
+ })},
}
}
diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go
index aeaec699..879739d0 100644
--- a/syncapi/routing/messages.go
+++ b/syncapi/routing/messages.go
@@ -241,7 +241,7 @@ func OnIncomingMessagesRequest(
device: device,
}
- clientEvents, start, end, err := mReq.retrieveEvents()
+ clientEvents, start, end, err := mReq.retrieveEvents(req.Context(), rsAPI)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed")
return util.JSONResponse{
@@ -273,7 +273,9 @@ func OnIncomingMessagesRequest(
JSON: spec.InternalServerError{},
}
}
- res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll)...)
+ res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
+ })...)
}
// If we didn't return any events, set the end to an empty string, so it will be omitted
@@ -310,7 +312,7 @@ func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.
// homeserver in the room for older events.
// Returns an error if there was an issue talking to the database or with the
// remote homeserver.
-func (r *messagesReq) retrieveEvents() (
+func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserverAPI) (
clientEvents []synctypes.ClientEvent, start,
end types.TopologyToken, err error,
) {
@@ -383,7 +385,9 @@ func (r *messagesReq) retrieveEvents() (
"events_before": len(events),
"events_after": len(filteredEvents),
}).Debug("applied history visibility (messages)")
- return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll), start, end, err
+ return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }), start, end, err
}
func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) {
diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go
index 8374bf5b..f21c684c 100644
--- a/syncapi/routing/relations.go
+++ b/syncapi/routing/relations.go
@@ -114,9 +114,14 @@ func Relations(
// type if it was specified.
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents {
+ sender := spec.UserID{}
+ userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
res.Chunk = append(
res.Chunk,
- synctypes.ToClientEvent(event.PDU, synctypes.FormatAll),
+ synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender),
)
}
diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go
index 9ad0c047..8542c0b7 100644
--- a/syncapi/routing/routing.go
+++ b/syncapi/routing/routing.go
@@ -171,7 +171,7 @@ func Setup(
nb := req.FormValue("next_batch")
nextBatch = &nb
}
- return Search(req, device, syncDB, fts, nextBatch)
+ return Search(req, device, syncDB, fts, nextBatch, rsAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go
index b7191873..9cf3eabe 100644
--- a/syncapi/routing/search.go
+++ b/syncapi/routing/search.go
@@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/synctypes"
@@ -38,7 +39,7 @@ import (
)
// nolint:gocyclo
-func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string) util.JSONResponse {
+func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string, rsAPI roomserverAPI.SyncRoomserverAPI) util.JSONResponse {
start := time.Now()
var (
searchReq SearchRequest
@@ -204,11 +205,17 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
profileInfos := make(map[string]ProfileInfoResponse)
for _, ev := range append(eventsBefore, eventsAfter...) {
- profile, ok := knownUsersProfiles[event.Sender()]
+ userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID())
+ if queryErr != nil {
+ logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile")
+ continue
+ }
+
+ profile, ok := knownUsersProfiles[userID.String()]
if !ok {
- stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.Sender())
- if err != nil {
- logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile")
+ stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID())
+ if stateErr != nil {
+ logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile")
continue
}
if stateEvent == nil {
@@ -218,21 +225,30 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str,
DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str,
}
- knownUsersProfiles[event.Sender()] = profile
+ knownUsersProfiles[userID.String()] = profile
}
- profileInfos[ev.Sender()] = profile
+ profileInfos[userID.String()] = profile
}
+ sender := spec.UserID{}
+ userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
results = append(results, Result{
Context: SearchContextResponse{
- Start: startToken.String(),
- End: endToken.String(),
- EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync),
- EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync),
- ProfileInfo: profileInfos,
+ Start: startToken.String(),
+ End: endToken.String(),
+ EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
+ }),
+ EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
+ }),
+ ProfileInfo: profileInfos,
},
Rank: eventScore[event.EventID()].Score,
- Result: synctypes.ToClientEvent(event, synctypes.FormatAll),
+ Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
})
roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID())
@@ -247,7 +263,9 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
JSON: spec.InternalServerError{},
}
}
- stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync)
+ stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
+ })
}
}
diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go
index 1cc95a87..b36be823 100644
--- a/syncapi/routing/search_test.go
+++ b/syncapi/routing/search_test.go
@@ -2,6 +2,7 @@ package routing
import (
"bytes"
+ "context"
"encoding/json"
"net/http"
"net/http/httptest"
@@ -9,6 +10,7 @@ import (
"github.com/matrix-org/dendrite/internal/fulltext"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ rsapi "github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/synctypes"
@@ -21,6 +23,12 @@ import (
"github.com/stretchr/testify/assert"
)
+type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI }
+
+func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
+ return spec.NewUserID(senderID, true)
+}
+
func TestSearch(t *testing.T) {
alice := test.NewUser(t)
aliceDevice := userapi.Device{UserID: alice.ID}
@@ -247,7 +255,7 @@ func TestSearch(t *testing.T) {
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", reqBody)
- res := Search(req, tc.device, db, fts, tc.from)
+ res := Search(req, tc.device, db, fts, tc.from, &FakeSyncRoomserverAPI{})
if !tc.wantOK && !res.Is2xx() {
return
}
diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go
index 0cc96373..bfe5e9bd 100644
--- a/syncapi/storage/postgres/current_room_state_table.go
+++ b/syncapi/storage/postgres/current_room_state_table.go
@@ -343,7 +343,7 @@ func (s *currentRoomStateStatements) UpsertRoomState(
event.RoomID(),
event.EventID(),
event.Type(),
- event.Sender(),
+ event.SenderID(),
containsURL,
*event.StateKey(),
headeredJSON,
diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go
index 3aadbccf..e068afab 100644
--- a/syncapi/storage/postgres/output_room_events_table.go
+++ b/syncapi/storage/postgres/output_room_events_table.go
@@ -407,7 +407,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
event.EventID(),
headeredJSON,
event.Type(),
- event.Sender(),
+ event.SenderID(),
containsURL,
pq.StringArray(addState),
pq.StringArray(removeState),
diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go
index ecfd418f..17a6a69c 100644
--- a/syncapi/storage/shared/storage_consumer.go
+++ b/syncapi/storage/shared/storage_consumer.go
@@ -195,7 +195,21 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea
for i := 0; i < len(in); i++ {
out[i] = in[i].HeaderedEvent
if device != nil && in[i].TransactionID != nil {
- if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID {
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ logrus.WithFields(logrus.Fields{
+ "event_id": out[i].EventID(),
+ }).WithError(err).Warnf("Failed to add transaction ID to event")
+ continue
+ }
+ deviceSenderID, err := d.getSenderIDForUser(in[i].RoomID(), *userID)
+ if err != nil {
+ logrus.WithFields(logrus.Fields{
+ "event_id": out[i].EventID(),
+ }).WithError(err).Warnf("Failed to add transaction ID to event")
+ continue
+ }
+ if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID {
err := out[i].SetUnsignedField(
"transaction_id", in[i].TransactionID.TransactionID,
)
@@ -210,6 +224,11 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea
return out
}
+func (d *Database) getSenderIDForUser(roomID string, userID spec.UserID) (string, error) { // nolint
+ // TODO: Repalce with actual logic for pseudoIDs
+ return userID.String(), nil
+}
+
// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
index 1b8632eb..e432e483 100644
--- a/syncapi/storage/sqlite3/current_room_state_table.go
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -342,7 +342,7 @@ func (s *currentRoomStateStatements) UpsertRoomState(
event.RoomID(),
event.EventID(),
event.Type(),
- event.Sender(),
+ event.SenderID(),
containsURL,
*event.StateKey(),
headeredJSON,
diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go
index d63e7606..5a47aec4 100644
--- a/syncapi/storage/sqlite3/output_room_events_table.go
+++ b/syncapi/storage/sqlite3/output_room_events_table.go
@@ -348,7 +348,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
event.EventID(),
headeredJSON,
event.Type(),
- event.Sender(),
+ event.SenderID(),
containsURL,
string(addStateJSON),
string(removeStateJSON),
diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go
index becd863a..a8b0a7b6 100644
--- a/syncapi/streams/stream_invite.go
+++ b/syncapi/streams/stream_invite.go
@@ -10,6 +10,7 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec"
+ "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/dendrite/syncapi/types"
@@ -17,6 +18,7 @@ import (
type InviteStreamProvider struct {
DefaultStreamProvider
+ rsAPI api.SyncRoomserverAPI
}
func (p *InviteStreamProvider) Setup(
@@ -62,11 +64,17 @@ func (p *InviteStreamProvider) IncrementalSync(
}
for roomID, inviteEvent := range invites {
+ user := spec.UserID{}
+ sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID())
+ if err == nil && sender != nil {
+ user = *sender
+ }
+
// skip ignored user events
- if _, ok := req.IgnoredUsers.List[inviteEvent.Sender()]; ok {
+ if _, ok := req.IgnoredUsers.List[user.String()]; ok {
continue
}
- ir := types.NewInviteResponse(inviteEvent)
+ ir := types.NewInviteResponse(inviteEvent, user)
req.Response.Rooms.Invite[roomID] = ir
}
diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go
index 0ea48a9d..8f83a089 100644
--- a/syncapi/streams/stream_pdu.go
+++ b/syncapi/streams/stream_pdu.go
@@ -376,20 +376,28 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
}
}
jr.Timeline.PrevBatch = &prevBatch
- jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync)
+ jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) {
+ return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
- jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync)
+ jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) {
+ return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
req.Response.Rooms.Join[delta.RoomID] = jr
case spec.Peek:
jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = &prevBatch
// TODO: Apply history visibility on peeked rooms
- jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync)
+ jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) {
+ return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
jr.Timeline.Limited = limited
- jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync)
+ jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) {
+ return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
req.Response.Rooms.Peek[delta.RoomID] = jr
case spec.Leave:
@@ -398,11 +406,15 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
case spec.Ban:
lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = &prevBatch
- lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync)
+ lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) {
+ return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
lr.Timeline.Limited = limited && len(events) == len(recentEvents)
- lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync)
+ lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) {
+ return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
req.Response.Rooms.Leave[delta.RoomID] = lr
}
@@ -425,7 +437,7 @@ func applyHistoryVisibilityFilter(
for _, ev := range recentEvents {
if ev.StateKey() != nil {
stateTypes = append(stateTypes, ev.Type())
- senders = append(senders, ev.Sender())
+ senders = append(senders, ev.SenderID())
}
}
@@ -552,11 +564,15 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
}
jr.Timeline.PrevBatch = prevBatch
- jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync)
+ jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) {
+ return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
// If we are limited by the filter AND the history visibility filter
// didn't "remove" events, return that the response is limited.
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
- jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync)
+ jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) {
+ return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
return jr, nil
}
@@ -577,8 +593,8 @@ func (p *PDUStreamProvider) lazyLoadMembers(
// Add all users the client doesn't know about yet to a list
for _, event := range timelineEvents {
// Membership is not yet cached, add it to the list
- if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok {
- timelineUsers[event.Sender()] = struct{}{}
+ if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.SenderID()); !ok {
+ timelineUsers[event.SenderID()] = struct{}{}
}
}
// Preallocate with the same amount, even if it will end up with fewer values
diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go
index a35491ac..f25bc978 100644
--- a/syncapi/streams/streams.go
+++ b/syncapi/streams/streams.go
@@ -45,6 +45,7 @@ func NewSyncStreamProviders(
},
InviteStreamProvider: &InviteStreamProvider{
DefaultStreamProvider: DefaultStreamProvider{DB: d},
+ rsAPI: rsAPI,
},
SendToDeviceStreamProvider: &SendToDeviceStreamProvider{
DefaultStreamProvider: DefaultStreamProvider{DB: d},
diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go
index bc766e66..78c857ab 100644
--- a/syncapi/syncapi_test.go
+++ b/syncapi/syncapi_test.go
@@ -40,6 +40,10 @@ type syncRoomserverAPI struct {
rooms []*test.Room
}
+func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
+ return spec.NewUserID(senderID, true)
+}
+
func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error {
var room *test.Room
for _, r := range s.rooms {
diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go
index c722fe60..66fb1d01 100644
--- a/syncapi/synctypes/clientevent.go
+++ b/syncapi/synctypes/clientevent.go
@@ -44,22 +44,27 @@ type ClientEvent struct {
}
// ToClientEvents converts server events to client events.
-func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat) []ClientEvent {
+func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) []ClientEvent {
evs := make([]ClientEvent, 0, len(serverEvs))
for _, se := range serverEvs {
if se == nil {
continue // TODO: shouldn't happen?
}
- evs = append(evs, ToClientEvent(se, format))
+ sender := spec.UserID{}
+ userID, err := userIDForSender(se.RoomID(), se.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
+ evs = append(evs, ToClientEvent(se, format, sender))
}
return evs
}
// ToClientEvent converts a single server event to a client event.
-func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat) ClientEvent {
+func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent {
ce := ClientEvent{
Content: spec.RawJSON(se.Content()),
- Sender: se.Sender(),
+ Sender: sender.String(),
Type: se.Type(),
StateKey: se.StateKey(),
Unsigned: spec.RawJSON(se.Unsigned()),
diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go
index b914e64f..34179508 100644
--- a/syncapi/synctypes/clientevent_test.go
+++ b/syncapi/synctypes/clientevent_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/gomatrixserverlib/spec"
)
func TestToClientEvent(t *testing.T) { // nolint: gocyclo
@@ -43,7 +44,11 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if err != nil {
t.Fatalf("failed to create Event: %s", err)
}
- ce := ToClientEvent(ev, FormatAll)
+ userID, err := spec.NewUserID("@test:localhost", true)
+ if err != nil {
+ t.Fatalf("failed to create userID: %s", err)
+ }
+ ce := ToClientEvent(ev, FormatAll, *userID)
if ce.EventID != ev.EventID() {
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
}
@@ -62,8 +67,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if !bytes.Equal(ce.Unsigned, ev.Unsigned()) {
t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned))
}
- if ce.Sender != ev.Sender() {
- t.Errorf("ClientEvent.Sender: wanted %s, got %s", ev.Sender(), ce.Sender)
+ if ce.Sender != userID.String() {
+ t.Errorf("ClientEvent.Sender: wanted %s, got %s", userID.String(), ce.Sender)
}
j, err := json.Marshal(ce)
if err != nil {
@@ -98,7 +103,11 @@ func TestToClientFormatSync(t *testing.T) {
if err != nil {
t.Fatalf("failed to create Event: %s", err)
}
- ce := ToClientEvent(ev, FormatSync)
+ userID, err := spec.NewUserID("@test:localhost", true)
+ if err != nil {
+ t.Fatalf("failed to create userID: %s", err)
+ }
+ ce := ToClientEvent(ev, FormatSync, *userID)
if ce.RoomID != "" {
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
}
diff --git a/syncapi/types/types.go b/syncapi/types/types.go
index 22c27fea..526a120d 100644
--- a/syncapi/types/types.go
+++ b/syncapi/types/types.go
@@ -539,7 +539,7 @@ type InviteResponse struct {
}
// NewInviteResponse creates an empty response with initialised arrays.
-func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse {
+func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse {
res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{}
@@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse {
// Then we'll see if we can create a partial of the invite event itself.
// This is needed for clients to work out *who* sent the invite.
- inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync)
+ inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID)
inviteEvent.Unsigned = nil
if ev, err := json.Marshal(inviteEvent); err == nil {
res.InviteState.Events = append(res.InviteState.Events, ev)
diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go
index 8e0448fe..a79ce541 100644
--- a/syncapi/types/types_test.go
+++ b/syncapi/types/types_test.go
@@ -8,8 +8,13 @@ import (
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/gomatrixserverlib/spec"
)
+func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) {
+ return spec.NewUserID(senderID, true)
+}
+
func TestSyncTokens(t *testing.T) {
shouldPass := map[string]string{
"s4_0_0_0_0_0_0_0_3": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0, 3}.String(),
@@ -56,7 +61,12 @@ func TestNewInviteResponse(t *testing.T) {
t.Fatal(err)
}
- res := NewInviteResponse(&types.HeaderedEvent{PDU: ev})
+ sender, err := spec.NewUserID("@neilalexander:matrix.org", true)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender)
j, err := json.Marshal(res)
if err != nil {
t.Fatal(err)
diff --git a/test/room.go b/test/room.go
index 852e3153..4cdb73aa 100644
--- a/test/room.go
+++ b/test/room.go
@@ -39,6 +39,10 @@ var (
roomIDCounter = int64(0)
)
+func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) {
+ return spec.NewUserID(senderID, true)
+}
+
type Room struct {
ID string
Version gomatrixserverlib.RoomVersion
@@ -195,7 +199,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
if err != nil {
t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err)
}
- if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil {
+ if err = gomatrixserverlib.Allowed(ev, &r.authEvents, UserIDForSender); err != nil {
t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
}
headeredEvent := &rstypes.HeaderedEvent{PDU: ev}
diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go
index 3cfdc0ce..c025deee 100644
--- a/userapi/consumers/roomserver.go
+++ b/userapi/consumers/roomserver.go
@@ -108,7 +108,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms
}
if s.cfg.Matrix.ReportStats.Enabled {
- go s.storeMessageStats(ctx, event.Type(), event.Sender(), event.RoomID())
+ go s.storeMessageStats(ctx, event.Type(), event.SenderID(), event.RoomID())
}
log.WithFields(log.Fields{
@@ -301,7 +301,12 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
switch {
case event.Type() == spec.MRoomMember:
- cevent := synctypes.ToClientEvent(event, synctypes.FormatAll)
+ sender := spec.UserID{}
+ userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if queryErr == nil && userID != nil {
+ sender = *userID
+ }
+ cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender)
var member *localMembership
member, err = newLocalMembership(&cevent)
if err != nil {
@@ -529,12 +534,17 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
return fmt.Errorf("s.localPushDevices: %w", err)
}
+ sender := spec.UserID{}
+ userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
n := &api.Notification{
Actions: actions,
// UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync.
- Event: synctypes.ToClientEvent(event, synctypes.FormatSync),
+ Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender),
// TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it
@@ -615,7 +625,12 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
// evaluatePushRules fetches and evaluates the push rules of a local
// user. Returns actions (including dont_notify).
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
- if event.Sender() == mem.UserID {
+ user := ""
+ sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if err == nil {
+ user = sender.String()
+ }
+ if user == mem.UserID {
// SPEC: Homeservers MUST NOT notify the Push Gateway for
// events that the user has sent themselves.
return nil, nil
@@ -632,9 +647,8 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
if err != nil {
return nil, err
}
- sender := event.Sender()
- if _, ok := ignored.List[sender]; ok {
- return nil, fmt.Errorf("user %s is ignored", sender)
+ if _, ok := ignored.List[sender.String()]; ok {
+ return nil, fmt.Errorf("user %s is ignored", sender.String())
}
}
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain)
@@ -650,7 +664,9 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
roomSize: roomSize,
}
eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global)
- rule, err := eval.MatchEvent(event.PDU)
+ rule, err := eval.MatchEvent(event.PDU, func(roomID, senderID string) (*spec.UserID, error) {
+ return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
if err != nil {
return nil, err
}
@@ -682,7 +698,7 @@ func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.Display
func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil }
-func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) {
+func (rse *ruleSetEvalContext) HasPowerLevel(senderID, levelKey string) (bool, error) {
req := &rsapi.QueryLatestEventsAndStateRequest{
RoomID: rse.roomID,
StateToFetch: []gomatrixserverlib.StateKeyTuple{
@@ -702,7 +718,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
if err != nil {
return false, err
}
- return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil
+ return plc.UserLevel(senderID) >= plc.NotificationLevel(levelKey), nil
}
return true, nil
}
@@ -756,6 +772,11 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
}
default:
+ sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if err != nil {
+ logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID())
+ return nil, err
+ }
req = pushgateway.NotifyRequest{
Notification: pushgateway.Notification{
Content: event.Content(),
@@ -767,7 +788,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
ID: event.EventID(),
RoomID: event.RoomID(),
RoomName: roomName,
- Sender: event.Sender(),
+ Sender: sender.String(),
Type: event.Type(),
},
}
diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go
index 53977206..899a5aaf 100644
--- a/userapi/consumers/roomserver_test.go
+++ b/userapi/consumers/roomserver_test.go
@@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/internal/pushrules"
+ rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi/storage"
@@ -44,13 +45,19 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent {
return &types.HeaderedEvent{PDU: ev}
}
+type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI }
+
+func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
+ return spec.NewUserID(senderID, true)
+}
+
func Test_evaluatePushRules(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
- consumer := OutputRoomEventConsumer{db: db}
+ consumer := OutputRoomEventConsumer{db: db, rsAPI: &FakeUserRoomserverAPI{}}
testCases := []struct {
name string
@@ -86,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) {
},
{
name: "m.room.message highlights",
- eventContent: `{"type":"m.room.message", "content": {"body": "test"} }`,
+ eventContent: `{"type":"m.room.message", "content": {"body": "test"}}`,
wantNotify: true,
wantAction: pushrules.NotifyAction,
wantActions: []*pushrules.Action{
diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go
index e1c88d47..27dd373c 100644
--- a/userapi/util/notify_test.go
+++ b/userapi/util/notify_test.go
@@ -11,6 +11,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"golang.org/x/crypto/bcrypt"
@@ -87,7 +88,7 @@ func TestNotifyUserCountsAsync(t *testing.T) {
}
// Prepare pusher with our test server URL
- if err := db.UpsertPusher(ctx, api.Pusher{
+ if err = db.UpsertPusher(ctx, api.Pusher{
Kind: api.HTTPKind,
AppID: appID,
PushKey: pushKey,
@@ -99,8 +100,12 @@ func TestNotifyUserCountsAsync(t *testing.T) {
}
// Insert a dummy event
+ sender, err := spec.NewUserID(alice.ID, true)
+ if err != nil {
+ t.Error(err)
+ }
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
- Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll),
+ Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender),
}); err != nil {
t.Error(err)
}