aboutsummaryrefslogtreecommitdiff
path: root/federationapi/queue/queue_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'federationapi/queue/queue_test.go')
-rw-r--r--federationapi/queue/queue_test.go1047
1 files changed, 1047 insertions, 0 deletions
diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go
new file mode 100644
index 00000000..6da9e6b3
--- /dev/null
+++ b/federationapi/queue/queue_test.go
@@ -0,0 +1,1047 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package queue
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+
+ "go.uber.org/atomic"
+ "gotest.tools/v3/poll"
+
+ "github.com/matrix-org/dendrite/federationapi/api"
+ "github.com/matrix-org/dendrite/federationapi/statistics"
+ "github.com/matrix-org/dendrite/federationapi/storage"
+ "github.com/matrix-org/dendrite/federationapi/storage/shared"
+ rsapi "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/pkg/errors"
+ "github.com/stretchr/testify/assert"
+)
+
+func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *process.ProcessContext, func()) {
+ if realDatabase {
+ // Real Database/s
+ b, baseClose := testrig.CreateBaseDendrite(t, dbType)
+ connStr, dbClose := test.PrepareDBConnectionString(t, dbType)
+ db, err := storage.NewDatabase(b, &config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, b.Caches, b.Cfg.Global.ServerName)
+ if err != nil {
+ t.Fatalf("NewDatabase returned %s", err)
+ }
+ return db, b.ProcessContext, func() {
+ dbClose()
+ baseClose()
+ }
+ } else {
+ // Fake Database
+ db := createDatabase()
+ b := struct {
+ ProcessContext *process.ProcessContext
+ }{ProcessContext: process.NewProcessContext()}
+ return db, b.ProcessContext, func() {}
+ }
+}
+
+func createDatabase() storage.Database {
+ return &fakeDatabase{
+ pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}),
+ pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}),
+ blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}),
+ pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent),
+ pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU),
+ associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}),
+ associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}),
+ }
+}
+
+type fakeDatabase struct {
+ storage.Database
+ dbMutex sync.Mutex
+ pendingPDUServers map[gomatrixserverlib.ServerName]struct{}
+ pendingEDUServers map[gomatrixserverlib.ServerName]struct{}
+ blacklistedServers map[gomatrixserverlib.ServerName]struct{}
+ pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent
+ pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU
+ associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
+ associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
+}
+
+var nidMutex sync.Mutex
+var nid = int64(0)
+
+func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ var event gomatrixserverlib.HeaderedEvent
+ if err := json.Unmarshal([]byte(js), &event); err == nil {
+ nidMutex.Lock()
+ defer nidMutex.Unlock()
+ nid++
+ receipt := shared.NewReceipt(nid)
+ d.pendingPDUs[&receipt] = &event
+ return &receipt, nil
+ }
+
+ var edu gomatrixserverlib.EDU
+ if err := json.Unmarshal([]byte(js), &edu); err == nil {
+ nidMutex.Lock()
+ defer nidMutex.Unlock()
+ nid++
+ receipt := shared.NewReceipt(nid)
+ d.pendingEDUs[&receipt] = &edu
+ return &receipt, nil
+ }
+
+ return nil, errors.New("Failed to determine type of json to store")
+}
+
+func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ pduCount := 0
+ pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent)
+ if receipts, ok := d.associatedPDUs[serverName]; ok {
+ for receipt := range receipts {
+ if event, ok := d.pendingPDUs[receipt]; ok {
+ pdus[receipt] = event
+ pduCount++
+ if pduCount == limit {
+ break
+ }
+ }
+ }
+ }
+ return pdus, nil
+}
+
+func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ eduCount := 0
+ edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU)
+ if receipts, ok := d.associatedEDUs[serverName]; ok {
+ for receipt := range receipts {
+ if event, ok := d.pendingEDUs[receipt]; ok {
+ edus[receipt] = event
+ eduCount++
+ if eduCount == limit {
+ break
+ }
+ }
+ }
+ }
+ return edus, nil
+}
+
+func (d *fakeDatabase) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ if _, ok := d.pendingPDUs[receipt]; ok {
+ if _, ok := d.associatedPDUs[serverName]; !ok {
+ d.associatedPDUs[serverName] = make(map[*shared.Receipt]struct{})
+ }
+ d.associatedPDUs[serverName][receipt] = struct{}{}
+ return nil
+ } else {
+ return errors.New("PDU doesn't exist")
+ }
+}
+
+func (d *fakeDatabase) AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ if _, ok := d.pendingEDUs[receipt]; ok {
+ if _, ok := d.associatedEDUs[serverName]; !ok {
+ d.associatedEDUs[serverName] = make(map[*shared.Receipt]struct{})
+ }
+ d.associatedEDUs[serverName][receipt] = struct{}{}
+ return nil
+ } else {
+ return errors.New("EDU doesn't exist")
+ }
+}
+
+func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ if pdus, ok := d.associatedPDUs[serverName]; ok {
+ for _, receipt := range receipts {
+ delete(pdus, receipt)
+ }
+ }
+
+ return nil
+}
+
+func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ if edus, ok := d.associatedEDUs[serverName]; ok {
+ for _, receipt := range receipts {
+ delete(edus, receipt)
+ }
+ }
+
+ return nil
+}
+
+func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ var count int64
+ if pdus, ok := d.associatedPDUs[serverName]; ok {
+ count = int64(len(pdus))
+ }
+ return count, nil
+}
+
+func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ var count int64
+ if edus, ok := d.associatedEDUs[serverName]; ok {
+ count = int64(len(edus))
+ }
+ return count, nil
+}
+
+func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ servers := []gomatrixserverlib.ServerName{}
+ for server := range d.pendingPDUServers {
+ servers = append(servers, server)
+ }
+ return servers, nil
+}
+
+func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ servers := []gomatrixserverlib.ServerName{}
+ for server := range d.pendingEDUServers {
+ servers = append(servers, server)
+ }
+ return servers, nil
+}
+
+func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ d.blacklistedServers[serverName] = struct{}{}
+ return nil
+}
+
+func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ delete(d.blacklistedServers, serverName)
+ return nil
+}
+
+func (d *fakeDatabase) RemoveAllServersFromBlacklist() error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{})
+ return nil
+}
+
+func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ isBlacklisted := false
+ if _, ok := d.blacklistedServers[serverName]; ok {
+ isBlacklisted = true
+ }
+
+ return isBlacklisted, nil
+}
+
+type stubFederationRoomServerAPI struct {
+ rsapi.FederationRoomserverAPI
+}
+
+func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Context, req *rsapi.QueryServerBannedFromRoomRequest, res *rsapi.QueryServerBannedFromRoomResponse) error {
+ res.Banned = false
+ return nil
+}
+
+type stubFederationClient struct {
+ api.FederationClient
+ shouldTxSucceed bool
+ txCount atomic.Uint32
+}
+
+func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) {
+ var result error
+ if !f.shouldTxSucceed {
+ result = fmt.Errorf("transaction failed")
+ }
+
+ f.txCount.Add(1)
+ return gomatrixserverlib.RespSend{}, result
+}
+
+func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent {
+ t.Helper()
+ content := `{"type":"m.room.message"}`
+ ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10)
+ if err != nil {
+ t.Fatalf("failed to create event: %v", err)
+ }
+ return ev.Headered(gomatrixserverlib.RoomVersionV10)
+}
+
+func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU {
+ t.Helper()
+ return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping}
+}
+
+func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) {
+ db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase)
+
+ fc := &stubFederationClient{
+ shouldTxSucceed: shouldTxSucceed,
+ txCount: *atomic.NewUint32(0),
+ }
+ rs := &stubFederationRoomServerAPI{}
+ stats := statistics.NewStatistics(db, failuresUntilBlacklist)
+ signingInfo := &SigningInfo{
+ KeyID: "ed21019:auto",
+ PrivateKey: test.PrivateKeyA,
+ ServerName: "localhost",
+ }
+ queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo)
+
+ return db, fc, queues, processContext, close
+}
+
+func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == 1 {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == 1 {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendPDUOnFailStoredInDB(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending
+ if fc.txCount.Load() >= 2 {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendEDUOnFailStoredInDB(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending
+ if fc.txCount.Load() >= 2 {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending
+ if fc.txCount.Load() >= 2 {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+
+ fc.shouldTxSucceed = true
+ ev = mustCreatePDU(t)
+ err = queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ pollEnd := time.Now().Add(1 * time.Second)
+ immediateCheck := func(log poll.LogT) poll.Result {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Error(fmt.Errorf("The backoff was interrupted early"))
+ }
+ if time.Now().After(pollEnd) {
+ // Allow more than enough time for the backoff to be interrupted before
+ // reporting that it wasn't.
+ return poll.Success()
+ }
+ return poll.Continue("waiting for events to be removed from database. Currently present PDU: %d", len(data))
+ }
+ poll.WaitOn(t, immediateCheck, poll.WithTimeout(2*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending
+ if fc.txCount.Load() >= 2 {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+
+ fc.shouldTxSucceed = true
+ ev = mustCreateEDU(t)
+ err = queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ pollEnd := time.Now().Add(1 * time.Second)
+ immediateCheck := func(log poll.LogT) poll.Result {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Error(fmt.Errorf("The backoff was interrupted early"))
+ }
+ if time.Now().After(pollEnd) {
+ // Allow more than enough time for the backoff to be interrupted before
+ // reporting that it wasn't.
+ return poll.Success()
+ }
+ return poll.Continue("waiting for events to be removed from database. Currently present EDU: %d", len(data))
+ }
+ poll.WaitOn(t, immediateCheck, poll.WithTimeout(2*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(2)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(2)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(2)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ queues.statistics.ForServer(destination).Failure()
+
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(2)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ queues.statistics.ForServer(destination).Failure()
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestRetryServerSendsPDUSuccessfully(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(1)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ // NOTE : getQueue before sending event to ensure we grab the same queue reference
+ // before it is blacklisted and deleted.
+ dest := queues.getQueue(destination)
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ checkBlacklisted := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ if !dest.running.Load() {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for queue to stop completely")
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+
+ fc.shouldTxSucceed = true
+ db.RemoveServerFromBlacklist(destination)
+ queues.RetryServer(destination)
+ checkRetry := func(log poll.LogT) poll.Result {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data))
+ }
+ poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestRetryServerSendsEDUSuccessfully(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(1)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ // NOTE : getQueue before sending event to ensure we grab the same queue reference
+ // before it is blacklisted and deleted.
+ dest := queues.getQueue(destination)
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ checkBlacklisted := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ if !dest.running.Load() {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for queue to stop completely")
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+
+ fc.shouldTxSucceed = true
+ db.RemoveServerFromBlacklist(destination)
+ queues.RetryServer(destination)
+ checkRetry := func(log poll.LogT) poll.Result {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data))
+ }
+ poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendPDUBatches(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+
+ // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true)
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ // Populate database with > maxPDUsPerTransaction
+ pduMultiplier := uint32(3)
+ for i := 0; i < maxPDUsPerTransaction*int(pduMultiplier); i++ {
+ ev := mustCreatePDU(t)
+ headeredJSON, _ := json.Marshal(ev)
+ nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON))
+ now := gomatrixserverlib.AsTimestamp(time.Now())
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i))
+ db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid)
+ }
+
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == pduMultiplier+1 { // +1 for the extra SendEvent()
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for all events to be removed from database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+ // })
+}
+
+func TestSendEDUBatches(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+
+ // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true)
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ // Populate database with > maxEDUsPerTransaction
+ eduMultiplier := uint32(3)
+ for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ {
+ ev := mustCreateEDU(t)
+ ephemeralJSON, _ := json.Marshal(ev)
+ nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON))
+ db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil)
+ }
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == eduMultiplier+1 { // +1 for the extra SendEvent()
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for all events to be removed from database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+ // })
+}
+
+func TestSendPDUAndEDUBatches(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+
+ // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true)
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ // Populate database with > maxEDUsPerTransaction
+ multiplier := uint32(3)
+
+ for i := 0; i < maxPDUsPerTransaction*int(multiplier)+1; i++ {
+ ev := mustCreatePDU(t)
+ headeredJSON, _ := json.Marshal(ev)
+ nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON))
+ now := gomatrixserverlib.AsTimestamp(time.Now())
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i))
+ db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid)
+ }
+
+ for i := 0; i < maxEDUsPerTransaction*int(multiplier); i++ {
+ ev := mustCreateEDU(t)
+ ephemeralJSON, _ := json.Marshal(ev)
+ nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON))
+ db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil)
+ }
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == multiplier+1 { // +1 for the extra SendEvent()
+ pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErrPDU)
+ eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErrEDU)
+ if len(pduData) == 0 && len(eduData) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for all events to be removed from database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData))
+ }
+ return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+ // })
+}
+
+func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ dest := queues.getQueue(destination)
+ queues.statistics.ForServer(destination).Failure()
+
+ ev := mustCreatePDU(t)
+ headeredJSON, _ := json.Marshal(ev)
+ nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON))
+ now := gomatrixserverlib.AsTimestamp(time.Now())
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, 1))
+ db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid)
+
+ pollEnd := time.Now().Add(3 * time.Second)
+ runningCheck := func(log poll.LogT) poll.Result {
+ if dest.running.Load() || fc.txCount.Load() > 0 {
+ return poll.Error(fmt.Errorf("The queue was started"))
+ }
+ if time.Now().After(pollEnd) {
+ // Allow more than enough time for the queue to be started in the case
+ // of backoff triggering it to start.
+ return poll.Success()
+ }
+ return poll.Continue("waiting to ensure queue doesn't start.")
+ }
+ poll.WaitOn(t, runningCheck, poll.WithTimeout(4*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) {
+ // NOTE : Only one test case against real databases can be run at a time.
+ t.Parallel()
+ failuresUntilBlacklist := uint32(1)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true)
+ // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up.
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ // NOTE : getQueue before sending event to ensure we grab the same queue reference
+ // before it is blacklisted and deleted.
+ dest := queues.getQueue(destination)
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ edu := mustCreateEDU(t)
+ errEDU := queues.SendEDU(edu, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, errEDU)
+
+ checkBlacklisted := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErrPDU)
+ eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErrEDU)
+ if len(pduData) == 1 && len(eduData) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ if !dest.running.Load() {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for queue to stop completely")
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for events to be added to database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond))
+
+ fc.shouldTxSucceed = true
+ db.RemoveServerFromBlacklist(destination)
+ queues.RetryServer(destination)
+ checkRetry := func(log poll.LogT) poll.Result {
+ pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErrPDU)
+ eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErrEDU)
+ if len(pduData) == 0 && len(eduData) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for events to be removed from database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData))
+ }
+ poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond))
+ })
+}