diff options
author | Till <2353100+S7evinK@users.noreply.github.com> | 2022-10-21 12:50:51 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-21 12:50:51 +0200 |
commit | 9e4c3171da4e2d6d7b95731e702891513d081b49 (patch) | |
tree | f5c5826d628d6672655e3195bbb4717916043025 /federationapi/queue/queue_test.go | |
parent | e98d75fd63103243c5af2a63f2f547e4300adc4d (diff) |
Optimize inserting pending PDUs/EDUs (#2821)
This optimizes the association of PDUs/EDUs to their destination by
inserting all destinations in one transaction.
Diffstat (limited to 'federationapi/queue/queue_test.go')
-rw-r--r-- | federationapi/queue/queue_test.go | 60 |
1 files changed, 35 insertions, 25 deletions
diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 40419b91..a1b28010 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -25,6 +25,10 @@ import ( "go.uber.org/atomic" "gotest.tools/v3/poll" + "github.com/matrix-org/gomatrixserverlib" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" @@ -34,9 +38,6 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" - "github.com/matrix-org/gomatrixserverlib" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" ) func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *process.ProcessContext, func()) { @@ -158,30 +159,36 @@ func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixse return edus, nil } -func (d *fakeDatabase) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error { +func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() if _, ok := d.pendingPDUs[receipt]; ok { - if _, ok := d.associatedPDUs[serverName]; !ok { - d.associatedPDUs[serverName] = make(map[*shared.Receipt]struct{}) + for destination := range destinations { + if _, ok := d.associatedPDUs[destination]; !ok { + d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{}) + } + d.associatedPDUs[destination][receipt] = struct{}{} } - d.associatedPDUs[serverName][receipt] = struct{}{} + return nil } else { return errors.New("PDU doesn't exist") } } -func (d *fakeDatabase) AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { +func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() if _, ok := d.pendingEDUs[receipt]; ok { - if _, ok := d.associatedEDUs[serverName]; !ok { - d.associatedEDUs[serverName] = make(map[*shared.Receipt]struct{}) + for destination := range destinations { + if _, ok := d.associatedEDUs[destination]; !ok { + d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{}) + } + d.associatedEDUs[destination][receipt] = struct{}{} } - d.associatedEDUs[serverName][receipt] = struct{}{} + return nil } else { return errors.New("EDU doesn't exist") @@ -821,15 +828,15 @@ func TestSendPDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} // Populate database with > maxPDUsPerTransaction pduMultiplier := uint32(3) for i := 0; i < maxPDUsPerTransaction*int(pduMultiplier); i++ { ev := mustCreatePDU(t) headeredJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) - now := gomatrixserverlib.AsTimestamp(time.Now()) - transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i)) - db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") } ev := mustCreatePDU(t) @@ -865,13 +872,15 @@ func TestSendEDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} // Populate database with > maxEDUsPerTransaction eduMultiplier := uint32(3) for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ { ev := mustCreateEDU(t) ephemeralJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) - db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil) + err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") } ev := mustCreateEDU(t) @@ -907,23 +916,23 @@ func TestSendPDUAndEDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} // Populate database with > maxEDUsPerTransaction multiplier := uint32(3) - for i := 0; i < maxPDUsPerTransaction*int(multiplier)+1; i++ { ev := mustCreatePDU(t) headeredJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) - now := gomatrixserverlib.AsTimestamp(time.Now()) - transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i)) - db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") } for i := 0; i < maxEDUsPerTransaction*int(multiplier); i++ { ev := mustCreateEDU(t) ephemeralJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) - db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil) + err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") } ev := mustCreateEDU(t) @@ -960,13 +969,12 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { dest := queues.getQueue(destination) queues.statistics.ForServer(destination).Failure() - + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} ev := mustCreatePDU(t) headeredJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) - now := gomatrixserverlib.AsTimestamp(time.Now()) - transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, 1)) - db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") pollEnd := time.Now().Add(3 * time.Second) runningCheck := func(log poll.LogT) poll.Result { @@ -988,6 +996,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. @@ -1009,7 +1018,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { edu := mustCreateEDU(t) ephemeralJSON, _ := json.Marshal(edu) nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) - db.AssociateEDUWithDestination(pc.Context(), destination, nid, edu.Type, nil) + err = db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, edu.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") checkBlacklisted := func(log poll.LogT) poll.Result { if fc.txCount.Load() == failuresUntilBlacklist { |