diff options
author | devonh <devon.dmytro@gmail.com> | 2023-05-19 16:27:01 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-19 16:27:01 +0000 |
commit | 2eae8dc489f056df5aec0ee4ace1b8ba8260e18e (patch) | |
tree | b820c15eec6d80812661f5f0df6d0e0db7dfd630 /federationapi/routing | |
parent | 027a9b8ce0a7e2d577e2c41f9de7a6fe42ace655 (diff) |
Move SendJoin logic to GMSL (#3084)
Moves the core matrix logic for handling the send_join endpoint over to
gmsl.
Diffstat (limited to 'federationapi/routing')
-rw-r--r-- | federationapi/routing/join.go | 321 | ||||
-rw-r--r-- | federationapi/routing/routing.go | 26 |
2 files changed, 130 insertions, 217 deletions
diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index cc22690a..cbdeca51 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -16,7 +16,6 @@ package routing import ( "context" - "encoding/json" "fmt" "net/http" "sort" @@ -160,45 +159,43 @@ func MakeJoin( BuildEventTemplate: createJoinTemplate, } response, internalErr := gomatrixserverlib.HandleMakeJoin(input) - if internalErr != nil { - switch e := internalErr.(type) { - case nil: - case spec.InternalServerError: - util.GetLogger(httpReq.Context()).WithError(internalErr) - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - case spec.MatrixError: - util.GetLogger(httpReq.Context()).WithError(internalErr) - code := http.StatusInternalServerError - switch e.ErrCode { - case spec.ErrorForbidden: - code = http.StatusForbidden - case spec.ErrorNotFound: - code = http.StatusNotFound - case spec.ErrorUnableToAuthoriseJoin: - code = http.StatusBadRequest - case spec.ErrorBadJSON: - code = http.StatusBadRequest - } + switch e := internalErr.(type) { + case nil: + case spec.InternalServerError: + util.GetLogger(httpReq.Context()).WithError(internalErr) + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + case spec.MatrixError: + util.GetLogger(httpReq.Context()).WithError(internalErr) + code := http.StatusInternalServerError + switch e.ErrCode { + case spec.ErrorForbidden: + code = http.StatusForbidden + case spec.ErrorNotFound: + code = http.StatusNotFound + case spec.ErrorUnableToAuthoriseJoin: + fallthrough // http.StatusBadRequest + case spec.ErrorBadJSON: + code = http.StatusBadRequest + } - return util.JSONResponse{ - Code: code, - JSON: e, - } - case spec.IncompatibleRoomVersionError: - util.GetLogger(httpReq.Context()).WithError(internalErr) - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: e, - } - default: - util.GetLogger(httpReq.Context()).WithError(internalErr) - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.Unknown("unknown error"), - } + return util.JSONResponse{ + Code: code, + JSON: e, + } + case spec.IncompatibleRoomVersionError: + util.GetLogger(httpReq.Context()).WithError(internalErr) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: e, + } + default: + util.GetLogger(httpReq.Context()).WithError(internalErr) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.Unknown("unknown error"), } } @@ -219,6 +216,25 @@ func MakeJoin( } } +type MembershipQuerier struct { + roomserver api.FederationRoomserverAPI +} + +func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { + req := api.QueryMembershipForUserRequest{ + RoomID: roomID.String(), + UserID: userID.String(), + } + res := api.QueryMembershipForUserResponse{} + err := mq.roomserver.QueryMembershipForUser(ctx, &req, &res) + + membership := "" + if err == nil { + membership = res.Membership + } + return membership, err +} + // SendJoin implements the /send_join API // The make-join send-join dance makes much more sense as a single // flow so the cyclomatic complexity is high: @@ -229,9 +245,10 @@ func SendJoin( cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, keys gomatrixserverlib.JSONVerifier, - roomID, eventID string, + roomID spec.RoomID, + eventID string, ) util.JSONResponse { - roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID) + roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID.String()) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed") return util.JSONResponse{ @@ -239,132 +256,71 @@ func SendJoin( JSON: spec.InternalServerError{}, } } - verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion) - if err != nil { - return util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.UnsupportedRoomVersion( - fmt.Sprintf("QueryRoomVersionForRoom returned unknown room version: %s", roomVersion), - ), - } - } - event, err := verImpl.NewEventFromUntrustedJSON(request.Content()) - if err != nil { + input := gomatrixserverlib.HandleSendJoinInput{ + Context: httpReq.Context(), + RoomID: roomID, + EventID: eventID, + JoinEvent: request.Content(), + RoomVersion: roomVersion, + RequestOrigin: request.Origin(), + LocalServerName: cfg.Matrix.ServerName, + KeyID: cfg.Matrix.KeyID, + PrivateKey: cfg.Matrix.PrivateKey, + Verifier: keys, + MembershipQuerier: &MembershipQuerier{roomserver: rsAPI}, + } + response, joinErr := gomatrixserverlib.HandleSendJoin(input) + switch e := joinErr.(type) { + case nil: + case spec.InternalServerError: + util.GetLogger(httpReq.Context()).WithError(joinErr) return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("The request body could not be decoded into valid JSON: " + err.Error()), - } - } - - // Check that a state key is provided. - if event.StateKey() == nil || event.StateKeyEquals("") { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("No state key was provided in the join event."), - } - } - if !event.StateKeyEquals(event.Sender()) { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("Event state key must match the event sender."), - } - } - - // Check that the sender belongs to the server that is sending us - // the request. By this point we've already asserted that the sender - // and the state key are equal so we don't need to check both. - var serverName spec.ServerName - if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("The sender of the join is invalid"), + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, } - } else if serverName != request.Origin() { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("The sender does not match the server that originated the request"), + case spec.MatrixError: + util.GetLogger(httpReq.Context()).WithError(joinErr) + code := http.StatusInternalServerError + switch e.ErrCode { + case spec.ErrorForbidden: + code = http.StatusForbidden + case spec.ErrorNotFound: + code = http.StatusNotFound + case spec.ErrorUnsupportedRoomVersion: + code = http.StatusInternalServerError + case spec.ErrorBadJSON: + code = http.StatusBadRequest } - } - // Check that the room ID is correct. - if event.RoomID() != roomID { return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON( - fmt.Sprintf( - "The room ID in the request path (%q) must match the room ID in the join event JSON (%q)", - roomID, event.RoomID(), - ), - ), + Code: code, + JSON: e, } - } - - // Check that the event ID is correct. - if event.EventID() != eventID { + default: + util.GetLogger(httpReq.Context()).WithError(joinErr) return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.BadJSON( - fmt.Sprintf( - "The event ID in the request path (%q) must match the event ID in the join event JSON (%q)", - eventID, event.EventID(), - ), - ), + JSON: spec.Unknown("unknown error"), } } - // Check that this is in fact a join event - membership, err := event.Membership() - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("missing content.membership key"), - } - } - if membership != spec.Join { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("membership must be 'join'"), - } - } - - // Check that the event is signed by the server sending the request. - redacted, err := verImpl.RedactEventJSON(event.JSON()) - if err != nil { - logrus.WithError(err).Errorf("XXX: join.go") - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON("The event JSON could not be redacted"), - } - } - verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: serverName, - Message: redacted, - AtTS: event.OriginServerTS(), - StrictValidityChecking: true, - }} - verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests) - if err != nil { - util.GetLogger(httpReq.Context()).WithError(err).Error("keys.VerifyJSONs failed") + if response == nil { + util.GetLogger(httpReq.Context()).Error("gmsl.HandleMakeJoin returned invalid response") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}, } - } - if verifyResults[0].Error != nil { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("Signature check failed: " + verifyResults[0].Error.Error()), - } + } // Fetch the state and auth chain. We do this before we send the events // on, in case this fails. var stateAndAuthChainResponse api.QueryStateAndAuthChainResponse err = rsAPI.QueryStateAndAuthChain(httpReq.Context(), &api.QueryStateAndAuthChainRequest{ - PrevEventIDs: event.PrevEventIDs(), - AuthEventIDs: event.AuthEventIDs(), - RoomID: roomID, + PrevEventIDs: response.JoinEvent.PrevEventIDs(), + AuthEventIDs: response.JoinEvent.AuthEventIDs(), + RoomID: roomID.String(), ResolveState: true, }, &stateAndAuthChainResponse) if err != nil { @@ -388,84 +344,27 @@ func SendJoin( } } - // Check if the user is already in the room. If they're already in then - // there isn't much point in sending another join event into the room. - // Also check to see if they are banned: if they are then we reject them. - alreadyJoined := false - isBanned := false - for _, se := range stateAndAuthChainResponse.StateEvents { - if !se.StateKeyEquals(*event.StateKey()) { - continue - } - if membership, merr := se.Membership(); merr == nil { - alreadyJoined = (membership == spec.Join) - isBanned = (membership == spec.Ban) - break - } - } - - if isBanned { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("user is banned"), - } - } - - // If the membership content contains a user ID for a server that is not - // ours then we should kick it back. - var memberContent gomatrixserverlib.MemberContent - if err := json.Unmarshal(event.Content(), &memberContent); err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON(err.Error()), - } - } - if memberContent.AuthorisedVia != "" { - _, domain, err := gomatrixserverlib.SplitID('@', memberContent.AuthorisedVia) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON(fmt.Sprintf("The authorising username %q is invalid.", memberContent.AuthorisedVia)), - } - } - if domain != cfg.Matrix.ServerName { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: spec.BadJSON(fmt.Sprintf("The authorising username %q does not belong to this server.", memberContent.AuthorisedVia)), - } - } - } - - // Sign the membership event. This is required for restricted joins to work - // in the case that the authorised via user is one of our own users. It also - // doesn't hurt to do it even if it isn't a restricted join. - signed := event.Sign( - string(cfg.Matrix.ServerName), - cfg.Matrix.KeyID, - cfg.Matrix.PrivateKey, - ) - // Send the events to the room server. // We are responsible for notifying other servers that the user has joined // the room, so set SendAsServer to cfg.Matrix.ServerName - if !alreadyJoined { - var response api.InputRoomEventsResponse + if !response.AlreadyJoined { + var rsResponse api.InputRoomEventsResponse rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ InputRoomEvents: []api.InputRoomEvent{ { Kind: api.KindNew, - Event: &types.HeaderedEvent{PDU: signed}, + Event: &types.HeaderedEvent{PDU: response.JoinEvent}, SendAsServer: string(cfg.Matrix.ServerName), TransactionID: nil, }, }, - }, &response) - if response.ErrMsg != "" { - util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).Error("SendEvents failed") - if response.NotAllowed { + }, &rsResponse) + if rsResponse.ErrMsg != "" { + util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, rsResponse.ErrMsg).Error("SendEvents failed") + if rsResponse.NotAllowed { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.Forbidden(response.ErrMsg), + JSON: spec.Forbidden(rsResponse.ErrMsg), } } return util.JSONResponse{ @@ -488,7 +387,7 @@ func SendJoin( StateEvents: types.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.StateEvents), AuthEvents: types.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.AuthChainEvents), Origin: cfg.Matrix.ServerName, - Event: signed.JSON(), + Event: response.JoinEvent.JSON(), }, } } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 44faad91..7be0857a 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -331,14 +331,14 @@ func Setup( if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.BadJSON("Invalid UserID"), + JSON: spec.InvalidParam("Invalid UserID"), } } roomID, err := spec.NewRoomID(vars["roomID"]) if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: spec.BadJSON("Invalid RoomID"), + JSON: spec.InvalidParam("Invalid RoomID"), } } @@ -358,10 +358,17 @@ func Setup( JSON: spec.Forbidden("Forbidden by server ACLs"), } } - roomID := vars["roomID"] eventID := vars["eventID"] + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), + } + } + res := SendJoin( - httpReq, request, cfg, rsAPI, keys, roomID, eventID, + httpReq, request, cfg, rsAPI, keys, *roomID, eventID, ) // not all responses get wrapped in [code, body] var body interface{} @@ -390,10 +397,17 @@ func Setup( JSON: spec.Forbidden("Forbidden by server ACLs"), } } - roomID := vars["roomID"] eventID := vars["eventID"] + roomID, err := spec.NewRoomID(vars["roomID"]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("Invalid RoomID"), + } + } + return SendJoin( - httpReq, request, cfg, rsAPI, keys, roomID, eventID, + httpReq, request, cfg, rsAPI, keys, *roomID, eventID, ) }, )).Methods(http.MethodPut) |