aboutsummaryrefslogtreecommitdiff
path: root/federationsender
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-08-04 11:32:14 +0100
committerGitHub <noreply@github.com>2020-08-04 11:32:14 +0100
commit0c4e8f6d4f0f39d2bd72807675295e4fad70479c (patch)
tree9560b73d1aa1a62d0bff700523b1ea9586cdb7cf /federationsender
parentfb56bbf0b7d4b21da3f55b066e71d24bf4599887 (diff)
Send device list updates to servers (outbound only) (#1237)
* Add QueryDeviceMessages to serve up device keys and stream IDs * Consume key change events in fedsender Don't yet send them to destinations as we haven't worked them out yet * Send device list updates to all required servers * Glue it all together
Diffstat (limited to 'federationsender')
-rw-r--r--federationsender/consumers/keychange.go135
-rw-r--r--federationsender/federationsender.go8
-rw-r--r--federationsender/storage/interface.go2
-rw-r--r--federationsender/storage/postgres/joined_hosts_table.go38
-rw-r--r--federationsender/storage/shared/storage.go4
-rw-r--r--federationsender/storage/sqlite3/joined_hosts_table.go45
-rw-r--r--federationsender/storage/tables/interface.go1
7 files changed, 222 insertions, 11 deletions
diff --git a/federationsender/consumers/keychange.go b/federationsender/consumers/keychange.go
new file mode 100644
index 00000000..4c3d23b5
--- /dev/null
+++ b/federationsender/consumers/keychange.go
@@ -0,0 +1,135 @@
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package consumers
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+
+ "github.com/Shopify/sarama"
+ stateapi "github.com/matrix-org/dendrite/currentstateserver/api"
+ "github.com/matrix-org/dendrite/federationsender/queue"
+ "github.com/matrix-org/dendrite/federationsender/storage"
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/config"
+ "github.com/matrix-org/dendrite/keyserver/api"
+ "github.com/matrix-org/gomatrixserverlib"
+ log "github.com/sirupsen/logrus"
+)
+
+// KeyChangeConsumer consumes events that originate in key server.
+type KeyChangeConsumer struct {
+ consumer *internal.ContinualConsumer
+ db storage.Database
+ queues *queue.OutgoingQueues
+ serverName gomatrixserverlib.ServerName
+ stateAPI stateapi.CurrentStateInternalAPI
+}
+
+// NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers.
+func NewKeyChangeConsumer(
+ cfg *config.Dendrite,
+ kafkaConsumer sarama.Consumer,
+ queues *queue.OutgoingQueues,
+ store storage.Database,
+ stateAPI stateapi.CurrentStateInternalAPI,
+) *KeyChangeConsumer {
+ c := &KeyChangeConsumer{
+ consumer: &internal.ContinualConsumer{
+ Topic: string(cfg.Kafka.Topics.OutputKeyChangeEvent),
+ Consumer: kafkaConsumer,
+ PartitionStore: store,
+ },
+ queues: queues,
+ db: store,
+ serverName: cfg.Matrix.ServerName,
+ stateAPI: stateAPI,
+ }
+ c.consumer.ProcessMessage = c.onMessage
+
+ return c
+}
+
+// Start consuming from key servers
+func (t *KeyChangeConsumer) Start() error {
+ if err := t.consumer.Start(); err != nil {
+ return fmt.Errorf("t.consumer.Start: %w", err)
+ }
+ return nil
+}
+
+// onMessage is called in response to a message received on the
+// key change events topic from the key server.
+func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error {
+ var m api.DeviceMessage
+ if err := json.Unmarshal(msg.Value, &m); err != nil {
+ log.WithError(err).Errorf("failed to read device message from key change topic")
+ return nil
+ }
+ logger := log.WithField("user_id", m.UserID)
+
+ // only send key change events which originated from us
+ _, originServerName, err := gomatrixserverlib.SplitID('@', m.UserID)
+ if err != nil {
+ logger.WithError(err).Error("Failed to extract domain from key change event")
+ return nil
+ }
+ if originServerName != t.serverName {
+ return nil
+ }
+
+ var queryRes stateapi.QueryRoomsForUserResponse
+ err = t.stateAPI.QueryRoomsForUser(context.Background(), &stateapi.QueryRoomsForUserRequest{
+ UserID: m.UserID,
+ WantMembership: "join",
+ }, &queryRes)
+ if err != nil {
+ logger.WithError(err).Error("failed to calculate joined rooms for user")
+ return nil
+ }
+ // send this key change to all servers who share rooms with this user.
+ destinations, err := t.db.GetJoinedHostsForRooms(context.Background(), queryRes.RoomIDs)
+ if err != nil {
+ logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
+ return nil
+ }
+
+ // Pack the EDU and marshal it
+ edu := &gomatrixserverlib.EDU{
+ Type: gomatrixserverlib.MDeviceListUpdate,
+ Origin: string(t.serverName),
+ }
+ event := gomatrixserverlib.DeviceListUpdateEvent{
+ UserID: m.UserID,
+ DeviceID: m.DeviceID,
+ DeviceDisplayName: m.DisplayName,
+ StreamID: m.StreamID,
+ PrevID: prevID(m.StreamID),
+ Deleted: len(m.KeyJSON) == 0,
+ Keys: m.KeyJSON,
+ }
+ if edu.Content, err = json.Marshal(event); err != nil {
+ return err
+ }
+
+ log.Infof("Sending device list update message to %q", destinations)
+ return t.queues.SendEDU(edu, t.serverName, destinations)
+}
+
+func prevID(streamID int) []int {
+ if streamID <= 1 {
+ return nil
+ }
+ return []int{streamID - 1}
+}
diff --git a/federationsender/federationsender.go b/federationsender/federationsender.go
index 9e14f6ec..fbf506aa 100644
--- a/federationsender/federationsender.go
+++ b/federationsender/federationsender.go
@@ -16,6 +16,7 @@ package federationsender
import (
"github.com/gorilla/mux"
+ stateapi "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/federationsender/consumers"
"github.com/matrix-org/dendrite/federationsender/internal"
@@ -41,6 +42,7 @@ func NewInternalAPI(
base *setup.BaseDendrite,
federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI,
+ stateAPI stateapi.CurrentStateInternalAPI,
keyRing *gomatrixserverlib.KeyRing,
) api.FederationSenderInternalAPI {
federationSenderDB, err := storage.NewDatabase(string(base.Cfg.Database.FederationSender), base.Cfg.DbProperties())
@@ -76,6 +78,12 @@ func NewInternalAPI(
if err := tsConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start typing server consumer")
}
+ keyConsumer := consumers.NewKeyChangeConsumer(
+ base.Cfg, base.KafkaConsumer, queues, federationSenderDB, stateAPI,
+ )
+ if err := keyConsumer.Start(); err != nil {
+ logrus.WithError(err).Panic("failed to start key server consumer")
+ }
return internal.NewFederationSenderInternalAPI(federationSenderDB, base.Cfg, rsAPI, federation, keyRing, stats, queues)
}
diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go
index b79499d3..734b368f 100644
--- a/federationsender/storage/interface.go
+++ b/federationsender/storage/interface.go
@@ -30,6 +30,8 @@ type Database interface {
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
+ // GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
+ GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
diff --git a/federationsender/storage/postgres/joined_hosts_table.go b/federationsender/storage/postgres/joined_hosts_table.go
index af0a5258..52865996 100644
--- a/federationsender/storage/postgres/joined_hosts_table.go
+++ b/federationsender/storage/postgres/joined_hosts_table.go
@@ -60,12 +60,16 @@ const selectJoinedHostsSQL = "" +
const selectAllJoinedHostsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
+const selectJoinedHostsForRoomsSQL = "" +
+ "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)"
+
type joinedHostsStatements struct {
- db *sql.DB
- insertJoinedHostsStmt *sql.Stmt
- deleteJoinedHostsStmt *sql.Stmt
- selectJoinedHostsStmt *sql.Stmt
- selectAllJoinedHostsStmt *sql.Stmt
+ db *sql.DB
+ insertJoinedHostsStmt *sql.Stmt
+ deleteJoinedHostsStmt *sql.Stmt
+ selectJoinedHostsStmt *sql.Stmt
+ selectAllJoinedHostsStmt *sql.Stmt
+ selectJoinedHostsForRoomsStmt *sql.Stmt
}
func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@@ -88,6 +92,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro
if s.selectAllJoinedHostsStmt, err = s.db.Prepare(selectAllJoinedHostsSQL); err != nil {
return
}
+ if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
+ return
+ }
return
}
@@ -144,6 +151,27 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
return result, rows.Err()
}
+func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
+ ctx context.Context, roomIDs []string,
+) ([]gomatrixserverlib.ServerName, error) {
+ rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed")
+
+ var result []gomatrixserverlib.ServerName
+ for rows.Next() {
+ var serverName string
+ if err = rows.Scan(&serverName); err != nil {
+ return nil, err
+ }
+ result = append(result, gomatrixserverlib.ServerName(serverName))
+ }
+
+ return result, rows.Err()
+}
+
func joinedHostsFromStmt(
ctx context.Context, stmt *sql.Stmt, roomID string,
) ([]types.JoinedHost, error) {
diff --git a/federationsender/storage/shared/storage.go b/federationsender/storage/shared/storage.go
index 52f02a28..4a681de6 100644
--- a/federationsender/storage/shared/storage.go
+++ b/federationsender/storage/shared/storage.go
@@ -123,6 +123,10 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S
return d.FederationSenderJoinedHosts.SelectAllJoinedHosts(ctx)
}
+func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) {
+ return d.FederationSenderJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs)
+}
+
// StoreJSON adds a JSON blob into the queue JSON table and returns
// a NID. The NID will then be used when inserting the per-destination
// metadata entries.
diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go
index bd917c61..4ae980d7 100644
--- a/federationsender/storage/sqlite3/joined_hosts_table.go
+++ b/federationsender/storage/sqlite3/joined_hosts_table.go
@@ -59,13 +59,17 @@ const selectJoinedHostsSQL = "" +
const selectAllJoinedHostsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
+const selectJoinedHostsForRoomsSQL = "" +
+ "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
+
type joinedHostsStatements struct {
- db *sql.DB
- writer *sqlutil.TransactionWriter
- insertJoinedHostsStmt *sql.Stmt
- deleteJoinedHostsStmt *sql.Stmt
- selectJoinedHostsStmt *sql.Stmt
- selectAllJoinedHostsStmt *sql.Stmt
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
+ insertJoinedHostsStmt *sql.Stmt
+ deleteJoinedHostsStmt *sql.Stmt
+ selectJoinedHostsStmt *sql.Stmt
+ selectAllJoinedHostsStmt *sql.Stmt
+ selectJoinedHostsForRoomsStmt *sql.Stmt
}
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@@ -89,6 +93,9 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error)
if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil {
return
}
+ if s.selectJoinedHostsForRoomsStmt, err = db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
+ return
+ }
return
}
@@ -153,6 +160,32 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
return result, rows.Err()
}
+func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
+ ctx context.Context, roomIDs []string,
+) ([]gomatrixserverlib.ServerName, error) {
+ iRoomIDs := make([]interface{}, len(roomIDs))
+ for i := range roomIDs {
+ iRoomIDs[i] = roomIDs[i]
+ }
+
+ rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, iRoomIDs...)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed")
+
+ var result []gomatrixserverlib.ServerName
+ for rows.Next() {
+ var serverName string
+ if err = rows.Scan(&serverName); err != nil {
+ return nil, err
+ }
+ result = append(result, gomatrixserverlib.ServerName(serverName))
+ }
+
+ return result, rows.Err()
+}
+
func joinedHostsFromStmt(
ctx context.Context, stmt *sql.Stmt, roomID string,
) ([]types.JoinedHost, error) {
diff --git a/federationsender/storage/tables/interface.go b/federationsender/storage/tables/interface.go
index 2def48d0..c6f8a2d5 100644
--- a/federationsender/storage/tables/interface.go
+++ b/federationsender/storage/tables/interface.go
@@ -53,6 +53,7 @@ type FederationSenderJoinedHosts interface {
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
+ SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
}
type FederationSenderRooms interface {