aboutsummaryrefslogtreecommitdiff
path: root/federationapi/routing
diff options
context:
space:
mode:
authordevonh <devon.dmytro@gmail.com>2023-05-19 16:27:01 +0000
committerGitHub <noreply@github.com>2023-05-19 16:27:01 +0000
commit2eae8dc489f056df5aec0ee4ace1b8ba8260e18e (patch)
treeb820c15eec6d80812661f5f0df6d0e0db7dfd630 /federationapi/routing
parent027a9b8ce0a7e2d577e2c41f9de7a6fe42ace655 (diff)
Move SendJoin logic to GMSL (#3084)
Moves the core matrix logic for handling the send_join endpoint over to gmsl.
Diffstat (limited to 'federationapi/routing')
-rw-r--r--federationapi/routing/join.go321
-rw-r--r--federationapi/routing/routing.go26
2 files changed, 130 insertions, 217 deletions
diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go
index cc22690a..cbdeca51 100644
--- a/federationapi/routing/join.go
+++ b/federationapi/routing/join.go
@@ -16,7 +16,6 @@ package routing
import (
"context"
- "encoding/json"
"fmt"
"net/http"
"sort"
@@ -160,45 +159,43 @@ func MakeJoin(
BuildEventTemplate: createJoinTemplate,
}
response, internalErr := gomatrixserverlib.HandleMakeJoin(input)
- if internalErr != nil {
- switch e := internalErr.(type) {
- case nil:
- case spec.InternalServerError:
- util.GetLogger(httpReq.Context()).WithError(internalErr)
- return util.JSONResponse{
- Code: http.StatusInternalServerError,
- JSON: spec.InternalServerError{},
- }
- case spec.MatrixError:
- util.GetLogger(httpReq.Context()).WithError(internalErr)
- code := http.StatusInternalServerError
- switch e.ErrCode {
- case spec.ErrorForbidden:
- code = http.StatusForbidden
- case spec.ErrorNotFound:
- code = http.StatusNotFound
- case spec.ErrorUnableToAuthoriseJoin:
- code = http.StatusBadRequest
- case spec.ErrorBadJSON:
- code = http.StatusBadRequest
- }
+ switch e := internalErr.(type) {
+ case nil:
+ case spec.InternalServerError:
+ util.GetLogger(httpReq.Context()).WithError(internalErr)
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{},
+ }
+ case spec.MatrixError:
+ util.GetLogger(httpReq.Context()).WithError(internalErr)
+ code := http.StatusInternalServerError
+ switch e.ErrCode {
+ case spec.ErrorForbidden:
+ code = http.StatusForbidden
+ case spec.ErrorNotFound:
+ code = http.StatusNotFound
+ case spec.ErrorUnableToAuthoriseJoin:
+ fallthrough // http.StatusBadRequest
+ case spec.ErrorBadJSON:
+ code = http.StatusBadRequest
+ }
- return util.JSONResponse{
- Code: code,
- JSON: e,
- }
- case spec.IncompatibleRoomVersionError:
- util.GetLogger(httpReq.Context()).WithError(internalErr)
- return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: e,
- }
- default:
- util.GetLogger(httpReq.Context()).WithError(internalErr)
- return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: spec.Unknown("unknown error"),
- }
+ return util.JSONResponse{
+ Code: code,
+ JSON: e,
+ }
+ case spec.IncompatibleRoomVersionError:
+ util.GetLogger(httpReq.Context()).WithError(internalErr)
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: e,
+ }
+ default:
+ util.GetLogger(httpReq.Context()).WithError(internalErr)
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.Unknown("unknown error"),
}
}
@@ -219,6 +216,25 @@ func MakeJoin(
}
}
+type MembershipQuerier struct {
+ roomserver api.FederationRoomserverAPI
+}
+
+func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) {
+ req := api.QueryMembershipForUserRequest{
+ RoomID: roomID.String(),
+ UserID: userID.String(),
+ }
+ res := api.QueryMembershipForUserResponse{}
+ err := mq.roomserver.QueryMembershipForUser(ctx, &req, &res)
+
+ membership := ""
+ if err == nil {
+ membership = res.Membership
+ }
+ return membership, err
+}
+
// SendJoin implements the /send_join API
// The make-join send-join dance makes much more sense as a single
// flow so the cyclomatic complexity is high:
@@ -229,9 +245,10 @@ func SendJoin(
cfg *config.FederationAPI,
rsAPI api.FederationRoomserverAPI,
keys gomatrixserverlib.JSONVerifier,
- roomID, eventID string,
+ roomID spec.RoomID,
+ eventID string,
) util.JSONResponse {
- roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID)
+ roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID.String())
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed")
return util.JSONResponse{
@@ -239,132 +256,71 @@ func SendJoin(
JSON: spec.InternalServerError{},
}
}
- verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion)
- if err != nil {
- return util.JSONResponse{
- Code: http.StatusInternalServerError,
- JSON: spec.UnsupportedRoomVersion(
- fmt.Sprintf("QueryRoomVersionForRoom returned unknown room version: %s", roomVersion),
- ),
- }
- }
- event, err := verImpl.NewEventFromUntrustedJSON(request.Content())
- if err != nil {
+ input := gomatrixserverlib.HandleSendJoinInput{
+ Context: httpReq.Context(),
+ RoomID: roomID,
+ EventID: eventID,
+ JoinEvent: request.Content(),
+ RoomVersion: roomVersion,
+ RequestOrigin: request.Origin(),
+ LocalServerName: cfg.Matrix.ServerName,
+ KeyID: cfg.Matrix.KeyID,
+ PrivateKey: cfg.Matrix.PrivateKey,
+ Verifier: keys,
+ MembershipQuerier: &MembershipQuerier{roomserver: rsAPI},
+ }
+ response, joinErr := gomatrixserverlib.HandleSendJoin(input)
+ switch e := joinErr.(type) {
+ case nil:
+ case spec.InternalServerError:
+ util.GetLogger(httpReq.Context()).WithError(joinErr)
return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: spec.BadJSON("The request body could not be decoded into valid JSON: " + err.Error()),
- }
- }
-
- // Check that a state key is provided.
- if event.StateKey() == nil || event.StateKeyEquals("") {
- return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: spec.BadJSON("No state key was provided in the join event."),
- }
- }
- if !event.StateKeyEquals(event.Sender()) {
- return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: spec.BadJSON("Event state key must match the event sender."),
- }
- }
-
- // Check that the sender belongs to the server that is sending us
- // the request. By this point we've already asserted that the sender
- // and the state key are equal so we don't need to check both.
- var serverName spec.ServerName
- if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil {
- return util.JSONResponse{
- Code: http.StatusForbidden,
- JSON: spec.Forbidden("The sender of the join is invalid"),
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{},
}
- } else if serverName != request.Origin() {
- return util.JSONResponse{
- Code: http.StatusForbidden,
- JSON: spec.Forbidden("The sender does not match the server that originated the request"),
+ case spec.MatrixError:
+ util.GetLogger(httpReq.Context()).WithError(joinErr)
+ code := http.StatusInternalServerError
+ switch e.ErrCode {
+ case spec.ErrorForbidden:
+ code = http.StatusForbidden
+ case spec.ErrorNotFound:
+ code = http.StatusNotFound
+ case spec.ErrorUnsupportedRoomVersion:
+ code = http.StatusInternalServerError
+ case spec.ErrorBadJSON:
+ code = http.StatusBadRequest
}
- }
- // Check that the room ID is correct.
- if event.RoomID() != roomID {
return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: spec.BadJSON(
- fmt.Sprintf(
- "The room ID in the request path (%q) must match the room ID in the join event JSON (%q)",
- roomID, event.RoomID(),
- ),
- ),
+ Code: code,
+ JSON: e,
}
- }
-
- // Check that the event ID is correct.
- if event.EventID() != eventID {
+ default:
+ util.GetLogger(httpReq.Context()).WithError(joinErr)
return util.JSONResponse{
Code: http.StatusBadRequest,
- JSON: spec.BadJSON(
- fmt.Sprintf(
- "The event ID in the request path (%q) must match the event ID in the join event JSON (%q)",
- eventID, event.EventID(),
- ),
- ),
+ JSON: spec.Unknown("unknown error"),
}
}
- // Check that this is in fact a join event
- membership, err := event.Membership()
- if err != nil {
- return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: spec.BadJSON("missing content.membership key"),
- }
- }
- if membership != spec.Join {
- return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: spec.BadJSON("membership must be 'join'"),
- }
- }
-
- // Check that the event is signed by the server sending the request.
- redacted, err := verImpl.RedactEventJSON(event.JSON())
- if err != nil {
- logrus.WithError(err).Errorf("XXX: join.go")
- return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: spec.BadJSON("The event JSON could not be redacted"),
- }
- }
- verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
- ServerName: serverName,
- Message: redacted,
- AtTS: event.OriginServerTS(),
- StrictValidityChecking: true,
- }}
- verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests)
- if err != nil {
- util.GetLogger(httpReq.Context()).WithError(err).Error("keys.VerifyJSONs failed")
+ if response == nil {
+ util.GetLogger(httpReq.Context()).Error("gmsl.HandleMakeJoin returned invalid response")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
- }
- if verifyResults[0].Error != nil {
- return util.JSONResponse{
- Code: http.StatusForbidden,
- JSON: spec.Forbidden("Signature check failed: " + verifyResults[0].Error.Error()),
- }
+
}
// Fetch the state and auth chain. We do this before we send the events
// on, in case this fails.
var stateAndAuthChainResponse api.QueryStateAndAuthChainResponse
err = rsAPI.QueryStateAndAuthChain(httpReq.Context(), &api.QueryStateAndAuthChainRequest{
- PrevEventIDs: event.PrevEventIDs(),
- AuthEventIDs: event.AuthEventIDs(),
- RoomID: roomID,
+ PrevEventIDs: response.JoinEvent.PrevEventIDs(),
+ AuthEventIDs: response.JoinEvent.AuthEventIDs(),
+ RoomID: roomID.String(),
ResolveState: true,
}, &stateAndAuthChainResponse)
if err != nil {
@@ -388,84 +344,27 @@ func SendJoin(
}
}
- // Check if the user is already in the room. If they're already in then
- // there isn't much point in sending another join event into the room.
- // Also check to see if they are banned: if they are then we reject them.
- alreadyJoined := false
- isBanned := false
- for _, se := range stateAndAuthChainResponse.StateEvents {
- if !se.StateKeyEquals(*event.StateKey()) {
- continue
- }
- if membership, merr := se.Membership(); merr == nil {
- alreadyJoined = (membership == spec.Join)
- isBanned = (membership == spec.Ban)
- break
- }
- }
-
- if isBanned {
- return util.JSONResponse{
- Code: http.StatusForbidden,
- JSON: spec.Forbidden("user is banned"),
- }
- }
-
- // If the membership content contains a user ID for a server that is not
- // ours then we should kick it back.
- var memberContent gomatrixserverlib.MemberContent
- if err := json.Unmarshal(event.Content(), &memberContent); err != nil {
- return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: spec.BadJSON(err.Error()),
- }
- }
- if memberContent.AuthorisedVia != "" {
- _, domain, err := gomatrixserverlib.SplitID('@', memberContent.AuthorisedVia)
- if err != nil {
- return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: spec.BadJSON(fmt.Sprintf("The authorising username %q is invalid.", memberContent.AuthorisedVia)),
- }
- }
- if domain != cfg.Matrix.ServerName {
- return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: spec.BadJSON(fmt.Sprintf("The authorising username %q does not belong to this server.", memberContent.AuthorisedVia)),
- }
- }
- }
-
- // Sign the membership event. This is required for restricted joins to work
- // in the case that the authorised via user is one of our own users. It also
- // doesn't hurt to do it even if it isn't a restricted join.
- signed := event.Sign(
- string(cfg.Matrix.ServerName),
- cfg.Matrix.KeyID,
- cfg.Matrix.PrivateKey,
- )
-
// Send the events to the room server.
// We are responsible for notifying other servers that the user has joined
// the room, so set SendAsServer to cfg.Matrix.ServerName
- if !alreadyJoined {
- var response api.InputRoomEventsResponse
+ if !response.AlreadyJoined {
+ var rsResponse api.InputRoomEventsResponse
rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
InputRoomEvents: []api.InputRoomEvent{
{
Kind: api.KindNew,
- Event: &types.HeaderedEvent{PDU: signed},
+ Event: &types.HeaderedEvent{PDU: response.JoinEvent},
SendAsServer: string(cfg.Matrix.ServerName),
TransactionID: nil,
},
},
- }, &response)
- if response.ErrMsg != "" {
- util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).Error("SendEvents failed")
- if response.NotAllowed {
+ }, &rsResponse)
+ if rsResponse.ErrMsg != "" {
+ util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, rsResponse.ErrMsg).Error("SendEvents failed")
+ if rsResponse.NotAllowed {
return util.JSONResponse{
Code: http.StatusBadRequest,
- JSON: spec.Forbidden(response.ErrMsg),
+ JSON: spec.Forbidden(rsResponse.ErrMsg),
}
}
return util.JSONResponse{
@@ -488,7 +387,7 @@ func SendJoin(
StateEvents: types.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.StateEvents),
AuthEvents: types.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.AuthChainEvents),
Origin: cfg.Matrix.ServerName,
- Event: signed.JSON(),
+ Event: response.JoinEvent.JSON(),
},
}
}
diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go
index 44faad91..7be0857a 100644
--- a/federationapi/routing/routing.go
+++ b/federationapi/routing/routing.go
@@ -331,14 +331,14 @@ func Setup(
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
- JSON: spec.BadJSON("Invalid UserID"),
+ JSON: spec.InvalidParam("Invalid UserID"),
}
}
roomID, err := spec.NewRoomID(vars["roomID"])
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
- JSON: spec.BadJSON("Invalid RoomID"),
+ JSON: spec.InvalidParam("Invalid RoomID"),
}
}
@@ -358,10 +358,17 @@ func Setup(
JSON: spec.Forbidden("Forbidden by server ACLs"),
}
}
- roomID := vars["roomID"]
eventID := vars["eventID"]
+ roomID, err := spec.NewRoomID(vars["roomID"])
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.InvalidParam("Invalid RoomID"),
+ }
+ }
+
res := SendJoin(
- httpReq, request, cfg, rsAPI, keys, roomID, eventID,
+ httpReq, request, cfg, rsAPI, keys, *roomID, eventID,
)
// not all responses get wrapped in [code, body]
var body interface{}
@@ -390,10 +397,17 @@ func Setup(
JSON: spec.Forbidden("Forbidden by server ACLs"),
}
}
- roomID := vars["roomID"]
eventID := vars["eventID"]
+ roomID, err := spec.NewRoomID(vars["roomID"])
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.InvalidParam("Invalid RoomID"),
+ }
+ }
+
return SendJoin(
- httpReq, request, cfg, rsAPI, keys, roomID, eventID,
+ httpReq, request, cfg, rsAPI, keys, *roomID, eventID,
)
},
)).Methods(http.MethodPut)