aboutsummaryrefslogtreecommitdiff
path: root/federationapi/internal/perform_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'federationapi/internal/perform_test.go')
-rw-r--r--federationapi/internal/perform_test.go41
1 files changed, 41 insertions, 0 deletions
diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go
index e8e0d00a..e6e366f9 100644
--- a/federationapi/internal/perform_test.go
+++ b/federationapi/internal/perform_test.go
@@ -123,6 +123,47 @@ func TestQueryRelayServers(t *testing.T) {
assert.Equal(t, len(relayServers), len(res.RelayServers))
}
+func TestRemoveRelayServers(t *testing.T) {
+ testDB := test.NewInMemoryFederationDatabase()
+
+ server := gomatrixserverlib.ServerName("wakeup")
+ relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"}
+ err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers)
+ assert.NoError(t, err)
+
+ cfg := config.FederationAPI{
+ Matrix: &config.Global{
+ SigningIdentity: gomatrixserverlib.SigningIdentity{
+ ServerName: "relay",
+ },
+ },
+ }
+ fedClient := &testFedClient{}
+ stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
+ queues := queue.NewOutgoingQueues(
+ testDB, process.NewProcessContext(),
+ false,
+ cfg.Matrix.ServerName, fedClient, nil, &stats,
+ nil,
+ )
+ fedAPI := NewFederationInternalAPI(
+ testDB, &cfg, nil, fedClient, &stats, nil, queues, nil,
+ )
+
+ req := api.P2PRemoveRelayServersRequest{
+ Server: server,
+ RelayServers: []gomatrixserverlib.ServerName{"relay1"},
+ }
+ res := api.P2PRemoveRelayServersResponse{}
+ err = fedAPI.P2PRemoveRelayServers(context.Background(), &req, &res)
+ assert.NoError(t, err)
+
+ finalRelays, err := testDB.P2PGetRelayServersForServer(context.Background(), server)
+ assert.NoError(t, err)
+ assert.Equal(t, 1, len(finalRelays))
+ assert.Equal(t, gomatrixserverlib.ServerName("relay2"), finalRelays[0])
+}
+
func TestPerformDirectoryLookup(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()