diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2020-07-22 17:01:29 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-07-22 17:01:29 +0100 |
commit | 1e71fd645ed9bbac87627434b303659a195512c7 (patch) | |
tree | af64865891c09990c65a7658a75d0aa5c34102fd /federationsender | |
parent | 470933789b4ae08cd33c0a1de3656673eb0ebe70 (diff) |
Persistent federation sender blacklist (#1214)
* Initial persistence of blacklists
* Move statistics folder
* Make MaxFederationRetries configurable
* Set lower failure thresholds for Yggdrasil demos
* Still write events into database for blacklisted hosts (they can be tidied up later)
* Review comments
Diffstat (limited to 'federationsender')
-rw-r--r-- | federationsender/federationsender.go | 12 | ||||
-rw-r--r-- | federationsender/internal/api.go | 6 | ||||
-rw-r--r-- | federationsender/queue/destinationqueue.go | 50 | ||||
-rw-r--r-- | federationsender/queue/queue.go | 10 | ||||
-rw-r--r-- | federationsender/statistics/statistics.go (renamed from federationsender/types/statistics.go) | 46 | ||||
-rw-r--r-- | federationsender/storage/interface.go | 4 | ||||
-rw-r--r-- | federationsender/storage/postgres/blacklist_table.go | 112 | ||||
-rw-r--r-- | federationsender/storage/postgres/storage.go | 5 | ||||
-rw-r--r-- | federationsender/storage/shared/storage.go | 13 | ||||
-rw-r--r-- | federationsender/storage/sqlite3/blacklist_table.go | 112 | ||||
-rw-r--r-- | federationsender/storage/sqlite3/storage.go | 5 | ||||
-rw-r--r-- | federationsender/storage/tables/interface.go | 6 |
12 files changed, 330 insertions, 51 deletions
diff --git a/federationsender/federationsender.go b/federationsender/federationsender.go index 79a2c084..9e14f6ec 100644 --- a/federationsender/federationsender.go +++ b/federationsender/federationsender.go @@ -21,8 +21,8 @@ import ( "github.com/matrix-org/dendrite/federationsender/internal" "github.com/matrix-org/dendrite/federationsender/inthttp" "github.com/matrix-org/dendrite/federationsender/queue" + "github.com/matrix-org/dendrite/federationsender/statistics" "github.com/matrix-org/dendrite/federationsender/storage" - "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/internal/setup" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" @@ -48,9 +48,13 @@ func NewInternalAPI( logrus.WithError(err).Panic("failed to connect to federation sender db") } - statistics := &types.Statistics{} + stats := &statistics.Statistics{ + DB: federationSenderDB, + FailuresUntilBlacklist: base.Cfg.Matrix.FederationMaxRetries, + } + queues := queue.NewOutgoingQueues( - federationSenderDB, base.Cfg.Matrix.ServerName, federation, rsAPI, statistics, + federationSenderDB, base.Cfg.Matrix.ServerName, federation, rsAPI, stats, &queue.SigningInfo{ KeyID: base.Cfg.Matrix.KeyID, PrivateKey: base.Cfg.Matrix.PrivateKey, @@ -73,5 +77,5 @@ func NewInternalAPI( logrus.WithError(err).Panic("failed to start typing server consumer") } - return internal.NewFederationSenderInternalAPI(federationSenderDB, base.Cfg, rsAPI, federation, keyRing, statistics, queues) + return internal.NewFederationSenderInternalAPI(federationSenderDB, base.Cfg, rsAPI, federation, keyRing, stats, queues) } diff --git a/federationsender/internal/api.go b/federationsender/internal/api.go index 0dca32fc..9a9880ce 100644 --- a/federationsender/internal/api.go +++ b/federationsender/internal/api.go @@ -2,8 +2,8 @@ package internal import ( "github.com/matrix-org/dendrite/federationsender/queue" + "github.com/matrix-org/dendrite/federationsender/statistics" "github.com/matrix-org/dendrite/federationsender/storage" - "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" @@ -13,7 +13,7 @@ import ( type FederationSenderInternalAPI struct { db storage.Database cfg *config.Dendrite - statistics *types.Statistics + statistics *statistics.Statistics rsAPI api.RoomserverInternalAPI federation *gomatrixserverlib.FederationClient keyRing *gomatrixserverlib.KeyRing @@ -25,7 +25,7 @@ func NewFederationSenderInternalAPI( rsAPI api.RoomserverInternalAPI, federation *gomatrixserverlib.FederationClient, keyRing *gomatrixserverlib.KeyRing, - statistics *types.Statistics, + statistics *statistics.Statistics, queues *queue.OutgoingQueues, ) *FederationSenderInternalAPI { return &FederationSenderInternalAPI{ diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index b7582bf9..dc2d4091 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -21,9 +21,9 @@ import ( "sync" "time" + "github.com/matrix-org/dendrite/federationsender/statistics" "github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/federationsender/storage/shared" - "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" @@ -51,7 +51,7 @@ type destinationQueue struct { destination gomatrixserverlib.ServerName // destination of requests running atomic.Bool // is the queue worker running? backingOff atomic.Bool // true if we're backing off - statistics *types.ServerStatistics // statistics about this remote server + statistics *statistics.ServerStatistics // statistics about this remote server incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send transactionIDMutex sync.Mutex // protects transactionID transactionID gomatrixserverlib.TransactionID // last transaction ID @@ -66,11 +66,6 @@ type destinationQueue struct { // If the queue is empty then it starts a background goroutine to // start sending events to that destination. func (oq *destinationQueue) sendEvent(receipt *shared.Receipt) { - if oq.statistics.Blacklisted() { - // If the destination is blacklisted then drop the event. - log.Infof("%s is blacklisted; dropping event", oq.destination) - return - } // Create a transaction ID. We'll either do this if we don't have // one made up yet, or if we've exceeded the number of maximum // events allowed in a single tranaction. We'll reset the counter @@ -97,13 +92,17 @@ func (oq *destinationQueue) sendEvent(receipt *shared.Receipt) { // We've successfully added a PDU to the transaction so increase // the counter. oq.transactionCount.Add(1) - // Wake up the queue if it's asleep. - oq.wakeQueueIfNeeded() - // If we're blocking on waiting PDUs then tell the queue that we - // have work to do. - select { - case oq.notifyPDUs <- true: - default: + // Check if the destination is blacklisted. If it isn't then wake + // up the queue. + if !oq.statistics.Blacklisted() { + // Wake up the queue if it's asleep. + oq.wakeQueueIfNeeded() + // If we're blocking on waiting PDUs then tell the queue that we + // have work to do. + select { + case oq.notifyPDUs <- true: + default: + } } } @@ -111,11 +110,6 @@ func (oq *destinationQueue) sendEvent(receipt *shared.Receipt) { // If the queue is empty then it starts a background goroutine to // start sending events to that destination. func (oq *destinationQueue) sendEDU(receipt *shared.Receipt) { - if oq.statistics.Blacklisted() { - // If the destination is blacklisted then drop the event. - log.Infof("%s is blacklisted; dropping ephemeral event", oq.destination) - return - } // Create a database entry that associates the given PDU NID with // this destination queue. We'll then be able to retrieve the PDU // later. @@ -130,13 +124,17 @@ func (oq *destinationQueue) sendEDU(receipt *shared.Receipt) { // We've successfully added an EDU to the transaction so increase // the counter. oq.transactionCount.Add(1) - // Wake up the queue if it's asleep. - oq.wakeQueueIfNeeded() - // If we're blocking on waiting PDUs then tell the queue that we - // have work to do. - select { - case oq.notifyEDUs <- true: - default: + // Check if the destination is blacklisted. If it isn't then wake + // up the queue. + if !oq.statistics.Blacklisted() { + // Wake up the queue if it's asleep. + oq.wakeQueueIfNeeded() + // If we're blocking on waiting EDUs then tell the queue that we + // have work to do. + select { + case oq.notifyEDUs <- true: + default: + } } } diff --git a/federationsender/queue/queue.go b/federationsender/queue/queue.go index e488a34a..5651fba2 100644 --- a/federationsender/queue/queue.go +++ b/federationsender/queue/queue.go @@ -21,8 +21,8 @@ import ( "fmt" "sync" + "github.com/matrix-org/dendrite/federationsender/statistics" "github.com/matrix-org/dendrite/federationsender/storage" - "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -36,7 +36,7 @@ type OutgoingQueues struct { rsAPI api.RoomserverInternalAPI origin gomatrixserverlib.ServerName client *gomatrixserverlib.FederationClient - statistics *types.Statistics + statistics *statistics.Statistics signing *SigningInfo queuesMutex sync.Mutex // protects the below queues map[gomatrixserverlib.ServerName]*destinationQueue @@ -48,7 +48,7 @@ func NewOutgoingQueues( origin gomatrixserverlib.ServerName, client *gomatrixserverlib.FederationClient, rsAPI api.RoomserverInternalAPI, - statistics *types.Statistics, + statistics *statistics.Statistics, signing *SigningInfo, ) *OutgoingQueues { queues := &OutgoingQueues{ @@ -77,7 +77,9 @@ func NewOutgoingQueues( log.WithError(err).Error("Failed to get EDU server names for destination queue hydration") } for serverName := range serverNames { - queues.getQueue(serverName).wakeQueueIfNeeded() + if !queues.getQueue(serverName).statistics.Blacklisted() { + queues.getQueue(serverName).wakeQueueIfNeeded() + } } return queues } diff --git a/federationsender/types/statistics.go b/federationsender/statistics/statistics.go index 63f82756..17dd896d 100644 --- a/federationsender/types/statistics.go +++ b/federationsender/statistics/statistics.go @@ -1,27 +1,28 @@ -package types +package statistics import ( "math" "sync" "time" + "github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" "go.uber.org/atomic" ) -const ( - // How many times should we tolerate consecutive failures before we - // just blacklist the host altogether? Bear in mind that the backoff - // is exponential, so the max time here to attempt is 2**failures. - FailuresUntilBlacklist = 16 // 16 equates to roughly 18 hours. -) - // Statistics contains information about all of the remote federated // hosts that we have interacted with. It is basically a threadsafe // wrapper. type Statistics struct { + DB storage.Database servers map[gomatrixserverlib.ServerName]*ServerStatistics mutex sync.RWMutex + + // How many times should we tolerate consecutive failures before we + // just blacklist the host altogether? The backoff is exponential, + // so the max time here to attempt is 2**failures seconds. + FailuresUntilBlacklist uint32 } // ForServer returns server statistics for the given server name. If it @@ -40,9 +41,18 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS // If we don't, then make one. if !found { s.mutex.Lock() - server = &ServerStatistics{} + server = &ServerStatistics{ + statistics: s, + serverName: serverName, + } s.servers[serverName] = server s.mutex.Unlock() + blacklisted, err := s.DB.IsServerBlacklisted(serverName) + if err != nil { + logrus.WithError(err).Errorf("Failed to get blacklist entry %q", serverName) + } else { + server.blacklisted.Store(blacklisted) + } } return server } @@ -52,10 +62,12 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS // many times we failed etc. It also manages the backoff time and black- // listing a remote host if it remains uncooperative. type ServerStatistics struct { - blacklisted atomic.Bool // is the remote side dead? - backoffUntil atomic.Value // time.Time to wait until before sending requests - failCounter atomic.Uint32 // how many times have we failed? - successCounter atomic.Uint32 // how many times have we succeeded? + statistics *Statistics // + serverName gomatrixserverlib.ServerName // + blacklisted atomic.Bool // is the node blacklisted + backoffUntil atomic.Value // time.Time to wait until before sending requests + failCounter atomic.Uint32 // how many times have we failed? + successCounter atomic.Uint32 // how many times have we succeeded? } // Success updates the server statistics with a new successful @@ -66,6 +78,9 @@ func (s *ServerStatistics) Success() { s.successCounter.Add(1) s.failCounter.Store(0) s.blacklisted.Store(false) + if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + } } // Failure marks a failure and works out when to backoff until. It @@ -77,12 +92,15 @@ func (s *ServerStatistics) Failure() bool { failCounter := s.failCounter.Add(1) // Check that we haven't failed more times than is acceptable. - if failCounter >= FailuresUntilBlacklist { + if failCounter >= s.statistics.FailuresUntilBlacklist { // We've exceeded the maximum amount of times we're willing // to back off, which is probably in the region of hours by // now. Mark the host as blacklisted and tell the caller to // give up. s.blacklisted.Store(true) + if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) + } return true } diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index 1bea83e2..b79499d3 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -47,4 +47,8 @@ type Database interface { GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + + AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error + RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error + IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) } diff --git a/federationsender/storage/postgres/blacklist_table.go b/federationsender/storage/postgres/blacklist_table.go new file mode 100644 index 00000000..f1db9fae --- /dev/null +++ b/federationsender/storage/postgres/blacklist_table.go @@ -0,0 +1,112 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const blacklistSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_blacklist ( + -- The blacklisted server name + server_name TEXT NOT NULL +); +` + +const insertBlacklistSQL = "" + + "INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectBlacklistSQL = "" + + "SELECT server_name FROM federationsender_blacklist WHERE server_name = $1" + +const deleteBlacklistSQL = "" + + "DELETE FROM federationsender_blacklist WHERE server_name = $1" + +type blacklistStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter + insertBlacklistStmt *sql.Stmt + selectBlacklistStmt *sql.Stmt + deleteBlacklistStmt *sql.Stmt +} + +func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { + s = &blacklistStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } + _, err = db.Exec(blacklistSchema) + if err != nil { + return + } + + if s.insertBlacklistStmt, err = db.Prepare(insertBlacklistSQL); err != nil { + return + } + if s.selectBlacklistStmt, err = db.Prepare(selectBlacklistSQL); err != nil { + return + } + if s.deleteBlacklistStmt, err = db.Prepare(deleteBlacklistSQL); err != nil { + return + } + return +} + +// insertRoom inserts the room if it didn't already exist. +// If the room didn't exist then last_event_id is set to the empty string. +func (s *blacklistStatements) InsertBlacklist( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err + }) +} + +// selectRoomForUpdate locks the row for the room and returns the last_event_id. +// The row must already exist in the table. Callers can ensure that the row +// exists by calling insertRoom first. +func (s *blacklistStatements) SelectBlacklist( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is blacklisted, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +// updateRoom updates the last_event_id for the room. selectRoomForUpdate should +// have already been called earlier within the transaction. +func (s *blacklistStatements) DeleteBlacklist( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err + }) +} diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index 66388bfe..a3094bda 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -56,6 +56,10 @@ func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties) (*Dat if err != nil { return nil, err } + blacklist, err := NewPostgresBlacklistTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, FederationSenderJoinedHosts: joinedHosts, @@ -63,6 +67,7 @@ func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties) (*Dat FederationSenderQueueEDUs: queueEDUs, FederationSenderQueueJSON: queueJSON, FederationSenderRooms: rooms, + FederationSenderBlacklist: blacklist, } if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil { return nil, err diff --git a/federationsender/storage/shared/storage.go b/federationsender/storage/shared/storage.go index 75681ea3..52f02a28 100644 --- a/federationsender/storage/shared/storage.go +++ b/federationsender/storage/shared/storage.go @@ -33,6 +33,7 @@ type Database struct { FederationSenderQueueJSON tables.FederationSenderQueueJSON FederationSenderJoinedHosts tables.FederationSenderJoinedHosts FederationSenderRooms tables.FederationSenderRooms + FederationSenderBlacklist tables.FederationSenderBlacklist } // An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. @@ -136,3 +137,15 @@ func (d *Database) StoreJSON( nids: []int64{nid}, }, nil } + +func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { + return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), nil, serverName) +} + +func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { + return d.FederationSenderBlacklist.DeleteBlacklist(context.TODO(), nil, serverName) +} + +func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { + return d.FederationSenderBlacklist.SelectBlacklist(context.TODO(), nil, serverName) +} diff --git a/federationsender/storage/sqlite3/blacklist_table.go b/federationsender/storage/sqlite3/blacklist_table.go new file mode 100644 index 00000000..3e302906 --- /dev/null +++ b/federationsender/storage/sqlite3/blacklist_table.go @@ -0,0 +1,112 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const blacklistSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_blacklist ( + -- The blacklisted server name + server_name TEXT NOT NULL +); +` + +const insertBlacklistSQL = "" + + "INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectBlacklistSQL = "" + + "SELECT server_name FROM federationsender_blacklist WHERE server_name = $1" + +const deleteBlacklistSQL = "" + + "DELETE FROM federationsender_blacklist WHERE server_name = $1" + +type blacklistStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter + insertBlacklistStmt *sql.Stmt + selectBlacklistStmt *sql.Stmt + deleteBlacklistStmt *sql.Stmt +} + +func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { + s = &blacklistStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } + _, err = db.Exec(blacklistSchema) + if err != nil { + return + } + + if s.insertBlacklistStmt, err = db.Prepare(insertBlacklistSQL); err != nil { + return + } + if s.selectBlacklistStmt, err = db.Prepare(selectBlacklistSQL); err != nil { + return + } + if s.deleteBlacklistStmt, err = db.Prepare(deleteBlacklistSQL); err != nil { + return + } + return +} + +// insertRoom inserts the room if it didn't already exist. +// If the room didn't exist then last_event_id is set to the empty string. +func (s *blacklistStatements) InsertBlacklist( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err + }) +} + +// selectRoomForUpdate locks the row for the room and returns the last_event_id. +// The row must already exist in the table. Callers can ensure that the row +// exists by calling insertRoom first. +func (s *blacklistStatements) SelectBlacklist( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is blacklisted, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +// updateRoom updates the last_event_id for the room. selectRoomForUpdate should +// have already been called earlier within the transaction. +func (s *blacklistStatements) DeleteBlacklist( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err + }) +} diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index 545a229c..c303d094 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -62,6 +62,10 @@ func NewDatabase(dataSourceName string) (*Database, error) { if err != nil { return nil, err } + blacklist, err := NewSQLiteBlacklistTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, FederationSenderJoinedHosts: joinedHosts, @@ -69,6 +73,7 @@ func NewDatabase(dataSourceName string) (*Database, error) { FederationSenderQueueEDUs: queueEDUs, FederationSenderQueueJSON: queueJSON, FederationSenderRooms: rooms, + FederationSenderBlacklist: blacklist, } if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil { return nil, err diff --git a/federationsender/storage/tables/interface.go b/federationsender/storage/tables/interface.go index 55d9119f..2def48d0 100644 --- a/federationsender/storage/tables/interface.go +++ b/federationsender/storage/tables/interface.go @@ -60,3 +60,9 @@ type FederationSenderRooms interface { SelectRoomForUpdate(ctx context.Context, txn *sql.Tx, roomID string) (string, error) UpdateRoom(ctx context.Context, txn *sql.Tx, roomID, lastEventID string) error } + +type FederationSenderBlacklist interface { + InsertBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + SelectBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error) + DeleteBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error +} |