aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordevonh <devon.dmytro@gmail.com>2023-06-12 11:19:25 +0000
committerGitHub <noreply@github.com>2023-06-12 11:19:25 +0000
commit77d9e4e93dd01f6baa82bd6236850c1007346cac (patch)
tree20be66224646cc82199028cf89f4cd7fab80b97f
parent832ccc32f6a023665e250eee44b5f678e985d50e (diff)
Cleanup remaining statekey usage for senderIDs (#3106)
-rw-r--r--clientapi/routing/account_data.go10
-rw-r--r--clientapi/routing/aliases.go9
-rw-r--r--clientapi/routing/createroom.go1
-rw-r--r--clientapi/routing/directory.go33
-rw-r--r--clientapi/routing/leaveroom.go10
-rw-r--r--clientapi/routing/membership.go131
-rw-r--r--clientapi/routing/redaction.go34
-rw-r--r--clientapi/routing/sendtyping.go10
-rw-r--r--clientapi/routing/server_notices.go13
-rw-r--r--clientapi/routing/state.go53
-rw-r--r--clientapi/routing/upgrade_room.go10
-rw-r--r--federationapi/routing/eventauth.go2
-rw-r--r--federationapi/routing/events.go12
-rw-r--r--federationapi/routing/state.go2
-rw-r--r--go.mod2
-rw-r--r--go.sum4
-rw-r--r--roomserver/api/api.go21
-rw-r--r--roomserver/api/output.go6
-rw-r--r--roomserver/api/perform.go4
-rw-r--r--roomserver/api/query.go20
-rw-r--r--roomserver/auth/auth.go14
-rw-r--r--roomserver/auth/auth_test.go12
-rw-r--r--roomserver/internal/helpers/helpers.go37
-rw-r--r--roomserver/internal/helpers/helpers_test.go5
-rw-r--r--roomserver/internal/input/input_events.go12
-rw-r--r--roomserver/internal/input/input_membership.go21
-rw-r--r--roomserver/internal/perform/perform_admin.go6
-rw-r--r--roomserver/internal/perform/perform_backfill.go2
-rw-r--r--roomserver/internal/perform/perform_create_room.go15
-rw-r--r--roomserver/internal/perform/perform_invite.go8
-rw-r--r--roomserver/internal/perform/perform_join.go35
-rw-r--r--roomserver/internal/perform/perform_leave.go77
-rw-r--r--roomserver/internal/perform/perform_upgrade.go116
-rw-r--r--roomserver/internal/query/query.go70
-rw-r--r--roomserver/roomserver_test.go19
-rw-r--r--roomserver/storage/interface.go2
-rw-r--r--roomserver/storage/shared/storage.go7
-rw-r--r--setup/mscs/msc2836/msc2836.go11
-rw-r--r--setup/mscs/msc2836/msc2836_test.go6
-rw-r--r--syncapi/consumers/roomserver.go29
-rw-r--r--syncapi/internal/history_visibility.go14
-rw-r--r--syncapi/internal/keychange.go16
-rw-r--r--syncapi/internal/keychange_test.go4
-rw-r--r--syncapi/notifier/notifier.go45
-rw-r--r--syncapi/notifier/notifier_test.go22
-rw-r--r--syncapi/routing/context.go18
-rw-r--r--syncapi/routing/getevent.go11
-rw-r--r--syncapi/routing/memberships.go13
-rw-r--r--syncapi/routing/messages.go6
-rw-r--r--syncapi/routing/relations.go11
-rw-r--r--syncapi/routing/search.go11
-rw-r--r--syncapi/storage/shared/storage_consumer.go16
-rw-r--r--syncapi/storage/shared/storage_sync.go4
-rw-r--r--syncapi/streams/stream_invite.go11
-rw-r--r--syncapi/streams/stream_pdu.go12
-rw-r--r--syncapi/syncapi.go2
-rw-r--r--syncapi/synctypes/clientevent.go35
-rw-r--r--syncapi/synctypes/clientevent_test.go6
-rw-r--r--syncapi/types/types.go4
-rw-r--r--syncapi/types/types_test.go8
-rw-r--r--userapi/consumers/roomserver.go36
-rw-r--r--userapi/util/notify_test.go3
62 files changed, 752 insertions, 447 deletions
diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go
index 7eacf9cc..81afc3b1 100644
--- a/clientapi/routing/account_data.go
+++ b/clientapi/routing/account_data.go
@@ -145,8 +145,16 @@ func SaveReadMarker(
userAPI api.ClientUserAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string,
) util.JSONResponse {
+ deviceUserID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON("userID for this device is invalid"),
+ }
+ }
+
// Verify that the user is a member of this room
- resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
+ resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if resErr != nil {
return *resErr
}
diff --git a/clientapi/routing/aliases.go b/clientapi/routing/aliases.go
index f6603be8..2d6b72d3 100644
--- a/clientapi/routing/aliases.go
+++ b/clientapi/routing/aliases.go
@@ -55,9 +55,16 @@ func GetAliases(
visibility = content.HistoryVisibility
}
if visibility != spec.WorldReadable {
+ deviceUserID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
+ }
+ }
queryReq := api.QueryMembershipForUserRequest{
RoomID: roomID,
- UserID: device.UserID,
+ UserID: *deviceUserID,
}
var queryRes api.QueryMembershipForUserResponse
if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil {
diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go
index 799fc797..320f236c 100644
--- a/clientapi/routing/createroom.go
+++ b/clientapi/routing/createroom.go
@@ -224,6 +224,7 @@ func createRoom(
PrivateKey: privateKey,
EventTime: evTime,
}
+
roomAlias, createRes := rsAPI.PerformCreateRoom(ctx, *userID, *roomID, &req)
if createRes != nil {
return *createRes
diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go
index 034296f4..f01e24ec 100644
--- a/clientapi/routing/directory.go
+++ b/clientapi/routing/directory.go
@@ -314,7 +314,22 @@ func SetVisibility(
req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device,
roomID string,
) util.JSONResponse {
- resErr := checkMemberInRoom(req.Context(), rsAPI, dev.UserID, roomID)
+ deviceUserID, err := spec.NewUserID(dev.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON("userID for this device is invalid"),
+ }
+ }
+ senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.Unknown("failed to find senderID for this user"),
+ }
+ }
+
+ resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if resErr != nil {
return *resErr
}
@@ -327,7 +342,7 @@ func SetVisibility(
}},
}
var queryEventsRes roomserverAPI.QueryLatestEventsAndStateResponse
- err := rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes)
+ err = rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes)
if err != nil || len(queryEventsRes.StateEvents) == 0 {
util.GetLogger(req.Context()).WithError(err).Error("could not query events from room")
return util.JSONResponse{
@@ -338,20 +353,6 @@ func SetVisibility(
// NOTSPEC: Check if the user's power is greater than power required to change m.room.canonical_alias event
power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].PDU)
- fullUserID, err := spec.NewUserID(dev.UserID, true)
- if err != nil {
- return util.JSONResponse{
- Code: http.StatusForbidden,
- JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
- }
- }
- senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
- if err != nil {
- return util.JSONResponse{
- Code: http.StatusForbidden,
- JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
- }
- }
if power.UserLevel(senderID) < power.EventLevel(spec.MRoomCanonicalAlias, true) {
return util.JSONResponse{
Code: http.StatusForbidden,
diff --git a/clientapi/routing/leaveroom.go b/clientapi/routing/leaveroom.go
index fbf14826..7e8c066e 100644
--- a/clientapi/routing/leaveroom.go
+++ b/clientapi/routing/leaveroom.go
@@ -29,10 +29,18 @@ func LeaveRoomByID(
rsAPI roomserverAPI.ClientRoomserverAPI,
roomID string,
) util.JSONResponse {
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.Unknown("device userID is invalid"),
+ }
+ }
+
// Prepare to ask the roomserver to perform the room join.
leaveReq := roomserverAPI.PerformLeaveRequest{
RoomID: roomID,
- UserID: device.UserID,
+ Leaver: *userID,
}
leaveRes := roomserverAPI.PerformLeaveResponse{}
diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go
index 78829bec..03e85edb 100644
--- a/clientapi/routing/membership.go
+++ b/clientapi/routing/membership.go
@@ -57,29 +57,30 @@ func SendBan(
}
}
- errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
- if errRes != nil {
- return *errRes
- }
-
- pl, errRes := getPowerlevels(req, rsAPI, roomID)
- if errRes != nil {
- return *errRes
- }
- fullUserID, err := spec.NewUserID(device.UserID, true)
+ deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"),
}
}
- senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
+ senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to ban this user, unknown senderID"),
}
}
+
+ errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
+ if errRes != nil {
+ return *errRes
+ }
+
+ pl, errRes := getPowerlevels(req, rsAPI, roomID)
+ if errRes != nil {
+ return *errRes
+ }
allowedToBan := pl.UserLevel(senderID) >= pl.Ban
if !allowedToBan {
return util.JSONResponse{
@@ -147,29 +148,30 @@ func SendKick(
}
}
- errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
- if errRes != nil {
- return *errRes
- }
-
- pl, errRes := getPowerlevels(req, rsAPI, roomID)
- if errRes != nil {
- return *errRes
- }
- fullUserID, err := spec.NewUserID(device.UserID, true)
+ deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
}
}
- senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
+ senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
if err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
}
}
+
+ errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
+ if errRes != nil {
+ return *errRes
+ }
+
+ pl, errRes := getPowerlevels(req, rsAPI, roomID)
+ if errRes != nil {
+ return *errRes
+ }
allowedToKick := pl.UserLevel(senderID) >= pl.Kick
if !allowedToKick {
return util.JSONResponse{
@@ -178,10 +180,17 @@ func SendKick(
}
}
+ bodyUserID, err := spec.NewUserID(body.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON("body userID is invalid"),
+ }
+ }
var queryRes roomserverAPI.QueryMembershipForUserResponse
err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID,
- UserID: body.UserID,
+ UserID: *bodyUserID,
}, &queryRes)
if err != nil {
return util.ErrorResponse(err)
@@ -213,15 +222,30 @@ func SendUnban(
}
}
- errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
+ deviceUserID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
+ }
+ }
+
+ errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if errRes != nil {
return *errRes
}
+ bodyUserID, err := spec.NewUserID(body.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON("body userID is invalid"),
+ }
+ }
var queryRes roomserverAPI.QueryMembershipForUserResponse
- err := rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
+ err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID,
- UserID: body.UserID,
+ UserID: *bodyUserID,
}, &queryRes)
if err != nil {
return util.ErrorResponse(err)
@@ -272,7 +296,15 @@ func SendInvite(
}
}
- errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
+ deviceUserID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
+ }
+ }
+
+ errRes := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if errRes != nil {
return *errRes
}
@@ -340,17 +372,18 @@ func sendInvite(
func buildMembershipEventDirect(
ctx context.Context,
- targetUserID, reason string, userDisplayName, userAvatarURL string,
- sender string, senderDomain spec.ServerName,
+ targetSenderID spec.SenderID, reason string, userDisplayName, userAvatarURL string,
+ sender spec.SenderID, senderDomain spec.ServerName,
membership, roomID string, isDirect bool,
keyID gomatrixserverlib.KeyID, privateKey ed25519.PrivateKey, evTime time.Time,
rsAPI roomserverAPI.ClientRoomserverAPI,
) (*types.HeaderedEvent, error) {
+ targetSenderString := string(targetSenderID)
proto := gomatrixserverlib.ProtoEvent{
- SenderID: sender,
+ SenderID: string(sender),
RoomID: roomID,
Type: "m.room.member",
- StateKey: &targetUserID,
+ StateKey: &targetSenderString,
}
content := gomatrixserverlib.MemberContent{
@@ -391,8 +424,25 @@ func buildMembershipEvent(
return nil, err
}
- return buildMembershipEventDirect(ctx, targetUserID, reason, profile.DisplayName, profile.AvatarURL,
- device.UserID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI)
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return nil, err
+ }
+ senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID)
+ if err != nil {
+ return nil, err
+ }
+
+ targetID, err := spec.NewUserID(targetUserID, true)
+ if err != nil {
+ return nil, err
+ }
+ targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *targetID)
+ if err != nil {
+ return nil, err
+ }
+ return buildMembershipEventDirect(ctx, targetSenderID, reason, profile.DisplayName, profile.AvatarURL,
+ senderID, device.UserDomain(), membership, roomID, isDirect, identity.KeyID, identity.PrivateKey, evTime, rsAPI)
}
// loadProfile lookups the profile of a given user from the database and returns
@@ -490,7 +540,7 @@ func checkAndProcessThreepid(
return
}
-func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID, roomID string) *util.JSONResponse {
+func checkMemberInRoom(ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, userID spec.UserID, roomID string) *util.JSONResponse {
var membershipRes roomserverAPI.QueryMembershipForUserResponse
err := rsAPI.QueryMembershipForUser(ctx, &roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID,
@@ -518,12 +568,21 @@ func SendForget(
) util.JSONResponse {
ctx := req.Context()
logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID)
+
+ deviceUserID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
+ }
+ }
+
var membershipRes roomserverAPI.QueryMembershipForUserResponse
membershipReq := roomserverAPI.QueryMembershipForUserRequest{
RoomID: roomID,
- UserID: device.UserID,
+ UserID: *deviceUserID,
}
- err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes)
+ err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes)
if err != nil {
logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user")
return util.JSONResponse{
diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go
index 22474fc0..da48e84d 100644
--- a/clientapi/routing/redaction.go
+++ b/clientapi/routing/redaction.go
@@ -47,7 +47,22 @@ func SendRedaction(
txnID *string,
txnCache *transactions.Cache,
) util.JSONResponse {
- resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID)
+ deviceUserID, userIDErr := spec.NewUserID(device.UserID, true)
+ if userIDErr != nil {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: spec.Forbidden("userID doesn't have power level to redact"),
+ }
+ }
+ senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
+ if queryErr != nil {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: spec.Forbidden("userID doesn't have power level to redact"),
+ }
+ }
+
+ resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if resErr != nil {
return *resErr
}
@@ -73,25 +88,10 @@ func SendRedaction(
}
}
- fullUserID, userIDErr := spec.NewUserID(device.UserID, true)
- if userIDErr != nil {
- return util.JSONResponse{
- Code: http.StatusForbidden,
- JSON: spec.Forbidden("userID doesn't have power level to redact"),
- }
- }
- senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *fullUserID)
- if queryErr != nil {
- return util.JSONResponse{
- Code: http.StatusForbidden,
- JSON: spec.Forbidden("userID doesn't have power level to redact"),
- }
- }
-
// "Users may redact their own events, and any user with a power level greater than or equal
// to the redact power level of the room may redact events there"
// https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid
- allowedToRedact := ev.SenderID() == senderID // TODO: Should replace device.UserID with device...PerRoomKey
+ allowedToRedact := ev.SenderID() == senderID
if !allowedToRedact {
plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomPowerLevels,
diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go
index c5b29297..979bced3 100644
--- a/clientapi/routing/sendtyping.go
+++ b/clientapi/routing/sendtyping.go
@@ -43,8 +43,16 @@ func SendTyping(
}
}
+ deviceUserID, err := spec.NewUserID(userID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
+ }
+ }
+
// Verify that the user is a member of this room
- resErr := checkMemberInRoom(req.Context(), rsAPI, userID, roomID)
+ resErr := checkMemberInRoom(req.Context(), rsAPI, *deviceUserID, roomID)
if resErr != nil {
return *resErr
}
diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go
index 06714ed1..7006ced4 100644
--- a/clientapi/routing/server_notices.go
+++ b/clientapi/routing/server_notices.go
@@ -52,6 +52,7 @@ type sendServerNoticeRequest struct {
StateKey string `json:"state_key,omitempty"`
}
+// nolint:gocyclo
// SendServerNotice sends a message to a specific user. It can only be invoked by an admin.
func SendServerNotice(
req *http.Request,
@@ -187,9 +188,17 @@ func SendServerNotice(
}
} else {
// we've found a room in common, check the membership
+ deviceUserID, err := spec.NewUserID(r.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: spec.Forbidden("userID doesn't have power level to change visibility"),
+ }
+ }
+
roomID = commonRooms[0]
membershipRes := api.QueryMembershipForUserResponse{}
- err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: r.UserID, RoomID: roomID}, &membershipRes)
+ err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("unable to query membership for user")
return util.JSONResponse{
@@ -234,7 +243,7 @@ func SendServerNotice(
ctx, rsAPI,
api.KindNew,
[]*types.HeaderedEvent{
- &types.HeaderedEvent{PDU: e},
+ {PDU: e},
},
device.UserDomain(),
cfgClient.Matrix.ServerName,
diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go
index 13f30899..e3a209b6 100644
--- a/clientapi/routing/state.go
+++ b/clientapi/routing/state.go
@@ -99,9 +99,17 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
if !worldReadable {
// The room isn't world-readable so try to work out based on the
// user's membership if we want the latest state or not.
- err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("UserID is invalid")
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.Unknown("Device UserID is invalid"),
+ }
+ }
+ err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
RoomID: roomID,
- UserID: device.UserID,
+ UserID: *userID,
}, &membershipRes)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser")
@@ -140,14 +148,11 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
// use the result of the previous QueryLatestEventsAndState response
// to find the state event, if provided.
for _, ev := range stateRes.StateEvents {
- sender := spec.UserID{}
- userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
- if err == nil && userID != nil {
- sender = *userID
- }
stateEvents = append(
stateEvents,
- synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
+ synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }, ev),
)
}
} else {
@@ -172,9 +177,18 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
if err == nil && userID != nil {
sender = *userID
}
+
+ sk := ev.StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
+ if err == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
stateEvents = append(
stateEvents,
- synctypes.ToClientEvent(ev, synctypes.FormatAll, sender),
+ synctypes.ToClientEvent(ev, synctypes.FormatAll, sender, sk),
)
}
}
@@ -259,11 +273,19 @@ func OnIncomingStateTypeRequest(
// membershipRes will only be populated if the room is not world-readable.
var membershipRes api.QueryMembershipForUserResponse
if !worldReadable {
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("UserID is invalid")
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.Unknown("Device UserID is invalid"),
+ }
+ }
// The room isn't world-readable so try to work out based on the
// user's membership if we want the latest state or not.
- err := rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
+ err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{
RoomID: roomID,
- UserID: device.UserID,
+ UserID: *userID,
}, &membershipRes)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to QueryMembershipForUser")
@@ -344,13 +366,10 @@ func OnIncomingStateTypeRequest(
}
}
- sender := spec.UserID{}
- userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
- if err == nil && userID != nil {
- sender = *userID
- }
stateEvent := stateEventInStateResp{
- ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
+ ClientEvent: synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }, event),
}
var res interface{}
diff --git a/clientapi/routing/upgrade_room.go b/clientapi/routing/upgrade_room.go
index a0b28078..03c0230e 100644
--- a/clientapi/routing/upgrade_room.go
+++ b/clientapi/routing/upgrade_room.go
@@ -59,7 +59,15 @@ func UpgradeRoom(
}
}
- newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, device.UserID, gomatrixserverlib.RoomVersion(r.NewVersion))
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("device UserID is invalid")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{},
+ }
+ }
+ newRoomID, err := rsAPI.PerformRoomUpgrade(req.Context(), roomID, *userID, gomatrixserverlib.RoomVersion(r.NewVersion))
switch e := err.(type) {
case nil:
case roomserverAPI.ErrNotAllowed:
diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go
index ca279ac2..c26aa2f1 100644
--- a/federationapi/routing/eventauth.go
+++ b/federationapi/routing/eventauth.go
@@ -45,7 +45,7 @@ func GetEventAuth(
if event.RoomID() != roomID {
return util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")}
}
- resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID)
+ resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID())
if resErr != nil {
return *resErr
}
diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go
index 196a54db..d3f0e81c 100644
--- a/federationapi/routing/events.go
+++ b/federationapi/routing/events.go
@@ -35,10 +35,6 @@ func GetEvent(
eventID string,
origin spec.ServerName,
) util.JSONResponse {
- err := allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID)
- if err != nil {
- return *err
- }
// /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string,
// which results in `QueryEventsByID` to first get the event and use that to determine the roomID.
event, err := fetchEvent(ctx, rsAPI, "", eventID)
@@ -46,6 +42,11 @@ func GetEvent(
return *err
}
+ err = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID())
+ if err != nil {
+ return *err
+ }
+
return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.Transaction{
Origin: origin,
OriginServerTS: spec.AsTimestamp(time.Now()),
@@ -62,8 +63,9 @@ func allowedToSeeEvent(
origin spec.ServerName,
rsAPI api.FederationRoomserverAPI,
eventID string,
+ roomID string,
) *util.JSONResponse {
- allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID)
+ allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID, roomID)
if err != nil {
resErr := util.ErrorResponse(err)
return &resErr
diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go
index fa0e9351..11ad1ebf 100644
--- a/federationapi/routing/state.go
+++ b/federationapi/routing/state.go
@@ -116,7 +116,7 @@ func getState(
if event.RoomID() != roomID {
return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: spec.NotFound("event does not belong to this room")}
}
- resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID)
+ resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID, event.RoomID())
if resErr != nil {
return nil, nil, resErr
}
diff --git a/go.mod b/go.mod
index 3621428c..2fbae314 100644
--- a/go.mod
+++ b/go.mod
@@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
- github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d
+ github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16
diff --git a/go.sum b/go.sum
index 1ee0261f..ef8c298a 100644
--- a/go.sum
+++ b/go.sum
@@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
-github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d h1:MjL8SXRzhO61aXDFL+gA3Bx1SicqLGL9gCWXDv8jkD8=
-github.com/matrix-org/gomatrixserverlib v0.0.0-20230607161930-ea5ef168992d/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
+github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 h1:AmKkAUjy9rZA2K+qHXm/O/dPEPnUYfRE2I6SL+Dj+LU=
+github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
diff --git a/roomserver/api/api.go b/roomserver/api/api.go
index 8c2cbd6b..bafde91c 100644
--- a/roomserver/api/api.go
+++ b/roomserver/api/api.go
@@ -34,11 +34,11 @@ func (e ErrNotAllowed) Error() string {
type RestrictedJoinAPI interface {
CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error)
- InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error)
- RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error)
+ InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error)
+ RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error)
QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error)
QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error
- UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error)
+ UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, senderID spec.SenderID) (bool, error)
LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error)
}
@@ -191,7 +191,7 @@ type ClientRoomserverAPI interface {
PerformCreateRoom(ctx context.Context, userID spec.UserID, roomID spec.RoomID, createRequest *PerformCreateRoomRequest) (string, *util.JSONResponse)
// PerformRoomUpgrade upgrades a room to a newer version
- PerformRoomUpgrade(ctx context.Context, roomID, userID string, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
+ PerformRoomUpgrade(ctx context.Context, roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion) (newRoomID string, err error)
PerformAdminEvacuateRoom(ctx context.Context, roomID string) (affected []string, err error)
PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error)
PerformAdminPurgeRoom(ctx context.Context, roomID string) error
@@ -228,6 +228,7 @@ type FederationRoomserverAPI interface {
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
+ QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
@@ -238,15 +239,13 @@ type FederationRoomserverAPI interface {
// Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate
// the state and auth chain to return.
QueryStateAndAuthChain(ctx context.Context, req *QueryStateAndAuthChainRequest, res *QueryStateAndAuthChainResponse) error
- // Query if we think we're still in a room.
- QueryServerJoinedToRoom(ctx context.Context, req *QueryServerJoinedToRoomRequest, res *QueryServerJoinedToRoomResponse) error
QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error
// Query missing events for a room from roomserver
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
// Query whether a server is allowed to see an event
- QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string) (allowed bool, err error)
+ QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string, roomID string) (allowed bool, err error)
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
- QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error)
+ QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error
@@ -254,12 +253,6 @@ type FederationRoomserverAPI interface {
// Query a given amount (or less) of events prior to a given set of events.
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
- CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error)
- InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error)
- QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types.RoomInfo, error)
- UserJoinedToRoom(ctx context.Context, roomID types.RoomNID, userID spec.UserID) (bool, error)
- LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomNID types.RoomNID) ([]gomatrixserverlib.PDU, error)
-
IsKnownRoom(ctx context.Context, roomID spec.RoomID) (bool, error)
StateQuerier() gomatrixserverlib.StateQuerier
}
diff --git a/roomserver/api/output.go b/roomserver/api/output.go
index 16b50495..852b6420 100644
--- a/roomserver/api/output.go
+++ b/roomserver/api/output.go
@@ -215,8 +215,10 @@ type OutputNewInviteEvent struct {
type OutputRetireInviteEvent struct {
// The ID of the "m.room.member" invite event.
EventID string
- // The target user ID of the "m.room.member" invite event that was retired.
- TargetUserID string
+ // The room ID of the "m.room.member" invite event.
+ RoomID string
+ // The target sender ID of the "m.room.member" invite event that was retired.
+ TargetSenderID spec.SenderID
// Optional event ID of the event that replaced the invite.
// This can be empty if the invite was rejected locally and we were unable
// to reach the server that originally sent the invite.
diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go
index 6cbaf5b1..b466b7ba 100644
--- a/roomserver/api/perform.go
+++ b/roomserver/api/perform.go
@@ -41,8 +41,8 @@ type PerformJoinRequest struct {
}
type PerformLeaveRequest struct {
- RoomID string `json:"room_id"`
- UserID string `json:"user_id"`
+ RoomID string
+ Leaver spec.UserID
}
type PerformLeaveResponse struct {
diff --git a/roomserver/api/query.go b/roomserver/api/query.go
index d79dcebb..684a5b0e 100644
--- a/roomserver/api/query.go
+++ b/roomserver/api/query.go
@@ -113,9 +113,9 @@ type QueryEventsByIDResponse struct {
// QueryMembershipForUserRequest is a request to QueryMembership
type QueryMembershipForUserRequest struct {
// ID of the room to fetch membership from
- RoomID string `json:"room_id"`
+ RoomID string
// ID of the user for whom membership is requested
- UserID string `json:"user_id"`
+ UserID spec.UserID
}
// QueryMembershipForUserResponse is a response to QueryMembership
@@ -145,7 +145,7 @@ type QueryMembershipsForRoomRequest struct {
// Optional - ID of the user sending the request, for checking if the
// user is allowed to see the memberships. If not specified then all
// room memberships will be returned.
- Sender string `json:"sender"`
+ SenderID spec.SenderID `json:"sender"`
}
// QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom
@@ -448,11 +448,11 @@ func (rq *JoinRoomQuerier) CurrentStateEvent(ctx context.Context, roomID spec.Ro
return rq.Roomserver.CurrentStateEvent(ctx, roomID, eventType, stateKey)
}
-func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) {
- return rq.Roomserver.InvitePending(ctx, roomID, userID)
+func (rq *JoinRoomQuerier) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) {
+ return rq.Roomserver.InvitePending(ctx, roomID, senderID)
}
-func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
+func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
roomInfo, err := rq.Roomserver.QueryRoomInfo(ctx, roomID)
if err != nil || roomInfo == nil || roomInfo.IsStub() {
return nil, err
@@ -468,7 +468,7 @@ func (rq *JoinRoomQuerier) RestrictedRoomJoinInfo(ctx context.Context, roomID sp
return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err)
}
- userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID)
+ userJoinedToRoom, err := rq.Roomserver.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed")
return nil, fmt.Errorf("InternalServerError: %w", err)
@@ -492,12 +492,8 @@ type MembershipQuerier struct {
}
func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
- req := QueryMembershipForUserRequest{
- RoomID: roomID.String(),
- UserID: string(senderID),
- }
res := QueryMembershipForUserResponse{}
- err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res)
+ err := mq.Roomserver.QueryMembershipForSenderID(ctx, roomID, senderID, &res)
membership := ""
if err == nil {
diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go
index b6168d38..ba10a433 100644
--- a/roomserver/auth/auth.go
+++ b/roomserver/auth/auth.go
@@ -13,6 +13,9 @@
package auth
import (
+ "context"
+
+ "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
@@ -22,6 +25,7 @@ import (
// IsServerAllowed returns true if the server is allowed to see events in the room
// at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87
func IsServerAllowed(
+ ctx context.Context, db storage.RoomDatabase,
serverName spec.ServerName,
serverCurrentlyInRoom bool,
authEvents []gomatrixserverlib.PDU,
@@ -37,7 +41,7 @@ func IsServerAllowed(
return true
}
// 2. If the user's membership was join, allow.
- joinedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Join)
+ joinedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Join)
if joinedUserExists {
return true
}
@@ -46,7 +50,7 @@ func IsServerAllowed(
return true
}
// 4. If the user's membership was invite, and the history_visibility was set to invited, allow.
- invitedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Invite)
+ invitedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Invite)
if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited {
return true
}
@@ -70,7 +74,7 @@ func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.PDU) gomatrixserver
return visibility
}
-func IsAnyUserOnServerWithMembership(serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool {
+func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool {
for _, ev := range authEvents {
if ev.Type() != spec.MRoomMember {
continue
@@ -85,12 +89,12 @@ func IsAnyUserOnServerWithMembership(serverName spec.ServerName, authEvents []go
continue
}
- _, domain, err := gomatrixserverlib.SplitID('@', *stateKey)
+ userID, err := db.GetUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey))
if err != nil {
continue
}
- if domain == serverName {
+ if userID.Domain() == serverName {
return true
}
}
diff --git a/roomserver/auth/auth_test.go b/roomserver/auth/auth_test.go
index e3eea5d8..192d9e5d 100644
--- a/roomserver/auth/auth_test.go
+++ b/roomserver/auth/auth_test.go
@@ -1,13 +1,23 @@
package auth
import (
+ "context"
"testing"
+ "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
+type FakeStorageDB struct {
+ storage.RoomDatabase
+}
+
+func (f *FakeStorageDB) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ return spec.NewUserID(string(senderID), true)
+}
+
func TestIsServerAllowed(t *testing.T) {
alice := test.NewUser(t)
@@ -77,7 +87,7 @@ func TestIsServerAllowed(t *testing.T) {
authEvents = append(authEvents, ev.PDU)
}
- if got := IsServerAllowed(tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
+ if got := IsServerAllowed(context.Background(), &FakeStorageDB{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want)
}
})
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go
index 95397cd5..263cb9f8 100644
--- a/roomserver/internal/helpers/helpers.go
+++ b/roomserver/internal/helpers/helpers.go
@@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"sort"
- "strings"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
@@ -55,9 +54,10 @@ func UpdateToInviteMembership(
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID,
+ RoomID: add.RoomID(),
Membership: spec.Join,
RetiredByEventID: add.EventID(),
- TargetUserID: *add.StateKey(),
+ TargetSenderID: spec.SenderID(*add.StateKey()),
},
})
}
@@ -94,13 +94,13 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
for i := range events {
gmslEvents[i] = events[i].PDU
}
- return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, spec.Join), nil
+ return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil
}
func IsInvitePending(
ctx context.Context, db storage.Database,
- roomID, userID string,
-) (bool, string, string, gomatrixserverlib.PDU, error) {
+ roomID string, senderID spec.SenderID,
+) (bool, spec.SenderID, string, gomatrixserverlib.PDU, error) {
// Look up the room NID for the supplied room ID.
info, err := db.RoomInfo(ctx, roomID)
if err != nil {
@@ -111,13 +111,13 @@ func IsInvitePending(
}
// Look up the state key NID for the supplied user ID.
- targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID})
+ targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{string(senderID)})
if err != nil {
return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
}
- targetUserNID, targetUserFound := targetUserNIDs[userID]
+ targetUserNID, targetUserFound := targetUserNIDs[string(senderID)]
if !targetUserFound {
- return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
+ return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", senderID, targetUserNIDs)
}
// Let's see if we have an event active for the user in the room. If
@@ -156,7 +156,7 @@ func IsInvitePending(
event, err := verImpl.NewEventFromTrustedJSON(eventJSON, false)
- return true, senderUser, userNIDToEventID[senderUserNIDs[0]], event, err
+ return true, spec.SenderID(senderUser), userNIDToEventID[senderUserNIDs[0]], event, err
}
// GetMembershipsAtState filters the state events to
@@ -264,7 +264,7 @@ func LoadStateEvents(
}
func CheckServerAllowedToSeeEvent(
- ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName, isServerInRoom bool,
+ ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool,
) (bool, error) {
stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
switch err {
@@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent(
case tables.OptimisationNotSupportedError:
// The database engine didn't support this optimisation, so fall back to using
// the old and slow method
- stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, eventID, serverName)
+ stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName)
if err != nil {
return false, err
}
@@ -288,11 +288,11 @@ func CheckServerAllowedToSeeEvent(
return false, err
}
}
- return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
+ return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil
}
func slowGetHistoryVisibilityState(
- ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName,
+ ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName,
) ([]gomatrixserverlib.PDU, error) {
roomState := state.NewStateResolution(db, info)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
@@ -319,8 +319,13 @@ func slowGetHistoryVisibilityState(
// then we'll filter it out. This does preserve state keys that
// are "" since these will contain history visibility etc.
for nid, key := range stateKeys {
- if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) {
- delete(stateKeys, nid)
+ if key != "" {
+ userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key))
+ if err == nil && userID != nil {
+ if userID.Domain() != serverName {
+ delete(stateKeys, nid)
+ }
+ }
}
}
@@ -410,7 +415,7 @@ BFSLoop:
// hasn't been seen before.
if !visited[pre] {
visited[pre] = true
- allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom)
+ allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom)
if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event",
diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go
index f1896277..1cef83df 100644
--- a/roomserver/internal/helpers/helpers_test.go
+++ b/roomserver/internal/helpers/helpers_test.go
@@ -8,6 +8,7 @@ import (
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/roomserver/types"
@@ -58,12 +59,12 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
}
// Alice should have no pending invites and should have a NID
- pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID)
+ pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, spec.SenderID(alice.ID))
assert.NoError(t, err, "failed to get pending invites")
assert.False(t, pendingInvite, "unexpected pending invite")
// Bob should have no pending invites and receive a new NID
- pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID)
+ pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, spec.SenderID(bob.ID))
assert.NoError(t, err, "failed to get pending invites")
assert.False(t, pendingInvite, "unexpected pending invite")
})
diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go
index 1f273da0..7bb40163 100644
--- a/roomserver/internal/input/input_events.go
+++ b/roomserver/internal/input/input_events.go
@@ -842,17 +842,15 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
continue
}
- // TODO: pseudoIDs: get userID for room using state key (which is now senderID)
- localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
+ memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, memberEvent.RoomID(), spec.SenderID(*memberEvent.StateKey()))
if err != nil {
continue
}
- // TODO: pseudoIDs: query account by state key (which is now senderID)
accountRes := &userAPI.QueryAccountByLocalpartResponse{}
if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
- Localpart: localpart,
- ServerName: senderDomain,
+ Localpart: memberUserID.Local(),
+ ServerName: memberUserID.Domain(),
}, accountRes); err != nil {
return err
}
@@ -896,8 +894,8 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
inputEvents = append(inputEvents, api.InputRoomEvent{
Kind: api.KindNew,
Event: event,
- Origin: senderDomain,
- SendAsServer: string(senderDomain),
+ Origin: memberUserID.Domain(),
+ SendAsServer: string(memberUserID.Domain()),
})
prevEvents = []string{event.EventID()}
}
diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go
index 98d7d13b..09c65dfe 100644
--- a/roomserver/internal/input/input_membership.go
+++ b/roomserver/internal/input/input_membership.go
@@ -18,7 +18,6 @@ import (
"context"
"fmt"
- "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/internal"
@@ -72,7 +71,7 @@ func (r *Inputer) updateMemberships(
if change.addedEventNID != 0 {
ae, _ = helpers.EventMap(events).Lookup(change.addedEventNID)
}
- if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil {
+ if updates, err = r.updateMembership(ctx, updater, targetUserNID, re, ae, updates); err != nil {
return nil, err
}
}
@@ -80,6 +79,7 @@ func (r *Inputer) updateMemberships(
}
func (r *Inputer) updateMembership(
+ ctx context.Context,
updater *shared.RoomUpdater,
targetUserNID types.EventStateKeyNID,
remove, add *types.Event,
@@ -97,7 +97,7 @@ func (r *Inputer) updateMembership(
var targetLocal bool
if add != nil {
- targetLocal = r.isLocalTarget(add)
+ targetLocal = r.isLocalTarget(ctx, add)
}
mu, err := updater.MembershipUpdater(targetUserNID, targetLocal)
@@ -136,11 +136,14 @@ func (r *Inputer) updateMembership(
}
}
-func (r *Inputer) isLocalTarget(event *types.Event) bool {
+func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool {
isTargetLocalUser := false
if statekey := event.StateKey(); statekey != nil {
- _, domain, _ := gomatrixserverlib.SplitID('@', *statekey)
- isTargetLocalUser = domain == r.ServerName
+ userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey))
+ if err != nil || userID == nil {
+ return isTargetLocalUser
+ }
+ isTargetLocalUser = userID.Domain() == r.ServerName
}
return isTargetLocalUser
}
@@ -161,9 +164,10 @@ func updateToJoinMembership(
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID,
+ RoomID: add.RoomID(),
Membership: spec.Join,
RetiredByEventID: add.EventID(),
- TargetUserID: *add.StateKey(),
+ TargetSenderID: spec.SenderID(*add.StateKey()),
},
})
}
@@ -187,9 +191,10 @@ func updateToLeaveMembership(
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID,
+ RoomID: add.RoomID(),
Membership: newMembership,
RetiredByEventID: add.EventID(),
- TargetUserID: *add.StateKey(),
+ TargetSenderID: spec.SenderID(*add.StateKey()),
},
})
}
diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go
index eeb1ac40..ec13bff8 100644
--- a/roomserver/internal/perform/perform_admin.go
+++ b/roomserver/internal/perform/perform_admin.go
@@ -149,11 +149,11 @@ func (r *Admin) PerformAdminEvacuateUser(
ctx context.Context,
userID string,
) (affected []string, err error) {
- _, domain, err := gomatrixserverlib.SplitID('@', userID)
+ fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
- if !r.Cfg.Matrix.IsLocalServerName(domain) {
+ if !r.Cfg.Matrix.IsLocalServerName(fullUserID.Domain()) {
return nil, fmt.Errorf("can only evacuate local users using this endpoint")
}
@@ -172,7 +172,7 @@ func (r *Admin) PerformAdminEvacuateUser(
for _, roomID := range allRooms {
leaveReq := &api.PerformLeaveRequest{
RoomID: roomID,
- UserID: userID,
+ Leaver: *fullUserID,
}
leaveRes := &api.PerformLeaveResponse{}
outputEvents, err := r.Leaver.PerformLeave(ctx, leaveReq, leaveRes)
diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go
index 38815093..8e87359a 100644
--- a/roomserver/internal/perform/perform_backfill.go
+++ b/roomserver/internal/perform/perform_backfill.go
@@ -582,7 +582,7 @@ func joinEventsFromHistoryVisibility(
}
// Can we see events in the room?
- canSeeEvents := auth.IsServerAllowed(thisServer, true, events)
+ canSeeEvents := auth.IsServerAllowed(ctx, db, thisServer, true, events)
visibility := auth.HistoryVisibilityForRoom(events)
if !canSeeEvents {
logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)
diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go
index a3ba20f7..475418aa 100644
--- a/roomserver/internal/perform/perform_create_room.go
+++ b/roomserver/internal/perform/perform_create_room.go
@@ -63,9 +63,17 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
}
}
}
- createContent["creator"] = userID.String()
+ senderID, err := c.DB.GetSenderIDForUser(ctx, roomID.String(), userID)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
+ return "", &util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{},
+ }
+ }
+ createContent["creator"] = senderID
createContent["room_version"] = createRequest.RoomVersion
- powerLevelContent := eventutil.InitialPowerLevelsContent(userID.String())
+ powerLevelContent := eventutil.InitialPowerLevelsContent(string(senderID))
joinRuleContent := gomatrixserverlib.JoinRuleContent{
JoinRule: spec.Invite,
}
@@ -121,7 +129,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
}
membershipEvent := gomatrixserverlib.FledglingEvent{
Type: spec.MRoomMember,
- StateKey: userID.String(),
+ StateKey: string(senderID),
Content: gomatrixserverlib.MemberContent{
Membership: spec.Join,
DisplayName: createRequest.UserDisplayName,
@@ -270,7 +278,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
var builtEvents []*types.HeaderedEvent
authEvents := gomatrixserverlib.NewAuthEvents(nil)
- senderID, err := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsapi.QuerySenderIDForUser failed")
return "", &util.JSONResponse{
diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go
index 56ee1606..1440daad 100644
--- a/roomserver/internal/perform/perform_invite.go
+++ b/roomserver/internal/perform/perform_invite.go
@@ -134,12 +134,12 @@ func (r *Inviter) PerformInvite(
return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
}
- if event.StateKey() == nil {
+ if event.StateKey() == nil || *event.StateKey() == "" {
return fmt.Errorf("invite must be a state event")
}
- invitedUser, err := spec.NewUserID(*event.StateKey(), true)
- if err != nil {
- return spec.InvalidParam("The user ID is invalid")
+ invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
+ if err != nil || invitedUser == nil {
+ return spec.InvalidParam("Could not find the matching senderID for this user")
}
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain())
diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go
index d41cc214..83c3b7c3 100644
--- a/roomserver/internal/perform/perform_join.go
+++ b/roomserver/internal/perform/perform_join.go
@@ -162,7 +162,7 @@ func (r *Joiner) performJoinRoomByID(
}
// Get the domain part of the room ID.
- _, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias)
+ roomID, err := spec.NewRoomID(req.RoomIDOrAlias)
if err != nil {
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("room ID %q is invalid: %w", req.RoomIDOrAlias, err)}
}
@@ -170,8 +170,8 @@ func (r *Joiner) performJoinRoomByID(
// If the server name in the room ID isn't ours then it's a
// possible candidate for finding the room via federation. Add
// it to the list of servers to try.
- if !r.Cfg.Matrix.IsLocalServerName(domain) {
- req.ServerNames = append(req.ServerNames, domain)
+ if !r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
+ req.ServerNames = append(req.ServerNames, roomID.Domain())
}
// Prepare the template for the join event.
@@ -203,7 +203,7 @@ func (r *Joiner) performJoinRoomByID(
req.Content = map[string]interface{}{}
}
req.Content["membership"] = spec.Join
- if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req); aerr != nil {
+ if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
return "", "", aerr
} else if authorisedVia != "" {
req.Content["join_authorised_via_users_server"] = authorisedVia
@@ -226,17 +226,17 @@ func (r *Joiner) performJoinRoomByID(
// Force a federated join if we're dealing with a pending invite
// and we aren't in the room.
- isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID)
+ isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
if err == nil && !serverInRoom && isInvitePending {
- _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
- if ierr != nil {
- return "", "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
+ inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender)
+ if queryErr != nil {
+ return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
}
// If we were invited by someone from another server then we can
// assume they are in the room so we can join via them.
- if !r.Cfg.Matrix.IsLocalServerName(inviterDomain) {
- req.ServerNames = append(req.ServerNames, inviterDomain)
+ if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
+ req.ServerNames = append(req.ServerNames, inviter.Domain())
forceFederatedJoin = true
memberEvent := gjson.Parse(string(inviteEvent.JSON()))
// only set unsigned if we've got a content.membership, which we _should_
@@ -298,12 +298,8 @@ func (r *Joiner) performJoinRoomByID(
// a member of the room. This is best-effort (as in we won't
// fail if we can't find the existing membership) because there
// is really no harm in just sending another membership event.
- membershipReq := &api.QueryMembershipForUserRequest{
- RoomID: req.RoomIDOrAlias,
- UserID: userID.String(),
- }
membershipRes := &api.QueryMembershipForUserResponse{}
- _ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes)
+ _ = r.Queryer.QueryMembershipForSenderID(ctx, *roomID, senderID, membershipRes)
// If we haven't already joined the room then send an event
// into the room changing our membership status.
@@ -328,7 +324,7 @@ func (r *Joiner) performJoinRoomByID(
// The room doesn't exist locally. If the room ID looks like it should
// be ours then this probably means that we've nuked our database at
// some point.
- if r.Cfg.Matrix.IsLocalServerName(domain) {
+ if r.Cfg.Matrix.IsLocalServerName(roomID.Domain()) {
// If there are no more server names to try then give up here.
// Otherwise we'll try a federated join as normal, since it's quite
// possible that the room still exists on other servers.
@@ -376,15 +372,12 @@ func (r *Joiner) performFederatedJoinRoomByID(
func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin(
ctx context.Context,
joinReq *rsAPI.PerformJoinRequest,
+ senderID spec.SenderID,
) (string, error) {
roomID, err := spec.NewRoomID(joinReq.RoomIDOrAlias)
if err != nil {
return "", err
}
- userID, err := spec.NewUserID(joinReq.UserID, true)
- if err != nil {
- return "", err
- }
- return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, *userID)
+ return r.Queryer.QueryRestrictedJoinAllowed(ctx, *roomID, senderID)
}
diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go
index 094537f8..1b23cc1f 100644
--- a/roomserver/internal/perform/perform_leave.go
+++ b/roomserver/internal/perform/perform_leave.go
@@ -53,16 +53,12 @@ func (r *Leaver) PerformLeave(
req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse,
) ([]api.OutputEvent, error) {
- _, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
- if err != nil {
- return nil, fmt.Errorf("supplied user ID %q in incorrect format", req.UserID)
- }
- if !r.Cfg.Matrix.IsLocalServerName(domain) {
- return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID)
+ if !r.Cfg.Matrix.IsLocalServerName(req.Leaver.Domain()) {
+ return nil, fmt.Errorf("user %q does not belong to this homeserver", req.Leaver.String())
}
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomID,
- "user_id": req.UserID,
+ "user_id": req.Leaver.String(),
})
logger.Info("User requested to leave join")
if strings.HasPrefix(req.RoomID, "!") {
@@ -82,21 +78,26 @@ func (r *Leaver) performLeaveRoomByID(
req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam
) ([]api.OutputEvent, error) {
+ leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver)
+ if err != nil {
+ return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
+ }
+
// If there's an invite outstanding for the room then respond to
// that.
- isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID)
+ isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver)
if err == nil && isInvitePending {
- _, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser)
- if serr != nil {
- return nil, fmt.Errorf("sender %q is invalid", senderUser)
+ sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser)
+ if serr != nil || sender == nil {
+ return nil, fmt.Errorf("sender %q has no matching userID", senderUser)
}
- if !r.Cfg.Matrix.IsLocalServerName(senderDomain) {
- return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID)
+ if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) {
+ return r.performFederatedRejectInvite(ctx, req, res, *sender, eventID, leaver)
}
// check that this is not a "server notice room"
accData := &userapi.QueryAccountDataResponse{}
if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{
- UserID: req.UserID,
+ UserID: req.Leaver.String(),
RoomID: req.RoomID,
DataType: "m.tag",
}, accData); err != nil {
@@ -127,7 +128,7 @@ func (r *Leaver) performLeaveRoomByID(
StateToFetch: []gomatrixserverlib.StateKeyTuple{
{
EventType: spec.MRoomMember,
- StateKey: req.UserID,
+ StateKey: string(leaver),
},
},
}
@@ -141,26 +142,18 @@ func (r *Leaver) performLeaveRoomByID(
// Now let's see if the user is in the room.
if len(latestRes.StateEvents) == 0 {
- return nil, fmt.Errorf("user %q is not a member of room %q", req.UserID, req.RoomID)
+ return nil, fmt.Errorf("user %q is not a member of room %q", req.Leaver.String(), req.RoomID)
}
membership, err := latestRes.StateEvents[0].Membership()
if err != nil {
return nil, fmt.Errorf("error getting membership: %w", err)
}
if membership != spec.Join && membership != spec.Invite {
- return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.UserID, membership)
+ return nil, fmt.Errorf("user %q is not joined to the room (membership is %q)", req.Leaver.String(), membership)
}
// Prepare the template for the leave event.
- fullUserID, err := spec.NewUserID(req.UserID, true)
- if err != nil {
- return nil, err
- }
- senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, *fullUserID)
- if err != nil {
- return nil, err
- }
- senderIDString := string(senderID)
+ senderIDString := string(leaver)
proto := gomatrixserverlib.ProtoEvent{
Type: spec.MRoomMember,
SenderID: senderIDString,
@@ -175,16 +168,13 @@ func (r *Leaver) performLeaveRoomByID(
return nil, fmt.Errorf("eb.SetUnsigned: %w", err)
}
- // Get the sender domain.
- senderDomain := fullUserID.Domain()
-
// We know that the user is in the room at this point so let's build
// a leave event.
// TODO: Check what happens if the room exists on the server
// but everyone has since left. I suspect it does the wrong thing.
var buildRes rsAPI.QueryLatestEventsAndStateResponse
- identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
+ identity, err := r.Cfg.Matrix.SigningIdentityFor(req.Leaver.Domain())
if err != nil {
return nil, fmt.Errorf("SigningIdentityFor: %w", err)
}
@@ -201,8 +191,8 @@ func (r *Leaver) performLeaveRoomByID(
{
Kind: api.KindNew,
Event: event,
- Origin: senderDomain,
- SendAsServer: string(senderDomain),
+ Origin: req.Leaver.Domain(),
+ SendAsServer: string(req.Leaver.Domain()),
},
},
}
@@ -219,21 +209,17 @@ func (r *Leaver) performFederatedRejectInvite(
ctx context.Context,
req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam
- senderUser, eventID string,
+ inviteSender spec.UserID, eventID string,
+ leaver spec.SenderID,
) ([]api.OutputEvent, error) {
- _, domain, err := gomatrixserverlib.SplitID('@', senderUser)
- if err != nil {
- return nil, fmt.Errorf("user ID %q invalid: %w", senderUser, err)
- }
-
// Ask the federation sender to perform a federated leave for us.
leaveReq := fsAPI.PerformLeaveRequest{
RoomID: req.RoomID,
- UserID: req.UserID,
- ServerNames: []spec.ServerName{domain},
+ UserID: req.Leaver.String(),
+ ServerNames: []spec.ServerName{inviteSender.Domain()},
}
leaveRes := fsAPI.PerformLeaveResponse{}
- if err = r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil {
+ if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil {
// failures in PerformLeave should NEVER stop us from telling other components like the
// sync API that the invite was withdrawn. Otherwise we can end up with stuck invites.
util.GetLogger(ctx).WithError(err).Errorf("failed to PerformLeave, still retiring invite event")
@@ -244,7 +230,7 @@ func (r *Leaver) performFederatedRejectInvite(
util.GetLogger(ctx).WithError(err).Errorf("failed to get RoomInfo, still retiring invite event")
}
- updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, req.UserID, true, info.RoomVersion)
+ updater, err := r.DB.MembershipUpdater(ctx, req.RoomID, string(leaver), true, info.RoomVersion)
if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event")
}
@@ -267,9 +253,10 @@ func (r *Leaver) performFederatedRejectInvite(
{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{
- EventID: eventID,
- Membership: "leave",
- TargetUserID: req.UserID,
+ EventID: eventID,
+ RoomID: req.RoomID,
+ Membership: "leave",
+ TargetSenderID: leaver,
},
},
}, nil
diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go
index 5710352b..1aaa42c9 100644
--- a/roomserver/internal/perform/perform_upgrade.go
+++ b/roomserver/internal/perform/perform_upgrade.go
@@ -38,19 +38,15 @@ type Upgrader struct {
// PerformRoomUpgrade upgrades a room from one version to another
func (r *Upgrader) PerformRoomUpgrade(
ctx context.Context,
- roomID, userID string, roomVersion gomatrixserverlib.RoomVersion,
+ roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion,
) (newRoomID string, err error) {
return r.performRoomUpgrade(ctx, roomID, userID, roomVersion)
}
func (r *Upgrader) performRoomUpgrade(
ctx context.Context,
- roomID, userID string, roomVersion gomatrixserverlib.RoomVersion,
+ roomID string, userID spec.UserID, roomVersion gomatrixserverlib.RoomVersion,
) (string, error) {
- _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID)
- if err != nil {
- return "", api.ErrNotAllowed{Err: fmt.Errorf("error validating the user ID")}
- }
evTime := time.Now()
// Return an immediate error if the room does not exist
@@ -58,14 +54,20 @@ func (r *Upgrader) performRoomUpgrade(
return "", err
}
+ senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
+ return "", err
+ }
+
// 1. Check if the user is authorized to actually perform the upgrade (can send m.room.tombstone)
- if !r.userIsAuthorized(ctx, userID, roomID) {
+ if !r.userIsAuthorized(ctx, senderID, roomID) {
return "", api.ErrNotAllowed{Err: fmt.Errorf("You don't have permission to upgrade the room, power level too low.")}
}
// TODO (#267): Check room ID doesn't clash with an existing one, and we
// probably shouldn't be using pseudo-random strings, maybe GUIDs?
- newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userDomain)
+ newRoomID := fmt.Sprintf("!%s:%s", util.RandomString(16), userID.Domain())
// Get the existing room state for the old room.
oldRoomReq := &api.QueryLatestEventsAndStateRequest{
@@ -77,25 +79,25 @@ func (r *Upgrader) performRoomUpgrade(
}
// Make the tombstone event
- tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, userID, roomID, newRoomID)
+ tombstoneEvent, pErr := r.makeTombstoneEvent(ctx, evTime, senderID, userID.Domain(), roomID, newRoomID)
if pErr != nil {
return "", pErr
}
// Generate the initial events we need to send into the new room. This includes copied state events and bans
// as well as the power level events needed to set up the room
- eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, userID, roomID, roomVersion, tombstoneEvent)
+ eventsToMake, pErr := r.generateInitialEvents(ctx, oldRoomRes, senderID, roomID, roomVersion, tombstoneEvent)
if pErr != nil {
return "", pErr
}
// Send the setup events to the new room
- if pErr = r.sendInitialEvents(ctx, evTime, userID, userDomain, newRoomID, roomVersion, eventsToMake); pErr != nil {
+ if pErr = r.sendInitialEvents(ctx, evTime, senderID, userID.Domain(), newRoomID, roomVersion, eventsToMake); pErr != nil {
return "", pErr
}
// 5. Send the tombstone event to the old room
- if pErr = r.sendHeaderedEvent(ctx, userDomain, tombstoneEvent, string(userDomain)); pErr != nil {
+ if pErr = r.sendHeaderedEvent(ctx, userID.Domain(), tombstoneEvent, string(userID.Domain())); pErr != nil {
return "", pErr
}
@@ -105,17 +107,17 @@ func (r *Upgrader) performRoomUpgrade(
}
// If the old room had a canonical alias event, it should be deleted in the old room
- if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, userID, userDomain, roomID); pErr != nil {
+ if pErr = r.clearOldCanonicalAliasEvent(ctx, oldRoomRes, evTime, senderID, userID.Domain(), roomID); pErr != nil {
return "", pErr
}
// 4. Move local aliases to the new room
- if pErr = moveLocalAliases(ctx, roomID, newRoomID, userID, r.URSAPI); pErr != nil {
+ if pErr = moveLocalAliases(ctx, roomID, newRoomID, senderID, userID, r.URSAPI); pErr != nil {
return "", pErr
}
// 6. Restrict power levels in the old room
- if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, userID, userDomain, roomID); pErr != nil {
+ if pErr = r.restrictOldRoomPowerLevels(ctx, evTime, senderID, userID.Domain(), roomID); pErr != nil {
return "", pErr
}
@@ -130,7 +132,7 @@ func (r *Upgrader) getRoomPowerLevels(ctx context.Context, roomID string) (*goma
return oldPowerLevelsEvent.PowerLevels()
}
-func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error {
+func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error {
restrictedPowerLevelContent, pErr := r.getRoomPowerLevels(ctx, roomID)
if pErr != nil {
return pErr
@@ -147,7 +149,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T
restrictedPowerLevelContent.EventsDefault = restrictedDefaultPowerLevel
restrictedPowerLevelContent.Invite = restrictedDefaultPowerLevel
- restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{
+ restrictedPowerLevelsHeadered, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{
Type: spec.MRoomPowerLevels,
StateKey: "",
Content: restrictedPowerLevelContent,
@@ -165,7 +167,7 @@ func (r *Upgrader) restrictOldRoomPowerLevels(ctx context.Context, evTime time.T
}
func moveLocalAliases(ctx context.Context,
- roomID, newRoomID, userID string,
+ roomID, newRoomID string, senderID spec.SenderID, userID spec.UserID,
URSAPI api.RoomserverInternalAPI,
) (err error) {
@@ -175,14 +177,6 @@ func moveLocalAliases(ctx context.Context,
return fmt.Errorf("Failed to get old room aliases: %w", err)
}
- fullUserID, err := spec.NewUserID(userID, true)
- if err != nil {
- return fmt.Errorf("Failed to get userID: %w", err)
- }
- senderID, err := URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
- if err != nil {
- return fmt.Errorf("Failed to get senderID: %w", err)
- }
for _, alias := range aliasRes.Aliases {
removeAliasReq := api.RemoveRoomAliasRequest{SenderID: senderID, Alias: alias}
removeAliasRes := api.RemoveRoomAliasResponse{}
@@ -190,7 +184,7 @@ func moveLocalAliases(ctx context.Context,
return fmt.Errorf("Failed to remove old room alias: %w", err)
}
- setAliasReq := api.SetRoomAliasRequest{UserID: userID, Alias: alias, RoomID: newRoomID}
+ setAliasReq := api.SetRoomAliasRequest{UserID: userID.String(), Alias: alias, RoomID: newRoomID}
setAliasRes := api.SetRoomAliasResponse{}
if err = URSAPI.SetRoomAlias(ctx, &setAliasReq, &setAliasRes); err != nil {
return fmt.Errorf("Failed to set new room alias: %w", err)
@@ -199,7 +193,7 @@ func moveLocalAliases(ctx context.Context,
return nil
}
-func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, userID string, userDomain spec.ServerName, roomID string) error {
+func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, roomID string) error {
for _, event := range oldRoom.StateEvents {
if event.Type() != spec.MRoomCanonicalAlias || !event.StateKeyEquals("") {
continue
@@ -217,7 +211,7 @@ func (r *Upgrader) clearOldCanonicalAliasEvent(ctx context.Context, oldRoom *api
}
}
- emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, userID, roomID, gomatrixserverlib.FledglingEvent{
+ emptyCanonicalAliasEvent, resErr := r.makeHeaderedEvent(ctx, evTime, senderID, userDomain, roomID, gomatrixserverlib.FledglingEvent{
Type: spec.MRoomCanonicalAlias,
Content: map[string]interface{}{},
})
@@ -280,7 +274,7 @@ func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error
return nil
}
-func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string,
+func (r *Upgrader) userIsAuthorized(ctx context.Context, senderID spec.SenderID, roomID string,
) bool {
plEvent := api.GetStateEvent(ctx, r.URSAPI, roomID, gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomPowerLevels,
@@ -295,26 +289,18 @@ func (r *Upgrader) userIsAuthorized(ctx context.Context, userID, roomID string,
}
// Check for power level required to send tombstone event (marks the current room as obsolete),
// if not found, use the StateDefault power level
- fullUserID, err := spec.NewUserID(userID, true)
- if err != nil {
- return false
- }
- senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
- if err != nil {
- return false
- }
return pl.UserLevel(senderID) >= pl.EventLevel("m.room.tombstone", true)
}
// nolint:gocyclo
-func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, userID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) {
+func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.QueryLatestEventsAndStateResponse, senderID spec.SenderID, roomID string, newVersion gomatrixserverlib.RoomVersion, tombstoneEvent *types.HeaderedEvent) ([]gomatrixserverlib.FledglingEvent, error) {
state := make(map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent, len(oldRoom.StateEvents))
for _, event := range oldRoom.StateEvents {
if event.StateKey() == nil {
// This shouldn't ever happen, but better to be safe than sorry.
continue
}
- if event.Type() == spec.MRoomMember && !event.StateKeyEquals(userID) {
+ if event.Type() == spec.MRoomMember && !event.StateKeyEquals(string(senderID)) {
// With the exception of bans which we do want to copy, we
// should ignore membership events that aren't our own, as event auth will
// prevent us from being able to create membership events on behalf of other
@@ -330,6 +316,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
}
}
// skip events that rely on a specific user being present
+ // TODO: What to do here for pseudoIDs? It's checking non-member events for state keys with userIDs.
sKey := *event.StateKey()
if event.Type() != spec.MRoomMember && len(sKey) > 0 && sKey[:1] == "@" {
continue
@@ -340,10 +327,10 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
// The following events are ones that we are going to override manually
// in the following section.
override := map[gomatrixserverlib.StateKeyTuple]struct{}{
- {EventType: spec.MRoomCreate, StateKey: ""}: {},
- {EventType: spec.MRoomMember, StateKey: userID}: {},
- {EventType: spec.MRoomPowerLevels, StateKey: ""}: {},
- {EventType: spec.MRoomJoinRules, StateKey: ""}: {},
+ {EventType: spec.MRoomCreate, StateKey: ""}: {},
+ {EventType: spec.MRoomMember, StateKey: string(senderID)}: {},
+ {EventType: spec.MRoomPowerLevels, StateKey: ""}: {},
+ {EventType: spec.MRoomJoinRules, StateKey: ""}: {},
}
// The overridden events are essential events that must be present in the
@@ -355,7 +342,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
}
oldCreateEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCreate, StateKey: ""}]
- oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: userID}]
+ oldMembershipEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomMember, StateKey: string(senderID)}]
oldPowerLevelsEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomPowerLevels, StateKey: ""}]
oldJoinRulesEvent := state[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomJoinRules, StateKey: ""}]
@@ -364,7 +351,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
// in the create event (such as for the room types MSC).
newCreateContent := map[string]interface{}{}
_ = json.Unmarshal(oldCreateEvent.Content(), &newCreateContent)
- newCreateContent["creator"] = userID
+ newCreateContent["creator"] = string(senderID)
newCreateContent["room_version"] = newVersion
newCreateContent["predecessor"] = gomatrixserverlib.PreviousRoom{
EventID: tombstoneEvent.EventID(),
@@ -385,7 +372,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
newMembershipContent["membership"] = spec.Join
newMembershipEvent := gomatrixserverlib.FledglingEvent{
Type: spec.MRoomMember,
- StateKey: userID,
+ StateKey: string(senderID),
Content: newMembershipContent,
}
@@ -400,14 +387,6 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
return nil, fmt.Errorf("Power level event content was invalid")
}
- fullUserID, err := spec.NewUserID(userID, true)
- if err != nil {
- return nil, err
- }
- senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
- if err != nil {
- return nil, err
- }
tempPowerLevelsEvent, powerLevelsOverridden := createTemporaryPowerLevels(powerLevelContent, senderID)
// Now do the join rules event, same as the create and membership
@@ -470,21 +449,13 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
return eventsToMake, nil
}
-func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, userID string, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error {
+func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, senderID spec.SenderID, userDomain spec.ServerName, newRoomID string, newVersion gomatrixserverlib.RoomVersion, eventsToMake []gomatrixserverlib.FledglingEvent) error {
var err error
var builtEvents []*types.HeaderedEvent
authEvents := gomatrixserverlib.NewAuthEvents(nil)
for i, e := range eventsToMake {
depth := i + 1 // depth starts at 1
- fullUserID, userIDErr := spec.NewUserID(userID, true)
- if userIDErr != nil {
- return userIDErr
- }
- senderID, queryErr := r.URSAPI.QuerySenderIDForUser(ctx, newRoomID, *fullUserID)
- if queryErr != nil {
- return queryErr
- }
proto := gomatrixserverlib.ProtoEvent{
SenderID: string(senderID),
RoomID: newRoomID,
@@ -549,7 +520,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
func (r *Upgrader) makeTombstoneEvent(
ctx context.Context,
evTime time.Time,
- userID, roomID, newRoomID string,
+ senderID spec.SenderID, senderDomain spec.ServerName, roomID, newRoomID string,
) (*types.HeaderedEvent, error) {
content := map[string]interface{}{
"body": "This room has been replaced",
@@ -559,30 +530,21 @@ func (r *Upgrader) makeTombstoneEvent(
Type: "m.room.tombstone",
Content: content,
}
- return r.makeHeaderedEvent(ctx, evTime, userID, roomID, event)
+ return r.makeHeaderedEvent(ctx, evTime, senderID, senderDomain, roomID, event)
}
-func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, userID, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) {
- fullUserID, err := spec.NewUserID(userID, true)
- if err != nil {
- return nil, err
- }
- senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
- if err != nil {
- return nil, err
- }
+func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, senderID spec.SenderID, senderDomain spec.ServerName, roomID string, event gomatrixserverlib.FledglingEvent) (*types.HeaderedEvent, error) {
proto := gomatrixserverlib.ProtoEvent{
SenderID: string(senderID),
RoomID: roomID,
Type: event.Type,
StateKey: &event.StateKey,
}
- err = proto.SetContent(event.Content)
+ err := proto.SetContent(event.Content)
if err != nil {
return nil, fmt.Errorf("failed to set new %q event content: %w", proto.Type, err)
}
// Get the sender domain.
- senderDomain := fullUserID.Domain()
identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain)
if err != nil {
return nil, fmt.Errorf("failed to get signing identity for %q: %w", senderDomain, err)
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index ae2b7cf5..caea6b52 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -48,7 +48,7 @@ type Queryer struct {
Cfg *config.Dendrite
}
-func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
+func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
roomInfo, err := r.QueryRoomInfo(ctx, roomID)
if err != nil || roomInfo == nil || roomInfo.IsStub() {
return nil, err
@@ -64,7 +64,7 @@ func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID
return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err)
}
- userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID)
+ userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed")
return nil, fmt.Errorf("InternalServerError: %w", err)
@@ -220,13 +220,14 @@ func (r *Queryer) QueryEventsByID(
return nil
}
-// QueryMembershipForUser implements api.RoomserverInternalAPI
-func (r *Queryer) QueryMembershipForUser(
+// QueryMembershipForSenderID implements api.RoomserverInternalAPI
+func (r *Queryer) QueryMembershipForSenderID(
ctx context.Context,
- request *api.QueryMembershipForUserRequest,
+ roomID spec.RoomID,
+ senderID spec.SenderID,
response *api.QueryMembershipForUserResponse,
) error {
- info, err := r.DB.RoomInfo(ctx, request.RoomID)
+ info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil {
return err
}
@@ -236,7 +237,7 @@ func (r *Queryer) QueryMembershipForUser(
}
response.RoomExists = true
- membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID)
+ membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID)
if err != nil {
return err
}
@@ -264,6 +265,24 @@ func (r *Queryer) QueryMembershipForUser(
return err
}
+// QueryMembershipForUser implements api.RoomserverInternalAPI
+func (r *Queryer) QueryMembershipForUser(
+ ctx context.Context,
+ request *api.QueryMembershipForUserRequest,
+ response *api.QueryMembershipForUserResponse,
+) error {
+ senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID)
+ if err != nil {
+ return err
+ }
+
+ roomID, err := spec.NewRoomID(request.RoomID)
+ if err != nil {
+ return err
+ }
+ return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response)
+}
+
// QueryMembershipAtEvent returns the known memberships at a given event.
// If the state before an event is not known, an empty list will be returned
// for that event instead.
@@ -373,7 +392,7 @@ func (r *Queryer) QueryMembershipsForRoom(
// If no sender is specified then we will just return the entire
// set of memberships for the room, regardless of whether a specific
// user is allowed to see them or not.
- if request.Sender == "" {
+ if request.SenderID == "" {
var events []types.Event
var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly)
@@ -388,18 +407,15 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err)
}
for _, event := range events {
- sender := spec.UserID{}
- userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
- if queryErr == nil && userID != nil {
- sender = *userID
- }
- clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
+ clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.QueryUserIDForSender(ctx, roomID, senderID)
+ }, event)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
return nil
}
- membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender)
+ membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID)
if err != nil {
return err
}
@@ -442,12 +458,9 @@ func (r *Queryer) QueryMembershipsForRoom(
}
for _, event := range events {
- sender := spec.UserID{}
- userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
- if err == nil && userID != nil {
- sender = *userID
- }
- clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
+ clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.QueryUserIDForSender(ctx, roomID, senderID)
+ }, event)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
@@ -489,6 +502,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
ctx context.Context,
serverName spec.ServerName,
eventID string,
+ roomID string,
) (allowed bool, err error) {
events, err := r.DB.EventNIDs(ctx, []string{eventID})
if err != nil {
@@ -518,7 +532,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
}
return helpers.CheckServerAllowedToSeeEvent(
- ctx, r.DB, info, eventID, serverName, isInRoom,
+ ctx, r.DB, info, roomID, eventID, serverName, isInRoom,
)
}
@@ -909,8 +923,8 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq
return nil
}
-func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) {
- pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), userID.String())
+func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) {
+ pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), senderID)
return pending, err
}
@@ -926,8 +940,8 @@ func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eve
return res, err
}
-func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, userID spec.UserID) (bool, error) {
- _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, userID.String())
+func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, senderID spec.SenderID) (bool, error) {
+ _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, senderID)
return isIn, err
}
@@ -957,7 +971,7 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse
}
// nolint:gocyclo
-func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) {
+func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
// Look up if we know anything about the room. If it doesn't exist
// or is a stub entry then we can't do anything.
roomInfo, err := r.DB.RoomInfo(ctx, roomID.String())
@@ -972,7 +986,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
return "", err
}
- return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID)
+ return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID)
}
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go
index 5e6ba7d4..90c94bbc 100644
--- a/roomserver/roomserver_test.go
+++ b/roomserver/roomserver_test.go
@@ -722,7 +722,7 @@ func TestQueryRestrictedJoinAllowed(t *testing.T) {
roomID, _ := spec.NewRoomID(testRoom.ID)
userID, _ := spec.NewUserID(bob.ID, true)
- got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, *userID)
+ got, err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), *roomID, spec.SenderID(userID.String()))
if tc.wantError && err == nil {
t.Fatal("expected error, got none")
}
@@ -822,17 +822,6 @@ func TestUpgrade(t *testing.T) {
wantNewRoom bool
}{
{
- name: "invalid userID",
- upgradeUser: "!notvalid:test",
- roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
- room := test.NewRoom(t, alice)
- if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
- t.Errorf("failed to send events: %v", err)
- }
- return room.ID
- },
- },
- {
name: "invalid roomID",
upgradeUser: alice.ID,
roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
@@ -1049,7 +1038,11 @@ func TestUpgrade(t *testing.T) {
}
roomID := tc.roomFunc(rsAPI)
- newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, tc.upgradeUser, version.DefaultRoomVersion())
+ userID, err := spec.NewUserID(tc.upgradeUser, true)
+ if err != nil {
+ t.Fatalf("upgrade userID is invalid")
+ }
+ newRoomID, err := rsAPI.PerformRoomUpgrade(processCtx.Context(), roomID, *userID, version.DefaultRoomVersion())
if err != nil && tc.wantNewRoom {
t.Fatal(err)
}
diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go
index 2d27d799..ef446378 100644
--- a/roomserver/storage/interface.go
+++ b/roomserver/storage/interface.go
@@ -131,7 +131,7 @@ type Database interface {
// in this room, along a boolean set to true if the user is still in this room,
// false if not.
// Returns an error if there was a problem talking to the database.
- GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error)
+ GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error)
// Lookup the membership event numeric IDs for all user that are or have
// been members of a given room. Only lookup events of "join" membership if
// joinOnly is set to true.
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index cb12b3f5..85a1ba7a 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -490,10 +490,10 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
})
}
-func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) {
+func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderID spec.SenderID) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) {
var requestSenderUserNID types.EventStateKeyNID
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, requestSenderUserID)
+ requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, string(requestSenderID))
return err
})
if err != nil {
@@ -936,6 +936,7 @@ func extractRoomVersionFromCreateEvent(event gomatrixserverlib.PDU) (
return roomVersion, err
}
+// nolint:gocyclo
// MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec:
// "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid."
// https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
@@ -1014,7 +1015,7 @@ func (d *EventDatabase) MaybeRedactEvent(
switch {
case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact:
// 1. The power level of the redaction event’s sender is greater than or equal to the redact level.
- case sender1Domain == sender2Domain:
+ case sender1Domain != "" && sender2Domain != "" && sender1Domain == sender2Domain:
// 2. The domain of the redaction event’s sender matches that of the original event’s sender.
default:
ignoreRedaction = true
diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go
index 47eb544e..d3f1c9dd 100644
--- a/setup/mscs/msc2836/msc2836.go
+++ b/setup/mscs/msc2836/msc2836.go
@@ -154,7 +154,7 @@ type reqCtx struct {
rsAPI roomserver.RoomserverInternalAPI
db Database
req *EventRelationshipRequest
- userID string
+ userID spec.UserID
roomVersion gomatrixserverlib.RoomVersion
// federated request args
@@ -173,10 +173,17 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP
JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)),
}
}
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: 400,
+ JSON: spec.BadJSON(fmt.Sprintf("invalid json: %s", err)),
+ }
+ }
rc := reqCtx{
ctx: req.Context(),
req: relation,
- userID: device.UserID,
+ userID: *userID,
rsAPI: rsAPI,
fsAPI: fsAPI,
isFederatedRequest: false,
diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go
index 551d7ad4..e32d6a9f 100644
--- a/setup/mscs/msc2836/msc2836_test.go
+++ b/setup/mscs/msc2836/msc2836_test.go
@@ -529,6 +529,10 @@ func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID str
return spec.NewUserID(string(senderID), true)
}
+func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
+ return spec.SenderID(userID.String()), nil
+}
+
func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error {
for _, eventID := range req.EventIDs {
ev := r.events[eventID]
@@ -540,7 +544,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver
}
func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error {
- rooms := r.userToJoinedRooms[req.UserID]
+ rooms := r.userToJoinedRooms[req.UserID.String()]
for _, roomID := range rooms {
if roomID == req.RoomID {
res.IsInRoom = true
diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go
index 8a2a0b1f..c5f2db9c 100644
--- a/syncapi/consumers/roomserver.go
+++ b/syncapi/consumers/roomserver.go
@@ -373,7 +373,15 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst
// TODO: check that it's a join and not a profile change (means unmarshalling prev_content)
if membership == spec.Join {
// check it's a local join
- if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil {
+ if ev.StateKey() == nil {
+ return sp, fmt.Errorf("unexpected nil state_key")
+ }
+
+ userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
+ if err != nil || userID == nil {
+ return sp, fmt.Errorf("failed getting userID for sender: %w", err)
+ }
+ if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) {
return sp, nil
}
@@ -395,9 +403,15 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
if msg.Event.StateKey() == nil {
return
}
- if _, _, err := s.cfg.Matrix.SplitLocalID('@', *msg.Event.StateKey()); err != nil {
+
+ userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey()))
+ if err != nil || userID == nil {
+ return
+ }
+ if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) {
return
}
+
pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)
if err != nil {
sentry.CaptureException(err)
@@ -440,7 +454,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
// Notify any active sync requests that the invite has been retired.
s.inviteStream.Advance(pduPos)
- s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID)
+ userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID)
+ if err != nil || userID == nil {
+ log.WithFields(log.Fields{
+ "event_id": msg.EventID,
+ "sender_id": msg.TargetSenderID,
+ log.ErrorKey: err,
+ }).Errorf("failed to find userID for sender")
+ return
+ }
+ s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, userID.String())
}
func (s *OutputRoomEventConsumer) onNewPeek(
diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go
index 7449b464..ab1a7f83 100644
--- a/syncapi/internal/history_visibility.go
+++ b/syncapi/internal/history_visibility.go
@@ -134,9 +134,17 @@ func ApplyHistoryVisibilityFilter(
}
}
// NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
- if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(userID) {
- eventsFiltered = append(eventsFiltered, ev)
- continue
+
+ user, err := spec.NewUserID(userID, true)
+ if err != nil {
+ return nil, err
+ }
+ senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user)
+ if err == nil {
+ if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
+ eventsFiltered = append(eventsFiltered, ev)
+ continue
+ }
}
// Always allow history evVis events on boundaries. This is done
// by setting the effective evVis to the least restrictive
diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go
index ad5935cd..f4b6ace5 100644
--- a/syncapi/internal/keychange.go
+++ b/syncapi/internal/keychange.go
@@ -169,12 +169,16 @@ func TrackChangedUsers(
if err != nil {
return nil, nil, err
}
- for _, state := range stateRes.Rooms {
+ for roomID, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != spec.Join {
continue
}
- queryRes.UserIDsToCount[tuple.StateKey]--
+ user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
+ if queryErr != nil || user == nil {
+ continue
+ }
+ queryRes.UserIDsToCount[user.String()]--
}
}
@@ -211,14 +215,18 @@ func TrackChangedUsers(
if err != nil {
return nil, left, err
}
- for _, state := range stateRes.Rooms {
+ for roomID, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != spec.Join {
continue
}
// new user who we weren't previously sharing rooms with
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
- changed = append(changed, tuple.StateKey) // changed is returned
+ user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
+ if err != nil || user == nil {
+ continue
+ }
+ changed = append(changed, user.String()) // changed is returned
}
}
}
diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go
index 23c2ecba..efa64147 100644
--- a/syncapi/internal/keychange_test.go
+++ b/syncapi/internal/keychange_test.go
@@ -64,6 +64,10 @@ type mockRoomserverAPI struct {
roomIDToJoinedMembers map[string][]string
}
+func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ return spec.NewUserID(string(senderID), true)
+}
+
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
return nil
diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go
index f7645685..4ee7c860 100644
--- a/syncapi/notifier/notifier.go
+++ b/syncapi/notifier/notifier.go
@@ -20,6 +20,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
@@ -36,7 +37,8 @@ import (
// the event, but the token has already advanced by the time they fetch it, resulting
// in missed events.
type Notifier struct {
- lock *sync.RWMutex
+ lock *sync.RWMutex
+ rsAPI api.SyncRoomserverAPI
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToJoinedUsers map[string]*userIDSet
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
@@ -55,8 +57,9 @@ type Notifier struct {
// NewNotifier creates a new notifier set to the given sync position.
// In order for this to be of any use, the Notifier needs to be told all rooms and
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
-func NewNotifier() *Notifier {
+func NewNotifier(rsAPI api.SyncRoomserverAPI) *Notifier {
return &Notifier{
+ rsAPI: rsAPI,
roomIDToJoinedUsers: make(map[string]*userIDSet),
roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
@@ -104,26 +107,32 @@ func (n *Notifier) OnNewEvent(
peekingDevicesToNotify := n._peekingDevices(ev.RoomID())
// If this is an invite, also add in the invitee to this list.
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
- targetUserID := *ev.StateKey()
- membership, err := ev.Membership()
+ targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err != nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
- "Notifier.OnNewEvent: Failed to unmarshal member event",
+ "Notifier.OnNewEvent: Failed to find the userID for this event",
)
} else {
- // Keep the joined user map up-to-date
- switch membership {
- case spec.Invite:
- usersToNotify = append(usersToNotify, targetUserID)
- case spec.Join:
- // Manually append the new user's ID so they get notified
- // along all members in the room
- usersToNotify = append(usersToNotify, targetUserID)
- n._addJoinedUser(ev.RoomID(), targetUserID)
- case spec.Leave:
- fallthrough
- case spec.Ban:
- n._removeJoinedUser(ev.RoomID(), targetUserID)
+ membership, err := ev.Membership()
+ if err != nil {
+ log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
+ "Notifier.OnNewEvent: Failed to unmarshal member event",
+ )
+ } else {
+ // Keep the joined user map up-to-date
+ switch membership {
+ case spec.Invite:
+ usersToNotify = append(usersToNotify, targetUserID.String())
+ case spec.Join:
+ // Manually append the new user's ID so they get notified
+ // along all members in the room
+ usersToNotify = append(usersToNotify, targetUserID.String())
+ n._addJoinedUser(ev.RoomID(), targetUserID.String())
+ case spec.Leave:
+ fallthrough
+ case spec.Ban:
+ n._removeJoinedUser(ev.RoomID(), targetUserID.String())
+ }
}
}
}
diff --git a/syncapi/notifier/notifier_test.go b/syncapi/notifier/notifier_test.go
index 36577a0e..7076f713 100644
--- a/syncapi/notifier/notifier_test.go
+++ b/syncapi/notifier/notifier_test.go
@@ -22,9 +22,11 @@ import (
"testing"
"time"
+ "github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
)
@@ -105,9 +107,15 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
}
}
+type TestRoomServer struct{ api.SyncRoomserverAPI }
+
+func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ return spec.NewUserID(string(senderID), true)
+}
+
// Test that the current position is returned if a request is already behind.
func TestImmediateNotification(t *testing.T) {
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld))
if err != nil {
@@ -118,7 +126,7 @@ func TestImmediateNotification(t *testing.T) {
// Test that new events to a joined room unblocks the request.
func TestNewEventAndJoinedToRoom(t *testing.T) {
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},
@@ -144,7 +152,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
}
func TestCorrectStream(t *testing.T) {
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
stream := lockedFetchUserStream(n, bob, bobDev)
if stream.UserID != bob {
@@ -156,7 +164,7 @@ func TestCorrectStream(t *testing.T) {
}
func TestCorrectStreamWakeup(t *testing.T) {
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
awoken := make(chan string)
@@ -184,7 +192,7 @@ func TestCorrectStreamWakeup(t *testing.T) {
// Test that an invite unblocks the request
func TestNewInviteEventForUser(t *testing.T) {
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},
@@ -241,7 +249,7 @@ func TestEDUWakeup(t *testing.T) {
// Test that all blocked requests get woken up on a new event.
func TestMultipleRequestWakeup(t *testing.T) {
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},
@@ -278,7 +286,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
// listen as bob. Make bob leave room. Make alice send event to room.
// Make sure alice gets woken up only and not bob as well.
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},
diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go
index 7fb88faa..55fd3c5a 100644
--- a/syncapi/routing/context.go
+++ b/syncapi/routing/context.go
@@ -85,9 +85,16 @@ func Context(
*filter.Rooms = append(*filter.Rooms, roomID)
}
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.InvalidParam("Device UserID is invalid"),
+ }
+ }
ctx := req.Context()
membershipRes := roomserver.QueryMembershipForUserResponse{}
- membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID}
+ membershipReq := roomserver.QueryMembershipForUserRequest{UserID: *userID, RoomID: roomID}
if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil {
logrus.WithError(err).Error("unable to query membership")
return util.JSONResponse{
@@ -217,12 +224,9 @@ func Context(
}
}
- sender := spec.UserID{}
- userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID())
- if err == nil && userID != nil {
- sender = *userID
- }
- ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender)
+ ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }, requestedEvent)
response := ContextRespsonse{
Event: &ev,
EventsAfter: eventsAfterClient,
diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go
index 63df7e83..de790e5c 100644
--- a/syncapi/routing/getevent.go
+++ b/syncapi/routing/getevent.go
@@ -106,8 +106,17 @@ func GetEvent(
if err == nil && senderUserID != nil {
sender = *senderUserID
}
+
+ sk := events[0].StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey()))
+ if err == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
return util.JSONResponse{
Code: http.StatusOK,
- JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender),
+ JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk),
}
}
diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go
index 813167a5..cf6769ba 100644
--- a/syncapi/routing/memberships.go
+++ b/syncapi/routing/memberships.go
@@ -59,14 +59,21 @@ func GetMemberships(
syncDB storage.Database, rsAPI api.SyncRoomserverAPI,
joinedOnly bool, membership, notMembership *string, at string,
) util.JSONResponse {
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.InvalidParam("Device UserID is invalid"),
+ }
+ }
queryReq := api.QueryMembershipForUserRequest{
RoomID: roomID,
- UserID: device.UserID,
+ UserID: *userID,
}
var queryRes api.QueryMembershipForUserResponse
- if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed")
+ if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil {
+ util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go
index 781fd53e..6784a27b 100644
--- a/syncapi/routing/messages.go
+++ b/syncapi/routing/messages.go
@@ -296,9 +296,13 @@ func OnIncomingMessagesRequest(
}
func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) {
+ fullUserID, err := spec.NewUserID(userID, true)
+ if err != nil {
+ return resp, err
+ }
req := api.QueryMembershipForUserRequest{
RoomID: roomID,
- UserID: userID,
+ UserID: *fullUserID,
}
if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil {
return api.QueryMembershipForUserResponse{}, err
diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go
index f21c684c..6efa065a 100644
--- a/syncapi/routing/relations.go
+++ b/syncapi/routing/relations.go
@@ -119,9 +119,18 @@ func Relations(
if err == nil && userID != nil {
sender = *userID
}
+
+ sk := event.StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
+ if err == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
res.Chunk = append(
res.Chunk,
- synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender),
+ synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk),
)
}
diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go
index add50b18..7d9182f4 100644
--- a/syncapi/routing/search.go
+++ b/syncapi/routing/search.go
@@ -235,6 +235,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
if err == nil && userID != nil {
sender = *userID
}
+
+ sk := event.StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
+ if err == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
results = append(results, Result{
Context: SearchContextResponse{
Start: startToken.String(),
@@ -248,7 +257,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
ProfileInfo: profileInfos,
},
Rank: eventScore[event.EventID()].Score,
- Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
+ Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk),
})
roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID())
diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go
index 5bd3b1f0..799e3d16 100644
--- a/syncapi/storage/shared/storage_consumer.go
+++ b/syncapi/storage/shared/storage_consumer.go
@@ -507,8 +507,20 @@ func (d *Database) CleanSendToDeviceUpdates(
// getMembershipFromEvent returns the value of content.membership iff the event is a state event
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
-func getMembershipFromEvent(ev gomatrixserverlib.PDU, userID string) (string, string) {
- if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) {
+func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userID string, rsAPI api.SyncRoomserverAPI) (string, string) {
+ if ev.StateKey() == nil || *ev.StateKey() == "" {
+ return "", ""
+ }
+ fullUser, err := spec.NewUserID(userID, true)
+ if err != nil {
+ return "", ""
+ }
+ senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser)
+ if err != nil {
+ return "", ""
+ }
+
+ if ev.Type() != "m.room.member" || !ev.StateKeyEquals(string(senderID)) {
return "", ""
}
membership, err := ev.Membership()
diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go
index df961385..8e79b71d 100644
--- a/syncapi/storage/shared/storage_sync.go
+++ b/syncapi/storage/shared/storage_sync.go
@@ -430,7 +430,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
for _, ev := range stateStreamEvents {
// Look for our membership in the state events and skip over any
// membership events that are not related to us.
- membership, prevMembership := getMembershipFromEvent(ev.PDU, userID)
+ membership, prevMembership := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI)
if membership == "" {
continue
}
@@ -556,7 +556,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents {
- if membership, _ := getMembershipFromEvent(ev.PDU, userID); membership != "" {
+ if membership, _ := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI); membership != "" {
if membership != spec.Join { // We've already added full state for all joined rooms above.
deltas[roomID] = types.StateDelta{
Membership: membership,
diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go
index a8b0a7b6..3a5badd9 100644
--- a/syncapi/streams/stream_invite.go
+++ b/syncapi/streams/stream_invite.go
@@ -70,11 +70,20 @@ func (p *InviteStreamProvider) IncrementalSync(
user = *sender
}
+ sk := inviteEvent.StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
+ if err == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
+
// skip ignored user events
if _, ok := req.IgnoredUsers.List[user.String()]; ok {
continue
}
- ir := types.NewInviteResponse(inviteEvent, user)
+ ir := types.NewInviteResponse(inviteEvent, user, sk)
req.Response.Rooms.Invite[roomID] = ir
}
diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go
index d214980b..f728d4ae 100644
--- a/syncapi/streams/stream_pdu.go
+++ b/syncapi/streams/stream_pdu.go
@@ -605,13 +605,17 @@ func (p *PDUStreamProvider) lazyLoadMembers(
// If this is a gapped incremental sync, we still want this membership
isGappedIncremental := limited && incremental
// We want this users membership event, keep it in the list
- stateKey := *event.StateKey()
- if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental || stateKey == device.UserID {
+ userID := ""
+ stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey()))
+ if err == nil && stateKeyUserID != nil {
+ userID = stateKeyUserID.String()
+ }
+ if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID {
newStateEvents = append(newStateEvents, event)
if !stateFilter.IncludeRedundantMembers {
- p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID())
+ p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, userID, event.EventID())
}
- delete(timelineUsers, stateKey)
+ delete(timelineUsers, userID)
}
} else {
newStateEvents = append(newStateEvents, event)
diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go
index ecbe05dd..64a4af75 100644
--- a/syncapi/syncapi.go
+++ b/syncapi/syncapi.go
@@ -60,7 +60,7 @@ func AddPublicRoutes(
}
eduCache := caching.NewTypingCache()
- notifier := notifier.NewNotifier()
+ notifier := notifier.NewNotifier(rsAPI)
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, caches, notifier)
notifier.SetCurrentPosition(streams.Latest(context.Background()))
if err = notifier.Load(context.Background(), syncDB); err != nil {
diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go
index 66fb1d01..358a0c97 100644
--- a/syncapi/synctypes/clientevent.go
+++ b/syncapi/synctypes/clientevent.go
@@ -55,18 +55,27 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
if err == nil && userID != nil {
sender = *userID
}
- evs = append(evs, ToClientEvent(se, format, sender))
+
+ sk := se.StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk))
+ if err == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
+ evs = append(evs, ToClientEvent(se, format, sender, sk))
}
return evs
}
// ToClientEvent converts a single server event to a client event.
-func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent {
+func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID, stateKey *string) ClientEvent {
ce := ClientEvent{
Content: spec.RawJSON(se.Content()),
Sender: sender.String(),
Type: se.Type(),
- StateKey: se.StateKey(),
+ StateKey: stateKey,
Unsigned: spec.RawJSON(se.Unsigned()),
OriginServerTS: se.OriginServerTS(),
EventID: se.EventID(),
@@ -77,3 +86,23 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp
}
return ce
}
+
+// ToClientEvent converts a single server event to a client event.
+// It provides default logic for event.SenderID & event.StateKey -> userID conversions.
+func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
+ sender := spec.UserID{}
+ userID, err := userIDQuery(event.RoomID(), event.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
+
+ sk := event.StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey()))
+ if err == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
+ return ToClientEvent(event, FormatAll, sender, sk)
+}
diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go
index 34179508..63c65b2a 100644
--- a/syncapi/synctypes/clientevent_test.go
+++ b/syncapi/synctypes/clientevent_test.go
@@ -48,7 +48,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
- ce := ToClientEvent(ev, FormatAll, *userID)
+ sk := ""
+ ce := ToClientEvent(ev, FormatAll, *userID, &sk)
if ce.EventID != ev.EventID() {
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
}
@@ -107,7 +108,8 @@ func TestToClientFormatSync(t *testing.T) {
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
- ce := ToClientEvent(ev, FormatSync, *userID)
+ sk := ""
+ ce := ToClientEvent(ev, FormatSync, *userID, &sk)
if ce.RoomID != "" {
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
}
diff --git a/syncapi/types/types.go b/syncapi/types/types.go
index a3dc7f54..cb3c362d 100644
--- a/syncapi/types/types.go
+++ b/syncapi/types/types.go
@@ -539,7 +539,7 @@ type InviteResponse struct {
}
// NewInviteResponse creates an empty response with initialised arrays.
-func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse {
+func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string) *InviteResponse {
res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{}
@@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteRe
// Then we'll see if we can create a partial of the invite event itself.
// This is needed for clients to work out *who* sent the invite.
- inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID)
+ inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID, stateKey)
inviteEvent.Unsigned = nil
if ev, err := json.Marshal(inviteEvent); err == nil {
res.InviteState.Events = append(res.InviteState.Events, ev)
diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go
index a79ce541..c1b7f70b 100644
--- a/syncapi/types/types_test.go
+++ b/syncapi/types/types_test.go
@@ -65,8 +65,14 @@ func TestNewInviteResponse(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true)
+ if err != nil {
+ t.Fatal(err)
+ }
+ skString := skUserID.String()
+ sk := &skString
- res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender)
+ res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk)
j, err := json.Marshal(res)
if err != nil {
t.Fatal(err)
diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go
index df507eb2..b2dc477a 100644
--- a/userapi/consumers/roomserver.go
+++ b/userapi/consumers/roomserver.go
@@ -306,7 +306,16 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
if queryErr == nil && userID != nil {
sender = *userID
}
- cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender)
+
+ sk := event.StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
+ if queryErr == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
+ cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk)
var member *localMembership
member, err = newLocalMembership(&cevent)
if err != nil {
@@ -539,12 +548,21 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
if err == nil && userID != nil {
sender = *userID
}
+
+ sk := event.StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
+ if queryErr == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
n := &api.Notification{
Actions: actions,
// UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync.
- Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender),
+ Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender, sk),
// TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it
@@ -792,10 +810,20 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
Type: event.Type(),
},
}
- if mem, err := event.Membership(); err == nil {
+ if mem, memberErr := event.Membership(); memberErr == nil {
req.Notification.Membership = mem
}
- if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) {
+ userID, err := spec.NewUserID(fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName), true)
+ if err != nil {
+ logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart)
+ return nil, err
+ }
+ localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, event.RoomID(), *userID)
+ if err != nil {
+ logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID())
+ return nil, err
+ }
+ if event.StateKey() != nil && *event.StateKey() == string(localSender) {
req.Notification.UserIsTarget = true
}
}
diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go
index 27dd373c..3017069b 100644
--- a/userapi/util/notify_test.go
+++ b/userapi/util/notify_test.go
@@ -104,8 +104,9 @@ func TestNotifyUserCountsAsync(t *testing.T) {
if err != nil {
t.Error(err)
}
+ sk := ""
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
- Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender),
+ Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender, &sk),
}); err != nil {
t.Error(err)
}