aboutsummaryrefslogtreecommitdiff
path: root/syncapi
diff options
context:
space:
mode:
authordevonh <devon.dmytro@gmail.com>2023-09-15 15:25:09 +0000
committerGitHub <noreply@github.com>2023-09-15 15:25:09 +0000
commitdb83789654ade3cf4f900e8fbcaa742b60c5dc6c (patch)
treeb68208908a4e73dba7fde6c72b91c52a47f8d018 /syncapi
parent8245b24100b0afaa046bb3fe52f0994f906c8ab1 (diff)
Move pseudoID ClientEvent hotswapping to a common location (#3199)
Fixes a variety of issues where clients were receiving pseudoIDs in places that should be userIDs. This change makes pseudoIDs work with sliding sync & element x. --------- Co-authored-by: Till <2353100+S7evinK@users.noreply.github.com>
Diffstat (limited to 'syncapi')
-rw-r--r--syncapi/routing/getevent.go18
-rw-r--r--syncapi/routing/relations.go21
-rw-r--r--syncapi/routing/search.go20
-rw-r--r--syncapi/streams/stream_invite.go15
-rw-r--r--syncapi/streams/stream_pdu.go131
-rw-r--r--syncapi/synctypes/clientevent.go365
-rw-r--r--syncapi/synctypes/clientevent_test.go30
-rw-r--r--syncapi/types/types.go35
-rw-r--r--syncapi/types/types_test.go29
9 files changed, 386 insertions, 278 deletions
diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go
index 886b1167..c089539f 100644
--- a/syncapi/routing/getevent.go
+++ b/syncapi/routing/getevent.go
@@ -118,25 +118,19 @@ func GetEvent(
}
}
- senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, events[0].SenderID())
- if err != nil || senderUserID == nil {
- util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("QueryUserIDForSender errored or returned nil-user ID when user should be part of a room")
+ clientEvent, err := synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("Failed converting to ClientEvent")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
- sk := events[0].StateKey()
- if sk != nil && *sk != "" {
- skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey()))
- if err == nil && skUserID != nil {
- skString := skUserID.String()
- sk = &skString
- }
- }
return util.JSONResponse{
Code: http.StatusOK,
- JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, senderUserID.String(), sk, events[0].Unsigned()),
+ JSON: *clientEvent,
}
}
diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go
index b451a7e2..935ba83b 100644
--- a/syncapi/routing/relations.go
+++ b/syncapi/routing/relations.go
@@ -130,23 +130,16 @@ func Relations(
// type if it was specified.
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents {
- sender := spec.UserID{}
- userID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, event.SenderID())
- if err == nil && userID != nil {
- sender = *userID
- }
-
- sk := event.StateKey()
- if sk != nil && *sk != "" {
- skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, spec.SenderID(*event.StateKey()))
- if err == nil && skUserID != nil {
- skString := skUserID.String()
- sk = &skString
- }
+ clientEvent, err := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
+ })
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("Failed converting to ClientEvent")
+ continue
}
res.Chunk = append(
res.Chunk,
- synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender.String(), sk, event.Unsigned()),
+ *clientEvent,
)
}
diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go
index f574781a..4a8be9f4 100644
--- a/syncapi/routing/search.go
+++ b/syncapi/routing/search.go
@@ -230,20 +230,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
profileInfos[userID.String()] = profile
}
- sender := spec.UserID{}
- userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
- if err == nil && userID != nil {
- sender = *userID
+ clientEvent, err := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).WithField("senderID", event.SenderID()).Error("Failed converting to ClientEvent")
+ continue
}
- sk := event.StateKey()
- if sk != nil && *sk != "" {
- skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
- if err == nil && skUserID != nil {
- skString := skUserID.String()
- sk = &skString
- }
- }
results = append(results, Result{
Context: SearchContextResponse{
Start: startToken.String(),
@@ -257,7 +251,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
ProfileInfo: profileInfos,
},
Rank: eventScore[event.EventID()].Score,
- Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender.String(), sk, event.Unsigned()),
+ Result: *clientEvent,
})
roomGroup := groups[event.RoomID().String()]
roomGroup.Results = append(roomGroup.Results, event.EventID())
diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go
index 1424dc2e..a3634c03 100644
--- a/syncapi/streams/stream_invite.go
+++ b/syncapi/streams/stream_invite.go
@@ -75,20 +75,15 @@ func (p *InviteStreamProvider) IncrementalSync(
user = *sender
}
- sk := inviteEvent.StateKey()
- if sk != nil && *sk != "" {
- skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
- if err == nil && skUserID != nil {
- skString := skUserID.String()
- sk = &skString
- }
- }
-
// skip ignored user events
if _, ok := req.IgnoredUsers.List[user.String()]; ok {
continue
}
- ir := types.NewInviteResponse(inviteEvent, user, sk, eventFormat)
+ ir, err := types.NewInviteResponse(ctx, p.rsAPI, inviteEvent, eventFormat)
+ if err != nil {
+ req.Log.WithError(err).Error("failed creating invite response")
+ continue
+ }
req.Response.Rooms.Invite[roomID] = ir
}
diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go
index eb1f0ef2..3abb0b3c 100644
--- a/syncapi/streams/stream_pdu.go
+++ b/syncapi/streams/stream_pdu.go
@@ -3,7 +3,6 @@ package streams
import (
"context"
"database/sql"
- "encoding/json"
"fmt"
"time"
@@ -16,8 +15,6 @@ import (
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib/spec"
- "github.com/tidwall/gjson"
- "github.com/tidwall/sjson"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/gomatrixserverlib"
@@ -359,23 +356,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// Now that we've filtered the timeline, work out which state events are still
// left. Anything that appears in the filtered timeline will be removed from the
// "state" section and kept in "timeline".
-
- // update the powerlevel event for timeline events
- for i, ev := range events {
- if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs {
- continue
- }
- if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") {
- continue
- }
- var newEvent gomatrixserverlib.PDU
- newEvent, err = p.updatePowerLevelEvent(ctx, ev, eventFormat)
- if err != nil {
- return r.From, err
- }
- events[i] = &rstypes.HeaderedEvent{PDU: newEvent}
- }
-
sEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering(
gomatrixserverlib.ToPDUs(removeDuplicates(delta.StateEvents, events)),
gomatrixserverlib.TopologicalOrderByAuthEvents,
@@ -390,15 +370,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
continue
}
delta.StateEvents[i-skipped] = he
- // update the powerlevel event for state events
- if ev.Version() == gomatrixserverlib.RoomVersionPseudoIDs && ev.Type() == spec.MRoomPowerLevels && ev.StateKeyEquals("") {
- var newEvent gomatrixserverlib.PDU
- newEvent, err = p.updatePowerLevelEvent(ctx, he, eventFormat)
- if err != nil {
- return r.From, err
- }
- delta.StateEvents[i-skipped] = &rstypes.HeaderedEvent{PDU: newEvent}
- }
}
delta.StateEvents = delta.StateEvents[:len(sEvents)-skipped]
@@ -468,79 +439,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
return latestPosition, nil
}
-func (p *PDUStreamProvider) updatePowerLevelEvent(ctx context.Context, ev *rstypes.HeaderedEvent, eventFormat synctypes.ClientEventFormat) (gomatrixserverlib.PDU, error) {
- pls, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev)
- if err != nil {
- return nil, err
- }
- newPls := make(map[string]int64)
- var userID *spec.UserID
- for user, level := range pls.Users {
- if eventFormat != synctypes.FormatSyncFederation {
- userID, err = p.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(user))
- if err != nil {
- return nil, err
- }
- user = userID.String()
- }
- newPls[user] = level
- }
- var newPlBytes, newEv []byte
- newPlBytes, err = json.Marshal(newPls)
- if err != nil {
- return nil, err
- }
- newEv, err = sjson.SetRawBytes(ev.JSON(), "content.users", newPlBytes)
- if err != nil {
- return nil, err
- }
-
- // do the same for prev content
- prevContent := gjson.GetBytes(ev.JSON(), "unsigned.prev_content")
- if !prevContent.Exists() {
- var evNew gomatrixserverlib.PDU
- evNew, err = gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSONWithEventID(ev.EventID(), newEv, false)
- if err != nil {
- return nil, err
- }
-
- return evNew, err
- }
- pls = gomatrixserverlib.PowerLevelContent{}
- err = json.Unmarshal([]byte(prevContent.Raw), &pls)
- if err != nil {
- return nil, err
- }
-
- newPls = make(map[string]int64)
- for user, level := range pls.Users {
- if eventFormat != synctypes.FormatSyncFederation {
- userID, err = p.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(user))
- if err != nil {
- return nil, err
- }
- user = userID.String()
- }
- newPls[user] = level
- }
- newPlBytes, err = json.Marshal(newPls)
- if err != nil {
- return nil, err
- }
- newEv, err = sjson.SetRawBytes(newEv, "unsigned.prev_content.users", newPlBytes)
- if err != nil {
- return nil, err
- }
-
- var evNew gomatrixserverlib.PDU
- evNew, err = gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSONWithEventID(ev.EventID(), newEv, false)
- if err != nil {
- return nil, err
- }
-
- return evNew, err
-}
-
// applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make
// sure we always return the required events in the timeline.
func applyHistoryVisibilityFilter(
@@ -690,35 +588,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
prevBatch.Decrement()
}
- // Update powerlevel events for timeline events
- for i, ev := range events {
- if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs {
- continue
- }
- if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") {
- continue
- }
- newEvent, err := p.updatePowerLevelEvent(ctx, ev, eventFormat)
- if err != nil {
- return nil, err
- }
- events[i] = &rstypes.HeaderedEvent{PDU: newEvent}
- }
- // Update powerlevel events for state events
- for i, ev := range stateEvents {
- if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs {
- continue
- }
- if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") {
- continue
- }
- newEvent, err := p.updatePowerLevelEvent(ctx, ev, eventFormat)
- if err != nil {
- return nil, err
- }
- stateEvents[i] = &rstypes.HeaderedEvent{PDU: newEvent}
- }
-
jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go
index e0616e11..6812f833 100644
--- a/syncapi/synctypes/clientevent.go
+++ b/syncapi/synctypes/clientevent.go
@@ -22,6 +22,8 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/sirupsen/logrus"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
)
// PrevEventRef represents a reference to a previous event in a state event upgrade
@@ -78,59 +80,62 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
if se == nil {
continue // TODO: shouldn't happen?
}
- if format == FormatSyncFederation {
- evs = append(evs, ToClientEvent(se, format, string(se.SenderID()), se.StateKey(), spec.RawJSON(se.Unsigned())))
+ ev, err := ToClientEvent(se, format, userIDForSender)
+ if err != nil {
+ logrus.WithError(err).Warn("Failed converting event to ClientEvent")
continue
}
+ evs = append(evs, *ev)
+ }
+ return evs
+}
- sender := spec.UserID{}
- userID, err := userIDForSender(se.RoomID(), se.SenderID())
- if err == nil && userID != nil {
- sender = *userID
- }
+// ToClientEventDefault converts a single server event to a client event.
+// It provides default logic for event.SenderID & event.StateKey -> userID conversions.
+func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
+ ev, err := ToClientEvent(event, FormatAll, userIDQuery)
+ if err != nil {
+ return ClientEvent{}
+ }
+ return *ev
+}
- sk := se.StateKey()
- if sk != nil && *sk != "" {
- skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk))
- if err == nil && skUserID != nil {
- skString := skUserID.String()
- sk = &skString
- }
+// If provided state key is a user ID (state keys beginning with @ are reserved for this purpose)
+// fetch it's associated sender ID and use that instead. Otherwise returns the same state key back.
+//
+// # This function either returns the state key that should be used, or an error
+//
+// TODO: handle failure cases better (e.g. no sender ID)
+func FromClientStateKey(roomID spec.RoomID, stateKey string, senderIDQuery spec.SenderIDForUser) (*string, error) {
+ if len(stateKey) >= 1 && stateKey[0] == '@' {
+ parsedStateKey, err := spec.NewUserID(stateKey, true)
+ if err != nil {
+ // If invalid user ID, then there is no associated state event.
+ return nil, fmt.Errorf("Provided state key begins with @ but is not a valid user ID: %w", err)
}
-
- unsigned := se.Unsigned()
- var prev PrevEventRef
- if err := json.Unmarshal(se.Unsigned(), &prev); err == nil && prev.PrevSenderID != "" {
- prevUserID, err := userIDForSender(se.RoomID(), spec.SenderID(prev.PrevSenderID))
- if err == nil && userID != nil {
- prev.PrevSenderID = prevUserID.String()
- } else {
- errString := "userID unknown"
- if err != nil {
- errString = err.Error()
- }
- logrus.Warnf("Failed to find userID for prev_sender in ClientEvent: %s", errString)
- // NOTE: Not much can be done here, so leave the previous value in place.
- }
- unsigned, err = json.Marshal(prev)
- if err != nil {
- logrus.Errorf("Failed to marshal unsigned content for ClientEvent: %s", err.Error())
- continue
- }
+ senderID, err := senderIDQuery(roomID, *parsedStateKey)
+ if err != nil {
+ return nil, fmt.Errorf("Failed to query sender ID: %w", err)
+ }
+ if senderID == nil {
+ // If no sender ID, then there is no associated state event.
+ return nil, fmt.Errorf("No associated sender ID found.")
}
- evs = append(evs, ToClientEvent(se, format, sender.String(), sk, spec.RawJSON(unsigned)))
+ newStateKey := string(*senderID)
+ return &newStateKey, nil
+ } else {
+ return &stateKey, nil
}
- return evs
}
// ToClientEvent converts a single server event to a client event.
-func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender string, stateKey *string, unsigned spec.RawJSON) ClientEvent {
+func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) (*ClientEvent, error) {
ce := ClientEvent{
- Content: spec.RawJSON(se.Content()),
- Sender: sender,
+ Content: se.Content(),
+ Sender: string(se.SenderID()),
Type: se.Type(),
- StateKey: stateKey,
- Unsigned: unsigned,
+ StateKey: se.StateKey(),
+ Unsigned: se.Unsigned(),
OriginServerTS: se.OriginServerTS(),
EventID: se.EventID(),
Redacts: se.Redacts(),
@@ -148,58 +153,268 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender st
// TODO: Set Signatures & Hashes fields
}
- if format != FormatSyncFederation {
- if se.Version() == gomatrixserverlib.RoomVersionPseudoIDs {
- ce.SenderKey = se.SenderID()
+ if format != FormatSyncFederation && se.Version() == gomatrixserverlib.RoomVersionPseudoIDs {
+ err := updatePseudoIDs(&ce, se, userIDForSender, format)
+ if err != nil {
+ return nil, err
}
}
- return ce
+
+ return &ce, nil
}
-// ToClientEvent converts a single server event to a client event.
-// It provides default logic for event.SenderID & event.StateKey -> userID conversions.
-func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
- sender := spec.UserID{}
- userID, err := userIDQuery(event.RoomID(), event.SenderID())
+func updatePseudoIDs(ce *ClientEvent, se gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender, format ClientEventFormat) error {
+ ce.SenderKey = se.SenderID()
+
+ userID, err := userIDForSender(se.RoomID(), se.SenderID())
if err == nil && userID != nil {
- sender = *userID
+ ce.Sender = userID.String()
}
- sk := event.StateKey()
+ sk := se.StateKey()
if sk != nil && *sk != "" {
- skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey()))
+ skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk))
if err == nil && skUserID != nil {
skString := skUserID.String()
- sk = &skString
+ ce.StateKey = &skString
}
}
- return ToClientEvent(event, FormatAll, sender.String(), sk, event.Unsigned())
+
+ var prev PrevEventRef
+ if err := json.Unmarshal(se.Unsigned(), &prev); err == nil && prev.PrevSenderID != "" {
+ prevUserID, err := userIDForSender(se.RoomID(), spec.SenderID(prev.PrevSenderID))
+ if err == nil && userID != nil {
+ prev.PrevSenderID = prevUserID.String()
+ } else {
+ errString := "userID unknown"
+ if err != nil {
+ errString = err.Error()
+ }
+ logrus.Warnf("Failed to find userID for prev_sender in ClientEvent: %s", errString)
+ // NOTE: Not much can be done here, so leave the previous value in place.
+ }
+ ce.Unsigned, err = json.Marshal(prev)
+ if err != nil {
+ err = fmt.Errorf("Failed to marshal unsigned content for ClientEvent: %w", err)
+ return err
+ }
+ }
+
+ switch se.Type() {
+ case spec.MRoomCreate:
+ updatedContent, err := updateCreateEvent(se.Content(), userIDForSender, se.RoomID())
+ if err != nil {
+ err = fmt.Errorf("Failed to update m.room.create event for ClientEvent: %w", err)
+ return err
+ }
+ ce.Content = updatedContent
+ case spec.MRoomMember:
+ updatedEvent, err := updateInviteEvent(userIDForSender, se, format)
+ if err != nil {
+ err = fmt.Errorf("Failed to update m.room.member event for ClientEvent: %w", err)
+ return err
+ }
+ if updatedEvent != nil {
+ ce.Unsigned = updatedEvent.Unsigned()
+ }
+ case spec.MRoomPowerLevels:
+ updatedEvent, err := updatePowerLevelEvent(userIDForSender, se, format)
+ if err != nil {
+ err = fmt.Errorf("Failed update m.room.power_levels event for ClientEvent: %w", err)
+ return err
+ }
+ if updatedEvent != nil {
+ ce.Content = updatedEvent.Content()
+ ce.Unsigned = updatedEvent.Unsigned()
+ }
+ }
+
+ return nil
}
-// If provided state key is a user ID (state keys beginning with @ are reserved for this purpose)
-// fetch it's associated sender ID and use that instead. Otherwise returns the same state key back.
-//
-// # This function either returns the state key that should be used, or an error
-//
-// TODO: handle failure cases better (e.g. no sender ID)
-func FromClientStateKey(roomID spec.RoomID, stateKey string, senderIDQuery spec.SenderIDForUser) (*string, error) {
- if len(stateKey) >= 1 && stateKey[0] == '@' {
- parsedStateKey, err := spec.NewUserID(stateKey, true)
+func updateCreateEvent(content spec.RawJSON, userIDForSender spec.UserIDForSender, roomID spec.RoomID) (spec.RawJSON, error) {
+ if creator := gjson.GetBytes(content, "creator"); creator.Exists() {
+ oldCreator := creator.Str
+ userID, err := userIDForSender(roomID, spec.SenderID(oldCreator))
if err != nil {
- // If invalid user ID, then there is no associated state event.
- return nil, fmt.Errorf("Provided state key begins with @ but is not a valid user ID: %s", err.Error())
+ err = fmt.Errorf("Failed to find userID for creator in ClientEvent: %w", err)
+ return nil, err
}
- senderID, err := senderIDQuery(roomID, *parsedStateKey)
+
+ if userID != nil {
+ var newCreatorBytes, newContent []byte
+ newCreatorBytes, err = json.Marshal(userID.String())
+ if err != nil {
+ err = fmt.Errorf("Failed to marshal new creator for ClientEvent: %w", err)
+ return nil, err
+ }
+
+ newContent, err = sjson.SetRawBytes([]byte(content), "creator", newCreatorBytes)
+ if err != nil {
+ err = fmt.Errorf("Failed to set new creator for ClientEvent: %w", err)
+ return nil, err
+ }
+
+ return newContent, nil
+ }
+ }
+
+ return content, nil
+}
+
+func updateInviteEvent(userIDForSender spec.UserIDForSender, ev gomatrixserverlib.PDU, eventFormat ClientEventFormat) (gomatrixserverlib.PDU, error) {
+ if inviteRoomState := gjson.GetBytes(ev.Unsigned(), "invite_room_state"); inviteRoomState.Exists() {
+ userID, err := userIDForSender(ev.RoomID(), ev.SenderID())
+ if err != nil || userID == nil {
+ if err != nil {
+ err = fmt.Errorf("invalid userID found when updating invite_room_state: %w", err)
+ }
+ return nil, err
+ }
+
+ newState, err := GetUpdatedInviteRoomState(userIDForSender, inviteRoomState, ev, ev.RoomID(), eventFormat)
if err != nil {
- return nil, fmt.Errorf("Failed to query sender ID: %s", err.Error())
+ return nil, err
}
- if senderID == nil {
- // If no sender ID, then there is no associated state event.
- return nil, fmt.Errorf("No associated sender ID found.")
+
+ var newEv []byte
+ newEv, err = sjson.SetRawBytes(ev.JSON(), "unsigned.invite_room_state", newState)
+ if err != nil {
+ return nil, err
}
- newStateKey := string(*senderID)
- return &newStateKey, nil
- } else {
- return &stateKey, nil
+
+ return gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSON(newEv, false)
+ }
+
+ return ev, nil
+}
+
+type InviteRoomStateEvent struct {
+ Content spec.RawJSON `json:"content"`
+ SenderID string `json:"sender"`
+ StateKey *string `json:"state_key"`
+ Type string `json:"type"`
+}
+
+func GetUpdatedInviteRoomState(userIDForSender spec.UserIDForSender, inviteRoomState gjson.Result, event gomatrixserverlib.PDU, roomID spec.RoomID, eventFormat ClientEventFormat) (spec.RawJSON, error) {
+ var res spec.RawJSON
+ inviteStateEvents := []InviteRoomStateEvent{}
+ err := json.Unmarshal([]byte(inviteRoomState.Raw), &inviteStateEvents)
+ if err != nil {
+ return nil, err
}
+
+ if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != FormatSyncFederation {
+ for i, ev := range inviteStateEvents {
+ userID, userIDErr := userIDForSender(roomID, spec.SenderID(ev.SenderID))
+ if userIDErr != nil {
+ return nil, userIDErr
+ }
+ if userID != nil {
+ inviteStateEvents[i].SenderID = userID.String()
+ }
+
+ if ev.StateKey != nil && *ev.StateKey != "" {
+ userID, senderErr := userIDForSender(roomID, spec.SenderID(*ev.StateKey))
+ if senderErr != nil {
+ return nil, senderErr
+ }
+ if userID != nil {
+ user := userID.String()
+ inviteStateEvents[i].StateKey = &user
+ }
+ }
+
+ updatedContent, updateErr := updateCreateEvent(ev.Content, userIDForSender, roomID)
+ if updateErr != nil {
+ updateErr = fmt.Errorf("Failed to update m.room.create event for ClientEvent: %w", userIDErr)
+ return nil, updateErr
+ }
+ inviteStateEvents[i].Content = updatedContent
+ }
+ }
+
+ res, err = json.Marshal(inviteStateEvents)
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
+
+func updatePowerLevelEvent(userIDForSender spec.UserIDForSender, se gomatrixserverlib.PDU, eventFormat ClientEventFormat) (gomatrixserverlib.PDU, error) {
+ if !se.StateKeyEquals("") {
+ return se, nil
+ }
+
+ pls, err := gomatrixserverlib.NewPowerLevelContentFromEvent(se)
+ if err != nil {
+ return nil, err
+ }
+ newPls := make(map[string]int64)
+ var userID *spec.UserID
+ for user, level := range pls.Users {
+ if eventFormat != FormatSyncFederation {
+ userID, err = userIDForSender(se.RoomID(), spec.SenderID(user))
+ if err != nil {
+ return nil, err
+ }
+ user = userID.String()
+ }
+ newPls[user] = level
+ }
+ var newPlBytes, newEv []byte
+ newPlBytes, err = json.Marshal(newPls)
+ if err != nil {
+ return nil, err
+ }
+ newEv, err = sjson.SetRawBytes(se.JSON(), "content.users", newPlBytes)
+ if err != nil {
+ return nil, err
+ }
+
+ // do the same for prev content
+ prevContent := gjson.GetBytes(se.JSON(), "unsigned.prev_content")
+ if !prevContent.Exists() {
+ var evNew gomatrixserverlib.PDU
+ evNew, err = gomatrixserverlib.MustGetRoomVersion(se.Version()).NewEventFromTrustedJSON(newEv, false)
+ if err != nil {
+ return nil, err
+ }
+
+ return evNew, err
+ }
+ pls = gomatrixserverlib.PowerLevelContent{}
+ err = json.Unmarshal([]byte(prevContent.Raw), &pls)
+ if err != nil {
+ return nil, err
+ }
+
+ newPls = make(map[string]int64)
+ for user, level := range pls.Users {
+ if eventFormat != FormatSyncFederation {
+ userID, err = userIDForSender(se.RoomID(), spec.SenderID(user))
+ if err != nil {
+ return nil, err
+ }
+ user = userID.String()
+ }
+ newPls[user] = level
+ }
+ newPlBytes, err = json.Marshal(newPls)
+ if err != nil {
+ return nil, err
+ }
+ newEv, err = sjson.SetRawBytes(newEv, "unsigned.prev_content.users", newPlBytes)
+ if err != nil {
+ return nil, err
+ }
+
+ var evNew gomatrixserverlib.PDU
+ evNew, err = gomatrixserverlib.MustGetRoomVersion(se.Version()).NewEventFromTrustedJSONWithEventID(se.EventID(), newEv, false)
+ if err != nil {
+ return nil, err
+ }
+
+ return evNew, err
}
diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go
index 202c185f..662f9ea4 100644
--- a/syncapi/synctypes/clientevent_test.go
+++ b/syncapi/synctypes/clientevent_test.go
@@ -26,6 +26,14 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec"
)
+func queryUserIDForSender(senderID spec.SenderID) (*spec.UserID, error) {
+ if senderID == "" {
+ return nil, nil
+ }
+
+ return spec.NewUserID(string(senderID), true)
+}
+
const testSenderID = "testSenderID"
const testUserID = "@test:localhost"
@@ -106,7 +114,12 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
t.Fatalf("failed to create userID: %s", err)
}
sk := ""
- ce := ToClientEvent(ev, FormatAll, userID.String(), &sk, ev.Unsigned())
+ ce, err := ToClientEvent(ev, FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return queryUserIDForSender(senderID)
+ })
+ if err != nil {
+ t.Fatalf("failed to create ClientEvent: %s", err)
+ }
verifyEventFields(t,
EventFieldsToVerify{
@@ -161,12 +174,12 @@ func TestToClientFormatSync(t *testing.T) {
if err != nil {
t.Fatalf("failed to create Event: %s", err)
}
- userID, err := spec.NewUserID("@test:localhost", true)
+ ce, err := ToClientEvent(ev, FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return queryUserIDForSender(senderID)
+ })
if err != nil {
- t.Fatalf("failed to create userID: %s", err)
+ t.Fatalf("failed to create ClientEvent: %s", err)
}
- sk := ""
- ce := ToClientEvent(ev, FormatSync, userID.String(), &sk, ev.Unsigned())
if ce.RoomID != "" {
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
}
@@ -206,7 +219,12 @@ func TestToClientEventFormatSyncFederation(t *testing.T) { // nolint: gocyclo
t.Fatalf("failed to create userID: %s", err)
}
sk := ""
- ce := ToClientEvent(ev, FormatSyncFederation, userID.String(), &sk, ev.Unsigned())
+ ce, err := ToClientEvent(ev, FormatSyncFederation, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return queryUserIDForSender(senderID)
+ })
+ if err != nil {
+ t.Fatalf("failed to create ClientEvent: %s", err)
+ }
verifyEventFields(t,
EventFieldsToVerify{
diff --git a/syncapi/types/types.go b/syncapi/types/types.go
index b90c128c..bca11855 100644
--- a/syncapi/types/types.go
+++ b/syncapi/types/types.go
@@ -15,6 +15,7 @@
package types
import (
+ "context"
"encoding/json"
"errors"
"fmt"
@@ -532,7 +533,7 @@ type InviteResponse struct {
}
// NewInviteResponse creates an empty response with initialised arrays.
-func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string, eventFormat synctypes.ClientEventFormat) *InviteResponse {
+func NewInviteResponse(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *types.HeaderedEvent, eventFormat synctypes.ClientEventFormat) (*InviteResponse, error) {
res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{}
@@ -540,18 +541,42 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey
// If there is then unmarshal it into the response. This will contain the
// partial room state such as join rules, room name etc.
if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() {
- _ = json.Unmarshal([]byte(inviteRoomState.Raw), &res.InviteState.Events)
+ if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != synctypes.FormatSyncFederation {
+ updatedInvite, err := synctypes.GetUpdatedInviteRoomState(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }, inviteRoomState, event.PDU, event.RoomID(), eventFormat)
+ if err != nil {
+ return nil, err
+ }
+ _ = json.Unmarshal(updatedInvite, &res.InviteState.Events)
+ } else {
+ _ = json.Unmarshal([]byte(inviteRoomState.Raw), &res.InviteState.Events)
+ }
+ }
+
+ // Clear unsigned so it doesn't have pseudoIDs converted during ToClientEvent
+ eventNoUnsigned, err := event.SetUnsigned(nil)
+ if err != nil {
+ return nil, err
}
// Then we'll see if we can create a partial of the invite event itself.
// This is needed for clients to work out *who* sent the invite.
- inviteEvent := synctypes.ToClientEvent(event.PDU, eventFormat, userID.String(), stateKey, event.Unsigned())
+ inviteEvent, err := synctypes.ToClientEvent(eventNoUnsigned, eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Ensure unsigned field is empty so it isn't marshalled into the final JSON
inviteEvent.Unsigned = nil
- if ev, err := json.Marshal(inviteEvent); err == nil {
+
+ if ev, err := json.Marshal(*inviteEvent); err == nil {
res.InviteState.Events = append(res.InviteState.Events, ev)
}
- return &res
+ return &res, nil
}
// LeaveResponse represents a /sync response for a room which is under the 'leave' key.
diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go
index a79b9fc5..35e1882c 100644
--- a/syncapi/types/types_test.go
+++ b/syncapi/types/types_test.go
@@ -1,6 +1,7 @@
package types
import (
+ "context"
"encoding/json"
"reflect"
"testing"
@@ -11,8 +12,19 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec"
)
-func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) {
- return spec.NewUserID(senderID, true)
+type FakeRoomserverAPI struct{}
+
+func (f *FakeRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ if senderID == "" {
+ return nil, nil
+ }
+
+ return spec.NewUserID(string(senderID), true)
+}
+
+func (f *FakeRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) {
+ sender := spec.SenderID(userID.String())
+ return &sender, nil
}
func TestSyncTokens(t *testing.T) {
@@ -61,25 +73,18 @@ func TestNewInviteResponse(t *testing.T) {
t.Fatal(err)
}
- sender, err := spec.NewUserID("@neilalexander:matrix.org", true)
+ rsAPI := FakeRoomserverAPI{}
+ res, err := NewInviteResponse(context.Background(), &rsAPI, &types.HeaderedEvent{PDU: ev}, synctypes.FormatSync)
if err != nil {
t.Fatal(err)
}
- skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true)
- if err != nil {
- t.Fatal(err)
- }
- skString := skUserID.String()
- sk := &skString
-
- res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk, synctypes.FormatSync)
j, err := json.Marshal(res)
if err != nil {
t.Fatal(err)
}
if string(j) != expected {
- t.Fatalf("Invite response didn't contain correct info")
+ t.Fatalf("Invite response didn't contain correct info, \nexpected: %s \ngot: %s", expected, string(j))
}
}