diff options
Diffstat (limited to 'federationsender')
-rw-r--r-- | federationsender/consumers/roomserver.go | 6 | ||||
-rw-r--r-- | federationsender/storage/interface.go | 1 | ||||
-rw-r--r-- | federationsender/storage/postgres/joined_hosts_table.go | 15 | ||||
-rw-r--r-- | federationsender/storage/shared/storage.go | 14 | ||||
-rw-r--r-- | federationsender/storage/sqlite3/joined_hosts_table.go | 25 | ||||
-rw-r--r-- | federationsender/storage/tables/interface.go | 1 |
6 files changed, 57 insertions, 5 deletions
diff --git a/federationsender/consumers/roomserver.go b/federationsender/consumers/roomserver.go index efeb53fa..ef945694 100644 --- a/federationsender/consumers/roomserver.go +++ b/federationsender/consumers/roomserver.go @@ -87,6 +87,12 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { case api.OutputTypeNewRoomEvent: ev := &output.NewRoomEvent.Event + if output.NewRoomEvent.RewritesState { + if err := s.db.PurgeRoomState(context.TODO(), ev.RoomID()); err != nil { + return fmt.Errorf("s.db.PurgeRoom: %w", err) + } + } + if err := s.processMessage(*output.NewRoomEvent); err != nil { // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index 734b368f..a3f5073f 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -32,6 +32,7 @@ type Database interface { GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) + PurgeRoomState(ctx context.Context, roomID string) error StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) diff --git a/federationsender/storage/postgres/joined_hosts_table.go b/federationsender/storage/postgres/joined_hosts_table.go index cc112b7f..0bc9335d 100644 --- a/federationsender/storage/postgres/joined_hosts_table.go +++ b/federationsender/storage/postgres/joined_hosts_table.go @@ -53,6 +53,9 @@ const insertJoinedHostsSQL = "" + const deleteJoinedHostsSQL = "" + "DELETE FROM federationsender_joined_hosts WHERE event_id = ANY($1)" +const deleteJoinedHostsForRoomSQL = "" + + "DELETE FROM federationsender_joined_hosts WHERE room_id = $1" + const selectJoinedHostsSQL = "" + "SELECT event_id, server_name FROM federationsender_joined_hosts" + " WHERE room_id = $1" @@ -67,6 +70,7 @@ type joinedHostsStatements struct { db *sql.DB insertJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt + deleteJoinedHostsForRoomStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt selectJoinedHostsForRoomsStmt *sql.Stmt @@ -86,6 +90,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro if s.deleteJoinedHostsStmt, err = s.db.Prepare(deleteJoinedHostsSQL); err != nil { return } + if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil { + return + } if s.selectJoinedHostsStmt, err = s.db.Prepare(selectJoinedHostsSQL); err != nil { return } @@ -117,6 +124,14 @@ func (s *joinedHostsStatements) DeleteJoinedHosts( return err } +func (s *joinedHostsStatements) DeleteJoinedHostsForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt) + _, err := stmt.ExecContext(ctx, roomID) + return err +} + func (s *joinedHostsStatements) SelectJoinedHostsWithTx( ctx context.Context, txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { diff --git a/federationsender/storage/shared/storage.go b/federationsender/storage/shared/storage.go index 4c80c079..d5731f31 100644 --- a/federationsender/storage/shared/storage.go +++ b/federationsender/storage/shared/storage.go @@ -148,6 +148,20 @@ func (d *Database) StoreJSON( }, nil } +func (d *Database) PurgeRoomState( + ctx context.Context, roomID string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // If the event is a create event then we'll delete all of the existing + // data for the room. The only reason that a create event would be replayed + // to us in this way is if we're about to receive the entire room state. + if err := d.FederationSenderJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.FederationSenderJoinedHosts.DeleteJoinedHosts: %w", err) + } + return nil + }) +} + func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), txn, serverName) diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index 3bc45e7d..1f906808 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -53,6 +53,9 @@ const insertJoinedHostsSQL = "" + const deleteJoinedHostsSQL = "" + "DELETE FROM federationsender_joined_hosts WHERE event_id = $1" +const deleteJoinedHostsForRoomSQL = "" + + "DELETE FROM federationsender_joined_hosts WHERE room_id = $1" + const selectJoinedHostsSQL = "" + "SELECT event_id, server_name FROM federationsender_joined_hosts" + " WHERE room_id = $1" @@ -64,11 +67,12 @@ const selectJoinedHostsForRoomsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" type joinedHostsStatements struct { - db *sql.DB - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt - selectAllJoinedHostsStmt *sql.Stmt + db *sql.DB + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + deleteJoinedHostsForRoomStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic } @@ -86,6 +90,9 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil { return } + if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil { + return + } if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil { return } @@ -118,6 +125,14 @@ func (s *joinedHostsStatements) DeleteJoinedHosts( return nil } +func (s *joinedHostsStatements) DeleteJoinedHostsForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt) + _, err := stmt.ExecContext(ctx, roomID) + return err +} + func (s *joinedHostsStatements) SelectJoinedHostsWithTx( ctx context.Context, txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { diff --git a/federationsender/storage/tables/interface.go b/federationsender/storage/tables/interface.go index c6f8a2d5..1167a212 100644 --- a/federationsender/storage/tables/interface.go +++ b/federationsender/storage/tables/interface.go @@ -50,6 +50,7 @@ type FederationSenderQueueJSON interface { type FederationSenderJoinedHosts interface { InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error + DeleteJoinedHostsForRoom(ctx context.Context, txn *sql.Tx, roomID string) error SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) |