aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--build/dendritejs-pinecone/main.go6
-rw-r--r--build/gobind-pinecone/monolith.go2
-rw-r--r--build/gobind-yggdrasil/monolith.go2
-rw-r--r--cmd/dendrite-demo-pinecone/main.go2
-rw-r--r--cmd/dendrite-demo-yggdrasil/main.go5
-rw-r--r--cmd/dendrite-monolith-server/main.go2
-rw-r--r--cmd/dendrite-polylith-multi/personalities/keyserver.go3
-rw-r--r--keyserver/internal/device_list_update.go29
-rw-r--r--keyserver/internal/device_list_update_test.go86
-rw-r--r--keyserver/keyserver.go12
-rw-r--r--keyserver/keyserver_test.go29
-rw-r--r--keyserver/storage/interface.go5
-rw-r--r--keyserver/storage/postgres/stale_device_lists.go33
-rw-r--r--keyserver/storage/shared/storage.go10
-rw-r--r--keyserver/storage/sqlite3/stale_device_lists.go44
-rw-r--r--keyserver/storage/tables/interface.go1
-rw-r--r--keyserver/storage/tables/stale_device_lists_test.go94
-rw-r--r--roomserver/api/api.go5
-rw-r--r--roomserver/api/api_trace.go6
-rw-r--r--roomserver/api/query.go12
-rw-r--r--roomserver/internal/query/query.go6
-rw-r--r--roomserver/inthttp/client.go8
-rw-r--r--roomserver/inthttp/server.go5
-rw-r--r--roomserver/roomserver_test.go77
-rw-r--r--roomserver/storage/interface.go1
-rw-r--r--roomserver/storage/postgres/membership_table.go34
-rw-r--r--roomserver/storage/shared/storage.go37
-rw-r--r--roomserver/storage/shared/storage_test.go96
-rw-r--r--roomserver/storage/sqlite3/membership_table.go47
-rw-r--r--roomserver/storage/tables/interface.go1
-rw-r--r--roomserver/storage/tables/membership_table_test.go6
31 files changed, 666 insertions, 40 deletions
diff --git a/build/dendritejs-pinecone/main.go b/build/dendritejs-pinecone/main.go
index e070173a..f44a7748 100644
--- a/build/dendritejs-pinecone/main.go
+++ b/build/dendritejs-pinecone/main.go
@@ -180,14 +180,14 @@ func startup() {
base := base.NewBaseDendrite(cfg, "Monolith")
defer base.Close() // nolint: errcheck
+ rsAPI := roomserver.NewInternalAPI(base)
+
federation := conn.CreateFederationClient(base, pSessions)
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
+ keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
- rsAPI := roomserver.NewInternalAPI(base)
-
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go
index e8ed8fe8..b8f8111d 100644
--- a/build/gobind-pinecone/monolith.go
+++ b/build/gobind-pinecone/monolith.go
@@ -350,7 +350,7 @@ func (m *DendriteMonolith) Start() {
base, federation, rsAPI, base.Caches, keyRing, true,
)
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
+ keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(m.userAPI)
diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go
index 9a3ac5d7..8c2d0a00 100644
--- a/build/gobind-yggdrasil/monolith.go
+++ b/build/gobind-yggdrasil/monolith.go
@@ -165,7 +165,7 @@ func (m *DendriteMonolith) Start() {
base, federation, rsAPI, base.Caches, keyRing, true,
)
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
+ keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go
index 2f647a41..3f627b41 100644
--- a/cmd/dendrite-demo-pinecone/main.go
+++ b/cmd/dendrite-demo-pinecone/main.go
@@ -213,7 +213,7 @@ func main() {
base, federation, rsAPI, base.Caches, keyRing, true,
)
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
+ keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsComponent)
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go
index 5dd61b1b..3ea4a08b 100644
--- a/cmd/dendrite-demo-yggdrasil/main.go
+++ b/cmd/dendrite-demo-yggdrasil/main.go
@@ -157,11 +157,12 @@ func main() {
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
-
rsComponent := roomserver.NewInternalAPI(
base,
)
+
+ keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsComponent)
+
rsAPI := rsComponent
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go
index 2d2f32b0..6836b642 100644
--- a/cmd/dendrite-monolith-server/main.go
+++ b/cmd/dendrite-monolith-server/main.go
@@ -95,7 +95,7 @@ func main() {
}
keyRing := fsAPI.KeyRing()
- keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
+ keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
keyAPI := keyImpl
if base.UseHTTPAPIs {
keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI, base.EnableMetrics)
diff --git a/cmd/dendrite-polylith-multi/personalities/keyserver.go b/cmd/dendrite-polylith-multi/personalities/keyserver.go
index d2924b89..ad0bd0e5 100644
--- a/cmd/dendrite-polylith-multi/personalities/keyserver.go
+++ b/cmd/dendrite-polylith-multi/personalities/keyserver.go
@@ -22,7 +22,8 @@ import (
func KeyServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
fsAPI := base.FederationAPIHTTPClient()
- intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
+ rsAPI := base.RoomserverHTTPClient()
+ intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
intAPI.SetUserAPI(base.UserAPIClient())
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics)
diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go
index 8ff9dfc3..c7bf8da5 100644
--- a/keyserver/internal/device_list_update.go
+++ b/keyserver/internal/device_list_update.go
@@ -24,6 +24,8 @@ import (
"sync"
"time"
+ rsapi "github.com/matrix-org/dendrite/roomserver/api"
+
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@@ -102,6 +104,7 @@ type DeviceListUpdater struct {
// block on or timeout via a select.
userIDToChan map[string]chan bool
userIDToChanMu *sync.Mutex
+ rsAPI rsapi.KeyserverRoomserverAPI
}
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
@@ -124,6 +127,8 @@ type DeviceListUpdaterDatabase interface {
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
+
+ DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error
}
type DeviceListUpdaterAPI interface {
@@ -140,7 +145,7 @@ func NewDeviceListUpdater(
process *process.ProcessContext, db DeviceListUpdaterDatabase,
api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
- thisServer gomatrixserverlib.ServerName,
+ rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName,
) *DeviceListUpdater {
return &DeviceListUpdater{
process: process,
@@ -154,6 +159,7 @@ func NewDeviceListUpdater(
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
userIDToChan: make(map[string]chan bool),
userIDToChanMu: &sync.Mutex{},
+ rsAPI: rsAPI,
}
}
@@ -168,7 +174,7 @@ func (u *DeviceListUpdater) Start() error {
go u.worker(ch)
}
- staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{})
+ staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
if err != nil {
return err
}
@@ -186,6 +192,25 @@ func (u *DeviceListUpdater) Start() error {
return nil
}
+// CleanUp removes stale device entries for users we don't share a room with anymore
+func (u *DeviceListUpdater) CleanUp() error {
+ staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
+ if err != nil {
+ return err
+ }
+
+ res := rsapi.QueryLeftUsersResponse{}
+ if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil {
+ return err
+ }
+
+ if len(res.LeftUsers) == 0 {
+ return nil
+ }
+ logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers))
+ return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers)
+}
+
func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
u.mu.Lock()
defer u.mu.Unlock()
diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go
index a374c951..60a2c2f3 100644
--- a/keyserver/internal/device_list_update_test.go
+++ b/keyserver/internal/device_list_update_test.go
@@ -30,7 +30,12 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/keyserver/api"
+ "github.com/matrix-org/dendrite/keyserver/storage"
+ roomserver "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
)
var (
@@ -53,6 +58,10 @@ type mockDeviceListUpdaterDatabase struct {
mu sync.Mutex // protect staleUsers
}
+func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error {
+ return nil
+}
+
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
@@ -153,7 +162,7 @@ func TestUpdateHavePrevID(t *testing.T) {
}
ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{}
- updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost")
+ updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost")
event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar",
Deleted: false,
@@ -225,7 +234,7 @@ func TestUpdateNoPrevID(t *testing.T) {
`)),
}, nil
})
- updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test")
+ updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test")
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
@@ -239,6 +248,7 @@ func TestUpdateNoPrevID(t *testing.T) {
UserID: remoteUserID,
}
err := updater.Update(ctx, event)
+
if err != nil {
t.Fatalf("Update returned an error: %s", err)
}
@@ -294,7 +304,7 @@ func TestDebounce(t *testing.T) {
close(incomingFedReq)
return <-fedCh, nil
})
- updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost")
+ updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost")
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
@@ -349,3 +359,73 @@ func TestDebounce(t *testing.T) {
t.Errorf("user %s is marked as stale", userID)
}
}
+
+func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) {
+ t.Helper()
+
+ base, _, _ := testrig.Base(nil)
+ connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
+ db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ return db, clearDB
+}
+
+type mockKeyserverRoomserverAPI struct {
+ leftUsers []string
+}
+
+func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
+ res.LeftUsers = m.leftUsers
+ return nil
+}
+
+func TestDeviceListUpdater_CleanUp(t *testing.T) {
+ processCtx := process.NewProcessContext()
+
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+
+ // Bob is not joined to any of our rooms
+ rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}}
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clearDB := mustCreateKeyserverDB(t, dbType)
+ defer clearDB()
+
+ // This should not get deleted
+ if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil {
+ t.Error(err)
+ }
+
+ // this one should get deleted
+ if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil {
+ t.Error(err)
+ }
+
+ updater := NewDeviceListUpdater(processCtx, db, nil,
+ nil, nil,
+ 0, rsAPI, "test")
+ if err := updater.CleanUp(); err != nil {
+ t.Error(err)
+ }
+
+ // check that we still have Alice in our stale list
+ staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
+ if err != nil {
+ t.Error(err)
+ }
+
+ // There should only be Alice
+ wantCount := 1
+ if count := len(staleUsers); count != wantCount {
+ t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count)
+ }
+
+ if staleUsers[0] != alice.ID {
+ t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID)
+ }
+ })
+}
diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go
index 5360c06f..27557677 100644
--- a/keyserver/keyserver.go
+++ b/keyserver/keyserver.go
@@ -18,6 +18,8 @@ import (
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
+ rsapi "github.com/matrix-org/dendrite/roomserver/api"
+
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/consumers"
@@ -40,6 +42,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI, enableMetr
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
+ rsAPI rsapi.KeyserverRoomserverAPI,
) api.KeyInternalAPI {
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
@@ -47,6 +50,7 @@ func NewInternalAPI(
if err != nil {
logrus.WithError(err).Panicf("failed to connect to key server database")
}
+
keyChangeProducer := &producers.KeyChange{
Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)),
JetStream: js,
@@ -58,8 +62,14 @@ func NewInternalAPI(
FedClient: fedClient,
Producer: keyChangeProducer,
}
- updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable
+ updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable
ap.Updater = updater
+
+ // Remove users which we don't share a room with anymore
+ if err := updater.CleanUp(); err != nil {
+ logrus.WithError(err).Error("failed to cleanup stale device lists")
+ }
+
go func() {
if err := updater.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start device list updater")
diff --git a/keyserver/keyserver_test.go b/keyserver/keyserver_test.go
new file mode 100644
index 00000000..159b280f
--- /dev/null
+++ b/keyserver/keyserver_test.go
@@ -0,0 +1,29 @@
+package keyserver
+
+import (
+ "context"
+ "testing"
+
+ roomserver "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
+)
+
+type mockKeyserverRoomserverAPI struct {
+ leftUsers []string
+}
+
+func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
+ res.LeftUsers = m.leftUsers
+ return nil
+}
+
+// Merely tests that we can create an internal keyserver API
+func Test_NewInternalAPI(t *testing.T) {
+ rsAPI := &mockKeyserverRoomserverAPI{}
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ base, closeBase := testrig.CreateBaseDendrite(t, dbType)
+ defer closeBase()
+ _ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
+ })
+}
diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go
index 242e16a0..c6a8f44c 100644
--- a/keyserver/storage/interface.go
+++ b/keyserver/storage/interface.go
@@ -85,4 +85,9 @@ type Database interface {
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
+
+ DeleteStaleDeviceLists(
+ ctx context.Context,
+ userIDs []string,
+ ) error
}
diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go
index d0fe50d0..248ddfb4 100644
--- a/keyserver/storage/postgres/stale_device_lists.go
+++ b/keyserver/storage/postgres/stale_device_lists.go
@@ -19,6 +19,10 @@ import (
"database/sql"
"time"
+ "github.com/lib/pq"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
@@ -48,10 +52,14 @@ const selectStaleDeviceListsWithDomainsSQL = "" +
const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
+const deleteStaleDevicesSQL = "" +
+ "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
+
type staleDeviceListsStatements struct {
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
+ deleteStaleDeviceListsStmt *sql.Stmt
}
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
@@ -60,16 +68,12 @@ func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, erro
if err != nil {
return nil, err
}
- if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
- return nil, err
- }
- if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
- return nil, err
- }
- if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
- return nil, err
- }
- return s, nil
+ return s, sqlutil.StatementList{
+ {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
+ {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
+ {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
+ {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL},
+ }.Prepare(db)
}
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
@@ -105,6 +109,15 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
return result, nil
}
+// DeleteStaleDeviceLists removes users from stale device lists
+func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
+ ctx context.Context, txn *sql.Tx, userIDs []string,
+) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt)
+ _, err := stmt.ExecContext(ctx, pq.Array(userIDs))
+ return err
+}
+
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() {
diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go
index 5beeed0f..54dd6ddc 100644
--- a/keyserver/storage/shared/storage.go
+++ b/keyserver/storage/shared/storage.go
@@ -249,3 +249,13 @@ func (d *Database) StoreCrossSigningSigsForTarget(
return nil
})
}
+
+// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
+func (d *Database) DeleteStaleDeviceLists(
+ ctx context.Context,
+ userIDs []string,
+) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
+ })
+}
diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go
index 1e08b266..fd76a6e3 100644
--- a/keyserver/storage/sqlite3/stale_device_lists.go
+++ b/keyserver/storage/sqlite3/stale_device_lists.go
@@ -17,8 +17,11 @@ package sqlite3
import (
"context"
"database/sql"
+ "strings"
"time"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
@@ -48,11 +51,15 @@ const selectStaleDeviceListsWithDomainsSQL = "" +
const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
+const deleteStaleDevicesSQL = "" +
+ "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)"
+
type staleDeviceListsStatements struct {
db *sql.DB
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
+ // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime
}
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
@@ -63,16 +70,12 @@ func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error)
if err != nil {
return nil, err
}
- if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
- return nil, err
- }
- if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
- return nil, err
- }
- if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
- return nil, err
- }
- return s, nil
+ return s, sqlutil.StatementList{
+ {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
+ {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
+ {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
+ // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime
+ }.Prepare(db)
}
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
@@ -108,6 +111,27 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
return result, nil
}
+// DeleteStaleDeviceLists removes users from stale device lists
+func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
+ ctx context.Context, txn *sql.Tx, userIDs []string,
+) error {
+ qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
+ stmt, err := s.db.Prepare(qry)
+ if err != nil {
+ return err
+ }
+ defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed")
+ stmt = sqlutil.TxStmt(txn, stmt)
+
+ params := make([]any, len(userIDs))
+ for i := range userIDs {
+ params[i] = userIDs[i]
+ }
+
+ _, err = stmt.ExecContext(ctx, params...)
+ return err
+}
+
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() {
diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go
index 37a010a7..24da1125 100644
--- a/keyserver/storage/tables/interface.go
+++ b/keyserver/storage/tables/interface.go
@@ -56,6 +56,7 @@ type KeyChanges interface {
type StaleDeviceLists interface {
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
+ DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
}
type CrossSigningKeys interface {
diff --git a/keyserver/storage/tables/stale_device_lists_test.go b/keyserver/storage/tables/stale_device_lists_test.go
new file mode 100644
index 00000000..76d3badd
--- /dev/null
+++ b/keyserver/storage/tables/stale_device_lists_test.go
@@ -0,0 +1,94 @@
+package tables_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/matrix-org/gomatrixserverlib"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
+ "github.com/matrix-org/dendrite/setup/config"
+
+ "github.com/matrix-org/dendrite/keyserver/storage/postgres"
+ "github.com/matrix-org/dendrite/keyserver/storage/tables"
+ "github.com/matrix-org/dendrite/test"
+)
+
+func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := sqlutil.Open(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, nil)
+ if err != nil {
+ t.Fatalf("failed to open database: %s", err)
+ }
+ switch dbType {
+ case test.DBTypePostgres:
+ tab, err = postgres.NewPostgresStaleDeviceListsTable(db)
+ case test.DBTypeSQLite:
+ tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db)
+ }
+ if err != nil {
+ t.Fatalf("failed to create new table: %s", err)
+ }
+ return tab, close
+}
+
+func TestStaleDeviceLists(t *testing.T) {
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+ charlie := "@charlie:localhost"
+ ctx := context.Background()
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ tab, closeDB := mustCreateTable(t, dbType)
+ defer closeDB()
+
+ if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil {
+ t.Fatalf("failed to insert stale device: %s", err)
+ }
+ if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil {
+ t.Fatalf("failed to insert stale device: %s", err)
+ }
+ if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil {
+ t.Fatalf("failed to insert stale device: %s", err)
+ }
+
+ // Query one server
+ wantStaleUsers := []string{alice.ID, bob.ID}
+ gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
+ if err != nil {
+ t.Fatalf("failed to query stale device lists: %s", err)
+ }
+ if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
+ t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
+ }
+
+ // Query all servers
+ wantStaleUsers = []string{alice.ID, bob.ID, charlie}
+ gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{})
+ if err != nil {
+ t.Fatalf("failed to query stale device lists: %s", err)
+ }
+ if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
+ t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
+ }
+
+ // Delete stale devices
+ deleteUsers := []string{alice.ID, bob.ID}
+ if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil {
+ t.Fatalf("failed to delete stale device lists: %s", err)
+ }
+
+ // Verify we don't get anything back after deleting
+ gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
+ if err != nil {
+ t.Fatalf("failed to query stale device lists: %s", err)
+ }
+
+ if gotCount := len(gotStaleUsers); gotCount > 0 {
+ t.Fatalf("expected no stale users, got %d", gotCount)
+ }
+ })
+}
diff --git a/roomserver/api/api.go b/roomserver/api/api.go
index 01e87ec8..420ef278 100644
--- a/roomserver/api/api.go
+++ b/roomserver/api/api.go
@@ -17,6 +17,7 @@ type RoomserverInternalAPI interface {
ClientRoomserverAPI
UserRoomserverAPI
FederationRoomserverAPI
+ KeyserverRoomserverAPI
// needed to avoid chicken and egg scenario when setting up the
// interdependencies between the roomserver and other input APIs
@@ -199,3 +200,7 @@ type FederationRoomserverAPI interface {
// Query a given amount (or less) of events prior to a given set of events.
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
}
+
+type KeyserverRoomserverAPI interface {
+ QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error
+}
diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go
index 342a3904..b23263d1 100644
--- a/roomserver/api/api_trace.go
+++ b/roomserver/api/api_trace.go
@@ -19,6 +19,12 @@ type RoomserverInternalAPITrace struct {
Impl RoomserverInternalAPI
}
+func (t *RoomserverInternalAPITrace) QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error {
+ err := t.Impl.QueryLeftUsers(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("QueryLeftUsers req=%+v res=%+v", js(req), js(res))
+ return err
+}
+
func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) {
t.Impl.SetFederationAPI(fsAPI, keyRing)
}
diff --git a/roomserver/api/query.go b/roomserver/api/query.go
index b62907f3..76f8298c 100644
--- a/roomserver/api/query.go
+++ b/roomserver/api/query.go
@@ -447,3 +447,15 @@ type QueryMembershipAtEventResponse struct {
// do not have known state will return an empty array here.
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
}
+
+// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a
+// a room with anymore. This is used to cleanup stale device list entries, where we would
+// otherwise keep on trying to get device lists.
+type QueryLeftUsersRequest struct {
+ StaleDeviceListUsers []string `json:"user_ids"`
+}
+
+// QueryLeftUsersResponse is the response to QueryLeftUsersRequest.
+type QueryLeftUsersResponse struct {
+ LeftUsers []string `json:"user_ids"`
+}
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index d8456fb4..69d841dd 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -805,6 +805,12 @@ func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkS
return nil
}
+func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersRequest, res *api.QueryLeftUsersResponse) error {
+ var err error
+ res.LeftUsers, err = r.DB.GetLeftUsers(ctx, req.StaleDeviceListUsers)
+ return err
+}
+
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
if err != nil {
diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go
index 1bd1b3fb..8a2e0a03 100644
--- a/roomserver/inthttp/client.go
+++ b/roomserver/inthttp/client.go
@@ -63,6 +63,7 @@ const (
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent"
+ RoomserverQueryLeftMembersPath = "/roomserver/queryLeftMembers"
)
type httpRoomserverInternalAPI struct {
@@ -553,3 +554,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context,
h.httpClient, ctx, request, response,
)
}
+
+func (h *httpRoomserverInternalAPI) QueryLeftUsers(ctx context.Context, request *api.QueryLeftUsersRequest, response *api.QueryLeftUsersResponse) error {
+ return httputil.CallInternalRPCAPI(
+ "RoomserverQueryLeftMembers", h.roomserverURL+RoomserverQueryLeftMembersPath,
+ h.httpClient, ctx, request, response,
+ )
+}
diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go
index 6e7c2d98..4d21909b 100644
--- a/roomserver/inthttp/server.go
+++ b/roomserver/inthttp/server.go
@@ -203,4 +203,9 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMe
RoomserverQueryMembershipAtEventPath,
httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", enableMetrics, r.QueryMembershipAtEvent),
)
+
+ internalAPIMux.Handle(
+ RoomserverQueryLeftMembersPath,
+ httputil.MakeInternalRPCAPI("RoomserverQueryLeftMembersPath", enableMetrics, r.QueryLeftUsers),
+ )
}
diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go
index 24b5515e..518bb372 100644
--- a/roomserver/roomserver_test.go
+++ b/roomserver/roomserver_test.go
@@ -2,20 +2,27 @@ package roomserver_test
import (
"context"
+ "net/http"
"testing"
+ "time"
+ "github.com/gorilla/mux"
+ "github.com/matrix-org/gomatrixserverlib"
+
+ "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/inthttp"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
- "github.com/matrix-org/gomatrixserverlib"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) {
+ t.Helper()
base, close := testrig.CreateBaseDendrite(t, dbType)
- db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches)
+ db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches)
if err != nil {
t.Fatalf("failed to create Database: %v", err)
}
@@ -67,3 +74,69 @@ func Test_SharedUsers(t *testing.T) {
}
})
}
+
+func Test_QueryLeftUsers(t *testing.T) {
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+ room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
+
+ // Invite and join Bob
+ room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "invite",
+ }, test.WithStateKey(bob.ID))
+ room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(bob.ID))
+
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ base, _, close := mustCreateDatabase(t, dbType)
+ defer close()
+
+ rsAPI := roomserver.NewInternalAPI(base)
+ // SetFederationAPI starts the room event input consumer
+ rsAPI.SetFederationAPI(nil, nil)
+ // Create the room
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Fatalf("failed to send events: %v", err)
+ }
+
+ // Query the left users, there should only be "@idontexist:test",
+ // as Alice and Bob are still joined.
+ res := &api.QueryLeftUsersResponse{}
+ leftUserID := "@idontexist:test"
+ getLeftUsersList := []string{alice.ID, bob.ID, leftUserID}
+
+ testCase := func(rsAPI api.RoomserverInternalAPI) {
+ if err := rsAPI.QueryLeftUsers(ctx, &api.QueryLeftUsersRequest{StaleDeviceListUsers: getLeftUsersList}, res); err != nil {
+ t.Fatalf("unable to query left users: %v", err)
+ }
+ wantCount := 1
+ if count := len(res.LeftUsers); count > wantCount {
+ t.Fatalf("unexpected left users count: want %d, got %d", wantCount, count)
+ }
+ if res.LeftUsers[0] != leftUserID {
+ t.Fatalf("unexpected left users : want %s, got %s", leftUserID, res.LeftUsers[0])
+ }
+ }
+
+ t.Run("HTTP API", func(t *testing.T) {
+ router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
+ roomserver.AddInternalRoutes(router, rsAPI, false)
+ apiURL, cancel := test.ListenAndServe(t, router, false)
+ defer cancel()
+ httpAPI, err := inthttp.NewRoomserverClient(apiURL, &http.Client{Timeout: time.Second * 5}, nil)
+ if err != nil {
+ t.Fatalf("failed to create HTTP client")
+ }
+ testCase(httpAPI)
+ })
+ t.Run("Monolith", func(t *testing.T) {
+ testCase(rsAPI)
+ // also test tracing
+ traceAPI := &api.RoomserverInternalAPITrace{Impl: rsAPI}
+ testCase(traceAPI)
+ })
+
+ })
+}
diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go
index 06db4b2d..92bc2e66 100644
--- a/roomserver/storage/interface.go
+++ b/roomserver/storage/interface.go
@@ -172,5 +172,6 @@ type Database interface {
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error)
+ GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error)
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
}
diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go
index 0150534e..d774b789 100644
--- a/roomserver/storage/postgres/membership_table.go
+++ b/roomserver/storage/postgres/membership_table.go
@@ -21,12 +21,13 @@ import (
"fmt"
"github.com/lib/pq"
+ "github.com/matrix-org/gomatrixserverlib"
+
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
- "github.com/matrix-org/gomatrixserverlib"
)
const membershipSchema = `
@@ -157,6 +158,12 @@ const selectServerInRoomSQL = "" +
" JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
" WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1"
+const selectJoinedUsersSQL = `
+SELECT DISTINCT target_nid
+FROM roomserver_membership m
+WHERE membership_nid > $1 AND target_nid = ANY($2)
+`
+
type membershipStatements struct {
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
@@ -174,6 +181,7 @@ type membershipStatements struct {
selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt
deleteMembershipStmt *sql.Stmt
+ selectJoinedUsersStmt *sql.Stmt
}
func CreateMembershipTable(db *sql.DB) error {
@@ -209,9 +217,33 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
{&s.selectServerInRoomStmt, selectServerInRoomSQL},
{&s.deleteMembershipStmt, deleteMembershipSQL},
+ {&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
}.Prepare(db)
}
+func (s *membershipStatements) SelectJoinedUsers(
+ ctx context.Context, txn *sql.Tx,
+ targetUserNIDs []types.EventStateKeyNID,
+) ([]types.EventStateKeyNID, error) {
+ result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs))
+
+ stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt)
+ rows, err := stmt.QueryContext(ctx, tables.MembershipStateLeaveOrBan, pq.Array(targetUserNIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed")
+ var targetNID types.EventStateKeyNID
+ for rows.Next() {
+ if err = rows.Scan(&targetNID); err != nil {
+ return nil, err
+ }
+ result = append(result, targetNID)
+ }
+
+ return result, rows.Err()
+}
+
func (s *membershipStatements) InsertMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 16898bcb..725cc5bc 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -1365,6 +1365,43 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
return result, nil
}
+// GetLeftUsers calculates users we (the server) don't share a room with anymore.
+func (d *Database) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) {
+ // Get the userNID for all users with a stale device list
+ stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, userIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ userNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap))
+ userNIDtoUserID := make(map[types.EventStateKeyNID]string, len(stateKeyNIDMap))
+ // Create a map from userNID -> userID
+ for userID, nid := range stateKeyNIDMap {
+ userNIDs = append(userNIDs, nid)
+ userNIDtoUserID[nid] = userID
+ }
+
+ // Get all users whose membership is still join, knock or invite.
+ stillJoinedUsersNIDs, err := d.MembershipTable.SelectJoinedUsers(ctx, nil, userNIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ // Remove joined users from the "user with stale devices" list, which contains left AND joined users
+ for _, joinedUser := range stillJoinedUsersNIDs {
+ delete(userNIDtoUserID, joinedUser)
+ }
+
+ // The users still in our userNIDtoUserID map are the users we don't share a room with anymore,
+ // and the return value we are looking for.
+ leftUsers := make([]string, 0, len(userNIDtoUserID))
+ for _, userID := range userNIDtoUserID {
+ leftUsers = append(leftUsers, userID)
+ }
+
+ return leftUsers, nil
+}
+
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)
diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go
new file mode 100644
index 00000000..58724340
--- /dev/null
+++ b/roomserver/storage/shared/storage_test.go
@@ -0,0 +1,96 @@
+package shared_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/storage/postgres"
+ "github.com/matrix-org/dendrite/roomserver/storage/shared"
+ "github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
+)
+
+func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Database, func()) {
+ t.Helper()
+
+ connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
+ base, _, _ := testrig.Base(nil)
+ dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}
+
+ db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter())
+ assert.NoError(t, err)
+
+ var membershipTable tables.Membership
+ var stateKeyTable tables.EventStateKeys
+ switch dbType {
+ case test.DBTypePostgres:
+ err = postgres.CreateEventStateKeysTable(db)
+ assert.NoError(t, err)
+ err = postgres.CreateMembershipTable(db)
+ assert.NoError(t, err)
+ membershipTable, err = postgres.PrepareMembershipTable(db)
+ assert.NoError(t, err)
+ stateKeyTable, err = postgres.PrepareEventStateKeysTable(db)
+ case test.DBTypeSQLite:
+ err = sqlite3.CreateEventStateKeysTable(db)
+ assert.NoError(t, err)
+ err = sqlite3.CreateMembershipTable(db)
+ assert.NoError(t, err)
+ membershipTable, err = sqlite3.PrepareMembershipTable(db)
+ assert.NoError(t, err)
+ stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db)
+ }
+ assert.NoError(t, err)
+
+ return &shared.Database{
+ DB: db,
+ EventStateKeysTable: stateKeyTable,
+ MembershipTable: membershipTable,
+ Writer: sqlutil.NewExclusiveWriter(),
+ }, func() {
+ err := base.Close()
+ assert.NoError(t, err)
+ clearDB()
+ err = db.Close()
+ assert.NoError(t, err)
+ }
+}
+
+func Test_GetLeftUsers(t *testing.T) {
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+ charlie := test.NewUser(t)
+
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateRoomserverDatabase(t, dbType)
+ defer close()
+
+ // Create dummy entries
+ for _, user := range []*test.User{alice, bob, charlie} {
+ nid, err := db.EventStateKeysTable.InsertEventStateKeyNID(ctx, nil, user.ID)
+ assert.NoError(t, err)
+ err = db.MembershipTable.InsertMembership(ctx, nil, 1, nid, true)
+ assert.NoError(t, err)
+ // We must update the membership with a non-zero event NID or it will get filtered out in later queries
+ membershipNID := tables.MembershipStateLeaveOrBan
+ if user == alice {
+ membershipNID = tables.MembershipStateJoin
+ }
+ _, err = db.MembershipTable.UpdateMembership(ctx, nil, 1, nid, nid, membershipNID, 1, false)
+ assert.NoError(t, err)
+ }
+
+ // Now try to get the left users, this should be Bob and Charlie, since they have a "leave" membership
+ expectedUserIDs := []string{bob.ID, charlie.ID}
+ leftUsers, err := db.GetLeftUsers(context.Background(), []string{alice.ID, bob.ID, charlie.ID})
+ assert.NoError(t, err)
+ assert.ElementsMatch(t, expectedUserIDs, leftUsers)
+ })
+}
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
index cd149f0e..8a60b359 100644
--- a/roomserver/storage/sqlite3/membership_table.go
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -21,12 +21,13 @@ import (
"fmt"
"strings"
+ "github.com/matrix-org/gomatrixserverlib"
+
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
- "github.com/matrix-org/gomatrixserverlib"
)
const membershipSchema = `
@@ -133,6 +134,12 @@ const selectServerInRoomSQL = "" +
const deleteMembershipSQL = "" +
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
+const selectJoinedUsersSQL = `
+SELECT DISTINCT target_nid
+FROM roomserver_membership m
+WHERE membership_nid > $1 AND target_nid IN ($2)
+`
+
type membershipStatements struct {
db *sql.DB
insertMembershipStmt *sql.Stmt
@@ -149,6 +156,7 @@ type membershipStatements struct {
selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt
deleteMembershipStmt *sql.Stmt
+ // selectJoinedUsersStmt *sql.Stmt // Prepared at runtime
}
func CreateMembershipTable(db *sql.DB) error {
@@ -412,3 +420,40 @@ func (s *membershipStatements) DeleteMembership(
)
return err
}
+
+func (s *membershipStatements) SelectJoinedUsers(
+ ctx context.Context, txn *sql.Tx,
+ targetUserNIDs []types.EventStateKeyNID,
+) ([]types.EventStateKeyNID, error) {
+ result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs))
+
+ qry := strings.Replace(selectJoinedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(targetUserNIDs), 1), 1)
+
+ stmt, err := s.db.Prepare(qry)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsers: stmt.Close failed")
+
+ params := make([]any, len(targetUserNIDs)+1)
+ params[0] = tables.MembershipStateLeaveOrBan
+ for i := range targetUserNIDs {
+ params[i+1] = targetUserNIDs[i]
+ }
+
+ stmt = sqlutil.TxStmt(txn, stmt)
+ rows, err := stmt.QueryContext(ctx, params...)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed")
+ var targetNID types.EventStateKeyNID
+ for rows.Next() {
+ if err = rows.Scan(&targetNID); err != nil {
+ return nil, err
+ }
+ result = append(result, targetNID)
+ }
+
+ return result, rows.Err()
+}
diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go
index 50d27c75..80fcf72d 100644
--- a/roomserver/storage/tables/interface.go
+++ b/roomserver/storage/tables/interface.go
@@ -144,6 +144,7 @@ type Membership interface {
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error
+ SelectJoinedUsers(ctx context.Context, txn *sql.Tx, targetUserNIDs []types.EventStateKeyNID) ([]types.EventStateKeyNID, error)
}
type Published interface {
diff --git a/roomserver/storage/tables/membership_table_test.go b/roomserver/storage/tables/membership_table_test.go
index c9541d9d..c4524ee4 100644
--- a/roomserver/storage/tables/membership_table_test.go
+++ b/roomserver/storage/tables/membership_table_test.go
@@ -129,5 +129,11 @@ func TestMembershipTable(t *testing.T) {
knownUsers, err := tab.SelectKnownUsers(ctx, nil, userNIDs[0], "localhost", 2)
assert.NoError(t, err)
assert.Equal(t, 1, len(knownUsers))
+
+ // get users we share a room with, given their userNID
+ joinedUsers, err := tab.SelectJoinedUsers(ctx, nil, userNIDs)
+ assert.NoError(t, err)
+ // Only userNIDs[0] is actually joined, so we only expect this userNID
+ assert.Equal(t, userNIDs[:1], joinedUsers)
})
}