aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--federationapi/federationapi.go9
-rw-r--r--federationapi/queue/destinationqueue.go348
-rw-r--r--federationapi/queue/queue.go26
-rw-r--r--federationapi/queue/queue_test.go1047
-rw-r--r--federationapi/statistics/statistics.go131
-rw-r--r--federationapi/statistics/statistics_test.go19
-rw-r--r--federationapi/storage/shared/storage.go4
-rw-r--r--go.mod2
8 files changed, 1397 insertions, 189 deletions
diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go
index 4a13c9d9..f6dace70 100644
--- a/federationapi/federationapi.go
+++ b/federationapi/federationapi.go
@@ -116,17 +116,14 @@ func NewInternalAPI(
_ = federationDB.RemoveAllServersFromBlacklist()
}
- stats := &statistics.Statistics{
- DB: federationDB,
- FailuresUntilBlacklist: cfg.FederationMaxRetries,
- }
+ stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1)
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
queues := queue.NewOutgoingQueues(
federationDB, base.ProcessContext,
cfg.Matrix.DisableFederation,
- cfg.Matrix.ServerName, federation, rsAPI, stats,
+ cfg.Matrix.ServerName, federation, rsAPI, &stats,
&queue.SigningInfo{
KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey,
@@ -183,5 +180,5 @@ func NewInternalAPI(
}
time.AfterFunc(time.Minute, cleanExpiredEDUs)
- return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing)
+ return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, &stats, caches, queues, keyRing)
}
diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go
index 5cb8cae1..00e02b2d 100644
--- a/federationapi/queue/destinationqueue.go
+++ b/federationapi/queue/destinationqueue.go
@@ -35,7 +35,7 @@ import (
const (
maxPDUsPerTransaction = 50
- maxEDUsPerTransaction = 50
+ maxEDUsPerTransaction = 100
maxPDUsInMemory = 128
maxEDUsInMemory = 128
queueIdleTimeout = time.Second * 30
@@ -64,7 +64,6 @@ type destinationQueue struct {
pendingPDUs []*queuedPDU // PDUs waiting to be sent
pendingEDUs []*queuedEDU // EDUs waiting to be sent
pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs
- interruptBackoff chan bool // interrupts backoff
}
// Send event adds the event to the pending queue for the destination.
@@ -75,6 +74,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
return
}
+
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
@@ -102,12 +102,12 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
- // Wake up the queue if it's asleep.
- oq.wakeQueueIfNeeded()
- select {
- case oq.notify <- struct{}{}:
- default:
+
+ if !oq.backingOff.Load() {
+ oq.wakeQueueAndNotify()
}
+ } else {
+ oq.overflowed.Store(true)
}
}
@@ -147,12 +147,37 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
- // Wake up the queue if it's asleep.
- oq.wakeQueueIfNeeded()
- select {
- case oq.notify <- struct{}{}:
- default:
+
+ if !oq.backingOff.Load() {
+ oq.wakeQueueAndNotify()
}
+ } else {
+ oq.overflowed.Store(true)
+ }
+}
+
+// handleBackoffNotifier is registered as the backoff notification
+// callback with Statistics. It will wakeup and notify the queue
+// if the queue is currently backing off.
+func (oq *destinationQueue) handleBackoffNotifier() {
+ // Only wake up the queue if it is backing off.
+ // Otherwise there is no pending work for the queue to handle
+ // so waking the queue would be a waste of resources.
+ if oq.backingOff.Load() {
+ oq.wakeQueueAndNotify()
+ }
+}
+
+// wakeQueueAndNotify ensures the destination queue is running and notifies it
+// that there is pending work.
+func (oq *destinationQueue) wakeQueueAndNotify() {
+ // Wake up the queue if it's asleep.
+ oq.wakeQueueIfNeeded()
+
+ // Notify the queue that there are events ready to send.
+ select {
+ case oq.notify <- struct{}{}:
+ default:
}
}
@@ -161,10 +186,11 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
// then we will interrupt the backoff, causing any federation
// requests to retry.
func (oq *destinationQueue) wakeQueueIfNeeded() {
- // If we are backing off then interrupt the backoff.
+ // Clear the backingOff flag and update the backoff metrics if it was set.
if oq.backingOff.CompareAndSwap(true, false) {
- oq.interruptBackoff <- true
+ destinationQueueBackingOff.Dec()
}
+
// If we aren't running then wake up the queue.
if !oq.running.Load() {
// Start the queue.
@@ -196,38 +222,54 @@ func (oq *destinationQueue) getPendingFromDatabase() {
gotEDUs[edu.receipt.String()] = struct{}{}
}
+ overflowed := false
if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 {
// We have room in memory for some PDUs - let's request no more than that.
- if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil {
+ if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, maxPDUsInMemory); err == nil {
+ if len(pdus) == maxPDUsInMemory {
+ overflowed = true
+ }
for receipt, pdu := range pdus {
if _, ok := gotPDUs[receipt.String()]; ok {
continue
}
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu})
retrieved = true
+ if len(oq.pendingPDUs) == maxPDUsInMemory {
+ break
+ }
}
} else {
logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination)
}
}
+
if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 {
// We have room in memory for some EDUs - let's request no more than that.
- if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil {
+ if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, maxEDUsInMemory); err == nil {
+ if len(edus) == maxEDUsInMemory {
+ overflowed = true
+ }
for receipt, edu := range edus {
if _, ok := gotEDUs[receipt.String()]; ok {
continue
}
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu})
retrieved = true
+ if len(oq.pendingEDUs) == maxEDUsInMemory {
+ break
+ }
}
} else {
logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination)
}
}
+
// If we've retrieved all of the events from the database with room to spare
// in memory then we'll no longer consider this queue to be overflowed.
- if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory {
+ if !overflowed {
oq.overflowed.Store(false)
+ } else {
}
// If we've retrieved some events then notify the destination queue goroutine.
if retrieved {
@@ -238,6 +280,24 @@ func (oq *destinationQueue) getPendingFromDatabase() {
}
}
+// checkNotificationsOnClose checks for any remaining notifications
+// and starts a new backgroundSend goroutine if any exist.
+func (oq *destinationQueue) checkNotificationsOnClose() {
+ // NOTE : If we are stopping the queue due to blacklist then it
+ // doesn't matter if we have been notified of new work since
+ // this queue instance will be deleted anyway.
+ if !oq.statistics.Blacklisted() {
+ select {
+ case <-oq.notify:
+ // We received a new notification in between the
+ // idle timeout firing and stopping the goroutine.
+ // Immediately restart the queue.
+ oq.wakeQueueAndNotify()
+ default:
+ }
+ }
+}
+
// backgroundSend is the worker goroutine for sending events.
func (oq *destinationQueue) backgroundSend() {
// Check if a worker is already running, and if it isn't, then
@@ -245,10 +305,17 @@ func (oq *destinationQueue) backgroundSend() {
if !oq.running.CompareAndSwap(false, true) {
return
}
+
+ // Register queue cleanup functions.
+ // NOTE : The ordering here is very intentional.
+ defer oq.checkNotificationsOnClose()
+ defer oq.running.Store(false)
+
destinationQueueRunning.Inc()
defer destinationQueueRunning.Dec()
- defer oq.queues.clearQueue(oq)
- defer oq.running.Store(false)
+
+ idleTimeout := time.NewTimer(queueIdleTimeout)
+ defer idleTimeout.Stop()
// Mark the queue as overflowed, so we will consult the database
// to see if there's anything new to send.
@@ -261,59 +328,33 @@ func (oq *destinationQueue) backgroundSend() {
oq.getPendingFromDatabase()
}
+ // Reset the queue idle timeout.
+ if !idleTimeout.Stop() {
+ select {
+ case <-idleTimeout.C:
+ default:
+ }
+ }
+ idleTimeout.Reset(queueIdleTimeout)
+
// If we have nothing to do then wait either for incoming events, or
// until we hit an idle timeout.
select {
case <-oq.notify:
// There's work to do, either because getPendingFromDatabase
- // told us there is, or because a new event has come in via
- // sendEvent/sendEDU.
- case <-time.After(queueIdleTimeout):
+ // told us there is, a new event has come in via sendEvent/sendEDU,
+ // or we are backing off and it is time to retry.
+ case <-idleTimeout.C:
// The worker is idle so stop the goroutine. It'll get
// restarted automatically the next time we have an event to
// send.
return
case <-oq.process.Context().Done():
// The parent process is shutting down, so stop.
+ oq.statistics.ClearBackoff()
return
}
- // If we are backing off this server then wait for the
- // backoff duration to complete first, or until explicitly
- // told to retry.
- until, blacklisted := oq.statistics.BackoffInfo()
- if blacklisted {
- // It's been suggested that we should give up because the backoff
- // has exceeded a maximum allowable value. Clean up the in-memory
- // buffers at this point. The PDU clean-up is already on a defer.
- logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
- oq.pendingMutex.Lock()
- for i := range oq.pendingPDUs {
- oq.pendingPDUs[i] = nil
- }
- for i := range oq.pendingEDUs {
- oq.pendingEDUs[i] = nil
- }
- oq.pendingPDUs = nil
- oq.pendingEDUs = nil
- oq.pendingMutex.Unlock()
- return
- }
- if until != nil && until.After(time.Now()) {
- // We haven't backed off yet, so wait for the suggested amount of
- // time.
- duration := time.Until(*until)
- logrus.Debugf("Backing off %q for %s", oq.destination, duration)
- oq.backingOff.Store(true)
- destinationQueueBackingOff.Inc()
- select {
- case <-time.After(duration):
- case <-oq.interruptBackoff:
- }
- destinationQueueBackingOff.Dec()
- oq.backingOff.Store(false)
- }
-
// Work out which PDUs/EDUs to include in the next transaction.
oq.pendingMutex.RLock()
pduCount := len(oq.pendingPDUs)
@@ -328,38 +369,97 @@ func (oq *destinationQueue) backgroundSend() {
toSendEDUs := oq.pendingEDUs[:eduCount]
oq.pendingMutex.RUnlock()
+ // If we didn't get anything from the database and there are no
+ // pending EDUs then there's nothing to do - stop here.
+ if pduCount == 0 && eduCount == 0 {
+ continue
+ }
+
// If we have pending PDUs or EDUs then construct a transaction.
// Try sending the next transaction and see what happens.
- transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
+ terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
if terr != nil {
// We failed to send the transaction. Mark it as a failure.
- oq.statistics.Failure()
-
- } else if transaction {
- // If we successfully sent the transaction then clear out
- // the pending events and EDUs, and wipe our transaction ID.
- oq.statistics.Success()
- oq.pendingMutex.Lock()
- for i := range oq.pendingPDUs[:pc] {
- oq.pendingPDUs[i] = nil
+ _, blacklisted := oq.statistics.Failure()
+ if !blacklisted {
+ // Register the backoff state and exit the goroutine.
+ // It'll get restarted automatically when the backoff
+ // completes.
+ oq.backingOff.Store(true)
+ destinationQueueBackingOff.Inc()
+ return
+ } else {
+ // Immediately trigger the blacklist logic.
+ oq.blacklistDestination()
+ return
}
- for i := range oq.pendingEDUs[:ec] {
- oq.pendingEDUs[i] = nil
- }
- oq.pendingPDUs = oq.pendingPDUs[pc:]
- oq.pendingEDUs = oq.pendingEDUs[ec:]
- oq.pendingMutex.Unlock()
+ } else {
+ oq.handleTransactionSuccess(pduCount, eduCount)
}
}
}
// nextTransaction creates a new transaction from the pending event
-// queue and sends it. Returns true if a transaction was sent or
-// false otherwise.
+// queue and sends it.
+// Returns an error if the transaction wasn't sent.
func (oq *destinationQueue) nextTransaction(
pdus []*queuedPDU,
edus []*queuedEDU,
-) (bool, int, int, error) {
+) error {
+ // Create the transaction.
+ t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus)
+ logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
+
+ // Try to send the transaction to the destination server.
+ ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5)
+ defer cancel()
+ _, err := oq.client.SendTransaction(ctx, t)
+ switch errResponse := err.(type) {
+ case nil:
+ // Clean up the transaction in the database.
+ if pduReceipts != nil {
+ //logrus.Infof("Cleaning PDUs %q", pduReceipt.String())
+ if err = oq.db.CleanPDUs(oq.process.Context(), oq.destination, pduReceipts); err != nil {
+ logrus.WithError(err).Errorf("Failed to clean PDUs for server %q", t.Destination)
+ }
+ }
+ if eduReceipts != nil {
+ //logrus.Infof("Cleaning EDUs %q", eduReceipt.String())
+ if err = oq.db.CleanEDUs(oq.process.Context(), oq.destination, eduReceipts); err != nil {
+ logrus.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination)
+ }
+ }
+ // Reset the transaction ID.
+ oq.transactionIDMutex.Lock()
+ oq.transactionID = ""
+ oq.transactionIDMutex.Unlock()
+ return nil
+ case gomatrix.HTTPError:
+ // Report that we failed to send the transaction and we
+ // will retry again, subject to backoff.
+
+ // TODO: we should check for 500-ish fails vs 400-ish here,
+ // since we shouldn't queue things indefinitely in response
+ // to a 400-ish error
+ code := errResponse.Code
+ logrus.Debug("Transaction failed with HTTP", code)
+ return err
+ default:
+ logrus.WithFields(logrus.Fields{
+ "destination": oq.destination,
+ logrus.ErrorKey: err,
+ }).Debugf("Failed to send transaction %q", t.TransactionID)
+ return err
+ }
+}
+
+// createTransaction generates a gomatrixserverlib.Transaction from the provided pdus and edus.
+// It also returns the associated event receipts so they can be cleaned from the database in
+// the case of a successful transaction.
+func (oq *destinationQueue) createTransaction(
+ pdus []*queuedPDU,
+ edus []*queuedEDU,
+) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) {
// If there's no projected transaction ID then generate one. If
// the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
@@ -371,7 +471,6 @@ func (oq *destinationQueue) nextTransaction(
}
oq.transactionIDMutex.Unlock()
- // Create the transaction.
t := gomatrixserverlib.Transaction{
PDUs: []json.RawMessage{},
EDUs: []gomatrixserverlib.EDU{},
@@ -381,18 +480,13 @@ func (oq *destinationQueue) nextTransaction(
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
- // If we didn't get anything from the database and there are no
- // pending EDUs then there's nothing to do - stop here.
- if len(pdus) == 0 && len(edus) == 0 {
- return false, 0, 0, nil
- }
-
var pduReceipts []*shared.Receipt
var eduReceipts []*shared.Receipt
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
for _, pdu := range pdus {
+ // These should never be nil.
if pdu == nil || pdu.pdu == nil {
continue
}
@@ -404,6 +498,7 @@ func (oq *destinationQueue) nextTransaction(
// Do the same for pending EDUS in the queue.
for _, edu := range edus {
+ // These should never be nil.
if edu == nil || edu.edu == nil {
continue
}
@@ -411,44 +506,55 @@ func (oq *destinationQueue) nextTransaction(
eduReceipts = append(eduReceipts, edu.receipt)
}
- logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
+ return t, pduReceipts, eduReceipts
+}
- // Try to send the transaction to the destination server.
- // TODO: we should check for 500-ish fails vs 400-ish here,
- // since we shouldn't queue things indefinitely in response
- // to a 400-ish error
- ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5)
- defer cancel()
- _, err := oq.client.SendTransaction(ctx, t)
- switch err.(type) {
- case nil:
- // Clean up the transaction in the database.
- if pduReceipts != nil {
- //logrus.Infof("Cleaning PDUs %q", pduReceipt.String())
- if err = oq.db.CleanPDUs(oq.process.Context(), oq.destination, pduReceipts); err != nil {
- logrus.WithError(err).Errorf("Failed to clean PDUs for server %q", t.Destination)
- }
- }
- if eduReceipts != nil {
- //logrus.Infof("Cleaning EDUs %q", eduReceipt.String())
- if err = oq.db.CleanEDUs(oq.process.Context(), oq.destination, eduReceipts); err != nil {
- logrus.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination)
- }
+// blacklistDestination removes all pending PDUs and EDUs that have been cached
+// and deletes this queue.
+func (oq *destinationQueue) blacklistDestination() {
+ // It's been suggested that we should give up because the backoff
+ // has exceeded a maximum allowable value. Clean up the in-memory
+ // buffers at this point. The PDU clean-up is already on a defer.
+ logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
+
+ oq.pendingMutex.Lock()
+ for i := range oq.pendingPDUs {
+ oq.pendingPDUs[i] = nil
+ }
+ for i := range oq.pendingEDUs {
+ oq.pendingEDUs[i] = nil
+ }
+ oq.pendingPDUs = nil
+ oq.pendingEDUs = nil
+ oq.pendingMutex.Unlock()
+
+ // Delete this queue as no more messages will be sent to this
+ // destination until it is no longer blacklisted.
+ oq.statistics.AssignBackoffNotifier(nil)
+ oq.queues.clearQueue(oq)
+}
+
+// handleTransactionSuccess updates the cached event queues as well as the success and
+// backoff information for this server.
+func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) {
+ // If we successfully sent the transaction then clear out
+ // the pending events and EDUs, and wipe our transaction ID.
+ oq.statistics.Success()
+ oq.pendingMutex.Lock()
+ for i := range oq.pendingPDUs[:pduCount] {
+ oq.pendingPDUs[i] = nil
+ }
+ for i := range oq.pendingEDUs[:eduCount] {
+ oq.pendingEDUs[i] = nil
+ }
+ oq.pendingPDUs = oq.pendingPDUs[pduCount:]
+ oq.pendingEDUs = oq.pendingEDUs[eduCount:]
+ oq.pendingMutex.Unlock()
+
+ if len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 {
+ select {
+ case oq.notify <- struct{}{}:
+ default:
}
- // Reset the transaction ID.
- oq.transactionIDMutex.Lock()
- oq.transactionID = ""
- oq.transactionIDMutex.Unlock()
- return true, len(t.PDUs), len(t.EDUs), nil
- case gomatrix.HTTPError:
- // Report that we failed to send the transaction and we
- // will retry again, subject to backoff.
- return false, 0, 0, err
- default:
- logrus.WithFields(logrus.Fields{
- "destination": oq.destination,
- logrus.ErrorKey: err,
- }).Debugf("Failed to send transaction %q", t.TransactionID)
- return false, 0, 0, err
}
}
diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go
index 8245aa5b..68f789e3 100644
--- a/federationapi/queue/queue.go
+++ b/federationapi/queue/queue.go
@@ -162,23 +162,25 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
if !ok || oq == nil {
destinationQueueTotal.Inc()
oq = &destinationQueue{
- queues: oqs,
- db: oqs.db,
- process: oqs.process,
- rsAPI: oqs.rsAPI,
- origin: oqs.origin,
- destination: destination,
- client: oqs.client,
- statistics: oqs.statistics.ForServer(destination),
- notify: make(chan struct{}, 1),
- interruptBackoff: make(chan bool),
- signing: oqs.signing,
+ queues: oqs,
+ db: oqs.db,
+ process: oqs.process,
+ rsAPI: oqs.rsAPI,
+ origin: oqs.origin,
+ destination: destination,
+ client: oqs.client,
+ statistics: oqs.statistics.ForServer(destination),
+ notify: make(chan struct{}, 1),
+ signing: oqs.signing,
}
+ oq.statistics.AssignBackoffNotifier(oq.handleBackoffNotifier)
oqs.queues[destination] = oq
}
return oq
}
+// clearQueue removes the queue for the provided destination from the
+// set of destination queues.
func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) {
oqs.queuesMutex.Lock()
defer oqs.queuesMutex.Unlock()
@@ -332,7 +334,9 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
if oqs.disabled {
return
}
+ oqs.statistics.ForServer(srv).RemoveBlacklist()
if queue := oqs.getQueue(srv); queue != nil {
+ queue.statistics.ClearBackoff()
queue.wakeQueueIfNeeded()
}
}
diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go
new file mode 100644
index 00000000..6da9e6b3
--- /dev/null
+++ b/federationapi/queue/queue_test.go
@@ -0,0 +1,1047 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package queue
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+
+ "go.uber.org/atomic"
+ "gotest.tools/v3/poll"
+
+ "github.com/matrix-org/dendrite/federationapi/api"
+ "github.com/matrix-org/dendrite/federationapi/statistics"
+ "github.com/matrix-org/dendrite/federationapi/storage"
+ "github.com/matrix-org/dendrite/federationapi/storage/shared"
+ rsapi "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/pkg/errors"
+ "github.com/stretchr/testify/assert"
+)
+
+func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *process.ProcessContext, func()) {
+ if realDatabase {
+ // Real Database/s
+ b, baseClose := testrig.CreateBaseDendrite(t, dbType)
+ connStr, dbClose := test.PrepareDBConnectionString(t, dbType)
+ db, err := storage.NewDatabase(b, &config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, b.Caches, b.Cfg.Global.ServerName)
+ if err != nil {
+ t.Fatalf("NewDatabase returned %s", err)
+ }
+ return db, b.ProcessContext, func() {
+ dbClose()
+ baseClose()
+ }
+ } else {
+ // Fake Database
+ db := createDatabase()
+ b := struct {
+ ProcessContext *process.ProcessContext
+ }{ProcessContext: process.NewProcessContext()}
+ return db, b.ProcessContext, func() {}
+ }
+}
+
+func createDatabase() storage.Database {
+ return &fakeDatabase{
+ pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}),
+ pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}),
+ blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}),
+ pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent),
+ pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU),
+ associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}),
+ associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}),
+ }
+}
+
+type fakeDatabase struct {
+ storage.Database
+ dbMutex sync.Mutex
+ pendingPDUServers map[gomatrixserverlib.ServerName]struct{}
+ pendingEDUServers map[gomatrixserverlib.ServerName]struct{}
+ blacklistedServers map[gomatrixserverlib.ServerName]struct{}
+ pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent
+ pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU
+ associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
+ associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
+}
+
+var nidMutex sync.Mutex
+var nid = int64(0)
+
+func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ var event gomatrixserverlib.HeaderedEvent
+ if err := json.Unmarshal([]byte(js), &event); err == nil {
+ nidMutex.Lock()
+ defer nidMutex.Unlock()
+ nid++
+ receipt := shared.NewReceipt(nid)
+ d.pendingPDUs[&receipt] = &event
+ return &receipt, nil
+ }
+
+ var edu gomatrixserverlib.EDU
+ if err := json.Unmarshal([]byte(js), &edu); err == nil {
+ nidMutex.Lock()
+ defer nidMutex.Unlock()
+ nid++
+ receipt := shared.NewReceipt(nid)
+ d.pendingEDUs[&receipt] = &edu
+ return &receipt, nil
+ }
+
+ return nil, errors.New("Failed to determine type of json to store")
+}
+
+func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ pduCount := 0
+ pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent)
+ if receipts, ok := d.associatedPDUs[serverName]; ok {
+ for receipt := range receipts {
+ if event, ok := d.pendingPDUs[receipt]; ok {
+ pdus[receipt] = event
+ pduCount++
+ if pduCount == limit {
+ break
+ }
+ }
+ }
+ }
+ return pdus, nil
+}
+
+func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ eduCount := 0
+ edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU)
+ if receipts, ok := d.associatedEDUs[serverName]; ok {
+ for receipt := range receipts {
+ if event, ok := d.pendingEDUs[receipt]; ok {
+ edus[receipt] = event
+ eduCount++
+ if eduCount == limit {
+ break
+ }
+ }
+ }
+ }
+ return edus, nil
+}
+
+func (d *fakeDatabase) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ if _, ok := d.pendingPDUs[receipt]; ok {
+ if _, ok := d.associatedPDUs[serverName]; !ok {
+ d.associatedPDUs[serverName] = make(map[*shared.Receipt]struct{})
+ }
+ d.associatedPDUs[serverName][receipt] = struct{}{}
+ return nil
+ } else {
+ return errors.New("PDU doesn't exist")
+ }
+}
+
+func (d *fakeDatabase) AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ if _, ok := d.pendingEDUs[receipt]; ok {
+ if _, ok := d.associatedEDUs[serverName]; !ok {
+ d.associatedEDUs[serverName] = make(map[*shared.Receipt]struct{})
+ }
+ d.associatedEDUs[serverName][receipt] = struct{}{}
+ return nil
+ } else {
+ return errors.New("EDU doesn't exist")
+ }
+}
+
+func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ if pdus, ok := d.associatedPDUs[serverName]; ok {
+ for _, receipt := range receipts {
+ delete(pdus, receipt)
+ }
+ }
+
+ return nil
+}
+
+func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ if edus, ok := d.associatedEDUs[serverName]; ok {
+ for _, receipt := range receipts {
+ delete(edus, receipt)
+ }
+ }
+
+ return nil
+}
+
+func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ var count int64
+ if pdus, ok := d.associatedPDUs[serverName]; ok {
+ count = int64(len(pdus))
+ }
+ return count, nil
+}
+
+func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ var count int64
+ if edus, ok := d.associatedEDUs[serverName]; ok {
+ count = int64(len(edus))
+ }
+ return count, nil
+}
+
+func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ servers := []gomatrixserverlib.ServerName{}
+ for server := range d.pendingPDUServers {
+ servers = append(servers, server)
+ }
+ return servers, nil
+}
+
+func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ servers := []gomatrixserverlib.ServerName{}
+ for server := range d.pendingEDUServers {
+ servers = append(servers, server)
+ }
+ return servers, nil
+}
+
+func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ d.blacklistedServers[serverName] = struct{}{}
+ return nil
+}
+
+func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ delete(d.blacklistedServers, serverName)
+ return nil
+}
+
+func (d *fakeDatabase) RemoveAllServersFromBlacklist() error {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{})
+ return nil
+}
+
+func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
+ d.dbMutex.Lock()
+ defer d.dbMutex.Unlock()
+
+ isBlacklisted := false
+ if _, ok := d.blacklistedServers[serverName]; ok {
+ isBlacklisted = true
+ }
+
+ return isBlacklisted, nil
+}
+
+type stubFederationRoomServerAPI struct {
+ rsapi.FederationRoomserverAPI
+}
+
+func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Context, req *rsapi.QueryServerBannedFromRoomRequest, res *rsapi.QueryServerBannedFromRoomResponse) error {
+ res.Banned = false
+ return nil
+}
+
+type stubFederationClient struct {
+ api.FederationClient
+ shouldTxSucceed bool
+ txCount atomic.Uint32
+}
+
+func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) {
+ var result error
+ if !f.shouldTxSucceed {
+ result = fmt.Errorf("transaction failed")
+ }
+
+ f.txCount.Add(1)
+ return gomatrixserverlib.RespSend{}, result
+}
+
+func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent {
+ t.Helper()
+ content := `{"type":"m.room.message"}`
+ ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10)
+ if err != nil {
+ t.Fatalf("failed to create event: %v", err)
+ }
+ return ev.Headered(gomatrixserverlib.RoomVersionV10)
+}
+
+func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU {
+ t.Helper()
+ return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping}
+}
+
+func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) {
+ db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase)
+
+ fc := &stubFederationClient{
+ shouldTxSucceed: shouldTxSucceed,
+ txCount: *atomic.NewUint32(0),
+ }
+ rs := &stubFederationRoomServerAPI{}
+ stats := statistics.NewStatistics(db, failuresUntilBlacklist)
+ signingInfo := &SigningInfo{
+ KeyID: "ed21019:auto",
+ PrivateKey: test.PrivateKeyA,
+ ServerName: "localhost",
+ }
+ queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo)
+
+ return db, fc, queues, processContext, close
+}
+
+func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == 1 {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == 1 {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendPDUOnFailStoredInDB(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending
+ if fc.txCount.Load() >= 2 {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendEDUOnFailStoredInDB(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending
+ if fc.txCount.Load() >= 2 {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending
+ if fc.txCount.Load() >= 2 {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+
+ fc.shouldTxSucceed = true
+ ev = mustCreatePDU(t)
+ err = queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ pollEnd := time.Now().Add(1 * time.Second)
+ immediateCheck := func(log poll.LogT) poll.Result {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Error(fmt.Errorf("The backoff was interrupted early"))
+ }
+ if time.Now().After(pollEnd) {
+ // Allow more than enough time for the backoff to be interrupted before
+ // reporting that it wasn't.
+ return poll.Success()
+ }
+ return poll.Continue("waiting for events to be removed from database. Currently present PDU: %d", len(data))
+ }
+ poll.WaitOn(t, immediateCheck, poll.WithTimeout(2*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending
+ if fc.txCount.Load() >= 2 {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+
+ fc.shouldTxSucceed = true
+ ev = mustCreateEDU(t)
+ err = queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ pollEnd := time.Now().Add(1 * time.Second)
+ immediateCheck := func(log poll.LogT) poll.Result {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Error(fmt.Errorf("The backoff was interrupted early"))
+ }
+ if time.Now().After(pollEnd) {
+ // Allow more than enough time for the backoff to be interrupted before
+ // reporting that it wasn't.
+ return poll.Success()
+ }
+ return poll.Continue("waiting for events to be removed from database. Currently present EDU: %d", len(data))
+ }
+ poll.WaitOn(t, immediateCheck, poll.WithTimeout(2*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(2)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(2)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(2)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ queues.statistics.ForServer(destination).Failure()
+
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(2)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ queues.statistics.ForServer(destination).Failure()
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestRetryServerSendsPDUSuccessfully(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(1)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ // NOTE : getQueue before sending event to ensure we grab the same queue reference
+ // before it is blacklisted and deleted.
+ dest := queues.getQueue(destination)
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ checkBlacklisted := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ if !dest.running.Load() {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for queue to stop completely")
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+
+ fc.shouldTxSucceed = true
+ db.RemoveServerFromBlacklist(destination)
+ queues.RetryServer(destination)
+ checkRetry := func(log poll.LogT) poll.Result {
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data))
+ }
+ poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestRetryServerSendsEDUSuccessfully(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(1)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ // NOTE : getQueue before sending event to ensure we grab the same queue reference
+ // before it is blacklisted and deleted.
+ dest := queues.getQueue(destination)
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ checkBlacklisted := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ if !dest.running.Load() {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for queue to stop completely")
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+
+ fc.shouldTxSucceed = true
+ db.RemoveServerFromBlacklist(destination)
+ queues.RetryServer(destination)
+ checkRetry := func(log poll.LogT) poll.Result {
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data))
+ }
+ poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestSendPDUBatches(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+
+ // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true)
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ // Populate database with > maxPDUsPerTransaction
+ pduMultiplier := uint32(3)
+ for i := 0; i < maxPDUsPerTransaction*int(pduMultiplier); i++ {
+ ev := mustCreatePDU(t)
+ headeredJSON, _ := json.Marshal(ev)
+ nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON))
+ now := gomatrixserverlib.AsTimestamp(time.Now())
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i))
+ db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid)
+ }
+
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == pduMultiplier+1 { // +1 for the extra SendEvent()
+ data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for all events to be removed from database. Currently present PDU: %d", len(data))
+ }
+ return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+ // })
+}
+
+func TestSendEDUBatches(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+
+ // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true)
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ // Populate database with > maxEDUsPerTransaction
+ eduMultiplier := uint32(3)
+ for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ {
+ ev := mustCreateEDU(t)
+ ephemeralJSON, _ := json.Marshal(ev)
+ nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON))
+ db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil)
+ }
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == eduMultiplier+1 { // +1 for the extra SendEvent()
+ data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErr)
+ if len(data) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for all events to be removed from database. Currently present EDU: %d", len(data))
+ }
+ return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+ // })
+}
+
+func TestSendPDUAndEDUBatches(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+
+ // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true)
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ // Populate database with > maxEDUsPerTransaction
+ multiplier := uint32(3)
+
+ for i := 0; i < maxPDUsPerTransaction*int(multiplier)+1; i++ {
+ ev := mustCreatePDU(t)
+ headeredJSON, _ := json.Marshal(ev)
+ nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON))
+ now := gomatrixserverlib.AsTimestamp(time.Now())
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i))
+ db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid)
+ }
+
+ for i := 0; i < maxEDUsPerTransaction*int(multiplier); i++ {
+ ev := mustCreateEDU(t)
+ ephemeralJSON, _ := json.Marshal(ev)
+ nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON))
+ db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil)
+ }
+
+ ev := mustCreateEDU(t)
+ err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ check := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == multiplier+1 { // +1 for the extra SendEvent()
+ pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErrPDU)
+ eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErrEDU)
+ if len(pduData) == 0 && len(eduData) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for all events to be removed from database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData))
+ }
+ return poll.Continue("waiting for the right amount of send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
+ // })
+}
+
+func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) {
+ t.Parallel()
+ failuresUntilBlacklist := uint32(16)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ dest := queues.getQueue(destination)
+ queues.statistics.ForServer(destination).Failure()
+
+ ev := mustCreatePDU(t)
+ headeredJSON, _ := json.Marshal(ev)
+ nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON))
+ now := gomatrixserverlib.AsTimestamp(time.Now())
+ transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, 1))
+ db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid)
+
+ pollEnd := time.Now().Add(3 * time.Second)
+ runningCheck := func(log poll.LogT) poll.Result {
+ if dest.running.Load() || fc.txCount.Load() > 0 {
+ return poll.Error(fmt.Errorf("The queue was started"))
+ }
+ if time.Now().After(pollEnd) {
+ // Allow more than enough time for the queue to be started in the case
+ // of backoff triggering it to start.
+ return poll.Success()
+ }
+ return poll.Continue("waiting to ensure queue doesn't start.")
+ }
+ poll.WaitOn(t, runningCheck, poll.WithTimeout(4*time.Second), poll.WithDelay(100*time.Millisecond))
+}
+
+func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) {
+ // NOTE : Only one test case against real databases can be run at a time.
+ t.Parallel()
+ failuresUntilBlacklist := uint32(1)
+ destination := gomatrixserverlib.ServerName("remotehost")
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true)
+ // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up.
+ defer close()
+ defer func() {
+ pc.ShutdownDendrite()
+ <-pc.WaitForShutdown()
+ }()
+
+ // NOTE : getQueue before sending event to ensure we grab the same queue reference
+ // before it is blacklisted and deleted.
+ dest := queues.getQueue(destination)
+ ev := mustCreatePDU(t)
+ err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, err)
+
+ edu := mustCreateEDU(t)
+ errEDU := queues.SendEDU(edu, "localhost", []gomatrixserverlib.ServerName{destination})
+ assert.NoError(t, errEDU)
+
+ checkBlacklisted := func(log poll.LogT) poll.Result {
+ if fc.txCount.Load() == failuresUntilBlacklist {
+ pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErrPDU)
+ eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErrEDU)
+ if len(pduData) == 1 && len(eduData) == 1 {
+ if val, _ := db.IsServerBlacklisted(destination); val {
+ if !dest.running.Load() {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for queue to stop completely")
+ }
+ return poll.Continue("waiting for server to be blacklisted")
+ }
+ return poll.Continue("waiting for events to be added to database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData))
+ }
+ return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
+ }
+ poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond))
+
+ fc.shouldTxSucceed = true
+ db.RemoveServerFromBlacklist(destination)
+ queues.RetryServer(destination)
+ checkRetry := func(log poll.LogT) poll.Result {
+ pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErrPDU)
+ eduData, dbErrEDU := db.GetPendingEDUs(pc.Context(), destination, 200)
+ assert.NoError(t, dbErrEDU)
+ if len(pduData) == 0 && len(eduData) == 0 {
+ return poll.Success()
+ }
+ return poll.Continue("waiting for events to be removed from database. Currently present PDU: %d EDU: %d", len(pduData), len(eduData))
+ }
+ poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond))
+ })
+}
diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go
index db6d5c73..2ba99112 100644
--- a/federationapi/statistics/statistics.go
+++ b/federationapi/statistics/statistics.go
@@ -2,6 +2,7 @@ package statistics
import (
"math"
+ "math/rand"
"sync"
"time"
@@ -20,12 +21,23 @@ type Statistics struct {
servers map[gomatrixserverlib.ServerName]*ServerStatistics
mutex sync.RWMutex
+ backoffTimers map[gomatrixserverlib.ServerName]*time.Timer
+ backoffMutex sync.RWMutex
+
// How many times should we tolerate consecutive failures before we
// just blacklist the host altogether? The backoff is exponential,
// so the max time here to attempt is 2**failures seconds.
FailuresUntilBlacklist uint32
}
+func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics {
+ return Statistics{
+ DB: db,
+ FailuresUntilBlacklist: failuresUntilBlacklist,
+ backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer),
+ }
+}
+
// ForServer returns server statistics for the given server name. If it
// does not exist, it will create empty statistics and return those.
func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics {
@@ -45,7 +57,6 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
server = &ServerStatistics{
statistics: s,
serverName: serverName,
- interrupt: make(chan struct{}),
}
s.servers[serverName] = server
s.mutex.Unlock()
@@ -64,29 +75,43 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
// many times we failed etc. It also manages the backoff time and black-
// listing a remote host if it remains uncooperative.
type ServerStatistics struct {
- statistics *Statistics //
- serverName gomatrixserverlib.ServerName //
- blacklisted atomic.Bool // is the node blacklisted
- backoffStarted atomic.Bool // is the backoff started
- backoffUntil atomic.Value // time.Time until this backoff interval ends
- backoffCount atomic.Uint32 // number of times BackoffDuration has been called
- interrupt chan struct{} // interrupts the backoff goroutine
- successCounter atomic.Uint32 // how many times have we succeeded?
+ statistics *Statistics //
+ serverName gomatrixserverlib.ServerName //
+ blacklisted atomic.Bool // is the node blacklisted
+ backoffStarted atomic.Bool // is the backoff started
+ backoffUntil atomic.Value // time.Time until this backoff interval ends
+ backoffCount atomic.Uint32 // number of times BackoffDuration has been called
+ successCounter atomic.Uint32 // how many times have we succeeded?
+ backoffNotifier func() // notifies destination queue when backoff completes
+ notifierMutex sync.Mutex
}
+const maxJitterMultiplier = 1.4
+const minJitterMultiplier = 0.8
+
// duration returns how long the next backoff interval should be.
func (s *ServerStatistics) duration(count uint32) time.Duration {
- return time.Second * time.Duration(math.Exp2(float64(count)))
+ // Add some jitter to minimise the chance of having multiple backoffs
+ // ending at the same time.
+ jitter := rand.Float64()*(maxJitterMultiplier-minJitterMultiplier) + minJitterMultiplier
+ duration := time.Millisecond * time.Duration(math.Exp2(float64(count))*jitter*1000)
+ return duration
}
// cancel will interrupt the currently active backoff.
func (s *ServerStatistics) cancel() {
s.blacklisted.Store(false)
s.backoffUntil.Store(time.Time{})
- select {
- case s.interrupt <- struct{}{}:
- default:
- }
+
+ s.ClearBackoff()
+}
+
+// AssignBackoffNotifier configures the channel to send to when
+// a backoff completes.
+func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) {
+ s.notifierMutex.Lock()
+ defer s.notifierMutex.Unlock()
+ s.backoffNotifier = notifier
}
// Success updates the server statistics with a new successful
@@ -95,8 +120,8 @@ func (s *ServerStatistics) cancel() {
// we will unblacklist it.
func (s *ServerStatistics) Success() {
s.cancel()
- s.successCounter.Inc()
s.backoffCount.Store(0)
+ s.successCounter.Inc()
if s.statistics.DB != nil {
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
@@ -105,13 +130,17 @@ func (s *ServerStatistics) Success() {
}
// Failure marks a failure and starts backing off if needed.
-// The next call to BackoffIfRequired will do the right thing
-// after this. It will return the time that the current failure
+// It will return the time that the current failure
// will result in backoff waiting until, and a bool signalling
// whether we have blacklisted and therefore to give up.
func (s *ServerStatistics) Failure() (time.Time, bool) {
+ // Return immediately if we have blacklisted this node.
+ if s.blacklisted.Load() {
+ return time.Time{}, true
+ }
+
// If we aren't already backing off, this call will start
- // a new backoff period. Increase the failure counter and
+ // a new backoff period, increase the failure counter and
// start a goroutine which will wait out the backoff and
// unset the backoffStarted flag when done.
if s.backoffStarted.CompareAndSwap(false, true) {
@@ -122,40 +151,48 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
}
}
+ s.ClearBackoff()
return time.Time{}, true
}
- go func() {
- until, ok := s.backoffUntil.Load().(time.Time)
- if ok && !until.IsZero() {
- select {
- case <-time.After(time.Until(until)):
- case <-s.interrupt:
- }
- s.backoffStarted.Store(false)
- }
- }()
- }
+ // We're starting a new back off so work out what the next interval
+ // will be.
+ count := s.backoffCount.Load()
+ until := time.Now().Add(s.duration(count))
+ s.backoffUntil.Store(until)
- // Check if we have blacklisted this node.
- if s.blacklisted.Load() {
- return time.Now(), true
+ s.statistics.backoffMutex.Lock()
+ defer s.statistics.backoffMutex.Unlock()
+ s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished)
}
- // If we're already backing off and we haven't yet surpassed
- // the deadline then return that. Repeated calls to Failure
- // within a single backoff interval will have no side effects.
- if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) {
- return until, false
+ return s.backoffUntil.Load().(time.Time), false
+}
+
+// ClearBackoff stops the backoff timer for this destination if it is running
+// and removes the timer from the backoffTimers map.
+func (s *ServerStatistics) ClearBackoff() {
+ // If the timer is still running then stop it so it's memory is cleaned up sooner.
+ s.statistics.backoffMutex.Lock()
+ defer s.statistics.backoffMutex.Unlock()
+ if timer, ok := s.statistics.backoffTimers[s.serverName]; ok {
+ timer.Stop()
}
+ delete(s.statistics.backoffTimers, s.serverName)
+
+ s.backoffStarted.Store(false)
+}
+
+// backoffFinished will clear the previous backoff and notify the destination queue.
+func (s *ServerStatistics) backoffFinished() {
+ s.ClearBackoff()
- // We're either backing off and have passed the deadline, or
- // we aren't backing off, so work out what the next interval
- // will be.
- count := s.backoffCount.Load()
- until := time.Now().Add(s.duration(count))
- s.backoffUntil.Store(until)
- return until, false
+ // Notify the destinationQueue if one is currently running.
+ s.notifierMutex.Lock()
+ defer s.notifierMutex.Unlock()
+ if s.backoffNotifier != nil {
+ s.backoffNotifier()
+ }
}
// BackoffInfo returns information about the current or previous backoff.
@@ -174,6 +211,12 @@ func (s *ServerStatistics) Blacklisted() bool {
return s.blacklisted.Load()
}
+// RemoveBlacklist removes the blacklisted status from the server.
+func (s *ServerStatistics) RemoveBlacklist() {
+ s.cancel()
+ s.backoffCount.Store(0)
+}
+
// SuccessCount returns the number of successful requests. This is
// usually useful in constructing transaction IDs.
func (s *ServerStatistics) SuccessCount() uint32 {
diff --git a/federationapi/statistics/statistics_test.go b/federationapi/statistics/statistics_test.go
index 225350b6..6aa997f4 100644
--- a/federationapi/statistics/statistics_test.go
+++ b/federationapi/statistics/statistics_test.go
@@ -7,9 +7,7 @@ import (
)
func TestBackoff(t *testing.T) {
- stats := Statistics{
- FailuresUntilBlacklist: 7,
- }
+ stats := NewStatistics(nil, 7)
server := ServerStatistics{
statistics: &stats,
serverName: "test.com",
@@ -36,7 +34,7 @@ func TestBackoff(t *testing.T) {
// Get the duration.
_, blacklist := server.BackoffInfo()
- duration := time.Until(until).Round(time.Second)
+ duration := time.Until(until)
// Unset the backoff, or otherwise our next call will think that
// there's a backoff in progress and return the same result.
@@ -57,8 +55,17 @@ func TestBackoff(t *testing.T) {
// Check if the duration is what we expect.
t.Logf("Backoff %d is for %s", i, duration)
- if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted {
- t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration)
+ roundingAllowance := 0.01
+ minDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*minJitterMultiplier*1000-roundingAllowance)
+ maxDuration := time.Millisecond * time.Duration(math.Exp2(float64(i))*maxJitterMultiplier*1000+roundingAllowance)
+ var inJitterRange bool
+ if duration >= minDuration && duration <= maxDuration {
+ inJitterRange = true
+ } else {
+ inJitterRange = false
+ }
+ if !blacklist && !inJitterRange {
+ t.Fatalf("Backoff %d should have been between %s and %s but was %s", i, minDuration, maxDuration, duration)
}
}
}
diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go
index 9e40f311..6afb313a 100644
--- a/federationapi/storage/shared/storage.go
+++ b/federationapi/storage/shared/storage.go
@@ -52,6 +52,10 @@ type Receipt struct {
nid int64
}
+func NewReceipt(nid int64) Receipt {
+ return Receipt{nid: nid}
+}
+
func (r *Receipt) String() string {
return fmt.Sprintf("%d", r.nid)
}
diff --git a/go.mod b/go.mod
index 911d36c1..2248e73c 100644
--- a/go.mod
+++ b/go.mod
@@ -50,6 +50,7 @@ require (
golang.org/x/term v0.0.0-20220919170432-7a66f970e087
gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0
+ gotest.tools/v3 v3.4.0
nhooyr.io/websocket v1.8.7
)
@@ -127,7 +128,6 @@ require (
gopkg.in/macaroon.v2 v2.1.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
- gotest.tools/v3 v3.4.0 // indirect
)
go 1.18