aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--federationapi/routing/send.go225
1 files changed, 149 insertions, 76 deletions
diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go
index a514127c..ae9a63fc 100644
--- a/federationapi/routing/send.go
+++ b/federationapi/routing/send.go
@@ -16,7 +16,6 @@ package routing
import (
"context"
- "database/sql"
"encoding/json"
"errors"
"fmt"
@@ -24,7 +23,6 @@ import (
"sync"
"time"
- "github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
@@ -36,6 +34,7 @@ import (
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
+ "go.uber.org/atomic"
)
const (
@@ -90,6 +89,67 @@ func init() {
)
}
+type sendFIFOQueue struct {
+ tasks []*inputTask
+ count int
+ mutex sync.Mutex
+ notifs chan struct{}
+}
+
+func newSendFIFOQueue() *sendFIFOQueue {
+ q := &sendFIFOQueue{
+ notifs: make(chan struct{}, 1),
+ }
+ return q
+}
+
+func (q *sendFIFOQueue) push(frame *inputTask) {
+ q.mutex.Lock()
+ defer q.mutex.Unlock()
+ q.tasks = append(q.tasks, frame)
+ q.count++
+ select {
+ case q.notifs <- struct{}{}:
+ default:
+ }
+}
+
+// pop returns the first item of the queue, if there is one.
+// The second return value will indicate if a task was returned.
+func (q *sendFIFOQueue) pop() (*inputTask, bool) {
+ q.mutex.Lock()
+ defer q.mutex.Unlock()
+ if q.count == 0 {
+ return nil, false
+ }
+ frame := q.tasks[0]
+ q.tasks[0] = nil
+ q.tasks = q.tasks[1:]
+ q.count--
+ if q.count == 0 {
+ // Force a GC of the underlying array, since it might have
+ // grown significantly if the queue was hammered for some reason
+ q.tasks = nil
+ }
+ return frame, true
+}
+
+type inputTask struct {
+ ctx context.Context
+ t *txnReq
+ event *gomatrixserverlib.Event
+ wg *sync.WaitGroup
+ err error // written back by worker, only safe to read when all tasks are done
+ duration time.Duration // written back by worker, only safe to read when all tasks are done
+}
+
+type inputWorker struct {
+ running atomic.Bool
+ input *sendFIFOQueue
+}
+
+var inputWorkers sync.Map // room ID -> *inputWorker
+
// Send implements /_matrix/federation/v1/send/{txnID}
func Send(
httpReq *http.Request,
@@ -193,8 +253,12 @@ type txnFederationClient interface {
func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) {
results := make(map[string]gomatrixserverlib.PDUResult)
+ //var resultsMutex sync.Mutex
+
+ var wg sync.WaitGroup
+ var tasks []*inputTask
+ wg.Add(1) // for processEDUs
- pdus := []*gomatrixserverlib.HeaderedEvent{}
for _, pdu := range t.PDUs {
pduCountTotal.WithLabelValues("total").Inc()
var header struct {
@@ -245,83 +309,97 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
}
continue
}
- pdus = append(pdus, event.Headered(verRes.RoomVersion))
- }
-
- // Process the events.
- for _, e := range pdus {
- evStart := time.Now()
- if err := t.processEvent(ctx, e.Unwrap()); err != nil {
- // If the error is due to the event itself being bad then we skip
- // it and move onto the next event. We report an error so that the
- // sender knows that we have skipped processing it.
- //
- // However if the event is due to a temporary failure in our server
- // such as a database being unavailable then we should bail, and
- // hope that the sender will retry when we are feeling better.
- //
- // It is uncertain what we should do if an event fails because
- // we failed to fetch more information from the sending server.
- // For example if a request to /state fails.
- // If we skip the event then we risk missing the event until we
- // receive another event referencing it.
- // If we bail and stop processing then we risk wedging incoming
- // transactions from that server forever.
- if isProcessingErrorFatal(err) {
- sentry.CaptureException(err)
- // Any other error should be the result of a temporary error in
- // our server so we should bail processing the transaction entirely.
- util.GetLogger(ctx).Warnf("Processing %s failed fatally: %s", e.EventID(), err)
- jsonErr := util.ErrorResponse(err)
- processEventSummary.WithLabelValues(t.work, MetricsOutcomeFatal).Observe(
- float64(time.Since(evStart).Nanoseconds()) / 1000.,
- )
- return nil, &jsonErr
- } else {
- // Auth errors mean the event is 'rejected' which have to be silent to appease sytest
- errMsg := ""
- outcome := MetricsOutcomeRejected
- _, rejected := err.(*gomatrixserverlib.NotAllowed)
- if !rejected {
- errMsg = err.Error()
- outcome = MetricsOutcomeFail
- }
- util.GetLogger(ctx).WithError(err).WithField("event_id", e.EventID()).WithField("rejected", rejected).Warn(
- "Failed to process incoming federation event, skipping",
- )
- processEventSummary.WithLabelValues(t.work, outcome).Observe(
- float64(time.Since(evStart).Nanoseconds()) / 1000.,
- )
- results[e.EventID()] = gomatrixserverlib.PDUResult{
- Error: errMsg,
- }
+ v, _ := inputWorkers.LoadOrStore(event.RoomID(), &inputWorker{
+ input: newSendFIFOQueue(),
+ })
+ worker := v.(*inputWorker)
+ if !worker.running.Load() {
+ go worker.run()
+ }
+ wg.Add(1)
+ task := &inputTask{
+ ctx: ctx,
+ t: t,
+ event: event,
+ wg: &wg,
+ }
+ tasks = append(tasks, task)
+ worker.input.push(task)
+ }
+
+ go func() {
+ defer wg.Done()
+ t.processEDUs(ctx)
+ }()
+
+ wg.Wait()
+
+ for _, task := range tasks {
+ if task.err != nil {
+ results[task.event.EventID()] = gomatrixserverlib.PDUResult{
+ Error: task.err.Error(),
}
} else {
- results[e.EventID()] = gomatrixserverlib.PDUResult{}
- pduCountTotal.WithLabelValues("success").Inc()
- processEventSummary.WithLabelValues(t.work, MetricsOutcomeOK).Observe(
- float64(time.Since(evStart).Nanoseconds()) / 1000.,
- )
+ results[task.event.EventID()] = gomatrixserverlib.PDUResult{}
}
}
- t.processEDUs(ctx)
if c := len(results); c > 0 {
util.GetLogger(ctx).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID)
}
return &gomatrixserverlib.RespSend{PDUs: results}, nil
}
-// isProcessingErrorFatal returns true if the error is really bad and
-// we should stop processing the transaction, and returns false if it
-// is just some less serious error about a specific event.
-func isProcessingErrorFatal(err error) bool {
- switch err {
- case sql.ErrConnDone:
- case sql.ErrTxDone:
- return true
+func (t *inputWorker) run() {
+ if !t.running.CAS(false, true) {
+ return
+ }
+ defer t.running.Store(false)
+ for {
+ task, ok := t.input.pop()
+ if !ok {
+ return
+ }
+ if task == nil {
+ continue
+ }
+ func() {
+ defer task.wg.Done()
+ select {
+ case <-task.ctx.Done():
+ task.err = context.DeadlineExceeded
+ return
+ default:
+ evStart := time.Now()
+ task.err = task.t.processEvent(task.ctx, task.event)
+ task.duration = time.Since(evStart)
+ if err := task.err; err != nil {
+ switch err.(type) {
+ case *gomatrixserverlib.NotAllowed:
+ processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeRejected).Observe(
+ float64(time.Since(evStart).Nanoseconds()) / 1000.,
+ )
+ util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", true).Warn(
+ "Failed to process incoming federation event, skipping",
+ )
+ task.err = nil // make "rejected" failures silent
+ default:
+ processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeFail).Observe(
+ float64(time.Since(evStart).Nanoseconds()) / 1000.,
+ )
+ util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", false).Warn(
+ "Failed to process incoming federation event, skipping",
+ )
+ }
+ } else {
+ pduCountTotal.WithLabelValues("success").Inc()
+ processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeOK).Observe(
+ float64(time.Since(evStart).Nanoseconds()) / 1000.,
+ )
+ }
+ }
+ }()
}
- return false
}
type roomNotFoundError struct {
@@ -633,11 +711,6 @@ func (t *txnReq) processEventWithMissingState(
processEventWithMissingStateMutexes.Lock(e.RoomID())
defer processEventWithMissingStateMutexes.Unlock(e.RoomID())
- // Do this with a fresh context, so that we keep working even if the
- // original request times out. With any luck, by the time the remote
- // side retries, we'll have fetched the missing state.
- gmectx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
- defer cancel()
// We are missing the previous events for this events.
// This means that there is a gap in our view of the history of the
// room. There two ways that we can handle such a gap:
@@ -658,7 +731,7 @@ func (t *txnReq) processEventWithMissingState(
// - fill in the gap completely then process event `e` returning no backwards extremity
// - fail to fill in the gap and tell us to terminate the transaction err=not nil
// - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction
- newEvents, err := t.getMissingEvents(gmectx, e, roomVersion)
+ newEvents, err := t.getMissingEvents(ctx, e, roomVersion)
if err != nil {
return err
}
@@ -685,7 +758,7 @@ func (t *txnReq) processEventWithMissingState(
// Look up what the state is after the backward extremity. This will either
// come from the roomserver, if we know all the required events, or it will
// come from a remote server via /state_ids if not.
- prevState, trustworthy, lerr := t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID)
+ prevState, trustworthy, lerr := t.lookupStateAfterEvent(ctx, roomVersion, backwardsExtremity.RoomID(), prevEventID)
if lerr != nil {
util.GetLogger(ctx).WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID)
return lerr
@@ -729,7 +802,7 @@ func (t *txnReq) processEventWithMissingState(
}
// There's more than one previous state - run them all through state res
t.roomsMu.Lock(e.RoomID())
- resolvedState, err = t.resolveStatesAndCheck(gmectx, roomVersion, respStates, backwardsExtremity)
+ resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, backwardsExtremity)
t.roomsMu.Unlock(e.RoomID())
if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID())