diff options
author | kegsay <kegan@matrix.org> | 2022-05-17 13:23:35 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-17 13:23:35 +0100 |
commit | 6de29c1cd23d218f04d2e570932db8967d6adc4f (patch) | |
tree | b95fa478ef9ecd2c21963868a3626063bdff7cbc | |
parent | cd82460513d5abf04e56c01667d56499d4c354be (diff) |
bugfix: E2EE device keys could sometimes not be sent to remote servers (#2466)
* Fix flakey sytest 'Local device key changes get to remote servers'
* Debug logs
* Remove internal/test and use /test only
Remove a lot of ancient code too.
* Use FederationRoomserverAPI in more places
* Use more interfaces in federationapi; begin adding regression test
* Linting
* Add regression test
* Unbreak tests
* ALL THE LOGS
* Fix a race condition which could cause events to not be sent to servers
If a new room event which rewrites state arrives, we remove all joined hosts
then re-calculate them. This wasn't done in a transaction so for a brief period
we would have no joined hosts. During this interim, key change events which arrive
would not be sent to destination servers. This would sporadically fail on sytest.
* Unbreak new tests
* Linting
48 files changed, 565 insertions, 617 deletions
diff --git a/cmd/generate-keys/main.go b/cmd/generate-keys/main.go index bddf219d..8acd28be 100644 --- a/cmd/generate-keys/main.go +++ b/cmd/generate-keys/main.go @@ -20,7 +20,7 @@ import ( "log" "os" - "github.com/matrix-org/dendrite/internal/test" + "github.com/matrix-org/dendrite/test" ) const usage = `Usage: %s diff --git a/federationapi/api/api.go b/federationapi/api/api.go index fc25194e..53d4701f 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -12,12 +12,16 @@ import ( // FederationInternalAPI is used to query information from the federation sender. type FederationInternalAPI interface { - FederationClient + gomatrixserverlib.FederatedStateClient + KeyserverFederationAPI gomatrixserverlib.KeyDatabase ClientFederationAPI RoomserverFederationAPI QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error + LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) + MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) + MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. PerformBroadcastEDU( @@ -60,17 +64,43 @@ type RoomserverFederationAPI interface { LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } -// FederationClient is a subset of gomatrixserverlib.FederationClient functions which the fedsender +// KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError -type FederationClient interface { - gomatrixserverlib.FederatedStateClient +type KeyserverFederationAPI interface { GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) +} + +// an interface for gmsl.FederationClient - contains functions called by federationapi only. +type FederationClient interface { + gomatrixserverlib.KeyClient + SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) + + // Perform operations + LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) + Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) + MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) + SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) + MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) + SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) + SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) + + GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + + GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) + GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) + ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) + QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) + Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) - LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) + + ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) + LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) + LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) + LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } // FederationClientError is returned from FederationClient methods in the event of a problem. diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 0ece18e9..95c9a7fd 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -39,7 +39,7 @@ type KeyChangeConsumer struct { db storage.Database queues *queue.OutgoingQueues serverName gomatrixserverlib.ServerName - rsAPI roomserverAPI.RoomserverInternalAPI + rsAPI roomserverAPI.FederationRoomserverAPI topic string } @@ -50,7 +50,7 @@ func NewKeyChangeConsumer( js nats.JetStreamContext, queues *queue.OutgoingQueues, store storage.Database, - rsAPI roomserverAPI.RoomserverInternalAPI, + rsAPI roomserverAPI.FederationRoomserverAPI, ) *KeyChangeConsumer { return &KeyChangeConsumer{ ctx: process.Context(), @@ -120,6 +120,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { logger.WithError(err).Error("failed to calculate joined rooms for user") return true } + logrus.Infof("DEBUG: %v joined rooms for user %v", queryRes.RoomIDs, m.UserID) // send this key change to all servers who share rooms with this user. destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { @@ -128,6 +129,9 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { } if len(destinations) == 0 { + logger.WithField("num_rooms", len(queryRes.RoomIDs)).Debug("user is in no federated rooms") + destinations, err = t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, false) + logrus.Infof("GetJoinedHostsForRooms exclude self=false -> %v %v", destinations, err) return true } // Pack the EDU and marshal it diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 80317ee6..7a0816ff 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/federationapi/queue" @@ -36,7 +37,7 @@ import ( type OutputRoomEventConsumer struct { ctx context.Context cfg *config.FederationAPI - rsAPI api.RoomserverInternalAPI + rsAPI api.FederationRoomserverAPI jetstream nats.JetStreamContext durable string db storage.Database @@ -51,7 +52,7 @@ func NewOutputRoomEventConsumer( js nats.JetStreamContext, queues *queue.OutgoingQueues, store storage.Database, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ ctx: process.Context(), @@ -89,15 +90,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) switch output.Type { case api.OutputTypeNewRoomEvent: ev := output.NewRoomEvent.Event - - if output.NewRoomEvent.RewritesState { - if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil { - log.WithError(err).Errorf("roomserver output log: purge room state failure") - return false - } - } - - if err := s.processMessage(*output.NewRoomEvent); err != nil { + if err := s.processMessage(*output.NewRoomEvent, output.NewRoomEvent.RewritesState); err != nil { // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ "event_id": ev.EventID(), @@ -145,7 +138,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee // processMessage updates the list of currently joined hosts in the room // and then sends the event to the hosts that were joined before the event. -func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error { +func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error { addsStateEvents, missingEventIDs := ore.NeededStateEventIDs() // Ask the roomserver and add in the rest of the results into the set. @@ -164,7 +157,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err addsStateEvents = append(addsStateEvents, eventsRes.Events...) } - addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents)) + addsJoinedHosts, err := JoinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents)) if err != nil { return err } @@ -173,13 +166,13 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err // expressed as a delta against the current state. // TODO(#290): handle EventIDMismatchError and recover the current state by // talking to the roomserver + logrus.Infof("room %s adds joined hosts: %v removes %v", ore.Event.RoomID(), addsJoinedHosts, ore.RemovesStateEventIDs) oldJoinedHosts, err := s.db.UpdateRoom( s.ctx, ore.Event.RoomID(), - ore.LastSentEventID, - ore.Event.EventID(), addsJoinedHosts, ore.RemovesStateEventIDs, + rewritesState, // if we're re-writing state, nuke all joined hosts before adding ) if err != nil { return err @@ -238,7 +231,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( return nil, err } - combinedAddsJoinedHosts, err := joinedHostsFromEvents(combinedAddsEvents) + combinedAddsJoinedHosts, err := JoinedHostsFromEvents(combinedAddsEvents) if err != nil { return nil, err } @@ -284,10 +277,10 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( return result, nil } -// joinedHostsFromEvents turns a list of state events into a list of joined hosts. +// JoinedHostsFromEvents turns a list of state events into a list of joined hosts. // This errors if one of the events was invalid. // It should be impossible for an invalid event to get this far in the pipeline. -func joinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) { +func JoinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) { var joinedHosts []types.JoinedHost for _, ev := range evs { if ev.Type() != "m.room.member" || ev.StateKey() == nil { diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index bec9ac77..ff159bee 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -93,8 +93,8 @@ func AddPublicRoutes( // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( base *base.BaseDendrite, - federation *gomatrixserverlib.FederationClient, - rsAPI roomserverAPI.RoomserverInternalAPI, + federation api.FederationClient, + rsAPI roomserverAPI.FederationRoomserverAPI, caches *caching.Caches, keyRing *gomatrixserverlib.KeyRing, resetBlacklist bool, diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index eedebc6c..ae244c56 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -3,18 +3,250 @@ package federationapi_test import ( "context" "crypto/ed25519" + "encoding/json" + "fmt" "strings" "testing" + "time" "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/internal" - "github.com/matrix-org/dendrite/internal/test" + keyapi "github.com/matrix-org/dendrite/keyserver/api" + rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" ) +type fedRoomserverAPI struct { + rsapi.FederationRoomserverAPI + inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) + queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error +} + +// PerformJoin will call this function +func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { + if f.inputRoomEvents == nil { + return + } + f.inputRoomEvents(ctx, req, res) +} + +// keychange consumer calls this +func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error { + if f.queryRoomsForUser == nil { + return nil + } + return f.queryRoomsForUser(ctx, req, res) +} + +// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate +type fedClient struct { + api.FederationClient + allowJoins []*test.Room + keys map[gomatrixserverlib.ServerName]struct { + key ed25519.PrivateKey + keyID gomatrixserverlib.KeyID + } + t *testing.T + sentTxn bool +} + +func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) { + fmt.Println("GetServerKeys:", matrixServer) + var keys gomatrixserverlib.ServerKeys + var keyID gomatrixserverlib.KeyID + var pkey ed25519.PrivateKey + for srv, data := range f.keys { + if srv == matrixServer { + pkey = data.key + keyID = data.keyID + break + } + } + if pkey == nil { + return keys, nil + } + + keys.ServerName = matrixServer + keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(10 * time.Hour)) + publicKey := pkey.Public().(ed25519.PublicKey) + keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ + keyID: { + Key: gomatrixserverlib.Base64Bytes(publicKey), + }, + } + toSign, err := json.Marshal(keys.ServerKeyFields) + if err != nil { + return keys, err + } + + keys.Raw, err = gomatrixserverlib.SignJSON( + string(matrixServer), keyID, pkey, toSign, + ) + if err != nil { + return keys, err + } + + return keys, nil +} + +func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { + for _, r := range f.allowJoins { + if r.ID == roomID { + res.RoomVersion = r.Version + res.JoinEvent = gomatrixserverlib.EventBuilder{ + Sender: userID, + RoomID: roomID, + Type: "m.room.member", + StateKey: &userID, + Content: gomatrixserverlib.RawJSON([]byte(`{"membership":"join"}`)), + PrevEvents: r.ForwardExtremities(), + } + var needed gomatrixserverlib.StateNeeded + needed, err = gomatrixserverlib.StateNeededForEventBuilder(&res.JoinEvent) + if err != nil { + f.t.Errorf("StateNeededForEventBuilder: %v", err) + return + } + res.JoinEvent.AuthEvents = r.MustGetAuthEventRefsForEvent(f.t, needed) + return + } + } + return +} +func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { + for _, r := range f.allowJoins { + if r.ID == event.RoomID() { + r.InsertEvent(f.t, event.Headered(r.Version)) + f.t.Logf("Join event: %v", event.EventID()) + res.StateEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.CurrentState()) + res.AuthEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.Events()) + } + } + return +} + +func (f *fedClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { + for _, edu := range t.EDUs { + if edu.Type == gomatrixserverlib.MDeviceListUpdate { + f.sentTxn = true + } + } + f.t.Logf("got /send") + return +} + +// Regression test to make sure that /send_join is updating the destination hosts synchronously and +// isn't relying on the roomserver. +func TestFederationAPIJoinThenKeyUpdate(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + testFederationAPIJoinThenKeyUpdate(t, dbType) + }) +} + +func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + base.Cfg.FederationAPI.PreferDirectFetch = true + defer close() + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + + serverA := gomatrixserverlib.ServerName("server.a") + serverAKeyID := gomatrixserverlib.KeyID("ed25519:servera") + serverAPrivKey := test.PrivateKeyA + creator := test.NewUser(t, test.WithSigningServer(serverA, serverAKeyID, serverAPrivKey)) + + myServer := base.Cfg.Global.ServerName + myServerKeyID := base.Cfg.Global.KeyID + myServerPrivKey := base.Cfg.Global.PrivateKey + joiningUser := test.NewUser(t, test.WithSigningServer(myServer, myServerKeyID, myServerPrivKey)) + fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID) + room := test.NewRoom(t, creator) + + rsapi := &fedRoomserverAPI{ + inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { + if req.Asynchronous { + t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous") + } + }, + queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error { + if req.UserID == joiningUser.ID && req.WantMembership == "join" { + res.RoomIDs = []string{room.ID} + return nil + } + return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req) + }, + } + fc := &fedClient{ + allowJoins: []*test.Room{room}, + t: t, + keys: map[gomatrixserverlib.ServerName]struct { + key ed25519.PrivateKey + keyID gomatrixserverlib.KeyID + }{ + serverA: { + key: serverAPrivKey, + keyID: serverAKeyID, + }, + myServer: { + key: myServerPrivKey, + keyID: myServerKeyID, + }, + }, + } + fsapi := federationapi.NewInternalAPI(base, fc, rsapi, base.Caches, nil, false) + + var resp api.PerformJoinResponse + fsapi.PerformJoin(context.Background(), &api.PerformJoinRequest{ + RoomID: room.ID, + UserID: joiningUser.ID, + ServerNames: []gomatrixserverlib.ServerName{serverA}, + }, &resp) + if resp.JoinedVia != serverA { + t.Errorf("PerformJoin: joined via %v want %v", resp.JoinedVia, serverA) + } + if resp.LastError != nil { + t.Fatalf("PerformJoin: returned error: %+v", *resp.LastError) + } + + // Inject a keyserver key change event and ensure we try to send it out. If we don't, then the + // federationapi is incorrectly waiting for an output room event to arrive to update the joined + // hosts table. + key := keyapi.DeviceMessage{ + Type: keyapi.TypeDeviceKeyUpdate, + DeviceKeys: &keyapi.DeviceKeys{ + UserID: joiningUser.ID, + DeviceID: "MY_DEVICE", + DisplayName: "BLARGLE", + KeyJSON: []byte(`{}`), + }, + } + b, err := json.Marshal(key) + if err != nil { + t.Fatalf("Failed to marshal device message: %s", err) + } + + msg := &nats.Msg{ + Subject: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), + Header: nats.Header{}, + Data: b, + } + msg.Header.Set(jetstream.UserID, key.UserID) + + testrig.MustPublishMsgs(t, jsctx, msg) + time.Sleep(500 * time.Millisecond) + if !fc.sentTxn { + t.Fatalf("did not send device list update") + } +} + // Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404. // Relevant for v3 rooms and a cause of flakey sytests as the IDs are randomly generated. func TestRoomsV3URLEscapeDoNot404(t *testing.T) { @@ -86,7 +318,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { } gerr, ok := err.(gomatrix.HTTPError) if !ok { - t.Errorf("failed to cast response error as gomatrix.HTTPError") + t.Errorf("failed to cast response error as gomatrix.HTTPError: %s", err) continue } t.Logf("Error: %+v", gerr) diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 4e9fa841..14056eaf 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -25,8 +25,8 @@ type FederationInternalAPI struct { db storage.Database cfg *config.FederationAPI statistics *statistics.Statistics - rsAPI roomserverAPI.RoomserverInternalAPI - federation *gomatrixserverlib.FederationClient + rsAPI roomserverAPI.FederationRoomserverAPI + federation api.FederationClient keyRing *gomatrixserverlib.KeyRing queues *queue.OutgoingQueues joins sync.Map // joins currently in progress @@ -34,8 +34,8 @@ type FederationInternalAPI struct { func NewFederationInternalAPI( db storage.Database, cfg *config.FederationAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, - federation *gomatrixserverlib.FederationClient, + rsAPI roomserverAPI.FederationRoomserverAPI, + federation api.FederationClient, statistics *statistics.Statistics, caches *caching.Caches, queues *queue.OutgoingQueues, diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 577cb70e..7ccd68ef 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -8,6 +8,7 @@ import ( "time" "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/consumers" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/gomatrix" @@ -235,6 +236,21 @@ func (r *FederationInternalAPI) performJoinUsingServer( return fmt.Errorf("respSendJoin.Check: %w", err) } + // We need to immediately update our list of joined hosts for this room now as we are technically + // joined. We must do this synchronously: we cannot rely on the roomserver output events as they + // will happen asyncly. If we don't update this table, you can end up with bad failure modes like + // joining a room, waiting for 200 OK then changing device keys and have those keys not be sent + // to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers") + // The events are trusted now as we performed auth checks above. + joinedHosts, err := consumers.JoinedHostsFromEvents(respState.StateEvents.TrustedEvents(respMakeJoin.RoomVersion, false)) + if err != nil { + return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err) + } + logrus.WithField("hosts", joinedHosts).WithField("room", roomID).Info("Joined federated room with hosts") + if _, err = r.db.UpdateRoom(context.Background(), roomID, joinedHosts, nil, true); err != nil { + return fmt.Errorf("UpdatedRoom: failed to update room with joined hosts: %s", err) + } + // If we successfully performed a send_join above then the other // server now thinks we're a part of the room. Send the newly // returned state to the roomserver to update our local view. @@ -650,7 +666,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder // FederatedAuthProvider is an auth chain provider which fetches events from the server provided func federatedAuthProvider( - ctx context.Context, federation *gomatrixserverlib.FederationClient, + ctx context.Context, federation api.FederationClient, keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName, ) gomatrixserverlib.AuthChainProvider { // A list of events that we have retried, if they were not included in diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 74794040..b6edec5d 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -21,6 +21,7 @@ import ( "sync" "time" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage/shared" @@ -49,21 +50,21 @@ type destinationQueue struct { db storage.Database process *process.ProcessContext signing *SigningInfo - rsAPI api.RoomserverInternalAPI - client *gomatrixserverlib.FederationClient // federation client - origin gomatrixserverlib.ServerName // origin of requests - destination gomatrixserverlib.ServerName // destination of requests - running atomic.Bool // is the queue worker running? - backingOff atomic.Bool // true if we're backing off - overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more - statistics *statistics.ServerStatistics // statistics about this remote server - transactionIDMutex sync.Mutex // protects transactionID - transactionID gomatrixserverlib.TransactionID // last transaction ID if retrying, or "" if last txn was successful - notify chan struct{} // interrupts idle wait pending PDUs/EDUs - pendingPDUs []*queuedPDU // PDUs waiting to be sent - pendingEDUs []*queuedEDU // EDUs waiting to be sent - pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs - interruptBackoff chan bool // interrupts backoff + rsAPI api.FederationRoomserverAPI + client fedapi.FederationClient // federation client + origin gomatrixserverlib.ServerName // origin of requests + destination gomatrixserverlib.ServerName // destination of requests + running atomic.Bool // is the queue worker running? + backingOff atomic.Bool // true if we're backing off + overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more + statistics *statistics.ServerStatistics // statistics about this remote server + transactionIDMutex sync.Mutex // protects transactionID + transactionID gomatrixserverlib.TransactionID // last transaction ID if retrying, or "" if last txn was successful + notify chan struct{} // interrupts idle wait pending PDUs/EDUs + pendingPDUs []*queuedPDU // PDUs waiting to be sent + pendingEDUs []*queuedEDU // EDUs waiting to be sent + pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs + interruptBackoff chan bool // interrupts backoff } // Send event adds the event to the pending queue for the destination. diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index d152886f..4c25c4ce 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -26,6 +26,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage/shared" @@ -39,9 +40,9 @@ type OutgoingQueues struct { db storage.Database process *process.ProcessContext disabled bool - rsAPI api.RoomserverInternalAPI + rsAPI api.FederationRoomserverAPI origin gomatrixserverlib.ServerName - client *gomatrixserverlib.FederationClient + client fedapi.FederationClient statistics *statistics.Statistics signing *SigningInfo queuesMutex sync.Mutex // protects the below @@ -85,8 +86,8 @@ func NewOutgoingQueues( process *process.ProcessContext, disabled bool, origin gomatrixserverlib.ServerName, - client *gomatrixserverlib.FederationClient, - rsAPI api.RoomserverInternalAPI, + client fedapi.FederationClient, + rsAPI api.FederationRoomserverAPI, statistics *statistics.Statistics, signing *SigningInfo, ) *OutgoingQueues { diff --git a/federationapi/routing/query.go b/federationapi/routing/query.go index 707b7b01..316c61a1 100644 --- a/federationapi/routing/query.go +++ b/federationapi/routing/query.go @@ -30,7 +30,7 @@ import ( // RoomAliasToID converts the queried alias into a room ID and returns it func RoomAliasToID( httpReq *http.Request, - federation *gomatrixserverlib.FederationClient, + federation federationAPI.FederationClient, cfg *config.FederationAPI, rsAPI roomserverAPI.FederationRoomserverAPI, senderAPI federationAPI.FederationInternalAPI, diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 9f95ed07..e25f9866 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -54,7 +54,7 @@ func Setup( rsAPI roomserverAPI.FederationRoomserverAPI, fsAPI *fedInternal.FederationInternalAPI, keys gomatrixserverlib.JSONVerifier, - federation *gomatrixserverlib.FederationClient, + federation federationAPI.FederationClient, userAPI userapi.FederationUserAPI, keyAPI keyserverAPI.FederationKeyAPI, mscCfg *config.MSCs, diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 55a11367..c25dabce 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -85,7 +85,7 @@ func Send( rsAPI api.FederationRoomserverAPI, keyAPI keyapi.FederationKeyAPI, keys gomatrixserverlib.JSONVerifier, - federation *gomatrixserverlib.FederationClient, + federation federationAPI.FederationClient, mu *internal.MutexByRoom, servers federationAPI.ServersInRoomProvider, producer *producers.SyncAPIProducer, diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 011d4e34..a111580c 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -8,8 +8,8 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 16f245ce..ccde9168 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -57,7 +58,7 @@ var ( func CreateInvitesFrom3PIDInvites( req *http.Request, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, - federation *gomatrixserverlib.FederationClient, + federation federationAPI.FederationClient, userAPI userapi.FederationUserAPI, ) util.JSONResponse { var body invites @@ -107,7 +108,7 @@ func ExchangeThirdPartyInvite( roomID string, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, - federation *gomatrixserverlib.FederationClient, + federation federationAPI.FederationClient, ) util.JSONResponse { var builder gomatrixserverlib.EventBuilder if err := json.Unmarshal(request.Content(), &builder); err != nil { @@ -165,7 +166,12 @@ func ExchangeThirdPartyInvite( // Ask the requesting server to sign the newly created event so we know it // acknowledged it - signedEvent, err := federation.SendInvite(httpReq.Context(), request.Origin(), event) + inviteReq, err := gomatrixserverlib.NewInviteV2Request(event.Headered(verRes.RoomVersion), nil) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") + return jsonerror.InternalServerError() + } + signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") return jsonerror.InternalServerError() @@ -205,7 +211,7 @@ func ExchangeThirdPartyInvite( func createInviteFrom3PIDInvite( ctx context.Context, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, - inv invite, federation *gomatrixserverlib.FederationClient, + inv invite, federation federationAPI.FederationClient, userAPI userapi.FederationUserAPI, ) (*gomatrixserverlib.Event, error) { verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID} @@ -335,7 +341,7 @@ func buildMembershipEvent( // them responded with an error. func sendToRemoteServer( ctx context.Context, inv invite, - federation *gomatrixserverlib.FederationClient, _ *config.FederationAPI, + federation federationAPI.FederationClient, _ *config.FederationAPI, builder gomatrixserverlib.EventBuilder, ) (err error) { remoteServers := make([]gomatrixserverlib.ServerName, 2) diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index e3038651..29254948 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -25,13 +25,12 @@ import ( type Database interface { gomatrixserverlib.KeyDatabase - UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) + UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) - PurgeRoomState(ctx context.Context, roomID string) error StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) diff --git a/federationapi/storage/postgres/joined_hosts_table.go b/federationapi/storage/postgres/joined_hosts_table.go index 5c95b72a..bb6f6bfa 100644 --- a/federationapi/storage/postgres/joined_hosts_table.go +++ b/federationapi/storage/postgres/joined_hosts_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" ) const joinedHostsSchema = ` @@ -111,6 +112,7 @@ func (s *joinedHostsStatements) InsertJoinedHosts( roomID, eventID string, serverName gomatrixserverlib.ServerName, ) error { + logrus.Debugf("FederationJoinedHosts: INSERT %v %v %v", roomID, eventID, serverName) stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) return err @@ -119,6 +121,7 @@ func (s *joinedHostsStatements) InsertJoinedHosts( func (s *joinedHostsStatements) DeleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { + logrus.Debugf("FederationJoinedHosts: DELETE WITH EVENTS %v", eventIDs) stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) _, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs)) return err @@ -127,6 +130,7 @@ func (s *joinedHostsStatements) DeleteJoinedHosts( func (s *joinedHostsStatements) DeleteJoinedHostsForRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { + logrus.Debugf("FederationJoinedHosts: DELETE ALL IN ROOM %v", roomID) stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt) _, err := stmt.ExecContext(ctx, roomID) return err @@ -207,6 +211,7 @@ func joinedHostsFromStmt( ServerName: gomatrixserverlib.ServerName(serverName), }) } + logrus.Debugf("FederationJoinedHosts: SELECT %v => %+v", roomID, result) return result, rows.Err() } diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 160c7f6f..a00d782f 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -63,11 +63,21 @@ func (r *Receipt) String() string { // this isn't a duplicate message. func (d *Database) UpdateRoom( ctx context.Context, - roomID, oldEventID, newEventID string, + roomID string, addHosts []types.JoinedHost, removeHosts []string, + purgeRoomFirst bool, ) (joinedHosts []types.JoinedHost, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if purgeRoomFirst { + // If the event is a create event then we'll delete all of the existing + // data for the room. The only reason that a create event would be replayed + // to us in this way is if we're about to receive the entire room state. + if err = d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err) + } + } + joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID) if err != nil { return err @@ -138,20 +148,6 @@ func (d *Database) StoreJSON( }, nil } -func (d *Database) PurgeRoomState( - ctx context.Context, roomID string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // If the event is a create event then we'll delete all of the existing - // data for the room. The only reason that a create event would be replayed - // to us in this way is if we're about to receive the entire room state. - if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil { - return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err) - } - return nil - }) -} - func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName) diff --git a/internal/caching/cache_typing_test.go b/internal/caching/cache_typing_test.go index c03d89bc..2cef32d3 100644 --- a/internal/caching/cache_typing_test.go +++ b/internal/caching/cache_typing_test.go @@ -20,7 +20,7 @@ import ( "testing" "time" - "github.com/matrix-org/dendrite/internal/test" + "github.com/matrix-org/dendrite/test" ) func TestEDUCache(t *testing.T) { diff --git a/internal/test/client.go b/internal/test/client.go deleted file mode 100644 index a38540ac..00000000 --- a/internal/test/client.go +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package test - -import ( - "crypto/tls" - "fmt" - "io" - "io/ioutil" - "net/http" - "sync" - "time" - - "github.com/matrix-org/gomatrixserverlib" -) - -// Request contains the information necessary to issue a request and test its result -type Request struct { - Req *http.Request - WantedBody string - WantedStatusCode int - LastErr *LastRequestErr -} - -// LastRequestErr is a synchronised error wrapper -// Useful for obtaining the last error from a set of requests -type LastRequestErr struct { - sync.Mutex - Err error -} - -// Set sets the error -func (r *LastRequestErr) Set(err error) { - r.Lock() - defer r.Unlock() - r.Err = err -} - -// Get gets the error -func (r *LastRequestErr) Get() error { - r.Lock() - defer r.Unlock() - return r.Err -} - -// CanonicalJSONInput canonicalises a slice of JSON strings -// Useful for test input -func CanonicalJSONInput(jsonData []string) []string { - for i := range jsonData { - jsonBytes, err := gomatrixserverlib.CanonicalJSON([]byte(jsonData[i])) - if err != nil && err != io.EOF { - panic(err) - } - jsonData[i] = string(jsonBytes) - } - return jsonData -} - -// Do issues a request and checks the status code and body of the response -func (r *Request) Do() (err error) { - client := &http.Client{ - Timeout: 5 * time.Second, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }, - } - res, err := client.Do(r.Req) - if err != nil { - return err - } - defer (func() { err = res.Body.Close() })() - - if res.StatusCode != r.WantedStatusCode { - return fmt.Errorf("incorrect status code. Expected: %d Got: %d", r.WantedStatusCode, res.StatusCode) - } - - if r.WantedBody != "" { - resBytes, err := ioutil.ReadAll(res.Body) - if err != nil { - return err - } - jsonBytes, err := gomatrixserverlib.CanonicalJSON(resBytes) - if err != nil { - return err - } - if string(jsonBytes) != r.WantedBody { - return fmt.Errorf("returned wrong bytes. Expected:\n%s\n\nGot:\n%s", r.WantedBody, string(jsonBytes)) - } - } - - return nil -} - -// DoUntilSuccess blocks and repeats the same request until the response returns the desired status code and body. -// It then closes the given channel and returns. -func (r *Request) DoUntilSuccess(done chan error) { - r.LastErr = &LastRequestErr{} - for { - if err := r.Do(); err != nil { - r.LastErr.Set(err) - time.Sleep(1 * time.Second) // don't tightloop - continue - } - close(done) - return - } -} - -// Run repeatedly issues a request until success, error or a timeout is reached -func (r *Request) Run(label string, timeout time.Duration, serverCmdChan chan error) { - fmt.Printf("==TESTING== %v (timeout: %v)\n", label, timeout) - done := make(chan error, 1) - - // We need to wait for the server to: - // - have connected to the database - // - have created the tables - // - be listening on the given port - go r.DoUntilSuccess(done) - - // wait for one of: - // - the test to pass (done channel is closed) - // - the server to exit with an error (error sent on serverCmdChan) - // - our test timeout to expire - // We don't need to clean up since the main() function handles that in the event we panic - select { - case <-time.After(timeout): - fmt.Printf("==TESTING== %v TIMEOUT\n", label) - if reqErr := r.LastErr.Get(); reqErr != nil { - fmt.Println("Last /sync request error:") - fmt.Println(reqErr) - } - panic(fmt.Sprintf("%v server timed out", label)) - case err := <-serverCmdChan: - if err != nil { - fmt.Println("=============================================================================================") - fmt.Printf("%v server failed to run. If failing with 'pq: password authentication failed for user' try:", label) - fmt.Println(" export PGHOST=/var/run/postgresql") - fmt.Println("=============================================================================================") - panic(err) - } - case <-done: - fmt.Printf("==TESTING== %v PASSED\n", label) - } -} diff --git a/internal/test/kafka.go b/internal/test/kafka.go deleted file mode 100644 index cbf24630..00000000 --- a/internal/test/kafka.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package test - -import ( - "io" - "os/exec" - "path/filepath" - "strings" -) - -// KafkaExecutor executes kafka scripts. -type KafkaExecutor struct { - // The location of Zookeeper. Typically this is `localhost:2181`. - ZookeeperURI string - // The directory where Kafka is installed to. Used to locate kafka scripts. - KafkaDirectory string - // The location of the Kafka logs. Typically this is `localhost:9092`. - KafkaURI string - // Where stdout and stderr should be written to. Typically this is `os.Stderr`. - OutputWriter io.Writer -} - -// CreateTopic creates a new kafka topic. This is created with a single partition. -func (e *KafkaExecutor) CreateTopic(topic string) error { - cmd := exec.Command( - filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"), - "--create", - "--zookeeper", e.ZookeeperURI, - "--replication-factor", "1", - "--partitions", "1", - "--topic", topic, - ) - cmd.Stdout = e.OutputWriter - cmd.Stderr = e.OutputWriter - return cmd.Run() -} - -// WriteToTopic writes data to a kafka topic. -func (e *KafkaExecutor) WriteToTopic(topic string, data []string) error { - cmd := exec.Command( - filepath.Join(e.KafkaDirectory, "bin", "kafka-console-producer.sh"), - "--broker-list", e.KafkaURI, - "--topic", topic, - ) - cmd.Stdout = e.OutputWriter - cmd.Stderr = e.OutputWriter - cmd.Stdin = strings.NewReader(strings.Join(data, "\n")) - return cmd.Run() -} - -// DeleteTopic deletes a given kafka topic if it exists. -func (e *KafkaExecutor) DeleteTopic(topic string) error { - cmd := exec.Command( - filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"), - "--delete", - "--if-exists", - "--zookeeper", e.ZookeeperURI, - "--topic", topic, - ) - cmd.Stderr = e.OutputWriter - cmd.Stdout = e.OutputWriter - return cmd.Run() -} diff --git a/internal/test/server.go b/internal/test/server.go deleted file mode 100644 index ca14ea1b..00000000 --- a/internal/test/server.go +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package test - -import ( - "context" - "fmt" - "net" - "net/http" - "os" - "os/exec" - "path/filepath" - "strings" - "sync" - "testing" - - "github.com/matrix-org/dendrite/setup/config" -) - -// Defaulting allows assignment of string variables with a fallback default value -// Useful for use with os.Getenv() for example -func Defaulting(value, defaultValue string) string { - if value == "" { - value = defaultValue - } - return value -} - -// CreateDatabase creates a new database, dropping it first if it exists -func CreateDatabase(command string, args []string, database string) error { - cmd := exec.Command(command, args...) - cmd.Stdin = strings.NewReader( - fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", database, database), - ) - // Send stdout and stderr to our stderr so that we see error messages from - // the psql process - cmd.Stdout = os.Stderr - cmd.Stderr = os.Stderr - return cmd.Run() -} - -// CreateBackgroundCommand creates an executable command -// The Cmd being executed is returned. A channel is also returned, -// which will have any termination errors sent down it, followed immediately by the channel being closed. -func CreateBackgroundCommand(command string, args []string) (*exec.Cmd, chan error) { - cmd := exec.Command(command, args...) - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stderr - - if err := cmd.Start(); err != nil { - panic("failed to start server: " + err.Error()) - } - cmdChan := make(chan error, 1) - go func() { - cmdChan <- cmd.Wait() - close(cmdChan) - }() - return cmd, cmdChan -} - -// InitDatabase creates the database and config file needed for the server to run -func InitDatabase(postgresDatabase, postgresContainerName string, databases []string) { - if len(databases) > 0 { - var dbCmd string - var dbArgs []string - if postgresContainerName == "" { - dbCmd = "psql" - dbArgs = []string{postgresDatabase} - } else { - dbCmd = "docker" - dbArgs = []string{ - "exec", "-i", postgresContainerName, "psql", "-U", "postgres", postgresDatabase, - } - } - for _, database := range databases { - if err := CreateDatabase(dbCmd, dbArgs, database); err != nil { - panic(err) - } - } - } -} - -// StartProxy creates a reverse proxy -func StartProxy(bindAddr string, cfg *config.Dendrite) (*exec.Cmd, chan error) { - proxyArgs := []string{ - "--bind-address", bindAddr, - "--sync-api-server-url", "http://" + string(cfg.SyncAPI.InternalAPI.Connect), - "--client-api-server-url", "http://" + string(cfg.ClientAPI.InternalAPI.Connect), - "--media-api-server-url", "http://" + string(cfg.MediaAPI.InternalAPI.Connect), - "--tls-cert", "server.crt", - "--tls-key", "server.key", - } - return CreateBackgroundCommand( - filepath.Join(filepath.Dir(os.Args[0]), "client-api-proxy"), - proxyArgs, - ) -} - -// ListenAndServe will listen on a random high-numbered port and attach the given router. -// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed. -func ListenAndServe(t *testing.T, router http.Handler, useTLS bool) (apiURL string, cancel func()) { - listener, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatalf("failed to listen: %s", err) - } - port := listener.Addr().(*net.TCPAddr).Port - srv := http.Server{} - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - srv.Handler = router - var err error - if useTLS { - certFile := filepath.Join(os.TempDir(), "dendrite.cert") - keyFile := filepath.Join(os.TempDir(), "dendrite.key") - err = NewTLSKey(keyFile, certFile) - if err != nil { - t.Logf("failed to generate tls key/cert: %s", err) - return - } - err = srv.ServeTLS(listener, certFile, keyFile) - } else { - err = srv.Serve(listener) - } - if err != nil && err != http.ErrServerClosed { - t.Logf("Listen failed: %s", err) - } - }() - - secure := "" - if useTLS { - secure = "s" - } - return fmt.Sprintf("http%s://localhost:%d", secure, port), func() { - _ = srv.Shutdown(context.Background()) - wg.Wait() - } -} diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 561c9a16..23f3e1a6 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -84,7 +84,7 @@ type DeviceListUpdater struct { db DeviceListUpdaterDatabase api DeviceListUpdaterAPI producer KeyChangeProducer - fedClient fedsenderapi.FederationClient + fedClient fedsenderapi.KeyserverFederationAPI workerChans []chan gomatrixserverlib.ServerName // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will @@ -127,7 +127,7 @@ type KeyChangeProducer interface { // NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. func NewDeviceListUpdater( db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, - fedClient fedsenderapi.FederationClient, numWorkers int, + fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, ) *DeviceListUpdater { return &DeviceListUpdater{ userIDToMutex: make(map[string]*sync.Mutex), diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index be71e575..f8d0d69c 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -37,7 +37,7 @@ import ( type KeyInternalAPI struct { DB storage.Database ThisServer gomatrixserverlib.ServerName - FedClient fedsenderapi.FederationClient + FedClient fedsenderapi.KeyserverFederationAPI UserAPI userapi.KeyserverUserAPI Producer *producers.KeyChange Updater *DeviceListUpdater diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 47d7f57f..3ffd3ba1 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -37,7 +37,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient, + base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI, ) api.KeyInternalAPI { js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index cbb4cebc..80e7aed6 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -183,6 +183,7 @@ type FederationRoomserverAPI interface { QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error // Query whether a server is allowed to see an event QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error + QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error // Query a given amount (or less) of events prior to a given set of events. diff --git a/roomserver/internal/input/input_test.go b/roomserver/internal/input/input_test.go index a95c1355..7c65f9ea 100644 --- a/roomserver/internal/input/input_test.go +++ b/roomserver/internal/input/input_test.go @@ -12,7 +12,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" ) @@ -22,7 +22,7 @@ var jc *nats.Conn func TestMain(m *testing.M) { var b *base.BaseDendrite - b, js, jc = test.Base(nil) + b, js, jc = testrig.Base(nil) code := m.Run() b.ShutdownDendrite() b.WaitForComponentsToFinish() diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index ba5bb9f5..03627ea9 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -19,8 +19,8 @@ import ( "encoding/json" "testing" - "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/roomserver/storage/tables/events_table_test.go b/roomserver/storage/tables/events_table_test.go index d5d699c4..6f72a59b 100644 --- a/roomserver/storage/tables/events_table_test.go +++ b/roomserver/storage/tables/events_table_test.go @@ -39,7 +39,7 @@ func mustCreateEventsTable(t *testing.T, dbType test.DBType) (tables.Events, fun } func Test_EventsTable(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { diff --git a/roomserver/storage/tables/previous_events_table_test.go b/roomserver/storage/tables/previous_events_table_test.go index 96d7bfed..63d54069 100644 --- a/roomserver/storage/tables/previous_events_table_test.go +++ b/roomserver/storage/tables/previous_events_table_test.go @@ -38,7 +38,7 @@ func mustCreatePreviousEventsTable(t *testing.T, dbType test.DBType) (tab tables func TestPreviousEventsTable(t *testing.T) { ctx := context.Background() - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { tab, close := mustCreatePreviousEventsTable(t, dbType) diff --git a/roomserver/storage/tables/published_table_test.go b/roomserver/storage/tables/published_table_test.go index 87662ed4..fff6dc18 100644 --- a/roomserver/storage/tables/published_table_test.go +++ b/roomserver/storage/tables/published_table_test.go @@ -38,7 +38,7 @@ func mustCreatePublishedTable(t *testing.T, dbType test.DBType) (tab tables.Publ func TestPublishedTable(t *testing.T) { ctx := context.Background() - alice := test.NewUser() + alice := test.NewUser(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { tab, close := mustCreatePublishedTable(t, dbType) diff --git a/roomserver/storage/tables/room_aliases_table_test.go b/roomserver/storage/tables/room_aliases_table_test.go index 8fb57d5a..624d92ae 100644 --- a/roomserver/storage/tables/room_aliases_table_test.go +++ b/roomserver/storage/tables/room_aliases_table_test.go @@ -36,7 +36,7 @@ func mustCreateRoomAliasesTable(t *testing.T, dbType test.DBType) (tab tables.Ro } func TestRoomAliasesTable(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) ctx := context.Background() diff --git a/roomserver/storage/tables/rooms_table_test.go b/roomserver/storage/tables/rooms_table_test.go index 9872fb80..0a02369a 100644 --- a/roomserver/storage/tables/rooms_table_test.go +++ b/roomserver/storage/tables/rooms_table_test.go @@ -38,7 +38,7 @@ func mustCreateRoomsTable(t *testing.T, dbType test.DBType) (tab tables.Rooms, c } func TestRoomsTable(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 1150c2f3..563c92e3 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -47,7 +47,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver func TestWriteEvents(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - alice := test.NewUser() + alice := test.NewUser(t) r := test.NewRoom(t, alice) db, close := MustCreateDatabase(t, dbType) defer close() @@ -60,7 +60,7 @@ func TestRecentEventsPDU(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := MustCreateDatabase(t, dbType) defer close() - alice := test.NewUser() + alice := test.NewUser(t) // dummy room to make sure SQL queries are filtering on room ID MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) @@ -163,7 +163,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := MustCreateDatabase(t, dbType) defer close() - alice := test.NewUser() + alice := test.NewUser(t) r := test.NewRoom(t, alice) for i := 0; i < 10; i++ { r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)}) diff --git a/syncapi/storage/tables/output_room_events_test.go b/syncapi/storage/tables/output_room_events_test.go index 8bbf879d..69bbd04c 100644 --- a/syncapi/storage/tables/output_room_events_test.go +++ b/syncapi/storage/tables/output_room_events_test.go @@ -45,7 +45,7 @@ func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, func TestOutputRoomEventsTable(t *testing.T) { ctx := context.Background() - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { tab, db, close := newOutputRoomEventsTable(t, dbType) diff --git a/syncapi/storage/tables/topology_test.go b/syncapi/storage/tables/topology_test.go index 2334aae2..f4f75bdf 100644 --- a/syncapi/storage/tables/topology_test.go +++ b/syncapi/storage/tables/topology_test.go @@ -40,7 +40,7 @@ func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.D func TestTopologyTable(t *testing.T) { ctx := context.Background() - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { tab, db, close := newTopologyTable(t, dbType) diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index d3d89839..5ecfd877 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -15,6 +15,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" @@ -86,7 +87,7 @@ func TestSyncAPIAccessTokens(t *testing.T) { } func testSyncAccessTokens(t *testing.T, dbType test.DBType) { - user := test.NewUser() + user := test.NewUser(t) room := test.NewRoom(t, user) alice := userapi.Device{ ID: "ALICEID", @@ -96,14 +97,14 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { AccountType: userapi.AccountTypeUser, } - base, close := test.CreateBaseDendrite(t, dbType) + base, close := testrig.CreateBaseDendrite(t, dbType) defer close() jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) msgs := toNATSMsgs(t, base, room.Events()) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) - test.MustPublishMsgs(t, jsctx, msgs...) + testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { name string @@ -173,7 +174,7 @@ func TestSyncAPICreateRoomSyncEarly(t *testing.T) { } func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { - user := test.NewUser() + user := test.NewUser(t) room := test.NewRoom(t, user) alice := userapi.Device{ ID: "ALICEID", @@ -183,7 +184,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { AccountType: userapi.AccountTypeUser, } - base, close := test.CreateBaseDendrite(t, dbType) + base, close := testrig.CreateBaseDendrite(t, dbType) defer close() jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) @@ -198,7 +199,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { sinceTokens := make([]string, len(msgs)) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) for i, msg := range msgs { - test.MustPublishMsgs(t, jsctx, msg) + testrig.MustPublishMsgs(t, jsctx, msg) time.Sleep(100 * time.Millisecond) w := httptest.NewRecorder() base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ @@ -262,7 +263,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverli if ev.StateKey() != nil { addsStateIDs = append(addsStateIDs, ev.EventID()) } - result[i] = test.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{ + result[i] = testrig.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{ Type: rsapi.OutputTypeNewRoomEvent, NewRoomEvent: &rsapi.OutputNewRoomEvent{ Event: ev, diff --git a/test/event.go b/test/event.go index 40cb8f0e..73fc656b 100644 --- a/test/event.go +++ b/test/event.go @@ -52,6 +52,24 @@ func WithUnsigned(unsigned interface{}) eventModifier { } } +func WithKeyID(keyID gomatrixserverlib.KeyID) eventModifier { + return func(e *eventMods) { + e.keyID = keyID + } +} + +func WithPrivateKey(pkey ed25519.PrivateKey) eventModifier { + return func(e *eventMods) { + e.privKey = pkey + } +} + +func WithOrigin(origin gomatrixserverlib.ServerName) eventModifier { + return func(e *eventMods) { + e.origin = origin + } +} + // Reverse a list of events func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) diff --git a/test/http.go b/test/http.go index a458a338..37b3648f 100644 --- a/test/http.go +++ b/test/http.go @@ -2,10 +2,15 @@ package test import ( "bytes" + "context" "encoding/json" + "fmt" "io" + "net" "net/http" "net/url" + "path/filepath" + "sync" "testing" ) @@ -43,3 +48,45 @@ func NewRequest(t *testing.T, method, path string, opts ...HTTPRequestOpt) *http } return req } + +// ListenAndServe will listen on a random high-numbered port and attach the given router. +// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed. +func ListenAndServe(t *testing.T, router http.Handler, withTLS bool) (apiURL string, cancel func()) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("failed to listen: %s", err) + } + port := listener.Addr().(*net.TCPAddr).Port + srv := http.Server{} + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + srv.Handler = router + var err error + if withTLS { + certFile := filepath.Join(t.TempDir(), "dendrite.cert") + keyFile := filepath.Join(t.TempDir(), "dendrite.key") + err = NewTLSKey(keyFile, certFile) + if err != nil { + t.Errorf("failed to make TLS key: %s", err) + return + } + err = srv.ServeTLS(listener, certFile, keyFile) + } else { + err = srv.Serve(listener) + } + if err != nil && err != http.ErrServerClosed { + t.Logf("Listen failed: %s", err) + } + }() + s := "" + if withTLS { + s = "s" + } + return fmt.Sprintf("http%s://localhost:%d", s, port), func() { + _ = srv.Shutdown(context.Background()) + wg.Wait() + } +} diff --git a/internal/test/keyring.go b/test/keyring.go index ed9c3484..ed9c3484 100644 --- a/internal/test/keyring.go +++ b/test/keyring.go diff --git a/internal/test/config.go b/test/keys.go index d8e0c453..75e3800e 100644 --- a/internal/test/config.go +++ b/test/keys.go @@ -25,103 +25,19 @@ import ( "io/ioutil" "math/big" "os" - "path/filepath" "strings" "time" - - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" - "gopkg.in/yaml.v2" ) const ( - // ConfigFile is the name of the config file for a server. - ConfigFile = "dendrite.yaml" // ServerKeyFile is the name of the file holding the matrix server private key. ServerKeyFile = "server_key.pem" // TLSCertFile is the name of the file holding the TLS certificate used for federation. TLSCertFile = "tls_cert.pem" // TLSKeyFile is the name of the file holding the TLS key used for federation. TLSKeyFile = "tls_key.pem" - // MediaDir is the name of the directory used to store media. - MediaDir = "media" ) -// MakeConfig makes a config suitable for running integration tests. -// Generates new matrix and TLS keys for the server. -func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*config.Dendrite, int, error) { - var cfg config.Dendrite - cfg.Defaults(true) - - port := startPort - assignAddress := func() config.HTTPAddress { - result := config.HTTPAddress(fmt.Sprintf("http://%s:%d", host, port)) - port++ - return result - } - - serverKeyPath := filepath.Join(configDir, ServerKeyFile) - tlsCertPath := filepath.Join(configDir, TLSKeyFile) - tlsKeyPath := filepath.Join(configDir, TLSCertFile) - mediaBasePath := filepath.Join(configDir, MediaDir) - - if err := NewMatrixKey(serverKeyPath); err != nil { - return nil, 0, err - } - - if err := NewTLSKey(tlsKeyPath, tlsCertPath); err != nil { - return nil, 0, err - } - - cfg.Version = config.Version - - cfg.Global.ServerName = gomatrixserverlib.ServerName(assignAddress()) - cfg.Global.PrivateKeyPath = config.Path(serverKeyPath) - - cfg.MediaAPI.BasePath = config.Path(mediaBasePath) - - cfg.Global.JetStream.Addresses = []string{kafkaURI} - - // TODO: Use different databases for the different schemas. - // Using the same database for every schema currently works because - // the table names are globally unique. But we might not want to - // rely on that in the future. - cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(database) - cfg.FederationAPI.Database.ConnectionString = config.DataSource(database) - cfg.KeyServer.Database.ConnectionString = config.DataSource(database) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(database) - cfg.RoomServer.Database.ConnectionString = config.DataSource(database) - cfg.SyncAPI.Database.ConnectionString = config.DataSource(database) - cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database) - - cfg.AppServiceAPI.InternalAPI.Listen = assignAddress() - cfg.FederationAPI.InternalAPI.Listen = assignAddress() - cfg.KeyServer.InternalAPI.Listen = assignAddress() - cfg.MediaAPI.InternalAPI.Listen = assignAddress() - cfg.RoomServer.InternalAPI.Listen = assignAddress() - cfg.SyncAPI.InternalAPI.Listen = assignAddress() - cfg.UserAPI.InternalAPI.Listen = assignAddress() - - cfg.AppServiceAPI.InternalAPI.Connect = cfg.AppServiceAPI.InternalAPI.Listen - cfg.FederationAPI.InternalAPI.Connect = cfg.FederationAPI.InternalAPI.Listen - cfg.KeyServer.InternalAPI.Connect = cfg.KeyServer.InternalAPI.Listen - cfg.MediaAPI.InternalAPI.Connect = cfg.MediaAPI.InternalAPI.Listen - cfg.RoomServer.InternalAPI.Connect = cfg.RoomServer.InternalAPI.Listen - cfg.SyncAPI.InternalAPI.Connect = cfg.SyncAPI.InternalAPI.Listen - cfg.UserAPI.InternalAPI.Connect = cfg.UserAPI.InternalAPI.Listen - - return &cfg, port, nil -} - -// WriteConfig writes the config file to the directory. -func WriteConfig(cfg *config.Dendrite, configDir string) error { - data, err := yaml.Marshal(cfg) - if err != nil { - return err - } - return ioutil.WriteFile(filepath.Join(configDir, ConfigFile), data, 0666) -} - // NewMatrixKey generates a new ed25519 matrix server key and writes it to a file. func NewMatrixKey(matrixKeyPath string) (err error) { var data [35]byte diff --git a/test/room.go b/test/room.go index 619cb5c9..6ae403b3 100644 --- a/test/room.go +++ b/test/room.go @@ -15,7 +15,6 @@ package test import ( - "crypto/ed25519" "encoding/json" "fmt" "sync/atomic" @@ -35,12 +34,6 @@ var ( PresetTrustedPrivateChat Preset = 3 roomIDCounter = int64(0) - - testKeyID = gomatrixserverlib.KeyID("ed25519:test") - testPrivateKey = ed25519.NewKeyFromSeed([]byte{ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, - }) ) type Room struct { @@ -49,22 +42,25 @@ type Room struct { preset Preset creator *User - authEvents gomatrixserverlib.AuthEvents - events []*gomatrixserverlib.HeaderedEvent + authEvents gomatrixserverlib.AuthEvents + currentState map[string]*gomatrixserverlib.HeaderedEvent + events []*gomatrixserverlib.HeaderedEvent } // Create a new test room. Automatically creates the initial create events. func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { t.Helper() counter := atomic.AddInt64(&roomIDCounter, 1) - - // set defaults then let roomModifiers override + if creator.srvName == "" { + t.Fatalf("NewRoom: creator doesn't belong to a server: %+v", *creator) + } r := &Room{ - ID: fmt.Sprintf("!%d:localhost", counter), - creator: creator, - authEvents: gomatrixserverlib.NewAuthEvents(nil), - preset: PresetPublicChat, - Version: gomatrixserverlib.RoomVersionV9, + ID: fmt.Sprintf("!%d:%s", counter, creator.srvName), + creator: creator, + authEvents: gomatrixserverlib.NewAuthEvents(nil), + preset: PresetPublicChat, + Version: gomatrixserverlib.RoomVersionV9, + currentState: make(map[string]*gomatrixserverlib.HeaderedEvent), } for _, m := range modifiers { m(t, r) @@ -73,6 +69,24 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { return r } +func (r *Room) MustGetAuthEventRefsForEvent(t *testing.T, needed gomatrixserverlib.StateNeeded) []gomatrixserverlib.EventReference { + t.Helper() + a, err := needed.AuthEventReferences(&r.authEvents) + if err != nil { + t.Fatalf("MustGetAuthEvents: %v", err) + } + return a +} + +func (r *Room) ForwardExtremities() []string { + if len(r.events) == 0 { + return nil + } + return []string{ + r.events[len(r.events)-1].EventID(), + } +} + func (r *Room) insertCreateEvents(t *testing.T) { t.Helper() var joinRule gomatrixserverlib.JoinRuleContent @@ -88,6 +102,7 @@ func (r *Room) insertCreateEvents(t *testing.T) { joinRule.JoinRule = "public" hisVis.HistoryVisibility = "shared" } + r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{ "creator": r.creator.ID, "room_version": r.Version, @@ -112,16 +127,16 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten } if mod.privKey == nil { - mod.privKey = testPrivateKey + mod.privKey = creator.privKey } if mod.keyID == "" { - mod.keyID = testKeyID + mod.keyID = creator.keyID } if mod.originServerTS.IsZero() { mod.originServerTS = time.Now() } if mod.origin == "" { - mod.origin = gomatrixserverlib.ServerName("localhost") + mod.origin = creator.srvName } var unsigned gomatrixserverlib.RawJSON @@ -174,13 +189,14 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten // Add a new event to this room DAG. Not thread-safe. func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) { t.Helper() - // Add the event to the list of auth events + // Add the event to the list of auth/state events r.events = append(r.events, he) if he.StateKey() != nil { err := r.authEvents.AddEvent(he.Unwrap()) if err != nil { t.Fatalf("InsertEvent: failed to add event to auth events: %s", err) } + r.currentState[he.Type()+" "+*he.StateKey()] = he } } @@ -188,6 +204,16 @@ func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent { return r.events } +func (r *Room) CurrentState() []*gomatrixserverlib.HeaderedEvent { + events := make([]*gomatrixserverlib.HeaderedEvent, len(r.currentState)) + i := 0 + for _, e := range r.currentState { + events[i] = e + i++ + } + return events +} + func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent { t.Helper() he := r.CreateEvent(t, creator, eventType, content, mods...) diff --git a/internal/test/slice.go b/test/slice.go index 00c740db..00c740db 100644 --- a/internal/test/slice.go +++ b/test/slice.go diff --git a/test/base.go b/test/testrig/base.go index 664442c0..facb49f3 100644 --- a/test/base.go +++ b/test/testrig/base.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package test +package testrig import ( "errors" @@ -24,22 +24,23 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" "github.com/nats-io/nats.go" ) -func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func()) { +func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, func()) { var cfg config.Dendrite cfg.Defaults(false) cfg.Global.JetStream.InMemory = true switch dbType { - case DBTypePostgres: + case test.DBTypePostgres: cfg.Global.Defaults(true) // autogen a signing key cfg.MediaAPI.Defaults(true) // autogen a media path // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) - connStr, close := PrepareDBConnectionString(t, dbType) + connStr, close := test.PrepareDBConnectionString(t, dbType) cfg.Global.DatabaseOptions = config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), MaxOpenConnections: 10, @@ -47,7 +48,7 @@ func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func() ConnMaxLifetimeSeconds: 60, } return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close - case DBTypeSQLite: + case test.DBTypeSQLite: cfg.Defaults(true) // sets a sqlite db per component // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( diff --git a/test/jetstream.go b/test/testrig/jetstream.go index 488c22be..74cf9506 100644 --- a/test/jetstream.go +++ b/test/testrig/jetstream.go @@ -1,4 +1,4 @@ -package test +package testrig import ( "encoding/json" diff --git a/test/user.go b/test/user.go index 41a66e1c..0020098a 100644 --- a/test/user.go +++ b/test/user.go @@ -15,22 +15,64 @@ package test import ( + "crypto/ed25519" "fmt" "sync/atomic" + "testing" + + "github.com/matrix-org/gomatrixserverlib" ) var ( userIDCounter = int64(0) + + serverName = gomatrixserverlib.ServerName("test") + keyID = gomatrixserverlib.KeyID("ed25519:test") + privateKey = ed25519.NewKeyFromSeed([]byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + }) + + // private keys that tests can use + PrivateKeyA = ed25519.NewKeyFromSeed([]byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 77, + }) + PrivateKeyB = ed25519.NewKeyFromSeed([]byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 66, + }) ) type User struct { ID string + // key ID and private key of the server who has this user, if known. + keyID gomatrixserverlib.KeyID + privKey ed25519.PrivateKey + srvName gomatrixserverlib.ServerName +} + +type UserOpt func(*User) + +func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, privKey ed25519.PrivateKey) UserOpt { + return func(u *User) { + u.keyID = keyID + u.privKey = privKey + u.srvName = srvName + } } -func NewUser() *User { +func NewUser(t *testing.T, opts ...UserOpt) *User { counter := atomic.AddInt64(&userIDCounter, 1) - u := &User{ - ID: fmt.Sprintf("@%d:localhost", counter), + var u User + for _, opt := range opts { + opt(&u) + } + if u.keyID == "" || u.srvName == "" || u.privKey == nil { + t.Logf("NewUser: missing signing server credentials; using default.") + WithSigningServer(serverName, keyID, privateKey)(&u) } - return u + u.ID = fmt.Sprintf("@%d:%s", counter, u.srvName) + t.Logf("NewUser: created user %s", u.ID) + return &u } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 5683fe06..5bee880d 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -43,7 +43,7 @@ func Test_AccountData(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() - alice := test.NewUser() + alice := test.NewUser(t) localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -74,7 +74,7 @@ func Test_Accounts(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() - alice := test.NewUser() + alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -128,7 +128,7 @@ func Test_Accounts(t *testing.T) { } func Test_Devices(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) deviceID := util.RandomString(8) @@ -212,7 +212,7 @@ func Test_Devices(t *testing.T) { } func Test_KeyBackup(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -291,7 +291,7 @@ func Test_KeyBackup(t *testing.T) { } func Test_LoginToken(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() @@ -321,7 +321,7 @@ func Test_LoginToken(t *testing.T) { } func Test_OpenID(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) token := util.RandomString(24) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -341,7 +341,7 @@ func Test_OpenID(t *testing.T) { } func Test_Profile(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -379,7 +379,7 @@ func Test_Profile(t *testing.T) { } func Test_Pusher(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -430,7 +430,7 @@ func Test_Pusher(t *testing.T) { } func Test_ThreePID(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -467,7 +467,7 @@ func Test_ThreePID(t *testing.T) { } func Test_Notification(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) room := test.NewRoom(t, alice) diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index e614765a..40e37c5d 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -24,7 +24,6 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" - internalTest "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/inthttp" @@ -135,7 +134,7 @@ func TestQueryProfile(t *testing.T) { t.Run("HTTP API", func(t *testing.T) { router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() userapi.AddInternalRoutes(router, userAPI) - apiURL, cancel := internalTest.ListenAndServe(t, router, false) + apiURL, cancel := test.ListenAndServe(t, router, false) defer cancel() httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) if err != nil { |