From 5b73592f5a4dddf64184fcbe33f4c1835c656480 Mon Sep 17 00:00:00 2001 From: devonh Date: Mon, 23 Jan 2023 17:55:12 +0000 Subject: Initial Store & Forward Implementation (#2917) This adds store & forward relays into dendrite for p2p. A few things have changed: - new relay api serves new http endpoints for s&f federation - updated outbound federation queueing which will attempt to forward using s&f if appropriate - database entries to track s&f relays for other nodes --- relayapi/api/api.go | 56 +++++ relayapi/internal/api.go | 53 +++++ relayapi/internal/perform.go | 141 +++++++++++++ relayapi/internal/perform_test.go | 121 +++++++++++ relayapi/relayapi.go | 74 +++++++ relayapi/relayapi_test.go | 154 ++++++++++++++ relayapi/routing/relaytxn.go | 74 +++++++ relayapi/routing/relaytxn_test.go | 220 ++++++++++++++++++++ relayapi/routing/routing.go | 123 +++++++++++ relayapi/routing/sendrelay.go | 77 +++++++ relayapi/routing/sendrelay_test.go | 209 +++++++++++++++++++ relayapi/storage/interface.go | 47 +++++ .../storage/postgres/relay_queue_json_table.go | 113 ++++++++++ relayapi/storage/postgres/relay_queue_table.go | 156 ++++++++++++++ relayapi/storage/postgres/storage.go | 64 ++++++ relayapi/storage/shared/storage.go | 170 +++++++++++++++ relayapi/storage/sqlite3/relay_queue_json_table.go | 137 ++++++++++++ relayapi/storage/sqlite3/relay_queue_table.go | 168 +++++++++++++++ relayapi/storage/sqlite3/storage.go | 64 ++++++ relayapi/storage/storage.go | 46 +++++ relayapi/storage/tables/interface.go | 66 ++++++ .../storage/tables/relay_queue_json_table_test.go | 173 ++++++++++++++++ relayapi/storage/tables/relay_queue_table_test.go | 229 +++++++++++++++++++++ 23 files changed, 2735 insertions(+) create mode 100644 relayapi/api/api.go create mode 100644 relayapi/internal/api.go create mode 100644 relayapi/internal/perform.go create mode 100644 relayapi/internal/perform_test.go create mode 100644 relayapi/relayapi.go create mode 100644 relayapi/relayapi_test.go create mode 100644 relayapi/routing/relaytxn.go create mode 100644 relayapi/routing/relaytxn_test.go create mode 100644 relayapi/routing/routing.go create mode 100644 relayapi/routing/sendrelay.go create mode 100644 relayapi/routing/sendrelay_test.go create mode 100644 relayapi/storage/interface.go create mode 100644 relayapi/storage/postgres/relay_queue_json_table.go create mode 100644 relayapi/storage/postgres/relay_queue_table.go create mode 100644 relayapi/storage/postgres/storage.go create mode 100644 relayapi/storage/shared/storage.go create mode 100644 relayapi/storage/sqlite3/relay_queue_json_table.go create mode 100644 relayapi/storage/sqlite3/relay_queue_table.go create mode 100644 relayapi/storage/sqlite3/storage.go create mode 100644 relayapi/storage/storage.go create mode 100644 relayapi/storage/tables/interface.go create mode 100644 relayapi/storage/tables/relay_queue_json_table_test.go create mode 100644 relayapi/storage/tables/relay_queue_table_test.go (limited to 'relayapi') diff --git a/relayapi/api/api.go b/relayapi/api/api.go new file mode 100644 index 00000000..9db39322 --- /dev/null +++ b/relayapi/api/api.go @@ -0,0 +1,56 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" +) + +// RelayInternalAPI is used to query information from the relay server. +type RelayInternalAPI interface { + RelayServerAPI + + // Retrieve from external relay server all transactions stored for us and process them. + PerformRelayServerSync( + ctx context.Context, + userID gomatrixserverlib.UserID, + relayServer gomatrixserverlib.ServerName, + ) error +} + +// RelayServerAPI exposes the store & query transaction functionality of a relay server. +type RelayServerAPI interface { + // Store transactions for forwarding to the destination at a later time. + PerformStoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, + userID gomatrixserverlib.UserID, + ) error + + // Obtain the oldest stored transaction for the specified userID. + QueryTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + previousEntry gomatrixserverlib.RelayEntry, + ) (QueryRelayTransactionsResponse, error) +} + +type QueryRelayTransactionsResponse struct { + Transaction gomatrixserverlib.Transaction `json:"transaction"` + EntryID int64 `json:"entry_id"` + EntriesQueued bool `json:"entries_queued"` +} diff --git a/relayapi/internal/api.go b/relayapi/internal/api.go new file mode 100644 index 00000000..3ff8c2ad --- /dev/null +++ b/relayapi/internal/api.go @@ -0,0 +1,53 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + fedAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/relayapi/storage" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" +) + +type RelayInternalAPI struct { + db storage.Database + fedClient fedAPI.FederationClient + rsAPI rsAPI.RoomserverInternalAPI + keyRing *gomatrixserverlib.KeyRing + producer *producers.SyncAPIProducer + presenceEnabledInbound bool + serverName gomatrixserverlib.ServerName +} + +func NewRelayInternalAPI( + db storage.Database, + fedClient fedAPI.FederationClient, + rsAPI rsAPI.RoomserverInternalAPI, + keyRing *gomatrixserverlib.KeyRing, + producer *producers.SyncAPIProducer, + presenceEnabledInbound bool, + serverName gomatrixserverlib.ServerName, +) *RelayInternalAPI { + return &RelayInternalAPI{ + db: db, + fedClient: fedClient, + rsAPI: rsAPI, + keyRing: keyRing, + producer: producer, + presenceEnabledInbound: presenceEnabledInbound, + serverName: serverName, + } +} diff --git a/relayapi/internal/perform.go b/relayapi/internal/perform.go new file mode 100644 index 00000000..59429933 --- /dev/null +++ b/relayapi/internal/perform.go @@ -0,0 +1,141 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// PerformRelayServerSync implements api.RelayInternalAPI +func (r *RelayInternalAPI) PerformRelayServerSync( + ctx context.Context, + userID gomatrixserverlib.UserID, + relayServer gomatrixserverlib.ServerName, +) error { + // Providing a default RelayEntry (EntryID = 0) is done to ask the relay if there are any + // transactions available for this node. + prevEntry := gomatrixserverlib.RelayEntry{} + asyncResponse, err := r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer) + if err != nil { + logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) + return err + } + r.processTransaction(&asyncResponse.Txn) + + prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} + for asyncResponse.EntriesQueued { + // There are still more entries available for this node from the relay. + logrus.Infof("Retrieving next entry from relay, previous: %v", prevEntry) + asyncResponse, err = r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer) + prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} + if err != nil { + logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) + return err + } + r.processTransaction(&asyncResponse.Txn) + } + + return nil +} + +// PerformStoreTransaction implements api.RelayInternalAPI +func (r *RelayInternalAPI) PerformStoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, + userID gomatrixserverlib.UserID, +) error { + logrus.Warnf("Storing transaction for %v", userID) + receipt, err := r.db.StoreTransaction(ctx, transaction) + if err != nil { + logrus.Errorf("db.StoreTransaction: %s", err.Error()) + return err + } + err = r.db.AssociateTransactionWithDestinations( + ctx, + map[gomatrixserverlib.UserID]struct{}{ + userID: {}, + }, + transaction.TransactionID, + receipt) + + return err +} + +// QueryTransactions implements api.RelayInternalAPI +func (r *RelayInternalAPI) QueryTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + previousEntry gomatrixserverlib.RelayEntry, +) (api.QueryRelayTransactionsResponse, error) { + logrus.Infof("QueryTransactions for %s", userID.Raw()) + if previousEntry.EntryID > 0 { + logrus.Infof("Cleaning previous entry (%v) from db for %s", + previousEntry.EntryID, + userID.Raw(), + ) + prevReceipt := receipt.NewReceipt(previousEntry.EntryID) + err := r.db.CleanTransactions(ctx, userID, []*receipt.Receipt{&prevReceipt}) + if err != nil { + logrus.Errorf("db.CleanTransactions: %s", err.Error()) + return api.QueryRelayTransactionsResponse{}, err + } + } + + transaction, receipt, err := r.db.GetTransaction(ctx, userID) + if err != nil { + logrus.Errorf("db.GetTransaction: %s", err.Error()) + return api.QueryRelayTransactionsResponse{}, err + } + + response := api.QueryRelayTransactionsResponse{} + if transaction != nil && receipt != nil { + logrus.Infof("Obtained transaction (%v) for %s", transaction.TransactionID, userID.Raw()) + response.Transaction = *transaction + response.EntryID = receipt.GetNID() + response.EntriesQueued = true + } else { + logrus.Infof("No more entries in the queue for %s", userID.Raw()) + response.EntryID = 0 + response.EntriesQueued = false + } + + return response, nil +} + +func (r *RelayInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) { + logrus.Warn("Processing transaction from relay server") + mu := internal.NewMutexByRoom() + t := internal.NewTxnReq( + r.rsAPI, + nil, + r.serverName, + r.keyRing, + mu, + r.producer, + r.presenceEnabledInbound, + txn.PDUs, + txn.EDUs, + txn.Origin, + txn.TransactionID, + txn.Destination) + + t.ProcessTransaction(context.TODO()) +} diff --git a/relayapi/internal/perform_test.go b/relayapi/internal/perform_test.go new file mode 100644 index 00000000..fb71b7d0 --- /dev/null +++ b/relayapi/internal/perform_test.go @@ -0,0 +1,121 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "testing" + + fedAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type testFedClient struct { + fedAPI.FederationClient + shouldFail bool + queryCount uint + queueDepth uint +} + +func (f *testFedClient) P2PGetTransactionFromRelay( + ctx context.Context, + u gomatrixserverlib.UserID, + prev gomatrixserverlib.RelayEntry, + relayServer gomatrixserverlib.ServerName, +) (res gomatrixserverlib.RespGetRelayTransaction, err error) { + f.queryCount++ + if f.shouldFail { + return res, fmt.Errorf("Error") + } + + res = gomatrixserverlib.RespGetRelayTransaction{ + Txn: gomatrixserverlib.Transaction{}, + EntryID: 0, + } + if f.queueDepth > 0 { + res.EntriesQueued = true + } else { + res.EntriesQueued = false + } + f.queueDepth -= 1 + + return +} + +func TestPerformRelayServerSync(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.NoError(t, err) +} + +func TestPerformRelayServerSyncFedError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{shouldFail: true} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.Error(t, err) +} + +func TestPerformRelayServerSyncRunsUntilQueueEmpty(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{queueDepth: 2} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.NoError(t, err) + assert.Equal(t, uint(3), fedClient.queryCount) +} diff --git a/relayapi/relayapi.go b/relayapi/relayapi.go new file mode 100644 index 00000000..f9f9d4ff --- /dev/null +++ b/relayapi/relayapi.go @@ -0,0 +1,74 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relayapi + +import ( + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. +func AddPublicRoutes( + base *base.BaseDendrite, + keyRing gomatrixserverlib.JSONVerifier, + relayAPI api.RelayInternalAPI, +) { + fedCfg := &base.Cfg.FederationAPI + + relay, ok := relayAPI.(*internal.RelayInternalAPI) + if !ok { + panic("relayapi.AddPublicRoutes called with a RelayInternalAPI impl which was not " + + "RelayInternalAPI. This is a programming error.") + } + + routing.Setup( + base.PublicFederationAPIMux, + fedCfg, + relay, + keyRing, + ) +} + +func NewRelayInternalAPI( + base *base.BaseDendrite, + fedClient *gomatrixserverlib.FederationClient, + rsAPI rsAPI.RoomserverInternalAPI, + keyRing *gomatrixserverlib.KeyRing, + producer *producers.SyncAPIProducer, +) api.RelayInternalAPI { + cfg := &base.Cfg.RelayAPI + + relayDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName) + if err != nil { + logrus.WithError(err).Panic("failed to connect to relay db") + } + + return internal.NewRelayInternalAPI( + relayDB, + fedClient, + rsAPI, + keyRing, + producer, + base.Cfg.Global.Presence.EnableInbound, + base.Cfg.Global.ServerName, + ) +} diff --git a/relayapi/relayapi_test.go b/relayapi/relayapi_test.go new file mode 100644 index 00000000..dfa06811 --- /dev/null +++ b/relayapi/relayapi_test.go @@ -0,0 +1,154 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relayapi_test + +import ( + "crypto/ed25519" + "encoding/hex" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + "github.com/matrix-org/dendrite/relayapi" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +func TestCreateNewRelayInternalAPI(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + assert.NotNil(t, relayAPI) + }) +} + +func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + if dbType == test.DBTypeSQLite { + base.Cfg.RelayAPI.Database.ConnectionString = "file:" + } else { + base.Cfg.RelayAPI.Database.ConnectionString = "test" + } + defer close() + + assert.Panics(t, func() { + relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + }) + }) +} + +func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + assert.Panics(t, func() { + relayapi.AddPublicRoutes(base, nil, nil) + }) + }) +} + +func createGetRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, userID string) *http.Request { + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", origin, serverName, "/_matrix/federation/v1/relay_txn/"+userID) + content := gomatrixserverlib.RelayEntry{EntryID: 0} + req.SetContent(content) + req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) + httpreq, _ := req.HTTPRequest() + vars := map[string]string{"userID": userID} + httpreq = mux.SetURLVars(httpreq, vars) + return httpreq +} + +type sendRelayContent struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` +} + +func createSendRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, txnID string, userID string) *http.Request { + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("PUT", origin, serverName, "/_matrix/federation/v1/send_relay/"+txnID+"/"+userID) + content := sendRelayContent{} + req.SetContent(content) + req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) + httpreq, _ := req.HTTPRequest() + vars := map[string]string{"userID": userID, "txnID": txnID} + httpreq = mux.SetURLVars(httpreq, vars) + return httpreq +} + +func TestCreateRelayPublicRoutes(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + assert.NotNil(t, relayAPI) + + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + relayapi.AddPublicRoutes(base, keyRing, relayAPI) + + testCases := []struct { + name string + req *http.Request + wantCode int + wantJoinedRooms []string + }{ + { + name: "relay_txn invalid user id", + req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "user:local"), + wantCode: 400, + }, + { + name: "relay_txn valid user id", + req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"), + wantCode: 200, + }, + { + name: "send_relay invalid user id", + req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "user:local"), + wantCode: 400, + }, + { + name: "send_relay valid user id", + req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"), + wantCode: 200, + }, + } + + for _, tc := range testCases { + w := httptest.NewRecorder() + base.PublicFederationAPIMux.ServeHTTP(w, tc.req) + if w.Code != tc.wantCode { + t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) + } + } + }) +} diff --git a/relayapi/routing/relaytxn.go b/relayapi/routing/relaytxn.go new file mode 100644 index 00000000..1b11b0ec --- /dev/null +++ b/relayapi/routing/relaytxn.go @@ -0,0 +1,74 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +type RelayTransactionResponse struct { + Transaction gomatrixserverlib.Transaction `json:"transaction"` + EntryID int64 `json:"entry_id,omitempty"` + EntriesQueued bool `json:"entries_queued"` +} + +// GetTransactionFromRelay implements /_matrix/federation/v1/relay_txn/{userID} +// This endpoint can be extracted into a separate relay server service. +func GetTransactionFromRelay( + httpReq *http.Request, + fedReq *gomatrixserverlib.FederationRequest, + relayAPI api.RelayInternalAPI, + userID gomatrixserverlib.UserID, +) util.JSONResponse { + logrus.Infof("Handling relay_txn for %s", userID.Raw()) + + previousEntry := gomatrixserverlib.RelayEntry{} + if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("invalid json provided"), + } + } + if previousEntry.EntryID < 0 { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("Invalid entry id provided. Must be >= 0."), + } + } + logrus.Infof("Previous entry provided: %v", previousEntry.EntryID) + + response, err := relayAPI.QueryTransactions(httpReq.Context(), userID, previousEntry) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + } + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: RelayTransactionResponse{ + Transaction: response.Transaction, + EntryID: response.EntryID, + EntriesQueued: response.EntriesQueued, + }, + } +} diff --git a/relayapi/routing/relaytxn_test.go b/relayapi/routing/relaytxn_test.go new file mode 100644 index 00000000..a47fdb19 --- /dev/null +++ b/relayapi/routing/relaytxn_test.go @@ -0,0 +1,220 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "net/http" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +func createQuery( + userID gomatrixserverlib.UserID, + prevEntry gomatrixserverlib.RelayEntry, +) gomatrixserverlib.FederationRequest { + var federationPathPrefixV1 = "/_matrix/federation/v1" + path := federationPathPrefixV1 + "/relay_txn/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("GET", userID.Domain(), "relay", path) + request.SetContent(prevEntry) + + return request +} + +func TestGetEmptyDatabaseReturnsNothing(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + + _, err = db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.RelayTransactionResponse) + assert.Equal(t, false, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} + +func TestGetInvalidPrevEntryFails(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + + _, err = db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusInternalServerError, response.Code) +} + +func TestGetReturnsSavedTransaction(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + receipt, err := db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction.TransactionID, + receipt) + assert.NoError(t, err, "Failed to associate transaction with user") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.RelayTransactionResponse) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction, jsonResponse.Transaction) + + // And once more to clear the queue + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.RelayTransactionResponse) + assert.False(t, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} + +func TestGetReturnsMultipleSavedTransactions(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + receipt, err := db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction.TransactionID, + receipt) + assert.NoError(t, err, "Failed to associate transaction with user") + + transaction2 := createTransaction() + receipt2, err := db.StoreTransaction(context.Background(), transaction2) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction2.TransactionID, + receipt2) + assert.NoError(t, err, "Failed to associate transaction with user") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.RelayTransactionResponse) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction, jsonResponse.Transaction) + + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.RelayTransactionResponse) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction2, jsonResponse.Transaction) + + // And once more to clear the queue + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.RelayTransactionResponse) + assert.False(t, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} diff --git a/relayapi/routing/routing.go b/relayapi/routing/routing.go new file mode 100644 index 00000000..6df0cdc5 --- /dev/null +++ b/relayapi/routing/routing.go @@ -0,0 +1,123 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "fmt" + "net/http" + "time" + + "github.com/getsentry/sentry-go" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/httputil" + relayInternal "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// Setup registers HTTP handlers with the given ServeMux. +// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly +// path unescape twice (once from the router, once from MakeRelayAPI). We need to have this enabled +// so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz] +// +// Due to Setup being used to call many other functions, a gocyclo nolint is +// applied: +// nolint: gocyclo +func Setup( + fedMux *mux.Router, + cfg *config.FederationAPI, + relayAPI *relayInternal.RelayInternalAPI, + keys gomatrixserverlib.JSONVerifier, +) { + v1fedmux := fedMux.PathPrefix("/v1").Subrouter() + + v1fedmux.Handle("/send_relay/{txnID}/{userID}", MakeRelayAPI( + "send_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return SendTransactionToRelay( + httpReq, request, relayAPI, gomatrixserverlib.TransactionID(vars["txnID"]), + *userID, + ) + }, + )).Methods(http.MethodPut, http.MethodOptions) + + v1fedmux.Handle("/relay_txn/{userID}", MakeRelayAPI( + "get_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return GetTransactionFromRelay(httpReq, request, relayAPI, *userID) + }, + )).Methods(http.MethodGet, http.MethodOptions) +} + +// MakeRelayAPI makes an http.Handler that checks matrix relay authentication. +func MakeRelayAPI( + metricsName string, serverName gomatrixserverlib.ServerName, + isLocalServerName func(gomatrixserverlib.ServerName) bool, + keyRing gomatrixserverlib.JSONVerifier, + f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, +) http.Handler { + h := func(req *http.Request) util.JSONResponse { + fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + req, time.Now(), serverName, isLocalServerName, keyRing, + ) + if fedReq == nil { + return errResp + } + // add the user to Sentry, if enabled + hub := sentry.GetHubFromContext(req.Context()) + if hub != nil { + hub.Scope().SetTag("origin", string(fedReq.Origin())) + hub.Scope().SetTag("uri", fedReq.RequestURI()) + } + defer func() { + if r := recover(); r != nil { + if hub != nil { + hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path)) + } + // re-panic to return the 500 + panic(r) + } + }() + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params") + } + + jsonRes := f(req, fedReq, vars) + // do not log 4xx as errors as they are client fails, not server fails + if hub != nil && jsonRes.Code >= 500 { + hub.Scope().SetExtra("response", jsonRes) + hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code)) + } + return jsonRes + } + return httputil.MakeExternalAPI(metricsName, h) +} diff --git a/relayapi/routing/sendrelay.go b/relayapi/routing/sendrelay.go new file mode 100644 index 00000000..a7027f29 --- /dev/null +++ b/relayapi/routing/sendrelay.go @@ -0,0 +1,77 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +// SendTransactionToRelay implements PUT /_matrix/federation/v1/relay_txn/{txnID}/{userID} +// This endpoint can be extracted into a separate relay server service. +func SendTransactionToRelay( + httpReq *http.Request, + fedReq *gomatrixserverlib.FederationRequest, + relayAPI api.RelayInternalAPI, + txnID gomatrixserverlib.TransactionID, + userID gomatrixserverlib.UserID, +) util.JSONResponse { + var txnEvents struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` + } + + if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil { + logrus.Info("The request body could not be decoded into valid JSON." + err.Error()) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON." + err.Error()), + } + } + + // Transactions are limited in size; they can have at most 50 PDUs and 100 EDUs. + // https://matrix.org/docs/spec/server_server/latest#transactions + if len(txnEvents.PDUs) > 50 || len(txnEvents.EDUs) > 100 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("max 50 pdus / 100 edus"), + } + } + + t := gomatrixserverlib.Transaction{} + t.PDUs = txnEvents.PDUs + t.EDUs = txnEvents.EDUs + t.Origin = fedReq.Origin() + t.TransactionID = txnID + t.Destination = userID.Domain() + + util.GetLogger(httpReq.Context()).Warnf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, fedReq.Origin(), len(t.PDUs), len(t.EDUs)) + + err := relayAPI.PerformStoreTransaction(httpReq.Context(), t, userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("could not store the transaction for forwarding"), + } + } + + return util.JSONResponse{Code: 200} +} diff --git a/relayapi/routing/sendrelay_test.go b/relayapi/routing/sendrelay_test.go new file mode 100644 index 00000000..d9ed7500 --- /dev/null +++ b/relayapi/routing/sendrelay_test.go @@ -0,0 +1,209 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") +) + +func createTransaction() gomatrixserverlib.Transaction { + txn := gomatrixserverlib.Transaction{} + txn.PDUs = []json.RawMessage{ + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + } + txn.Origin = testOrigin + return txn +} + +func createFederationRequest( + userID gomatrixserverlib.UserID, + txnID gomatrixserverlib.TransactionID, + origin gomatrixserverlib.ServerName, + destination gomatrixserverlib.ServerName, + content interface{}, +) gomatrixserverlib.FederationRequest { + var federationPathPrefixV1 = "/_matrix/federation/v1" + path := federationPathPrefixV1 + "/send_relay/" + string(txnID) + "/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("PUT", origin, destination, path) + request.SetContent(content) + + return request +} + +func TestForwardEmptyReturnsOk(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.Equal(t, 200, response.Code) +} + +func TestForwardBadJSONReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field bool `json:"pdus"` + } + content := BadData{ + Field: false, + } + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestForwardTooManyPDUsReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field []json.RawMessage `json:"pdus"` + } + content := BadData{ + Field: []json.RawMessage{}, + } + for i := 0; i < 51; i++ { + content.Field = append(content.Field, []byte{}) + } + assert.Greater(t, len(content.Field), 50) + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestForwardTooManyEDUsReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field []gomatrixserverlib.EDU `json:"edus"` + } + content := BadData{ + Field: []gomatrixserverlib.EDU{}, + } + for i := 0; i < 101; i++ { + content.Field = append(content.Field, gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping}) + } + assert.Greater(t, len(content.Field), 100) + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestUniqueTransactionStoredInDatabase(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay( + httpReq, &request, relayAPI, txn.TransactionID, *userID) + transaction, _, err := db.GetTransaction(context.Background(), *userID) + assert.NoError(t, err, "Failed retrieving transaction") + + transactionCount, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err, "Failed retrieving transaction count") + + assert.Equal(t, 200, response.Code) + assert.Equal(t, int64(1), transactionCount) + assert.Equal(t, txn.TransactionID, transaction.TransactionID) +} diff --git a/relayapi/storage/interface.go b/relayapi/storage/interface.go new file mode 100644 index 00000000..f5f9a06e --- /dev/null +++ b/relayapi/storage/interface.go @@ -0,0 +1,47 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database interface { + // Adds a new transaction to the queue json table. + // Adding a duplicate transaction will result in a new row being added and a new unique nid. + // return: unique nid representing this entry. + StoreTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*receipt.Receipt, error) + + // Adds a new transaction_id: server_name mapping with associated json table nid to the queue + // entry table for each provided destination. + AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt) error + + // Removes every server_name: receipt pair provided from the queue entries table. + // Will then remove every entry for each receipt provided from the queue json table. + // If any of the entries don't exist in either table, nothing will happen for that entry and + // an error will not be generated. + CleanTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*receipt.Receipt) error + + // Gets the oldest transaction for the provided server_name. + // If no transactions exist, returns nil and no error. + GetTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) + + // Gets the number of transactions being stored for the provided server_name. + // If the server doesn't exist in the database then 0 is returned with no error. + GetTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) +} diff --git a/relayapi/storage/postgres/relay_queue_json_table.go b/relayapi/storage/postgres/relay_queue_json_table.go new file mode 100644 index 00000000..74410fc8 --- /dev/null +++ b/relayapi/storage/postgres/relay_queue_json_table.go @@ -0,0 +1,113 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const relayQueueJSONSchema = ` +-- The relayapi_queue_json table contains event contents that +-- we are storing for future forwarding. +CREATE TABLE IF NOT EXISTS relayapi_queue_json ( + -- The JSON NID. This allows cross-referencing to find the JSON blob. + json_nid BIGSERIAL, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx + ON relayapi_queue_json (json_nid); +` + +const insertQueueJSONSQL = "" + + "INSERT INTO relayapi_queue_json (json_body)" + + " VALUES ($1)" + + " RETURNING json_nid" + +const deleteQueueJSONSQL = "" + + "DELETE FROM relayapi_queue_json WHERE json_nid = ANY($1)" + +const selectQueueJSONSQL = "" + + "SELECT json_nid, json_body FROM relayapi_queue_json" + + " WHERE json_nid = ANY($1)" + +type relayQueueJSONStatements struct { + db *sql.DB + insertJSONStmt *sql.Stmt + deleteJSONStmt *sql.Stmt + selectJSONStmt *sql.Stmt +} + +func NewPostgresRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) { + s = &relayQueueJSONStatements{ + db: db, + } + _, err = s.db.Exec(relayQueueJSONSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertJSONStmt, insertQueueJSONSQL}, + {&s.deleteJSONStmt, deleteQueueJSONSQL}, + {&s.selectJSONStmt, selectQueueJSONSQL}, + }.Prepare(db) +} + +func (s *relayQueueJSONStatements) InsertQueueJSON( + ctx context.Context, txn *sql.Tx, json string, +) (int64, error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + var lastid int64 + if err := stmt.QueryRowContext(ctx, json).Scan(&lastid); err != nil { + return 0, err + } + return lastid, nil +} + +func (s *relayQueueJSONStatements) DeleteQueueJSON( + ctx context.Context, txn *sql.Tx, nids []int64, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt) + _, err := stmt.ExecContext(ctx, pq.Int64Array(nids)) + return err +} + +func (s *relayQueueJSONStatements) SelectQueueJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, s.selectJSONStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(jsonNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, err + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/relayapi/storage/postgres/relay_queue_table.go b/relayapi/storage/postgres/relay_queue_table.go new file mode 100644 index 00000000..e97cf8cc --- /dev/null +++ b/relayapi/storage/postgres/relay_queue_table.go @@ -0,0 +1,156 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayQueueSchema = ` +CREATE TABLE IF NOT EXISTS relayapi_queue ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The destination server that we will send the event to. + server_name TEXT NOT NULL, + -- The JSON NID from the relayapi_queue_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx + ON relayapi_queue (json_nid, server_name); +CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx + ON relayapi_queue (json_nid); +CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx + ON relayapi_queue (server_name); +` + +const insertQueueEntrySQL = "" + + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueEntriesSQL = "" + + "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid = ANY($2)" + +const selectQueueEntriesSQL = "" + + "SELECT json_nid FROM relayapi_queue" + + " WHERE server_name = $1" + + " ORDER BY json_nid" + + " LIMIT $2" + +const selectQueueEntryCountSQL = "" + + "SELECT COUNT(*) FROM relayapi_queue" + + " WHERE server_name = $1" + +type relayQueueStatements struct { + db *sql.DB + insertQueueEntryStmt *sql.Stmt + deleteQueueEntriesStmt *sql.Stmt + selectQueueEntriesStmt *sql.Stmt + selectQueueEntryCountStmt *sql.Stmt +} + +func NewPostgresRelayQueueTable( + db *sql.DB, +) (s *relayQueueStatements, err error) { + s = &relayQueueStatements{ + db: db, + } + _, err = s.db.Exec(relayQueueSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertQueueEntryStmt, insertQueueEntrySQL}, + {&s.deleteQueueEntriesStmt, deleteQueueEntriesSQL}, + {&s.selectQueueEntriesStmt, selectQueueEntriesSQL}, + {&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL}, + }.Prepare(db) +} + +func (s *relayQueueStatements) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *relayQueueStatements) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs)) + return err +} + +func (s *relayQueueStatements) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} + +func (s *relayQueueStatements) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} diff --git a/relayapi/storage/postgres/storage.go b/relayapi/storage/postgres/storage.go new file mode 100644 index 00000000..1042beba --- /dev/null +++ b/relayapi/storage/postgres/storage.go @@ -0,0 +1,64 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database stores information needed by the relayapi +type Database struct { + shared.Database + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (*Database, error) { + var d Database + var err error + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { + return nil, err + } + queue, err := NewPostgresRelayQueueTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewPostgresRelayQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + RelayQueue: queue, + RelayQueueJSON: queueJSON, + } + return &d, nil +} diff --git a/relayapi/storage/shared/storage.go b/relayapi/storage/shared/storage.go new file mode 100644 index 00000000..0993707b --- /dev/null +++ b/relayapi/storage/shared/storage.go @@ -0,0 +1,170 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database struct { + DB *sql.DB + IsLocalServerName func(gomatrixserverlib.ServerName) bool + Cache caching.FederationCache + Writer sqlutil.Writer + RelayQueue tables.RelayQueue + RelayQueueJSON tables.RelayQueueJSON +} + +func (d *Database) StoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, +) (*receipt.Receipt, error) { + var err error + jsonTransaction, err := json.Marshal(transaction) + if err != nil { + return nil, fmt.Errorf("failed to marshal: %w", err) + } + + var nid int64 + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + nid, err = d.RelayQueueJSON.InsertQueueJSON(ctx, txn, string(jsonTransaction)) + return err + }) + if err != nil { + return nil, fmt.Errorf("d.insertQueueJSON: %w", err) + } + + newReceipt := receipt.NewReceipt(nid) + return &newReceipt, nil +} + +func (d *Database) AssociateTransactionWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.UserID]struct{}, + transactionID gomatrixserverlib.TransactionID, + dbReceipt *receipt.Receipt, +) error { + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var lastErr error + for destination := range destinations { + destination := destination + err := d.RelayQueue.InsertQueueEntry( + ctx, + txn, + transactionID, + destination.Domain(), + dbReceipt.GetNID(), + ) + if err != nil { + lastErr = fmt.Errorf("d.insertQueueEntry: %w", err) + } + } + return lastErr + }) + + return err +} + +func (d *Database) CleanTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + receipts []*receipt.Receipt, +) error { + nids := make([]int64, len(receipts)) + for i, dbReceipt := range receipts { + nids[i] = dbReceipt.GetNID() + } + + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + deleteEntryErr := d.RelayQueue.DeleteQueueEntries(ctx, txn, userID.Domain(), nids) + // TODO : If there are still queue entries for any of these nids for other destinations + // then we shouldn't delete the json entries. + // But this can't happen with the current api design. + // There will only ever be one server entry for each nid since each call to send_relay + // only accepts a single server name and inside there we create a new json entry. + // So for multiple destinations we would call send_relay multiple times and have multiple + // json entries of the same transaction. + // + // TLDR; this works as expected right now but can easily be optimised in the future. + deleteJSONErr := d.RelayQueueJSON.DeleteQueueJSON(ctx, txn, nids) + + if deleteEntryErr != nil { + return fmt.Errorf("d.deleteQueueEntries: %w", deleteEntryErr) + } + if deleteJSONErr != nil { + return fmt.Errorf("d.deleteQueueJSON: %w", deleteJSONErr) + } + return nil + }) + + return err +} + +func (d *Database) GetTransaction( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) { + entriesRequested := 1 + nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), entriesRequested) + if err != nil { + return nil, nil, fmt.Errorf("d.SelectQueueEntries: %w", err) + } + if len(nids) == 0 { + return nil, nil, nil + } + firstNID := nids[0] + + txns := map[int64][]byte{} + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + txns, err = d.RelayQueueJSON.SelectQueueJSON(ctx, txn, nids) + return err + }) + if err != nil { + return nil, nil, fmt.Errorf("d.SelectQueueJSON: %w", err) + } + + transaction := &gomatrixserverlib.Transaction{} + if _, ok := txns[firstNID]; !ok { + return nil, nil, fmt.Errorf("Failed retrieving json blob for transaction: %d", firstNID) + } + + err = json.Unmarshal(txns[firstNID], transaction) + if err != nil { + return nil, nil, fmt.Errorf("Unmarshal transaction: %w", err) + } + + newReceipt := receipt.NewReceipt(firstNID) + return transaction, &newReceipt, nil +} + +func (d *Database) GetTransactionCount( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (int64, error) { + count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain()) + if err != nil { + return 0, fmt.Errorf("d.SelectQueueEntryCount: %w", err) + } + return count, nil +} diff --git a/relayapi/storage/sqlite3/relay_queue_json_table.go b/relayapi/storage/sqlite3/relay_queue_json_table.go new file mode 100644 index 00000000..502da3b0 --- /dev/null +++ b/relayapi/storage/sqlite3/relay_queue_json_table.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const relayQueueJSONSchema = ` +-- The relayapi_queue_json table contains event contents that +-- we are storing for future forwarding. +CREATE TABLE IF NOT EXISTS relayapi_queue_json ( + -- The JSON NID. This allows cross-referencing to find the JSON blob. + json_nid INTEGER PRIMARY KEY AUTOINCREMENT, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx + ON relayapi_queue_json (json_nid); +` + +const insertQueueJSONSQL = "" + + "INSERT INTO relayapi_queue_json (json_body)" + + " VALUES ($1)" + +const deleteQueueJSONSQL = "" + + "DELETE FROM relayapi_queue_json WHERE json_nid IN ($1)" + +const selectQueueJSONSQL = "" + + "SELECT json_nid, json_body FROM relayapi_queue_json" + + " WHERE json_nid IN ($1)" + +type relayQueueJSONStatements struct { + db *sql.DB + insertJSONStmt *sql.Stmt + //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic + //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) { + s = &relayQueueJSONStatements{ + db: db, + } + _, err = db.Exec(relayQueueJSONSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertJSONStmt, insertQueueJSONSQL}, + }.Prepare(db) +} + +func (s *relayQueueJSONStatements) InsertQueueJSON( + ctx context.Context, txn *sql.Tx, json string, +) (lastid int64, err error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + res, err := stmt.ExecContext(ctx, json) + if err != nil { + return 0, fmt.Errorf("stmt.QueryContext: %w", err) + } + lastid, err = res.LastInsertId() + if err != nil { + return 0, fmt.Errorf("res.LastInsertId: %w", err) + } + return +} + +func (s *relayQueueJSONStatements) DeleteQueueJSON( + ctx context.Context, txn *sql.Tx, nids []int64, +) error { + deleteSQL := strings.Replace(deleteQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(nids)) + for k, v := range nids { + iNIDs[k] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, iNIDs...) + return err +} + +func (s *relayQueueJSONStatements) SelectQueueJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + selectSQL := strings.Replace(selectQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) + selectStmt, err := txn.Prepare(selectSQL) + if err != nil { + return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(jsonNIDs)) + for k, v := range jsonNIDs { + iNIDs[k] = v + } + + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, selectStmt) + rows, err := stmt.QueryContext(ctx, iNIDs...) + if err != nil { + return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "selectQueueJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err) + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/relayapi/storage/sqlite3/relay_queue_table.go b/relayapi/storage/sqlite3/relay_queue_table.go new file mode 100644 index 00000000..49c6b4de --- /dev/null +++ b/relayapi/storage/sqlite3/relay_queue_table.go @@ -0,0 +1,168 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayQueueSchema = ` +CREATE TABLE IF NOT EXISTS relayapi_queue ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The domain part of the user ID the m.room.member event is for. + server_name TEXT NOT NULL, + -- The JSON NID from the relayapi_queue_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx + ON relayapi_queue (json_nid, server_name); +CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx + ON relayapi_queue (json_nid); +CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx + ON relayapi_queue (server_name); +` + +const insertQueueEntrySQL = "" + + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueEntriesSQL = "" + + "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid IN ($2)" + +const selectQueueEntriesSQL = "" + + "SELECT json_nid FROM relayapi_queue" + + " WHERE server_name = $1" + + " ORDER BY json_nid" + + " LIMIT $2" + +const selectQueueEntryCountSQL = "" + + "SELECT COUNT(*) FROM relayapi_queue" + + " WHERE server_name = $1" + +type relayQueueStatements struct { + db *sql.DB + insertQueueEntryStmt *sql.Stmt + selectQueueEntriesStmt *sql.Stmt + selectQueueEntryCountStmt *sql.Stmt + // deleteQueueEntriesStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteRelayQueueTable( + db *sql.DB, +) (s *relayQueueStatements, err error) { + s = &relayQueueStatements{ + db: db, + } + _, err = db.Exec(relayQueueSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertQueueEntryStmt, insertQueueEntrySQL}, + {&s.selectQueueEntriesStmt, selectQueueEntriesSQL}, + {&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL}, + }.Prepare(db) +} + +func (s *relayQueueStatements) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *relayQueueStatements) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueEntries s.db.Prepare: %w", err) + } + + params := make([]interface{}, len(jsonNIDs)+1) + params[0] = serverName + for k, v := range jsonNIDs { + params[k+1] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *relayQueueStatements) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} + +func (s *relayQueueStatements) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} diff --git a/relayapi/storage/sqlite3/storage.go b/relayapi/storage/sqlite3/storage.go new file mode 100644 index 00000000..3ed4ab04 --- /dev/null +++ b/relayapi/storage/sqlite3/storage.go @@ -0,0 +1,64 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database stores information needed by the federation sender +type Database struct { + shared.Database + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (*Database, error) { + var d Database + var err error + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { + return nil, err + } + queue, err := NewSQLiteRelayQueueTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewSQLiteRelayQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + RelayQueue: queue, + RelayQueueJSON: queueJSON, + } + return &d, nil +} diff --git a/relayapi/storage/storage.go b/relayapi/storage/storage.go new file mode 100644 index 00000000..16ecbcfb --- /dev/null +++ b/relayapi/storage/storage.go @@ -0,0 +1,46 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !wasm +// +build !wasm + +package storage + +import ( + "fmt" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (Database, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) + case dbProperties.ConnectionString.IsPostgres(): + return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName) + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/relayapi/storage/tables/interface.go b/relayapi/storage/tables/interface.go new file mode 100644 index 00000000..9056a567 --- /dev/null +++ b/relayapi/storage/tables/interface.go @@ -0,0 +1,66 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/gomatrixserverlib" +) + +// RelayQueue table contains a mapping of server name to transaction id and the corresponding nid. +// These are the transactions being stored for the given destination server. +// The nids correspond to entries in the RelayQueueJSON table. +type RelayQueue interface { + // Adds a new transaction_id: server_name mapping with associated json table nid to the table. + // Will ensure only one transaction id is present for each server_name: nid mapping. + // Adding duplicates will silently do nothing. + InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + + // Removes multiple entries from the table corresponding the the list of nids provided. + // If any of the provided nids don't match a row in the table, that deletion is considered + // successful. + DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + + // Get a list of nids associated with the provided server name. + // Returns up to `limit` nids. The entries are returned oldest first. + // Will return an empty list if no matches were found. + SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + + // Get the number of entries in the table associated with the provided server name. + // If there are no matching rows, a count of 0 is returned with err set to nil. + SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) +} + +// RelayQueueJSON table contains a map of nid to the raw transaction json. +type RelayQueueJSON interface { + // Adds a new transaction to the table. + // Adding a duplicate transaction will result in a new row being added and a new unique nid. + // return: unique nid representing this entry. + InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + + // Removes multiple nids from the table. + // If any of the provided nids don't match a row in the table, that deletion is considered + // successful. + DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + + // Get the transaction json corresponding to the provided nids. + // Will return a partial result containing any matching nid from the table. + // Will return an empty map if no matches were found. + // It is the caller's responsibility to deal with the results appropriately. + // return: map indexed by nid of each matching transaction json. + SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} diff --git a/relayapi/storage/tables/relay_queue_json_table_test.go b/relayapi/storage/tables/relay_queue_json_table_test.go new file mode 100644 index 00000000..efa3363e --- /dev/null +++ b/relayapi/storage/tables/relay_queue_json_table_test.go @@ -0,0 +1,173 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables_test + +import ( + "context" + "database/sql" + "encoding/json" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") +) + +func mustCreateTransaction() gomatrixserverlib.Transaction { + txn := gomatrixserverlib.Transaction{} + txn.PDUs = []json.RawMessage{ + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + } + txn.Origin = testOrigin + + return txn +} + +type RelayQueueJSONDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.RelayQueueJSON +} + +func mustCreateQueueJSONTable( + t *testing.T, + dbType test.DBType, +) (database RelayQueueJSONDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.RelayQueueJSON + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayQueueJSONTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayQueueJSONTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayQueueJSONDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func TestShoudInsertTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + _, err = db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + }) +} + +func TestShouldRetrieveInsertedTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + var storedJSON map[int64][]byte + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, 1, len(storedJSON)) + + var storedTx gomatrixserverlib.Transaction + json.Unmarshal(storedJSON[1], &storedTx) + + assert.Equal(t, transaction, storedTx) + }) +} + +func TestShouldDeleteTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + storedJSON := map[int64][]byte{} + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + storedJSON = map[int64][]byte{} + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, 0, len(storedJSON)) + }) +} diff --git a/relayapi/storage/tables/relay_queue_table_test.go b/relayapi/storage/tables/relay_queue_table_test.go new file mode 100644 index 00000000..99f9922c --- /dev/null +++ b/relayapi/storage/tables/relay_queue_table_test.go @@ -0,0 +1,229 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type RelayQueueDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.RelayQueue +} + +func mustCreateQueueTable( + t *testing.T, + dbType test.DBType, +) (database RelayQueueDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.RelayQueue + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayQueueTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayQueueTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayQueueDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func TestShoudInsertQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + }) +} + +func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 10) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, nid, retrievedNids[0]) + assert.Equal(t, 1, len(retrievedNids)) + }) +} + +func TestShouldRetrieveOldestInsertedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(2) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName = gomatrixserverlib.ServerName("domain") + oldestNID := int64(1) + err = db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, oldestNID) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 1) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, oldestNID, retrievedNids[0]) + assert.Equal(t, 1, len(retrievedNids)) + + retrievedNids, err = db.Table.SelectQueueEntries(ctx, nil, serverName, 10) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, oldestNID, retrievedNids[0]) + assert.Equal(t, nid, retrievedNids[1]) + assert.Equal(t, 2, len(retrievedNids)) + }) +} + +func TestShouldDeleteQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(0), count) + }) +} + +func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + transactionID2 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d2", time.Now().UnixNano())) + serverName2 := gomatrixserverlib.ServerName("domain2") + nid2 := int64(2) + transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano())) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertQueueEntry(ctx, nil, transactionID2, serverName2, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertQueueEntry(ctx, nil, transactionID3, serverName, nid2) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(1), count) + + count, err = db.Table.SelectQueueEntryCount(ctx, nil, serverName2) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(1), count) + }) +} -- cgit v1.2.3