aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2022-07-05 14:50:56 +0200
committerGitHub <noreply@github.com>2022-07-05 14:50:56 +0200
commit5087b36af035bcf82a8655e35a2c661d7be72048 (patch)
tree8646eaf67d02dbf3619770438f07da68e6157780
parentf29cdb26f6ca14b0f533ecdabda81aa7d9439db2 (diff)
Fix QuerySharedUsers for the SyncAPI keychange consumer (#2554)
* Make more use of base.BaseDendrite * Fix QuerySharedUsers if no UserIDs are supplied
-rw-r--r--roomserver/internal/api.go36
-rw-r--r--roomserver/internal/input/input.go6
-rw-r--r--roomserver/roomserver.go14
-rw-r--r--roomserver/roomserver_test.go69
-rw-r--r--roomserver/storage/postgres/membership_table.go22
-rw-r--r--roomserver/storage/shared/storage.go7
-rw-r--r--roomserver/storage/sqlite3/membership_table.go14
7 files changed, 137 insertions, 31 deletions
diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go
index d59b8be7..1a11586a 100644
--- a/roomserver/internal/api.go
+++ b/roomserver/internal/api.go
@@ -14,6 +14,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/producers"
"github.com/matrix-org/dendrite/roomserver/storage"
+ "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
@@ -39,6 +40,7 @@ type RoomserverInternalAPI struct {
*perform.Upgrader
*perform.Admin
ProcessContext *process.ProcessContext
+ Base *base.BaseDendrite
DB storage.Database
Cfg *config.RoomServer
Cache caching.RoomServerCaches
@@ -56,33 +58,38 @@ type RoomserverInternalAPI struct {
}
func NewRoomserverAPI(
- processCtx *process.ProcessContext, cfg *config.RoomServer, roomserverDB storage.Database,
- js nats.JetStreamContext, nc *nats.Conn, inputRoomEventTopic string,
- caches caching.RoomServerCaches, perspectiveServerNames []gomatrixserverlib.ServerName,
+ base *base.BaseDendrite, roomserverDB storage.Database,
+ js nats.JetStreamContext, nc *nats.Conn,
) *RoomserverInternalAPI {
+ var perspectiveServerNames []gomatrixserverlib.ServerName
+ for _, kp := range base.Cfg.FederationAPI.KeyPerspectives {
+ perspectiveServerNames = append(perspectiveServerNames, kp.ServerName)
+ }
+
serverACLs := acls.NewServerACLs(roomserverDB)
producer := &producers.RoomEventProducer{
- Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent)),
+ Topic: string(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)),
JetStream: js,
ACLs: serverACLs,
}
a := &RoomserverInternalAPI{
- ProcessContext: processCtx,
+ ProcessContext: base.ProcessContext,
DB: roomserverDB,
- Cfg: cfg,
- Cache: caches,
- ServerName: cfg.Matrix.ServerName,
+ Base: base,
+ Cfg: &base.Cfg.RoomServer,
+ Cache: base.Caches,
+ ServerName: base.Cfg.Global.ServerName,
PerspectiveServerNames: perspectiveServerNames,
- InputRoomEventTopic: inputRoomEventTopic,
+ InputRoomEventTopic: base.Cfg.Global.JetStream.Prefixed(jetstream.InputRoomEvent),
OutputProducer: producer,
JetStream: js,
NATSClient: nc,
- Durable: cfg.Matrix.JetStream.Durable("RoomserverInputConsumer"),
+ Durable: base.Cfg.Global.JetStream.Durable("RoomserverInputConsumer"),
ServerACLs: serverACLs,
Queryer: &query.Queryer{
DB: roomserverDB,
- Cache: caches,
- ServerName: cfg.Matrix.ServerName,
+ Cache: base.Caches,
+ ServerName: base.Cfg.Global.ServerName,
ServerACLs: serverACLs,
},
// perform-er structs get initialised when we have a federation sender to use
@@ -98,8 +105,9 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
r.KeyRing = keyRing
r.Inputer = &input.Inputer{
- Cfg: r.Cfg,
- ProcessContext: r.ProcessContext,
+ Cfg: &r.Base.Cfg.RoomServer,
+ Base: r.Base,
+ ProcessContext: r.Base.ProcessContext,
DB: r.DB,
InputRoomEventTopic: r.InputRoomEventTopic,
OutputProducer: r.OutputProducer,
diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go
index fa07c1d2..ecd4ecbb 100644
--- a/roomserver/internal/input/input.go
+++ b/roomserver/internal/input/input.go
@@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/producers"
"github.com/matrix-org/dendrite/roomserver/storage"
+ "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
@@ -69,6 +70,7 @@ import (
// or C.
type Inputer struct {
Cfg *config.RoomServer
+ Base *base.BaseDendrite
ProcessContext *process.ProcessContext
DB storage.Database
NATSClient *nats.Conn
@@ -160,7 +162,9 @@ func (r *Inputer) startWorkerForRoom(roomID string) {
// will look to see if we have a worker for that room which has its
// own consumer. If we don't, we'll start one.
func (r *Inputer) Start() error {
- prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration)
+ if r.Base.EnableMetrics {
+ prometheus.MustRegister(roomserverInputBackpressure, processRoomEventDuration)
+ }
_, err := r.JetStream.Subscribe(
"", // This is blank because we specified it in BindStream.
func(m *nats.Msg) {
diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go
index eb68100f..1f707735 100644
--- a/roomserver/roomserver.go
+++ b/roomserver/roomserver.go
@@ -17,13 +17,10 @@ package roomserver
import (
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/roomserver/api"
- "github.com/matrix-org/dendrite/roomserver/inthttp"
- "github.com/matrix-org/gomatrixserverlib"
-
"github.com/matrix-org/dendrite/roomserver/internal"
+ "github.com/matrix-org/dendrite/roomserver/inthttp"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
- "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/sirupsen/logrus"
)
@@ -40,11 +37,6 @@ func NewInternalAPI(
) api.RoomserverInternalAPI {
cfg := &base.Cfg.RoomServer
- var perspectiveServerNames []gomatrixserverlib.ServerName
- for _, kp := range base.Cfg.FederationAPI.KeyPerspectives {
- perspectiveServerNames = append(perspectiveServerNames, kp.ServerName)
- }
-
roomserverDB, err := storage.Open(base, &cfg.Database, base.Caches)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to room server db")
@@ -53,8 +45,6 @@ func NewInternalAPI(
js, nc := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
return internal.NewRoomserverAPI(
- base.ProcessContext, cfg, roomserverDB, js, nc,
- cfg.Matrix.JetStream.Prefixed(jetstream.InputRoomEvent),
- base.Caches, perspectiveServerNames,
+ base, roomserverDB, js, nc,
)
}
diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go
new file mode 100644
index 00000000..4e98af85
--- /dev/null
+++ b/roomserver/roomserver_test.go
@@ -0,0 +1,69 @@
+package roomserver_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/matrix-org/dendrite/roomserver"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/storage"
+ "github.com/matrix-org/dendrite/setup/base"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) {
+ base, close := testrig.CreateBaseDendrite(t, dbType)
+ db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches)
+ if err != nil {
+ t.Fatalf("failed to create Database: %v", err)
+ }
+ return base, db, close
+}
+
+func Test_SharedUsers(t *testing.T) {
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+ room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
+
+ // Invite and join Bob
+ room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "invite",
+ }, test.WithStateKey(bob.ID))
+ room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(bob.ID))
+
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ base, _, close := mustCreateDatabase(t, dbType)
+ defer close()
+
+ rsAPI := roomserver.NewInternalAPI(base)
+ // SetFederationAPI starts the room event input consumer
+ rsAPI.SetFederationAPI(nil, nil)
+ // Create the room
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", nil, false); err != nil {
+ t.Fatalf("failed to send events: %v", err)
+ }
+
+ // Query the shared users for Alice, there should only be Bob.
+ // This is used by the SyncAPI keychange consumer.
+ res := &api.QuerySharedUsersResponse{}
+ if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
+ t.Fatalf("unable to query known users: %v", err)
+ }
+ if _, ok := res.UserIDsToCount[bob.ID]; !ok {
+ t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
+ }
+ // Also verify that we get the expected result when specifying OtherUserIDs.
+ // This is used by the SyncAPI when getting device list changes.
+ if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
+ t.Fatalf("unable to query known users: %v", err)
+ }
+ if _, ok := res.UserIDsToCount[bob.ID]; !ok {
+ t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
+ }
+ })
+}
diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go
index c01753c3..ce626ad1 100644
--- a/roomserver/storage/postgres/membership_table.go
+++ b/roomserver/storage/postgres/membership_table.go
@@ -65,12 +65,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
);
`
-var selectJoinedUsersSetForRoomsSQL = "" +
+var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
+var selectJoinedUsersSetForRoomsSQL = "" +
+ "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
+ " WHERE room_nid = ANY($1) AND" +
+ " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
+ " GROUP BY target_nid"
+
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
@@ -153,6 +159,7 @@ type membershipStatements struct {
selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
selectRoomsWithMembershipStmt *sql.Stmt
+ selectJoinedUsersSetForRoomsAndUserStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt
updateMembershipForgetRoomStmt *sql.Stmt
@@ -178,6 +185,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
+ {&s.selectJoinedUsersSetForRoomsAndUserStmt, selectJoinedUsersSetForRoomsAndUserSQL},
{&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL},
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
@@ -313,8 +321,18 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
roomNIDs []types.RoomNID,
userNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]int, error) {
+ var (
+ rows *sql.Rows
+ err error
+ )
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
- rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
+ if len(userNIDs) > 0 {
+ stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt)
+ rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
+ } else {
+ rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs))
+ }
+
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 67dcfdf3..5c633122 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -1214,6 +1214,13 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
stateKeyNIDs[i] = nid
i++
}
+ // If we didn't have any userIDs to look up, get the UserIDs for the returned userNIDToCount now
+ if len(userIDs) == 0 {
+ nidToUserID, err = d.EventStateKeys(ctx, stateKeyNIDs)
+ if err != nil {
+ return nil, err
+ }
+ }
result := make(map[string]int, len(userNIDToCount))
for nid, count := range userNIDToCount {
result[nidToUserID[nid]] = count
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
index 6f0fe8b6..570d3919 100644
--- a/roomserver/storage/sqlite3/membership_table.go
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -41,12 +41,18 @@ const membershipSchema = `
);
`
-var selectJoinedUsersSetForRoomsSQL = "" +
+var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
+var selectJoinedUsersSetForRoomsSQL = "" +
+ "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
+ " WHERE room_nid IN ($1) AND " +
+ " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
+ " GROUP BY target_nid"
+
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
@@ -293,8 +299,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
for _, v := range userNIDs {
params = append(params, v)
}
+
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
- query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
+ if len(userNIDs) > 0 {
+ query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
+ query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
+ }
var rows *sql.Rows
var err error
if txn != nil {