diff options
author | devonh <devon.dmytro@gmail.com> | 2023-06-07 17:14:35 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-07 17:14:35 +0000 |
commit | 8ea1a11105ea7e66aa459537bcbef0de606147cd (patch) | |
tree | 62a539b761d078978226752c8d3ef23f1a80e677 /federationapi | |
parent | 7a1fd7f512ce06a472a2051ee63eae4a270eb71a (diff) |
Use SenderID Type (#3105)
Diffstat (limited to 'federationapi')
-rw-r--r-- | federationapi/federationapi_test.go | 13 | ||||
-rw-r--r-- | federationapi/internal/perform.go | 28 | ||||
-rw-r--r-- | federationapi/routing/invite.go | 4 | ||||
-rw-r--r-- | federationapi/routing/join.go | 32 | ||||
-rw-r--r-- | federationapi/routing/leave.go | 32 | ||||
-rw-r--r-- | federationapi/routing/threepid.go | 14 |
6 files changed, 80 insertions, 43 deletions
diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index a97bcdea..17390843 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -36,8 +36,12 @@ type fedRoomserverAPI struct { queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error } -func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + +func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { + return spec.SenderID(userID.String()), nil } // PerformJoin will call this function @@ -115,12 +119,13 @@ func (f *fedClient) MakeJoin(ctx context.Context, origin, s spec.ServerName, roo defer f.fedClientMutex.Unlock() for _, r := range f.allowJoins { if r.ID == roomID { + senderIDString := userID res.RoomVersion = r.Version res.JoinEvent = gomatrixserverlib.ProtoEvent{ - Sender: userID, + SenderID: senderIDString, RoomID: roomID, Type: "m.room.member", - StateKey: &userID, + StateKey: &senderIDString, Content: spec.RawJSON([]byte(`{"membership":"join"}`)), PrevEvents: r.ForwardExtremities(), } diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 2d59d0f9..485b79a0 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -154,9 +154,14 @@ func (r *FederationInternalAPI) performJoinUsingServer( if err != nil { return err } + senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, roomID, *user) + if err != nil { + return err + } joinInput := gomatrixserverlib.PerformJoinInput{ UserID: user, + SenderID: senderID, RoomID: room, ServerName: serverName, Content: content, @@ -164,10 +169,10 @@ func (r *FederationInternalAPI) performJoinUsingServer( PrivateKey: r.cfg.Matrix.PrivateKey, KeyID: r.cfg.Matrix.KeyID, KeyRing: r.keyRing, - EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID, senderID string) (*spec.UserID, error) { + EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), - UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, } @@ -363,7 +368,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - userIDProvider := func(roomID, senderID string) (*spec.UserID, error) { + userIDProvider := func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) } authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( @@ -414,7 +419,7 @@ func (r *FederationInternalAPI) PerformLeave( request *api.PerformLeaveRequest, response *api.PerformLeaveResponse, ) (err error) { - _, origin, err := r.cfg.Matrix.SplitLocalID('@', request.UserID) + userID, err := spec.NewUserID(request.UserID, true) if err != nil { return err } @@ -433,7 +438,7 @@ func (r *FederationInternalAPI) PerformLeave( // request. respMakeLeave, err := r.federation.MakeLeave( ctx, - origin, + userID.Domain(), serverName, request.RoomID, request.UserID, @@ -454,9 +459,14 @@ func (r *FederationInternalAPI) PerformLeave( // Set all the fields to be what they should be, this should be a no-op // but it's possible that the remote server returned us something "odd" + senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, request.RoomID, *userID) + if err != nil { + return err + } + senderIDString := string(senderID) respMakeLeave.LeaveEvent.Type = spec.MRoomMember - respMakeLeave.LeaveEvent.Sender = request.UserID - respMakeLeave.LeaveEvent.StateKey = &request.UserID + respMakeLeave.LeaveEvent.SenderID = senderIDString + respMakeLeave.LeaveEvent.StateKey = &senderIDString respMakeLeave.LeaveEvent.RoomID = request.RoomID respMakeLeave.LeaveEvent.Redacts = "" leaveEB := verImpl.NewEventBuilderFromProtoEvent(&respMakeLeave.LeaveEvent) @@ -478,7 +488,7 @@ func (r *FederationInternalAPI) PerformLeave( // Build the leave event. event, err := leaveEB.Build( time.Now(), - origin, + userID.Domain(), r.cfg.Matrix.KeyID, r.cfg.Matrix.PrivateKey, ) @@ -490,7 +500,7 @@ func (r *FederationInternalAPI) PerformLeave( // Try to perform a send_leave using the newly built event. err = r.federation.SendLeave( ctx, - origin, + userID.Domain(), serverName, event, ) diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index d792335b..5b15f810 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -95,7 +95,7 @@ func InviteV2( StateQuerier: rsAPI.StateQuerier(), InviteEvent: inviteReq.Event(), StrippedState: inviteReq.InviteRoomState(), - UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } @@ -188,7 +188,7 @@ func InviteV1( StateQuerier: rsAPI.StateQuerier(), InviteEvent: event, StrippedState: strippedState, - UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 9da05918..d1480192 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -55,7 +55,7 @@ func MakeJoin( RoomID: roomID.String(), } res := api.QueryServerJoinedToRoomResponse{} - if err := rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { + if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -64,26 +64,26 @@ func MakeJoin( } createJoinTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) { - identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) - if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Errorf("obtaining signing identity for %s failed", request.Destination()) + identity, signErr := cfg.Matrix.SigningIdentityFor(request.Destination()) + if signErr != nil { + util.GetLogger(httpReq.Context()).WithError(signErr).Errorf("obtaining signing identity for %s failed", request.Destination()) return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination())) } queryRes := api.QueryLatestEventsAndStateResponse{ RoomVersion: roomVersion, } - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) - switch e := err.(type) { + event, signErr := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) + switch e := signErr.(type) { case nil: case eventutil.ErrRoomNoExists: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed") return nil, nil, spec.NotFound("Room does not exist") case gomatrixserverlib.BadJSONError: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed") return nil, nil, spec.BadJSON(e.Error()) default: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(signErr).Error("eventutil.BuildEvent failed") return nil, nil, spec.InternalServerError{} } @@ -98,9 +98,19 @@ func MakeJoin( Roomserver: rsAPI, } + senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + input := gomatrixserverlib.HandleMakeJoinInput{ Context: httpReq.Context(), UserID: userID, + SenderID: senderID, RoomID: roomID, RoomVersion: roomVersion, RemoteVersions: remoteVersions, @@ -108,7 +118,7 @@ func MakeJoin( LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, RoomQuerier: &roomQuerier, - UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, BuildEventTemplate: createJoinTemplate, @@ -205,7 +215,7 @@ func SendJoin( PrivateKey: cfg.Matrix.PrivateKey, Verifier: keys, MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, - UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 30e99c4f..716276be 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -50,7 +50,7 @@ func MakeLeave( RoomID: roomID.String(), } res := api.QueryServerJoinedToRoomResponse{} - if err := rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { + if err = rsAPI.QueryServerJoinedToRoom(httpReq.Context(), &req, &res); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryServerJoinedToRoom failed") return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -59,24 +59,24 @@ func MakeLeave( } createLeaveTemplate := func(proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, []gomatrixserverlib.PDU, error) { - identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) - if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Errorf("obtaining signing identity for %s failed", request.Destination()) + identity, signErr := cfg.Matrix.SigningIdentityFor(request.Destination()) + if signErr != nil { + util.GetLogger(httpReq.Context()).WithError(signErr).Errorf("obtaining signing identity for %s failed", request.Destination()) return nil, nil, spec.NotFound(fmt.Sprintf("Server name %q does not exist", request.Destination())) } queryRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) - switch e := err.(type) { + event, buildErr := eventutil.QueryAndBuildEvent(httpReq.Context(), proto, identity, time.Now(), rsAPI, &queryRes) + switch e := buildErr.(type) { case nil: case eventutil.ErrRoomNoExists: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed") return nil, nil, spec.NotFound("Room does not exist") case gomatrixserverlib.BadJSONError: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed") return nil, nil, spec.BadJSON(e.Error()) default: - util.GetLogger(httpReq.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(httpReq.Context()).WithError(buildErr).Error("eventutil.BuildEvent failed") return nil, nil, spec.InternalServerError{} } @@ -87,15 +87,25 @@ func MakeLeave( return event, stateEvents, nil } + senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + input := gomatrixserverlib.HandleMakeLeaveInput{ UserID: userID, + SenderID: senderID, RoomID: roomID, RoomVersion: roomVersion, RequestOrigin: request.Origin(), LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, BuildEventTemplate: createLeaveTemplate, - UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) }, } @@ -216,7 +226,7 @@ func SendLeave( JSON: spec.BadJSON("No state key was provided in the leave event."), } } - if !event.StateKeyEquals(event.SenderID()) { + if !event.StateKeyEquals(string(event.SenderID())) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("Event state key must match the event sender."), diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 76a2f3d5..360802de 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -140,22 +140,24 @@ func ExchangeThirdPartyInvite( } } - _, senderDomain, err := cfg.Matrix.SplitLocalID('@', proto.Sender) - if err != nil { + userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(proto.SenderID)) + if err != nil || userID == nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.BadJSON("Invalid sender ID: " + err.Error()), + JSON: spec.BadJSON("Invalid sender ID"), } } + senderDomain := userID.Domain() // Check that the state key is correct. - _, targetDomain, err := gomatrixserverlib.SplitID('@', *proto.StateKey) - if err != nil { + targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(*proto.StateKey)) + if err != nil || targetUserID == nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("The event's state key isn't a Matrix user ID"), } } + targetDomain := targetUserID.Domain() // Check that the target user is from the requesting homeserver. if targetDomain != request.Origin() { @@ -271,7 +273,7 @@ func createInviteFrom3PIDInvite( // Build the event proto := &gomatrixserverlib.ProtoEvent{ Type: "m.room.member", - Sender: inv.Sender, + SenderID: inv.Sender, RoomID: inv.RoomID, StateKey: &inv.MXID, } |