diff options
author | Devon Hudson <devonhudson@librem.one> | 2023-01-31 15:51:08 -0700 |
---|---|---|
committer | Devon Hudson <devonhudson@librem.one> | 2023-02-01 13:41:36 -0700 |
commit | be43b9c0eae5fd45e38f954c23cfadbdfa77cb32 (patch) | |
tree | fd7b4a61f21f63512d40ecea818d53e7d73995e1 | |
parent | 4738fe656fc0975b5faa4fcdf6c11a776b9563a0 (diff) |
Refactor common relay sync struct to remove duplication
-rw-r--r-- | build/gobind-pinecone/monolith.go | 218 | ||||
-rw-r--r-- | build/gobind-pinecone/monolith_test.go | 69 | ||||
-rw-r--r-- | cmd/dendrite-demo-pinecone/main.go | 112 | ||||
-rw-r--r-- | cmd/dendrite-demo-pinecone/relay/retriever.go | 237 | ||||
-rw-r--r-- | cmd/dendrite-demo-pinecone/relay/retriever_test.go | 99 |
5 files changed, 366 insertions, 369 deletions
diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index a219621b..f44eed89 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -36,6 +36,7 @@ import ( "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" @@ -70,11 +71,10 @@ import ( ) const ( - PeerTypeRemote = pineconeRouter.PeerTypeRemote - PeerTypeMulticast = pineconeRouter.PeerTypeMulticast - PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth - PeerTypeBonjour = pineconeRouter.PeerTypeBonjour - relayServerRetryInterval = time.Second * 30 + PeerTypeRemote = pineconeRouter.PeerTypeRemote + PeerTypeMulticast = pineconeRouter.PeerTypeMulticast + PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth + PeerTypeBonjour = pineconeRouter.PeerTypeBonjour ) type DendriteMonolith struct { @@ -91,7 +91,7 @@ type DendriteMonolith struct { userAPI userapiAPI.UserInternalAPI federationAPI api.FederationInternalAPI relayAPI relayServerAPI.RelayInternalAPI - relayRetriever RelayServerRetriever + relayRetriever relay.RelayServerRetriever } func (m *DendriteMonolith) PublicKey() string { @@ -189,57 +189,6 @@ func getServerKeyFromString(nodeID string) (gomatrixserverlib.ServerName, error) return nodeKey, nil } -func updateNodeRelayServers( - node gomatrixserverlib.ServerName, - relays []gomatrixserverlib.ServerName, - ctx context.Context, - fedAPI api.FederationInternalAPI, -) { - // Get the current relay list - request := api.P2PQueryRelayServersRequest{Server: node} - response := api.P2PQueryRelayServersResponse{} - err := fedAPI.P2PQueryRelayServers(ctx, &request, &response) - if err != nil { - logrus.Warnf("Failed obtaining list of relay servers for %s: %s", node, err.Error()) - } - - // Remove old, non-matching relays - var serversToRemove []gomatrixserverlib.ServerName - for _, existingServer := range response.RelayServers { - shouldRemove := true - for _, newServer := range relays { - if newServer == existingServer { - shouldRemove = false - break - } - } - - if shouldRemove { - serversToRemove = append(serversToRemove, existingServer) - } - } - removeRequest := api.P2PRemoveRelayServersRequest{ - Server: node, - RelayServers: serversToRemove, - } - removeResponse := api.P2PRemoveRelayServersResponse{} - err = fedAPI.P2PRemoveRelayServers(ctx, &removeRequest, &removeResponse) - if err != nil { - logrus.Warnf("Failed removing old relay servers for %s: %s", node, err.Error()) - } - - // Add new relays - addRequest := api.P2PAddRelayServersRequest{ - Server: node, - RelayServers: relays, - } - addResponse := api.P2PAddRelayServersResponse{} - err = fedAPI.P2PAddRelayServers(ctx, &addRequest, &addResponse) - if err != nil { - logrus.Warnf("Failed adding relay servers for %s: %s", node, err.Error()) - } -} - func (m *DendriteMonolith) SetRelayServers(nodeID string, uris string) { relays := []gomatrixserverlib.ServerName{} for _, uri := range strings.Split(uris, ",") { @@ -266,7 +215,7 @@ func (m *DendriteMonolith) SetRelayServers(nodeID string, uris string) { logrus.Infof("Setting own relay servers to: %v", relays) m.relayRetriever.SetRelayServers(relays) } else { - updateNodeRelayServers( + relay.UpdateNodeRelayServers( gomatrixserverlib.ServerName(nodeKey), relays, m.baseDendrite.Context(), @@ -610,27 +559,23 @@ func (m *DendriteMonolith) Start() { }() stopRelayServerSync := make(chan bool) - eLog := logrus.WithField("pinecone", "events") - m.relayRetriever = RelayServerRetriever{ - Context: context.Background(), - ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), - FederationAPI: m.federationAPI, - relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), - RelayAPI: monolith.RelayAPI, - running: *atomic.NewBool(false), - quit: stopRelayServerSync, - } + m.relayRetriever = relay.NewRelayServerRetriever( + context.Background(), + gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), + m.federationAPI, + monolith.RelayAPI, + stopRelayServerSync, + ) m.relayRetriever.InitializeRelayServers(eLog) go func(ch <-chan pineconeEvents.Event) { - for event := range ch { switch e := event.(type) { case pineconeEvents.PeerAdded: m.relayRetriever.StartSync() case pineconeEvents.PeerRemoved: - if m.relayRetriever.running.Load() && m.PineconeRouter.TotalPeerCount() == 0 { + if m.relayRetriever.IsRunning() && m.PineconeRouter.TotalPeerCount() == 0 { stopRelayServerSync <- true } case pineconeEvents.BroadcastReceived: @@ -658,139 +603,6 @@ func (m *DendriteMonolith) Stop() { _ = m.PineconeRouter.Close() } -type RelayServerRetriever struct { - Context context.Context - ServerName gomatrixserverlib.ServerName - FederationAPI api.FederationInternalAPI - RelayAPI relayServerAPI.RelayInternalAPI - relayServersQueried map[gomatrixserverlib.ServerName]bool - queriedServersMutex sync.Mutex - running atomic.Bool - quit <-chan bool -} - -func (r *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { - request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(r.ServerName)} - response := api.P2PQueryRelayServersResponse{} - err := r.FederationAPI.P2PQueryRelayServers(r.Context, &request, &response) - if err != nil { - eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) - } - - r.queriedServersMutex.Lock() - defer r.queriedServersMutex.Unlock() - for _, server := range response.RelayServers { - r.relayServersQueried[server] = false - } - - eLog.Infof("Registered relay servers: %v", response.RelayServers) -} - -func (r *RelayServerRetriever) SetRelayServers(servers []gomatrixserverlib.ServerName) { - updateNodeRelayServers(r.ServerName, servers, r.Context, r.FederationAPI) - - // Replace list of servers to sync with and mark them all as unsynced. - r.queriedServersMutex.Lock() - defer r.queriedServersMutex.Unlock() - r.relayServersQueried = make(map[gomatrixserverlib.ServerName]bool) - for _, server := range servers { - r.relayServersQueried[server] = false - } - - r.StartSync() -} - -func (r *RelayServerRetriever) GetRelayServers() []gomatrixserverlib.ServerName { - r.queriedServersMutex.Lock() - defer r.queriedServersMutex.Unlock() - relayServers := []gomatrixserverlib.ServerName{} - for server := range r.relayServersQueried { - relayServers = append(relayServers, server) - } - - return relayServers -} - -func (r *RelayServerRetriever) StartSync() { - if !r.running.Load() { - logrus.Info("Starting relay server sync") - go r.SyncRelayServers(r.quit) - } -} - -func (r *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { - defer r.running.Store(false) - - t := time.NewTimer(relayServerRetryInterval) - for { - relayServersToQuery := []gomatrixserverlib.ServerName{} - func() { - r.queriedServersMutex.Lock() - defer r.queriedServersMutex.Unlock() - for server, complete := range r.relayServersQueried { - if !complete { - relayServersToQuery = append(relayServersToQuery, server) - } - } - }() - if len(relayServersToQuery) == 0 { - // All relay servers have been synced. - logrus.Info("Finished syncing with all known relays") - return - } - r.queryRelayServers(relayServersToQuery) - t.Reset(relayServerRetryInterval) - - select { - case <-stop: - if !t.Stop() { - <-t.C - } - return - case <-t.C: - } - } -} - -func (r *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { - r.queriedServersMutex.Lock() - defer r.queriedServersMutex.Unlock() - - result := map[gomatrixserverlib.ServerName]bool{} - for server, queried := range r.relayServersQueried { - result[server] = queried - } - return result -} - -func (r *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { - logrus.Info("Querying relay servers for any available transactions") - for _, server := range relayServers { - userID, err := gomatrixserverlib.NewUserID("@user:"+string(r.ServerName), false) - if err != nil { - return - } - - logrus.Infof("Syncing with relay: %s", string(server)) - err = r.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) - if err == nil { - func() { - r.queriedServersMutex.Lock() - defer r.queriedServersMutex.Unlock() - r.relayServersQueried[server] = true - }() - // TODO : What happens if your relay receives new messages after this point? - // Should you continue to check with them, or should they try and contact you? - // They could send a "new_async_events" message your way maybe? - // Then you could mark them as needing to be queried again. - // What if you miss this message? - // Maybe you should try querying them again after a certain period of time as a backup? - } else { - logrus.Errorf("Failed querying relay server: %s", err.Error()) - } - } -} - const MaxFrameSize = types.MaxFrameSize type Conduit struct { diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go index 3c8873e0..f6bf2ef0 100644 --- a/build/gobind-pinecone/monolith_test.go +++ b/build/gobind-pinecone/monolith_test.go @@ -15,19 +15,13 @@ package gobind import ( - "context" "fmt" "net" "strings" "testing" - "time" - "github.com/matrix-org/dendrite/federationapi/api" - relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "gotest.tools/v3/poll" ) var TestBuf = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} @@ -128,69 +122,6 @@ func TestConduitReadCopyFails(t *testing.T) { assert.Error(t, err) } -var testRelayServers = []gomatrixserverlib.ServerName{"relay1", "relay2"} - -type FakeFedAPI struct { - api.FederationInternalAPI -} - -func (f *FakeFedAPI) P2PQueryRelayServers(ctx context.Context, req *api.P2PQueryRelayServersRequest, res *api.P2PQueryRelayServersResponse) error { - res.RelayServers = testRelayServers - return nil -} - -type FakeRelayAPI struct { - relayServerAPI.RelayInternalAPI -} - -func (r *FakeRelayAPI) PerformRelayServerSync(ctx context.Context, userID gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName) error { - return nil -} - -func TestRelayRetrieverInitialization(t *testing.T) { - retriever := RelayServerRetriever{ - Context: context.Background(), - ServerName: "server", - relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), - FederationAPI: &FakeFedAPI{}, - RelayAPI: &FakeRelayAPI{}, - } - - retriever.InitializeRelayServers(logrus.WithField("test", "relay")) - relayServers := retriever.GetQueriedServerStatus() - assert.Equal(t, 2, len(relayServers)) -} - -func TestRelayRetrieverSync(t *testing.T) { - retriever := RelayServerRetriever{ - Context: context.Background(), - ServerName: "server", - relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), - FederationAPI: &FakeFedAPI{}, - RelayAPI: &FakeRelayAPI{}, - } - - retriever.InitializeRelayServers(logrus.WithField("test", "relay")) - relayServers := retriever.GetQueriedServerStatus() - assert.Equal(t, 2, len(relayServers)) - - stopRelayServerSync := make(chan bool) - go retriever.SyncRelayServers(stopRelayServerSync) - - check := func(log poll.LogT) poll.Result { - relayServers := retriever.GetQueriedServerStatus() - for _, queried := range relayServers { - if !queried { - return poll.Continue("waiting for all servers to be queried") - } - } - - stopRelayServerSync <- true - return poll.Success() - } - poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) -} - func TestMonolithStarts(t *testing.T) { monolith := DendriteMonolith{} monolith.Start() diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 3f63ab65..abffbb67 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -33,6 +33,7 @@ import ( "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/embed" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" @@ -43,7 +44,6 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/relayapi" - relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -52,7 +52,6 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" - "go.uber.org/atomic" pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" @@ -72,8 +71,6 @@ var ( instanceRelayingEnabled = flag.Bool("relay", false, "whether to enable store & forward relaying for other nodes") ) -const relayServerRetryInterval = time.Second * 30 - // nolint:gocyclo func main() { flag.Parse() @@ -328,28 +325,24 @@ func main() { logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter)) }() - go func(ch <-chan pineconeEvents.Event) { - eLog := logrus.WithField("pinecone", "events") - relayServerSyncRunning := atomic.NewBool(false) - stopRelayServerSync := make(chan bool) - - m := RelayServerRetriever{ - Context: context.Background(), - ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()), - FederationAPI: fsAPI, - RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool), - RelayAPI: monolith.RelayAPI, - } - m.InitializeRelayServers(eLog) + stopRelayServerSync := make(chan bool) + eLog := logrus.WithField("pinecone", "events") + relayRetriever := relay.NewRelayServerRetriever( + context.Background(), + gomatrixserverlib.ServerName(pRouter.PublicKey().String()), + monolith.FederationAPI, + monolith.RelayAPI, + stopRelayServerSync, + ) + relayRetriever.InitializeRelayServers(eLog) + go func(ch <-chan pineconeEvents.Event) { for event := range ch { switch e := event.(type) { case pineconeEvents.PeerAdded: - if !relayServerSyncRunning.Load() { - go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning) - } + relayRetriever.StartSync() case pineconeEvents.PeerRemoved: - if relayServerSyncRunning.Load() && pRouter.TotalPeerCount() == 0 { + if relayRetriever.IsRunning() && pRouter.TotalPeerCount() == 0 { stopRelayServerSync <- true } case pineconeEvents.BroadcastReceived: @@ -359,7 +352,7 @@ func main() { ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, } res := &api.PerformWakeupServersResponse{} - if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { + if err := monolith.FederationAPI.PerformWakeupServers(base.Context(), req, res); err != nil { eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) } default: @@ -369,78 +362,3 @@ func main() { base.WaitForShutdown() } - -type RelayServerRetriever struct { - Context context.Context - ServerName gomatrixserverlib.ServerName - FederationAPI api.FederationInternalAPI - RelayServersQueried map[gomatrixserverlib.ServerName]bool - RelayAPI relayServerAPI.RelayInternalAPI -} - -func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { - request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)} - response := api.P2PQueryRelayServersResponse{} - err := m.FederationAPI.P2PQueryRelayServers(m.Context, &request, &response) - if err != nil { - eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) - } - for _, server := range response.RelayServers { - m.RelayServersQueried[server] = false - } - - eLog.Infof("Registered relay servers: %v", response.RelayServers) -} - -func (m *RelayServerRetriever) syncRelayServers(stop <-chan bool, running atomic.Bool) { - defer running.Store(false) - - t := time.NewTimer(relayServerRetryInterval) - for { - relayServersToQuery := []gomatrixserverlib.ServerName{} - for server, complete := range m.RelayServersQueried { - if !complete { - relayServersToQuery = append(relayServersToQuery, server) - } - } - if len(relayServersToQuery) == 0 { - // All relay servers have been synced. - return - } - m.queryRelayServers(relayServersToQuery) - t.Reset(relayServerRetryInterval) - - select { - case <-stop: - // We have been asked to stop syncing, drain the timer and return. - if !t.Stop() { - <-t.C - } - return - case <-t.C: - // The timer has expired. Continue to the next loop iteration. - } - } -} - -func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { - logrus.Info("querying relay servers for any available transactions") - for _, server := range relayServers { - userID, err := gomatrixserverlib.NewUserID("@user:"+string(m.ServerName), false) - if err != nil { - return - } - err = m.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) - if err == nil { - m.RelayServersQueried[server] = true - // TODO : What happens if your relay receives new messages after this point? - // Should you continue to check with them, or should they try and contact you? - // They could send a "new_async_events" message your way maybe? - // Then you could mark them as needing to be queried again. - // What if you miss this message? - // Maybe you should try querying them again after a certain period of time as a backup? - } else { - logrus.Errorf("Failed querying relay server: %s", err.Error()) - } - } -} diff --git a/cmd/dendrite-demo-pinecone/relay/retriever.go b/cmd/dendrite-demo-pinecone/relay/retriever.go new file mode 100644 index 00000000..1b5c617e --- /dev/null +++ b/cmd/dendrite-demo-pinecone/relay/retriever.go @@ -0,0 +1,237 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relay + +import ( + "context" + "sync" + "time" + + federationAPI "github.com/matrix-org/dendrite/federationapi/api" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "go.uber.org/atomic" +) + +const ( + relayServerRetryInterval = time.Second * 30 +) + +type RelayServerRetriever struct { + ctx context.Context + serverName gomatrixserverlib.ServerName + federationAPI federationAPI.FederationInternalAPI + relayAPI relayServerAPI.RelayInternalAPI + relayServersQueried map[gomatrixserverlib.ServerName]bool + queriedServersMutex sync.Mutex + running atomic.Bool + quit <-chan bool +} + +func NewRelayServerRetriever( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + federationAPI federationAPI.FederationInternalAPI, + relayAPI relayServerAPI.RelayInternalAPI, + quit <-chan bool, +) RelayServerRetriever { + return RelayServerRetriever{ + ctx: ctx, + serverName: serverName, + federationAPI: federationAPI, + relayAPI: relayAPI, + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + running: *atomic.NewBool(false), + quit: quit, + } +} + +func (r *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { + request := federationAPI.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(r.serverName)} + response := federationAPI.P2PQueryRelayServersResponse{} + err := r.federationAPI.P2PQueryRelayServers(r.ctx, &request, &response) + if err != nil { + eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + } + + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + for _, server := range response.RelayServers { + r.relayServersQueried[server] = false + } + + eLog.Infof("Registered relay servers: %v", response.RelayServers) +} + +func (r *RelayServerRetriever) SetRelayServers(servers []gomatrixserverlib.ServerName) { + UpdateNodeRelayServers(r.serverName, servers, r.ctx, r.federationAPI) + + // Replace list of servers to sync with and mark them all as unsynced. + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + r.relayServersQueried = make(map[gomatrixserverlib.ServerName]bool) + for _, server := range servers { + r.relayServersQueried[server] = false + } + + r.StartSync() +} + +func (r *RelayServerRetriever) GetRelayServers() []gomatrixserverlib.ServerName { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + relayServers := []gomatrixserverlib.ServerName{} + for server := range r.relayServersQueried { + relayServers = append(relayServers, server) + } + + return relayServers +} + +func (r *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + + result := map[gomatrixserverlib.ServerName]bool{} + for server, queried := range r.relayServersQueried { + result[server] = queried + } + return result +} + +func (r *RelayServerRetriever) StartSync() { + if !r.running.Load() { + logrus.Info("Starting relay server sync") + go r.SyncRelayServers(r.quit) + } +} + +func (r *RelayServerRetriever) IsRunning() bool { + return r.running.Load() +} + +func (r *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { + defer r.running.Store(false) + + t := time.NewTimer(relayServerRetryInterval) + for { + relayServersToQuery := []gomatrixserverlib.ServerName{} + func() { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + for server, complete := range r.relayServersQueried { + if !complete { + relayServersToQuery = append(relayServersToQuery, server) + } + } + }() + if len(relayServersToQuery) == 0 { + // All relay servers have been synced. + logrus.Info("Finished syncing with all known relays") + return + } + r.queryRelayServers(relayServersToQuery) + t.Reset(relayServerRetryInterval) + + select { + case <-stop: + if !t.Stop() { + <-t.C + } + return + case <-t.C: + } + } +} + +func (r *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { + logrus.Info("Querying relay servers for any available transactions") + for _, server := range relayServers { + userID, err := gomatrixserverlib.NewUserID("@user:"+string(r.serverName), false) + if err != nil { + return + } + + logrus.Infof("Syncing with relay: %s", string(server)) + err = r.relayAPI.PerformRelayServerSync(context.Background(), *userID, server) + if err == nil { + func() { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + r.relayServersQueried[server] = true + }() + // TODO : What happens if your relay receives new messages after this point? + // Should you continue to check with them, or should they try and contact you? + // They could send a "new_async_events" message your way maybe? + // Then you could mark them as needing to be queried again. + // What if you miss this message? + // Maybe you should try querying them again after a certain period of time as a backup? + } else { + logrus.Errorf("Failed querying relay server: %s", err.Error()) + } + } +} + +func UpdateNodeRelayServers( + node gomatrixserverlib.ServerName, + relays []gomatrixserverlib.ServerName, + ctx context.Context, + fedAPI federationAPI.FederationInternalAPI, +) { + // Get the current relay list + request := federationAPI.P2PQueryRelayServersRequest{Server: node} + response := federationAPI.P2PQueryRelayServersResponse{} + err := fedAPI.P2PQueryRelayServers(ctx, &request, &response) + if err != nil { + logrus.Warnf("Failed obtaining list of relay servers for %s: %s", node, err.Error()) + } + + // Remove old, non-matching relays + var serversToRemove []gomatrixserverlib.ServerName + for _, existingServer := range response.RelayServers { + shouldRemove := true + for _, newServer := range relays { + if newServer == existingServer { + shouldRemove = false + break + } + } + + if shouldRemove { + serversToRemove = append(serversToRemove, existingServer) + } + } + removeRequest := federationAPI.P2PRemoveRelayServersRequest{ + Server: node, + RelayServers: serversToRemove, + } + removeResponse := federationAPI.P2PRemoveRelayServersResponse{} + err = fedAPI.P2PRemoveRelayServers(ctx, &removeRequest, &removeResponse) + if err != nil { + logrus.Warnf("Failed removing old relay servers for %s: %s", node, err.Error()) + } + + // Add new relays + addRequest := federationAPI.P2PAddRelayServersRequest{ + Server: node, + RelayServers: relays, + } + addResponse := federationAPI.P2PAddRelayServersResponse{} + err = fedAPI.P2PAddRelayServers(ctx, &addRequest, &addResponse) + if err != nil { + logrus.Warnf("Failed adding relay servers for %s: %s", node, err.Error()) + } +} diff --git a/cmd/dendrite-demo-pinecone/relay/retriever_test.go b/cmd/dendrite-demo-pinecone/relay/retriever_test.go new file mode 100644 index 00000000..8f86a377 --- /dev/null +++ b/cmd/dendrite-demo-pinecone/relay/retriever_test.go @@ -0,0 +1,99 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relay + +import ( + "context" + "testing" + "time" + + federationAPI "github.com/matrix-org/dendrite/federationapi/api" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "gotest.tools/v3/poll" +) + +var testRelayServers = []gomatrixserverlib.ServerName{"relay1", "relay2"} + +type FakeFedAPI struct { + federationAPI.FederationInternalAPI +} + +func (f *FakeFedAPI) P2PQueryRelayServers( + ctx context.Context, + req *federationAPI.P2PQueryRelayServersRequest, + res *federationAPI.P2PQueryRelayServersResponse, +) error { + res.RelayServers = testRelayServers + return nil +} + +type FakeRelayAPI struct { + relayServerAPI.RelayInternalAPI +} + +func (r *FakeRelayAPI) PerformRelayServerSync( + ctx context.Context, + userID gomatrixserverlib.UserID, + relayServer gomatrixserverlib.ServerName, +) error { + return nil +} + +func TestRelayRetrieverInitialization(t *testing.T) { + retriever := NewRelayServerRetriever( + context.Background(), + "server", + &FakeFedAPI{}, + &FakeRelayAPI{}, + make(<-chan bool), + ) + + retriever.InitializeRelayServers(logrus.WithField("test", "relay")) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) +} + +func TestRelayRetrieverSync(t *testing.T) { + retriever := NewRelayServerRetriever( + context.Background(), + "server", + &FakeFedAPI{}, + &FakeRelayAPI{}, + make(<-chan bool), + ) + + retriever.InitializeRelayServers(logrus.WithField("test", "relay")) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) + + stopRelayServerSync := make(chan bool) + go retriever.SyncRelayServers(stopRelayServerSync) + + check := func(log poll.LogT) poll.Result { + relayServers := retriever.GetQueriedServerStatus() + for _, queried := range relayServers { + if !queried { + return poll.Continue("waiting for all servers to be queried") + } + } + + stopRelayServerSync <- true + return poll.Success() + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} |