aboutsummaryrefslogtreecommitdiff
path: root/federationapi
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-04-27 15:29:49 +0100
committerNeil Alexander <neilalexander@users.noreply.github.com>2022-04-27 15:29:49 +0100
commit923f789ca3174a685bd53ce5e64a5e86cabd38cb (patch)
tree77dedd2028e257e3c1c24f77e19d889189ec38ad /federationapi
parent103795d33a09728d7619e73014d507505ff121e2 (diff)
Fix graceful shutdown
Diffstat (limited to 'federationapi')
-rw-r--r--federationapi/queue/destinationqueue.go13
-rw-r--r--federationapi/queue/queue.go13
2 files changed, 14 insertions, 12 deletions
diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go
index a5f8c03b..74794040 100644
--- a/federationapi/queue/destinationqueue.go
+++ b/federationapi/queue/destinationqueue.go
@@ -78,7 +78,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oq.db.AssociatePDUWithDestination(
- context.TODO(),
+ oq.process.Context(),
"", // TODO: remove this, as we don't need to persist the transaction ID
oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table
@@ -122,7 +122,7 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oq.db.AssociateEDUWithDestination(
- context.TODO(),
+ oq.process.Context(),
oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table
event.Type,
@@ -177,7 +177,7 @@ func (oq *destinationQueue) getPendingFromDatabase() {
// Check to see if there's anything to do for this server
// in the database.
retrieved := false
- ctx := context.Background()
+ ctx := oq.process.Context()
oq.pendingMutex.Lock()
defer oq.pendingMutex.Unlock()
@@ -271,6 +271,9 @@ func (oq *destinationQueue) backgroundSend() {
// restarted automatically the next time we have an event to
// send.
return
+ case <-oq.process.Context().Done():
+ // The parent process is shutting down, so stop.
+ return
}
// If we are backing off this server then wait for the
@@ -420,13 +423,13 @@ func (oq *destinationQueue) nextTransaction(
// Clean up the transaction in the database.
if pduReceipts != nil {
//logrus.Infof("Cleaning PDUs %q", pduReceipt.String())
- if err = oq.db.CleanPDUs(context.Background(), oq.destination, pduReceipts); err != nil {
+ if err = oq.db.CleanPDUs(oq.process.Context(), oq.destination, pduReceipts); err != nil {
logrus.WithError(err).Errorf("Failed to clean PDUs for server %q", t.Destination)
}
}
if eduReceipts != nil {
//logrus.Infof("Cleaning EDUs %q", eduReceipt.String())
- if err = oq.db.CleanEDUs(context.Background(), oq.destination, eduReceipts); err != nil {
+ if err = oq.db.CleanEDUs(oq.process.Context(), oq.destination, eduReceipts); err != nil {
logrus.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination)
}
}
diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go
index c45bbd1d..d152886f 100644
--- a/federationapi/queue/queue.go
+++ b/federationapi/queue/queue.go
@@ -15,7 +15,6 @@
package queue
import (
- "context"
"crypto/ed25519"
"encoding/json"
"fmt"
@@ -105,14 +104,14 @@ func NewOutgoingQueues(
// Look up which servers we have pending items for and then rehydrate those queues.
if !disabled {
serverNames := map[gomatrixserverlib.ServerName]struct{}{}
- if names, err := db.GetPendingPDUServerNames(context.Background()); err == nil {
+ if names, err := db.GetPendingPDUServerNames(process.Context()); err == nil {
for _, serverName := range names {
serverNames[serverName] = struct{}{}
}
} else {
log.WithError(err).Error("Failed to get PDU server names for destination queue hydration")
}
- if names, err := db.GetPendingEDUServerNames(context.Background()); err == nil {
+ if names, err := db.GetPendingEDUServerNames(process.Context()); err == nil {
for _, serverName := range names {
serverNames[serverName] = struct{}{}
}
@@ -215,7 +214,7 @@ func (oqs *OutgoingQueues) SendEvent(
// Check if any of the destinations are prohibited by server ACLs.
for destination := range destmap {
if api.IsServerBannedFromRoom(
- context.TODO(),
+ oqs.process.Context(),
oqs.rsAPI,
ev.RoomID(),
destination,
@@ -238,7 +237,7 @@ func (oqs *OutgoingQueues) SendEvent(
return fmt.Errorf("json.Marshal: %w", err)
}
- nid, err := oqs.db.StoreJSON(context.TODO(), string(headeredJSON))
+ nid, err := oqs.db.StoreJSON(oqs.process.Context(), string(headeredJSON))
if err != nil {
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
}
@@ -286,7 +285,7 @@ func (oqs *OutgoingQueues) SendEDU(
if result := gjson.GetBytes(e.Content, "room_id"); result.Exists() {
for destination := range destmap {
if api.IsServerBannedFromRoom(
- context.TODO(),
+ oqs.process.Context(),
oqs.rsAPI,
result.Str,
destination,
@@ -310,7 +309,7 @@ func (oqs *OutgoingQueues) SendEDU(
return fmt.Errorf("json.Marshal: %w", err)
}
- nid, err := oqs.db.StoreJSON(context.TODO(), string(ephemeralJSON))
+ nid, err := oqs.db.StoreJSON(oqs.process.Context(), string(ephemeralJSON))
if err != nil {
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
}