aboutsummaryrefslogtreecommitdiff
path: root/relayapi
diff options
context:
space:
mode:
authordevonh <devon.dmytro@gmail.com>2023-01-23 17:55:12 +0000
committerGitHub <noreply@github.com>2023-01-23 17:55:12 +0000
commit5b73592f5a4dddf64184fcbe33f4c1835c656480 (patch)
treeb6dac51b6be7a1e591f24881ee1bfae1b92088e9 /relayapi
parent48fa869fa3578741d1d5775d30f24f6b097ab995 (diff)
Initial Store & Forward Implementation (#2917)
This adds store & forward relays into dendrite for p2p. A few things have changed: - new relay api serves new http endpoints for s&f federation - updated outbound federation queueing which will attempt to forward using s&f if appropriate - database entries to track s&f relays for other nodes
Diffstat (limited to 'relayapi')
-rw-r--r--relayapi/api/api.go56
-rw-r--r--relayapi/internal/api.go53
-rw-r--r--relayapi/internal/perform.go141
-rw-r--r--relayapi/internal/perform_test.go121
-rw-r--r--relayapi/relayapi.go74
-rw-r--r--relayapi/relayapi_test.go154
-rw-r--r--relayapi/routing/relaytxn.go74
-rw-r--r--relayapi/routing/relaytxn_test.go220
-rw-r--r--relayapi/routing/routing.go123
-rw-r--r--relayapi/routing/sendrelay.go77
-rw-r--r--relayapi/routing/sendrelay_test.go209
-rw-r--r--relayapi/storage/interface.go47
-rw-r--r--relayapi/storage/postgres/relay_queue_json_table.go113
-rw-r--r--relayapi/storage/postgres/relay_queue_table.go156
-rw-r--r--relayapi/storage/postgres/storage.go64
-rw-r--r--relayapi/storage/shared/storage.go170
-rw-r--r--relayapi/storage/sqlite3/relay_queue_json_table.go137
-rw-r--r--relayapi/storage/sqlite3/relay_queue_table.go168
-rw-r--r--relayapi/storage/sqlite3/storage.go64
-rw-r--r--relayapi/storage/storage.go46
-rw-r--r--relayapi/storage/tables/interface.go66
-rw-r--r--relayapi/storage/tables/relay_queue_json_table_test.go173
-rw-r--r--relayapi/storage/tables/relay_queue_table_test.go229
23 files changed, 2735 insertions, 0 deletions
diff --git a/relayapi/api/api.go b/relayapi/api/api.go
new file mode 100644
index 00000000..9db39322
--- /dev/null
+++ b/relayapi/api/api.go
@@ -0,0 +1,56 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package api
+
+import (
+ "context"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// RelayInternalAPI is used to query information from the relay server.
+type RelayInternalAPI interface {
+ RelayServerAPI
+
+ // Retrieve from external relay server all transactions stored for us and process them.
+ PerformRelayServerSync(
+ ctx context.Context,
+ userID gomatrixserverlib.UserID,
+ relayServer gomatrixserverlib.ServerName,
+ ) error
+}
+
+// RelayServerAPI exposes the store & query transaction functionality of a relay server.
+type RelayServerAPI interface {
+ // Store transactions for forwarding to the destination at a later time.
+ PerformStoreTransaction(
+ ctx context.Context,
+ transaction gomatrixserverlib.Transaction,
+ userID gomatrixserverlib.UserID,
+ ) error
+
+ // Obtain the oldest stored transaction for the specified userID.
+ QueryTransactions(
+ ctx context.Context,
+ userID gomatrixserverlib.UserID,
+ previousEntry gomatrixserverlib.RelayEntry,
+ ) (QueryRelayTransactionsResponse, error)
+}
+
+type QueryRelayTransactionsResponse struct {
+ Transaction gomatrixserverlib.Transaction `json:"transaction"`
+ EntryID int64 `json:"entry_id"`
+ EntriesQueued bool `json:"entries_queued"`
+}
diff --git a/relayapi/internal/api.go b/relayapi/internal/api.go
new file mode 100644
index 00000000..3ff8c2ad
--- /dev/null
+++ b/relayapi/internal/api.go
@@ -0,0 +1,53 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ fedAPI "github.com/matrix-org/dendrite/federationapi/api"
+ "github.com/matrix-org/dendrite/federationapi/producers"
+ "github.com/matrix-org/dendrite/relayapi/storage"
+ rsAPI "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+type RelayInternalAPI struct {
+ db storage.Database
+ fedClient fedAPI.FederationClient
+ rsAPI rsAPI.RoomserverInternalAPI
+ keyRing *gomatrixserverlib.KeyRing
+ producer *producers.SyncAPIProducer
+ presenceEnabledInbound bool
+ serverName gomatrixserverlib.ServerName
+}
+
+func NewRelayInternalAPI(
+ db storage.Database,
+ fedClient fedAPI.FederationClient,
+ rsAPI rsAPI.RoomserverInternalAPI,
+ keyRing *gomatrixserverlib.KeyRing,
+ producer *producers.SyncAPIProducer,
+ presenceEnabledInbound bool,
+ serverName gomatrixserverlib.ServerName,
+) *RelayInternalAPI {
+ return &RelayInternalAPI{
+ db: db,
+ fedClient: fedClient,
+ rsAPI: rsAPI,
+ keyRing: keyRing,
+ producer: producer,
+ presenceEnabledInbound: presenceEnabledInbound,
+ serverName: serverName,
+ }
+}
diff --git a/relayapi/internal/perform.go b/relayapi/internal/perform.go
new file mode 100644
index 00000000..59429933
--- /dev/null
+++ b/relayapi/internal/perform.go
@@ -0,0 +1,141 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ "context"
+
+ "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/relayapi/api"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+)
+
+// PerformRelayServerSync implements api.RelayInternalAPI
+func (r *RelayInternalAPI) PerformRelayServerSync(
+ ctx context.Context,
+ userID gomatrixserverlib.UserID,
+ relayServer gomatrixserverlib.ServerName,
+) error {
+ // Providing a default RelayEntry (EntryID = 0) is done to ask the relay if there are any
+ // transactions available for this node.
+ prevEntry := gomatrixserverlib.RelayEntry{}
+ asyncResponse, err := r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer)
+ if err != nil {
+ logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error())
+ return err
+ }
+ r.processTransaction(&asyncResponse.Txn)
+
+ prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID}
+ for asyncResponse.EntriesQueued {
+ // There are still more entries available for this node from the relay.
+ logrus.Infof("Retrieving next entry from relay, previous: %v", prevEntry)
+ asyncResponse, err = r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer)
+ prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID}
+ if err != nil {
+ logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error())
+ return err
+ }
+ r.processTransaction(&asyncResponse.Txn)
+ }
+
+ return nil
+}
+
+// PerformStoreTransaction implements api.RelayInternalAPI
+func (r *RelayInternalAPI) PerformStoreTransaction(
+ ctx context.Context,
+ transaction gomatrixserverlib.Transaction,
+ userID gomatrixserverlib.UserID,
+) error {
+ logrus.Warnf("Storing transaction for %v", userID)
+ receipt, err := r.db.StoreTransaction(ctx, transaction)
+ if err != nil {
+ logrus.Errorf("db.StoreTransaction: %s", err.Error())
+ return err
+ }
+ err = r.db.AssociateTransactionWithDestinations(
+ ctx,
+ map[gomatrixserverlib.UserID]struct{}{
+ userID: {},
+ },
+ transaction.TransactionID,
+ receipt)
+
+ return err
+}
+
+// QueryTransactions implements api.RelayInternalAPI
+func (r *RelayInternalAPI) QueryTransactions(
+ ctx context.Context,
+ userID gomatrixserverlib.UserID,
+ previousEntry gomatrixserverlib.RelayEntry,
+) (api.QueryRelayTransactionsResponse, error) {
+ logrus.Infof("QueryTransactions for %s", userID.Raw())
+ if previousEntry.EntryID > 0 {
+ logrus.Infof("Cleaning previous entry (%v) from db for %s",
+ previousEntry.EntryID,
+ userID.Raw(),
+ )
+ prevReceipt := receipt.NewReceipt(previousEntry.EntryID)
+ err := r.db.CleanTransactions(ctx, userID, []*receipt.Receipt{&prevReceipt})
+ if err != nil {
+ logrus.Errorf("db.CleanTransactions: %s", err.Error())
+ return api.QueryRelayTransactionsResponse{}, err
+ }
+ }
+
+ transaction, receipt, err := r.db.GetTransaction(ctx, userID)
+ if err != nil {
+ logrus.Errorf("db.GetTransaction: %s", err.Error())
+ return api.QueryRelayTransactionsResponse{}, err
+ }
+
+ response := api.QueryRelayTransactionsResponse{}
+ if transaction != nil && receipt != nil {
+ logrus.Infof("Obtained transaction (%v) for %s", transaction.TransactionID, userID.Raw())
+ response.Transaction = *transaction
+ response.EntryID = receipt.GetNID()
+ response.EntriesQueued = true
+ } else {
+ logrus.Infof("No more entries in the queue for %s", userID.Raw())
+ response.EntryID = 0
+ response.EntriesQueued = false
+ }
+
+ return response, nil
+}
+
+func (r *RelayInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) {
+ logrus.Warn("Processing transaction from relay server")
+ mu := internal.NewMutexByRoom()
+ t := internal.NewTxnReq(
+ r.rsAPI,
+ nil,
+ r.serverName,
+ r.keyRing,
+ mu,
+ r.producer,
+ r.presenceEnabledInbound,
+ txn.PDUs,
+ txn.EDUs,
+ txn.Origin,
+ txn.TransactionID,
+ txn.Destination)
+
+ t.ProcessTransaction(context.TODO())
+}
diff --git a/relayapi/internal/perform_test.go b/relayapi/internal/perform_test.go
new file mode 100644
index 00000000..fb71b7d0
--- /dev/null
+++ b/relayapi/internal/perform_test.go
@@ -0,0 +1,121 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ "context"
+ "fmt"
+ "testing"
+
+ fedAPI "github.com/matrix-org/dendrite/federationapi/api"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/relayapi/storage/shared"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/stretchr/testify/assert"
+)
+
+type testFedClient struct {
+ fedAPI.FederationClient
+ shouldFail bool
+ queryCount uint
+ queueDepth uint
+}
+
+func (f *testFedClient) P2PGetTransactionFromRelay(
+ ctx context.Context,
+ u gomatrixserverlib.UserID,
+ prev gomatrixserverlib.RelayEntry,
+ relayServer gomatrixserverlib.ServerName,
+) (res gomatrixserverlib.RespGetRelayTransaction, err error) {
+ f.queryCount++
+ if f.shouldFail {
+ return res, fmt.Errorf("Error")
+ }
+
+ res = gomatrixserverlib.RespGetRelayTransaction{
+ Txn: gomatrixserverlib.Transaction{},
+ EntryID: 0,
+ }
+ if f.queueDepth > 0 {
+ res.EntriesQueued = true
+ } else {
+ res.EntriesQueued = false
+ }
+ f.queueDepth -= 1
+
+ return
+}
+
+func TestPerformRelayServerSync(t *testing.T) {
+ testDB := test.NewInMemoryRelayDatabase()
+ db := shared.Database{
+ Writer: sqlutil.NewDummyWriter(),
+ RelayQueue: testDB,
+ RelayQueueJSON: testDB,
+ }
+
+ userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
+ assert.Nil(t, err, "Invalid userID")
+
+ fedClient := &testFedClient{}
+ relayAPI := NewRelayInternalAPI(
+ &db, fedClient, nil, nil, nil, false, "",
+ )
+
+ err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay"))
+ assert.NoError(t, err)
+}
+
+func TestPerformRelayServerSyncFedError(t *testing.T) {
+ testDB := test.NewInMemoryRelayDatabase()
+ db := shared.Database{
+ Writer: sqlutil.NewDummyWriter(),
+ RelayQueue: testDB,
+ RelayQueueJSON: testDB,
+ }
+
+ userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
+ assert.Nil(t, err, "Invalid userID")
+
+ fedClient := &testFedClient{shouldFail: true}
+ relayAPI := NewRelayInternalAPI(
+ &db, fedClient, nil, nil, nil, false, "",
+ )
+
+ err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay"))
+ assert.Error(t, err)
+}
+
+func TestPerformRelayServerSyncRunsUntilQueueEmpty(t *testing.T) {
+ testDB := test.NewInMemoryRelayDatabase()
+ db := shared.Database{
+ Writer: sqlutil.NewDummyWriter(),
+ RelayQueue: testDB,
+ RelayQueueJSON: testDB,
+ }
+
+ userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
+ assert.Nil(t, err, "Invalid userID")
+
+ fedClient := &testFedClient{queueDepth: 2}
+ relayAPI := NewRelayInternalAPI(
+ &db, fedClient, nil, nil, nil, false, "",
+ )
+
+ err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay"))
+ assert.NoError(t, err)
+ assert.Equal(t, uint(3), fedClient.queryCount)
+}
diff --git a/relayapi/relayapi.go b/relayapi/relayapi.go
new file mode 100644
index 00000000..f9f9d4ff
--- /dev/null
+++ b/relayapi/relayapi.go
@@ -0,0 +1,74 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package relayapi
+
+import (
+ "github.com/matrix-org/dendrite/federationapi/producers"
+ "github.com/matrix-org/dendrite/relayapi/api"
+ "github.com/matrix-org/dendrite/relayapi/internal"
+ "github.com/matrix-org/dendrite/relayapi/routing"
+ "github.com/matrix-org/dendrite/relayapi/storage"
+ rsAPI "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/setup/base"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+)
+
+// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component.
+func AddPublicRoutes(
+ base *base.BaseDendrite,
+ keyRing gomatrixserverlib.JSONVerifier,
+ relayAPI api.RelayInternalAPI,
+) {
+ fedCfg := &base.Cfg.FederationAPI
+
+ relay, ok := relayAPI.(*internal.RelayInternalAPI)
+ if !ok {
+ panic("relayapi.AddPublicRoutes called with a RelayInternalAPI impl which was not " +
+ "RelayInternalAPI. This is a programming error.")
+ }
+
+ routing.Setup(
+ base.PublicFederationAPIMux,
+ fedCfg,
+ relay,
+ keyRing,
+ )
+}
+
+func NewRelayInternalAPI(
+ base *base.BaseDendrite,
+ fedClient *gomatrixserverlib.FederationClient,
+ rsAPI rsAPI.RoomserverInternalAPI,
+ keyRing *gomatrixserverlib.KeyRing,
+ producer *producers.SyncAPIProducer,
+) api.RelayInternalAPI {
+ cfg := &base.Cfg.RelayAPI
+
+ relayDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName)
+ if err != nil {
+ logrus.WithError(err).Panic("failed to connect to relay db")
+ }
+
+ return internal.NewRelayInternalAPI(
+ relayDB,
+ fedClient,
+ rsAPI,
+ keyRing,
+ producer,
+ base.Cfg.Global.Presence.EnableInbound,
+ base.Cfg.Global.ServerName,
+ )
+}
diff --git a/relayapi/relayapi_test.go b/relayapi/relayapi_test.go
new file mode 100644
index 00000000..dfa06811
--- /dev/null
+++ b/relayapi/relayapi_test.go
@@ -0,0 +1,154 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package relayapi_test
+
+import (
+ "crypto/ed25519"
+ "encoding/hex"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gorilla/mux"
+ "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
+ "github.com/matrix-org/dendrite/relayapi"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestCreateNewRelayInternalAPI(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ base, close := testrig.CreateBaseDendrite(t, dbType)
+ defer close()
+
+ relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil)
+ assert.NotNil(t, relayAPI)
+ })
+}
+
+func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ base, close := testrig.CreateBaseDendrite(t, dbType)
+ if dbType == test.DBTypeSQLite {
+ base.Cfg.RelayAPI.Database.ConnectionString = "file:"
+ } else {
+ base.Cfg.RelayAPI.Database.ConnectionString = "test"
+ }
+ defer close()
+
+ assert.Panics(t, func() {
+ relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil)
+ })
+ })
+}
+
+func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ base, close := testrig.CreateBaseDendrite(t, dbType)
+ defer close()
+
+ assert.Panics(t, func() {
+ relayapi.AddPublicRoutes(base, nil, nil)
+ })
+ })
+}
+
+func createGetRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, userID string) *http.Request {
+ _, sk, _ := ed25519.GenerateKey(nil)
+ keyID := signing.KeyID
+ pk := sk.Public().(ed25519.PublicKey)
+ origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk))
+ req := gomatrixserverlib.NewFederationRequest("GET", origin, serverName, "/_matrix/federation/v1/relay_txn/"+userID)
+ content := gomatrixserverlib.RelayEntry{EntryID: 0}
+ req.SetContent(content)
+ req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk)
+ httpreq, _ := req.HTTPRequest()
+ vars := map[string]string{"userID": userID}
+ httpreq = mux.SetURLVars(httpreq, vars)
+ return httpreq
+}
+
+type sendRelayContent struct {
+ PDUs []json.RawMessage `json:"pdus"`
+ EDUs []gomatrixserverlib.EDU `json:"edus"`
+}
+
+func createSendRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, txnID string, userID string) *http.Request {
+ _, sk, _ := ed25519.GenerateKey(nil)
+ keyID := signing.KeyID
+ pk := sk.Public().(ed25519.PublicKey)
+ origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk))
+ req := gomatrixserverlib.NewFederationRequest("PUT", origin, serverName, "/_matrix/federation/v1/send_relay/"+txnID+"/"+userID)
+ content := sendRelayContent{}
+ req.SetContent(content)
+ req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk)
+ httpreq, _ := req.HTTPRequest()
+ vars := map[string]string{"userID": userID, "txnID": txnID}
+ httpreq = mux.SetURLVars(httpreq, vars)
+ return httpreq
+}
+
+func TestCreateRelayPublicRoutes(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ base, close := testrig.CreateBaseDendrite(t, dbType)
+ defer close()
+
+ relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil)
+ assert.NotNil(t, relayAPI)
+
+ serverKeyAPI := &signing.YggdrasilKeys{}
+ keyRing := serverKeyAPI.KeyRing()
+ relayapi.AddPublicRoutes(base, keyRing, relayAPI)
+
+ testCases := []struct {
+ name string
+ req *http.Request
+ wantCode int
+ wantJoinedRooms []string
+ }{
+ {
+ name: "relay_txn invalid user id",
+ req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "user:local"),
+ wantCode: 400,
+ },
+ {
+ name: "relay_txn valid user id",
+ req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"),
+ wantCode: 200,
+ },
+ {
+ name: "send_relay invalid user id",
+ req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "user:local"),
+ wantCode: 400,
+ },
+ {
+ name: "send_relay valid user id",
+ req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"),
+ wantCode: 200,
+ },
+ }
+
+ for _, tc := range testCases {
+ w := httptest.NewRecorder()
+ base.PublicFederationAPIMux.ServeHTTP(w, tc.req)
+ if w.Code != tc.wantCode {
+ t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode)
+ }
+ }
+ })
+}
diff --git a/relayapi/routing/relaytxn.go b/relayapi/routing/relaytxn.go
new file mode 100644
index 00000000..1b11b0ec
--- /dev/null
+++ b/relayapi/routing/relaytxn.go
@@ -0,0 +1,74 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package routing
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/matrix-org/dendrite/clientapi/jsonerror"
+ "github.com/matrix-org/dendrite/relayapi/api"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
+)
+
+type RelayTransactionResponse struct {
+ Transaction gomatrixserverlib.Transaction `json:"transaction"`
+ EntryID int64 `json:"entry_id,omitempty"`
+ EntriesQueued bool `json:"entries_queued"`
+}
+
+// GetTransactionFromRelay implements /_matrix/federation/v1/relay_txn/{userID}
+// This endpoint can be extracted into a separate relay server service.
+func GetTransactionFromRelay(
+ httpReq *http.Request,
+ fedReq *gomatrixserverlib.FederationRequest,
+ relayAPI api.RelayInternalAPI,
+ userID gomatrixserverlib.UserID,
+) util.JSONResponse {
+ logrus.Infof("Handling relay_txn for %s", userID.Raw())
+
+ previousEntry := gomatrixserverlib.RelayEntry{}
+ if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil {
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: jsonerror.BadJSON("invalid json provided"),
+ }
+ }
+ if previousEntry.EntryID < 0 {
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: jsonerror.BadJSON("Invalid entry id provided. Must be >= 0."),
+ }
+ }
+ logrus.Infof("Previous entry provided: %v", previousEntry.EntryID)
+
+ response, err := relayAPI.QueryTransactions(httpReq.Context(), userID, previousEntry)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ }
+ }
+
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: RelayTransactionResponse{
+ Transaction: response.Transaction,
+ EntryID: response.EntryID,
+ EntriesQueued: response.EntriesQueued,
+ },
+ }
+}
diff --git a/relayapi/routing/relaytxn_test.go b/relayapi/routing/relaytxn_test.go
new file mode 100644
index 00000000..a47fdb19
--- /dev/null
+++ b/relayapi/routing/relaytxn_test.go
@@ -0,0 +1,220 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package routing_test
+
+import (
+ "context"
+ "net/http"
+ "testing"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/relayapi/internal"
+ "github.com/matrix-org/dendrite/relayapi/routing"
+ "github.com/matrix-org/dendrite/relayapi/storage/shared"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/stretchr/testify/assert"
+)
+
+func createQuery(
+ userID gomatrixserverlib.UserID,
+ prevEntry gomatrixserverlib.RelayEntry,
+) gomatrixserverlib.FederationRequest {
+ var federationPathPrefixV1 = "/_matrix/federation/v1"
+ path := federationPathPrefixV1 + "/relay_txn/" + userID.Raw()
+ request := gomatrixserverlib.NewFederationRequest("GET", userID.Domain(), "relay", path)
+ request.SetContent(prevEntry)
+
+ return request
+}
+
+func TestGetEmptyDatabaseReturnsNothing(t *testing.T) {
+ testDB := test.NewInMemoryRelayDatabase()
+ db := shared.Database{
+ Writer: sqlutil.NewDummyWriter(),
+ RelayQueue: testDB,
+ RelayQueueJSON: testDB,
+ }
+ httpReq := &http.Request{}
+ userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
+ assert.NoError(t, err, "Invalid userID")
+
+ transaction := createTransaction()
+
+ _, err = db.StoreTransaction(context.Background(), transaction)
+ assert.NoError(t, err, "Failed to store transaction")
+
+ relayAPI := internal.NewRelayInternalAPI(
+ &db, nil, nil, nil, nil, false, "",
+ )
+
+ request := createQuery(*userID, gomatrixserverlib.RelayEntry{})
+ response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
+ assert.Equal(t, http.StatusOK, response.Code)
+
+ jsonResponse := response.JSON.(routing.RelayTransactionResponse)
+ assert.Equal(t, false, jsonResponse.EntriesQueued)
+ assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction)
+
+ count, err := db.GetTransactionCount(context.Background(), *userID)
+ assert.NoError(t, err)
+ assert.Zero(t, count)
+}
+
+func TestGetInvalidPrevEntryFails(t *testing.T) {
+ testDB := test.NewInMemoryRelayDatabase()
+ db := shared.Database{
+ Writer: sqlutil.NewDummyWriter(),
+ RelayQueue: testDB,
+ RelayQueueJSON: testDB,
+ }
+ httpReq := &http.Request{}
+ userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
+ assert.NoError(t, err, "Invalid userID")
+
+ transaction := createTransaction()
+
+ _, err = db.StoreTransaction(context.Background(), transaction)
+ assert.NoError(t, err, "Failed to store transaction")
+
+ relayAPI := internal.NewRelayInternalAPI(
+ &db, nil, nil, nil, nil, false, "",
+ )
+
+ request := createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1})
+ response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
+ assert.Equal(t, http.StatusInternalServerError, response.Code)
+}
+
+func TestGetReturnsSavedTransaction(t *testing.T) {
+ testDB := test.NewInMemoryRelayDatabase()
+ db := shared.Database{
+ Writer: sqlutil.NewDummyWriter(),
+ RelayQueue: testDB,
+ RelayQueueJSON: testDB,
+ }
+ httpReq := &http.Request{}
+ userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
+ assert.NoError(t, err, "Invalid userID")
+
+ transaction := createTransaction()
+ receipt, err := db.StoreTransaction(context.Background(), transaction)
+ assert.NoError(t, err, "Failed to store transaction")
+
+ err = db.AssociateTransactionWithDestinations(
+ context.Background(),
+ map[gomatrixserverlib.UserID]struct{}{
+ *userID: {},
+ },
+ transaction.TransactionID,
+ receipt)
+ assert.NoError(t, err, "Failed to associate transaction with user")
+
+ relayAPI := internal.NewRelayInternalAPI(
+ &db, nil, nil, nil, nil, false, "",
+ )
+
+ request := createQuery(*userID, gomatrixserverlib.RelayEntry{})
+ response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
+ assert.Equal(t, http.StatusOK, response.Code)
+
+ jsonResponse := response.JSON.(routing.RelayTransactionResponse)
+ assert.True(t, jsonResponse.EntriesQueued)
+ assert.Equal(t, transaction, jsonResponse.Transaction)
+
+ // And once more to clear the queue
+ request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID})
+ response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
+ assert.Equal(t, http.StatusOK, response.Code)
+
+ jsonResponse = response.JSON.(routing.RelayTransactionResponse)
+ assert.False(t, jsonResponse.EntriesQueued)
+ assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction)
+
+ count, err := db.GetTransactionCount(context.Background(), *userID)
+ assert.NoError(t, err)
+ assert.Zero(t, count)
+}
+
+func TestGetReturnsMultipleSavedTransactions(t *testing.T) {
+ testDB := test.NewInMemoryRelayDatabase()
+ db := shared.Database{
+ Writer: sqlutil.NewDummyWriter(),
+ RelayQueue: testDB,
+ RelayQueueJSON: testDB,
+ }
+ httpReq := &http.Request{}
+ userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
+ assert.NoError(t, err, "Invalid userID")
+
+ transaction := createTransaction()
+ receipt, err := db.StoreTransaction(context.Background(), transaction)
+ assert.NoError(t, err, "Failed to store transaction")
+
+ err = db.AssociateTransactionWithDestinations(
+ context.Background(),
+ map[gomatrixserverlib.UserID]struct{}{
+ *userID: {},
+ },
+ transaction.TransactionID,
+ receipt)
+ assert.NoError(t, err, "Failed to associate transaction with user")
+
+ transaction2 := createTransaction()
+ receipt2, err := db.StoreTransaction(context.Background(), transaction2)
+ assert.NoError(t, err, "Failed to store transaction")
+
+ err = db.AssociateTransactionWithDestinations(
+ context.Background(),
+ map[gomatrixserverlib.UserID]struct{}{
+ *userID: {},
+ },
+ transaction2.TransactionID,
+ receipt2)
+ assert.NoError(t, err, "Failed to associate transaction with user")
+
+ relayAPI := internal.NewRelayInternalAPI(
+ &db, nil, nil, nil, nil, false, "",
+ )
+
+ request := createQuery(*userID, gomatrixserverlib.RelayEntry{})
+ response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
+ assert.Equal(t, http.StatusOK, response.Code)
+
+ jsonResponse := response.JSON.(routing.RelayTransactionResponse)
+ assert.True(t, jsonResponse.EntriesQueued)
+ assert.Equal(t, transaction, jsonResponse.Transaction)
+
+ request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID})
+ response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
+ assert.Equal(t, http.StatusOK, response.Code)
+
+ jsonResponse = response.JSON.(routing.RelayTransactionResponse)
+ assert.True(t, jsonResponse.EntriesQueued)
+ assert.Equal(t, transaction2, jsonResponse.Transaction)
+
+ // And once more to clear the queue
+ request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID})
+ response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID)
+ assert.Equal(t, http.StatusOK, response.Code)
+
+ jsonResponse = response.JSON.(routing.RelayTransactionResponse)
+ assert.False(t, jsonResponse.EntriesQueued)
+ assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction)
+
+ count, err := db.GetTransactionCount(context.Background(), *userID)
+ assert.NoError(t, err)
+ assert.Zero(t, count)
+}
diff --git a/relayapi/routing/routing.go b/relayapi/routing/routing.go
new file mode 100644
index 00000000..6df0cdc5
--- /dev/null
+++ b/relayapi/routing/routing.go
@@ -0,0 +1,123 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package routing
+
+import (
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/getsentry/sentry-go"
+ "github.com/gorilla/mux"
+ "github.com/matrix-org/dendrite/clientapi/jsonerror"
+ "github.com/matrix-org/dendrite/internal/httputil"
+ relayInternal "github.com/matrix-org/dendrite/relayapi/internal"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+)
+
+// Setup registers HTTP handlers with the given ServeMux.
+// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly
+// path unescape twice (once from the router, once from MakeRelayAPI). We need to have this enabled
+// so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz]
+//
+// Due to Setup being used to call many other functions, a gocyclo nolint is
+// applied:
+// nolint: gocyclo
+func Setup(
+ fedMux *mux.Router,
+ cfg *config.FederationAPI,
+ relayAPI *relayInternal.RelayInternalAPI,
+ keys gomatrixserverlib.JSONVerifier,
+) {
+ v1fedmux := fedMux.PathPrefix("/v1").Subrouter()
+
+ v1fedmux.Handle("/send_relay/{txnID}/{userID}", MakeRelayAPI(
+ "send_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys,
+ func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
+ userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.InvalidUsername("Username was invalid"),
+ }
+ }
+ return SendTransactionToRelay(
+ httpReq, request, relayAPI, gomatrixserverlib.TransactionID(vars["txnID"]),
+ *userID,
+ )
+ },
+ )).Methods(http.MethodPut, http.MethodOptions)
+
+ v1fedmux.Handle("/relay_txn/{userID}", MakeRelayAPI(
+ "get_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys,
+ func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
+ userID, err := gomatrixserverlib.NewUserID(vars["userID"], false)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.InvalidUsername("Username was invalid"),
+ }
+ }
+ return GetTransactionFromRelay(httpReq, request, relayAPI, *userID)
+ },
+ )).Methods(http.MethodGet, http.MethodOptions)
+}
+
+// MakeRelayAPI makes an http.Handler that checks matrix relay authentication.
+func MakeRelayAPI(
+ metricsName string, serverName gomatrixserverlib.ServerName,
+ isLocalServerName func(gomatrixserverlib.ServerName) bool,
+ keyRing gomatrixserverlib.JSONVerifier,
+ f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse,
+) http.Handler {
+ h := func(req *http.Request) util.JSONResponse {
+ fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest(
+ req, time.Now(), serverName, isLocalServerName, keyRing,
+ )
+ if fedReq == nil {
+ return errResp
+ }
+ // add the user to Sentry, if enabled
+ hub := sentry.GetHubFromContext(req.Context())
+ if hub != nil {
+ hub.Scope().SetTag("origin", string(fedReq.Origin()))
+ hub.Scope().SetTag("uri", fedReq.RequestURI())
+ }
+ defer func() {
+ if r := recover(); r != nil {
+ if hub != nil {
+ hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path))
+ }
+ // re-panic to return the 500
+ panic(r)
+ }
+ }()
+ vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params")
+ }
+
+ jsonRes := f(req, fedReq, vars)
+ // do not log 4xx as errors as they are client fails, not server fails
+ if hub != nil && jsonRes.Code >= 500 {
+ hub.Scope().SetExtra("response", jsonRes)
+ hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code))
+ }
+ return jsonRes
+ }
+ return httputil.MakeExternalAPI(metricsName, h)
+}
diff --git a/relayapi/routing/sendrelay.go b/relayapi/routing/sendrelay.go
new file mode 100644
index 00000000..a7027f29
--- /dev/null
+++ b/relayapi/routing/sendrelay.go
@@ -0,0 +1,77 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package routing
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/matrix-org/dendrite/clientapi/jsonerror"
+ "github.com/matrix-org/dendrite/relayapi/api"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
+)
+
+// SendTransactionToRelay implements PUT /_matrix/federation/v1/relay_txn/{txnID}/{userID}
+// This endpoint can be extracted into a separate relay server service.
+func SendTransactionToRelay(
+ httpReq *http.Request,
+ fedReq *gomatrixserverlib.FederationRequest,
+ relayAPI api.RelayInternalAPI,
+ txnID gomatrixserverlib.TransactionID,
+ userID gomatrixserverlib.UserID,
+) util.JSONResponse {
+ var txnEvents struct {
+ PDUs []json.RawMessage `json:"pdus"`
+ EDUs []gomatrixserverlib.EDU `json:"edus"`
+ }
+
+ if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil {
+ logrus.Info("The request body could not be decoded into valid JSON." + err.Error())
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON." + err.Error()),
+ }
+ }
+
+ // Transactions are limited in size; they can have at most 50 PDUs and 100 EDUs.
+ // https://matrix.org/docs/spec/server_server/latest#transactions
+ if len(txnEvents.PDUs) > 50 || len(txnEvents.EDUs) > 100 {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.BadJSON("max 50 pdus / 100 edus"),
+ }
+ }
+
+ t := gomatrixserverlib.Transaction{}
+ t.PDUs = txnEvents.PDUs
+ t.EDUs = txnEvents.EDUs
+ t.Origin = fedReq.Origin()
+ t.TransactionID = txnID
+ t.Destination = userID.Domain()
+
+ util.GetLogger(httpReq.Context()).Warnf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, fedReq.Origin(), len(t.PDUs), len(t.EDUs))
+
+ err := relayAPI.PerformStoreTransaction(httpReq.Context(), t, userID)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: jsonerror.BadJSON("could not store the transaction for forwarding"),
+ }
+ }
+
+ return util.JSONResponse{Code: 200}
+}
diff --git a/relayapi/routing/sendrelay_test.go b/relayapi/routing/sendrelay_test.go
new file mode 100644
index 00000000..d9ed7500
--- /dev/null
+++ b/relayapi/routing/sendrelay_test.go
@@ -0,0 +1,209 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package routing_test
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "testing"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/relayapi/internal"
+ "github.com/matrix-org/dendrite/relayapi/routing"
+ "github.com/matrix-org/dendrite/relayapi/storage/shared"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/stretchr/testify/assert"
+)
+
+const (
+ testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
+)
+
+func createTransaction() gomatrixserverlib.Transaction {
+ txn := gomatrixserverlib.Transaction{}
+ txn.PDUs = []json.RawMessage{
+ []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
+ }
+ txn.Origin = testOrigin
+ return txn
+}
+
+func createFederationRequest(
+ userID gomatrixserverlib.UserID,
+ txnID gomatrixserverlib.TransactionID,
+ origin gomatrixserverlib.ServerName,
+ destination gomatrixserverlib.ServerName,
+ content interface{},
+) gomatrixserverlib.FederationRequest {
+ var federationPathPrefixV1 = "/_matrix/federation/v1"
+ path := federationPathPrefixV1 + "/send_relay/" + string(txnID) + "/" + userID.Raw()
+ request := gomatrixserverlib.NewFederationRequest("PUT", origin, destination, path)
+ request.SetContent(content)
+
+ return request
+}
+
+func TestForwardEmptyReturnsOk(t *testing.T) {
+ testDB := test.NewInMemoryRelayDatabase()
+ db := shared.Database{
+ Writer: sqlutil.NewDummyWriter(),
+ RelayQueue: testDB,
+ RelayQueueJSON: testDB,
+ }
+ httpReq := &http.Request{}
+ userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
+ assert.NoError(t, err, "Invalid userID")
+
+ txn := createTransaction()
+ request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn)
+
+ relayAPI := internal.NewRelayInternalAPI(
+ &db, nil, nil, nil, nil, false, "",
+ )
+
+ response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID)
+
+ assert.Equal(t, 200, response.Code)
+}
+
+func TestForwardBadJSONReturnsError(t *testing.T) {
+ testDB := test.NewInMemoryRelayDatabase()
+ db := shared.Database{
+ Writer: sqlutil.NewDummyWriter(),
+ RelayQueue: testDB,
+ RelayQueueJSON: testDB,
+ }
+ httpReq := &http.Request{}
+ userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
+ assert.NoError(t, err, "Invalid userID")
+
+ type BadData struct {
+ Field bool `json:"pdus"`
+ }
+ content := BadData{
+ Field: false,
+ }
+ txn := createTransaction()
+ request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content)
+
+ relayAPI := internal.NewRelayInternalAPI(
+ &db, nil, nil, nil, nil, false, "",
+ )
+
+ response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID)
+
+ assert.NotEqual(t, 200, response.Code)
+}
+
+func TestForwardTooManyPDUsReturnsError(t *testing.T) {
+ testDB := test.NewInMemoryRelayDatabase()
+ db := shared.Database{
+ Writer: sqlutil.NewDummyWriter(),
+ RelayQueue: testDB,
+ RelayQueueJSON: testDB,
+ }
+ httpReq := &http.Request{}
+ userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
+ assert.NoError(t, err, "Invalid userID")
+
+ type BadData struct {
+ Field []json.RawMessage `json:"pdus"`
+ }
+ content := BadData{
+ Field: []json.RawMessage{},
+ }
+ for i := 0; i < 51; i++ {
+ content.Field = append(content.Field, []byte{})
+ }
+ assert.Greater(t, len(content.Field), 50)
+
+ txn := createTransaction()
+ request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content)
+
+ relayAPI := internal.NewRelayInternalAPI(
+ &db, nil, nil, nil, nil, false, "",
+ )
+
+ response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID)
+
+ assert.NotEqual(t, 200, response.Code)
+}
+
+func TestForwardTooManyEDUsReturnsError(t *testing.T) {
+ testDB := test.NewInMemoryRelayDatabase()
+ db := shared.Database{
+ Writer: sqlutil.NewDummyWriter(),
+ RelayQueue: testDB,
+ RelayQueueJSON: testDB,
+ }
+ httpReq := &http.Request{}
+ userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
+ assert.NoError(t, err, "Invalid userID")
+
+ type BadData struct {
+ Field []gomatrixserverlib.EDU `json:"edus"`
+ }
+ content := BadData{
+ Field: []gomatrixserverlib.EDU{},
+ }
+ for i := 0; i < 101; i++ {
+ content.Field = append(content.Field, gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping})
+ }
+ assert.Greater(t, len(content.Field), 100)
+
+ txn := createTransaction()
+ request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content)
+
+ relayAPI := internal.NewRelayInternalAPI(
+ &db, nil, nil, nil, nil, false, "",
+ )
+
+ response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID)
+
+ assert.NotEqual(t, 200, response.Code)
+}
+
+func TestUniqueTransactionStoredInDatabase(t *testing.T) {
+ testDB := test.NewInMemoryRelayDatabase()
+ db := shared.Database{
+ Writer: sqlutil.NewDummyWriter(),
+ RelayQueue: testDB,
+ RelayQueueJSON: testDB,
+ }
+ httpReq := &http.Request{}
+ userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
+ assert.NoError(t, err, "Invalid userID")
+
+ txn := createTransaction()
+ request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn)
+
+ relayAPI := internal.NewRelayInternalAPI(
+ &db, nil, nil, nil, nil, false, "",
+ )
+
+ response := routing.SendTransactionToRelay(
+ httpReq, &request, relayAPI, txn.TransactionID, *userID)
+ transaction, _, err := db.GetTransaction(context.Background(), *userID)
+ assert.NoError(t, err, "Failed retrieving transaction")
+
+ transactionCount, err := db.GetTransactionCount(context.Background(), *userID)
+ assert.NoError(t, err, "Failed retrieving transaction count")
+
+ assert.Equal(t, 200, response.Code)
+ assert.Equal(t, int64(1), transactionCount)
+ assert.Equal(t, txn.TransactionID, transaction.TransactionID)
+}
diff --git a/relayapi/storage/interface.go b/relayapi/storage/interface.go
new file mode 100644
index 00000000..f5f9a06e
--- /dev/null
+++ b/relayapi/storage/interface.go
@@ -0,0 +1,47 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+
+ "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+type Database interface {
+ // Adds a new transaction to the queue json table.
+ // Adding a duplicate transaction will result in a new row being added and a new unique nid.
+ // return: unique nid representing this entry.
+ StoreTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*receipt.Receipt, error)
+
+ // Adds a new transaction_id: server_name mapping with associated json table nid to the queue
+ // entry table for each provided destination.
+ AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt) error
+
+ // Removes every server_name: receipt pair provided from the queue entries table.
+ // Will then remove every entry for each receipt provided from the queue json table.
+ // If any of the entries don't exist in either table, nothing will happen for that entry and
+ // an error will not be generated.
+ CleanTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*receipt.Receipt) error
+
+ // Gets the oldest transaction for the provided server_name.
+ // If no transactions exist, returns nil and no error.
+ GetTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error)
+
+ // Gets the number of transactions being stored for the provided server_name.
+ // If the server doesn't exist in the database then 0 is returned with no error.
+ GetTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
+}
diff --git a/relayapi/storage/postgres/relay_queue_json_table.go b/relayapi/storage/postgres/relay_queue_json_table.go
new file mode 100644
index 00000000..74410fc8
--- /dev/null
+++ b/relayapi/storage/postgres/relay_queue_json_table.go
@@ -0,0 +1,113 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+)
+
+const relayQueueJSONSchema = `
+-- The relayapi_queue_json table contains event contents that
+-- we are storing for future forwarding.
+CREATE TABLE IF NOT EXISTS relayapi_queue_json (
+ -- The JSON NID. This allows cross-referencing to find the JSON blob.
+ json_nid BIGSERIAL,
+ -- The JSON body. Text so that we preserve UTF-8.
+ json_body TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx
+ ON relayapi_queue_json (json_nid);
+`
+
+const insertQueueJSONSQL = "" +
+ "INSERT INTO relayapi_queue_json (json_body)" +
+ " VALUES ($1)" +
+ " RETURNING json_nid"
+
+const deleteQueueJSONSQL = "" +
+ "DELETE FROM relayapi_queue_json WHERE json_nid = ANY($1)"
+
+const selectQueueJSONSQL = "" +
+ "SELECT json_nid, json_body FROM relayapi_queue_json" +
+ " WHERE json_nid = ANY($1)"
+
+type relayQueueJSONStatements struct {
+ db *sql.DB
+ insertJSONStmt *sql.Stmt
+ deleteJSONStmt *sql.Stmt
+ selectJSONStmt *sql.Stmt
+}
+
+func NewPostgresRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) {
+ s = &relayQueueJSONStatements{
+ db: db,
+ }
+ _, err = s.db.Exec(relayQueueJSONSchema)
+ if err != nil {
+ return
+ }
+
+ return s, sqlutil.StatementList{
+ {&s.insertJSONStmt, insertQueueJSONSQL},
+ {&s.deleteJSONStmt, deleteQueueJSONSQL},
+ {&s.selectJSONStmt, selectQueueJSONSQL},
+ }.Prepare(db)
+}
+
+func (s *relayQueueJSONStatements) InsertQueueJSON(
+ ctx context.Context, txn *sql.Tx, json string,
+) (int64, error) {
+ stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
+ var lastid int64
+ if err := stmt.QueryRowContext(ctx, json).Scan(&lastid); err != nil {
+ return 0, err
+ }
+ return lastid, nil
+}
+
+func (s *relayQueueJSONStatements) DeleteQueueJSON(
+ ctx context.Context, txn *sql.Tx, nids []int64,
+) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt)
+ _, err := stmt.ExecContext(ctx, pq.Int64Array(nids))
+ return err
+}
+
+func (s *relayQueueJSONStatements) SelectQueueJSON(
+ ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
+) (map[int64][]byte, error) {
+ blobs := map[int64][]byte{}
+ stmt := sqlutil.TxStmt(txn, s.selectJSONStmt)
+ rows, err := stmt.QueryContext(ctx, pq.Int64Array(jsonNIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
+ for rows.Next() {
+ var nid int64
+ var blob []byte
+ if err = rows.Scan(&nid, &blob); err != nil {
+ return nil, err
+ }
+ blobs[nid] = blob
+ }
+ return blobs, err
+}
diff --git a/relayapi/storage/postgres/relay_queue_table.go b/relayapi/storage/postgres/relay_queue_table.go
new file mode 100644
index 00000000..e97cf8cc
--- /dev/null
+++ b/relayapi/storage/postgres/relay_queue_table.go
@@ -0,0 +1,156 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const relayQueueSchema = `
+CREATE TABLE IF NOT EXISTS relayapi_queue (
+ -- The transaction ID that was generated before persisting the event.
+ transaction_id TEXT NOT NULL,
+ -- The destination server that we will send the event to.
+ server_name TEXT NOT NULL,
+ -- The JSON NID from the relayapi_queue_json table.
+ json_nid BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx
+ ON relayapi_queue (json_nid, server_name);
+CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx
+ ON relayapi_queue (json_nid);
+CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx
+ ON relayapi_queue (server_name);
+`
+
+const insertQueueEntrySQL = "" +
+ "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" +
+ " VALUES ($1, $2, $3)"
+
+const deleteQueueEntriesSQL = "" +
+ "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid = ANY($2)"
+
+const selectQueueEntriesSQL = "" +
+ "SELECT json_nid FROM relayapi_queue" +
+ " WHERE server_name = $1" +
+ " ORDER BY json_nid" +
+ " LIMIT $2"
+
+const selectQueueEntryCountSQL = "" +
+ "SELECT COUNT(*) FROM relayapi_queue" +
+ " WHERE server_name = $1"
+
+type relayQueueStatements struct {
+ db *sql.DB
+ insertQueueEntryStmt *sql.Stmt
+ deleteQueueEntriesStmt *sql.Stmt
+ selectQueueEntriesStmt *sql.Stmt
+ selectQueueEntryCountStmt *sql.Stmt
+}
+
+func NewPostgresRelayQueueTable(
+ db *sql.DB,
+) (s *relayQueueStatements, err error) {
+ s = &relayQueueStatements{
+ db: db,
+ }
+ _, err = s.db.Exec(relayQueueSchema)
+ if err != nil {
+ return
+ }
+
+ return s, sqlutil.StatementList{
+ {&s.insertQueueEntryStmt, insertQueueEntrySQL},
+ {&s.deleteQueueEntriesStmt, deleteQueueEntriesSQL},
+ {&s.selectQueueEntriesStmt, selectQueueEntriesSQL},
+ {&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL},
+ }.Prepare(db)
+}
+
+func (s *relayQueueStatements) InsertQueueEntry(
+ ctx context.Context,
+ txn *sql.Tx,
+ transactionID gomatrixserverlib.TransactionID,
+ serverName gomatrixserverlib.ServerName,
+ nid int64,
+) error {
+ stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ transactionID, // the transaction ID that we initially attempted
+ serverName, // destination server name
+ nid, // JSON blob NID
+ )
+ return err
+}
+
+func (s *relayQueueStatements) DeleteQueueEntries(
+ ctx context.Context,
+ txn *sql.Tx,
+ serverName gomatrixserverlib.ServerName,
+ jsonNIDs []int64,
+) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt)
+ _, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs))
+ return err
+}
+
+func (s *relayQueueStatements) SelectQueueEntries(
+ ctx context.Context,
+ txn *sql.Tx,
+ serverName gomatrixserverlib.ServerName,
+ limit int,
+) ([]int64, error) {
+ stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
+ rows, err := stmt.QueryContext(ctx, serverName, limit)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
+ var result []int64
+ for rows.Next() {
+ var nid int64
+ if err = rows.Scan(&nid); err != nil {
+ return nil, err
+ }
+ result = append(result, nid)
+ }
+
+ return result, rows.Err()
+}
+
+func (s *relayQueueStatements) SelectQueueEntryCount(
+ ctx context.Context,
+ txn *sql.Tx,
+ serverName gomatrixserverlib.ServerName,
+) (int64, error) {
+ var count int64
+ stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)
+ err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
+ if err == sql.ErrNoRows {
+ // It's acceptable for there to be no rows referencing a given
+ // JSON NID but it's not an error condition. Just return as if
+ // there's a zero count.
+ return 0, nil
+ }
+ return count, err
+}
diff --git a/relayapi/storage/postgres/storage.go b/relayapi/storage/postgres/storage.go
new file mode 100644
index 00000000..1042beba
--- /dev/null
+++ b/relayapi/storage/postgres/storage.go
@@ -0,0 +1,64 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/internal/caching"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/relayapi/storage/shared"
+ "github.com/matrix-org/dendrite/setup/base"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// Database stores information needed by the relayapi
+type Database struct {
+ shared.Database
+ db *sql.DB
+ writer sqlutil.Writer
+}
+
+// NewDatabase opens a new database
+func NewDatabase(
+ base *base.BaseDendrite,
+ dbProperties *config.DatabaseOptions,
+ cache caching.FederationCache,
+ isLocalServerName func(gomatrixserverlib.ServerName) bool,
+) (*Database, error) {
+ var d Database
+ var err error
+ if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
+ return nil, err
+ }
+ queue, err := NewPostgresRelayQueueTable(d.db)
+ if err != nil {
+ return nil, err
+ }
+ queueJSON, err := NewPostgresRelayQueueJSONTable(d.db)
+ if err != nil {
+ return nil, err
+ }
+ d.Database = shared.Database{
+ DB: d.db,
+ IsLocalServerName: isLocalServerName,
+ Cache: cache,
+ Writer: d.writer,
+ RelayQueue: queue,
+ RelayQueueJSON: queueJSON,
+ }
+ return &d, nil
+}
diff --git a/relayapi/storage/shared/storage.go b/relayapi/storage/shared/storage.go
new file mode 100644
index 00000000..0993707b
--- /dev/null
+++ b/relayapi/storage/shared/storage.go
@@ -0,0 +1,170 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shared
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
+ "github.com/matrix-org/dendrite/internal/caching"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/relayapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+type Database struct {
+ DB *sql.DB
+ IsLocalServerName func(gomatrixserverlib.ServerName) bool
+ Cache caching.FederationCache
+ Writer sqlutil.Writer
+ RelayQueue tables.RelayQueue
+ RelayQueueJSON tables.RelayQueueJSON
+}
+
+func (d *Database) StoreTransaction(
+ ctx context.Context,
+ transaction gomatrixserverlib.Transaction,
+) (*receipt.Receipt, error) {
+ var err error
+ jsonTransaction, err := json.Marshal(transaction)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal: %w", err)
+ }
+
+ var nid int64
+ _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ nid, err = d.RelayQueueJSON.InsertQueueJSON(ctx, txn, string(jsonTransaction))
+ return err
+ })
+ if err != nil {
+ return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
+ }
+
+ newReceipt := receipt.NewReceipt(nid)
+ return &newReceipt, nil
+}
+
+func (d *Database) AssociateTransactionWithDestinations(
+ ctx context.Context,
+ destinations map[gomatrixserverlib.UserID]struct{},
+ transactionID gomatrixserverlib.TransactionID,
+ dbReceipt *receipt.Receipt,
+) error {
+ err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ var lastErr error
+ for destination := range destinations {
+ destination := destination
+ err := d.RelayQueue.InsertQueueEntry(
+ ctx,
+ txn,
+ transactionID,
+ destination.Domain(),
+ dbReceipt.GetNID(),
+ )
+ if err != nil {
+ lastErr = fmt.Errorf("d.insertQueueEntry: %w", err)
+ }
+ }
+ return lastErr
+ })
+
+ return err
+}
+
+func (d *Database) CleanTransactions(
+ ctx context.Context,
+ userID gomatrixserverlib.UserID,
+ receipts []*receipt.Receipt,
+) error {
+ nids := make([]int64, len(receipts))
+ for i, dbReceipt := range receipts {
+ nids[i] = dbReceipt.GetNID()
+ }
+
+ err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ deleteEntryErr := d.RelayQueue.DeleteQueueEntries(ctx, txn, userID.Domain(), nids)
+ // TODO : If there are still queue entries for any of these nids for other destinations
+ // then we shouldn't delete the json entries.
+ // But this can't happen with the current api design.
+ // There will only ever be one server entry for each nid since each call to send_relay
+ // only accepts a single server name and inside there we create a new json entry.
+ // So for multiple destinations we would call send_relay multiple times and have multiple
+ // json entries of the same transaction.
+ //
+ // TLDR; this works as expected right now but can easily be optimised in the future.
+ deleteJSONErr := d.RelayQueueJSON.DeleteQueueJSON(ctx, txn, nids)
+
+ if deleteEntryErr != nil {
+ return fmt.Errorf("d.deleteQueueEntries: %w", deleteEntryErr)
+ }
+ if deleteJSONErr != nil {
+ return fmt.Errorf("d.deleteQueueJSON: %w", deleteJSONErr)
+ }
+ return nil
+ })
+
+ return err
+}
+
+func (d *Database) GetTransaction(
+ ctx context.Context,
+ userID gomatrixserverlib.UserID,
+) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) {
+ entriesRequested := 1
+ nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), entriesRequested)
+ if err != nil {
+ return nil, nil, fmt.Errorf("d.SelectQueueEntries: %w", err)
+ }
+ if len(nids) == 0 {
+ return nil, nil, nil
+ }
+ firstNID := nids[0]
+
+ txns := map[int64][]byte{}
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ txns, err = d.RelayQueueJSON.SelectQueueJSON(ctx, txn, nids)
+ return err
+ })
+ if err != nil {
+ return nil, nil, fmt.Errorf("d.SelectQueueJSON: %w", err)
+ }
+
+ transaction := &gomatrixserverlib.Transaction{}
+ if _, ok := txns[firstNID]; !ok {
+ return nil, nil, fmt.Errorf("Failed retrieving json blob for transaction: %d", firstNID)
+ }
+
+ err = json.Unmarshal(txns[firstNID], transaction)
+ if err != nil {
+ return nil, nil, fmt.Errorf("Unmarshal transaction: %w", err)
+ }
+
+ newReceipt := receipt.NewReceipt(firstNID)
+ return transaction, &newReceipt, nil
+}
+
+func (d *Database) GetTransactionCount(
+ ctx context.Context,
+ userID gomatrixserverlib.UserID,
+) (int64, error) {
+ count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain())
+ if err != nil {
+ return 0, fmt.Errorf("d.SelectQueueEntryCount: %w", err)
+ }
+ return count, nil
+}
diff --git a/relayapi/storage/sqlite3/relay_queue_json_table.go b/relayapi/storage/sqlite3/relay_queue_json_table.go
new file mode 100644
index 00000000..502da3b0
--- /dev/null
+++ b/relayapi/storage/sqlite3/relay_queue_json_table.go
@@ -0,0 +1,137 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+)
+
+const relayQueueJSONSchema = `
+-- The relayapi_queue_json table contains event contents that
+-- we are storing for future forwarding.
+CREATE TABLE IF NOT EXISTS relayapi_queue_json (
+ -- The JSON NID. This allows cross-referencing to find the JSON blob.
+ json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
+ -- The JSON body. Text so that we preserve UTF-8.
+ json_body TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx
+ ON relayapi_queue_json (json_nid);
+`
+
+const insertQueueJSONSQL = "" +
+ "INSERT INTO relayapi_queue_json (json_body)" +
+ " VALUES ($1)"
+
+const deleteQueueJSONSQL = "" +
+ "DELETE FROM relayapi_queue_json WHERE json_nid IN ($1)"
+
+const selectQueueJSONSQL = "" +
+ "SELECT json_nid, json_body FROM relayapi_queue_json" +
+ " WHERE json_nid IN ($1)"
+
+type relayQueueJSONStatements struct {
+ db *sql.DB
+ insertJSONStmt *sql.Stmt
+ //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
+ //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
+}
+
+func NewSQLiteRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) {
+ s = &relayQueueJSONStatements{
+ db: db,
+ }
+ _, err = db.Exec(relayQueueJSONSchema)
+ if err != nil {
+ return
+ }
+
+ return s, sqlutil.StatementList{
+ {&s.insertJSONStmt, insertQueueJSONSQL},
+ }.Prepare(db)
+}
+
+func (s *relayQueueJSONStatements) InsertQueueJSON(
+ ctx context.Context, txn *sql.Tx, json string,
+) (lastid int64, err error) {
+ stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
+ res, err := stmt.ExecContext(ctx, json)
+ if err != nil {
+ return 0, fmt.Errorf("stmt.QueryContext: %w", err)
+ }
+ lastid, err = res.LastInsertId()
+ if err != nil {
+ return 0, fmt.Errorf("res.LastInsertId: %w", err)
+ }
+ return
+}
+
+func (s *relayQueueJSONStatements) DeleteQueueJSON(
+ ctx context.Context, txn *sql.Tx, nids []int64,
+) error {
+ deleteSQL := strings.Replace(deleteQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
+ deleteStmt, err := txn.Prepare(deleteSQL)
+ if err != nil {
+ return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
+ }
+
+ iNIDs := make([]interface{}, len(nids))
+ for k, v := range nids {
+ iNIDs[k] = v
+ }
+
+ stmt := sqlutil.TxStmt(txn, deleteStmt)
+ _, err = stmt.ExecContext(ctx, iNIDs...)
+ return err
+}
+
+func (s *relayQueueJSONStatements) SelectQueueJSON(
+ ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
+) (map[int64][]byte, error) {
+ selectSQL := strings.Replace(selectQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
+ selectStmt, err := txn.Prepare(selectSQL)
+ if err != nil {
+ return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err)
+ }
+
+ iNIDs := make([]interface{}, len(jsonNIDs))
+ for k, v := range jsonNIDs {
+ iNIDs[k] = v
+ }
+
+ blobs := map[int64][]byte{}
+ stmt := sqlutil.TxStmt(txn, selectStmt)
+ rows, err := stmt.QueryContext(ctx, iNIDs...)
+ if err != nil {
+ return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectQueueJSON: rows.close() failed")
+ for rows.Next() {
+ var nid int64
+ var blob []byte
+ if err = rows.Scan(&nid, &blob); err != nil {
+ return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err)
+ }
+ blobs[nid] = blob
+ }
+ return blobs, err
+}
diff --git a/relayapi/storage/sqlite3/relay_queue_table.go b/relayapi/storage/sqlite3/relay_queue_table.go
new file mode 100644
index 00000000..49c6b4de
--- /dev/null
+++ b/relayapi/storage/sqlite3/relay_queue_table.go
@@ -0,0 +1,168 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const relayQueueSchema = `
+CREATE TABLE IF NOT EXISTS relayapi_queue (
+ -- The transaction ID that was generated before persisting the event.
+ transaction_id TEXT NOT NULL,
+ -- The domain part of the user ID the m.room.member event is for.
+ server_name TEXT NOT NULL,
+ -- The JSON NID from the relayapi_queue_json table.
+ json_nid BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx
+ ON relayapi_queue (json_nid, server_name);
+CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx
+ ON relayapi_queue (json_nid);
+CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx
+ ON relayapi_queue (server_name);
+`
+
+const insertQueueEntrySQL = "" +
+ "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" +
+ " VALUES ($1, $2, $3)"
+
+const deleteQueueEntriesSQL = "" +
+ "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid IN ($2)"
+
+const selectQueueEntriesSQL = "" +
+ "SELECT json_nid FROM relayapi_queue" +
+ " WHERE server_name = $1" +
+ " ORDER BY json_nid" +
+ " LIMIT $2"
+
+const selectQueueEntryCountSQL = "" +
+ "SELECT COUNT(*) FROM relayapi_queue" +
+ " WHERE server_name = $1"
+
+type relayQueueStatements struct {
+ db *sql.DB
+ insertQueueEntryStmt *sql.Stmt
+ selectQueueEntriesStmt *sql.Stmt
+ selectQueueEntryCountStmt *sql.Stmt
+ // deleteQueueEntriesStmt *sql.Stmt - prepared at runtime due to variadic
+}
+
+func NewSQLiteRelayQueueTable(
+ db *sql.DB,
+) (s *relayQueueStatements, err error) {
+ s = &relayQueueStatements{
+ db: db,
+ }
+ _, err = db.Exec(relayQueueSchema)
+ if err != nil {
+ return
+ }
+
+ return s, sqlutil.StatementList{
+ {&s.insertQueueEntryStmt, insertQueueEntrySQL},
+ {&s.selectQueueEntriesStmt, selectQueueEntriesSQL},
+ {&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL},
+ }.Prepare(db)
+}
+
+func (s *relayQueueStatements) InsertQueueEntry(
+ ctx context.Context,
+ txn *sql.Tx,
+ transactionID gomatrixserverlib.TransactionID,
+ serverName gomatrixserverlib.ServerName,
+ nid int64,
+) error {
+ stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ transactionID, // the transaction ID that we initially attempted
+ serverName, // destination server name
+ nid, // JSON blob NID
+ )
+ return err
+}
+
+func (s *relayQueueStatements) DeleteQueueEntries(
+ ctx context.Context,
+ txn *sql.Tx,
+ serverName gomatrixserverlib.ServerName,
+ jsonNIDs []int64,
+) error {
+ deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
+ deleteStmt, err := txn.Prepare(deleteSQL)
+ if err != nil {
+ return fmt.Errorf("s.deleteQueueEntries s.db.Prepare: %w", err)
+ }
+
+ params := make([]interface{}, len(jsonNIDs)+1)
+ params[0] = serverName
+ for k, v := range jsonNIDs {
+ params[k+1] = v
+ }
+
+ stmt := sqlutil.TxStmt(txn, deleteStmt)
+ _, err = stmt.ExecContext(ctx, params...)
+ return err
+}
+
+func (s *relayQueueStatements) SelectQueueEntries(
+ ctx context.Context,
+ txn *sql.Tx,
+ serverName gomatrixserverlib.ServerName,
+ limit int,
+) ([]int64, error) {
+ stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
+ rows, err := stmt.QueryContext(ctx, serverName, limit)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
+ var result []int64
+ for rows.Next() {
+ var nid int64
+ if err = rows.Scan(&nid); err != nil {
+ return nil, err
+ }
+ result = append(result, nid)
+ }
+
+ return result, rows.Err()
+}
+
+func (s *relayQueueStatements) SelectQueueEntryCount(
+ ctx context.Context,
+ txn *sql.Tx,
+ serverName gomatrixserverlib.ServerName,
+) (int64, error) {
+ var count int64
+ stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)
+ err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
+ if err == sql.ErrNoRows {
+ // It's acceptable for there to be no rows referencing a given
+ // JSON NID but it's not an error condition. Just return as if
+ // there's a zero count.
+ return 0, nil
+ }
+ return count, err
+}
diff --git a/relayapi/storage/sqlite3/storage.go b/relayapi/storage/sqlite3/storage.go
new file mode 100644
index 00000000..3ed4ab04
--- /dev/null
+++ b/relayapi/storage/sqlite3/storage.go
@@ -0,0 +1,64 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/internal/caching"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/relayapi/storage/shared"
+ "github.com/matrix-org/dendrite/setup/base"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// Database stores information needed by the federation sender
+type Database struct {
+ shared.Database
+ db *sql.DB
+ writer sqlutil.Writer
+}
+
+// NewDatabase opens a new database
+func NewDatabase(
+ base *base.BaseDendrite,
+ dbProperties *config.DatabaseOptions,
+ cache caching.FederationCache,
+ isLocalServerName func(gomatrixserverlib.ServerName) bool,
+) (*Database, error) {
+ var d Database
+ var err error
+ if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
+ return nil, err
+ }
+ queue, err := NewSQLiteRelayQueueTable(d.db)
+ if err != nil {
+ return nil, err
+ }
+ queueJSON, err := NewSQLiteRelayQueueJSONTable(d.db)
+ if err != nil {
+ return nil, err
+ }
+ d.Database = shared.Database{
+ DB: d.db,
+ IsLocalServerName: isLocalServerName,
+ Cache: cache,
+ Writer: d.writer,
+ RelayQueue: queue,
+ RelayQueueJSON: queueJSON,
+ }
+ return &d, nil
+}
diff --git a/relayapi/storage/storage.go b/relayapi/storage/storage.go
new file mode 100644
index 00000000..16ecbcfb
--- /dev/null
+++ b/relayapi/storage/storage.go
@@ -0,0 +1,46 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build !wasm
+// +build !wasm
+
+package storage
+
+import (
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal/caching"
+ "github.com/matrix-org/dendrite/relayapi/storage/postgres"
+ "github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
+ "github.com/matrix-org/dendrite/setup/base"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// NewDatabase opens a new database
+func NewDatabase(
+ base *base.BaseDendrite,
+ dbProperties *config.DatabaseOptions,
+ cache caching.FederationCache,
+ isLocalServerName func(gomatrixserverlib.ServerName) bool,
+) (Database, error) {
+ switch {
+ case dbProperties.ConnectionString.IsSQLite():
+ return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName)
+ case dbProperties.ConnectionString.IsPostgres():
+ return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName)
+ default:
+ return nil, fmt.Errorf("unexpected database type")
+ }
+}
diff --git a/relayapi/storage/tables/interface.go b/relayapi/storage/tables/interface.go
new file mode 100644
index 00000000..9056a567
--- /dev/null
+++ b/relayapi/storage/tables/interface.go
@@ -0,0 +1,66 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tables
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// RelayQueue table contains a mapping of server name to transaction id and the corresponding nid.
+// These are the transactions being stored for the given destination server.
+// The nids correspond to entries in the RelayQueueJSON table.
+type RelayQueue interface {
+ // Adds a new transaction_id: server_name mapping with associated json table nid to the table.
+ // Will ensure only one transaction id is present for each server_name: nid mapping.
+ // Adding duplicates will silently do nothing.
+ InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
+
+ // Removes multiple entries from the table corresponding the the list of nids provided.
+ // If any of the provided nids don't match a row in the table, that deletion is considered
+ // successful.
+ DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
+
+ // Get a list of nids associated with the provided server name.
+ // Returns up to `limit` nids. The entries are returned oldest first.
+ // Will return an empty list if no matches were found.
+ SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
+
+ // Get the number of entries in the table associated with the provided server name.
+ // If there are no matching rows, a count of 0 is returned with err set to nil.
+ SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
+}
+
+// RelayQueueJSON table contains a map of nid to the raw transaction json.
+type RelayQueueJSON interface {
+ // Adds a new transaction to the table.
+ // Adding a duplicate transaction will result in a new row being added and a new unique nid.
+ // return: unique nid representing this entry.
+ InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error)
+
+ // Removes multiple nids from the table.
+ // If any of the provided nids don't match a row in the table, that deletion is considered
+ // successful.
+ DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error
+
+ // Get the transaction json corresponding to the provided nids.
+ // Will return a partial result containing any matching nid from the table.
+ // Will return an empty map if no matches were found.
+ // It is the caller's responsibility to deal with the results appropriately.
+ // return: map indexed by nid of each matching transaction json.
+ SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
+}
diff --git a/relayapi/storage/tables/relay_queue_json_table_test.go b/relayapi/storage/tables/relay_queue_json_table_test.go
new file mode 100644
index 00000000..efa3363e
--- /dev/null
+++ b/relayapi/storage/tables/relay_queue_json_table_test.go
@@ -0,0 +1,173 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tables_test
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "testing"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/relayapi/storage/postgres"
+ "github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
+ "github.com/matrix-org/dendrite/relayapi/storage/tables"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/stretchr/testify/assert"
+)
+
+const (
+ testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
+)
+
+func mustCreateTransaction() gomatrixserverlib.Transaction {
+ txn := gomatrixserverlib.Transaction{}
+ txn.PDUs = []json.RawMessage{
+ []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
+ }
+ txn.Origin = testOrigin
+
+ return txn
+}
+
+type RelayQueueJSONDatabase struct {
+ DB *sql.DB
+ Writer sqlutil.Writer
+ Table tables.RelayQueueJSON
+}
+
+func mustCreateQueueJSONTable(
+ t *testing.T,
+ dbType test.DBType,
+) (database RelayQueueJSONDatabase, close func()) {
+ t.Helper()
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := sqlutil.Open(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, sqlutil.NewExclusiveWriter())
+ assert.NoError(t, err)
+ var tab tables.RelayQueueJSON
+ switch dbType {
+ case test.DBTypePostgres:
+ tab, err = postgres.NewPostgresRelayQueueJSONTable(db)
+ assert.NoError(t, err)
+ case test.DBTypeSQLite:
+ tab, err = sqlite3.NewSQLiteRelayQueueJSONTable(db)
+ assert.NoError(t, err)
+ }
+ assert.NoError(t, err)
+
+ database = RelayQueueJSONDatabase{
+ DB: db,
+ Writer: sqlutil.NewDummyWriter(),
+ Table: tab,
+ }
+ return database, close
+}
+
+func TestShoudInsertTransaction(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateQueueJSONTable(t, dbType)
+ defer close()
+
+ transaction := mustCreateTransaction()
+ tx, err := json.Marshal(transaction)
+ if err != nil {
+ t.Fatalf("Invalid transaction: %s", err.Error())
+ }
+
+ _, err = db.Table.InsertQueueJSON(ctx, nil, string(tx))
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+ })
+}
+
+func TestShouldRetrieveInsertedTransaction(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateQueueJSONTable(t, dbType)
+ defer close()
+
+ transaction := mustCreateTransaction()
+ tx, err := json.Marshal(transaction)
+ if err != nil {
+ t.Fatalf("Invalid transaction: %s", err.Error())
+ }
+
+ nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx))
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+
+ var storedJSON map[int64][]byte
+ _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
+ storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid})
+ return err
+ })
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction: %s", err.Error())
+ }
+
+ assert.Equal(t, 1, len(storedJSON))
+
+ var storedTx gomatrixserverlib.Transaction
+ json.Unmarshal(storedJSON[1], &storedTx)
+
+ assert.Equal(t, transaction, storedTx)
+ })
+}
+
+func TestShouldDeleteTransaction(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateQueueJSONTable(t, dbType)
+ defer close()
+
+ transaction := mustCreateTransaction()
+ tx, err := json.Marshal(transaction)
+ if err != nil {
+ t.Fatalf("Invalid transaction: %s", err.Error())
+ }
+
+ nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx))
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+
+ storedJSON := map[int64][]byte{}
+ _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
+ err = db.Table.DeleteQueueJSON(ctx, txn, []int64{nid})
+ return err
+ })
+ if err != nil {
+ t.Fatalf("Failed deleting transaction: %s", err.Error())
+ }
+
+ storedJSON = map[int64][]byte{}
+ _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
+ storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid})
+ return err
+ })
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction: %s", err.Error())
+ }
+
+ assert.Equal(t, 0, len(storedJSON))
+ })
+}
diff --git a/relayapi/storage/tables/relay_queue_table_test.go b/relayapi/storage/tables/relay_queue_table_test.go
new file mode 100644
index 00000000..99f9922c
--- /dev/null
+++ b/relayapi/storage/tables/relay_queue_table_test.go
@@ -0,0 +1,229 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tables_test
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/relayapi/storage/postgres"
+ "github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
+ "github.com/matrix-org/dendrite/relayapi/storage/tables"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/stretchr/testify/assert"
+)
+
+type RelayQueueDatabase struct {
+ DB *sql.DB
+ Writer sqlutil.Writer
+ Table tables.RelayQueue
+}
+
+func mustCreateQueueTable(
+ t *testing.T,
+ dbType test.DBType,
+) (database RelayQueueDatabase, close func()) {
+ t.Helper()
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := sqlutil.Open(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, sqlutil.NewExclusiveWriter())
+ assert.NoError(t, err)
+ var tab tables.RelayQueue
+ switch dbType {
+ case test.DBTypePostgres:
+ tab, err = postgres.NewPostgresRelayQueueTable(db)
+ assert.NoError(t, err)
+ case test.DBTypeSQLite:
+ tab, err = sqlite3.NewSQLiteRelayQueueTable(db)
+ assert.NoError(t, err)
+ }
+ assert.NoError(t, err)
+
+ database = RelayQueueDatabase{
+ DB: db,
+ Writer: sqlutil.NewDummyWriter(),
+ Table: tab,
+ }
+ return database, close
+}
+
+func TestShoudInsertQueueTransaction(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateQueueTable(t, dbType)
+ defer close()
+
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
+ serverName := gomatrixserverlib.ServerName("domain")
+ nid := int64(1)
+ err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+ })
+}
+
+func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateQueueTable(t, dbType)
+ defer close()
+
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
+ serverName := gomatrixserverlib.ServerName("domain")
+ nid := int64(1)
+
+ err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+
+ retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 10)
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction: %s", err.Error())
+ }
+
+ assert.Equal(t, nid, retrievedNids[0])
+ assert.Equal(t, 1, len(retrievedNids))
+ })
+}
+
+func TestShouldRetrieveOldestInsertedQueueTransaction(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateQueueTable(t, dbType)
+ defer close()
+
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
+ serverName := gomatrixserverlib.ServerName("domain")
+ nid := int64(2)
+ err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+
+ transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
+ serverName = gomatrixserverlib.ServerName("domain")
+ oldestNID := int64(1)
+ err = db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, oldestNID)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+
+ retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 1)
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction: %s", err.Error())
+ }
+
+ assert.Equal(t, oldestNID, retrievedNids[0])
+ assert.Equal(t, 1, len(retrievedNids))
+
+ retrievedNids, err = db.Table.SelectQueueEntries(ctx, nil, serverName, 10)
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction: %s", err.Error())
+ }
+
+ assert.Equal(t, oldestNID, retrievedNids[0])
+ assert.Equal(t, nid, retrievedNids[1])
+ assert.Equal(t, 2, len(retrievedNids))
+ })
+}
+
+func TestShouldDeleteQueueTransaction(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateQueueTable(t, dbType)
+ defer close()
+
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
+ serverName := gomatrixserverlib.ServerName("domain")
+ nid := int64(1)
+
+ err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+
+ _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
+ err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
+ return err
+ })
+ if err != nil {
+ t.Fatalf("Failed deleting transaction: %s", err.Error())
+ }
+
+ count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction count: %s", err.Error())
+ }
+ assert.Equal(t, int64(0), count)
+ })
+}
+
+func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateQueueTable(t, dbType)
+ defer close()
+
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
+ serverName := gomatrixserverlib.ServerName("domain")
+ nid := int64(1)
+ transactionID2 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d2", time.Now().UnixNano()))
+ serverName2 := gomatrixserverlib.ServerName("domain2")
+ nid2 := int64(2)
+ transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano()))
+
+ err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+ err = db.Table.InsertQueueEntry(ctx, nil, transactionID2, serverName2, nid)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+ err = db.Table.InsertQueueEntry(ctx, nil, transactionID3, serverName, nid2)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+
+ _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
+ err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
+ return err
+ })
+ if err != nil {
+ t.Fatalf("Failed deleting transaction: %s", err.Error())
+ }
+
+ count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction count: %s", err.Error())
+ }
+ assert.Equal(t, int64(1), count)
+
+ count, err = db.Table.SelectQueueEntryCount(ctx, nil, serverName2)
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction count: %s", err.Error())
+ }
+ assert.Equal(t, int64(1), count)
+ })
+}