aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--clientapi/clientapi.go3
-rw-r--r--clientapi/routing/joinroom.go10
-rw-r--r--clientapi/routing/joinroom_test.go158
-rw-r--r--roomserver/api/perform.go1
-rw-r--r--roomserver/internal/api.go20
-rw-r--r--roomserver/internal/input/input.go4
-rw-r--r--roomserver/internal/input/input_events.go105
-rw-r--r--roomserver/internal/perform/perform_join.go23
-rw-r--r--roomserver/roomserver_test.go139
-rw-r--r--setup/config/config_global.go2
-rw-r--r--setup/config/config_test.go54
-rw-r--r--sytest-blacklist5
-rw-r--r--sytest-whitelist5
-rw-r--r--test/room.go22
-rw-r--r--userapi/api/api.go10
-rw-r--r--userapi/api/api_trace.go6
-rw-r--r--userapi/internal/api.go5
-rw-r--r--userapi/inthttp/client.go12
-rw-r--r--userapi/inthttp/server.go5
-rw-r--r--userapi/userapi_test.go61
20 files changed, 606 insertions, 44 deletions
diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go
index 62ffa615..2d17e092 100644
--- a/clientapi/clientapi.go
+++ b/clientapi/clientapi.go
@@ -15,6 +15,8 @@
package clientapi
import (
+ "github.com/matrix-org/gomatrixserverlib"
+
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/producers"
@@ -26,7 +28,6 @@ import (
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/gomatrixserverlib"
)
// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component.
diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go
index c50e552b..e371d921 100644
--- a/clientapi/routing/joinroom.go
+++ b/clientapi/routing/joinroom.go
@@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias(
joinReq := roomserverAPI.PerformJoinRequest{
RoomIDOrAlias: roomIDOrAlias,
UserID: device.UserID,
+ IsGuest: device.AccountType == api.AccountTypeGuest,
Content: map[string]interface{}{},
}
joinRes := roomserverAPI.PerformJoinResponse{}
@@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias(
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
done <- jsonerror.InternalAPIError(req.Context(), err)
} else if joinRes.Error != nil {
- done <- joinRes.Error.JSONResponse()
+ if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest {
+ done <- util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg),
+ }
+ } else {
+ done <- joinRes.Error.JSONResponse()
+ }
} else {
done <- util.JSONResponse{
Code: http.StatusOK,
diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go
new file mode 100644
index 00000000..9e8208e6
--- /dev/null
+++ b/clientapi/routing/joinroom_test.go
@@ -0,0 +1,158 @@
+package routing
+
+import (
+ "bytes"
+ "context"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/matrix-org/gomatrixserverlib"
+
+ "github.com/matrix-org/dendrite/appservice"
+ "github.com/matrix-org/dendrite/keyserver"
+ "github.com/matrix-org/dendrite/roomserver"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
+ "github.com/matrix-org/dendrite/userapi"
+ uapi "github.com/matrix-org/dendrite/userapi/api"
+)
+
+func TestJoinRoomByIDOrAlias(t *testing.T) {
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+ charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest))
+
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ base, baseClose := testrig.CreateBaseDendrite(t, dbType)
+ defer baseClose()
+
+ rsAPI := roomserver.NewInternalAPI(base)
+ keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
+ userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
+ asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
+ rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc
+
+ // Create the users in the userapi
+ for _, u := range []*test.User{alice, bob, charlie} {
+ localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
+ userRes := &uapi.PerformAccountCreationResponse{}
+ if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
+ AccountType: u.AccountType,
+ Localpart: localpart,
+ ServerName: serverName,
+ Password: "someRandomPassword",
+ }, userRes); err != nil {
+ t.Errorf("failed to create account: %s", err)
+ }
+
+ }
+
+ aliceDev := &uapi.Device{UserID: alice.ID}
+ bobDev := &uapi.Device{UserID: bob.ID}
+ charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest}
+
+ // create a room with disabled guest access and invite Bob
+ resp := createRoom(ctx, createRoomRequest{
+ Name: "testing",
+ IsDirect: true,
+ Topic: "testing",
+ Visibility: "public",
+ Preset: presetPublicChat,
+ RoomAliasName: "alias",
+ Invite: []string{bob.ID},
+ GuestCanJoin: false,
+ }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
+ crResp, ok := resp.JSON.(createRoomResponse)
+ if !ok {
+ t.Fatalf("response is not a createRoomResponse: %+v", resp)
+ }
+
+ // create a room with guest access enabled and invite Charlie
+ resp = createRoom(ctx, createRoomRequest{
+ Name: "testing",
+ IsDirect: true,
+ Topic: "testing",
+ Visibility: "public",
+ Preset: presetPublicChat,
+ Invite: []string{charlie.ID},
+ GuestCanJoin: true,
+ }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
+ crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse)
+ if !ok {
+ t.Fatalf("response is not a createRoomResponse: %+v", resp)
+ }
+
+ // Dummy request
+ body := &bytes.Buffer{}
+ req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ testCases := []struct {
+ name string
+ device *uapi.Device
+ roomID string
+ wantHTTP200 bool
+ }{
+ {
+ name: "User can join successfully by alias",
+ device: bobDev,
+ roomID: crResp.RoomAlias,
+ wantHTTP200: true,
+ },
+ {
+ name: "User can join successfully by roomID",
+ device: bobDev,
+ roomID: crResp.RoomID,
+ wantHTTP200: true,
+ },
+ {
+ name: "join is forbidden if user is guest",
+ device: charlieDev,
+ roomID: crResp.RoomID,
+ },
+ {
+ name: "room does not exist",
+ device: aliceDev,
+ roomID: "!doesnotexist:test",
+ },
+ {
+ name: "user from different server",
+ device: &uapi.Device{UserID: "@wrong:server"},
+ roomID: crResp.RoomAlias,
+ },
+ {
+ name: "user doesn't exist locally",
+ device: &uapi.Device{UserID: "@doesnotexist:test"},
+ roomID: crResp.RoomAlias,
+ },
+ {
+ name: "invalid room ID",
+ device: aliceDev,
+ roomID: "invalidRoomID",
+ },
+ {
+ name: "roomAlias does not exist",
+ device: aliceDev,
+ roomID: "#doesnotexist:test",
+ },
+ {
+ name: "room with guest_access event",
+ device: charlieDev,
+ roomID: crRespWithGuestAccess.RoomID,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID)
+ if tc.wantHTTP200 && !joinResp.Is2xx() {
+ t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp)
+ }
+ })
+ }
+ })
+}
diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go
index e70e5ea9..e789b956 100644
--- a/roomserver/api/perform.go
+++ b/roomserver/api/perform.go
@@ -78,6 +78,7 @@ const (
type PerformJoinRequest struct {
RoomIDOrAlias string `json:"room_id_or_alias"`
UserID string `json:"user_id"`
+ IsGuest bool `json:"is_guest"`
Content map[string]interface{} `json:"content"`
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
Unsigned map[string]interface{} `json:"unsigned"`
diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go
index 1a362660..451b3769 100644
--- a/roomserver/internal/api.go
+++ b/roomserver/internal/api.go
@@ -4,6 +4,10 @@ import (
"context"
"github.com/getsentry/sentry-go"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
+ "github.com/sirupsen/logrus"
+
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/caching"
@@ -19,9 +23,6 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/nats-io/nats.go"
- "github.com/sirupsen/logrus"
)
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
@@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
r.fsAPI = fsAPI
r.KeyRing = keyRing
+ identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName)
+ if err != nil {
+ logrus.Panic(err)
+ }
+
r.Inputer = &input.Inputer{
Cfg: &r.Base.Cfg.RoomServer,
Base: r.Base,
@@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
JetStream: r.JetStream,
NATSClient: r.NATSClient,
Durable: nats.Durable(r.Durable),
- ServerName: r.Cfg.Matrix.ServerName,
+ ServerName: r.ServerName,
+ SigningIdentity: identity,
FSAPI: fsAPI,
KeyRing: keyRing,
ACLs: r.ServerACLs,
@@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
Queryer: r.Queryer,
}
r.Peeker = &perform.Peeker{
- ServerName: r.Cfg.Matrix.ServerName,
+ ServerName: r.ServerName,
Cfg: r.Cfg,
DB: r.DB,
FSAPI: r.fsAPI,
@@ -146,7 +153,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
Inputer: r.Inputer,
}
r.Unpeeker = &perform.Unpeeker{
- ServerName: r.Cfg.Matrix.ServerName,
+ ServerName: r.ServerName,
Cfg: r.Cfg,
DB: r.DB,
FSAPI: r.fsAPI,
@@ -193,6 +200,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
r.Leaver.UserAPI = userAPI
+ r.Inputer.UserAPI = userAPI
}
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go
index e965691c..94131103 100644
--- a/roomserver/internal/input/input.go
+++ b/roomserver/internal/input/input.go
@@ -23,6 +23,8 @@ import (
"sync"
"time"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
+
"github.com/Arceliar/phony"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
@@ -79,6 +81,7 @@ type Inputer struct {
JetStream nats.JetStreamContext
Durable nats.SubOpt
ServerName gomatrixserverlib.ServerName
+ SigningIdentity *gomatrixserverlib.SigningIdentity
FSAPI fedapi.RoomserverFederationAPI
KeyRing gomatrixserverlib.JSONVerifier
ACLs *acls.ServerACLs
@@ -87,6 +90,7 @@ type Inputer struct {
workers sync.Map // room ID -> *worker
Queryer *query.Queryer
+ UserAPI userapi.RoomserverUserAPI
}
// If a room consumer is inactive for a while then we will allow NATS
diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go
index 10b8ee27..4179fc1e 100644
--- a/roomserver/internal/input/input_events.go
+++ b/roomserver/internal/input/input_events.go
@@ -19,6 +19,7 @@ package input
import (
"context"
"database/sql"
+ "encoding/json"
"errors"
"fmt"
"time"
@@ -31,6 +32,8 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
+ userAPI "github.com/matrix-org/dendrite/userapi/api"
+
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil"
@@ -440,6 +443,13 @@ func (r *Inputer) processRoomEvent(
}
}
+ // If guest_access changed and is not can_join, kick all guest users.
+ if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" {
+ if err = r.kickGuests(ctx, event, roomInfo); err != nil {
+ logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation")
+ }
+ }
+
// Everything was OK — the latest events updater didn't error and
// we've sent output events. Finally, generate a hook call.
hooks.Run(hooks.KindNewEventPersisted, headered)
@@ -729,3 +739,98 @@ func (r *Inputer) calculateAndSetState(
succeeded = true
return nil
}
+
+// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited.
+func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error {
+ membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
+ if err != nil {
+ return err
+ }
+
+ memberEvents, err := r.DB.Events(ctx, membershipNIDs)
+ if err != nil {
+ return err
+ }
+
+ inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
+ latestReq := &api.QueryLatestEventsAndStateRequest{
+ RoomID: event.RoomID(),
+ }
+ latestRes := &api.QueryLatestEventsAndStateResponse{}
+ if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
+ return err
+ }
+
+ prevEvents := latestRes.LatestEvents
+ for _, memberEvent := range memberEvents {
+ if memberEvent.StateKey() == nil {
+ continue
+ }
+
+ localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
+ if err != nil {
+ continue
+ }
+
+ accountRes := &userAPI.QueryAccountByLocalpartResponse{}
+ if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
+ Localpart: localpart,
+ ServerName: senderDomain,
+ }, accountRes); err != nil {
+ return err
+ }
+ if accountRes.Account == nil {
+ continue
+ }
+
+ if accountRes.Account.AccountType != userAPI.AccountTypeGuest {
+ continue
+ }
+
+ var memberContent gomatrixserverlib.MemberContent
+ if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil {
+ return err
+ }
+ memberContent.Membership = gomatrixserverlib.Leave
+
+ stateKey := *memberEvent.StateKey()
+ fledglingEvent := &gomatrixserverlib.EventBuilder{
+ RoomID: event.RoomID(),
+ Type: gomatrixserverlib.MRoomMember,
+ StateKey: &stateKey,
+ Sender: stateKey,
+ PrevEvents: prevEvents,
+ }
+
+ if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
+ return err
+ }
+
+ eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
+ if err != nil {
+ return err
+ }
+
+ event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes)
+ if err != nil {
+ return err
+ }
+
+ inputEvents = append(inputEvents, api.InputRoomEvent{
+ Kind: api.KindNew,
+ Event: event,
+ Origin: senderDomain,
+ SendAsServer: string(senderDomain),
+ })
+ prevEvents = []gomatrixserverlib.EventReference{
+ event.EventReference(),
+ }
+ }
+
+ inputReq := &api.InputRoomEventsRequest{
+ InputRoomEvents: inputEvents,
+ Asynchronous: true, // Needs to be async, as we otherwise create a deadlock
+ }
+ inputRes := &api.InputRoomEventsResponse{}
+ return r.InputRoomEvents(ctx, inputReq, inputRes)
+}
diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go
index 4de008c6..fc7ba940 100644
--- a/roomserver/internal/perform/perform_join.go
+++ b/roomserver/internal/perform/perform_join.go
@@ -16,6 +16,7 @@ package perform
import (
"context"
+ "database/sql"
"errors"
"fmt"
"strings"
@@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID(
}
}
+ // If a guest is trying to join a room, check that the room has a m.room.guest_access event
+ if req.IsGuest {
+ var guestAccessEvent *gomatrixserverlib.HeaderedEvent
+ guestAccess := "forbidden"
+ guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "")
+ if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil {
+ logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'")
+ }
+ if guestAccessEvent != nil {
+ guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String()
+ }
+
+ // Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
+ // is present on the room and has the guest_access value can_join.
+ if guestAccess != "can_join" {
+ return "", "", &rsAPI.PerformError{
+ Code: rsAPI.PerformErrorNotAllowed,
+ Msg: "Guest access is forbidden",
+ }
+ }
+ }
+
// If we should do a forced federated join then do that.
var joinedVia gomatrixserverlib.ServerName
if forceFederatedJoin {
diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go
index 518bb372..595ceb52 100644
--- a/roomserver/roomserver_test.go
+++ b/roomserver/roomserver_test.go
@@ -3,18 +3,23 @@ package roomserver_test
import (
"context"
"net/http"
+ "reflect"
"testing"
"time"
"github.com/gorilla/mux"
+ "github.com/matrix-org/dendrite/internal/httputil"
+ "github.com/matrix-org/dendrite/setup/base"
+ "github.com/matrix-org/dendrite/userapi"
+
+ userAPI "github.com/matrix-org/dendrite/userapi/api"
+
"github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/inthttp"
"github.com/matrix-org/dendrite/roomserver/storage"
- "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
)
@@ -29,7 +34,28 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, s
return base, db, close
}
-func Test_SharedUsers(t *testing.T) {
+func TestUsers(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ base, close := testrig.CreateBaseDendrite(t, dbType)
+ defer close()
+ rsAPI := roomserver.NewInternalAPI(base)
+ // SetFederationAPI starts the room event input consumer
+ rsAPI.SetFederationAPI(nil, nil)
+
+ t.Run("shared users", func(t *testing.T) {
+ testSharedUsers(t, rsAPI)
+ })
+
+ t.Run("kick users", func(t *testing.T) {
+ usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil)
+ rsAPI.SetUserAPI(usrAPI)
+ testKickUsers(t, rsAPI, usrAPI)
+ })
+ })
+
+}
+
+func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) {
alice := test.NewUser(t)
bob := test.NewUser(t)
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
@@ -43,36 +69,93 @@ func Test_SharedUsers(t *testing.T) {
}, test.WithStateKey(bob.ID))
ctx := context.Background()
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- base, _, close := mustCreateDatabase(t, dbType)
- defer close()
- rsAPI := roomserver.NewInternalAPI(base)
- // SetFederationAPI starts the room event input consumer
- rsAPI.SetFederationAPI(nil, nil)
- // Create the room
- if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
- t.Fatalf("failed to send events: %v", err)
- }
+ // Create the room
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
- // Query the shared users for Alice, there should only be Bob.
- // This is used by the SyncAPI keychange consumer.
- res := &api.QuerySharedUsersResponse{}
- if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
- t.Fatalf("unable to query known users: %v", err)
- }
- if _, ok := res.UserIDsToCount[bob.ID]; !ok {
- t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
+ // Query the shared users for Alice, there should only be Bob.
+ // This is used by the SyncAPI keychange consumer.
+ res := &api.QuerySharedUsersResponse{}
+ if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
+ t.Errorf("unable to query known users: %v", err)
+ }
+ if _, ok := res.UserIDsToCount[bob.ID]; !ok {
+ t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
+ }
+ // Also verify that we get the expected result when specifying OtherUserIDs.
+ // This is used by the SyncAPI when getting device list changes.
+ if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
+ t.Errorf("unable to query known users: %v", err)
+ }
+ if _, ok := res.UserIDsToCount[bob.ID]; !ok {
+ t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
+ }
+}
+
+func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) {
+ // Create users and room; Bob is going to be the guest and kicked on revocation of guest access
+ alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser))
+ bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest))
+
+ room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true))
+
+ // Join with the guest user
+ room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(bob.ID))
+
+ ctx := context.Background()
+
+ // Create the users in the userapi, so the RSAPI can query the account type later
+ for _, u := range []*test.User{alice, bob} {
+ localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
+ userRes := &userAPI.PerformAccountCreationResponse{}
+ if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{
+ AccountType: u.AccountType,
+ Localpart: localpart,
+ ServerName: serverName,
+ Password: "someRandomPassword",
+ }, userRes); err != nil {
+ t.Errorf("failed to create account: %s", err)
}
- // Also verify that we get the expected result when specifying OtherUserIDs.
- // This is used by the SyncAPI when getting device list changes.
- if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
- t.Fatalf("unable to query known users: %v", err)
+ }
+
+ // Create the room in the database
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
+
+ // Get the membership events BEFORE revoking guest access
+ membershipRes := &api.QueryMembershipsForRoomResponse{}
+ if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil {
+ t.Errorf("failed to query membership for room: %s", err)
+ }
+
+ // revoke guest access
+ revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey(""))
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
+
+ // TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need
+ // to loop and wait for the events to be processed by the roomserver.
+ for i := 0; i <= 20; i++ {
+ // Get the membership events AFTER revoking guest access
+ membershipRes2 := &api.QueryMembershipsForRoomResponse{}
+ if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil {
+ t.Errorf("failed to query membership for room: %s", err)
}
- if _, ok := res.UserIDsToCount[bob.ID]; !ok {
- t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
+
+ // The membership events should NOT match, as Bob (guest user) should now be kicked from the room
+ if !reflect.DeepEqual(membershipRes, membershipRes2) {
+ return
}
- })
+ time.Sleep(time.Millisecond * 10)
+ }
+
+ t.Errorf("memberships didn't change in time")
}
func Test_QueryLeftUsers(t *testing.T) {
diff --git a/setup/config/config_global.go b/setup/config/config_global.go
index 511951fe..804eb1a2 100644
--- a/setup/config/config_global.go
+++ b/setup/config/config_global.go
@@ -174,7 +174,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g
return id, nil
}
}
- return nil, fmt.Errorf("no signing identity %q", serverName)
+ return nil, fmt.Errorf("no signing identity for %q", serverName)
}
func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity {
diff --git a/setup/config/config_test.go b/setup/config/config_test.go
index ee7e7389..3408bf46 100644
--- a/setup/config/config_test.go
+++ b/setup/config/config_test.go
@@ -16,8 +16,10 @@ package config
import (
"fmt"
+ "reflect"
"testing"
+ "github.com/matrix-org/gomatrixserverlib"
"gopkg.in/yaml.v2"
)
@@ -290,3 +292,55 @@ func TestUnmarshalDataUnit(t *testing.T) {
}
}
}
+
+func Test_SigningIdentityFor(t *testing.T) {
+ tests := []struct {
+ name string
+ virtualHosts []*VirtualHost
+ serverName gomatrixserverlib.ServerName
+ want *gomatrixserverlib.SigningIdentity
+ wantErr bool
+ }{
+ {
+ name: "no virtual hosts defined",
+ wantErr: true,
+ },
+ {
+ name: "no identity found",
+ serverName: gomatrixserverlib.ServerName("doesnotexist"),
+ wantErr: true,
+ },
+ {
+ name: "found identity",
+ serverName: gomatrixserverlib.ServerName("main"),
+ want: &gomatrixserverlib.SigningIdentity{ServerName: "main"},
+ },
+ {
+ name: "identity found on virtual hosts",
+ serverName: gomatrixserverlib.ServerName("vh2"),
+ virtualHosts: []*VirtualHost{
+ {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}},
+ {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}},
+ },
+ want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ c := &Global{
+ VirtualHosts: tt.virtualHosts,
+ SigningIdentity: gomatrixserverlib.SigningIdentity{
+ ServerName: "main",
+ },
+ }
+ got, err := c.SigningIdentityFor(tt.serverName)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/sytest-blacklist b/sytest-blacklist
index c35b03bd..99cfbabc 100644
--- a/sytest-blacklist
+++ b/sytest-blacklist
@@ -48,4 +48,7 @@ If a device list update goes missing, the server resyncs on the next one
Leaves are present in non-gapped incremental syncs
# Below test was passing for the wrong reason, failing correctly since #2858
-New federated private chats get full presence information (SYN-115) \ No newline at end of file
+New federated private chats get full presence information (SYN-115)
+
+# We don't have any state to calculate m.room.guest_access when accepting invites
+Guest users can accept invites to private rooms over federation \ No newline at end of file
diff --git a/sytest-whitelist b/sytest-whitelist
index 49ffb8fe..215889a4 100644
--- a/sytest-whitelist
+++ b/sytest-whitelist
@@ -763,4 +763,7 @@ AS and main public room lists are separate
local user has tags copied to the new room
remote user has tags copied to the new room
/upgrade moves remote aliases to the new room
-Local and remote users' homeservers remove a room from their public directory on upgrade \ No newline at end of file
+Local and remote users' homeservers remove a room from their public directory on upgrade
+Guest users denied access over federation if guest access prohibited
+Guest users are kicked from guest_access rooms on revocation of guest_access
+Guest users are kicked from guest_access rooms on revocation of guest_access over federation \ No newline at end of file
diff --git a/test/room.go b/test/room.go
index 4328bf84..685876cb 100644
--- a/test/room.go
+++ b/test/room.go
@@ -38,11 +38,12 @@ var (
)
type Room struct {
- ID string
- Version gomatrixserverlib.RoomVersion
- preset Preset
- visibility gomatrixserverlib.HistoryVisibility
- creator *User
+ ID string
+ Version gomatrixserverlib.RoomVersion
+ preset Preset
+ guestCanJoin bool
+ visibility gomatrixserverlib.HistoryVisibility
+ creator *User
authEvents gomatrixserverlib.AuthEvents
currentState map[string]*gomatrixserverlib.HeaderedEvent
@@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) {
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
+ if r.guestCanJoin {
+ r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{
+ "guest_access": "can_join",
+ }, WithStateKey(""))
+ }
}
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
@@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
r.Version = ver
}
}
+
+func GuestsCanJoin(canJoin bool) roomModifier {
+ return func(t *testing.T, r *Room) {
+ r.guestCanJoin = canJoin
+ }
+}
diff --git a/userapi/api/api.go b/userapi/api/api.go
index d3f5aefc..4ea2e91c 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -50,6 +50,7 @@ type KeyserverUserAPI interface {
type RoomserverUserAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
+ QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
}
// api functions required by the media api
@@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct {
ServerName gomatrixserverlib.ServerName
Medium string
}
+
+type QueryAccountByLocalpartRequest struct {
+ Localpart string
+ ServerName gomatrixserverlib.ServerName
+}
+
+type QueryAccountByLocalpartResponse struct {
+ Account *Account
+}
diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go
index ce661770..d10b5767 100644
--- a/userapi/api/api_trace.go
+++ b/userapi/api/api_trace.go
@@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex
return err
}
+func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error {
+ err := t.Impl.QueryAccountByLocalpart(ctx, req, res)
+ util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res))
+ return err
+}
+
func js(thing interface{}) string {
b, err := json.Marshal(thing)
if err != nil {
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index 3f256457..0bb480da 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil
}
+func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) {
+ res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName)
+ return
+}
+
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
// creating a 'device'.
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {
diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go
index 87ae058c..51b0fe3e 100644
--- a/userapi/inthttp/client.go
+++ b/userapi/inthttp/client.go
@@ -60,6 +60,7 @@ const (
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
+ QueryAccountByLocalpartPath = "/userapi/queryAccountType"
)
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(
h.httpClient, ctx, request, response,
)
}
+
+func (h *httpUserInternalAPI) QueryAccountByLocalpart(
+ ctx context.Context,
+ req *api.QueryAccountByLocalpartRequest,
+ res *api.QueryAccountByLocalpartResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath,
+ h.httpClient, ctx, req, res,
+ )
+}
diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go
index f0579079..b40b507c 100644
--- a/userapi/inthttp/server.go
+++ b/userapi/inthttp/server.go
@@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics
PerformSaveThreePIDAssociationPath,
httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation),
)
+
+ internalAPIMux.Handle(
+ QueryAccountByLocalpartPath,
+ httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart),
+ )
}
diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go
index 8a19af19..dada56de 100644
--- a/userapi/userapi_test.go
+++ b/userapi/userapi_test.go
@@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) {
})
})
}
+
+func TestQueryAccountByLocalpart(t *testing.T) {
+ alice := test.NewUser(t)
+
+ localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID)
+
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
+ defer close()
+
+ createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType)
+ if err != nil {
+ t.Error(err)
+ }
+
+ testCases := func(t *testing.T, internalAPI api.UserInternalAPI) {
+ // Query existing account
+ queryAccResp := &api.QueryAccountByLocalpartResponse{}
+ if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
+ Localpart: localpart,
+ ServerName: userServername,
+ }, queryAccResp); err != nil {
+ t.Error(err)
+ }
+ if !reflect.DeepEqual(createdAcc, queryAccResp.Account) {
+ t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account)
+ }
+
+ // Query non-existent account, this should result in an error
+ err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
+ Localpart: "doesnotexist",
+ ServerName: userServername,
+ }, queryAccResp)
+
+ if err == nil {
+ t.Fatalf("expected an error, but got none: %+v", queryAccResp)
+ }
+ }
+
+ t.Run("Monolith", func(t *testing.T) {
+ testCases(t, intAPI)
+ // also test tracing
+ testCases(t, &api.UserInternalAPITrace{Impl: intAPI})
+ })
+
+ t.Run("HTTP API", func(t *testing.T) {
+ router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
+ userapi.AddInternalRoutes(router, intAPI, false)
+ apiURL, cancel := test.ListenAndServe(t, router, false)
+ defer cancel()
+
+ userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5})
+ if err != nil {
+ t.Fatalf("failed to create HTTP client: %s", err)
+ }
+ testCases(t, userHTTPApi)
+
+ })
+ })
+}