diff options
author | devonh <devon.dmytro@gmail.com> | 2023-01-28 23:27:53 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-28 23:27:53 +0000 |
commit | 63df85db6d5bc528a784dc52e550fc64385c5f67 (patch) | |
tree | 80da0f2cbcf9f4473974e600f90f20aed9803707 /build | |
parent | 2debabf0f09bb6e55063bbaa00dfb77090789abc (diff) |
Relay integration to pinecone demos (#2955)
This extends the dendrite monolith for pinecone to integrate the s&f
features into the mobile apps.
Also makes a few tweaks to federation queueing/statistics to make some
edge cases more robust.
Diffstat (limited to 'build')
-rw-r--r-- | build/gobind-pinecone/monolith.go | 287 | ||||
-rw-r--r-- | build/gobind-pinecone/monolith_test.go | 124 |
2 files changed, 360 insertions, 51 deletions
diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index ff61ea6c..5e8e5875 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -78,19 +78,19 @@ const ( ) type DendriteMonolith struct { - logger logrus.Logger - baseDendrite *base.BaseDendrite - PineconeRouter *pineconeRouter.Router - PineconeMulticast *pineconeMulticast.Multicast - PineconeQUIC *pineconeSessions.Sessions - PineconeManager *pineconeConnections.ConnectionManager - StorageDirectory string - CacheDirectory string - listener net.Listener - httpServer *http.Server - userAPI userapiAPI.UserInternalAPI - federationAPI api.FederationInternalAPI - relayServersQueried map[gomatrixserverlib.ServerName]bool + logger logrus.Logger + baseDendrite *base.BaseDendrite + PineconeRouter *pineconeRouter.Router + PineconeMulticast *pineconeMulticast.Multicast + PineconeQUIC *pineconeSessions.Sessions + PineconeManager *pineconeConnections.ConnectionManager + StorageDirectory string + CacheDirectory string + listener net.Listener + httpServer *http.Server + userAPI userapiAPI.UserInternalAPI + federationAPI api.FederationInternalAPI + relayRetriever RelayServerRetriever } func (m *DendriteMonolith) PublicKey() string { @@ -167,6 +167,152 @@ func (m *DendriteMonolith) SetStaticPeer(uri string) { } } +func getServerKeyFromString(nodeID string) (gomatrixserverlib.ServerName, error) { + var nodeKey gomatrixserverlib.ServerName + if userID, err := gomatrixserverlib.NewUserID(nodeID, false); err == nil { + hexKey, decodeErr := hex.DecodeString(string(userID.Domain())) + if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { + return "", fmt.Errorf("UserID domain is not a valid ed25519 public key: %v", userID.Domain()) + } else { + nodeKey = userID.Domain() + } + } else { + hexKey, decodeErr := hex.DecodeString(nodeID) + if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { + return "", fmt.Errorf("Relay server uri is not a valid ed25519 public key: %v", nodeID) + } else { + nodeKey = gomatrixserverlib.ServerName(nodeID) + } + } + + return nodeKey, nil +} + +func updateNodeRelayServers( + node gomatrixserverlib.ServerName, + relays []gomatrixserverlib.ServerName, + ctx context.Context, + fedAPI api.FederationInternalAPI, +) { + // Get the current relay list + request := api.P2PQueryRelayServersRequest{Server: node} + response := api.P2PQueryRelayServersResponse{} + err := fedAPI.P2PQueryRelayServers(ctx, &request, &response) + if err != nil { + logrus.Warnf("Failed obtaining list of relay servers for %s: %s", node, err.Error()) + } + + // Remove old, non-matching relays + var serversToRemove []gomatrixserverlib.ServerName + for _, existingServer := range response.RelayServers { + shouldRemove := true + for _, newServer := range relays { + if newServer == existingServer { + shouldRemove = false + break + } + } + + if shouldRemove { + serversToRemove = append(serversToRemove, existingServer) + } + } + removeRequest := api.P2PRemoveRelayServersRequest{ + Server: node, + RelayServers: serversToRemove, + } + removeResponse := api.P2PRemoveRelayServersResponse{} + err = fedAPI.P2PRemoveRelayServers(ctx, &removeRequest, &removeResponse) + if err != nil { + logrus.Warnf("Failed removing old relay servers for %s: %s", node, err.Error()) + } + + // Add new relays + addRequest := api.P2PAddRelayServersRequest{ + Server: node, + RelayServers: relays, + } + addResponse := api.P2PAddRelayServersResponse{} + err = fedAPI.P2PAddRelayServers(ctx, &addRequest, &addResponse) + if err != nil { + logrus.Warnf("Failed adding relay servers for %s: %s", node, err.Error()) + } +} + +func (m *DendriteMonolith) SetRelayServers(nodeID string, uris string) { + relays := []gomatrixserverlib.ServerName{} + for _, uri := range strings.Split(uris, ",") { + uri = strings.TrimSpace(uri) + if len(uri) == 0 { + continue + } + + nodeKey, err := getServerKeyFromString(uri) + if err != nil { + logrus.Errorf(err.Error()) + continue + } + relays = append(relays, nodeKey) + } + + nodeKey, err := getServerKeyFromString(nodeID) + if err != nil { + logrus.Errorf(err.Error()) + return + } + + if string(nodeKey) == m.PublicKey() { + logrus.Infof("Setting own relay servers to: %v", relays) + m.relayRetriever.SetRelayServers(relays) + } else { + updateNodeRelayServers( + gomatrixserverlib.ServerName(nodeKey), + relays, + m.baseDendrite.Context(), + m.federationAPI, + ) + } +} + +func (m *DendriteMonolith) GetRelayServers(nodeID string) string { + nodeKey, err := getServerKeyFromString(nodeID) + if err != nil { + logrus.Errorf(err.Error()) + return "" + } + + relaysString := "" + if string(nodeKey) == m.PublicKey() { + relays := m.relayRetriever.GetRelayServers() + + for i, relay := range relays { + if i != 0 { + // Append a comma to the previous entry if there is one. + relaysString += "," + } + relaysString += string(relay) + } + } else { + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(nodeKey)} + response := api.P2PQueryRelayServersResponse{} + err := m.federationAPI.P2PQueryRelayServers(m.baseDendrite.Context(), &request, &response) + if err != nil { + logrus.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + return "" + } + + for i, relay := range response.RelayServers { + if i != 0 { + // Append a comma to the previous entry if there is one. + relaysString += "," + } + relaysString += string(relay) + } + } + + return relaysString +} + func (m *DendriteMonolith) DisconnectType(peertype int) { for _, p := range m.PineconeRouter.Peers() { if int(peertype) == p.PeerType { @@ -454,28 +600,28 @@ func (m *DendriteMonolith) Start() { } }() + stopRelayServerSync := make(chan bool) + + eLog := logrus.WithField("pinecone", "events") + m.relayRetriever = RelayServerRetriever{ + Context: context.Background(), + ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), + FederationAPI: m.federationAPI, + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + RelayAPI: monolith.RelayAPI, + running: *atomic.NewBool(false), + quit: stopRelayServerSync, + } + m.relayRetriever.InitializeRelayServers(eLog) + go func(ch <-chan pineconeEvents.Event) { - eLog := logrus.WithField("pinecone", "events") - stopRelayServerSync := make(chan bool) - - relayRetriever := RelayServerRetriever{ - Context: context.Background(), - ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), - FederationAPI: m.federationAPI, - relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), - RelayAPI: monolith.RelayAPI, - running: *atomic.NewBool(false), - } - relayRetriever.InitializeRelayServers(eLog) for event := range ch { switch e := event.(type) { case pineconeEvents.PeerAdded: - if !relayRetriever.running.Load() { - go relayRetriever.SyncRelayServers(stopRelayServerSync) - } + m.relayRetriever.StartSync() case pineconeEvents.PeerRemoved: - if relayRetriever.running.Load() && m.PineconeRouter.TotalPeerCount() == 0 { + if m.relayRetriever.running.Load() && m.PineconeRouter.TotalPeerCount() == 0 { stopRelayServerSync <- true } case pineconeEvents.BroadcastReceived: @@ -495,7 +641,7 @@ func (m *DendriteMonolith) Start() { } func (m *DendriteMonolith) Stop() { - m.baseDendrite.Close() + _ = m.baseDendrite.Close() m.baseDendrite.WaitForShutdown() _ = m.listener.Close() m.PineconeMulticast.Stop() @@ -511,32 +657,68 @@ type RelayServerRetriever struct { relayServersQueried map[gomatrixserverlib.ServerName]bool queriedServersMutex sync.Mutex running atomic.Bool + quit <-chan bool } -func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { - request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)} +func (r *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(r.ServerName)} response := api.P2PQueryRelayServersResponse{} - err := m.FederationAPI.P2PQueryRelayServers(m.Context, &request, &response) + err := r.FederationAPI.P2PQueryRelayServers(r.Context, &request, &response) if err != nil { eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) } + + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() for _, server := range response.RelayServers { - m.relayServersQueried[server] = false + r.relayServersQueried[server] = false } eLog.Infof("Registered relay servers: %v", response.RelayServers) } -func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { - defer m.running.Store(false) +func (r *RelayServerRetriever) SetRelayServers(servers []gomatrixserverlib.ServerName) { + updateNodeRelayServers(r.ServerName, servers, r.Context, r.FederationAPI) + + // Replace list of servers to sync with and mark them all as unsynced. + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + r.relayServersQueried = make(map[gomatrixserverlib.ServerName]bool) + for _, server := range servers { + r.relayServersQueried[server] = false + } + + r.StartSync() +} + +func (r *RelayServerRetriever) GetRelayServers() []gomatrixserverlib.ServerName { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + relayServers := []gomatrixserverlib.ServerName{} + for server := range r.relayServersQueried { + relayServers = append(relayServers, server) + } + + return relayServers +} + +func (r *RelayServerRetriever) StartSync() { + if !r.running.Load() { + logrus.Info("Starting relay server sync") + go r.SyncRelayServers(r.quit) + } +} + +func (r *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { + defer r.running.Store(false) t := time.NewTimer(relayServerRetryInterval) for { relayServersToQuery := []gomatrixserverlib.ServerName{} func() { - m.queriedServersMutex.Lock() - defer m.queriedServersMutex.Unlock() - for server, complete := range m.relayServersQueried { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + for server, complete := range r.relayServersQueried { if !complete { relayServersToQuery = append(relayServersToQuery, server) } @@ -544,9 +726,10 @@ func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { }() if len(relayServersToQuery) == 0 { // All relay servers have been synced. + logrus.Info("Finished syncing with all known relays") return } - m.queryRelayServers(relayServersToQuery) + r.queryRelayServers(relayServersToQuery) t.Reset(relayServerRetryInterval) select { @@ -560,30 +743,32 @@ func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { } } -func (m *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { - m.queriedServersMutex.Lock() - defer m.queriedServersMutex.Unlock() +func (r *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() result := map[gomatrixserverlib.ServerName]bool{} - for server, queried := range m.relayServersQueried { + for server, queried := range r.relayServersQueried { result[server] = queried } return result } -func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { - logrus.Info("querying relay servers for any available transactions") +func (r *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { + logrus.Info("Querying relay servers for any available transactions") for _, server := range relayServers { - userID, err := gomatrixserverlib.NewUserID("@user:"+string(m.ServerName), false) + userID, err := gomatrixserverlib.NewUserID("@user:"+string(r.ServerName), false) if err != nil { return } - err = m.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) + + logrus.Infof("Syncing with relay: %s", string(server)) + err = r.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) if err == nil { func() { - m.queriedServersMutex.Lock() - defer m.queriedServersMutex.Unlock() - m.relayServersQueried[server] = true + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + r.relayServersQueried[server] = true }() // TODO : What happens if your relay receives new messages after this point? // Should you continue to check with them, or should they try and contact you? diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go index edcf22bb..3c8873e0 100644 --- a/build/gobind-pinecone/monolith_test.go +++ b/build/gobind-pinecone/monolith_test.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "net" + "strings" "testing" "time" @@ -196,3 +197,126 @@ func TestMonolithStarts(t *testing.T) { monolith.PublicKey() monolith.Stop() } + +func TestMonolithSetRelayServers(t *testing.T) { + testCases := []struct { + name string + nodeID string + relays string + expectedRelays string + expectSelf bool + }{ + { + name: "assorted valid, invalid, empty & self keys", + nodeID: "@valid:abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + relays: "@valid:123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd,@invalid:notakey,,", + expectedRelays: "123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectSelf: true, + }, + { + name: "invalid node key", + nodeID: "@invalid:notakey", + relays: "@valid:123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd,@invalid:notakey,,", + expectedRelays: "", + expectSelf: false, + }, + { + name: "node is self", + nodeID: "self", + relays: "@valid:123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd,@invalid:notakey,,", + expectedRelays: "123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectSelf: false, + }, + } + + for _, tc := range testCases { + monolith := DendriteMonolith{} + monolith.Start() + + inputRelays := tc.relays + expectedRelays := tc.expectedRelays + if tc.expectSelf { + inputRelays += "," + monolith.PublicKey() + expectedRelays += "," + monolith.PublicKey() + } + nodeID := tc.nodeID + if nodeID == "self" { + nodeID = monolith.PublicKey() + } + + monolith.SetRelayServers(nodeID, inputRelays) + relays := monolith.GetRelayServers(nodeID) + monolith.Stop() + + if !containSameKeys(strings.Split(relays, ","), strings.Split(expectedRelays, ",")) { + t.Fatalf("%s: expected %s got %s", tc.name, expectedRelays, relays) + } + } +} + +func containSameKeys(expected []string, actual []string) bool { + if len(expected) != len(actual) { + return false + } + + for _, expectedKey := range expected { + hasMatch := false + for _, actualKey := range actual { + if actualKey == expectedKey { + hasMatch = true + } + } + + if !hasMatch { + return false + } + } + + return true +} + +func TestParseServerKey(t *testing.T) { + testCases := []struct { + name string + serverKey string + expectedErr bool + expectedKey gomatrixserverlib.ServerName + }{ + { + name: "valid userid as key", + serverKey: "@valid:abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectedErr: false, + expectedKey: "abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + }, + { + name: "valid key", + serverKey: "abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectedErr: false, + expectedKey: "abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + }, + { + name: "invalid userid key", + serverKey: "@invalid:notakey", + expectedErr: true, + expectedKey: "", + }, + { + name: "invalid key", + serverKey: "@invalid:notakey", + expectedErr: true, + expectedKey: "", + }, + } + + for _, tc := range testCases { + key, err := getServerKeyFromString(tc.serverKey) + if tc.expectedErr && err == nil { + t.Fatalf("%s: expected an error", tc.name) + } else if !tc.expectedErr && err != nil { + t.Fatalf("%s: didn't expect an error: %s", tc.name, err.Error()) + } + if tc.expectedKey != key { + t.Fatalf("%s: keys not equal. expected: %s got: %s", tc.name, tc.expectedKey, key) + } + } +} |