diff options
Diffstat (limited to 'syncapi/routing/relations.go')
-rw-r--r-- | syncapi/routing/relations.go | 33 |
1 files changed, 22 insertions, 11 deletions
diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 17933b2f..e3d1069a 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -43,9 +43,25 @@ func Relations( req *http.Request, device *userapi.Device, syncDB storage.Database, rsAPI api.SyncRoomserverAPI, - roomID, eventID, relType, eventType string, + rawRoomID, eventID, relType, eventType string, ) util.JSONResponse { - var err error + roomID, err := spec.NewRoomID(rawRoomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("invalid room ID"), + } + } + + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown("internal server error"), + } + } + var from, to types.StreamPosition var limit int dir := req.URL.Query().Get("dir") @@ -93,7 +109,7 @@ func Relations( } var events []types.StreamEvent events, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor( - req.Context(), roomID, eventID, relType, eventType, from, to, dir == "b", limit, + req.Context(), roomID.String(), eventID, relType, eventType, from, to, dir == "b", limit, ) if err != nil { return util.ErrorResponse(err) @@ -105,12 +121,7 @@ func Relations( } // Apply history visibility to the result events. - filteredEvents, err := internal.ApplyHistoryVisibilityFilter(req.Context(), snapshot, rsAPI, headeredEvents, nil, device.UserID, "relations") - if err != nil { - return util.ErrorResponse(err) - } - - validRoomID, err := spec.NewRoomID(roomID) + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(req.Context(), snapshot, rsAPI, headeredEvents, nil, *userID, "relations") if err != nil { return util.ErrorResponse(err) } @@ -120,14 +131,14 @@ func Relations( res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) for _, event := range filteredEvents { sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID()) + userID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, event.SenderID()) if err == nil && userID != nil { sender = *userID } sk := event.StateKey() if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey())) + skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, spec.SenderID(*event.StateKey())) if err == nil && skUserID != nil { skString := skUserID.String() sk = &skString |