aboutsummaryrefslogtreecommitdiff
path: root/relayapi/storage/tables/relay_queue_table_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'relayapi/storage/tables/relay_queue_table_test.go')
-rw-r--r--relayapi/storage/tables/relay_queue_table_test.go229
1 files changed, 229 insertions, 0 deletions
diff --git a/relayapi/storage/tables/relay_queue_table_test.go b/relayapi/storage/tables/relay_queue_table_test.go
new file mode 100644
index 00000000..99f9922c
--- /dev/null
+++ b/relayapi/storage/tables/relay_queue_table_test.go
@@ -0,0 +1,229 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tables_test
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/relayapi/storage/postgres"
+ "github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
+ "github.com/matrix-org/dendrite/relayapi/storage/tables"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/stretchr/testify/assert"
+)
+
+type RelayQueueDatabase struct {
+ DB *sql.DB
+ Writer sqlutil.Writer
+ Table tables.RelayQueue
+}
+
+func mustCreateQueueTable(
+ t *testing.T,
+ dbType test.DBType,
+) (database RelayQueueDatabase, close func()) {
+ t.Helper()
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := sqlutil.Open(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, sqlutil.NewExclusiveWriter())
+ assert.NoError(t, err)
+ var tab tables.RelayQueue
+ switch dbType {
+ case test.DBTypePostgres:
+ tab, err = postgres.NewPostgresRelayQueueTable(db)
+ assert.NoError(t, err)
+ case test.DBTypeSQLite:
+ tab, err = sqlite3.NewSQLiteRelayQueueTable(db)
+ assert.NoError(t, err)
+ }
+ assert.NoError(t, err)
+
+ database = RelayQueueDatabase{
+ DB: db,
+ Writer: sqlutil.NewDummyWriter(),
+ Table: tab,
+ }
+ return database, close
+}
+
+func TestShoudInsertQueueTransaction(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateQueueTable(t, dbType)
+ defer close()
+
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
+ serverName := gomatrixserverlib.ServerName("domain")
+ nid := int64(1)
+ err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+ })
+}
+
+func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateQueueTable(t, dbType)
+ defer close()
+
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
+ serverName := gomatrixserverlib.ServerName("domain")
+ nid := int64(1)
+
+ err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+
+ retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 10)
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction: %s", err.Error())
+ }
+
+ assert.Equal(t, nid, retrievedNids[0])
+ assert.Equal(t, 1, len(retrievedNids))
+ })
+}
+
+func TestShouldRetrieveOldestInsertedQueueTransaction(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateQueueTable(t, dbType)
+ defer close()
+
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
+ serverName := gomatrixserverlib.ServerName("domain")
+ nid := int64(2)
+ err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+
+ transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
+ serverName = gomatrixserverlib.ServerName("domain")
+ oldestNID := int64(1)
+ err = db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, oldestNID)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+
+ retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 1)
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction: %s", err.Error())
+ }
+
+ assert.Equal(t, oldestNID, retrievedNids[0])
+ assert.Equal(t, 1, len(retrievedNids))
+
+ retrievedNids, err = db.Table.SelectQueueEntries(ctx, nil, serverName, 10)
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction: %s", err.Error())
+ }
+
+ assert.Equal(t, oldestNID, retrievedNids[0])
+ assert.Equal(t, nid, retrievedNids[1])
+ assert.Equal(t, 2, len(retrievedNids))
+ })
+}
+
+func TestShouldDeleteQueueTransaction(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateQueueTable(t, dbType)
+ defer close()
+
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
+ serverName := gomatrixserverlib.ServerName("domain")
+ nid := int64(1)
+
+ err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+
+ _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
+ err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
+ return err
+ })
+ if err != nil {
+ t.Fatalf("Failed deleting transaction: %s", err.Error())
+ }
+
+ count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction count: %s", err.Error())
+ }
+ assert.Equal(t, int64(0), count)
+ })
+}
+
+func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateQueueTable(t, dbType)
+ defer close()
+
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
+ serverName := gomatrixserverlib.ServerName("domain")
+ nid := int64(1)
+ transactionID2 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d2", time.Now().UnixNano()))
+ serverName2 := gomatrixserverlib.ServerName("domain2")
+ nid2 := int64(2)
+ transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano()))
+
+ err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+ err = db.Table.InsertQueueEntry(ctx, nil, transactionID2, serverName2, nid)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+ err = db.Table.InsertQueueEntry(ctx, nil, transactionID3, serverName, nid2)
+ if err != nil {
+ t.Fatalf("Failed inserting transaction: %s", err.Error())
+ }
+
+ _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
+ err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid})
+ return err
+ })
+ if err != nil {
+ t.Fatalf("Failed deleting transaction: %s", err.Error())
+ }
+
+ count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName)
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction count: %s", err.Error())
+ }
+ assert.Equal(t, int64(1), count)
+
+ count, err = db.Table.SelectQueueEntryCount(ctx, nil, serverName2)
+ if err != nil {
+ t.Fatalf("Failed retrieving transaction count: %s", err.Error())
+ }
+ assert.Equal(t, int64(1), count)
+ })
+}