aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--appservice/appservice.go2
-rw-r--r--appservice/appservice_test.go2
-rw-r--r--build/dendritejs-pinecone/main.go6
-rw-r--r--build/gobind-yggdrasil/monolith.go6
-rw-r--r--clientapi/admin_test.go15
-rw-r--r--clientapi/clientapi.go6
-rw-r--r--clientapi/routing/admin.go23
-rw-r--r--clientapi/routing/joinroom_test.go4
-rw-r--r--clientapi/routing/key_crosssigning.go9
-rw-r--r--clientapi/routing/keys.go7
-rw-r--r--clientapi/routing/login_test.go7
-rw-r--r--clientapi/routing/register_test.go13
-rw-r--r--clientapi/routing/routing.go18
-rw-r--r--cmd/dendrite-demo-pinecone/monolith/monolith.go6
-rw-r--r--cmd/dendrite-demo-yggdrasil/main.go7
-rw-r--r--cmd/dendrite/main.go8
-rw-r--r--federationapi/consumers/keychange.go2
-rw-r--r--federationapi/federationapi.go6
-rw-r--r--federationapi/federationapi_test.go10
-rw-r--r--federationapi/producers/syncapi.go2
-rw-r--r--federationapi/routing/devices.go12
-rw-r--r--federationapi/routing/keys.go2
-rw-r--r--federationapi/routing/profile_test.go2
-rw-r--r--federationapi/routing/query_test.go2
-rw-r--r--federationapi/routing/routing.go10
-rw-r--r--federationapi/routing/send.go4
-rw-r--r--federationapi/routing/send_test.go2
-rw-r--r--internal/transactionrequest.go8
-rw-r--r--internal/transactionrequest_test.go2
-rw-r--r--keyserver/README.md19
-rw-r--r--keyserver/api/api.go346
-rw-r--r--keyserver/keyserver.go86
-rw-r--r--keyserver/keyserver_test.go29
-rw-r--r--keyserver/storage/interface.go93
-rw-r--r--keyserver/storage/postgres/storage.go69
-rw-r--r--keyserver/storage/shared/storage.go261
-rw-r--r--keyserver/storage/sqlite3/storage.go68
-rw-r--r--keyserver/storage/storage.go40
-rw-r--r--keyserver/storage/storage_test.go197
-rw-r--r--keyserver/storage/storage_wasm.go34
-rw-r--r--keyserver/storage/tables/interface.go71
-rw-r--r--roomserver/api/api.go2
-rw-r--r--roomserver/roomserver_test.go8
-rw-r--r--setup/monolith.go15
-rw-r--r--syncapi/consumers/keychange.go2
-rw-r--r--syncapi/consumers/sendtodevice.go10
-rw-r--r--syncapi/internal/keychange.go16
-rw-r--r--syncapi/internal/keychange_test.go23
-rw-r--r--syncapi/streams/stream_devicelist.go10
-rw-r--r--syncapi/streams/streams.go5
-rw-r--r--syncapi/sync/requestpool.go9
-rw-r--r--syncapi/syncapi.go8
-rw-r--r--syncapi/syncapi_test.go26
-rw-r--r--userapi/api/api.go329
-rw-r--r--userapi/consumers/clientapi.go4
-rw-r--r--userapi/consumers/devicelistupdate.go (renamed from keyserver/consumers/devicelistupdate.go)4
-rw-r--r--userapi/consumers/roomserver.go4
-rw-r--r--userapi/consumers/roomserver_test.go4
-rw-r--r--userapi/consumers/signingkeyupdate.go (renamed from keyserver/consumers/signingkeyupdate.go)21
-rw-r--r--userapi/internal/cross_signing.go (renamed from keyserver/internal/cross_signing.go)40
-rw-r--r--userapi/internal/device_list_update.go (renamed from keyserver/internal/device_list_update.go)2
-rw-r--r--userapi/internal/device_list_update_default.go (renamed from keyserver/internal/device_list_update_default.go)0
-rw-r--r--userapi/internal/device_list_update_sytest.go (renamed from keyserver/internal/device_list_update_sytest.go)0
-rw-r--r--userapi/internal/device_list_update_test.go (renamed from keyserver/internal/device_list_update_test.go)8
-rw-r--r--userapi/internal/key_api.go (renamed from keyserver/internal/internal.go)114
-rw-r--r--userapi/internal/key_api_test.go (renamed from keyserver/internal/internal_test.go)21
-rw-r--r--userapi/internal/user_api.go (renamed from userapi/internal/api.go)36
-rw-r--r--userapi/producers/keychange.go (renamed from keyserver/producers/keychange.go)8
-rw-r--r--userapi/producers/syncapi.go4
-rw-r--r--userapi/storage/interface.go76
-rw-r--r--userapi/storage/postgres/account_data_table.go8
-rw-r--r--userapi/storage/postgres/cross_signing_keys_table.go (renamed from keyserver/storage/postgres/cross_signing_keys_table.go)4
-rw-r--r--userapi/storage/postgres/cross_signing_sigs_table.go (renamed from keyserver/storage/postgres/cross_signing_sigs_table.go)6
-rw-r--r--userapi/storage/postgres/deltas/2022012016470000_key_changes.go (renamed from keyserver/storage/postgres/deltas/2022012016470000_key_changes.go)0
-rw-r--r--userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go (renamed from keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go)0
-rw-r--r--userapi/storage/postgres/device_keys_table.go (renamed from keyserver/storage/postgres/device_keys_table.go)39
-rw-r--r--userapi/storage/postgres/devices_table.go8
-rw-r--r--userapi/storage/postgres/key_backup_table.go4
-rw-r--r--userapi/storage/postgres/key_changes_table.go (renamed from keyserver/storage/postgres/key_changes_table.go)19
-rw-r--r--userapi/storage/postgres/one_time_keys_table.go (renamed from keyserver/storage/postgres/one_time_keys_table.go)33
-rw-r--r--userapi/storage/postgres/stale_device_lists.go (renamed from keyserver/storage/postgres/stale_device_lists.go)2
-rw-r--r--userapi/storage/postgres/storage.go41
-rw-r--r--userapi/storage/shared/storage.go235
-rw-r--r--userapi/storage/sqlite3/cross_signing_keys_table.go (renamed from keyserver/storage/sqlite3/cross_signing_keys_table.go)4
-rw-r--r--userapi/storage/sqlite3/cross_signing_sigs_table.go (renamed from keyserver/storage/sqlite3/cross_signing_sigs_table.go)6
-rw-r--r--userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go (renamed from keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go)0
-rw-r--r--userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go (renamed from keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go)0
-rw-r--r--userapi/storage/sqlite3/device_keys_table.go (renamed from keyserver/storage/sqlite3/device_keys_table.go)36
-rw-r--r--userapi/storage/sqlite3/devices_table.go17
-rw-r--r--userapi/storage/sqlite3/key_backup_table.go4
-rw-r--r--userapi/storage/sqlite3/key_changes_table.go (renamed from keyserver/storage/sqlite3/key_changes_table.go)19
-rw-r--r--userapi/storage/sqlite3/one_time_keys_table.go (renamed from keyserver/storage/sqlite3/one_time_keys_table.go)33
-rw-r--r--userapi/storage/sqlite3/stale_device_lists.go (renamed from keyserver/storage/sqlite3/stale_device_lists.go)2
-rw-r--r--userapi/storage/sqlite3/stats_table.go3
-rw-r--r--userapi/storage/sqlite3/storage.go45
-rw-r--r--userapi/storage/storage.go27
-rw-r--r--userapi/storage/storage_test.go210
-rw-r--r--userapi/storage/storage_wasm.go4
-rw-r--r--userapi/storage/tables/interface.go46
-rw-r--r--userapi/storage/tables/stale_device_lists_test.go (renamed from keyserver/storage/tables/stale_device_lists_test.go)6
-rw-r--r--userapi/types/storage.go (renamed from keyserver/types/storage.go)0
-rw-r--r--userapi/userapi.go62
-rw-r--r--userapi/userapi_test.go327
-rw-r--r--userapi/util/devices.go2
-rw-r--r--userapi/util/notify.go4
-rw-r--r--userapi/util/notify_test.go2
-rw-r--r--userapi/util/phonehomestats_test.go2
107 files changed, 1726 insertions, 1859 deletions
diff --git a/appservice/appservice.go b/appservice/appservice.go
index b950a821..5b1b93de 100644
--- a/appservice/appservice.go
+++ b/appservice/appservice.go
@@ -38,7 +38,7 @@ import (
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
base *base.BaseDendrite,
- userAPI userapi.UserInternalAPI,
+ userAPI userapi.AppserviceUserAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
) appserviceAPI.AppServiceInternalAPI {
client := &http.Client{
diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go
index 9e9940cd..de9f5aaf 100644
--- a/appservice/appservice_test.go
+++ b/appservice/appservice_test.go
@@ -125,7 +125,7 @@ func TestAppserviceInternalAPI(t *testing.T) {
// Create required internal APIs
rsAPI := roomserver.NewInternalAPI(base)
- usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil)
+ usrAPI := userapi.NewInternalAPI(base, rsAPI, nil)
asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI)
runCases(t, asAPI)
diff --git a/build/dendritejs-pinecone/main.go b/build/dendritejs-pinecone/main.go
index f3fcb03e..44e52286 100644
--- a/build/dendritejs-pinecone/main.go
+++ b/build/dendritejs-pinecone/main.go
@@ -30,7 +30,6 @@ import (
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/internal/httputil"
- "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
@@ -183,13 +182,11 @@ func startup() {
rsAPI := roomserver.NewInternalAPI(base)
federation := conn.CreateFederationClient(base, pSessions)
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
- userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
- keyAPI.SetUserAPI(userAPI)
+ userAPI := userapi.NewInternalAPI(base, rsAPI, federation)
asQuery := appservice.NewInternalAPI(
base, userAPI, rsAPI,
@@ -208,7 +205,6 @@ func startup() {
FederationAPI: fedSenderAPI,
RoomserverAPI: rsAPI,
UserAPI: userAPI,
- KeyAPI: keyAPI,
//ServerKeyAPI: serverKeyAPI,
ExtPublicRoomsProvider: rooms.NewPineconeRoomProvider(pRouter, pSessions, fedSenderAPI, federation),
}
diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go
index fad850a9..32af611a 100644
--- a/build/gobind-yggdrasil/monolith.go
+++ b/build/gobind-yggdrasil/monolith.go
@@ -20,7 +20,6 @@ import (
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil"
- "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
@@ -165,9 +164,7 @@ func (m *DendriteMonolith) Start() {
base, federation, rsAPI, base.Caches, keyRing, true,
)
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
- userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
- keyAPI.SetUserAPI(userAPI)
+ userAPI := userapi.NewInternalAPI(base, rsAPI, federation)
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
rsAPI.SetAppserviceAPI(asAPI)
@@ -186,7 +183,6 @@ func (m *DendriteMonolith) Start() {
FederationAPI: fsAPI,
RoomserverAPI: rsAPI,
UserAPI: userAPI,
- KeyAPI: keyAPI,
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
ygg, fsAPI, federation,
),
diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go
index c7ca019f..300d3a88 100644
--- a/clientapi/admin_test.go
+++ b/clientapi/admin_test.go
@@ -8,7 +8,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/federationapi"
- "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
@@ -40,11 +39,9 @@ func TestAdminResetPassword(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(base)
// Needed for changing the password/login
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
- userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
- keyAPI.SetUserAPI(userAPI)
+ userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
- AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil)
+ AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil)
// Create the users in the userapi and login
accessTokens := map[*test.User]string{
@@ -155,14 +152,12 @@ func TestPurgeRoom(t *testing.T) {
fedClient := base.CreateFederationClient()
rsAPI := roomserver.NewInternalAPI(base)
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI)
- userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
+ userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
// this starts the JetStream consumers
- syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI)
+ syncapi.AddPublicRoutes(base, userAPI, rsAPI)
federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true)
rsAPI.SetFederationAPI(nil, nil)
- keyAPI.SetUserAPI(userAPI)
// Create the room
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
@@ -170,7 +165,7 @@ func TestPurgeRoom(t *testing.T) {
}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
- AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil)
+ AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil)
// Create the users in the userapi and login
accessTokens := map[*test.User]string{
diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go
index 2d17e092..e9985d43 100644
--- a/clientapi/clientapi.go
+++ b/clientapi/clientapi.go
@@ -15,6 +15,7 @@
package clientapi
import (
+ userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
@@ -23,11 +24,9 @@ import (
"github.com/matrix-org/dendrite/clientapi/routing"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/transactions"
- keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream"
- userapi "github.com/matrix-org/dendrite/userapi/api"
)
// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component.
@@ -40,7 +39,6 @@ func AddPublicRoutes(
fsAPI federationAPI.ClientFederationAPI,
userAPI userapi.ClientUserAPI,
userDirectoryProvider userapi.QuerySearchProfilesAPI,
- keyAPI keyserverAPI.ClientKeyAPI,
extRoomsProvider api.ExtraPublicRoomsProvider,
) {
cfg := &base.Cfg.ClientAPI
@@ -61,7 +59,7 @@ func AddPublicRoutes(
base,
cfg, rsAPI, asAPI,
userAPI, userDirectoryProvider, federation,
- syncProducer, transactionsCache, fsAPI, keyAPI,
+ syncProducer, transactionsCache, fsAPI,
extRoomsProvider, mscCfg, natsClient,
)
}
diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go
index 4b4dedfd..a01f6b94 100644
--- a/clientapi/routing/admin.go
+++ b/clientapi/routing/admin.go
@@ -16,14 +16,13 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/httputil"
- "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
- userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/api"
)
-func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
+func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@@ -56,7 +55,7 @@ func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi
}
}
-func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
+func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@@ -99,7 +98,7 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi
}
}
-func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
+func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@@ -130,7 +129,7 @@ func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.De
}
}
-func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
+func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.Device, userAPI api.ClientUserAPI) util.JSONResponse {
if req.Body == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
@@ -150,8 +149,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
JSON: jsonerror.InvalidArgumentValue(err.Error()),
}
}
- accAvailableResp := &userapi.QueryAccountAvailabilityResponse{}
- if err = userAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{
+ accAvailableResp := &api.QueryAccountAvailabilityResponse{}
+ if err = userAPI.QueryAccountAvailability(req.Context(), &api.QueryAccountAvailabilityRequest{
Localpart: localpart,
ServerName: serverName,
}, accAvailableResp); err != nil {
@@ -186,13 +185,13 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
return *internal.PasswordResponse(err)
}
- updateReq := &userapi.PerformPasswordUpdateRequest{
+ updateReq := &api.PerformPasswordUpdateRequest{
Localpart: localpart,
ServerName: serverName,
Password: request.Password,
LogoutDevices: true,
}
- updateRes := &userapi.PerformPasswordUpdateResponse{}
+ updateRes := &api.PerformPasswordUpdateResponse{}
if err := userAPI.PerformPasswordUpdate(req.Context(), updateReq, updateRes); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
@@ -209,7 +208,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
}
}
-func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, natsClient *nats.Conn) util.JSONResponse {
+func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *api.Device, natsClient *nats.Conn) util.JSONResponse {
_, err := natsClient.RequestMsg(nats.NewMsg(cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex)), time.Second*10)
if err != nil {
logrus.WithError(err).Error("failed to publish nats message")
@@ -255,7 +254,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien
}
}
-func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
+func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go
index 9e8208e6..1450ef4b 100644
--- a/clientapi/routing/joinroom_test.go
+++ b/clientapi/routing/joinroom_test.go
@@ -10,7 +10,6 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/appservice"
- "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
@@ -29,8 +28,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) {
defer baseClose()
rsAPI := roomserver.NewInternalAPI(base)
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
- userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
+ userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc
diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go
index 2570db09..267ba1dc 100644
--- a/clientapi/routing/key_crosssigning.go
+++ b/clientapi/routing/key_crosssigning.go
@@ -21,9 +21,8 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
- "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
- userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
@@ -34,8 +33,8 @@ type crossSigningRequest struct {
func UploadCrossSigningDeviceKeys(
req *http.Request, userInteractiveAuth *auth.UserInteractive,
- keyserverAPI api.ClientKeyAPI, device *userapi.Device,
- accountAPI userapi.ClientUserAPI, cfg *config.ClientAPI,
+ keyserverAPI api.ClientKeyAPI, device *api.Device,
+ accountAPI api.ClientUserAPI, cfg *config.ClientAPI,
) util.JSONResponse {
uploadReq := &crossSigningRequest{}
uploadRes := &api.PerformUploadDeviceKeysResponse{}
@@ -107,7 +106,7 @@ func UploadCrossSigningDeviceKeys(
}
}
-func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse {
+func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
uploadReq := &api.PerformUploadDeviceSignaturesRequest{}
uploadRes := &api.PerformUploadDeviceSignaturesResponse{}
diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go
index 0c12b111..3d60fcc3 100644
--- a/clientapi/routing/keys.go
+++ b/clientapi/routing/keys.go
@@ -23,8 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
- "github.com/matrix-org/dendrite/keyserver/api"
- userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/api"
)
type uploadKeysRequest struct {
@@ -32,7 +31,7 @@ type uploadKeysRequest struct {
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
}
-func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse {
+func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
var r uploadKeysRequest
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
@@ -106,7 +105,7 @@ func (r *queryKeysRequest) GetTimeout() time.Duration {
return timeout
}
-func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse {
+func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
var r queryKeysRequest
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
diff --git a/clientapi/routing/login_test.go b/clientapi/routing/login_test.go
index d429d7f8..b72db9d8 100644
--- a/clientapi/routing/login_test.go
+++ b/clientapi/routing/login_test.go
@@ -9,7 +9,6 @@ import (
"testing"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
- "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
@@ -39,12 +38,10 @@ func TestLogin(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(base)
// Needed for /login
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
- userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
- keyAPI.SetUserAPI(userAPI)
+ userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
- Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, keyAPI, nil, &base.Cfg.MSCs, nil)
+ Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, &base.Cfg.MSCs, nil)
// Create password
password := util.RandomString(8)
diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go
index 670c392b..651e3d3d 100644
--- a/clientapi/routing/register_test.go
+++ b/clientapi/routing/register_test.go
@@ -30,7 +30,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
@@ -409,9 +408,7 @@ func Test_register(t *testing.T) {
defer baseClose()
rsAPI := roomserver.NewInternalAPI(base)
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
- userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
- keyAPI.SetUserAPI(userAPI)
+ userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
@@ -582,9 +579,7 @@ func TestRegisterUserWithDisplayName(t *testing.T) {
base.Cfg.Global.ServerName = "server"
rsAPI := roomserver.NewInternalAPI(base)
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
- userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
- keyAPI.SetUserAPI(userAPI)
+ userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
deviceName, deviceID := "deviceName", "deviceID"
expectedDisplayName := "DisplayName"
response := completeRegistration(
@@ -623,9 +618,7 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) {
base.Cfg.ClientAPI.RegistrationSharedSecret = sharedSecret
rsAPI := roomserver.NewInternalAPI(base)
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
- userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
- keyAPI.SetUserAPI(userAPI)
+ userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
expectedDisplayName := "rabbit"
jsonStr := []byte(`{"admin":true,"mac":"24dca3bba410e43fe64b9b5c28306693bf3baa9f","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`)
diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go
index 66610c0a..028d02e9 100644
--- a/clientapi/routing/routing.go
+++ b/clientapi/routing/routing.go
@@ -22,6 +22,7 @@ import (
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/setup/base"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
@@ -37,11 +38,9 @@ import (
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/transactions"
- keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
- userapi "github.com/matrix-org/dendrite/userapi/api"
)
// Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client
@@ -61,7 +60,6 @@ func Setup(
syncProducer *producers.SyncAPIProducer,
transactionsCache *transactions.Cache,
federationSender federationAPI.ClientFederationAPI,
- keyAPI keyserverAPI.ClientKeyAPI,
extRoomsProvider api.ExtraPublicRoomsProvider,
mscCfg *config.MSCs, natsClient *nats.Conn,
) {
@@ -192,7 +190,7 @@ func Setup(
dendriteAdminRouter.Handle("/admin/refreshDevices/{userID}",
httputil.MakeAdminAPI("admin_refresh_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
- return AdminMarkAsStale(req, cfg, keyAPI)
+ return AdminMarkAsStale(req, cfg, userAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
@@ -1372,11 +1370,11 @@ func Setup(
// Cross-signing device keys
postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
- return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg)
+ return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, userAPI, device, userAPI, cfg)
})
postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
- return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
+ return UploadCrossSigningDeviceSignatures(req, userAPI, device)
}, httputil.WithAllowGuests())
v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
@@ -1388,22 +1386,22 @@ func Setup(
// Supplying a device ID is deprecated.
v3mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
- return UploadKeys(req, keyAPI, device)
+ return UploadKeys(req, userAPI, device)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
- return UploadKeys(req, keyAPI, device)
+ return UploadKeys(req, userAPI, device)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/query",
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
- return QueryKeys(req, keyAPI, device)
+ return QueryKeys(req, userAPI, device)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/claim",
httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
- return ClaimKeys(req, keyAPI)
+ return ClaimKeys(req, userAPI)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}",
diff --git a/cmd/dendrite-demo-pinecone/monolith/monolith.go b/cmd/dendrite-demo-pinecone/monolith/monolith.go
index fe19593c..27720369 100644
--- a/cmd/dendrite-demo-pinecone/monolith/monolith.go
+++ b/cmd/dendrite-demo-pinecone/monolith/monolith.go
@@ -38,7 +38,6 @@ import (
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/internal/httputil"
- "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/relayapi"
relayAPI "github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/dendrite/roomserver"
@@ -139,9 +138,7 @@ func (p *P2PMonolith) SetupDendrite(cfg *config.Dendrite, port int, enableRelayi
p.BaseDendrite, federation, rsAPI, p.BaseDendrite.Caches, keyRing, true,
)
- keyAPI := keyserver.NewInternalAPI(p.BaseDendrite, &p.BaseDendrite.Cfg.KeyServer, fsAPI, rsComponent)
- userAPI := userapi.NewInternalAPI(p.BaseDendrite, &cfg.UserAPI, nil, keyAPI, rsAPI, p.BaseDendrite.PushGatewayHTTPClient())
- keyAPI.SetUserAPI(userAPI)
+ userAPI := userapi.NewInternalAPI(p.BaseDendrite, rsAPI, federation)
asAPI := appservice.NewInternalAPI(p.BaseDendrite, userAPI, rsAPI)
@@ -175,7 +172,6 @@ func (p *P2PMonolith) SetupDendrite(cfg *config.Dendrite, port int, enableRelayi
FederationAPI: fsAPI,
RoomserverAPI: rsAPI,
UserAPI: userAPI,
- KeyAPI: keyAPI,
RelayAPI: relayAPI,
ExtPublicRoomsProvider: roomProvider,
ExtUserDirectoryProvider: userProvider,
diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go
index 842682b4..d759c6a7 100644
--- a/cmd/dendrite-demo-yggdrasil/main.go
+++ b/cmd/dendrite-demo-yggdrasil/main.go
@@ -39,7 +39,6 @@ import (
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/httputil"
- "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
@@ -161,10 +160,7 @@ func main() {
base,
)
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
-
- userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
- keyAPI.SetUserAPI(userAPI)
+ userAPI := userapi.NewInternalAPI(base, rsAPI, federation)
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
rsAPI.SetAppserviceAPI(asAPI)
@@ -184,7 +180,6 @@ func main() {
FederationAPI: fsAPI,
RoomserverAPI: rsAPI,
UserAPI: userAPI,
- KeyAPI: keyAPI,
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
ygg, fsAPI, federation,
),
diff --git a/cmd/dendrite/main.go b/cmd/dendrite/main.go
index 35bfc1d6..e8ff0a47 100644
--- a/cmd/dendrite/main.go
+++ b/cmd/dendrite/main.go
@@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/federationapi"
- "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup"
basepkg "github.com/matrix-org/dendrite/setup/base"
@@ -56,10 +55,7 @@ func main() {
keyRing := fsAPI.KeyRing()
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
-
- pgClient := base.PushGatewayHTTPClient()
- userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient)
+ userAPI := userapi.NewInternalAPI(base, rsAPI, federation)
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
@@ -69,7 +65,6 @@ func main() {
rsAPI.SetFederationAPI(fsAPI, keyRing)
rsAPI.SetAppserviceAPI(asAPI)
rsAPI.SetUserAPI(userAPI)
- keyAPI.SetUserAPI(userAPI)
monolith := setup.Monolith{
Config: base.Cfg,
@@ -83,7 +78,6 @@ func main() {
FederationAPI: fsAPI,
RoomserverAPI: rsAPI,
UserAPI: userAPI,
- KeyAPI: keyAPI,
}
monolith.AddAllPublicRoutes(base)
diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go
index 601257d4..7d9df3d7 100644
--- a/federationapi/consumers/keychange.go
+++ b/federationapi/consumers/keychange.go
@@ -26,11 +26,11 @@ import (
"github.com/matrix-org/dendrite/federationapi/queue"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/types"
- "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/userapi/api"
)
// KeyChangeConsumer consumes events that originate in key server.
diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go
index 10803916..ec482659 100644
--- a/federationapi/federationapi.go
+++ b/federationapi/federationapi.go
@@ -28,7 +28,6 @@ import (
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/internal/caching"
- keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream"
@@ -42,12 +41,11 @@ import (
// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component.
func AddPublicRoutes(
base *base.BaseDendrite,
- userAPI userapi.UserInternalAPI,
+ userAPI userapi.FederationUserAPI,
federation *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.JSONVerifier,
rsAPI roomserverAPI.FederationRoomserverAPI,
fedAPI federationAPI.FederationInternalAPI,
- keyAPI keyserverAPI.FederationKeyAPI,
servers federationAPI.ServersInRoomProvider,
) {
cfg := &base.Cfg.FederationAPI
@@ -79,7 +77,7 @@ func AddPublicRoutes(
routing.Setup(
base,
rsAPI, f, keyRing,
- federation, userAPI, keyAPI, mscCfg,
+ federation, userAPI, mscCfg,
servers, producer,
)
}
diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go
index 8d1d8514..57d4b964 100644
--- a/federationapi/federationapi_test.go
+++ b/federationapi/federationapi_test.go
@@ -17,13 +17,13 @@ import (
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/internal"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
)
type fedRoomserverAPI struct {
@@ -230,9 +230,9 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
// Inject a keyserver key change event and ensure we try to send it out. If we don't, then the
// federationapi is incorrectly waiting for an output room event to arrive to update the joined
// hosts table.
- key := keyapi.DeviceMessage{
- Type: keyapi.TypeDeviceKeyUpdate,
- DeviceKeys: &keyapi.DeviceKeys{
+ key := userapi.DeviceMessage{
+ Type: userapi.TypeDeviceKeyUpdate,
+ DeviceKeys: &userapi.DeviceKeys{
UserID: joiningUser.ID,
DeviceID: "MY_DEVICE",
DisplayName: "BLARGLE",
@@ -277,7 +277,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
keyRing := &test.NopJSONVerifier{}
// TODO: This is pretty fragile, as if anything calls anything on these nils this test will break.
// Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing.
- federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, nil)
+ federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil)
baseURL, cancel := test.ListenAndServe(t, b.PublicFederationAPIMux, true)
defer cancel()
serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://"))
diff --git a/federationapi/producers/syncapi.go b/federationapi/producers/syncapi.go
index 7cce13a7..6bcfafa3 100644
--- a/federationapi/producers/syncapi.go
+++ b/federationapi/producers/syncapi.go
@@ -41,7 +41,7 @@ type SyncAPIProducer struct {
TopicSigningKeyUpdate string
JetStream nats.JetStreamContext
Config *config.FederationAPI
- UserAPI userapi.UserInternalAPI
+ UserAPI userapi.FederationUserAPI
}
func (p *SyncAPIProducer) SendReceipt(
diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go
index ce8b06b7..871d26cd 100644
--- a/federationapi/routing/devices.go
+++ b/federationapi/routing/devices.go
@@ -17,7 +17,7 @@ import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
+ "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
@@ -26,11 +26,11 @@ import (
// GetUserDevices for the given user id
func GetUserDevices(
req *http.Request,
- keyAPI keyapi.FederationKeyAPI,
+ keyAPI api.FederationKeyAPI,
userID string,
) util.JSONResponse {
- var res keyapi.QueryDeviceMessagesResponse
- if err := keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{
+ var res api.QueryDeviceMessagesResponse
+ if err := keyAPI.QueryDeviceMessages(req.Context(), &api.QueryDeviceMessagesRequest{
UserID: userID,
}, &res); err != nil {
return util.ErrorResponse(err)
@@ -40,12 +40,12 @@ func GetUserDevices(
return jsonerror.InternalServerError()
}
- sigReq := &keyapi.QuerySignaturesRequest{
+ sigReq := &api.QuerySignaturesRequest{
TargetIDs: map[string][]gomatrixserverlib.KeyID{
userID: {},
},
}
- sigRes := &keyapi.QuerySignaturesResponse{}
+ sigRes := &api.QuerySignaturesResponse{}
for _, dev := range res.Devices {
sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID))
}
diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go
index dc262cfd..2885cc91 100644
--- a/federationapi/routing/keys.go
+++ b/federationapi/routing/keys.go
@@ -22,8 +22,8 @@ import (
clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
- "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
diff --git a/federationapi/routing/profile_test.go b/federationapi/routing/profile_test.go
index 76365608..3b9d576b 100644
--- a/federationapi/routing/profile_test.go
+++ b/federationapi/routing/profile_test.go
@@ -62,7 +62,7 @@ func TestHandleQueryProfile(t *testing.T) {
if !ok {
panic("This is a programming error.")
}
- routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil)
+ routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil)
handler := fedMux.Get(routing.QueryProfileRouteName).GetHandler().ServeHTTP
_, sk, _ := ed25519.GenerateKey(nil)
diff --git a/federationapi/routing/query_test.go b/federationapi/routing/query_test.go
index 21f35bf0..d839a16b 100644
--- a/federationapi/routing/query_test.go
+++ b/federationapi/routing/query_test.go
@@ -62,7 +62,7 @@ func TestHandleQueryDirectory(t *testing.T) {
if !ok {
panic("This is a programming error.")
}
- routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil)
+ routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil)
handler := fedMux.Get(routing.QueryDirectoryRouteName).GetHandler().ServeHTTP
_, sk, _ := ed25519.GenerateKey(nil)
diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go
index 5eb30c6e..324740dd 100644
--- a/federationapi/routing/routing.go
+++ b/federationapi/routing/routing.go
@@ -29,7 +29,6 @@ import (
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/httputil"
- keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
@@ -62,7 +61,6 @@ func Setup(
keys gomatrixserverlib.JSONVerifier,
federation federationAPI.FederationClient,
userAPI userapi.FederationUserAPI,
- keyAPI keyserverAPI.FederationKeyAPI,
mscCfg *config.MSCs,
servers federationAPI.ServersInRoomProvider,
producer *producers.SyncAPIProducer,
@@ -141,7 +139,7 @@ func Setup(
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return Send(
httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]),
- cfg, rsAPI, keyAPI, keys, federation, mu, servers, producer,
+ cfg, rsAPI, userAPI, keys, federation, mu, servers, producer,
)
},
)).Methods(http.MethodPut, http.MethodOptions).Name(SendRouteName)
@@ -269,7 +267,7 @@ func Setup(
"federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return GetUserDevices(
- httpReq, keyAPI, vars["userID"],
+ httpReq, userAPI, vars["userID"],
)
},
)).Methods(http.MethodGet)
@@ -494,14 +492,14 @@ func Setup(
v1fedmux.Handle("/user/keys/claim", MakeFedAPI(
"federation_keys_claim", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
- return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
+ return ClaimOneTimeKeys(httpReq, request, userAPI, cfg.Matrix.ServerName)
},
)).Methods(http.MethodPost)
v1fedmux.Handle("/user/keys/query", MakeFedAPI(
"federation_keys_query", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
- return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
+ return QueryDeviceKeys(httpReq, request, userAPI, cfg.Matrix.ServerName)
},
)).Methods(http.MethodPost)
diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go
index 67b513c9..82651719 100644
--- a/federationapi/routing/send.go
+++ b/federationapi/routing/send.go
@@ -28,9 +28,9 @@ import (
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/internal"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
+ userAPI "github.com/matrix-org/dendrite/userapi/api"
)
const (
@@ -59,7 +59,7 @@ func Send(
txnID gomatrixserverlib.TransactionID,
cfg *config.FederationAPI,
rsAPI api.FederationRoomserverAPI,
- keyAPI keyapi.FederationKeyAPI,
+ keyAPI userAPI.FederationUserAPI,
keys gomatrixserverlib.JSONVerifier,
federation federationAPI.FederationClient,
mu *internal.MutexByRoom,
diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go
index d7feee0e..eed4e7e6 100644
--- a/federationapi/routing/send_test.go
+++ b/federationapi/routing/send_test.go
@@ -58,7 +58,7 @@ func TestHandleSend(t *testing.T) {
if !ok {
panic("This is a programming error.")
}
- routing.Setup(base, nil, r, keyRing, nil, nil, nil, &base.Cfg.MSCs, nil, nil)
+ routing.Setup(base, nil, r, keyRing, nil, nil, &base.Cfg.MSCs, nil, nil)
handler := fedMux.Get(routing.SendRouteName).GetHandler().ServeHTTP
_, sk, _ := ed25519.GenerateKey(nil)
diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go
index 95673fc1..13b00af5 100644
--- a/internal/transactionrequest.go
+++ b/internal/transactionrequest.go
@@ -24,9 +24,9 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/federationapi/types"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
syncTypes "github.com/matrix-org/dendrite/syncapi/types"
+ userAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
@@ -56,7 +56,7 @@ var (
type TxnReq struct {
gomatrixserverlib.Transaction
rsAPI api.FederationRoomserverAPI
- keyAPI keyapi.FederationKeyAPI
+ userAPI userAPI.FederationUserAPI
ourServerName gomatrixserverlib.ServerName
keys gomatrixserverlib.JSONVerifier
roomsMu *MutexByRoom
@@ -66,7 +66,7 @@ type TxnReq struct {
func NewTxnReq(
rsAPI api.FederationRoomserverAPI,
- keyAPI keyapi.FederationKeyAPI,
+ userAPI userAPI.FederationUserAPI,
ourServerName gomatrixserverlib.ServerName,
keys gomatrixserverlib.JSONVerifier,
roomsMu *MutexByRoom,
@@ -80,7 +80,7 @@ func NewTxnReq(
) TxnReq {
t := TxnReq{
rsAPI: rsAPI,
- keyAPI: keyAPI,
+ userAPI: userAPI,
ourServerName: ourServerName,
keys: keys,
roomsMu: roomsMu,
diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go
index 93c6fb6f..8597ae24 100644
--- a/internal/transactionrequest_test.go
+++ b/internal/transactionrequest_test.go
@@ -23,13 +23,13 @@ import (
"time"
"github.com/matrix-org/dendrite/federationapi/producers"
- keyAPI "github.com/matrix-org/dendrite/keyserver/api"
rsAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
+ keyAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/stretchr/testify/assert"
diff --git a/keyserver/README.md b/keyserver/README.md
deleted file mode 100644
index fd9f37d2..00000000
--- a/keyserver/README.md
+++ /dev/null
@@ -1,19 +0,0 @@
-## Key Server
-
-This is an internal component which manages E2E keys from clients. It handles all the [Key Management APIs](https://matrix.org/docs/spec/client_server/r0.6.1#key-management-api) with the exception of `/keys/changes` which is handled by Sync API. This component is designed to shard by user ID.
-
-Keys are uploaded and stored in this component, and key changes are emitted to a Kafka topic for downstream components such as Sync API.
-
-### Internal APIs
-- `PerformUploadKeys` stores identity keys and one-time public keys for given user(s).
-- `PerformClaimKeys` acquires one-time public keys for given user(s). This may involve outbound federation calls.
-- `QueryKeys` returns identity keys for given user(s). This may involve outbound federation calls. This component may then cache federated identity keys to avoid repeatedly hitting remote servers.
-- A topic which emits identity keys every time there is a change (addition or deletion).
-
-### Endpoint mappings
-- Client API maps `/keys/upload` to `PerformUploadKeys`.
-- Client API maps `/keys/query` to `QueryKeys`.
-- Client API maps `/keys/claim` to `PerformClaimKeys`.
-- Federation API maps `/user/keys/query` to `QueryKeys`.
-- Federation API maps `/user/keys/claim` to `PerformClaimKeys`.
-- Sync API maps `/keys/changes` to consuming from the Kafka topic.
diff --git a/keyserver/api/api.go b/keyserver/api/api.go
deleted file mode 100644
index 14fced3e..00000000
--- a/keyserver/api/api.go
+++ /dev/null
@@ -1,346 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package api
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "strings"
- "time"
-
- "github.com/matrix-org/gomatrixserverlib"
-
- "github.com/matrix-org/dendrite/keyserver/types"
- userapi "github.com/matrix-org/dendrite/userapi/api"
-)
-
-type KeyInternalAPI interface {
- SyncKeyAPI
- ClientKeyAPI
- FederationKeyAPI
- UserKeyAPI
-
- // SetUserAPI assigns a user API to query when extracting device names.
- SetUserAPI(i userapi.KeyserverUserAPI)
-}
-
-// API functions required by the clientapi
-type ClientKeyAPI interface {
- QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
- PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
- PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
- PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error
- // PerformClaimKeys claims one-time keys for use in pre-key messages
- PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
- PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
-}
-
-// API functions required by the userapi
-type UserKeyAPI interface {
- PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
- PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error
-}
-
-// API functions required by the syncapi
-type SyncKeyAPI interface {
- QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
- QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
- PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
-}
-
-type FederationKeyAPI interface {
- QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
- QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
- QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
- PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
- PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
-}
-
-// KeyError is returned if there was a problem performing/querying the server
-type KeyError struct {
- Err string `json:"error"`
- IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE
- IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM
- IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM
-}
-
-func (k *KeyError) Error() string {
- return k.Err
-}
-
-type DeviceMessageType int
-
-const (
- TypeDeviceKeyUpdate DeviceMessageType = iota
- TypeCrossSigningUpdate
-)
-
-// DeviceMessage represents the message produced into Kafka by the key server.
-type DeviceMessage struct {
- Type DeviceMessageType `json:"Type,omitempty"`
- *DeviceKeys `json:"DeviceKeys,omitempty"`
- *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
- // A monotonically increasing number which represents device changes for this user.
- StreamID int64
- DeviceChangeID int64
-}
-
-// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log
-type OutputCrossSigningKeyUpdate struct {
- CrossSigningKeyUpdate `json:"signing_keys"`
-}
-
-type CrossSigningKeyUpdate struct {
- MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"`
- SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"`
- UserID string `json:"user_id"`
-}
-
-// DeviceKeysEqual returns true if the device keys updates contain the
-// same display name and key JSON. This will return false if either of
-// the updates is not a device keys update, or if the user ID/device ID
-// differ between the two.
-func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool {
- if m1.DeviceKeys == nil || m2.DeviceKeys == nil {
- return false
- }
- if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID {
- return false
- }
- if m1.DisplayName != m2.DisplayName {
- return false // different display names
- }
- if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 {
- return false // either is empty
- }
- return bytes.Equal(m1.KeyJSON, m2.KeyJSON)
-}
-
-// DeviceKeys represents a set of device keys for a single device
-// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
-type DeviceKeys struct {
- // The user who owns this device
- UserID string
- // The device ID of this device
- DeviceID string
- // The device display name
- DisplayName string
- // The raw device key JSON
- KeyJSON []byte
-}
-
-// WithStreamID returns a copy of this device message with the given stream ID
-func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage {
- return DeviceMessage{
- DeviceKeys: k,
- StreamID: streamID,
- }
-}
-
-// OneTimeKeys represents a set of one-time keys for a single device
-// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
-type OneTimeKeys struct {
- // The user who owns this device
- UserID string
- // The device ID of this device
- DeviceID string
- // A map of algorithm:key_id => key JSON
- KeyJSON map[string]json.RawMessage
-}
-
-// Split a key in KeyJSON into algorithm and key ID
-func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
- segments := strings.Split(keyIDWithAlgo, ":")
- return segments[0], segments[1]
-}
-
-// OneTimeKeysCount represents the counts of one-time keys for a single device
-type OneTimeKeysCount struct {
- // The user who owns this device
- UserID string
- // The device ID of this device
- DeviceID string
- // algorithm to count e.g:
- // {
- // "curve25519": 10,
- // "signed_curve25519": 20
- // }
- KeyCount map[string]int
-}
-
-// PerformUploadKeysRequest is the request to PerformUploadKeys
-type PerformUploadKeysRequest struct {
- UserID string // Required - User performing the request
- DeviceID string // Optional - Device performing the request, for fetching OTK count
- DeviceKeys []DeviceKeys
- OneTimeKeys []OneTimeKeys
- // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
- // the display name for their respective device, and NOT to modify the keys. The key
- // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
- // Without this flag, requests to modify device display names would delete device keys.
- OnlyDisplayNameUpdates bool
-}
-
-// PerformUploadKeysResponse is the response to PerformUploadKeys
-type PerformUploadKeysResponse struct {
- // A fatal error when processing e.g database failures
- Error *KeyError
- // A map of user_id -> device_id -> Error for tracking failures.
- KeyErrors map[string]map[string]*KeyError
- OneTimeKeyCounts []OneTimeKeysCount
-}
-
-// PerformDeleteKeysRequest asks the keyserver to forget about certain
-// keys, and signatures related to those keys.
-type PerformDeleteKeysRequest struct {
- UserID string
- KeyIDs []gomatrixserverlib.KeyID
-}
-
-// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest.
-type PerformDeleteKeysResponse struct {
- Error *KeyError
-}
-
-// KeyError sets a key error field on KeyErrors
-func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) {
- if r.KeyErrors[userID] == nil {
- r.KeyErrors[userID] = make(map[string]*KeyError)
- }
- r.KeyErrors[userID][deviceID] = err
-}
-
-type PerformClaimKeysRequest struct {
- // Map of user_id to device_id to algorithm name
- OneTimeKeys map[string]map[string]string
- Timeout time.Duration
-}
-
-type PerformClaimKeysResponse struct {
- // Map of user_id to device_id to algorithm:key_id to key JSON
- OneTimeKeys map[string]map[string]map[string]json.RawMessage
- // Map of remote server domain to error JSON
- Failures map[string]interface{}
- // Set if there was a fatal error processing this action
- Error *KeyError
-}
-
-type PerformUploadDeviceKeysRequest struct {
- gomatrixserverlib.CrossSigningKeys
- // The user that uploaded the key, should be populated by the clientapi.
- UserID string
-}
-
-type PerformUploadDeviceKeysResponse struct {
- Error *KeyError
-}
-
-type PerformUploadDeviceSignaturesRequest struct {
- Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice
- // The user that uploaded the sig, should be populated by the clientapi.
- UserID string
-}
-
-type PerformUploadDeviceSignaturesResponse struct {
- Error *KeyError
-}
-
-type QueryKeysRequest struct {
- // The user ID asking for the keys, e.g. if from a client API request.
- // Will not be populated if the key request came from federation.
- UserID string
- // Maps user IDs to a list of devices
- UserToDevices map[string][]string
- Timeout time.Duration
-}
-
-type QueryKeysResponse struct {
- // Map of remote server domain to error JSON
- Failures map[string]interface{}
- // Map of user_id to device_id to device_key
- DeviceKeys map[string]map[string]json.RawMessage
- // Maps of user_id to cross signing key
- MasterKeys map[string]gomatrixserverlib.CrossSigningKey
- SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
- UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
- // Set if there was a fatal error processing this query
- Error *KeyError
-}
-
-type QueryKeyChangesRequest struct {
- // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning
- Offset int64
- // The inclusive offset where to track key changes up to. Messages with this offset are included in the response.
- // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing).
- ToOffset int64
-}
-
-type QueryKeyChangesResponse struct {
- // The set of users who have had their keys change.
- UserIDs []string
- // The latest offset represented in this response.
- Offset int64
- // Set if there was a problem handling the request.
- Error *KeyError
-}
-
-type QueryOneTimeKeysRequest struct {
- // The local user to query OTK counts for
- UserID string
- // The device to query OTK counts for
- DeviceID string
-}
-
-type QueryOneTimeKeysResponse struct {
- // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
- Count OneTimeKeysCount
- Error *KeyError
-}
-
-type QueryDeviceMessagesRequest struct {
- UserID string
-}
-
-type QueryDeviceMessagesResponse struct {
- // The latest stream ID
- StreamID int64
- Devices []DeviceMessage
- Error *KeyError
-}
-
-type QuerySignaturesRequest struct {
- // A map of target user ID -> target key/device IDs to retrieve signatures for
- TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"`
-}
-
-type QuerySignaturesResponse struct {
- // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures
- Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap
- // A map of target user ID -> cross-signing master key
- MasterKeys map[string]gomatrixserverlib.CrossSigningKey
- // A map of target user ID -> cross-signing self-signing key
- SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
- // A map of target user ID -> cross-signing user-signing key
- UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
- // The request error, if any
- Error *KeyError
-}
-
-type PerformMarkAsStaleRequest struct {
- UserID string
- Domain gomatrixserverlib.ServerName
- DeviceID string
-}
diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go
deleted file mode 100644
index 2d143682..00000000
--- a/keyserver/keyserver.go
+++ /dev/null
@@ -1,86 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package keyserver
-
-import (
- "github.com/sirupsen/logrus"
-
- rsapi "github.com/matrix-org/dendrite/roomserver/api"
-
- fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/consumers"
- "github.com/matrix-org/dendrite/keyserver/internal"
- "github.com/matrix-org/dendrite/keyserver/producers"
- "github.com/matrix-org/dendrite/keyserver/storage"
- "github.com/matrix-org/dendrite/setup/base"
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/setup/jetstream"
-)
-
-// NewInternalAPI returns a concerete implementation of the internal API. Callers
-// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
-func NewInternalAPI(
- base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
- rsAPI rsapi.KeyserverRoomserverAPI,
-) api.KeyInternalAPI {
- js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
-
- db, err := storage.NewDatabase(base, &cfg.Database)
- if err != nil {
- logrus.WithError(err).Panicf("failed to connect to key server database")
- }
-
- keyChangeProducer := &producers.KeyChange{
- Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)),
- JetStream: js,
- DB: db,
- }
- ap := &internal.KeyInternalAPI{
- DB: db,
- Cfg: cfg,
- FedClient: fedClient,
- Producer: keyChangeProducer,
- }
- updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable
- ap.Updater = updater
-
- // Remove users which we don't share a room with anymore
- if err := updater.CleanUp(); err != nil {
- logrus.WithError(err).Error("failed to cleanup stale device lists")
- }
-
- go func() {
- if err := updater.Start(); err != nil {
- logrus.WithError(err).Panicf("failed to start device list updater")
- }
- }()
-
- dlConsumer := consumers.NewDeviceListUpdateConsumer(
- base.ProcessContext, cfg, js, updater,
- )
- if err := dlConsumer.Start(); err != nil {
- logrus.WithError(err).Panic("failed to start device list consumer")
- }
-
- sigConsumer := consumers.NewSigningKeyUpdateConsumer(
- base.ProcessContext, cfg, js, ap,
- )
- if err := sigConsumer.Start(); err != nil {
- logrus.WithError(err).Panic("failed to start signing key consumer")
- }
-
- return ap
-}
diff --git a/keyserver/keyserver_test.go b/keyserver/keyserver_test.go
deleted file mode 100644
index 159b280f..00000000
--- a/keyserver/keyserver_test.go
+++ /dev/null
@@ -1,29 +0,0 @@
-package keyserver
-
-import (
- "context"
- "testing"
-
- roomserver "github.com/matrix-org/dendrite/roomserver/api"
- "github.com/matrix-org/dendrite/test"
- "github.com/matrix-org/dendrite/test/testrig"
-)
-
-type mockKeyserverRoomserverAPI struct {
- leftUsers []string
-}
-
-func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
- res.LeftUsers = m.leftUsers
- return nil
-}
-
-// Merely tests that we can create an internal keyserver API
-func Test_NewInternalAPI(t *testing.T) {
- rsAPI := &mockKeyserverRoomserverAPI{}
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- base, closeBase := testrig.CreateBaseDendrite(t, dbType)
- defer closeBase()
- _ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
- })
-}
diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go
deleted file mode 100644
index c6a8f44c..00000000
--- a/keyserver/storage/interface.go
+++ /dev/null
@@ -1,93 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package storage
-
-import (
- "context"
- "encoding/json"
-
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-type Database interface {
- // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
- // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
- ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
-
- // StoreOneTimeKeys persists the given one-time keys.
- StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
-
- // OneTimeKeysCount returns a count of all OTKs for this device.
- OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
-
- // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
- DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
-
- // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
- // for this (user, device).
- // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
- // Returns an error if there was a problem storing the keys.
- StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
-
- // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
- // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
- // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
- StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
-
- // PrevIDsExists returns true if all prev IDs exist for this user.
- PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
-
- // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
- // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
- DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
-
- // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
- // cross-signing signatures relating to that device.
- DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error
-
- // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
- // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
- ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
-
- // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
- // `userID` is the the user who has changed their keys in some way.
- StoreKeyChange(ctx context.Context, userID string) (int64, error)
-
- // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
- // A to offset of types.OffsetNewest means no upper limit.
- // Returns the offset of the latest key change.
- KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
-
- // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
- // If no domains are given, all user IDs with stale device lists are returned.
- StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
-
- // MarkDeviceListStale sets the stale bit for this user to isStale.
- MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
-
- CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error)
- CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error)
- CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
-
- StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
- StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
-
- DeleteStaleDeviceLists(
- ctx context.Context,
- userIDs []string,
- ) error
-}
diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go
deleted file mode 100644
index 35e63055..00000000
--- a/keyserver/storage/postgres/storage.go
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package postgres
-
-import (
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/shared"
- "github.com/matrix-org/dendrite/setup/base"
- "github.com/matrix-org/dendrite/setup/config"
-)
-
-// NewDatabase creates a new sync server database
-func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) {
- var err error
- db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter())
- if err != nil {
- return nil, err
- }
- otk, err := NewPostgresOneTimeKeysTable(db)
- if err != nil {
- return nil, err
- }
- dk, err := NewPostgresDeviceKeysTable(db)
- if err != nil {
- return nil, err
- }
- kc, err := NewPostgresKeyChangesTable(db)
- if err != nil {
- return nil, err
- }
- sdl, err := NewPostgresStaleDeviceListsTable(db)
- if err != nil {
- return nil, err
- }
- csk, err := NewPostgresCrossSigningKeysTable(db)
- if err != nil {
- return nil, err
- }
- css, err := NewPostgresCrossSigningSigsTable(db)
- if err != nil {
- return nil, err
- }
- if err = kc.Prepare(); err != nil {
- return nil, err
- }
- d := &shared.Database{
- DB: db,
- Writer: writer,
- OneTimeKeysTable: otk,
- DeviceKeysTable: dk,
- KeyChangesTable: kc,
- StaleDeviceListsTable: sdl,
- CrossSigningKeysTable: csk,
- CrossSigningSigsTable: css,
- }
- return d, nil
-}
diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go
deleted file mode 100644
index 54dd6ddc..00000000
--- a/keyserver/storage/shared/storage.go
+++ /dev/null
@@ -1,261 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package shared
-
-import (
- "context"
- "database/sql"
- "encoding/json"
- "fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-type Database struct {
- DB *sql.DB
- Writer sqlutil.Writer
- OneTimeKeysTable tables.OneTimeKeys
- DeviceKeysTable tables.DeviceKeys
- KeyChangesTable tables.KeyChanges
- StaleDeviceListsTable tables.StaleDeviceLists
- CrossSigningKeysTable tables.CrossSigningKeys
- CrossSigningSigsTable tables.CrossSigningSigs
-}
-
-func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
- return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
-}
-
-func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
- _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys)
- return err
- })
- return
-}
-
-func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
- return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
-}
-
-func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
- return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
-}
-
-func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
- count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
- if err != nil {
- return false, err
- }
- return count == len(prevIDs), nil
-}
-
-func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- for _, userID := range clearUserIDs {
- err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
- if err != nil {
- return err
- }
- }
- return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
- })
-}
-
-func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
- // work out the latest stream IDs for each user
- userIDToStreamID := make(map[string]int64)
- for _, k := range keys {
- userIDToStreamID[k.UserID] = 0
- }
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- for userID := range userIDToStreamID {
- streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
- if err != nil {
- return err
- }
- userIDToStreamID[userID] = streamID
- }
- // set the stream IDs for each key
- for i := range keys {
- k := keys[i]
- userIDToStreamID[k.UserID]++ // start stream from 1
- k.StreamID = userIDToStreamID[k.UserID]
- keys[i] = k
- }
- return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
- })
-}
-
-func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
- return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
-}
-
-func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
- var result []api.OneTimeKeys
- err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- for userID, deviceToAlgo := range userToDeviceToAlgorithm {
- for deviceID, algo := range deviceToAlgo {
- keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
- if err != nil {
- return err
- }
- if keyJSON != nil {
- result = append(result, api.OneTimeKeys{
- UserID: userID,
- DeviceID: deviceID,
- KeyJSON: keyJSON,
- })
- }
- }
- }
- return nil
- })
- return result, err
-}
-
-func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
- err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
- id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
- return err
- })
- return
-}
-
-func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
- return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
-}
-
-// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
-// If no domains are given, all user IDs with stale device lists are returned.
-func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
- return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
-}
-
-// MarkDeviceListStale sets the stale bit for this user to isStale.
-func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
- return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
- return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
- })
-}
-
-// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
-// cross-signing signatures relating to that device.
-func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- for _, deviceID := range deviceIDs {
- if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows {
- return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err)
- }
- if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
- return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err)
- }
- if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
- return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err)
- }
- }
- return nil
- })
-}
-
-// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
-func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) {
- keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
- if err != nil {
- return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err)
- }
- results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
- for purpose, key := range keyMap {
- keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode())
- result := gomatrixserverlib.CrossSigningKey{
- UserID: userID,
- Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose},
- Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
- keyID: key,
- },
- }
- sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID)
- if err != nil {
- continue
- }
- for sigUserID, forSigUserID := range sigMap {
- if userID != sigUserID {
- continue
- }
- if result.Signatures == nil {
- result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
- }
- if _, ok := result.Signatures[sigUserID]; !ok {
- result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
- }
- for sigKeyID, sigBytes := range forSigUserID {
- result.Signatures[sigUserID][sigKeyID] = sigBytes
- }
- }
- results[purpose] = result
- }
- return results, nil
-}
-
-// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
-func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) {
- return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
-}
-
-// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
-func (d *Database) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
- return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID)
-}
-
-// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
-func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- for keyType, keyData := range keyMap {
- if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
- return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
- }
- }
- return nil
- })
-}
-
-// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
-func (d *Database) StoreCrossSigningSigsForTarget(
- ctx context.Context,
- originUserID string, originKeyID gomatrixserverlib.KeyID,
- targetUserID string, targetKeyID gomatrixserverlib.KeyID,
- signature gomatrixserverlib.Base64Bytes,
-) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
- return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err)
- }
- return nil
- })
-}
-
-// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
-func (d *Database) DeleteStaleDeviceLists(
- ctx context.Context,
- userIDs []string,
-) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
- })
-}
diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go
deleted file mode 100644
index 873fe3e2..00000000
--- a/keyserver/storage/sqlite3/storage.go
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package sqlite3
-
-import (
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/shared"
- "github.com/matrix-org/dendrite/setup/base"
- "github.com/matrix-org/dendrite/setup/config"
-)
-
-func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) {
- db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
- if err != nil {
- return nil, err
- }
- otk, err := NewSqliteOneTimeKeysTable(db)
- if err != nil {
- return nil, err
- }
- dk, err := NewSqliteDeviceKeysTable(db)
- if err != nil {
- return nil, err
- }
- kc, err := NewSqliteKeyChangesTable(db)
- if err != nil {
- return nil, err
- }
- sdl, err := NewSqliteStaleDeviceListsTable(db)
- if err != nil {
- return nil, err
- }
- csk, err := NewSqliteCrossSigningKeysTable(db)
- if err != nil {
- return nil, err
- }
- css, err := NewSqliteCrossSigningSigsTable(db)
- if err != nil {
- return nil, err
- }
-
- if err = kc.Prepare(); err != nil {
- return nil, err
- }
- d := &shared.Database{
- DB: db,
- Writer: writer,
- OneTimeKeysTable: otk,
- DeviceKeysTable: dk,
- KeyChangesTable: kc,
- StaleDeviceListsTable: sdl,
- CrossSigningKeysTable: csk,
- CrossSigningSigsTable: css,
- }
- return d, nil
-}
diff --git a/keyserver/storage/storage.go b/keyserver/storage/storage.go
deleted file mode 100644
index ab6a3540..00000000
--- a/keyserver/storage/storage.go
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//go:build !wasm
-// +build !wasm
-
-package storage
-
-import (
- "fmt"
-
- "github.com/matrix-org/dendrite/keyserver/storage/postgres"
- "github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
- "github.com/matrix-org/dendrite/setup/base"
- "github.com/matrix-org/dendrite/setup/config"
-)
-
-// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
-// and sets postgres connection parameters
-func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) {
- switch {
- case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(base, dbProperties)
- case dbProperties.ConnectionString.IsPostgres():
- return postgres.NewDatabase(base, dbProperties)
- default:
- return nil, fmt.Errorf("unexpected database type")
- }
-}
diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go
deleted file mode 100644
index e7a2af7c..00000000
--- a/keyserver/storage/storage_test.go
+++ /dev/null
@@ -1,197 +0,0 @@
-package storage_test
-
-import (
- "context"
- "reflect"
- "sync"
- "testing"
-
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/dendrite/test"
- "github.com/matrix-org/dendrite/test/testrig"
-)
-
-var ctx = context.Background()
-
-func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
- base, close := testrig.CreateBaseDendrite(t, dbType)
- db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database)
- if err != nil {
- t.Fatalf("failed to create new database: %v", err)
- }
- return db, close
-}
-
-func MustNotError(t *testing.T, err error) {
- t.Helper()
- if err == nil {
- return
- }
- t.Fatalf("operation failed: %s", err)
-}
-
-func TestKeyChanges(t *testing.T) {
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, clean := MustCreateDatabase(t, dbType)
- defer clean()
- _, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
- MustNotError(t, err)
- deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
- MustNotError(t, err)
- userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
- if err != nil {
- t.Fatalf("Failed to KeyChanges: %s", err)
- }
- if latest != deviceChangeIDC {
- t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
- }
- if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
- t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
- }
- })
-}
-
-func TestKeyChangesNoDupes(t *testing.T) {
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, clean := MustCreateDatabase(t, dbType)
- defer clean()
- deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- if deviceChangeIDA == deviceChangeIDB {
- t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
- }
- deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
- if err != nil {
- t.Fatalf("Failed to KeyChanges: %s", err)
- }
- if latest != deviceChangeID {
- t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
- }
- if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
- t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
- }
- })
-}
-
-func TestKeyChangesUpperLimit(t *testing.T) {
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, clean := MustCreateDatabase(t, dbType)
- defer clean()
- deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
- MustNotError(t, err)
- _, err = db.StoreKeyChange(ctx, "@charlie:localhost")
- MustNotError(t, err)
- userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
- if err != nil {
- t.Fatalf("Failed to KeyChanges: %s", err)
- }
- if latest != deviceChangeIDB {
- t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
- }
- if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
- t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
- }
- })
-}
-
-var dbLock sync.Mutex
-var deviceArray = []string{"AAA", "another_device"}
-
-// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
-// and that they are returned correctly when querying for device keys.
-func TestDeviceKeysStreamIDGeneration(t *testing.T) {
- var err error
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, clean := MustCreateDatabase(t, dbType)
- defer clean()
- alice := "@alice:TestDeviceKeysStreamIDGeneration"
- bob := "@bob:TestDeviceKeysStreamIDGeneration"
- msgs := []api.DeviceMessage{
- {
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: "AAA",
- UserID: alice,
- KeyJSON: []byte(`{"key":"v1"}`),
- },
- // StreamID: 1
- },
- {
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: "AAA",
- UserID: bob,
- KeyJSON: []byte(`{"key":"v1"}`),
- },
- // StreamID: 1 as this is a different user
- },
- {
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: "another_device",
- UserID: alice,
- KeyJSON: []byte(`{"key":"v1"}`),
- },
- // StreamID: 2 as this is a 2nd device key
- },
- }
- MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
- if msgs[0].StreamID != 1 {
- t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
- }
- if msgs[1].StreamID != 1 {
- t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
- }
- if msgs[2].StreamID != 2 {
- t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
- }
-
- // updating a device sets the next stream ID for that user
- msgs = []api.DeviceMessage{
- {
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: "AAA",
- UserID: alice,
- KeyJSON: []byte(`{"key":"v2"}`),
- },
- // StreamID: 3
- },
- }
- MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
- if msgs[0].StreamID != 3 {
- t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
- }
-
- dbLock.Lock()
- defer dbLock.Unlock()
- // Querying for device keys returns the latest stream IDs
- msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false)
-
- if err != nil {
- t.Fatalf("DeviceKeysForUser returned error: %s", err)
- }
- wantStreamIDs := map[string]int64{
- "AAA": 3,
- "another_device": 2,
- }
- if len(msgs) != len(wantStreamIDs) {
- t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
- }
- for _, m := range msgs {
- if m.StreamID != wantStreamIDs[m.DeviceID] {
- t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
- }
- }
- })
-}
diff --git a/keyserver/storage/storage_wasm.go b/keyserver/storage/storage_wasm.go
deleted file mode 100644
index 75c9053e..00000000
--- a/keyserver/storage/storage_wasm.go
+++ /dev/null
@@ -1,34 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package storage
-
-import (
- "fmt"
-
- "github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
- "github.com/matrix-org/dendrite/setup/base"
- "github.com/matrix-org/dendrite/setup/config"
-)
-
-func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) {
- switch {
- case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(base, dbProperties)
- case dbProperties.ConnectionString.IsPostgres():
- return nil, fmt.Errorf("can't use Postgres implementation")
- default:
- return nil, fmt.Errorf("unexpected database type")
- }
-}
diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go
deleted file mode 100644
index 24da1125..00000000
--- a/keyserver/storage/tables/interface.go
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tables
-
-import (
- "context"
- "database/sql"
- "encoding/json"
-
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-type OneTimeKeys interface {
- SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
- CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
- InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
- // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
- // Returns an empty map if the key does not exist.
- SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
- DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
-}
-
-type DeviceKeys interface {
- SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
- InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
- SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
- CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
- SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
- DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
- DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
-}
-
-type KeyChanges interface {
- InsertKeyChange(ctx context.Context, userID string) (int64, error)
- // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
- // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset.
- SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
-
- Prepare() error
-}
-
-type StaleDeviceLists interface {
- InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
- SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
- DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
-}
-
-type CrossSigningKeys interface {
- SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error)
- UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error
-}
-
-type CrossSigningSigs interface {
- SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
- UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
- DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error
-}
diff --git a/roomserver/api/api.go b/roomserver/api/api.go
index a8228ae8..73732ae3 100644
--- a/roomserver/api/api.go
+++ b/roomserver/api/api.go
@@ -17,7 +17,6 @@ type RoomserverInternalAPI interface {
ClientRoomserverAPI
UserRoomserverAPI
FederationRoomserverAPI
- KeyserverRoomserverAPI
// needed to avoid chicken and egg scenario when setting up the
// interdependencies between the roomserver and other input APIs
@@ -167,6 +166,7 @@ type ClientRoomserverAPI interface {
type UserRoomserverAPI interface {
QueryLatestEventsAndStateAPI
+ KeyserverRoomserverAPI
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error
diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go
index 7ba01e50..304311c4 100644
--- a/roomserver/roomserver_test.go
+++ b/roomserver/roomserver_test.go
@@ -12,7 +12,6 @@ import (
userAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/federationapi"
- "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi"
"github.com/matrix-org/gomatrixserverlib"
@@ -47,7 +46,7 @@ func TestUsers(t *testing.T) {
})
t.Run("kick users", func(t *testing.T) {
- usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil)
+ usrAPI := userapi.NewInternalAPI(base, rsAPI, nil)
rsAPI.SetUserAPI(usrAPI)
testKickUsers(t, rsAPI, usrAPI)
})
@@ -228,11 +227,10 @@ func TestPurgeRoom(t *testing.T) {
fedClient := base.CreateFederationClient()
rsAPI := roomserver.NewInternalAPI(base)
- keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI)
- userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
+ userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
// this starts the JetStream consumers
- syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI)
+ syncapi.AddPublicRoutes(base, userAPI, rsAPI)
federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true)
rsAPI.SetFederationAPI(nil, nil)
diff --git a/setup/monolith.go b/setup/monolith.go
index 5bbe4019..d8c65223 100644
--- a/setup/monolith.go
+++ b/setup/monolith.go
@@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/federationapi"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/transactions"
- keyAPI "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/mediaapi"
"github.com/matrix-org/dendrite/relayapi"
relayAPI "github.com/matrix-org/dendrite/relayapi/api"
@@ -45,7 +44,6 @@ type Monolith struct {
FederationAPI federationAPI.FederationInternalAPI
RoomserverAPI roomserverAPI.RoomserverInternalAPI
UserAPI userapi.UserInternalAPI
- KeyAPI keyAPI.KeyInternalAPI
RelayAPI relayAPI.RelayInternalAPI
// Optional
@@ -61,19 +59,14 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) {
}
clientapi.AddPublicRoutes(
base, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(),
- m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI,
+ m.FederationAPI, m.UserAPI, userDirectoryProvider,
m.ExtPublicRoomsProvider,
)
federationapi.AddPublicRoutes(
- base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI,
- m.KeyAPI, nil,
- )
- mediaapi.AddPublicRoutes(
- base, m.UserAPI, m.Client,
- )
- syncapi.AddPublicRoutes(
- base, m.UserAPI, m.RoomserverAPI, m.KeyAPI,
+ base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, nil,
)
+ mediaapi.AddPublicRoutes(base, m.UserAPI, m.Client)
+ syncapi.AddPublicRoutes(base, m.UserAPI, m.RoomserverAPI)
if m.RelayAPI != nil {
relayapi.AddPublicRoutes(base, m.KeyRing, m.RelayAPI)
diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go
index 92f08150..5faaefb8 100644
--- a/syncapi/consumers/keychange.go
+++ b/syncapi/consumers/keychange.go
@@ -19,7 +19,6 @@ import (
"encoding/json"
"github.com/getsentry/sentry-go"
- "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
@@ -28,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/dendrite/userapi/api"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
)
diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go
index 356e8326..32208c58 100644
--- a/syncapi/consumers/sendtodevice.go
+++ b/syncapi/consumers/sendtodevice.go
@@ -25,7 +25,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
@@ -33,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/dendrite/userapi/api"
)
// OutputSendToDeviceEventConsumer consumes events that originated in the EDU server.
@@ -42,7 +42,7 @@ type OutputSendToDeviceEventConsumer struct {
durable string
topic string
db storage.Database
- keyAPI keyapi.SyncKeyAPI
+ userAPI api.SyncKeyAPI
isLocalServerName func(gomatrixserverlib.ServerName) bool
stream streams.StreamProvider
notifier *notifier.Notifier
@@ -55,7 +55,7 @@ func NewOutputSendToDeviceEventConsumer(
cfg *config.SyncAPI,
js nats.JetStreamContext,
store storage.Database,
- keyAPI keyapi.SyncKeyAPI,
+ userAPI api.SyncKeyAPI,
notifier *notifier.Notifier,
stream streams.StreamProvider,
) *OutputSendToDeviceEventConsumer {
@@ -65,7 +65,7 @@ func NewOutputSendToDeviceEventConsumer(
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"),
db: store,
- keyAPI: keyAPI,
+ userAPI: userAPI,
isLocalServerName: cfg.Matrix.IsLocalServerName,
notifier: notifier,
stream: stream,
@@ -116,7 +116,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs []
_, senderDomain, _ := gomatrixserverlib.SplitID('@', output.Sender)
if requestingDeviceID != "" && !s.isLocalServerName(senderDomain) {
// Mark the requesting device as stale, if we don't know about it.
- if err = s.keyAPI.PerformMarkAsStaleIfNeeded(ctx, &keyapi.PerformMarkAsStaleRequest{
+ if err = s.userAPI.PerformMarkAsStaleIfNeeded(ctx, &api.PerformMarkAsStaleRequest{
UserID: output.Sender, Domain: senderDomain, DeviceID: requestingDeviceID,
}, &struct{}{}); err != nil {
logger.WithError(err).Errorf("failed to mark as stale if needed")
diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go
index 4867f7d9..e7f677c8 100644
--- a/syncapi/internal/keychange.go
+++ b/syncapi/internal/keychange.go
@@ -18,22 +18,22 @@ import (
"context"
"strings"
+ keytypes "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
- keytypes "github.com/matrix-org/dendrite/keyserver/types"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/dendrite/userapi/api"
)
// DeviceOTKCounts adds one-time key counts to the /sync response
-func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error {
- var queryRes keyapi.QueryOneTimeKeysResponse
- _ = keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{
+func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceID string, res *types.Response) error {
+ var queryRes api.QueryOneTimeKeysResponse
+ _ = keyAPI.QueryOneTimeKeys(ctx, &api.QueryOneTimeKeysRequest{
UserID: userID,
DeviceID: deviceID,
}, &queryRes)
@@ -48,7 +48,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, devi
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information.
func DeviceListCatchup(
- ctx context.Context, db storage.SharedUsers, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
+ ctx context.Context, db storage.SharedUsers, userAPI api.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
userID string, res *types.Response, from, to types.StreamPosition,
) (newPos types.StreamPosition, hasNew bool, err error) {
@@ -74,8 +74,8 @@ func DeviceListCatchup(
if from > 0 {
offset = int64(from)
}
- var queryRes keyapi.QueryKeyChangesResponse
- _ = keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{
+ var queryRes api.QueryKeyChangesResponse
+ _ = userAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{
Offset: offset,
ToOffset: toOffset,
}, &queryRes)
diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go
index d64bea11..4bb85166 100644
--- a/syncapi/internal/keychange_test.go
+++ b/syncapi/internal/keychange_test.go
@@ -9,7 +9,6 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
@@ -22,44 +21,44 @@ var (
type mockKeyAPI struct{}
-func (k *mockKeyAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *keyapi.PerformMarkAsStaleRequest, res *struct{}) error {
+func (k *mockKeyAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *userapi.PerformMarkAsStaleRequest, res *struct{}) error {
return nil
}
-func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) error {
+func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *userapi.PerformUploadKeysRequest, res *userapi.PerformUploadKeysResponse) error {
return nil
}
func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {}
// PerformClaimKeys claims one-time keys for use in pre-key messages
-func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) error {
+func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *userapi.PerformClaimKeysRequest, res *userapi.PerformClaimKeysResponse) error {
return nil
}
-func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) error {
+func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *userapi.PerformDeleteKeysRequest, res *userapi.PerformDeleteKeysResponse) error {
return nil
}
-func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) error {
+func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *userapi.PerformUploadDeviceKeysRequest, res *userapi.PerformUploadDeviceKeysResponse) error {
return nil
}
-func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) error {
+func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *userapi.PerformUploadDeviceSignaturesRequest, res *userapi.PerformUploadDeviceSignaturesResponse) error {
return nil
}
-func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) error {
+func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *userapi.QueryKeysRequest, res *userapi.QueryKeysResponse) error {
return nil
}
-func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error {
+func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error {
return nil
}
-func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error {
+func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
return nil
}
-func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) error {
+func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error {
return nil
}
-func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) error {
+func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *userapi.QuerySignaturesRequest, res *userapi.QuerySignaturesResponse) error {
return nil
}
diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go
index 7996c203..e8189c35 100644
--- a/syncapi/streams/stream_devicelist.go
+++ b/syncapi/streams/stream_devicelist.go
@@ -3,17 +3,17 @@ package streams
import (
"context"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
)
type DeviceListStreamProvider struct {
DefaultStreamProvider
- rsAPI api.SyncRoomserverAPI
- keyAPI keyapi.SyncKeyAPI
+ rsAPI api.SyncRoomserverAPI
+ userAPI userapi.SyncKeyAPI
}
func (p *DeviceListStreamProvider) CompleteSync(
@@ -31,12 +31,12 @@ func (p *DeviceListStreamProvider) IncrementalSync(
from, to types.StreamPosition,
) types.StreamPosition {
var err error
- to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
+ to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.userAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
return from
}
- err = internal.DeviceOTKCounts(req.Context, p.keyAPI, req.Device.UserID, req.Device.ID, req.Response)
+ err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response)
if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
return from
diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go
index dc854762..a35491ac 100644
--- a/syncapi/streams/streams.go
+++ b/syncapi/streams/streams.go
@@ -5,7 +5,6 @@ import (
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
@@ -27,7 +26,7 @@ type Streams struct {
func NewSyncStreamProviders(
d storage.Database, userAPI userapi.SyncUserAPI,
- rsAPI rsapi.SyncRoomserverAPI, keyAPI keyapi.SyncKeyAPI,
+ rsAPI rsapi.SyncRoomserverAPI,
eduCache *caching.EDUCache, lazyLoadCache caching.LazyLoadCache, notifier *notifier.Notifier,
) *Streams {
streams := &Streams{
@@ -60,7 +59,7 @@ func NewSyncStreamProviders(
DeviceListStreamProvider: &DeviceListStreamProvider{
DefaultStreamProvider: DefaultStreamProvider{DB: d},
rsAPI: rsAPI,
- keyAPI: keyAPI,
+ userAPI: userAPI,
},
PresenceStreamProvider: &PresenceStreamProvider{
DefaultStreamProvider: DefaultStreamProvider{DB: d},
diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go
index b086567b..68f91387 100644
--- a/syncapi/sync/requestpool.go
+++ b/syncapi/sync/requestpool.go
@@ -32,7 +32,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/sqlutil"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/internal"
@@ -48,7 +47,6 @@ type RequestPool struct {
db storage.Database
cfg *config.SyncAPI
userAPI userapi.SyncUserAPI
- keyAPI keyapi.SyncKeyAPI
rsAPI roomserverAPI.SyncRoomserverAPI
lastseen *sync.Map
presence *sync.Map
@@ -69,7 +67,7 @@ type PresenceConsumer interface {
// NewRequestPool makes a new RequestPool
func NewRequestPool(
db storage.Database, cfg *config.SyncAPI,
- userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI,
+ userAPI userapi.SyncUserAPI,
rsAPI roomserverAPI.SyncRoomserverAPI,
streams *streams.Streams, notifier *notifier.Notifier,
producer PresencePublisher, consumer PresenceConsumer, enableMetrics bool,
@@ -83,7 +81,6 @@ func NewRequestPool(
db: db,
cfg: cfg,
userAPI: userAPI,
- keyAPI: keyAPI,
rsAPI: rsAPI,
lastseen: &sync.Map{},
presence: &sync.Map{},
@@ -280,7 +277,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
// https://github.com/matrix-org/synapse/blob/29f06704b8871a44926f7c99e73cf4a978fb8e81/synapse/rest/client/sync.py#L276-L281
// Only try to get OTKs if the context isn't already done.
if syncReq.Context.Err() == nil {
- err = internal.DeviceOTKCounts(syncReq.Context, rp.keyAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response)
+ err = internal.DeviceOTKCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response)
if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTK counts")
}
@@ -551,7 +548,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), snapshot, syncReq, fromToken.PDUPosition, toToken.PDUPosition)
_, _, err = internal.DeviceListCatchup(
- req.Context(), snapshot, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
+ req.Context(), snapshot, rp.userAPI, rp.rsAPI, syncReq.Device.UserID,
syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition,
)
if err != nil {
diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go
index be19310f..153f7af5 100644
--- a/syncapi/syncapi.go
+++ b/syncapi/syncapi.go
@@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/internal/caching"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream"
@@ -42,7 +41,6 @@ func AddPublicRoutes(
base *base.BaseDendrite,
userAPI userapi.SyncUserAPI,
rsAPI api.SyncRoomserverAPI,
- keyAPI keyapi.SyncKeyAPI,
) {
cfg := &base.Cfg.SyncAPI
@@ -55,7 +53,7 @@ func AddPublicRoutes(
eduCache := caching.NewTypingCache()
notifier := notifier.NewNotifier()
- streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, base.Caches, notifier)
+ streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, base.Caches, notifier)
notifier.SetCurrentPosition(streams.Latest(context.Background()))
if err = notifier.Load(context.Background(), syncDB); err != nil {
logrus.WithError(err).Panicf("failed to load notifier ")
@@ -71,7 +69,7 @@ func AddPublicRoutes(
userAPI,
)
- requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics)
+ requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics)
if err = presenceConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start presence consumer")
@@ -117,7 +115,7 @@ func AddPublicRoutes(
}
sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer(
- base.ProcessContext, cfg, js, syncDB, keyAPI, notifier, streams.SendToDeviceStreamProvider,
+ base.ProcessContext, cfg, js, syncDB, userAPI, notifier, streams.SendToDeviceStreamProvider,
)
if err = sendToDeviceConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start send-to-device consumer")
diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go
index 643c3026..e748660f 100644
--- a/syncapi/syncapi_test.go
+++ b/syncapi/syncapi_test.go
@@ -16,7 +16,6 @@ import (
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/clientapi/producers"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
@@ -83,18 +82,15 @@ func (s *syncUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAc
return nil
}
-func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
+func (s *syncUserAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error {
return nil
}
-type syncKeyAPI struct {
- keyapi.SyncKeyAPI
-}
-
-func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error {
+func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
return nil
}
-func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error {
+
+func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
return nil
}
@@ -121,7 +117,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
msgs := toNATSMsgs(t, base, room.Events()...)
- AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
+ AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}})
testrig.MustPublishMsgs(t, jsctx, msgs...)
testCases := []struct {
@@ -220,7 +216,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
// m.room.history_visibility
msgs := toNATSMsgs(t, base, room.Events()...)
sinceTokens := make([]string, len(msgs))
- AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
+ AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}})
for i, msg := range msgs {
testrig.MustPublishMsgs(t, jsctx, msg)
time.Sleep(100 * time.Millisecond)
@@ -304,7 +300,7 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
- AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{})
+ AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{})
w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
"access_token": alice.AccessToken,
@@ -422,7 +418,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
rsAPI := roomserver.NewInternalAPI(base)
rsAPI.SetFederationAPI(nil, nil)
- AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{})
+ AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI)
for _, tc := range testCases {
testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType)
@@ -722,7 +718,7 @@ func TestGetMembership(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(base)
rsAPI.SetFederationAPI(nil, nil)
- AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{})
+ AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
@@ -789,7 +785,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
- AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{})
+ AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{})
producer := producers.SyncAPIProducer{
TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
@@ -1008,7 +1004,7 @@ func testContext(t *testing.T, dbType test.DBType) {
rsAPI := roomserver.NewInternalAPI(base)
rsAPI.SetFederationAPI(nil, nil)
- AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, &syncKeyAPI{})
+ AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI)
room := test.NewRoom(t, user)
diff --git a/userapi/api/api.go b/userapi/api/api.go
index 4ea2e91c..fa297f77 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -15,9 +15,13 @@
package api
import (
+ "bytes"
"context"
"encoding/json"
+ "strings"
+ "time"
+ "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@@ -26,15 +30,12 @@ import (
// UserInternalAPI is the internal API for information about users and devices.
type UserInternalAPI interface {
- AppserviceUserAPI
SyncUserAPI
ClientUserAPI
- MediaUserAPI
FederationUserAPI
- RoomserverUserAPI
- KeyserverUserAPI
QuerySearchProfilesAPI // used by p2p demos
+ QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
}
// api functions required by the appservice api
@@ -43,11 +44,6 @@ type AppserviceUserAPI interface {
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
}
-type KeyserverUserAPI interface {
- QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
- QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
-}
-
type RoomserverUserAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
@@ -60,13 +56,20 @@ type MediaUserAPI interface {
// api functions required by the federation api
type FederationUserAPI interface {
+ UploadDeviceKeysAPI
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
+ QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
+ QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
+ QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
+ QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
+ PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
}
// api functions required by the sync api
type SyncUserAPI interface {
QueryAcccessTokenAPI
+ SyncKeyAPI
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
@@ -79,6 +82,7 @@ type ClientUserAPI interface {
QueryAcccessTokenAPI
LoginTokenInternalAPI
UserLoginAPI
+ ClientKeyAPI
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
@@ -681,3 +685,310 @@ type QueryAccountByLocalpartRequest struct {
type QueryAccountByLocalpartResponse struct {
Account *Account
}
+
+// API functions required by the clientapi
+type ClientKeyAPI interface {
+ UploadDeviceKeysAPI
+ QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
+ PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
+
+ PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error
+ // PerformClaimKeys claims one-time keys for use in pre-key messages
+ PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
+ PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
+}
+
+type UploadDeviceKeysAPI interface {
+ PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
+}
+
+// API functions required by the syncapi
+type SyncKeyAPI interface {
+ QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
+ QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
+ PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
+}
+
+type FederationKeyAPI interface {
+ UploadDeviceKeysAPI
+ QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
+ QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
+ QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
+ PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
+}
+
+// KeyError is returned if there was a problem performing/querying the server
+type KeyError struct {
+ Err string `json:"error"`
+ IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE
+ IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM
+ IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM
+}
+
+func (k *KeyError) Error() string {
+ return k.Err
+}
+
+type DeviceMessageType int
+
+const (
+ TypeDeviceKeyUpdate DeviceMessageType = iota
+ TypeCrossSigningUpdate
+)
+
+// DeviceMessage represents the message produced into Kafka by the key server.
+type DeviceMessage struct {
+ Type DeviceMessageType `json:"Type,omitempty"`
+ *DeviceKeys `json:"DeviceKeys,omitempty"`
+ *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
+ // A monotonically increasing number which represents device changes for this user.
+ StreamID int64
+ DeviceChangeID int64
+}
+
+// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log
+type OutputCrossSigningKeyUpdate struct {
+ CrossSigningKeyUpdate `json:"signing_keys"`
+}
+
+type CrossSigningKeyUpdate struct {
+ MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"`
+ SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"`
+ UserID string `json:"user_id"`
+}
+
+// DeviceKeysEqual returns true if the device keys updates contain the
+// same display name and key JSON. This will return false if either of
+// the updates is not a device keys update, or if the user ID/device ID
+// differ between the two.
+func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool {
+ if m1.DeviceKeys == nil || m2.DeviceKeys == nil {
+ return false
+ }
+ if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID {
+ return false
+ }
+ if m1.DisplayName != m2.DisplayName {
+ return false // different display names
+ }
+ if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 {
+ return false // either is empty
+ }
+ return bytes.Equal(m1.KeyJSON, m2.KeyJSON)
+}
+
+// DeviceKeys represents a set of device keys for a single device
+// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
+type DeviceKeys struct {
+ // The user who owns this device
+ UserID string
+ // The device ID of this device
+ DeviceID string
+ // The device display name
+ DisplayName string
+ // The raw device key JSON
+ KeyJSON []byte
+}
+
+// WithStreamID returns a copy of this device message with the given stream ID
+func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage {
+ return DeviceMessage{
+ DeviceKeys: k,
+ StreamID: streamID,
+ }
+}
+
+// OneTimeKeys represents a set of one-time keys for a single device
+// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
+type OneTimeKeys struct {
+ // The user who owns this device
+ UserID string
+ // The device ID of this device
+ DeviceID string
+ // A map of algorithm:key_id => key JSON
+ KeyJSON map[string]json.RawMessage
+}
+
+// Split a key in KeyJSON into algorithm and key ID
+func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
+ segments := strings.Split(keyIDWithAlgo, ":")
+ return segments[0], segments[1]
+}
+
+// OneTimeKeysCount represents the counts of one-time keys for a single device
+type OneTimeKeysCount struct {
+ // The user who owns this device
+ UserID string
+ // The device ID of this device
+ DeviceID string
+ // algorithm to count e.g:
+ // {
+ // "curve25519": 10,
+ // "signed_curve25519": 20
+ // }
+ KeyCount map[string]int
+}
+
+// PerformUploadKeysRequest is the request to PerformUploadKeys
+type PerformUploadKeysRequest struct {
+ UserID string // Required - User performing the request
+ DeviceID string // Optional - Device performing the request, for fetching OTK count
+ DeviceKeys []DeviceKeys
+ OneTimeKeys []OneTimeKeys
+ // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
+ // the display name for their respective device, and NOT to modify the keys. The key
+ // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
+ // Without this flag, requests to modify device display names would delete device keys.
+ OnlyDisplayNameUpdates bool
+}
+
+// PerformUploadKeysResponse is the response to PerformUploadKeys
+type PerformUploadKeysResponse struct {
+ // A fatal error when processing e.g database failures
+ Error *KeyError
+ // A map of user_id -> device_id -> Error for tracking failures.
+ KeyErrors map[string]map[string]*KeyError
+ OneTimeKeyCounts []OneTimeKeysCount
+}
+
+// PerformDeleteKeysRequest asks the keyserver to forget about certain
+// keys, and signatures related to those keys.
+type PerformDeleteKeysRequest struct {
+ UserID string
+ KeyIDs []gomatrixserverlib.KeyID
+}
+
+// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest.
+type PerformDeleteKeysResponse struct {
+ Error *KeyError
+}
+
+// KeyError sets a key error field on KeyErrors
+func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) {
+ if r.KeyErrors[userID] == nil {
+ r.KeyErrors[userID] = make(map[string]*KeyError)
+ }
+ r.KeyErrors[userID][deviceID] = err
+}
+
+type PerformClaimKeysRequest struct {
+ // Map of user_id to device_id to algorithm name
+ OneTimeKeys map[string]map[string]string
+ Timeout time.Duration
+}
+
+type PerformClaimKeysResponse struct {
+ // Map of user_id to device_id to algorithm:key_id to key JSON
+ OneTimeKeys map[string]map[string]map[string]json.RawMessage
+ // Map of remote server domain to error JSON
+ Failures map[string]interface{}
+ // Set if there was a fatal error processing this action
+ Error *KeyError
+}
+
+type PerformUploadDeviceKeysRequest struct {
+ gomatrixserverlib.CrossSigningKeys
+ // The user that uploaded the key, should be populated by the clientapi.
+ UserID string
+}
+
+type PerformUploadDeviceKeysResponse struct {
+ Error *KeyError
+}
+
+type PerformUploadDeviceSignaturesRequest struct {
+ Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice
+ // The user that uploaded the sig, should be populated by the clientapi.
+ UserID string
+}
+
+type PerformUploadDeviceSignaturesResponse struct {
+ Error *KeyError
+}
+
+type QueryKeysRequest struct {
+ // The user ID asking for the keys, e.g. if from a client API request.
+ // Will not be populated if the key request came from federation.
+ UserID string
+ // Maps user IDs to a list of devices
+ UserToDevices map[string][]string
+ Timeout time.Duration
+}
+
+type QueryKeysResponse struct {
+ // Map of remote server domain to error JSON
+ Failures map[string]interface{}
+ // Map of user_id to device_id to device_key
+ DeviceKeys map[string]map[string]json.RawMessage
+ // Maps of user_id to cross signing key
+ MasterKeys map[string]gomatrixserverlib.CrossSigningKey
+ SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
+ UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
+ // Set if there was a fatal error processing this query
+ Error *KeyError
+}
+
+type QueryKeyChangesRequest struct {
+ // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning
+ Offset int64
+ // The inclusive offset where to track key changes up to. Messages with this offset are included in the response.
+ // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing).
+ ToOffset int64
+}
+
+type QueryKeyChangesResponse struct {
+ // The set of users who have had their keys change.
+ UserIDs []string
+ // The latest offset represented in this response.
+ Offset int64
+ // Set if there was a problem handling the request.
+ Error *KeyError
+}
+
+type QueryOneTimeKeysRequest struct {
+ // The local user to query OTK counts for
+ UserID string
+ // The device to query OTK counts for
+ DeviceID string
+}
+
+type QueryOneTimeKeysResponse struct {
+ // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
+ Count OneTimeKeysCount
+ Error *KeyError
+}
+
+type QueryDeviceMessagesRequest struct {
+ UserID string
+}
+
+type QueryDeviceMessagesResponse struct {
+ // The latest stream ID
+ StreamID int64
+ Devices []DeviceMessage
+ Error *KeyError
+}
+
+type QuerySignaturesRequest struct {
+ // A map of target user ID -> target key/device IDs to retrieve signatures for
+ TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"`
+}
+
+type QuerySignaturesResponse struct {
+ // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures
+ Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap
+ // A map of target user ID -> cross-signing master key
+ MasterKeys map[string]gomatrixserverlib.CrossSigningKey
+ // A map of target user ID -> cross-signing self-signing key
+ SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
+ // A map of target user ID -> cross-signing user-signing key
+ UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
+ // The request error, if any
+ Error *KeyError
+}
+
+type PerformMarkAsStaleRequest struct {
+ UserID string
+ Domain gomatrixserverlib.ServerName
+ DeviceID string
+}
diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go
index 42ae72e7..51bd2753 100644
--- a/userapi/consumers/clientapi.go
+++ b/userapi/consumers/clientapi.go
@@ -37,7 +37,7 @@ type OutputReceiptEventConsumer struct {
jetstream nats.JetStreamContext
durable string
topic string
- db storage.Database
+ db storage.UserDatabase
serverName gomatrixserverlib.ServerName
syncProducer *producers.SyncAPI
pgClient pushgateway.Client
@@ -49,7 +49,7 @@ func NewOutputReceiptEventConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
- store storage.Database,
+ store storage.UserDatabase,
syncProducer *producers.SyncAPI,
pgClient pushgateway.Client,
) *OutputReceiptEventConsumer {
diff --git a/keyserver/consumers/devicelistupdate.go b/userapi/consumers/devicelistupdate.go
index cd911f8c..a65889fc 100644
--- a/keyserver/consumers/devicelistupdate.go
+++ b/userapi/consumers/devicelistupdate.go
@@ -18,11 +18,11 @@ import (
"context"
"encoding/json"
+ "github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
- "github.com/matrix-org/dendrite/keyserver/internal"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
@@ -41,7 +41,7 @@ type DeviceListUpdateConsumer struct {
// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers.
func NewDeviceListUpdateConsumer(
process *process.ProcessContext,
- cfg *config.KeyServer,
+ cfg *config.UserAPI,
js nats.JetStreamContext,
updater *internal.DeviceListUpdater,
) *DeviceListUpdateConsumer {
diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go
index 3ce5af62..47d33095 100644
--- a/userapi/consumers/roomserver.go
+++ b/userapi/consumers/roomserver.go
@@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct {
rsAPI rsapi.UserRoomserverAPI
jetstream nats.JetStreamContext
durable string
- db storage.Database
+ db storage.UserDatabase
topic string
pgClient pushgateway.Client
syncProducer *producers.SyncAPI
@@ -53,7 +53,7 @@ func NewOutputRoomEventConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
- store storage.Database,
+ store storage.UserDatabase,
pgClient pushgateway.Client,
rsAPI rsapi.UserRoomserverAPI,
syncProducer *producers.SyncAPI,
diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go
index 39f4aab4..bc5ae652 100644
--- a/userapi/consumers/roomserver_test.go
+++ b/userapi/consumers/roomserver_test.go
@@ -18,11 +18,11 @@ import (
userAPITypes "github.com/matrix-org/dendrite/userapi/types"
)
-func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
+func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
- db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
+ db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "", 4, 0, 0, "")
if err != nil {
diff --git a/keyserver/consumers/signingkeyupdate.go b/userapi/consumers/signingkeyupdate.go
index bcceaad1..f4ff017d 100644
--- a/keyserver/consumers/signingkeyupdate.go
+++ b/userapi/consumers/signingkeyupdate.go
@@ -22,11 +22,10 @@ import (
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/internal"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/userapi/api"
)
// SigningKeyUpdateConsumer consumes signing key updates that came in over federation.
@@ -35,24 +34,24 @@ type SigningKeyUpdateConsumer struct {
jetstream nats.JetStreamContext
durable string
topic string
- keyAPI *internal.KeyInternalAPI
- cfg *config.KeyServer
+ userAPI api.UploadDeviceKeysAPI
+ cfg *config.UserAPI
isLocalServerName func(gomatrixserverlib.ServerName) bool
}
// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers.
func NewSigningKeyUpdateConsumer(
process *process.ProcessContext,
- cfg *config.KeyServer,
+ cfg *config.UserAPI,
js nats.JetStreamContext,
- keyAPI *internal.KeyInternalAPI,
+ userAPI api.UploadDeviceKeysAPI,
) *SigningKeyUpdateConsumer {
return &SigningKeyUpdateConsumer{
ctx: process.Context(),
jetstream: js,
durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
- keyAPI: keyAPI,
+ userAPI: userAPI,
cfg: cfg,
isLocalServerName: cfg.Matrix.IsLocalServerName,
}
@@ -70,7 +69,7 @@ func (t *SigningKeyUpdateConsumer) Start() error {
// signing key update events topic from the key server.
func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
- var updatePayload keyapi.CrossSigningKeyUpdate
+ var updatePayload api.CrossSigningKeyUpdate
if err := json.Unmarshal(msg.Data, &updatePayload); err != nil {
logrus.WithError(err).Errorf("Failed to read from signing key update input topic")
return true
@@ -94,12 +93,12 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
if updatePayload.SelfSigningKey != nil {
keys.SelfSigningKey = *updatePayload.SelfSigningKey
}
- uploadReq := &keyapi.PerformUploadDeviceKeysRequest{
+ uploadReq := &api.PerformUploadDeviceKeysRequest{
CrossSigningKeys: keys,
UserID: updatePayload.UserID,
}
- uploadRes := &keyapi.PerformUploadDeviceKeysResponse{}
- if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil {
+ uploadRes := &api.PerformUploadDeviceKeysResponse{}
+ if err := t.userAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil {
logrus.WithError(err).Error("failed to upload device keys")
return false
}
diff --git a/keyserver/internal/cross_signing.go b/userapi/internal/cross_signing.go
index 99859dff..8b9704d1 100644
--- a/keyserver/internal/cross_signing.go
+++ b/userapi/internal/cross_signing.go
@@ -22,8 +22,8 @@ import (
"fmt"
"strings"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/types"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/curve25519"
@@ -103,7 +103,7 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos
}
// nolint:gocyclo
-func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
+func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
// Find the keys to store.
byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
toStore := types.CrossSigningKeyMap{}
@@ -169,7 +169,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
// something if any of the specified keys in the request are different
// to what we've got in the database, to avoid generating key change
// notifications unnecessarily.
- existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
+ existingKeys, err := a.KeyDatabase.CrossSigningKeysDataForUser(ctx, req.UserID)
if err != nil {
res.Error = &api.KeyError{
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
@@ -216,7 +216,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
}
// Store the keys.
- if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {
+ if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err),
}
@@ -234,7 +234,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
continue
}
for sigKeyID, sigBytes := range forSigUserID {
- if err := a.DB.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil {
+ if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err),
}
@@ -257,7 +257,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
if update.MasterKey == nil && update.SelfSigningKey == nil {
return nil
}
- if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil {
+ if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
}
@@ -266,7 +266,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
return nil
}
-func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error {
+func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error {
// Before we do anything, we need the master and self-signing keys for this user.
// Then we can verify the signatures make sense.
queryReq := &api.QueryKeysRequest{
@@ -342,7 +342,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
MasterKey: &masterKey,
SelfSigningKey: &selfSigningKey,
}
- if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil {
+ if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
}
@@ -352,7 +352,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
return nil
}
-func (a *KeyInternalAPI) processSelfSignatures(
+func (a *UserInternalAPI) processSelfSignatures(
ctx context.Context,
signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
) error {
@@ -373,7 +373,7 @@ func (a *KeyInternalAPI) processSelfSignatures(
}
for originUserID, forOriginUserID := range sig.Signatures {
for originKeyID, originSig := range forOriginUserID {
- if err := a.DB.StoreCrossSigningSigsForTarget(
+ if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
); err != nil {
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
@@ -384,7 +384,7 @@ func (a *KeyInternalAPI) processSelfSignatures(
case *gomatrixserverlib.DeviceKeys:
for originUserID, forOriginUserID := range sig.Signatures {
for originKeyID, originSig := range forOriginUserID {
- if err := a.DB.StoreCrossSigningSigsForTarget(
+ if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
); err != nil {
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
@@ -401,7 +401,7 @@ func (a *KeyInternalAPI) processSelfSignatures(
return nil
}
-func (a *KeyInternalAPI) processOtherSignatures(
+func (a *UserInternalAPI) processOtherSignatures(
ctx context.Context, userID string, queryRes *api.QueryKeysResponse,
signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
) error {
@@ -442,7 +442,7 @@ func (a *KeyInternalAPI) processOtherSignatures(
}
for originKeyID, originSig := range userSigs {
- if err := a.DB.StoreCrossSigningSigsForTarget(
+ if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
ctx, userID, originKeyID, targetUserID, targetKeyID, originSig,
); err != nil {
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
@@ -461,11 +461,11 @@ func (a *KeyInternalAPI) processOtherSignatures(
return nil
}
-func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
+func (a *UserInternalAPI) crossSigningKeysFromDatabase(
ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse,
) {
for targetUserID := range req.UserToDevices {
- keys, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
+ keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
if err != nil {
logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID)
continue
@@ -478,7 +478,7 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
break
}
- sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
+ sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID)
continue
@@ -522,9 +522,9 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
}
}
-func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error {
+func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error {
for targetUserID, forTargetUser := range req.TargetIDs {
- keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
+ keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
if err != nil && err != sql.ErrNoRows {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err),
@@ -556,7 +556,7 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign
for _, targetKeyID := range forTargetUser {
// Get own signatures only.
- sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID)
+ sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID)
if err != nil && err != sql.ErrNoRows {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),
diff --git a/keyserver/internal/device_list_update.go b/userapi/internal/device_list_update.go
index 1b00d1ee..3b4dcf98 100644
--- a/keyserver/internal/device_list_update.go
+++ b/userapi/internal/device_list_update.go
@@ -33,8 +33,8 @@ import (
"github.com/sirupsen/logrus"
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
- "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/userapi/api"
)
var (
diff --git a/keyserver/internal/device_list_update_default.go b/userapi/internal/device_list_update_default.go
index 7d357c95..7d357c95 100644
--- a/keyserver/internal/device_list_update_default.go
+++ b/userapi/internal/device_list_update_default.go
diff --git a/keyserver/internal/device_list_update_sytest.go b/userapi/internal/device_list_update_sytest.go
index 1c60d2eb..1c60d2eb 100644
--- a/keyserver/internal/device_list_update_sytest.go
+++ b/userapi/internal/device_list_update_sytest.go
diff --git a/keyserver/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go
index 60a2c2f3..868fc9be 100644
--- a/keyserver/internal/device_list_update_test.go
+++ b/userapi/internal/device_list_update_test.go
@@ -29,13 +29,13 @@ import (
"github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage"
)
var (
@@ -360,12 +360,12 @@ func TestDebounce(t *testing.T) {
}
}
-func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) {
+func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
t.Helper()
base, _, _ := testrig.Base(nil)
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
- db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
+ db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
if err != nil {
t.Fatal(err)
}
diff --git a/keyserver/internal/internal.go b/userapi/internal/key_api.go
index 9a08a0bb..be816fe5 100644
--- a/keyserver/internal/internal.go
+++ b/userapi/internal/key_api.go
@@ -29,29 +29,11 @@ import (
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
- fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/producers"
- "github.com/matrix-org/dendrite/keyserver/storage"
- "github.com/matrix-org/dendrite/setup/config"
- userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/api"
)
-type KeyInternalAPI struct {
- DB storage.Database
- Cfg *config.KeyServer
- FedClient fedsenderapi.KeyserverFederationAPI
- UserAPI userapi.KeyserverUserAPI
- Producer *producers.KeyChange
- Updater *DeviceListUpdater
-}
-
-func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) {
- a.UserAPI = i
-}
-
-func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
- userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset)
+func (a *UserInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
+ userIDs, latest, err := a.KeyDatabase.KeyChanges(ctx, req.Offset, req.ToOffset)
if err != nil {
res.Error = &api.KeyError{
Err: err.Error(),
@@ -63,7 +45,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC
return nil
}
-func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
+func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
res.KeyErrors = make(map[string]map[string]*api.KeyError)
if len(req.DeviceKeys) > 0 {
a.uploadLocalDeviceKeys(ctx, req, res)
@@ -71,7 +53,7 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform
if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res)
}
- otks, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
+ otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
return err
}
@@ -79,7 +61,7 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform
return nil
}
-func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
+func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{})
// wrap request map in a top-level by-domain map
@@ -97,11 +79,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
domainToDeviceKeys[string(serverName)] = nested
}
for domain, local := range domainToDeviceKeys {
- if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
+ if !a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue
}
// claim local keys
- keys, err := a.DB.ClaimKeys(ctx, local)
+ keys, err := a.KeyDatabase.ClaimKeys(ctx, local)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
@@ -129,7 +111,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
return nil
}
-func (a *KeyInternalAPI) claimRemoteKeys(
+func (a *UserInternalAPI) claimRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
) {
var wg sync.WaitGroup // Wait for fan-out goroutines to finish
@@ -146,7 +128,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
defer cancel()
defer wg.Done()
- claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
+ claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
mu.Lock()
defer mu.Unlock()
@@ -177,8 +159,8 @@ func (a *KeyInternalAPI) claimRemoteKeys(
}).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys))
}
-func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
- if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
+func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
+ if err := a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to delete device keys: %s", err),
}
@@ -186,8 +168,8 @@ func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.Perform
return nil
}
-func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
- count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
+func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
+ count, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
@@ -198,8 +180,8 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne
return nil
}
-func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
- msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
+func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
+ msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
@@ -225,8 +207,8 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present
// in our database.
-func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error {
- knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{}, true)
+func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error {
+ knownDevices, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, []string{}, true)
if err != nil {
return err
}
@@ -244,7 +226,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap
}
// nolint:gocyclo
-func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
+func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
var respMu sync.Mutex
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
@@ -262,8 +244,8 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
domain := string(serverName)
// query local devices
- if a.Cfg.Matrix.IsLocalServerName(serverName) {
- deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
+ if a.Config.Matrix.IsLocalServerName(serverName) {
+ deviceKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query local device keys: %s", err),
@@ -276,8 +258,8 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
for _, dk := range deviceKeys {
dids = append(dids, dk.DeviceID)
}
- var queryRes userapi.QueryDeviceInfosResponse
- err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{
+ var queryRes api.QueryDeviceInfosResponse
+ err = a.QueryDeviceInfos(ctx, &api.QueryDeviceInfosRequest{
DeviceIDs: dids,
}, &queryRes)
if err != nil {
@@ -341,14 +323,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
for targetKeyID := range masterKey.Keys {
- sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
+ sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
if err != nil {
// Stop executing the function if the context was canceled/the deadline was exceeded,
// as we can't continue without a valid context.
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
- logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
+ logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
continue
}
if len(sigMap) == 0 {
@@ -367,14 +349,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
for targetUserID, forUserID := range res.DeviceKeys {
for targetKeyID, key := range forUserID {
- sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
+ sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
if err != nil {
// Stop executing the function if the context was canceled/the deadline was exceeded,
// as we can't continue without a valid context.
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
- logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
+ logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
continue
}
if len(sigMap) == 0 {
@@ -403,7 +385,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
return nil
}
-func (a *KeyInternalAPI) remoteKeysFromDatabase(
+func (a *UserInternalAPI) remoteKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
) map[string]map[string][]string {
fetchRemote := make(map[string]map[string][]string)
@@ -429,7 +411,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
return fetchRemote
}
-func (a *KeyInternalAPI) queryRemoteKeys(
+func (a *UserInternalAPI) queryRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse,
domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{},
) {
@@ -441,13 +423,13 @@ func (a *KeyInternalAPI) queryRemoteKeys(
domains := map[string]struct{}{}
for domain := range domainToDeviceKeys {
- if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
+ if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue
}
domains[domain] = struct{}{}
}
for domain := range domainToCrossSigningKeys {
- if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
+ if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue
}
domains[domain] = struct{}{}
@@ -499,7 +481,7 @@ func (a *KeyInternalAPI) queryRemoteKeys(
}
}
-func (a *KeyInternalAPI) queryRemoteKeysOnServer(
+func (a *UserInternalAPI) queryRemoteKeysOnServer(
ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{},
wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
res *api.QueryKeysResponse,
@@ -559,7 +541,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
if len(devKeys) == 0 {
return
}
- queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
+ queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
if err == nil {
resultCh <- &queryKeysResp
return
@@ -586,10 +568,10 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
respMu.Unlock()
}
-func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
+func (a *UserInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
) error {
- keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
+ keys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
if err != nil {
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
@@ -621,11 +603,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
return nil
}
-func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
+func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
// get a list of devices from the user API that actually exist, as
// we won't store keys for devices that don't exist
- uapidevices := &userapi.QueryDevicesResponse{}
- if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
+ uapidevices := &api.QueryDevicesResponse{}
+ if err := a.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
res.Error = &api.KeyError{
Err: err.Error(),
}
@@ -643,7 +625,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
}
// Get all of the user existing device keys so we can check for changes.
- existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
+ existingKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, true)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
@@ -662,7 +644,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
}
if len(toClean) > 0 {
- if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
+ if err = a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
} else {
logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
@@ -693,7 +675,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
if err != nil {
continue // ignore invalid users
}
- if !a.Cfg.Matrix.IsLocalServerName(serverName) {
+ if !a.Config.Matrix.IsLocalServerName(serverName) {
continue // ignore remote users
}
if len(key.KeyJSON) == 0 {
@@ -722,30 +704,30 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
}
// store the device keys and emit changes
- err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
+ err = a.KeyDatabase.StoreLocalDeviceKeys(ctx, keysToStore)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
}
return
}
- err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates)
+ err = emitDeviceKeyChanges(a.KeyChangeProducer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates)
if err != nil {
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
}
}
-func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
+func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
if req.UserID == "" {
res.Error = &api.KeyError{
Err: "user ID missing",
}
}
if req.DeviceID != "" && len(req.OneTimeKeys) == 0 {
- counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
+ counts, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
- Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err),
+ Err: fmt.Sprintf("a.KeyDatabase.OneTimeKeysCount: %s", err),
}
}
if counts != nil {
@@ -761,7 +743,7 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
keyIDsWithAlgorithms[i] = keyIDWithAlgo
i++
}
- existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
+ existingKeys, err := a.KeyDatabase.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: "failed to query existing one-time keys: " + err.Error(),
@@ -778,7 +760,7 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
}
}
// store one-time keys
- counts, err := a.DB.StoreOneTimeKeys(ctx, key)
+ counts, err := a.KeyDatabase.StoreOneTimeKeys(ctx, key)
if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()),
diff --git a/keyserver/internal/internal_test.go b/userapi/internal/key_api_test.go
index 8a2c9c5d..fc7e7e0d 100644
--- a/keyserver/internal/internal_test.go
+++ b/userapi/internal/key_api_test.go
@@ -5,23 +5,28 @@ import (
"reflect"
"testing"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/internal"
- "github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/internal"
+ "github.com/matrix-org/dendrite/userapi/storage"
)
-func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
+func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
- db, err := storage.NewDatabase(nil, &config.DatabaseOptions{
+ base, _, _ := testrig.Base(nil)
+ db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("failed to create new user db: %v", err)
}
- return db, close
+ return db, func() {
+ base.Close()
+ close()
+ }
}
func Test_QueryDeviceMessages(t *testing.T) {
@@ -140,8 +145,8 @@ func Test_QueryDeviceMessages(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- a := &internal.KeyInternalAPI{
- DB: db,
+ a := &internal.UserInternalAPI{
+ KeyDatabase: db,
}
if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr {
t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr)
diff --git a/userapi/internal/api.go b/userapi/internal/user_api.go
index 0bb480da..1cbd9719 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/user_api.go
@@ -23,6 +23,7 @@ import (
"strconv"
"time"
+ fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@@ -32,7 +33,6 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/sqlutil"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
synctypes "github.com/matrix-org/dendrite/syncapi/types"
@@ -44,17 +44,19 @@ import (
)
type UserInternalAPI struct {
- DB storage.Database
- SyncProducer *producers.SyncAPI
- Config *config.UserAPI
+ DB storage.UserDatabase
+ KeyDatabase storage.KeyDatabase
+ SyncProducer *producers.SyncAPI
+ KeyChangeProducer *producers.KeyChange
+ Config *config.UserAPI
DisableTLSValidation bool
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
- KeyAPI keyapi.UserKeyAPI
RSAPI rsapi.UserRoomserverAPI
PgClient pushgateway.Client
- Cfg *config.UserAPI
+ FedClient fedsenderapi.KeyserverFederationAPI
+ Updater *DeviceListUpdater
}
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
@@ -221,7 +223,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return fmt.Errorf("a.DB.SetDisplayName: %w", err)
}
- postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
+ postRegisterJoinRooms(a.Config, acc, a.RSAPI)
res.AccountCreated = true
res.Account = acc
@@ -293,14 +295,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
return err
}
// Ask the keyserver to delete device keys and signatures for those devices
- deleteReq := &keyapi.PerformDeleteKeysRequest{
+ deleteReq := &api.PerformDeleteKeysRequest{
UserID: req.UserID,
}
for _, keyID := range req.DeviceIDs {
deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID))
}
- deleteRes := &keyapi.PerformDeleteKeysResponse{}
- if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
+ deleteRes := &api.PerformDeleteKeysResponse{}
+ if err := a.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
return err
}
if err := deleteRes.Error; err != nil {
@@ -311,17 +313,17 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
}
func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error {
- deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs))
+ deviceKeys := make([]api.DeviceKeys, len(deviceIDs))
for i, did := range deviceIDs {
- deviceKeys[i] = keyapi.DeviceKeys{
+ deviceKeys[i] = api.DeviceKeys{
UserID: userID,
DeviceID: did,
KeyJSON: nil,
}
}
- var uploadRes keyapi.PerformUploadKeysResponse
- if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
+ var uploadRes api.PerformUploadKeysResponse
+ if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{
UserID: userID,
DeviceKeys: deviceKeys,
}, &uploadRes); err != nil {
@@ -385,10 +387,10 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
}
if req.DisplayName != nil && dev.DisplayName != *req.DisplayName {
// display name has changed: update the device key
- var uploadRes keyapi.PerformUploadKeysResponse
- if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
+ var uploadRes api.PerformUploadKeysResponse
+ if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{
UserID: req.RequestingUserID,
- DeviceKeys: []keyapi.DeviceKeys{
+ DeviceKeys: []api.DeviceKeys{
{
DeviceID: dev.ID,
DisplayName: *req.DisplayName,
diff --git a/keyserver/producers/keychange.go b/userapi/producers/keychange.go
index f86c3417..da6cea31 100644
--- a/keyserver/producers/keychange.go
+++ b/userapi/producers/keychange.go
@@ -18,9 +18,9 @@ import (
"context"
"encoding/json"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
)
@@ -28,8 +28,8 @@ import (
// KeyChange produces key change events for the sync API and federation sender to consume
type KeyChange struct {
Topic string
- JetStream nats.JetStreamContext
- DB storage.Database
+ JetStream JetStreamPublisher
+ DB storage.KeyChangeDatabase
}
// ProduceKeyChanges creates new change events for each key
diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go
index 51eaa985..165de899 100644
--- a/userapi/producers/syncapi.go
+++ b/userapi/producers/syncapi.go
@@ -19,13 +19,13 @@ type JetStreamPublisher interface {
// SyncAPI produces messages for the Sync API server to consume.
type SyncAPI struct {
- db storage.Database
+ db storage.Notification
producer JetStreamPublisher
clientDataTopic string
notificationDataTopic string
}
-func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
+func NewSyncAPI(db storage.UserDatabase, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
return &SyncAPI{
db: db,
producer: js,
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go
index c22b7658..27837886 100644
--- a/userapi/storage/interface.go
+++ b/userapi/storage/interface.go
@@ -90,7 +90,7 @@ type KeyBackup interface {
type LoginToken interface {
// CreateLoginToken generates a token, stores and returns it. The lifetime is
- // determined by the loginTokenLifetime given to the Database constructor.
+ // determined by the loginTokenLifetime given to the UserDatabase constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
@@ -130,7 +130,7 @@ type Notification interface {
DeleteOldNotifications(ctx context.Context) error
}
-type Database interface {
+type UserDatabase interface {
Account
AccountData
Device
@@ -144,6 +144,78 @@ type Database interface {
ThreePID
}
+type KeyChangeDatabase interface {
+ // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
+ // `userID` is the the user who has changed their keys in some way.
+ StoreKeyChange(ctx context.Context, userID string) (int64, error)
+}
+
+type KeyDatabase interface {
+ KeyChangeDatabase
+ // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
+ // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
+ ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
+
+ // StoreOneTimeKeys persists the given one-time keys.
+ StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
+
+ // OneTimeKeysCount returns a count of all OTKs for this device.
+ OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
+
+ // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
+ DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
+
+ // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
+ // for this (user, device).
+ // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
+ // Returns an error if there was a problem storing the keys.
+ StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
+
+ // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
+ // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
+ // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
+ StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
+
+ // PrevIDsExists returns true if all prev IDs exist for this user.
+ PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
+
+ // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
+ // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
+ DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
+
+ // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
+ // cross-signing signatures relating to that device.
+ DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error
+
+ // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
+ // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
+ ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
+
+ // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
+ // A to offset of types.OffsetNewest means no upper limit.
+ // Returns the offset of the latest key change.
+ KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
+
+ // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
+ // If no domains are given, all user IDs with stale device lists are returned.
+ StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
+
+ // MarkDeviceListStale sets the stale bit for this user to isStale.
+ MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
+
+ CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error)
+ CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error)
+ CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
+
+ StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
+ StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
+
+ DeleteStaleDeviceLists(
+ ctx context.Context,
+ userIDs []string,
+ ) error
+}
+
type Statistics interface {
UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error)
DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error)
diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go
index 2a4777d7..05716037 100644
--- a/userapi/storage/postgres/account_data_table.go
+++ b/userapi/storage/postgres/account_data_table.go
@@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData(
roomID, dataType string, content json.RawMessage,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
- _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
+ // Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage
+ // when passing the data to trigger "NOT NULL" constraint
+ var data *json.RawMessage
+ if len(content) > 0 {
+ data = &content
+ }
+ _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data)
return
}
diff --git a/keyserver/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go
index 1022157e..c0ecbd30 100644
--- a/keyserver/storage/postgres/cross_signing_keys_table.go
+++ b/userapi/storage/postgres/cross_signing_keys_table.go
@@ -21,8 +21,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
- "github.com/matrix-org/dendrite/keyserver/types"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/userapi/storage/postgres/cross_signing_sigs_table.go
index 4536b7d8..b0117145 100644
--- a/keyserver/storage/postgres/cross_signing_sigs_table.go
+++ b/userapi/storage/postgres/cross_signing_sigs_table.go
@@ -21,9 +21,9 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
- "github.com/matrix-org/dendrite/keyserver/types"
+ "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
diff --git a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go
index 0cfe9e79..0cfe9e79 100644
--- a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go
+++ b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go
diff --git a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go
index 1a3d4fee..1a3d4fee 100644
--- a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go
+++ b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go
diff --git a/keyserver/storage/postgres/device_keys_table.go b/userapi/storage/postgres/device_keys_table.go
index 2aa11c52..a9203857 100644
--- a/keyserver/storage/postgres/device_keys_table.go
+++ b/userapi/storage/postgres/device_keys_table.go
@@ -23,8 +23,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
)
var deviceKeysSchema = `
@@ -92,31 +92,16 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if err != nil {
return nil, err
}
- if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
- return nil, err
- }
- if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
- return nil, err
- }
- if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
- return nil, err
- }
- if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
- return nil, err
- }
- return s, nil
+ return s, sqlutil.StatementList{
+ {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
+ {&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
+ {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
+ {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL},
+ {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
+ {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
+ }.Prepare(db)
}
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go
index 7481ac5b..88f8839c 100644
--- a/userapi/storage/postgres/devices_table.go
+++ b/userapi/storage/postgres/devices_table.go
@@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice(
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
}
- return &api.Device{
+ dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@@ -168,7 +168,11 @@ func (s *devicesStatements) InsertDevice(
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
- }, nil
+ }
+ if displayName != nil {
+ dev.DisplayName = *displayName
+ }
+ return dev, nil
}
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go
index 7b58f7ba..91a34c35 100644
--- a/userapi/storage/postgres/key_backup_table.go
+++ b/userapi/storage/postgres/key_backup_table.go
@@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
const countKeysSQL = "" +
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
-const selectKeysSQL = "" +
+const selectBackupKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
@@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
{&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL},
- {&s.selectKeysStmt, selectKeysSQL},
+ {&s.selectKeysStmt, selectBackupKeysSQL},
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
}.Prepare(db)
diff --git a/keyserver/storage/postgres/key_changes_table.go b/userapi/storage/postgres/key_changes_table.go
index c0e3429c..a0049414 100644
--- a/keyserver/storage/postgres/key_changes_table.go
+++ b/userapi/storage/postgres/key_changes_table.go
@@ -22,8 +22,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
)
var keyChangesSchema = `
@@ -66,7 +66,10 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
if err = executeMigration(context.Background(), db); err != nil {
return nil, err
}
- return s, nil
+ return s, sqlutil.StatementList{
+ {&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
+ {&s.selectKeyChangesStmt, selectKeyChangesSQL},
+ }.Prepare(db)
}
func executeMigration(ctx context.Context, db *sql.DB) error {
@@ -95,16 +98,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error {
return m.Up(ctx)
}
-func (s *keyChangesStatements) Prepare() (err error) {
- if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil {
- return err
- }
- if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil {
- return err
- }
- return nil
-}
-
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
return
diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/userapi/storage/postgres/one_time_keys_table.go
index 2117efca..972a5914 100644
--- a/keyserver/storage/postgres/one_time_keys_table.go
+++ b/userapi/storage/postgres/one_time_keys_table.go
@@ -23,8 +23,8 @@ import (
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
)
var oneTimeKeysSchema = `
@@ -49,7 +49,7 @@ const upsertKeysSQL = "" +
" ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" +
" DO UPDATE SET key_json = $6"
-const selectKeysSQL = "" +
+const selectOneTimeKeysSQL = "" +
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
const selectKeysCountSQL = "" +
@@ -84,25 +84,14 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
if err != nil {
return nil, err
}
- if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil {
- return nil, err
- }
- if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
- return nil, err
- }
- if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
- return nil, err
- }
- if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
- return nil, err
- }
- if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
- return nil, err
- }
- if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil {
- return nil, err
- }
- return s, nil
+ return s, sqlutil.StatementList{
+ {&s.upsertKeysStmt, upsertKeysSQL},
+ {&s.selectKeysStmt, selectOneTimeKeysSQL},
+ {&s.selectKeysCountStmt, selectKeysCountSQL},
+ {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
+ {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
+ {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
+ }.Prepare(db)
}
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
diff --git a/keyserver/storage/postgres/stale_device_lists.go b/userapi/storage/postgres/stale_device_lists.go
index 248ddfb4..c823b58c 100644
--- a/keyserver/storage/postgres/stale_device_lists.go
+++ b/userapi/storage/postgres/stale_device_lists.go
@@ -24,7 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go
index 92dc4808..673d123b 100644
--- a/userapi/storage/postgres/storage.go
+++ b/userapi/storage/postgres/storage.go
@@ -136,3 +136,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}
+
+func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
+ db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter())
+ if err != nil {
+ return nil, err
+ }
+ otk, err := NewPostgresOneTimeKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ dk, err := NewPostgresDeviceKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ kc, err := NewPostgresKeyChangesTable(db)
+ if err != nil {
+ return nil, err
+ }
+ sdl, err := NewPostgresStaleDeviceListsTable(db)
+ if err != nil {
+ return nil, err
+ }
+ csk, err := NewPostgresCrossSigningKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ css, err := NewPostgresCrossSigningSigsTable(db)
+ if err != nil {
+ return nil, err
+ }
+
+ return &shared.KeyDatabase{
+ OneTimeKeysTable: otk,
+ DeviceKeysTable: dk,
+ KeyChangesTable: kc,
+ StaleDeviceListsTable: sdl,
+ CrossSigningKeysTable: csk,
+ CrossSigningSigsTable: css,
+ Writer: writer,
+ }, nil
+}
diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go
index bf94f14d..d3272a03 100644
--- a/userapi/storage/shared/storage.go
+++ b/userapi/storage/shared/storage.go
@@ -59,6 +59,17 @@ type Database struct {
OpenIDTokenLifetimeMS int64
}
+type KeyDatabase struct {
+ OneTimeKeysTable tables.OneTimeKeys
+ DeviceKeysTable tables.DeviceKeys
+ KeyChangesTable tables.KeyChanges
+ StaleDeviceListsTable tables.StaleDeviceLists
+ CrossSigningKeysTable tables.CrossSigningKeys
+ CrossSigningSigsTable tables.CrossSigningSigs
+ DB *sql.DB
+ Writer sqlutil.Writer
+}
+
const (
// The length of generated device IDs
deviceIDByteLength = 6
@@ -875,3 +886,227 @@ func (d *Database) DailyRoomsMessages(
) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) {
return d.Stats.DailyRoomsMessages(ctx, nil, serverName)
}
+
+//
+
+func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
+ return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
+}
+
+func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
+ _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys)
+ return err
+ })
+ return
+}
+
+func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
+ return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
+}
+
+func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
+ return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
+}
+
+func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
+ count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
+ if err != nil {
+ return false, err
+ }
+ return count == len(prevIDs), nil
+}
+
+func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for _, userID := range clearUserIDs {
+ err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
+ if err != nil {
+ return err
+ }
+ }
+ return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
+ })
+}
+
+func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
+ // work out the latest stream IDs for each user
+ userIDToStreamID := make(map[string]int64)
+ for _, k := range keys {
+ userIDToStreamID[k.UserID] = 0
+ }
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for userID := range userIDToStreamID {
+ streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
+ if err != nil {
+ return err
+ }
+ userIDToStreamID[userID] = streamID
+ }
+ // set the stream IDs for each key
+ for i := range keys {
+ k := keys[i]
+ userIDToStreamID[k.UserID]++ // start stream from 1
+ k.StreamID = userIDToStreamID[k.UserID]
+ keys[i] = k
+ }
+ return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
+ })
+}
+
+func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
+ return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
+}
+
+func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
+ var result []api.OneTimeKeys
+ err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for userID, deviceToAlgo := range userToDeviceToAlgorithm {
+ for deviceID, algo := range deviceToAlgo {
+ keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
+ if err != nil {
+ return err
+ }
+ if keyJSON != nil {
+ result = append(result, api.OneTimeKeys{
+ UserID: userID,
+ DeviceID: deviceID,
+ KeyJSON: keyJSON,
+ })
+ }
+ }
+ }
+ return nil
+ })
+ return result, err
+}
+
+func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
+ err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
+ id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
+ return err
+ })
+ return
+}
+
+func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
+ return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
+}
+
+// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
+// If no domains are given, all user IDs with stale device lists are returned.
+func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
+ return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
+}
+
+// MarkDeviceListStale sets the stale bit for this user to isStale.
+func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
+ return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
+ return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
+ })
+}
+
+// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
+// cross-signing signatures relating to that device.
+func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for _, deviceID := range deviceIDs {
+ if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows {
+ return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err)
+ }
+ if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
+ return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err)
+ }
+ if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
+ return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err)
+ }
+ }
+ return nil
+ })
+}
+
+// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
+func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) {
+ keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
+ if err != nil {
+ return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err)
+ }
+ results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
+ for purpose, key := range keyMap {
+ keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode())
+ result := gomatrixserverlib.CrossSigningKey{
+ UserID: userID,
+ Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose},
+ Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
+ keyID: key,
+ },
+ }
+ sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID)
+ if err != nil {
+ continue
+ }
+ for sigUserID, forSigUserID := range sigMap {
+ if userID != sigUserID {
+ continue
+ }
+ if result.Signatures == nil {
+ result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ if _, ok := result.Signatures[sigUserID]; !ok {
+ result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ for sigKeyID, sigBytes := range forSigUserID {
+ result.Signatures[sigUserID][sigKeyID] = sigBytes
+ }
+ }
+ results[purpose] = result
+ }
+ return results, nil
+}
+
+// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
+func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) {
+ return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
+}
+
+// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
+func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
+ return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID)
+}
+
+// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
+func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for keyType, keyData := range keyMap {
+ if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
+ return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
+ }
+ }
+ return nil
+ })
+}
+
+// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
+func (d *KeyDatabase) StoreCrossSigningSigsForTarget(
+ ctx context.Context,
+ originUserID string, originKeyID gomatrixserverlib.KeyID,
+ targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+ signature gomatrixserverlib.Base64Bytes,
+) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
+ return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err)
+ }
+ return nil
+ })
+}
+
+// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
+func (d *KeyDatabase) DeleteStaleDeviceLists(
+ ctx context.Context,
+ userIDs []string,
+) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
+ })
+}
diff --git a/keyserver/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go
index e103d988..10721fcc 100644
--- a/keyserver/storage/sqlite3/cross_signing_keys_table.go
+++ b/userapi/storage/sqlite3/cross_signing_keys_table.go
@@ -21,8 +21,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
- "github.com/matrix-org/dendrite/keyserver/types"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/userapi/storage/sqlite3/cross_signing_sigs_table.go
index 7a153e8f..2be00c9c 100644
--- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go
+++ b/userapi/storage/sqlite3/cross_signing_sigs_table.go
@@ -21,9 +21,9 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
- "github.com/matrix-org/dendrite/keyserver/types"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
diff --git a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go
index cd0f19df..cd0f19df 100644
--- a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go
+++ b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go
diff --git a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go
index d4e38dea..d4e38dea 100644
--- a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go
+++ b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/userapi/storage/sqlite3/device_keys_table.go
index 73768da5..15e69cc4 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/userapi/storage/sqlite3/device_keys_table.go
@@ -22,8 +22,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
)
var deviceKeysSchema = `
@@ -86,28 +86,16 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if err != nil {
return nil, err
}
- if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
- return nil, err
- }
- if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
- return nil, err
- }
- if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
- return nil, err
- }
- return s, nil
+ return s, sqlutil.StatementList{
+ {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
+ {&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
+ {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
+ // {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, // prepared at runtime
+ {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
+ {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
+ }.Prepare(db)
}
func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go
index 449e4549..65e17527 100644
--- a/userapi/storage/sqlite3/devices_table.go
+++ b/userapi/storage/sqlite3/devices_table.go
@@ -151,7 +151,7 @@ func (s *devicesStatements) InsertDevice(
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
- return &api.Device{
+ dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@@ -159,7 +159,11 @@ func (s *devicesStatements) InsertDevice(
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
- }, nil
+ }
+ if displayName != nil {
+ dev.DisplayName = *displayName
+ }
+ return dev, nil
}
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
@@ -172,7 +176,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
- return &api.Device{
+ dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@@ -180,7 +184,11 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
- }, nil
+ }
+ if displayName != nil {
+ dev.DisplayName = *displayName
+ }
+ return dev, nil
}
func (s *devicesStatements) DeleteDevice(
@@ -202,6 +210,7 @@ func (s *devicesStatements) DeleteDevices(
if err != nil {
return err
}
+ defer internal.CloseAndLogIfError(ctx, prep, "DeleteDevices.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+2)
params[0] = localpart
diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go
index 7883ffb1..ed274631 100644
--- a/userapi/storage/sqlite3/key_backup_table.go
+++ b/userapi/storage/sqlite3/key_backup_table.go
@@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
const countKeysSQL = "" +
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
-const selectKeysSQL = "" +
+const selectBackupKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
@@ -83,7 +83,7 @@ func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
{&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL},
- {&s.selectKeysStmt, selectKeysSQL},
+ {&s.selectKeysStmt, selectBackupKeysSQL},
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
}.Prepare(db)
diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/userapi/storage/sqlite3/key_changes_table.go
index 0c844d67..923bb57e 100644
--- a/keyserver/storage/sqlite3/key_changes_table.go
+++ b/userapi/storage/sqlite3/key_changes_table.go
@@ -22,8 +22,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
)
var keyChangesSchema = `
@@ -65,7 +65,10 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
return nil, err
}
- return s, nil
+ return s, sqlutil.StatementList{
+ {&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
+ {&s.selectKeyChangesStmt, selectKeyChangesSQL},
+ }.Prepare(db)
}
func executeMigration(ctx context.Context, db *sql.DB) error {
@@ -93,16 +96,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error {
return m.Up(ctx)
}
-func (s *keyChangesStatements) Prepare() (err error) {
- if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil {
- return err
- }
- if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil {
- return err
- }
- return nil
-}
-
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
return
diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/userapi/storage/sqlite3/one_time_keys_table.go
index 7a923d0e..a992d399 100644
--- a/keyserver/storage/sqlite3/one_time_keys_table.go
+++ b/userapi/storage/sqlite3/one_time_keys_table.go
@@ -22,8 +22,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
)
var oneTimeKeysSchema = `
@@ -48,7 +48,7 @@ const upsertKeysSQL = "" +
" ON CONFLICT (user_id, device_id, key_id, algorithm)" +
" DO UPDATE SET key_json = $6"
-const selectKeysSQL = "" +
+const selectOneTimeKeysSQL = "" +
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
const selectKeysCountSQL = "" +
@@ -83,25 +83,14 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
if err != nil {
return nil, err
}
- if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil {
- return nil, err
- }
- if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
- return nil, err
- }
- if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
- return nil, err
- }
- if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
- return nil, err
- }
- if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
- return nil, err
- }
- if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil {
- return nil, err
- }
- return s, nil
+ return s, sqlutil.StatementList{
+ {&s.upsertKeysStmt, upsertKeysSQL},
+ {&s.selectKeysStmt, selectOneTimeKeysSQL},
+ {&s.selectKeysCountStmt, selectKeysCountSQL},
+ {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
+ {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
+ {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
+ }.Prepare(db)
}
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/userapi/storage/sqlite3/stale_device_lists.go
index fd76a6e3..f078fc99 100644
--- a/keyserver/storage/sqlite3/stale_device_lists.go
+++ b/userapi/storage/sqlite3/stale_device_lists.go
@@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go
index a1365c94..72b3ba49 100644
--- a/userapi/storage/sqlite3/stats_table.go
+++ b/userapi/storage/sqlite3/stats_table.go
@@ -256,6 +256,7 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int
if err != nil {
return 0, err
}
+ defer internal.CloseAndLogIfError(ctx, queryStmt, "allUsers.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
err = stmt.QueryRowContext(ctx,
1, 2, 3, 4,
@@ -269,6 +270,7 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res
if err != nil {
return 0, err
}
+ defer internal.CloseAndLogIfError(ctx, queryStmt, "nonBridgedUsers.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
err = stmt.QueryRowContext(ctx,
1, 2, 3,
@@ -286,6 +288,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx)
if err != nil {
return nil, err
}
+ defer internal.CloseAndLogIfError(ctx, queryStmt, "registeredUserByType.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
registeredAfter := time.Now().AddDate(0, 0, -30)
diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go
index 85a1f706..0f3eeed1 100644
--- a/userapi/storage/sqlite3/storage.go
+++ b/userapi/storage/sqlite3/storage.go
@@ -30,8 +30,8 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
)
-// NewDatabase creates a new accounts and profiles database
-func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
+// NewUserDatabase creates a new accounts and profiles database
+func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
if err != nil {
return nil, err
@@ -134,3 +134,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}
+
+func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
+ db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
+ if err != nil {
+ return nil, err
+ }
+ otk, err := NewSqliteOneTimeKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ dk, err := NewSqliteDeviceKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ kc, err := NewSqliteKeyChangesTable(db)
+ if err != nil {
+ return nil, err
+ }
+ sdl, err := NewSqliteStaleDeviceListsTable(db)
+ if err != nil {
+ return nil, err
+ }
+ csk, err := NewSqliteCrossSigningKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ css, err := NewSqliteCrossSigningSigsTable(db)
+ if err != nil {
+ return nil, err
+ }
+
+ return &shared.KeyDatabase{
+ OneTimeKeysTable: otk,
+ DeviceKeysTable: dk,
+ KeyChangesTable: kc,
+ StaleDeviceListsTable: sdl,
+ CrossSigningKeysTable: csk,
+ CrossSigningSigsTable: css,
+ Writer: writer,
+ }, nil
+}
diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go
index 42221e75..0329fb46 100644
--- a/userapi/storage/storage.go
+++ b/userapi/storage/storage.go
@@ -29,15 +29,36 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
)
-// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
+// NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters
-func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
+func NewUserDatabase(
+ base *base.BaseDendrite,
+ dbProperties *config.DatabaseOptions,
+ serverName gomatrixserverlib.ServerName,
+ bcryptCost int,
+ openIDTokenLifetimeMS int64,
+ loginTokenLifetime time.Duration,
+ serverNoticesLocalpart string,
+) (UserDatabase, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
+ return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
default:
return nil, fmt.Errorf("unexpected database type")
}
}
+
+// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme)
+// and sets postgres connection parameters.
+func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) {
+ switch {
+ case dbProperties.ConnectionString.IsSQLite():
+ return sqlite3.NewKeyDatabase(base, dbProperties)
+ case dbProperties.ConnectionString.IsPostgres():
+ return postgres.NewKeyDatabase(base, dbProperties)
+ default:
+ return nil, fmt.Errorf("unexpected database type")
+ }
+}
diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go
index 23aafff0..f52e7e17 100644
--- a/userapi/storage/storage_test.go
+++ b/userapi/storage/storage_test.go
@@ -4,9 +4,12 @@ import (
"context"
"encoding/json"
"fmt"
+ "reflect"
+ "sync"
"testing"
"time"
+ "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
@@ -29,14 +32,14 @@ var (
ctx = context.Background()
)
-func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
+func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType)
- db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
+ db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
if err != nil {
- t.Fatalf("NewUserAPIDatabase returned %s", err)
+ t.Fatalf("NewUserDatabase returned %s", err)
}
return db, func() {
close()
@@ -47,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun
// Tests storing and getting account data
func Test_AccountData(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
@@ -78,7 +81,7 @@ func Test_AccountData(t *testing.T) {
// Tests the creation of accounts
func Test_Accounts(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
@@ -158,7 +161,7 @@ func Test_Devices(t *testing.T) {
accessToken := util.RandomString(16)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
@@ -238,7 +241,7 @@ func Test_KeyBackup(t *testing.T) {
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
wantAuthData := json.RawMessage("my auth data")
@@ -315,7 +318,7 @@ func Test_KeyBackup(t *testing.T) {
func Test_LoginToken(t *testing.T) {
alice := test.NewUser(t)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
// create a new token
@@ -347,7 +350,7 @@ func Test_OpenID(t *testing.T) {
token := util.RandomString(24)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS
@@ -368,7 +371,7 @@ func Test_Profile(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
// create account, which also creates a profile
@@ -417,7 +420,7 @@ func Test_Pusher(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
appID := util.RandomString(8)
@@ -468,7 +471,7 @@ func Test_ThreePID(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
threePID := util.RandomString(8)
medium := util.RandomString(8)
@@ -507,7 +510,7 @@ func Test_Notification(t *testing.T) {
room := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
// generate some dummy notifications
for i := 0; i < 10; i++ {
@@ -571,3 +574,184 @@ func Test_Notification(t *testing.T) {
assert.Equal(t, int64(0), total)
})
}
+
+func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
+ base, close := testrig.CreateBaseDendrite(t, dbType)
+ db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database)
+ if err != nil {
+ t.Fatalf("failed to create new database: %v", err)
+ }
+ return db, close
+}
+
+func MustNotError(t *testing.T, err error) {
+ t.Helper()
+ if err == nil {
+ return
+ }
+ t.Fatalf("operation failed: %s", err)
+}
+
+func TestKeyChanges(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ _, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
+ MustNotError(t, err)
+ deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
+ MustNotError(t, err)
+ userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
+ if err != nil {
+ t.Fatalf("Failed to KeyChanges: %s", err)
+ }
+ if latest != deviceChangeIDC {
+ t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
+ }
+ if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
+ t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
+ }
+ })
+}
+
+func TestKeyChangesNoDupes(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ if deviceChangeIDA == deviceChangeIDB {
+ t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
+ }
+ deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
+ if err != nil {
+ t.Fatalf("Failed to KeyChanges: %s", err)
+ }
+ if latest != deviceChangeID {
+ t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
+ }
+ if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
+ t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
+ }
+ })
+}
+
+func TestKeyChangesUpperLimit(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
+ MustNotError(t, err)
+ _, err = db.StoreKeyChange(ctx, "@charlie:localhost")
+ MustNotError(t, err)
+ userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
+ if err != nil {
+ t.Fatalf("Failed to KeyChanges: %s", err)
+ }
+ if latest != deviceChangeIDB {
+ t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
+ }
+ if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
+ t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
+ }
+ })
+}
+
+var dbLock sync.Mutex
+var deviceArray = []string{"AAA", "another_device"}
+
+// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
+// and that they are returned correctly when querying for device keys.
+func TestDeviceKeysStreamIDGeneration(t *testing.T) {
+ var err error
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ alice := "@alice:TestDeviceKeysStreamIDGeneration"
+ bob := "@bob:TestDeviceKeysStreamIDGeneration"
+ msgs := []api.DeviceMessage{
+ {
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 1
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: bob,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 1 as this is a different user
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: "another_device",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 2 as this is a 2nd device key
+ },
+ }
+ MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
+ if msgs[0].StreamID != 1 {
+ t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
+ }
+ if msgs[1].StreamID != 1 {
+ t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
+ }
+ if msgs[2].StreamID != 2 {
+ t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
+ }
+
+ // updating a device sets the next stream ID for that user
+ msgs = []api.DeviceMessage{
+ {
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v2"}`),
+ },
+ // StreamID: 3
+ },
+ }
+ MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
+ if msgs[0].StreamID != 3 {
+ t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
+ }
+
+ dbLock.Lock()
+ defer dbLock.Unlock()
+ // Querying for device keys returns the latest stream IDs
+ msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false)
+
+ if err != nil {
+ t.Fatalf("DeviceKeysForUser returned error: %s", err)
+ }
+ wantStreamIDs := map[string]int64{
+ "AAA": 3,
+ "another_device": 2,
+ }
+ if len(msgs) != len(wantStreamIDs) {
+ t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
+ }
+ for _, m := range msgs {
+ if m.StreamID != wantStreamIDs[m.DeviceID] {
+ t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
+ }
+ }
+ })
+}
diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go
index 5d5d292e..163e3e17 100644
--- a/userapi/storage/storage_wasm.go
+++ b/userapi/storage/storage_wasm.go
@@ -32,10 +32,10 @@ func NewUserAPIDatabase(
openIDTokenLifetimeMS int64,
loginTokenLifetime time.Duration,
serverNoticesLocalpart string,
-) (Database, error) {
+) (UserDatabase, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
+ return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go
index 9221e571..693e7303 100644
--- a/userapi/storage/tables/interface.go
+++ b/userapi/storage/tables/interface.go
@@ -20,10 +20,10 @@ import (
"encoding/json"
"time"
+ "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
- "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/types"
)
@@ -145,3 +145,47 @@ const (
// uint32.
AllNotifications NotificationFilter = (1 << 31) - 1
)
+
+type OneTimeKeys interface {
+ SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
+ CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
+ InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
+ // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
+ // Returns an empty map if the key does not exist.
+ SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
+ DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
+}
+
+type DeviceKeys interface {
+ SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
+ InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
+ SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
+ CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
+ SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
+ DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
+ DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
+}
+
+type KeyChanges interface {
+ InsertKeyChange(ctx context.Context, userID string) (int64, error)
+ // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
+ // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset.
+ SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
+}
+
+type StaleDeviceLists interface {
+ InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
+ SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
+ DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
+}
+
+type CrossSigningKeys interface {
+ SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error)
+ UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error
+}
+
+type CrossSigningSigs interface {
+ SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
+ UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
+ DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error
+}
diff --git a/keyserver/storage/tables/stale_device_lists_test.go b/userapi/storage/tables/stale_device_lists_test.go
index 76d3badd..b9bdafda 100644
--- a/keyserver/storage/tables/stale_device_lists_test.go
+++ b/userapi/storage/tables/stale_device_lists_test.go
@@ -4,15 +4,15 @@ import (
"context"
"testing"
+ "github.com/matrix-org/dendrite/userapi/storage/postgres"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/keyserver/storage/postgres"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
)
func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {
diff --git a/keyserver/types/storage.go b/userapi/types/storage.go
index 7fb90454..7fb90454 100644
--- a/keyserver/types/storage.go
+++ b/userapi/types/storage.go
diff --git a/userapi/userapi.go b/userapi/userapi.go
index 2dd81d75..826bd721 100644
--- a/userapi/userapi.go
+++ b/userapi/userapi.go
@@ -17,13 +17,11 @@ package userapi
import (
"time"
+ fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/sirupsen/logrus"
- "github.com/matrix-org/dendrite/internal/pushgateway"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
- "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/consumers"
@@ -33,16 +31,20 @@ import (
"github.com/matrix-org/dendrite/userapi/util"
)
-// NewInternalAPI returns a concerete implementation of the internal API. Callers
+// NewInternalAPI returns a concrete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
- base *base.BaseDendrite, cfg *config.UserAPI,
- appServices []config.ApplicationService, keyAPI keyapi.UserKeyAPI,
- rsAPI rsapi.UserRoomserverAPI, pgClient pushgateway.Client,
-) api.UserInternalAPI {
+ base *base.BaseDendrite,
+ rsAPI rsapi.UserRoomserverAPI,
+ fedClient fedsenderapi.KeyserverFederationAPI,
+) *internal.UserInternalAPI {
+ cfg := &base.Cfg.UserAPI
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
+ appServices := base.Cfg.Derived.ApplicationServices
- db, err := storage.NewUserAPIDatabase(
+ pgClient := base.PushGatewayHTTPClient()
+
+ db, err := storage.NewUserDatabase(
base,
&cfg.AccountDatabase,
cfg.Matrix.ServerName,
@@ -55,6 +57,11 @@ func NewInternalAPI(
logrus.WithError(err).Panicf("failed to connect to accounts db")
}
+ keyDB, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database)
+ if err != nil {
+ logrus.WithError(err).Panicf("failed to connect to key db")
+ }
+
syncProducer := producers.NewSyncAPI(
db, js,
// TODO: user API should handle syncs for account data. Right now,
@@ -64,17 +71,50 @@ func NewInternalAPI(
cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData),
cfg.Matrix.JetStream.Prefixed(jetstream.OutputNotificationData),
)
+ keyChangeProducer := &producers.KeyChange{
+ Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
+ JetStream: js,
+ DB: keyDB,
+ }
userAPI := &internal.UserInternalAPI{
DB: db,
+ KeyDatabase: keyDB,
SyncProducer: syncProducer,
+ KeyChangeProducer: keyChangeProducer,
Config: cfg,
AppServices: appServices,
- KeyAPI: keyAPI,
RSAPI: rsAPI,
DisableTLSValidation: cfg.PushGatewayDisableTLSValidation,
PgClient: pgClient,
- Cfg: cfg,
+ FedClient: fedClient,
+ }
+
+ updater := internal.NewDeviceListUpdater(base.ProcessContext, keyDB, userAPI, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable
+ userAPI.Updater = updater
+ // Remove users which we don't share a room with anymore
+ if err := updater.CleanUp(); err != nil {
+ logrus.WithError(err).Error("failed to cleanup stale device lists")
+ }
+
+ go func() {
+ if err := updater.Start(); err != nil {
+ logrus.WithError(err).Panicf("failed to start device list updater")
+ }
+ }()
+
+ dlConsumer := consumers.NewDeviceListUpdateConsumer(
+ base.ProcessContext, cfg, js, updater,
+ )
+ if err := dlConsumer.Start(); err != nil {
+ logrus.WithError(err).Panic("failed to start device list consumer")
+ }
+
+ sigConsumer := consumers.NewSigningKeyUpdateConsumer(
+ base.ProcessContext, cfg, js, userAPI,
+ )
+ if err := sigConsumer.Start(); err != nil {
+ logrus.WithError(err).Panic("failed to start signing key consumer")
}
receiptConsumer := consumers.NewOutputReceiptEventConsumer(
diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go
index 68d08c2f..08b1336b 100644
--- a/userapi/userapi_test.go
+++ b/userapi/userapi_test.go
@@ -21,7 +21,10 @@ import (
"testing"
"time"
+ "github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/nats-io/nats.go"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/setup/config"
@@ -38,32 +41,55 @@ const (
type apiTestOpts struct {
loginTokenLifetime time.Duration
+ serverName string
}
-func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) {
+type dummyProducer struct{}
+
+func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) {
+ return &nats.PubAck{}, nil
+}
+
+func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.UserDatabase, func()) {
if opts.loginTokenLifetime == 0 {
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
}
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType)
- accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
+ sName := serverName
+ if opts.serverName != "" {
+ sName = gomatrixserverlib.ServerName(opts.serverName)
+ }
+ accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
- }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
+ }, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
if err != nil {
t.Fatalf("failed to create account DB: %s", err)
}
+ keyDB, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ })
+ if err != nil {
+ t.Fatalf("failed to create key DB: %s", err)
+ }
+
cfg := &config.UserAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
- ServerName: serverName,
+ ServerName: sName,
},
},
}
+ syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "")
+ keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: &dummyProducer{}}
return &internal.UserInternalAPI{
- DB: accountDB,
- Config: cfg,
+ DB: accountDB,
+ KeyDatabase: keyDB,
+ Config: cfg,
+ SyncProducer: syncProducer,
+ KeyChangeProducer: keyChangeProducer,
}, accountDB, func() {
close()
baseclose()
@@ -332,3 +358,292 @@ func TestQueryAccountByLocalpart(t *testing.T) {
testCases(t, intAPI)
})
}
+
+func TestAccountData(t *testing.T) {
+ ctx := context.Background()
+ alice := test.NewUser(t)
+
+ testCases := []struct {
+ name string
+ inputData *api.InputAccountDataRequest
+ wantErr bool
+ }{
+ {
+ name: "not a local user",
+ inputData: &api.InputAccountDataRequest{UserID: "@notlocal:example.com"},
+ wantErr: true,
+ },
+ {
+ name: "local user missing datatype",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID},
+ wantErr: true,
+ },
+ {
+ name: "missing json",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: nil},
+ wantErr: true,
+ },
+ {
+ name: "with json",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}")},
+ },
+ {
+ name: "room data",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}"), RoomID: "!dummy:test"},
+ },
+ {
+ name: "ignored users",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.ignored_user_list", AccountData: []byte("{}")},
+ },
+ {
+ name: "m.fully_read",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.fully_read", AccountData: []byte("{}")},
+ },
+ }
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
+ defer close()
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ res := api.InputAccountDataResponse{}
+ err := intAPI.InputAccountData(ctx, tc.inputData, &res)
+ if tc.wantErr && err == nil {
+ t.Fatalf("expected an error, but got none")
+ }
+ if !tc.wantErr && err != nil {
+ t.Fatalf("expected no error, but got: %s", err)
+ }
+
+ // query the data again and compare
+ queryRes := api.QueryAccountDataResponse{}
+ queryReq := api.QueryAccountDataRequest{
+ UserID: tc.inputData.UserID,
+ DataType: tc.inputData.DataType,
+ RoomID: tc.inputData.RoomID,
+ }
+ err = intAPI.QueryAccountData(ctx, &queryReq, &queryRes)
+ if err != nil && !tc.wantErr {
+ t.Fatal(err)
+ }
+ // verify global data
+ if tc.inputData.RoomID == "" {
+ if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.GlobalAccountData[tc.inputData.DataType]) {
+ t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.GlobalAccountData[tc.inputData.DataType]))
+ }
+ } else {
+ // verify room data
+ if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]) {
+ t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]))
+ }
+ }
+ })
+ }
+ })
+}
+
+func TestDevices(t *testing.T) {
+ ctx := context.Background()
+
+ dupeAccessToken := util.RandomString(8)
+
+ displayName := "testing"
+
+ creationTests := []struct {
+ name string
+ inputData *api.PerformDeviceCreationRequest
+ wantErr bool
+ wantNewDevID bool
+ }{
+ {
+ name: "not a local user",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal"},
+ wantErr: true,
+ },
+ {
+ name: "implicit local user",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, DeviceDisplayName: &displayName},
+ },
+ {
+ name: "explicit local user",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
+ },
+ {
+ name: "dupe token - ok",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true},
+ },
+ {
+ name: "dupe token - not ok",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true},
+ wantErr: true,
+ },
+ {
+ name: "test3 second device", // used to test deletion later
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
+ },
+ {
+ name: "test3 third device", // used to test deletion later
+ wantNewDevID: true,
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
+ },
+ }
+
+ deletionTests := []struct {
+ name string
+ inputData *api.PerformDeviceDeletionRequest
+ wantErr bool
+ wantDevices int
+ }{
+ {
+ name: "deletion - not a local user",
+ inputData: &api.PerformDeviceDeletionRequest{UserID: "@test:notlocalhost"},
+ wantErr: true,
+ },
+ {
+ name: "deleting not existing devices should not error",
+ inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test", DeviceIDs: []string{"iDontExist"}},
+ wantDevices: 1,
+ },
+ {
+ name: "delete all devices",
+ inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test"},
+ wantDevices: 0,
+ },
+ {
+ name: "delete all devices",
+ inputData: &api.PerformDeviceDeletionRequest{UserID: "@test3:test"},
+ wantDevices: 0,
+ },
+ }
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
+ defer close()
+
+ for _, tc := range creationTests {
+ t.Run(tc.name, func(t *testing.T) {
+ res := api.PerformDeviceCreationResponse{}
+ deviceID := util.RandomString(8)
+ tc.inputData.DeviceID = &deviceID
+ if tc.wantNewDevID {
+ tc.inputData.DeviceID = nil
+ }
+ err := intAPI.PerformDeviceCreation(ctx, tc.inputData, &res)
+ if tc.wantErr && err == nil {
+ t.Fatalf("expected an error, but got none")
+ }
+ if !tc.wantErr && err != nil {
+ t.Fatalf("expected no error, but got: %s", err)
+ }
+ if !res.DeviceCreated {
+ return
+ }
+
+ queryDevicesRes := api.QueryDevicesResponse{}
+ queryDevicesReq := api.QueryDevicesRequest{UserID: res.Device.UserID}
+ if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil {
+ t.Fatal(err)
+ }
+ // We only want to verify one device
+ if len(queryDevicesRes.Devices) > 1 {
+ return
+ }
+ res.Device.AccessToken = ""
+
+ // At this point, there should only be one device
+ if !reflect.DeepEqual(*res.Device, queryDevicesRes.Devices[0]) {
+ t.Fatalf("expected device to be\n%#v, got \n%#v", *res.Device, queryDevicesRes.Devices[0])
+ }
+
+ newDisplayName := "new name"
+ if tc.inputData.DeviceDisplayName == nil {
+ updateRes := api.PerformDeviceUpdateResponse{}
+ updateReq := api.PerformDeviceUpdateRequest{
+ RequestingUserID: fmt.Sprintf("@%s:%s", tc.inputData.Localpart, "test"),
+ DeviceID: deviceID,
+ DisplayName: &newDisplayName,
+ }
+
+ if err = intAPI.PerformDeviceUpdate(ctx, &updateReq, &updateRes); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ queryDeviceInfosRes := api.QueryDeviceInfosResponse{}
+ queryDeviceInfosReq := api.QueryDeviceInfosRequest{DeviceIDs: []string{*tc.inputData.DeviceID}}
+ if err = intAPI.QueryDeviceInfos(ctx, &queryDeviceInfosReq, &queryDeviceInfosRes); err != nil {
+ t.Fatal(err)
+ }
+ gotDisplayName := queryDeviceInfosRes.DeviceInfo[*tc.inputData.DeviceID].DisplayName
+ if tc.inputData.DeviceDisplayName != nil {
+ wantDisplayName := *tc.inputData.DeviceDisplayName
+ if wantDisplayName != gotDisplayName {
+ t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName)
+ }
+ } else {
+ wantDisplayName := newDisplayName
+ if wantDisplayName != gotDisplayName {
+ t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName)
+ }
+ }
+ })
+ }
+
+ for _, tc := range deletionTests {
+ t.Run(tc.name, func(t *testing.T) {
+ delRes := api.PerformDeviceDeletionResponse{}
+ err := intAPI.PerformDeviceDeletion(ctx, tc.inputData, &delRes)
+ if tc.wantErr && err == nil {
+ t.Fatalf("expected an error, but got none")
+ }
+ if !tc.wantErr && err != nil {
+ t.Fatalf("expected no error, but got: %s", err)
+ }
+ if tc.wantErr {
+ return
+ }
+
+ queryDevicesRes := api.QueryDevicesResponse{}
+ queryDevicesReq := api.QueryDevicesRequest{UserID: tc.inputData.UserID}
+ if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil {
+ t.Fatal(err)
+ }
+
+ if len(queryDevicesRes.Devices) != tc.wantDevices {
+ t.Fatalf("expected %d devices, got %d", tc.wantDevices, len(queryDevicesRes.Devices))
+ }
+
+ })
+ }
+ })
+}
+
+// Tests that the session ID of a device is not reused when reusing the same device ID.
+func TestDeviceIDReuse(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
+ defer close()
+
+ res := api.PerformDeviceCreationResponse{}
+ // create a first device
+ deviceID := util.RandomString(8)
+ req := api.PerformDeviceCreationRequest{Localpart: "alice", ServerName: "test", DeviceID: &deviceID, NoDeviceListUpdate: true}
+ err := intAPI.PerformDeviceCreation(ctx, &req, &res)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Do the same request again, we expect a different sessionID
+ res2 := api.PerformDeviceCreationResponse{}
+ err = intAPI.PerformDeviceCreation(ctx, &req, &res2)
+ if err != nil {
+ t.Fatalf("expected no error, but got: %v", err)
+ }
+
+ if res2.Device.SessionID == res.Device.SessionID {
+ t.Fatalf("expected a different session ID, but they are the same")
+ }
+ })
+}
diff --git a/userapi/util/devices.go b/userapi/util/devices.go
index c55fc799..31617d8c 100644
--- a/userapi/util/devices.go
+++ b/userapi/util/devices.go
@@ -19,7 +19,7 @@ type PusherDevice struct {
}
// GetPushDevices pushes to the configured devices of a local user.
-func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
+func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) {
pushers, err := db.GetPushers(ctx, localpart, serverName)
if err != nil {
return nil, fmt.Errorf("db.GetPushers: %w", err)
diff --git a/userapi/util/notify.go b/userapi/util/notify.go
index fc0ab39b..08d1371d 100644
--- a/userapi/util/notify.go
+++ b/userapi/util/notify.go
@@ -13,11 +13,11 @@ import (
)
// NotifyUserCountsAsync sends notifications to a local user's
-// notification destinations. Database lookups run synchronously, but
+// notification destinations. UserDatabase lookups run synchronously, but
// a single goroutine is started when talking to the Push
// gateways. There is no way to know when the background goroutine has
// finished.
-func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error {
+func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.UserDatabase) error {
pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db)
if err != nil {
return err
diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go
index f1d20259..421852d3 100644
--- a/userapi/util/notify_test.go
+++ b/userapi/util/notify_test.go
@@ -79,7 +79,7 @@ func TestNotifyUserCountsAsync(t *testing.T) {
defer close()
base, _, _ := testrig.Base(nil)
defer base.Close()
- db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
+ db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "test", bcrypt.MinCost, 0, 0, "")
if err != nil {
diff --git a/userapi/util/phonehomestats_test.go b/userapi/util/phonehomestats_test.go
index 6e62210e..5f626b5b 100644
--- a/userapi/util/phonehomestats_test.go
+++ b/userapi/util/phonehomestats_test.go
@@ -21,7 +21,7 @@ func TestCollect(t *testing.T) {
b, _, _ := testrig.Base(nil)
connStr, closeDB := test.PrepareDBConnectionString(t, dbType)
defer closeDB()
- db, err := storage.NewUserAPIDatabase(b, &config.DatabaseOptions{
+ db, err := storage.NewUserDatabase(b, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "localhost", bcrypt.MinCost, 1000, 1000, "")
if err != nil {