From 4594233f89f8531fca8f696ab0ece36909130c2a Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 20 Feb 2023 14:58:03 +0100 Subject: Merge keyserver & userapi (#2972) As discussed yesterday, a first draft of merging the keyserver and the userapi. --- appservice/appservice.go | 2 +- appservice/appservice_test.go | 2 +- build/dendritejs-pinecone/main.go | 6 +- build/gobind-yggdrasil/monolith.go | 6 +- clientapi/admin_test.go | 15 +- clientapi/clientapi.go | 6 +- clientapi/routing/admin.go | 23 +- clientapi/routing/joinroom_test.go | 4 +- clientapi/routing/key_crosssigning.go | 9 +- clientapi/routing/keys.go | 7 +- clientapi/routing/login_test.go | 7 +- clientapi/routing/register_test.go | 13 +- clientapi/routing/routing.go | 18 +- cmd/dendrite-demo-pinecone/monolith/monolith.go | 6 +- cmd/dendrite-demo-yggdrasil/main.go | 7 +- cmd/dendrite/main.go | 8 +- federationapi/consumers/keychange.go | 2 +- federationapi/federationapi.go | 6 +- federationapi/federationapi_test.go | 10 +- federationapi/producers/syncapi.go | 2 +- federationapi/routing/devices.go | 12 +- federationapi/routing/keys.go | 2 +- federationapi/routing/profile_test.go | 2 +- federationapi/routing/query_test.go | 2 +- federationapi/routing/routing.go | 10 +- federationapi/routing/send.go | 4 +- federationapi/routing/send_test.go | 2 +- internal/transactionrequest.go | 8 +- internal/transactionrequest_test.go | 2 +- keyserver/README.md | 19 - keyserver/api/api.go | 346 -------- keyserver/consumers/devicelistupdate.go | 95 -- keyserver/consumers/signingkeyupdate.go | 112 --- keyserver/internal/cross_signing.go | 587 ------------- keyserver/internal/device_list_update.go | 579 ------------ keyserver/internal/device_list_update_default.go | 22 - keyserver/internal/device_list_update_sytest.go | 25 - keyserver/internal/device_list_update_test.go | 431 --------- keyserver/internal/internal.go | 816 ----------------- keyserver/internal/internal_test.go | 156 ---- keyserver/keyserver.go | 86 -- keyserver/keyserver_test.go | 29 - keyserver/producers/keychange.go | 107 --- keyserver/storage/interface.go | 93 -- .../storage/postgres/cross_signing_keys_table.go | 102 --- .../storage/postgres/cross_signing_sigs_table.go | 131 --- .../deltas/2022012016470000_key_changes.go | 69 -- .../deltas/2022042612000000_xsigning_idx.go | 47 - keyserver/storage/postgres/device_keys_table.go | 228 ----- keyserver/storage/postgres/key_changes_table.go | 134 --- keyserver/storage/postgres/one_time_keys_table.go | 205 ----- keyserver/storage/postgres/stale_device_lists.go | 131 --- keyserver/storage/postgres/storage.go | 69 -- keyserver/storage/shared/storage.go | 261 ------ .../storage/sqlite3/cross_signing_keys_table.go | 101 --- .../storage/sqlite3/cross_signing_sigs_table.go | 129 --- .../sqlite3/deltas/2022012016470000_key_changes.go | 66 -- .../deltas/2022042612000000_xsigning_idx.go | 71 -- keyserver/storage/sqlite3/device_keys_table.go | 225 ----- keyserver/storage/sqlite3/key_changes_table.go | 132 --- keyserver/storage/sqlite3/one_time_keys_table.go | 219 ----- keyserver/storage/sqlite3/stale_device_lists.go | 145 --- keyserver/storage/sqlite3/storage.go | 68 -- keyserver/storage/storage.go | 40 - keyserver/storage/storage_test.go | 197 ----- keyserver/storage/storage_wasm.go | 34 - keyserver/storage/tables/interface.go | 71 -- .../storage/tables/stale_device_lists_test.go | 94 -- keyserver/types/storage.go | 50 -- roomserver/api/api.go | 2 +- roomserver/roomserver_test.go | 8 +- setup/monolith.go | 15 +- syncapi/consumers/keychange.go | 2 +- syncapi/consumers/sendtodevice.go | 10 +- syncapi/internal/keychange.go | 16 +- syncapi/internal/keychange_test.go | 23 +- syncapi/streams/stream_devicelist.go | 10 +- syncapi/streams/streams.go | 5 +- syncapi/sync/requestpool.go | 9 +- syncapi/syncapi.go | 8 +- syncapi/syncapi_test.go | 26 +- userapi/api/api.go | 329 ++++++- userapi/consumers/clientapi.go | 4 +- userapi/consumers/devicelistupdate.go | 95 ++ userapi/consumers/roomserver.go | 4 +- userapi/consumers/roomserver_test.go | 4 +- userapi/consumers/signingkeyupdate.go | 111 +++ userapi/internal/api.go | 968 -------------------- userapi/internal/cross_signing.go | 587 +++++++++++++ userapi/internal/device_list_update.go | 579 ++++++++++++ userapi/internal/device_list_update_default.go | 22 + userapi/internal/device_list_update_sytest.go | 25 + userapi/internal/device_list_update_test.go | 431 +++++++++ userapi/internal/key_api.go | 798 +++++++++++++++++ userapi/internal/key_api_test.go | 161 ++++ userapi/internal/user_api.go | 970 +++++++++++++++++++++ userapi/producers/keychange.go | 107 +++ userapi/producers/syncapi.go | 4 +- userapi/storage/interface.go | 76 +- userapi/storage/postgres/account_data_table.go | 8 +- .../storage/postgres/cross_signing_keys_table.go | 102 +++ .../storage/postgres/cross_signing_sigs_table.go | 131 +++ .../deltas/2022012016470000_key_changes.go | 69 ++ .../deltas/2022042612000000_xsigning_idx.go | 47 + userapi/storage/postgres/device_keys_table.go | 213 +++++ userapi/storage/postgres/devices_table.go | 8 +- userapi/storage/postgres/key_backup_table.go | 4 +- userapi/storage/postgres/key_changes_table.go | 127 +++ userapi/storage/postgres/one_time_keys_table.go | 194 +++++ userapi/storage/postgres/stale_device_lists.go | 131 +++ userapi/storage/postgres/storage.go | 41 + userapi/storage/shared/storage.go | 235 +++++ .../storage/sqlite3/cross_signing_keys_table.go | 101 +++ .../storage/sqlite3/cross_signing_sigs_table.go | 129 +++ .../sqlite3/deltas/2022012016470000_key_changes.go | 66 ++ .../deltas/2022042612000000_xsigning_idx.go | 71 ++ userapi/storage/sqlite3/device_keys_table.go | 213 +++++ userapi/storage/sqlite3/devices_table.go | 17 +- userapi/storage/sqlite3/key_backup_table.go | 4 +- userapi/storage/sqlite3/key_changes_table.go | 125 +++ userapi/storage/sqlite3/one_time_keys_table.go | 208 +++++ userapi/storage/sqlite3/stale_device_lists.go | 145 +++ userapi/storage/sqlite3/stats_table.go | 3 + userapi/storage/sqlite3/storage.go | 45 +- userapi/storage/storage.go | 27 +- userapi/storage/storage_test.go | 210 ++++- userapi/storage/storage_wasm.go | 4 +- userapi/storage/tables/interface.go | 46 +- userapi/storage/tables/stale_device_lists_test.go | 94 ++ userapi/types/storage.go | 50 ++ userapi/userapi.go | 62 +- userapi/userapi_test.go | 327 ++++++- userapi/util/devices.go | 2 +- userapi/util/notify.go | 4 +- userapi/util/notify_test.go | 2 +- userapi/util/phonehomestats_test.go | 2 +- 136 files changed, 7634 insertions(+), 7767 deletions(-) delete mode 100644 keyserver/README.md delete mode 100644 keyserver/api/api.go delete mode 100644 keyserver/consumers/devicelistupdate.go delete mode 100644 keyserver/consumers/signingkeyupdate.go delete mode 100644 keyserver/internal/cross_signing.go delete mode 100644 keyserver/internal/device_list_update.go delete mode 100644 keyserver/internal/device_list_update_default.go delete mode 100644 keyserver/internal/device_list_update_sytest.go delete mode 100644 keyserver/internal/device_list_update_test.go delete mode 100644 keyserver/internal/internal.go delete mode 100644 keyserver/internal/internal_test.go delete mode 100644 keyserver/keyserver.go delete mode 100644 keyserver/keyserver_test.go delete mode 100644 keyserver/producers/keychange.go delete mode 100644 keyserver/storage/interface.go delete mode 100644 keyserver/storage/postgres/cross_signing_keys_table.go delete mode 100644 keyserver/storage/postgres/cross_signing_sigs_table.go delete mode 100644 keyserver/storage/postgres/deltas/2022012016470000_key_changes.go delete mode 100644 keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go delete mode 100644 keyserver/storage/postgres/device_keys_table.go delete mode 100644 keyserver/storage/postgres/key_changes_table.go delete mode 100644 keyserver/storage/postgres/one_time_keys_table.go delete mode 100644 keyserver/storage/postgres/stale_device_lists.go delete mode 100644 keyserver/storage/postgres/storage.go delete mode 100644 keyserver/storage/shared/storage.go delete mode 100644 keyserver/storage/sqlite3/cross_signing_keys_table.go delete mode 100644 keyserver/storage/sqlite3/cross_signing_sigs_table.go delete mode 100644 keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go delete mode 100644 keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go delete mode 100644 keyserver/storage/sqlite3/device_keys_table.go delete mode 100644 keyserver/storage/sqlite3/key_changes_table.go delete mode 100644 keyserver/storage/sqlite3/one_time_keys_table.go delete mode 100644 keyserver/storage/sqlite3/stale_device_lists.go delete mode 100644 keyserver/storage/sqlite3/storage.go delete mode 100644 keyserver/storage/storage.go delete mode 100644 keyserver/storage/storage_test.go delete mode 100644 keyserver/storage/storage_wasm.go delete mode 100644 keyserver/storage/tables/interface.go delete mode 100644 keyserver/storage/tables/stale_device_lists_test.go delete mode 100644 keyserver/types/storage.go create mode 100644 userapi/consumers/devicelistupdate.go create mode 100644 userapi/consumers/signingkeyupdate.go delete mode 100644 userapi/internal/api.go create mode 100644 userapi/internal/cross_signing.go create mode 100644 userapi/internal/device_list_update.go create mode 100644 userapi/internal/device_list_update_default.go create mode 100644 userapi/internal/device_list_update_sytest.go create mode 100644 userapi/internal/device_list_update_test.go create mode 100644 userapi/internal/key_api.go create mode 100644 userapi/internal/key_api_test.go create mode 100644 userapi/internal/user_api.go create mode 100644 userapi/producers/keychange.go create mode 100644 userapi/storage/postgres/cross_signing_keys_table.go create mode 100644 userapi/storage/postgres/cross_signing_sigs_table.go create mode 100644 userapi/storage/postgres/deltas/2022012016470000_key_changes.go create mode 100644 userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go create mode 100644 userapi/storage/postgres/device_keys_table.go create mode 100644 userapi/storage/postgres/key_changes_table.go create mode 100644 userapi/storage/postgres/one_time_keys_table.go create mode 100644 userapi/storage/postgres/stale_device_lists.go create mode 100644 userapi/storage/sqlite3/cross_signing_keys_table.go create mode 100644 userapi/storage/sqlite3/cross_signing_sigs_table.go create mode 100644 userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go create mode 100644 userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go create mode 100644 userapi/storage/sqlite3/device_keys_table.go create mode 100644 userapi/storage/sqlite3/key_changes_table.go create mode 100644 userapi/storage/sqlite3/one_time_keys_table.go create mode 100644 userapi/storage/sqlite3/stale_device_lists.go create mode 100644 userapi/storage/tables/stale_device_lists_test.go create mode 100644 userapi/types/storage.go diff --git a/appservice/appservice.go b/appservice/appservice.go index b950a821..5b1b93de 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -38,7 +38,7 @@ import ( // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( base *base.BaseDendrite, - userAPI userapi.UserInternalAPI, + userAPI userapi.AppserviceUserAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ) appserviceAPI.AppServiceInternalAPI { client := &http.Client{ diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index 9e9940cd..de9f5aaf 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -125,7 +125,7 @@ func TestAppserviceInternalAPI(t *testing.T) { // Create required internal APIs rsAPI := roomserver.NewInternalAPI(base) - usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + usrAPI := userapi.NewInternalAPI(base, rsAPI, nil) asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI) runCases(t, asAPI) diff --git a/build/dendritejs-pinecone/main.go b/build/dendritejs-pinecone/main.go index f3fcb03e..44e52286 100644 --- a/build/dendritejs-pinecone/main.go +++ b/build/dendritejs-pinecone/main.go @@ -30,7 +30,6 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -183,13 +182,11 @@ func startup() { rsAPI := roomserver.NewInternalAPI(base) federation := conn.CreateFederationClient(base, pSessions) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) asQuery := appservice.NewInternalAPI( base, userAPI, rsAPI, @@ -208,7 +205,6 @@ func startup() { FederationAPI: fedSenderAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, //ServerKeyAPI: serverKeyAPI, ExtPublicRoomsProvider: rooms.NewPineconeRoomProvider(pRouter, pSessions, fedSenderAPI, federation), } diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index fad850a9..32af611a 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -165,9 +164,7 @@ func (m *DendriteMonolith) Start() { base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) rsAPI.SetAppserviceAPI(asAPI) @@ -186,7 +183,6 @@ func (m *DendriteMonolith) Start() { FederationAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, ), diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index c7ca019f..300d3a88 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -8,7 +8,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -40,11 +39,9 @@ func TestAdminResetPassword(t *testing.T) { rsAPI := roomserver.NewInternalAPI(base) // Needed for changing the password/login - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil) + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil) // Create the users in the userapi and login accessTokens := map[*test.User]string{ @@ -155,14 +152,12 @@ func TestPurgeRoom(t *testing.T) { fedClient := base.CreateFederationClient() rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) // this starts the JetStream consumers - syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI) + syncapi.AddPublicRoutes(base, userAPI, rsAPI) federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) rsAPI.SetFederationAPI(nil, nil) - keyAPI.SetUserAPI(userAPI) // Create the room if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { @@ -170,7 +165,7 @@ func TestPurgeRoom(t *testing.T) { } // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil) + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil) // Create the users in the userapi and login accessTokens := map[*test.User]string{ diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 2d17e092..e9985d43 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -15,6 +15,7 @@ package clientapi import ( + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" @@ -23,11 +24,9 @@ import ( "github.com/matrix-org/dendrite/clientapi/routing" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/transactions" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" - userapi "github.com/matrix-org/dendrite/userapi/api" ) // AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. @@ -40,7 +39,6 @@ func AddPublicRoutes( fsAPI federationAPI.ClientFederationAPI, userAPI userapi.ClientUserAPI, userDirectoryProvider userapi.QuerySearchProfilesAPI, - keyAPI keyserverAPI.ClientKeyAPI, extRoomsProvider api.ExtraPublicRoomsProvider, ) { cfg := &base.Cfg.ClientAPI @@ -61,7 +59,7 @@ func AddPublicRoutes( base, cfg, rsAPI, asAPI, userAPI, userDirectoryProvider, federation, - syncProducer, transactionsCache, fsAPI, keyAPI, + syncProducer, transactionsCache, fsAPI, extRoomsProvider, mscCfg, natsClient, ) } diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 4b4dedfd..a01f6b94 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -16,14 +16,13 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" ) -func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -56,7 +55,7 @@ func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi } } -func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -99,7 +98,7 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi } } -func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -130,7 +129,7 @@ func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.De } } -func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { +func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.Device, userAPI api.ClientUserAPI) util.JSONResponse { if req.Body == nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -150,8 +149,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap JSON: jsonerror.InvalidArgumentValue(err.Error()), } } - accAvailableResp := &userapi.QueryAccountAvailabilityResponse{} - if err = userAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{ + accAvailableResp := &api.QueryAccountAvailabilityResponse{} + if err = userAPI.QueryAccountAvailability(req.Context(), &api.QueryAccountAvailabilityRequest{ Localpart: localpart, ServerName: serverName, }, accAvailableResp); err != nil { @@ -186,13 +185,13 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap return *internal.PasswordResponse(err) } - updateReq := &userapi.PerformPasswordUpdateRequest{ + updateReq := &api.PerformPasswordUpdateRequest{ Localpart: localpart, ServerName: serverName, Password: request.Password, LogoutDevices: true, } - updateRes := &userapi.PerformPasswordUpdateResponse{} + updateRes := &api.PerformPasswordUpdateResponse{} if err := userAPI.PerformPasswordUpdate(req.Context(), updateReq, updateRes); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -209,7 +208,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap } } -func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, natsClient *nats.Conn) util.JSONResponse { +func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *api.Device, natsClient *nats.Conn) util.JSONResponse { _, err := natsClient.RequestMsg(nats.NewMsg(cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex)), time.Second*10) if err != nil { logrus.WithError(err).Error("failed to publish nats message") @@ -255,7 +254,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien } } -func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { +func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go index 9e8208e6..1450ef4b 100644 --- a/clientapi/routing/joinroom_test.go +++ b/clientapi/routing/joinroom_test.go @@ -10,7 +10,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" @@ -29,8 +28,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) { defer baseClose() rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 2570db09..267ba1dc 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -21,9 +21,8 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) @@ -34,8 +33,8 @@ type crossSigningRequest struct { func UploadCrossSigningDeviceKeys( req *http.Request, userInteractiveAuth *auth.UserInteractive, - keyserverAPI api.ClientKeyAPI, device *userapi.Device, - accountAPI userapi.ClientUserAPI, cfg *config.ClientAPI, + keyserverAPI api.ClientKeyAPI, device *api.Device, + accountAPI api.ClientUserAPI, cfg *config.ClientAPI, ) util.JSONResponse { uploadReq := &crossSigningRequest{} uploadRes := &api.PerformUploadDeviceKeysResponse{} @@ -107,7 +106,7 @@ func UploadCrossSigningDeviceKeys( } } -func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { +func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { uploadReq := &api.PerformUploadDeviceSignaturesRequest{} uploadRes := &api.PerformUploadDeviceSignaturesResponse{} diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 0c12b111..3d60fcc3 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -23,8 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/dendrite/keyserver/api" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/api" ) type uploadKeysRequest struct { @@ -32,7 +31,7 @@ type uploadKeysRequest struct { OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"` } -func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { +func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { var r uploadKeysRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { @@ -106,7 +105,7 @@ func (r *queryKeysRequest) GetTimeout() time.Duration { return timeout } -func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { +func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { var r queryKeysRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { diff --git a/clientapi/routing/login_test.go b/clientapi/routing/login_test.go index d429d7f8..b72db9d8 100644 --- a/clientapi/routing/login_test.go +++ b/clientapi/routing/login_test.go @@ -9,7 +9,6 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" @@ -39,12 +38,10 @@ func TestLogin(t *testing.T) { rsAPI := roomserver.NewInternalAPI(base) // Needed for /login - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, keyAPI, nil, &base.Cfg.MSCs, nil) + Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, &base.Cfg.MSCs, nil) // Create password password := util.RandomString(8) diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 670c392b..651e3d3d 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -30,7 +30,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" @@ -409,9 +408,7 @@ func Test_register(t *testing.T) { defer baseClose() rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -582,9 +579,7 @@ func TestRegisterUserWithDisplayName(t *testing.T) { base.Cfg.Global.ServerName = "server" rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) deviceName, deviceID := "deviceName", "deviceID" expectedDisplayName := "DisplayName" response := completeRegistration( @@ -623,9 +618,7 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) { base.Cfg.ClientAPI.RegistrationSharedSecret = sharedSecret rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) expectedDisplayName := "rabbit" jsonStr := []byte(`{"admin":true,"mac":"24dca3bba410e43fe64b9b5c28306693bf3baa9f","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 66610c0a..028d02e9 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -22,6 +22,7 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/setup/base" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/nats-io/nats.go" @@ -37,11 +38,9 @@ import ( federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/transactions" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" - userapi "github.com/matrix-org/dendrite/userapi/api" ) // Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client @@ -61,7 +60,6 @@ func Setup( syncProducer *producers.SyncAPIProducer, transactionsCache *transactions.Cache, federationSender federationAPI.ClientFederationAPI, - keyAPI keyserverAPI.ClientKeyAPI, extRoomsProvider api.ExtraPublicRoomsProvider, mscCfg *config.MSCs, natsClient *nats.Conn, ) { @@ -192,7 +190,7 @@ func Setup( dendriteAdminRouter.Handle("/admin/refreshDevices/{userID}", httputil.MakeAdminAPI("admin_refresh_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return AdminMarkAsStale(req, cfg, keyAPI) + return AdminMarkAsStale(req, cfg, userAPI) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -1372,11 +1370,11 @@ func Setup( // Cross-signing device keys postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg) + return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, userAPI, device, userAPI, cfg) }) postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadCrossSigningDeviceSignatures(req, keyAPI, device) + return UploadCrossSigningDeviceSignatures(req, userAPI, device) }, httputil.WithAllowGuests()) v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) @@ -1388,22 +1386,22 @@ func Setup( // Supplying a device ID is deprecated. v3mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadKeys(req, keyAPI, device) + return UploadKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/upload", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadKeys(req, keyAPI, device) + return UploadKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/query", httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return QueryKeys(req, keyAPI, device) + return QueryKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/claim", httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return ClaimKeys(req, keyAPI) + return ClaimKeys(req, userAPI) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", diff --git a/cmd/dendrite-demo-pinecone/monolith/monolith.go b/cmd/dendrite-demo-pinecone/monolith/monolith.go index fe19593c..27720369 100644 --- a/cmd/dendrite-demo-pinecone/monolith/monolith.go +++ b/cmd/dendrite-demo-pinecone/monolith/monolith.go @@ -38,7 +38,6 @@ import ( federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/relayapi" relayAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/dendrite/roomserver" @@ -139,9 +138,7 @@ func (p *P2PMonolith) SetupDendrite(cfg *config.Dendrite, port int, enableRelayi p.BaseDendrite, federation, rsAPI, p.BaseDendrite.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(p.BaseDendrite, &p.BaseDendrite.Cfg.KeyServer, fsAPI, rsComponent) - userAPI := userapi.NewInternalAPI(p.BaseDendrite, &cfg.UserAPI, nil, keyAPI, rsAPI, p.BaseDendrite.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(p.BaseDendrite, rsAPI, federation) asAPI := appservice.NewInternalAPI(p.BaseDendrite, userAPI, rsAPI) @@ -175,7 +172,6 @@ func (p *P2PMonolith) SetupDendrite(cfg *config.Dendrite, port int, enableRelayi FederationAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, RelayAPI: relayAPI, ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 842682b4..d759c6a7 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -39,7 +39,6 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -161,10 +160,7 @@ func main() { base, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) - - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) rsAPI.SetAppserviceAPI(asAPI) @@ -184,7 +180,6 @@ func main() { FederationAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, ), diff --git a/cmd/dendrite/main.go b/cmd/dendrite/main.go index 35bfc1d6..e8ff0a47 100644 --- a/cmd/dendrite/main.go +++ b/cmd/dendrite/main.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" basepkg "github.com/matrix-org/dendrite/setup/base" @@ -56,10 +55,7 @@ func main() { keyRing := fsAPI.KeyRing() - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) - - pgClient := base.PushGatewayHTTPClient() - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient) + userAPI := userapi.NewInternalAPI(base, rsAPI, federation) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) @@ -69,7 +65,6 @@ func main() { rsAPI.SetFederationAPI(fsAPI, keyRing) rsAPI.SetAppserviceAPI(asAPI) rsAPI.SetUserAPI(userAPI) - keyAPI.SetUserAPI(userAPI) monolith := setup.Monolith{ Config: base.Cfg, @@ -83,7 +78,6 @@ func main() { FederationAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - KeyAPI: keyAPI, } monolith.AddAllPublicRoutes(base) diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 601257d4..7d9df3d7 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -26,11 +26,11 @@ import ( "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/types" - "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" ) // KeyChangeConsumer consumes events that originate in key server. diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 10803916..ec482659 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -28,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/internal/caching" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" @@ -42,12 +41,11 @@ import ( // AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. func AddPublicRoutes( base *base.BaseDendrite, - userAPI userapi.UserInternalAPI, + userAPI userapi.FederationUserAPI, federation *gomatrixserverlib.FederationClient, keyRing gomatrixserverlib.JSONVerifier, rsAPI roomserverAPI.FederationRoomserverAPI, fedAPI federationAPI.FederationInternalAPI, - keyAPI keyserverAPI.FederationKeyAPI, servers federationAPI.ServersInRoomProvider, ) { cfg := &base.Cfg.FederationAPI @@ -79,7 +77,7 @@ func AddPublicRoutes( routing.Setup( base, rsAPI, f, keyRing, - federation, userAPI, keyAPI, mscCfg, + federation, userAPI, mscCfg, servers, producer, ) } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 8d1d8514..57d4b964 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -17,13 +17,13 @@ import ( "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/internal" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type fedRoomserverAPI struct { @@ -230,9 +230,9 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { // Inject a keyserver key change event and ensure we try to send it out. If we don't, then the // federationapi is incorrectly waiting for an output room event to arrive to update the joined // hosts table. - key := keyapi.DeviceMessage{ - Type: keyapi.TypeDeviceKeyUpdate, - DeviceKeys: &keyapi.DeviceKeys{ + key := userapi.DeviceMessage{ + Type: userapi.TypeDeviceKeyUpdate, + DeviceKeys: &userapi.DeviceKeys{ UserID: joiningUser.ID, DeviceID: "MY_DEVICE", DisplayName: "BLARGLE", @@ -277,7 +277,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { keyRing := &test.NopJSONVerifier{} // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, nil) + federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil) baseURL, cancel := test.ListenAndServe(t, b.PublicFederationAPIMux, true) defer cancel() serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) diff --git a/federationapi/producers/syncapi.go b/federationapi/producers/syncapi.go index 7cce13a7..6bcfafa3 100644 --- a/federationapi/producers/syncapi.go +++ b/federationapi/producers/syncapi.go @@ -41,7 +41,7 @@ type SyncAPIProducer struct { TopicSigningKeyUpdate string JetStream nats.JetStreamContext Config *config.FederationAPI - UserAPI userapi.UserInternalAPI + UserAPI userapi.FederationUserAPI } func (p *SyncAPIProducer) SendReceipt( diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index ce8b06b7..871d26cd 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -17,7 +17,7 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" - keyapi "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/tidwall/gjson" @@ -26,11 +26,11 @@ import ( // GetUserDevices for the given user id func GetUserDevices( req *http.Request, - keyAPI keyapi.FederationKeyAPI, + keyAPI api.FederationKeyAPI, userID string, ) util.JSONResponse { - var res keyapi.QueryDeviceMessagesResponse - if err := keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{ + var res api.QueryDeviceMessagesResponse + if err := keyAPI.QueryDeviceMessages(req.Context(), &api.QueryDeviceMessagesRequest{ UserID: userID, }, &res); err != nil { return util.ErrorResponse(err) @@ -40,12 +40,12 @@ func GetUserDevices( return jsonerror.InternalServerError() } - sigReq := &keyapi.QuerySignaturesRequest{ + sigReq := &api.QuerySignaturesRequest{ TargetIDs: map[string][]gomatrixserverlib.KeyID{ userID: {}, }, } - sigRes := &keyapi.QuerySignaturesResponse{} + sigRes := &api.QuerySignaturesResponse{} for _, dev := range res.Devices { sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID)) } diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index dc262cfd..2885cc91 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -22,8 +22,8 @@ import ( clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" diff --git a/federationapi/routing/profile_test.go b/federationapi/routing/profile_test.go index 76365608..3b9d576b 100644 --- a/federationapi/routing/profile_test.go +++ b/federationapi/routing/profile_test.go @@ -62,7 +62,7 @@ func TestHandleQueryProfile(t *testing.T) { if !ok { panic("This is a programming error.") } - routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil) + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil) handler := fedMux.Get(routing.QueryProfileRouteName).GetHandler().ServeHTTP _, sk, _ := ed25519.GenerateKey(nil) diff --git a/federationapi/routing/query_test.go b/federationapi/routing/query_test.go index 21f35bf0..d839a16b 100644 --- a/federationapi/routing/query_test.go +++ b/federationapi/routing/query_test.go @@ -62,7 +62,7 @@ func TestHandleQueryDirectory(t *testing.T) { if !ok { panic("This is a programming error.") } - routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil) + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil) handler := fedMux.Get(routing.QueryDirectoryRouteName).GetHandler().ServeHTTP _, sk, _ := ed25519.GenerateKey(nil) diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 5eb30c6e..324740dd 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -29,7 +29,6 @@ import ( "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" - keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -62,7 +61,6 @@ func Setup( keys gomatrixserverlib.JSONVerifier, federation federationAPI.FederationClient, userAPI userapi.FederationUserAPI, - keyAPI keyserverAPI.FederationKeyAPI, mscCfg *config.MSCs, servers federationAPI.ServersInRoomProvider, producer *producers.SyncAPIProducer, @@ -141,7 +139,7 @@ func Setup( func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), - cfg, rsAPI, keyAPI, keys, federation, mu, servers, producer, + cfg, rsAPI, userAPI, keys, federation, mu, servers, producer, ) }, )).Methods(http.MethodPut, http.MethodOptions).Name(SendRouteName) @@ -269,7 +267,7 @@ func Setup( "federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return GetUserDevices( - httpReq, keyAPI, vars["userID"], + httpReq, userAPI, vars["userID"], ) }, )).Methods(http.MethodGet) @@ -494,14 +492,14 @@ func Setup( v1fedmux.Handle("/user/keys/claim", MakeFedAPI( "federation_keys_claim", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) + return ClaimOneTimeKeys(httpReq, request, userAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) v1fedmux.Handle("/user/keys/query", MakeFedAPI( "federation_keys_query", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName) + return QueryDeviceKeys(httpReq, request, userAPI, cfg.Matrix.ServerName) }, )).Methods(http.MethodPost) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 67b513c9..82651719 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -28,9 +28,9 @@ import ( federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" + userAPI "github.com/matrix-org/dendrite/userapi/api" ) const ( @@ -59,7 +59,7 @@ func Send( txnID gomatrixserverlib.TransactionID, cfg *config.FederationAPI, rsAPI api.FederationRoomserverAPI, - keyAPI keyapi.FederationKeyAPI, + keyAPI userAPI.FederationUserAPI, keys gomatrixserverlib.JSONVerifier, federation federationAPI.FederationClient, mu *internal.MutexByRoom, diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index d7feee0e..eed4e7e6 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -58,7 +58,7 @@ func TestHandleSend(t *testing.T) { if !ok { panic("This is a programming error.") } - routing.Setup(base, nil, r, keyRing, nil, nil, nil, &base.Cfg.MSCs, nil, nil) + routing.Setup(base, nil, r, keyRing, nil, nil, &base.Cfg.MSCs, nil, nil) handler := fedMux.Get(routing.SendRouteName).GetHandler().ServeHTTP _, sk, _ := ed25519.GenerateKey(nil) diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index 95673fc1..13b00af5 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -24,9 +24,9 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/federationapi/types" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" syncTypes "github.com/matrix-org/dendrite/syncapi/types" + userAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -56,7 +56,7 @@ var ( type TxnReq struct { gomatrixserverlib.Transaction rsAPI api.FederationRoomserverAPI - keyAPI keyapi.FederationKeyAPI + userAPI userAPI.FederationUserAPI ourServerName gomatrixserverlib.ServerName keys gomatrixserverlib.JSONVerifier roomsMu *MutexByRoom @@ -66,7 +66,7 @@ type TxnReq struct { func NewTxnReq( rsAPI api.FederationRoomserverAPI, - keyAPI keyapi.FederationKeyAPI, + userAPI userAPI.FederationUserAPI, ourServerName gomatrixserverlib.ServerName, keys gomatrixserverlib.JSONVerifier, roomsMu *MutexByRoom, @@ -80,7 +80,7 @@ func NewTxnReq( ) TxnReq { t := TxnReq{ rsAPI: rsAPI, - keyAPI: keyAPI, + userAPI: userAPI, ourServerName: ourServerName, keys: keys, roomsMu: roomsMu, diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index 93c6fb6f..8597ae24 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -23,13 +23,13 @@ import ( "time" "github.com/matrix-org/dendrite/federationapi/producers" - keyAPI "github.com/matrix-org/dendrite/keyserver/api" rsAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" + keyAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" diff --git a/keyserver/README.md b/keyserver/README.md deleted file mode 100644 index fd9f37d2..00000000 --- a/keyserver/README.md +++ /dev/null @@ -1,19 +0,0 @@ -## Key Server - -This is an internal component which manages E2E keys from clients. It handles all the [Key Management APIs](https://matrix.org/docs/spec/client_server/r0.6.1#key-management-api) with the exception of `/keys/changes` which is handled by Sync API. This component is designed to shard by user ID. - -Keys are uploaded and stored in this component, and key changes are emitted to a Kafka topic for downstream components such as Sync API. - -### Internal APIs -- `PerformUploadKeys` stores identity keys and one-time public keys for given user(s). -- `PerformClaimKeys` acquires one-time public keys for given user(s). This may involve outbound federation calls. -- `QueryKeys` returns identity keys for given user(s). This may involve outbound federation calls. This component may then cache federated identity keys to avoid repeatedly hitting remote servers. -- A topic which emits identity keys every time there is a change (addition or deletion). - -### Endpoint mappings -- Client API maps `/keys/upload` to `PerformUploadKeys`. -- Client API maps `/keys/query` to `QueryKeys`. -- Client API maps `/keys/claim` to `PerformClaimKeys`. -- Federation API maps `/user/keys/query` to `QueryKeys`. -- Federation API maps `/user/keys/claim` to `PerformClaimKeys`. -- Sync API maps `/keys/changes` to consuming from the Kafka topic. diff --git a/keyserver/api/api.go b/keyserver/api/api.go deleted file mode 100644 index 14fced3e..00000000 --- a/keyserver/api/api.go +++ /dev/null @@ -1,346 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package api - -import ( - "bytes" - "context" - "encoding/json" - "strings" - "time" - - "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/keyserver/types" - userapi "github.com/matrix-org/dendrite/userapi/api" -) - -type KeyInternalAPI interface { - SyncKeyAPI - ClientKeyAPI - FederationKeyAPI - UserKeyAPI - - // SetUserAPI assigns a user API to query when extracting device names. - SetUserAPI(i userapi.KeyserverUserAPI) -} - -// API functions required by the clientapi -type ClientKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error - PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error - // PerformClaimKeys claims one-time keys for use in pre-key messages - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error - PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error -} - -// API functions required by the userapi -type UserKeyAPI interface { - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error - PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error -} - -// API functions required by the syncapi -type SyncKeyAPI interface { - QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error - QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error - PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error -} - -type FederationKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error - QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error -} - -// KeyError is returned if there was a problem performing/querying the server -type KeyError struct { - Err string `json:"error"` - IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE - IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM - IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM -} - -func (k *KeyError) Error() string { - return k.Err -} - -type DeviceMessageType int - -const ( - TypeDeviceKeyUpdate DeviceMessageType = iota - TypeCrossSigningUpdate -) - -// DeviceMessage represents the message produced into Kafka by the key server. -type DeviceMessage struct { - Type DeviceMessageType `json:"Type,omitempty"` - *DeviceKeys `json:"DeviceKeys,omitempty"` - *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` - // A monotonically increasing number which represents device changes for this user. - StreamID int64 - DeviceChangeID int64 -} - -// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log -type OutputCrossSigningKeyUpdate struct { - CrossSigningKeyUpdate `json:"signing_keys"` -} - -type CrossSigningKeyUpdate struct { - MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"` - SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"` - UserID string `json:"user_id"` -} - -// DeviceKeysEqual returns true if the device keys updates contain the -// same display name and key JSON. This will return false if either of -// the updates is not a device keys update, or if the user ID/device ID -// differ between the two. -func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool { - if m1.DeviceKeys == nil || m2.DeviceKeys == nil { - return false - } - if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID { - return false - } - if m1.DisplayName != m2.DisplayName { - return false // different display names - } - if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 { - return false // either is empty - } - return bytes.Equal(m1.KeyJSON, m2.KeyJSON) -} - -// DeviceKeys represents a set of device keys for a single device -// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload -type DeviceKeys struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // The device display name - DisplayName string - // The raw device key JSON - KeyJSON []byte -} - -// WithStreamID returns a copy of this device message with the given stream ID -func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage { - return DeviceMessage{ - DeviceKeys: k, - StreamID: streamID, - } -} - -// OneTimeKeys represents a set of one-time keys for a single device -// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload -type OneTimeKeys struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // A map of algorithm:key_id => key JSON - KeyJSON map[string]json.RawMessage -} - -// Split a key in KeyJSON into algorithm and key ID -func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { - segments := strings.Split(keyIDWithAlgo, ":") - return segments[0], segments[1] -} - -// OneTimeKeysCount represents the counts of one-time keys for a single device -type OneTimeKeysCount struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // algorithm to count e.g: - // { - // "curve25519": 10, - // "signed_curve25519": 20 - // } - KeyCount map[string]int -} - -// PerformUploadKeysRequest is the request to PerformUploadKeys -type PerformUploadKeysRequest struct { - UserID string // Required - User performing the request - DeviceID string // Optional - Device performing the request, for fetching OTK count - DeviceKeys []DeviceKeys - OneTimeKeys []OneTimeKeys - // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update - // the display name for their respective device, and NOT to modify the keys. The key - // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. - // Without this flag, requests to modify device display names would delete device keys. - OnlyDisplayNameUpdates bool -} - -// PerformUploadKeysResponse is the response to PerformUploadKeys -type PerformUploadKeysResponse struct { - // A fatal error when processing e.g database failures - Error *KeyError - // A map of user_id -> device_id -> Error for tracking failures. - KeyErrors map[string]map[string]*KeyError - OneTimeKeyCounts []OneTimeKeysCount -} - -// PerformDeleteKeysRequest asks the keyserver to forget about certain -// keys, and signatures related to those keys. -type PerformDeleteKeysRequest struct { - UserID string - KeyIDs []gomatrixserverlib.KeyID -} - -// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest. -type PerformDeleteKeysResponse struct { - Error *KeyError -} - -// KeyError sets a key error field on KeyErrors -func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) { - if r.KeyErrors[userID] == nil { - r.KeyErrors[userID] = make(map[string]*KeyError) - } - r.KeyErrors[userID][deviceID] = err -} - -type PerformClaimKeysRequest struct { - // Map of user_id to device_id to algorithm name - OneTimeKeys map[string]map[string]string - Timeout time.Duration -} - -type PerformClaimKeysResponse struct { - // Map of user_id to device_id to algorithm:key_id to key JSON - OneTimeKeys map[string]map[string]map[string]json.RawMessage - // Map of remote server domain to error JSON - Failures map[string]interface{} - // Set if there was a fatal error processing this action - Error *KeyError -} - -type PerformUploadDeviceKeysRequest struct { - gomatrixserverlib.CrossSigningKeys - // The user that uploaded the key, should be populated by the clientapi. - UserID string -} - -type PerformUploadDeviceKeysResponse struct { - Error *KeyError -} - -type PerformUploadDeviceSignaturesRequest struct { - Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice - // The user that uploaded the sig, should be populated by the clientapi. - UserID string -} - -type PerformUploadDeviceSignaturesResponse struct { - Error *KeyError -} - -type QueryKeysRequest struct { - // The user ID asking for the keys, e.g. if from a client API request. - // Will not be populated if the key request came from federation. - UserID string - // Maps user IDs to a list of devices - UserToDevices map[string][]string - Timeout time.Duration -} - -type QueryKeysResponse struct { - // Map of remote server domain to error JSON - Failures map[string]interface{} - // Map of user_id to device_id to device_key - DeviceKeys map[string]map[string]json.RawMessage - // Maps of user_id to cross signing key - MasterKeys map[string]gomatrixserverlib.CrossSigningKey - SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey - UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // Set if there was a fatal error processing this query - Error *KeyError -} - -type QueryKeyChangesRequest struct { - // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning - Offset int64 - // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. - // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). - ToOffset int64 -} - -type QueryKeyChangesResponse struct { - // The set of users who have had their keys change. - UserIDs []string - // The latest offset represented in this response. - Offset int64 - // Set if there was a problem handling the request. - Error *KeyError -} - -type QueryOneTimeKeysRequest struct { - // The local user to query OTK counts for - UserID string - // The device to query OTK counts for - DeviceID string -} - -type QueryOneTimeKeysResponse struct { - // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 - Count OneTimeKeysCount - Error *KeyError -} - -type QueryDeviceMessagesRequest struct { - UserID string -} - -type QueryDeviceMessagesResponse struct { - // The latest stream ID - StreamID int64 - Devices []DeviceMessage - Error *KeyError -} - -type QuerySignaturesRequest struct { - // A map of target user ID -> target key/device IDs to retrieve signatures for - TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"` -} - -type QuerySignaturesResponse struct { - // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures - Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap - // A map of target user ID -> cross-signing master key - MasterKeys map[string]gomatrixserverlib.CrossSigningKey - // A map of target user ID -> cross-signing self-signing key - SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // A map of target user ID -> cross-signing user-signing key - UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // The request error, if any - Error *KeyError -} - -type PerformMarkAsStaleRequest struct { - UserID string - Domain gomatrixserverlib.ServerName - DeviceID string -} diff --git a/keyserver/consumers/devicelistupdate.go b/keyserver/consumers/devicelistupdate.go deleted file mode 100644 index cd911f8c..00000000 --- a/keyserver/consumers/devicelistupdate.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package consumers - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" - - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" -) - -// DeviceListUpdateConsumer consumes device list updates that came in over federation. -type DeviceListUpdateConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - updater *internal.DeviceListUpdater - isLocalServerName func(gomatrixserverlib.ServerName) bool -} - -// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. -func NewDeviceListUpdateConsumer( - process *process.ProcessContext, - cfg *config.KeyServer, - js nats.JetStreamContext, - updater *internal.DeviceListUpdater, -) *DeviceListUpdateConsumer { - return &DeviceListUpdateConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), - updater: updater, - isLocalServerName: cfg.Matrix.IsLocalServerName, - } -} - -// Start consuming from key servers -func (t *DeviceListUpdateConsumer) Start() error { - return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, 1, - t.onMessage, nats.DeliverAll(), nats.ManualAck(), - ) -} - -// onMessage is called in response to a message received on the -// key change events topic from the key server. -func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { - msg := msgs[0] // Guaranteed to exist if onMessage is called - var m gomatrixserverlib.DeviceListUpdateEvent - if err := json.Unmarshal(msg.Data, &m); err != nil { - logrus.WithError(err).Errorf("Failed to read from device list update input topic") - return true - } - origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) - if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { - return true - } else if t.isLocalServerName(serverName) { - return true - } else if serverName != origin { - return true - } - - err := t.updater.Update(ctx, m) - if err != nil { - logrus.WithFields(logrus.Fields{ - "user_id": m.UserID, - "device_id": m.DeviceID, - "stream_id": m.StreamID, - "prev_id": m.PrevID, - }).WithError(err).Errorf("Failed to update device list") - return false - } - return true -} diff --git a/keyserver/consumers/signingkeyupdate.go b/keyserver/consumers/signingkeyupdate.go deleted file mode 100644 index bcceaad1..00000000 --- a/keyserver/consumers/signingkeyupdate.go +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package consumers - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" - - keyapi "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" -) - -// SigningKeyUpdateConsumer consumes signing key updates that came in over federation. -type SigningKeyUpdateConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - keyAPI *internal.KeyInternalAPI - cfg *config.KeyServer - isLocalServerName func(gomatrixserverlib.ServerName) bool -} - -// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. -func NewSigningKeyUpdateConsumer( - process *process.ProcessContext, - cfg *config.KeyServer, - js nats.JetStreamContext, - keyAPI *internal.KeyInternalAPI, -) *SigningKeyUpdateConsumer { - return &SigningKeyUpdateConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - keyAPI: keyAPI, - cfg: cfg, - isLocalServerName: cfg.Matrix.IsLocalServerName, - } -} - -// Start consuming from key servers -func (t *SigningKeyUpdateConsumer) Start() error { - return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, 1, - t.onMessage, nats.DeliverAll(), nats.ManualAck(), - ) -} - -// onMessage is called in response to a message received on the -// signing key update events topic from the key server. -func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { - msg := msgs[0] // Guaranteed to exist if onMessage is called - var updatePayload keyapi.CrossSigningKeyUpdate - if err := json.Unmarshal(msg.Data, &updatePayload); err != nil { - logrus.WithError(err).Errorf("Failed to read from signing key update input topic") - return true - } - origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) - if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { - logrus.WithError(err).Error("failed to split user id") - return true - } else if t.isLocalServerName(serverName) { - logrus.Warn("dropping device key update from ourself") - return true - } else if serverName != origin { - logrus.Warnf("dropping device key update, %s != %s", serverName, origin) - return true - } - - keys := gomatrixserverlib.CrossSigningKeys{} - if updatePayload.MasterKey != nil { - keys.MasterKey = *updatePayload.MasterKey - } - if updatePayload.SelfSigningKey != nil { - keys.SelfSigningKey = *updatePayload.SelfSigningKey - } - uploadReq := &keyapi.PerformUploadDeviceKeysRequest{ - CrossSigningKeys: keys, - UserID: updatePayload.UserID, - } - uploadRes := &keyapi.PerformUploadDeviceKeysResponse{} - if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { - logrus.WithError(err).Error("failed to upload device keys") - return false - } - if uploadRes.Error != nil { - logrus.WithError(uploadRes.Error).Error("failed to upload device keys") - return true - } - - return true -} diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go deleted file mode 100644 index 99859dff..00000000 --- a/keyserver/internal/cross_signing.go +++ /dev/null @@ -1,587 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "bytes" - "context" - "crypto/ed25519" - "database/sql" - "fmt" - "strings" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" - "golang.org/x/crypto/curve25519" -) - -func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpose gomatrixserverlib.CrossSigningKeyPurpose) error { - // Is there exactly one key? - if len(key.Keys) != 1 { - return fmt.Errorf("should contain exactly one key") - } - - // Does the key ID match the key value? Iterates exactly once - for keyID, keyData := range key.Keys { - b64 := keyData.Encode() - tokens := strings.Split(string(keyID), ":") - if len(tokens) != 2 { - return fmt.Errorf("key ID is incorrectly formatted") - } - if tokens[1] != b64 { - return fmt.Errorf("key ID isn't correct") - } - switch tokens[0] { - case "ed25519": - if len(keyData) != ed25519.PublicKeySize { - return fmt.Errorf("ed25519 key is not the correct length") - } - case "curve25519": - if len(keyData) != curve25519.PointSize { - return fmt.Errorf("curve25519 key is not the correct length") - } - default: - // We can't enforce the key length to be correct for an - // algorithm that we don't recognise, so instead we'll - // just make sure that it isn't incredibly excessive. - if l := len(keyData); l > 4096 { - return fmt.Errorf("unknown key type is too long (%d bytes)", l) - } - } - } - - // Check to see if the signatures make sense - for _, forOriginUser := range key.Signatures { - for originKeyID, originSignature := range forOriginUser { - switch strings.SplitN(string(originKeyID), ":", 1)[0] { - case "ed25519": - if len(originSignature) != ed25519.SignatureSize { - return fmt.Errorf("ed25519 signature is not the correct length") - } - case "curve25519": - return fmt.Errorf("curve25519 signatures are impossible") - default: - if l := len(originSignature); l > 4096 { - return fmt.Errorf("unknown signature type is too long (%d bytes)", l) - } - } - } - } - - // Does the key claim to be from the right user? - if userID != key.UserID { - return fmt.Errorf("key has a user ID mismatch") - } - - // Does the key contain the correct purpose? - useful := false - for _, usage := range key.Usage { - if usage == purpose { - useful = true - break - } - } - if !useful { - return fmt.Errorf("key does not contain correct usage purpose") - } - - return nil -} - -// nolint:gocyclo -func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { - // Find the keys to store. - byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} - toStore := types.CrossSigningKeyMap{} - hasMasterKey := false - - if len(req.MasterKey.Keys) > 0 { - if err := sanityCheckKey(req.MasterKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err != nil { - res.Error = &api.KeyError{ - Err: "Master key sanity check failed: " + err.Error(), - IsInvalidParam: true, - } - return nil - } - - byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey - for _, key := range req.MasterKey.Keys { // iterates once, see sanityCheckKey - toStore[gomatrixserverlib.CrossSigningKeyPurposeMaster] = key - } - hasMasterKey = true - } - - if len(req.SelfSigningKey.Keys) > 0 { - if err := sanityCheckKey(req.SelfSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err != nil { - res.Error = &api.KeyError{ - Err: "Self-signing key sanity check failed: " + err.Error(), - IsInvalidParam: true, - } - return nil - } - - byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey - for _, key := range req.SelfSigningKey.Keys { // iterates once, see sanityCheckKey - toStore[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = key - } - } - - if len(req.UserSigningKey.Keys) > 0 { - if err := sanityCheckKey(req.UserSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeUserSigning); err != nil { - res.Error = &api.KeyError{ - Err: "User-signing key sanity check failed: " + err.Error(), - IsInvalidParam: true, - } - return nil - } - - byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey - for _, key := range req.UserSigningKey.Keys { // iterates once, see sanityCheckKey - toStore[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = key - } - } - - // If there's nothing to do then stop here. - if len(toStore) == 0 { - res.Error = &api.KeyError{ - Err: "No keys were supplied in the request", - IsMissingParam: true, - } - return nil - } - - // We can't have a self-signing or user-signing key without a master - // key, so make sure we have one of those. We will also only actually do - // something if any of the specified keys in the request are different - // to what we've got in the database, to avoid generating key change - // notifications unnecessarily. - existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) - if err != nil { - res.Error = &api.KeyError{ - Err: "Retrieving cross-signing keys from database failed: " + err.Error(), - } - return nil - } - - // If we still can't find a master key for the user then stop the upload. - // This satisfies the "Fails to upload self-signing key without master key" test. - if !hasMasterKey { - if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey { - res.Error = &api.KeyError{ - Err: "No master key was found", - IsMissingParam: true, - } - return nil - } - } - - // Check if anything actually changed compared to what we have in the database. - changed := false - for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{ - gomatrixserverlib.CrossSigningKeyPurposeMaster, - gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, - gomatrixserverlib.CrossSigningKeyPurposeUserSigning, - } { - old, gotOld := existingKeys[purpose] - new, gotNew := toStore[purpose] - if gotOld != gotNew { - // A new key purpose has been specified that we didn't know before, - // or one has been removed. - changed = true - break - } - if !bytes.Equal(old, new) { - // One of the existing keys for a purpose we already knew about has - // changed. - changed = true - break - } - } - if !changed { - return nil - } - - // Store the keys. - if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), - } - return nil - } - - // Now upload any signatures that were included with the keys. - for _, key := range byPurpose { - var targetKeyID gomatrixserverlib.KeyID - for targetKey := range key.Keys { // iterates once, see sanityCheckKey - targetKeyID = targetKey - } - for sigUserID, forSigUserID := range key.Signatures { - if sigUserID != req.UserID { - continue - } - for sigKeyID, sigBytes := range forSigUserID { - if err := a.DB.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err), - } - return nil - } - } - } - } - - // Finally, generate a notification that we updated the keys. - update := api.CrossSigningKeyUpdate{ - UserID: req.UserID, - } - if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok { - update.MasterKey = &mk - } - if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok { - update.SelfSigningKey = &ssk - } - if update.MasterKey == nil && update.SelfSigningKey == nil { - return nil - } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), - } - return nil - } - return nil -} - -func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { - // Before we do anything, we need the master and self-signing keys for this user. - // Then we can verify the signatures make sense. - queryReq := &api.QueryKeysRequest{ - UserID: req.UserID, - UserToDevices: map[string][]string{}, - } - queryRes := &api.QueryKeysResponse{} - for userID := range req.Signatures { - queryReq.UserToDevices[userID] = []string{} - } - _ = a.QueryKeys(ctx, queryReq, queryRes) - - selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} - otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} - - // Sort signatures into two groups: one where people have signed their own - // keys and one where people have signed someone elses - for userID, forUserID := range req.Signatures { - for keyID, keyOrDevice := range forUserID { - switch key := keyOrDevice.CrossSigningBody.(type) { - case *gomatrixserverlib.CrossSigningKey: - if key.UserID == req.UserID { - if _, ok := selfSignatures[userID]; !ok { - selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} - } - selfSignatures[userID][keyID] = keyOrDevice - } else { - if _, ok := otherSignatures[userID]; !ok { - otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} - } - otherSignatures[userID][keyID] = keyOrDevice - } - - case *gomatrixserverlib.DeviceKeys: - if key.UserID == req.UserID { - if _, ok := selfSignatures[userID]; !ok { - selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} - } - selfSignatures[userID][keyID] = keyOrDevice - } else { - if _, ok := otherSignatures[userID]; !ok { - otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} - } - otherSignatures[userID][keyID] = keyOrDevice - } - - default: - continue - } - } - } - - if err := a.processSelfSignatures(ctx, selfSignatures); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.processSelfSignatures: %s", err), - } - return nil - } - - if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.processOtherSignatures: %s", err), - } - return nil - } - - // Finally, generate a notification that we updated the signatures. - for userID := range req.Signatures { - masterKey := queryRes.MasterKeys[userID] - selfSigningKey := queryRes.SelfSigningKeys[userID] - update := api.CrossSigningKeyUpdate{ - UserID: userID, - MasterKey: &masterKey, - SelfSigningKey: &selfSigningKey, - } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), - } - return nil - } - } - return nil -} - -func (a *KeyInternalAPI) processSelfSignatures( - ctx context.Context, - signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, -) error { - // Here we will process: - // * The user signing their own devices using their self-signing key - // * The user signing their master key using one of their devices - - for targetUserID, forTargetUserID := range signatures { - for targetKeyID, signature := range forTargetUserID { - switch sig := signature.CrossSigningBody.(type) { - case *gomatrixserverlib.CrossSigningKey: - for keyID := range sig.Keys { - split := strings.SplitN(string(keyID), ":", 2) - if len(split) > 1 && gomatrixserverlib.KeyID(split[1]) == targetKeyID { - targetKeyID = keyID // contains the ed25519: or other scheme - break - } - } - for originUserID, forOriginUserID := range sig.Signatures { - for originKeyID, originSig := range forOriginUserID { - if err := a.DB.StoreCrossSigningSigsForTarget( - ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, - ); err != nil { - return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) - } - } - } - - case *gomatrixserverlib.DeviceKeys: - for originUserID, forOriginUserID := range sig.Signatures { - for originKeyID, originSig := range forOriginUserID { - if err := a.DB.StoreCrossSigningSigsForTarget( - ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, - ); err != nil { - return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) - } - } - } - - default: - return fmt.Errorf("unexpected type assertion") - } - } - } - - return nil -} - -func (a *KeyInternalAPI) processOtherSignatures( - ctx context.Context, userID string, queryRes *api.QueryKeysResponse, - signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, -) error { - // Here we will process: - // * A user signing someone else's master keys using their user-signing keys - - for targetUserID, forTargetUserID := range signatures { - for _, signature := range forTargetUserID { - switch sig := signature.CrossSigningBody.(type) { - case *gomatrixserverlib.CrossSigningKey: - // Find the local copy of the master key. We'll use this to be - // sure that the supplied stanza matches the key that we think it - // should be. - masterKey, ok := queryRes.MasterKeys[targetUserID] - if !ok { - return fmt.Errorf("failed to find master key for user %q", targetUserID) - } - - // For each key ID, write the signatures. Maybe there'll be more - // than one algorithm in the future so it's best not to focus on - // everything being ed25519:. - for targetKeyID, suppliedKeyData := range sig.Keys { - // The master key will be supplied in the request, but we should - // make sure that it matches what we think the master key should - // actually be. - localKeyData, lok := masterKey.Keys[targetKeyID] - if !lok { - return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID) - } else if !bytes.Equal(suppliedKeyData, localKeyData) { - return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID) - } - - // We only care about the signatures from the uploading user, so - // we will ignore anything that didn't originate from them. - userSigs, ok := sig.Signatures[userID] - if !ok { - return fmt.Errorf("there are no signatures on master key %q from uploading user %q", targetKeyID, userID) - } - - for originKeyID, originSig := range userSigs { - if err := a.DB.StoreCrossSigningSigsForTarget( - ctx, userID, originKeyID, targetUserID, targetKeyID, originSig, - ); err != nil { - return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) - } - } - } - - default: - // Users should only be signing another person's master key, - // so if we're here, it's probably because it's actually a - // gomatrixserverlib.DeviceKeys, which doesn't make sense. - } - } - } - - return nil -} - -func (a *KeyInternalAPI) crossSigningKeysFromDatabase( - ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse, -) { - for targetUserID := range req.UserToDevices { - keys, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) - if err != nil { - logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID) - continue - } - - for keyType, key := range keys { - var keyID gomatrixserverlib.KeyID - for id := range key.Keys { - keyID = id - break - } - - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) - if err != nil && err != sql.ErrNoRows { - logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID) - continue - } - - appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) { - if key.Signatures == nil { - key.Signatures = types.CrossSigningSigMap{} - } - if _, ok := key.Signatures[originUserID]; !ok { - key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes) - } - key.Signatures[originUserID][originKeyID] = signature - } - - for originUserID, forOrigin := range sigMap { - for originKeyID, signature := range forOrigin { - switch { - case req.UserID != "" && originUserID == req.UserID: - // Include signatures that we created - appendSignature(originUserID, originKeyID, signature) - case originUserID == targetUserID: - // Include signatures that were created by the person whose key - // we are processing - appendSignature(originUserID, originKeyID, signature) - } - } - } - - switch keyType { - case gomatrixserverlib.CrossSigningKeyPurposeMaster: - res.MasterKeys[targetUserID] = key - - case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: - res.SelfSigningKeys[targetUserID] = key - - case gomatrixserverlib.CrossSigningKeyPurposeUserSigning: - res.UserSigningKeys[targetUserID] = key - } - } - } -} - -func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { - for targetUserID, forTargetUser := range req.TargetIDs { - keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) - if err != nil && err != sql.ErrNoRows { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err), - } - continue - } - - for targetPurpose, targetKey := range keyMap { - switch targetPurpose { - case gomatrixserverlib.CrossSigningKeyPurposeMaster: - if res.MasterKeys == nil { - res.MasterKeys = map[string]gomatrixserverlib.CrossSigningKey{} - } - res.MasterKeys[targetUserID] = targetKey - - case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: - if res.SelfSigningKeys == nil { - res.SelfSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{} - } - res.SelfSigningKeys[targetUserID] = targetKey - - case gomatrixserverlib.CrossSigningKeyPurposeUserSigning: - if res.UserSigningKeys == nil { - res.UserSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{} - } - res.UserSigningKeys[targetUserID] = targetKey - } - } - - for _, targetKeyID := range forTargetUser { - // Get own signatures only. - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID) - if err != nil && err != sql.ErrNoRows { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), - } - return nil - } - - for sourceUserID, forSourceUser := range sigMap { - for sourceKeyID, sourceSig := range forSourceUser { - if res.Signatures == nil { - res.Signatures = map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{} - } - if _, ok := res.Signatures[targetUserID]; !ok { - res.Signatures[targetUserID] = map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{} - } - if _, ok := res.Signatures[targetUserID][targetKeyID]; !ok { - res.Signatures[targetUserID][targetKeyID] = types.CrossSigningSigMap{} - } - if _, ok := res.Signatures[targetUserID][targetKeyID][sourceUserID]; !ok { - res.Signatures[targetUserID][targetKeyID][sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - res.Signatures[targetUserID][targetKeyID][sourceUserID][sourceKeyID] = sourceSig - } - } - } - } - return nil -} diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go deleted file mode 100644 index 1b00d1ee..00000000 --- a/keyserver/internal/device_list_update.go +++ /dev/null @@ -1,579 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "hash/fnv" - "net" - "sync" - "time" - - rsapi "github.com/matrix-org/dendrite/roomserver/api" - - "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" - - fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/setup/process" -) - -var ( - deviceListUpdateCount = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: "dendrite", - Subsystem: "keyserver", - Name: "device_list_update", - Help: "Number of times we have attempted to update device lists from this server", - }, - []string{"server"}, - ) -) - -const requestTimeout = time.Second * 30 - -func init() { - prometheus.MustRegister( - deviceListUpdateCount, - ) -} - -// DeviceListUpdater handles device list updates from remote servers. -// -// In the case where we have the prev_id for an update, the updater just stores the update (after acquiring a per-user lock). -// In the case where we do not have the prev_id for an update, the updater marks the user_id as stale and notifies -// a worker to get the latest device list for this user. Note: stream IDs are scoped per user so missing a prev_id -// for a (user, device) does not mean that DEVICE is outdated as the previous ID could be for a different device: -// we have to invalidate all devices for that user. Once the list has been fetched, the per-user lock is acquired and the -// updater stores the latest list along with the latest stream ID. -// -// On startup, the updater spins up N workers which are responsible for querying device keys from remote servers. -// Workers are scoped by homeserver domain, with one worker responsible for many domains, determined by hashing -// mod N the server name. Work is sent via a channel which just serves to "poke" the worker as the data is retrieved -// from the database (which allows us to batch requests to the same server). This has a number of desirable properties: -// - We guarantee only 1 in-flight /keys/query request per server at any time as there is exactly 1 worker responsible -// for that domain. -// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where -// we have many many servers) -// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers. -// -// The downsides are that: -// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free -// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts) -// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests -// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse -// than being stuck behind foo.bar -// -// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is -// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried. -type DeviceListUpdater struct { - process *process.ProcessContext - // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1 - // request to the remote server and race. - // TODO: Put in an LRU cache to bound growth - userIDToMutex map[string]*sync.Mutex - mu *sync.Mutex // protects UserIDToMutex - - db DeviceListUpdaterDatabase - api DeviceListUpdaterAPI - producer KeyChangeProducer - fedClient fedsenderapi.KeyserverFederationAPI - workerChans []chan gomatrixserverlib.ServerName - thisServer gomatrixserverlib.ServerName - - // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will - // block on or timeout via a select. - userIDToChan map[string]chan bool - userIDToChanMu *sync.Mutex - rsAPI rsapi.KeyserverRoomserverAPI -} - -// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. -// Useful for testing. -type DeviceListUpdaterDatabase interface { - // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. - // If no domains are given, all user IDs with stale device lists are returned. - StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) - - // MarkDeviceListStale sets the stale bit for this user to isStale. - MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error - - // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior - // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. - StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error - - // PrevIDsExists returns true if all prev IDs exist for this user. - PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) - - // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. - DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - - DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error -} - -type DeviceListUpdaterAPI interface { - PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error -} - -// KeyChangeProducer is the interface for producers.KeyChange useful for testing. -type KeyChangeProducer interface { - ProduceKeyChanges(keys []api.DeviceMessage) error -} - -// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. -func NewDeviceListUpdater( - process *process.ProcessContext, db DeviceListUpdaterDatabase, - api DeviceListUpdaterAPI, producer KeyChangeProducer, - fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, - rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName, -) *DeviceListUpdater { - return &DeviceListUpdater{ - process: process, - userIDToMutex: make(map[string]*sync.Mutex), - mu: &sync.Mutex{}, - db: db, - api: api, - producer: producer, - fedClient: fedClient, - thisServer: thisServer, - workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), - userIDToChan: make(map[string]chan bool), - userIDToChanMu: &sync.Mutex{}, - rsAPI: rsAPI, - } -} - -// Start the device list updater, which will try to refresh any stale device lists. -func (u *DeviceListUpdater) Start() error { - for i := 0; i < len(u.workerChans); i++ { - // Allocate a small buffer per channel. - // If the buffer limit is reached, backpressure will cause the processing of EDUs - // to stop (in this transaction) until key requests can be made. - ch := make(chan gomatrixserverlib.ServerName, 10) - u.workerChans[i] = ch - go u.worker(ch) - } - - staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) - if err != nil { - return err - } - offset, step := time.Second*10, time.Second - if max := len(staleLists); max > 120 { - step = (time.Second * 120) / time.Duration(max) - } - for _, userID := range staleLists { - userID := userID // otherwise we are only sending the last entry - time.AfterFunc(offset, func() { - u.notifyWorkers(userID) - }) - offset += step - } - return nil -} - -// CleanUp removes stale device entries for users we don't share a room with anymore -func (u *DeviceListUpdater) CleanUp() error { - staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) - if err != nil { - return err - } - - res := rsapi.QueryLeftUsersResponse{} - if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil { - return err - } - - if len(res.LeftUsers) == 0 { - return nil - } - logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers)) - return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers) -} - -func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { - u.mu.Lock() - defer u.mu.Unlock() - if u.userIDToMutex[userID] == nil { - u.userIDToMutex[userID] = &sync.Mutex{} - } - return u.userIDToMutex[userID] -} - -// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it. -// Blocks until the device list is synced or the timeout is reached. -func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error { - mu := u.mutex(userID) - mu.Lock() - err := u.db.MarkDeviceListStale(ctx, userID, true) - mu.Unlock() - if err != nil { - return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err) - } - u.notifyWorkers(userID) - return nil -} - -// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest, -// which assumes when /send 200 OKs that the device lists have been updated. -func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error { - isDeviceListStale, err := u.update(ctx, event) - if err != nil { - return err - } - if isDeviceListStale { - // poke workers to handle stale device lists - u.notifyWorkers(event.UserID) - } - return nil -} - -func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) (bool, error) { - mu := u.mutex(event.UserID) - mu.Lock() - defer mu.Unlock() - // check if we have the prev IDs - exists, err := u.db.PrevIDsExists(ctx, event.UserID, event.PrevID) - if err != nil { - return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err) - } - // if this is the first time we're hearing about this user, sync the device list manually. - if len(event.PrevID) == 0 { - exists = false - } - util.GetLogger(ctx).WithFields(logrus.Fields{ - "prev_ids_exist": exists, - "user_id": event.UserID, - "device_id": event.DeviceID, - "stream_id": event.StreamID, - "prev_ids": event.PrevID, - "display_name": event.DeviceDisplayName, - "deleted": event.Deleted, - }).Trace("DeviceListUpdater.Update") - - // if we haven't missed anything update the database and notify users - if exists || event.Deleted { - k := event.Keys - if event.Deleted { - k = nil - } - keys := []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: event.DeviceID, - DisplayName: event.DeviceDisplayName, - KeyJSON: k, - UserID: event.UserID, - }, - StreamID: event.StreamID, - }, - } - - // DeviceKeysJSON will side-effect modify this, so it needs - // to be a copy, not sharing any pointers with the above. - deviceKeysCopy := *keys[0].DeviceKeys - deviceKeysCopy.KeyJSON = nil - existingKeys := []api.DeviceMessage{ - { - Type: keys[0].Type, - DeviceKeys: &deviceKeysCopy, - StreamID: keys[0].StreamID, - }, - } - - // fetch what keys we had already and only emit changes - if err = u.db.DeviceKeysJSON(ctx, existingKeys); err != nil { - // non-fatal, log and continue - util.GetLogger(ctx).WithError(err).WithField("user_id", event.UserID).Errorf( - "failed to query device keys json for calculating diffs", - ) - } - - err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil) - if err != nil { - return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err) - } - - if err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false); err != nil { - return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err) - } - return false, nil - } - - err = u.db.MarkDeviceListStale(ctx, event.UserID, true) - if err != nil { - return false, fmt.Errorf("failed to mark device list for %s as stale: %w", event.UserID, err) - } - - return true, nil -} - -func (u *DeviceListUpdater) notifyWorkers(userID string) { - _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return - } - hash := fnv.New32a() - _, _ = hash.Write([]byte(remoteServer)) - index := int(int64(hash.Sum32()) % int64(len(u.workerChans))) - - ch := u.assignChannel(userID) - u.workerChans[index] <- remoteServer - select { - case <-ch: - case <-time.After(10 * time.Second): - // we don't return an error in this case as it's not a failure condition. - // we mainly block for the benefit of sytest anyway - } -} - -func (u *DeviceListUpdater) assignChannel(userID string) chan bool { - u.userIDToChanMu.Lock() - defer u.userIDToChanMu.Unlock() - if ch, ok := u.userIDToChan[userID]; ok { - return ch - } - ch := make(chan bool) - u.userIDToChan[userID] = ch - return ch -} - -func (u *DeviceListUpdater) clearChannel(userID string) { - u.userIDToChanMu.Lock() - defer u.userIDToChanMu.Unlock() - if ch, ok := u.userIDToChan[userID]; ok { - close(ch) - delete(u.userIDToChan, userID) - } -} - -func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { - retries := make(map[gomatrixserverlib.ServerName]time.Time) - retriesMu := &sync.Mutex{} - // restarter goroutine which will inject failed servers into ch when it is time - go func() { - var serversToRetry []gomatrixserverlib.ServerName - for { - serversToRetry = serversToRetry[:0] // reuse memory - time.Sleep(time.Second) - retriesMu.Lock() - now := time.Now() - for srv, retryAt := range retries { - if now.After(retryAt) { - serversToRetry = append(serversToRetry, srv) - } - } - for _, srv := range serversToRetry { - delete(retries, srv) - } - retriesMu.Unlock() - for _, srv := range serversToRetry { - ch <- srv - } - } - }() - for serverName := range ch { - retriesMu.Lock() - _, exists := retries[serverName] - retriesMu.Unlock() - if exists { - // Don't retry a server that we're already waiting for. - continue - } - waitTime, shouldRetry := u.processServer(serverName) - if shouldRetry { - retriesMu.Lock() - if _, exists = retries[serverName]; !exists { - retries[serverName] = time.Now().Add(waitTime) - } - retriesMu.Unlock() - } - } -} - -func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) { - ctx := u.process.Context() - logger := util.GetLogger(ctx).WithField("server_name", serverName) - deviceListUpdateCount.WithLabelValues(string(serverName)).Inc() - - waitTime := defaultWaitTime // How long should we wait to try again? - successCount := 0 // How many user requests failed? - - userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) - if err != nil { - logger.WithError(err).Error("Failed to load stale device lists") - return waitTime, true - } - - defer func() { - for _, userID := range userIDs { - // always clear the channel to unblock Update calls regardless of success/failure - u.clearChannel(userID) - } - }() - - for _, userID := range userIDs { - userWait, err := u.processServerUser(ctx, serverName, userID) - if err != nil { - if userWait > waitTime { - waitTime = userWait - } - break - } - successCount++ - } - - allUsersSucceeded := successCount == len(userIDs) - if !allUsersSucceeded { - logger.WithFields(logrus.Fields{ - "total": len(userIDs), - "succeeded": successCount, - "failed": len(userIDs) - successCount, - "wait_time": waitTime, - }).Debug("Failed to query device keys for some users") - } - return waitTime, !allUsersSucceeded -} - -func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) { - ctx, cancel := context.WithTimeout(ctx, requestTimeout) - defer cancel() - logger := util.GetLogger(ctx).WithFields(logrus.Fields{ - "server_name": serverName, - "user_id": userID, - }) - res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID) - if err != nil { - if errors.Is(err, context.DeadlineExceeded) { - return time.Minute * 10, err - } - switch e := err.(type) { - case *json.UnmarshalTypeError, *json.SyntaxError: - logger.WithError(err).Debugf("Device list update for %q contained invalid JSON", userID) - return defaultWaitTime, nil - case *fedsenderapi.FederationClientError: - if e.RetryAfter > 0 { - return e.RetryAfter, err - } else if e.Blacklisted { - return time.Hour * 8, err - } - case net.Error: - // Use the default waitTime, if it's a timeout. - // It probably doesn't make sense to try further users. - if !e.Timeout() { - logger.WithError(e).Debug("GetUserDevices returned net.Error") - return time.Minute * 10, err - } - case gomatrix.HTTPError: - // The remote server returned an error, give it some time to recover. - // This is to avoid spamming remote servers, which may not be Matrix servers anymore. - if e.Code >= 300 { - logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError") - return hourWaitTime, err - } - default: - // Something else failed - logger.WithError(err).Debugf("GetUserDevices returned unknown error type: %T", err) - return time.Minute * 10, err - } - } - if res.UserID != userID { - logger.WithError(err).Debugf("User ID %q in device list update response doesn't match expected %q", res.UserID, userID) - return defaultWaitTime, nil - } - if res.MasterKey != nil || res.SelfSigningKey != nil { - uploadReq := &api.PerformUploadDeviceKeysRequest{ - UserID: userID, - } - uploadRes := &api.PerformUploadDeviceKeysResponse{} - if res.MasterKey != nil { - if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil { - uploadReq.MasterKey = *res.MasterKey - } - } - if res.SelfSigningKey != nil { - if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil { - uploadReq.SelfSigningKey = *res.SelfSigningKey - } - } - _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) - } - err = u.updateDeviceList(&res) - if err != nil { - logger.WithError(err).Error("Fetched device list but failed to store/emit it") - return defaultWaitTime, err - } - return defaultWaitTime, nil -} - -func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error { - ctx := context.Background() // we've got the keys, don't time out when persisting them to the database. - keys := make([]api.DeviceMessage, len(res.Devices)) - existingKeys := make([]api.DeviceMessage, len(res.Devices)) - for i, device := range res.Devices { - keyJSON, err := json.Marshal(device.Keys) - if err != nil { - util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device") - continue - } - keys[i] = api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - StreamID: res.StreamID, - DeviceKeys: &api.DeviceKeys{ - DeviceID: device.DeviceID, - DisplayName: device.DisplayName, - UserID: res.UserID, - KeyJSON: keyJSON, - }, - } - existingKeys[i] = api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - UserID: res.UserID, - DeviceID: device.DeviceID, - }, - } - } - // fetch what keys we had already and only emit changes - if err := u.db.DeviceKeysJSON(ctx, existingKeys); err != nil { - // non-fatal, log and continue - util.GetLogger(ctx).WithError(err).WithField("user_id", res.UserID).Errorf( - "failed to query device keys json for calculating diffs", - ) - } - - err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID}) - if err != nil { - return fmt.Errorf("failed to store remote device keys: %w", err) - } - err = u.db.MarkDeviceListStale(ctx, res.UserID, false) - if err != nil { - return fmt.Errorf("failed to mark device list as fresh: %w", err) - } - err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false) - if err != nil { - return fmt.Errorf("failed to emit key changes for fresh device list: %w", err) - } - return nil -} diff --git a/keyserver/internal/device_list_update_default.go b/keyserver/internal/device_list_update_default.go deleted file mode 100644 index 7d357c95..00000000 --- a/keyserver/internal/device_list_update_default.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !vw - -package internal - -import "time" - -const defaultWaitTime = time.Minute -const hourWaitTime = time.Hour diff --git a/keyserver/internal/device_list_update_sytest.go b/keyserver/internal/device_list_update_sytest.go deleted file mode 100644 index 1c60d2eb..00000000 --- a/keyserver/internal/device_list_update_sytest.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build vw - -package internal - -import "time" - -// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite -// results in a one-hour wait time from a previous device so the test times out. This is fine for -// production, but makes an otherwise passing test fail. -const defaultWaitTime = time.Second -const hourWaitTime = time.Second diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go deleted file mode 100644 index 60a2c2f3..00000000 --- a/keyserver/internal/device_list_update_test.go +++ /dev/null @@ -1,431 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "context" - "crypto/ed25519" - "fmt" - "io" - "net/http" - "net/url" - "reflect" - "strings" - "sync" - "testing" - "time" - - "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" - roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" -) - -var ( - ctx = context.Background() -) - -type mockKeyChangeProducer struct { - events []api.DeviceMessage -} - -func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error { - p.events = append(p.events, keys...) - return nil -} - -type mockDeviceListUpdaterDatabase struct { - staleUsers map[string]bool - prevIDsExist func(string, []int64) bool - storedKeys []api.DeviceMessage - mu sync.Mutex // protect staleUsers -} - -func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error { - return nil -} - -// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. -// If no domains are given, all user IDs with stale device lists are returned. -func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - d.mu.Lock() - defer d.mu.Unlock() - var result []string - for userID, isStale := range d.staleUsers { - if !isStale { - continue - } - _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return nil, err - } - if len(domains) == 0 { - result = append(result, userID) - continue - } - for _, d := range domains { - if remoteServer == d { - result = append(result, userID) - break - } - } - } - return result, nil -} - -// MarkDeviceListStale sets the stale bit for this user to isStale. -func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { - d.mu.Lock() - defer d.mu.Unlock() - d.staleUsers[userID] = isStale - return nil -} - -func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool { - d.mu.Lock() - defer d.mu.Unlock() - return d.staleUsers[userID] -} - -// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key -// for this (user, device). Does not modify the stream ID for keys. -func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error { - d.storedKeys = append(d.storedKeys, keys...) - return nil -} - -// PrevIDsExists returns true if all prev IDs exist for this user. -func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { - return d.prevIDsExist(userID, prevIDs), nil -} - -func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { - return nil -} - -type mockDeviceListUpdaterAPI struct { -} - -func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { - return nil -} - -type roundTripper struct { - fn func(*http.Request) (*http.Response, error) -} - -func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - return t.fn(req) -} - -func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { - _, pkey, _ := ed25519.GenerateKey(nil) - fedClient := gomatrixserverlib.NewFederationClient( - []*gomatrixserverlib.SigningIdentity{ - { - ServerName: gomatrixserverlib.ServerName("example.test"), - KeyID: gomatrixserverlib.KeyID("ed25519:test"), - PrivateKey: pkey, - }, - }, - ) - fedClient.Client = *gomatrixserverlib.NewClient( - gomatrixserverlib.WithTransport(&roundTripper{tripper}), - ) - return fedClient -} - -// Test that the device keys get persisted and emitted if we have the previous IDs. -func TestUpdateHavePrevID(t *testing.T) { - db := &mockDeviceListUpdaterDatabase{ - staleUsers: make(map[string]bool), - prevIDsExist: func(string, []int64) bool { - return true - }, - } - ap := &mockDeviceListUpdaterAPI{} - producer := &mockKeyChangeProducer{} - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost") - event := gomatrixserverlib.DeviceListUpdateEvent{ - DeviceDisplayName: "Foo Bar", - Deleted: false, - DeviceID: "FOO", - Keys: []byte(`{"key":"value"}`), - PrevID: []int64{0}, - StreamID: 1, - UserID: "@alice:localhost", - } - err := updater.Update(ctx, event) - if err != nil { - t.Fatalf("Update returned an error: %s", err) - } - want := api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - StreamID: event.StreamID, - DeviceKeys: &api.DeviceKeys{ - DeviceID: event.DeviceID, - DisplayName: event.DeviceDisplayName, - KeyJSON: event.Keys, - UserID: event.UserID, - }, - } - if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { - t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) - } - if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { - t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) - } - if db.isStale(event.UserID) { - t.Errorf("%s incorrectly marked as stale", event.UserID) - } -} - -// Test that device keys are fetched from the remote server if we are missing prev IDs -// and that the user's devices are marked as stale until it succeeds. -func TestUpdateNoPrevID(t *testing.T) { - db := &mockDeviceListUpdaterDatabase{ - staleUsers: make(map[string]bool), - prevIDsExist: func(string, []int64) bool { - return false - }, - } - ap := &mockDeviceListUpdaterAPI{} - producer := &mockKeyChangeProducer{} - remoteUserID := "@alice:example.somewhere" - var wg sync.WaitGroup - wg.Add(1) - keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` - fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { - defer wg.Done() - if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) { - return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) - } - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(` - { - "user_id": "` + remoteUserID + `", - "stream_id": 5, - "devices": [ - { - "device_id": "JLAFKJWSCS", - "keys": ` + keyJSON + `, - "device_display_name": "Mobile Phone" - } - ] - } - `)), - }, nil - }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test") - if err := updater.Start(); err != nil { - t.Fatalf("failed to start updater: %s", err) - } - event := gomatrixserverlib.DeviceListUpdateEvent{ - DeviceDisplayName: "Mobile Phone", - Deleted: false, - DeviceID: "another_device_id", - Keys: []byte(`{"key":"value"}`), - PrevID: []int64{3}, - StreamID: 4, - UserID: remoteUserID, - } - err := updater.Update(ctx, event) - - if err != nil { - t.Fatalf("Update returned an error: %s", err) - } - t.Log("waiting for /users/devices to be called...") - wg.Wait() - // wait a bit for db to be updated... - time.Sleep(100 * time.Millisecond) - want := api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - StreamID: 5, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "JLAFKJWSCS", - DisplayName: "Mobile Phone", - UserID: remoteUserID, - KeyJSON: []byte(keyJSON), - }, - } - // Now we should have a fresh list and the keys and emitted something - if db.isStale(event.UserID) { - t.Errorf("%s still marked as stale", event.UserID) - } - if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { - t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON)) - t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) - } - if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { - t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) - } - -} - -// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the -// update is still ongoing. -func TestDebounce(t *testing.T) { - t.Skipf("panic on closed channel on GHA") - db := &mockDeviceListUpdaterDatabase{ - staleUsers: make(map[string]bool), - prevIDsExist: func(string, []int64) bool { - return true - }, - } - ap := &mockDeviceListUpdaterAPI{} - producer := &mockKeyChangeProducer{} - fedCh := make(chan *http.Response, 1) - srv := gomatrixserverlib.ServerName("example.com") - userID := "@alice:example.com" - keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` - incomingFedReq := make(chan struct{}) - fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { - if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) { - return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) - } - close(incomingFedReq) - return <-fedCh, nil - }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost") - if err := updater.Start(); err != nil { - t.Fatalf("failed to start updater: %s", err) - } - - // hit this 5 times - var wg sync.WaitGroup - wg.Add(5) - for i := 0; i < 5; i++ { - go func() { - defer wg.Done() - if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil { - t.Errorf("ManualUpdate: %s", err) - } - }() - } - - // wait until the updater hits federation - select { - case <-incomingFedReq: - case <-time.After(time.Second): - t.Fatalf("timed out waiting for updater to hit federation") - } - - // user should be marked as stale - if !db.isStale(userID) { - t.Errorf("user %s not marked as stale", userID) - } - // now send the response over federation - fedCh <- &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(` - { - "user_id": "` + userID + `", - "stream_id": 5, - "devices": [ - { - "device_id": "JLAFKJWSCS", - "keys": ` + keyJSON + `, - "device_display_name": "Mobile Phone" - } - ] - } - `)), - } - close(fedCh) - // wait until all 5 ManualUpdates return. If we hit federation again we won't send a response - // and should panic with read on a closed channel - wg.Wait() - - // user is no longer stale now - if db.isStale(userID) { - t.Errorf("user %s is marked as stale", userID) - } -} - -func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) { - t.Helper() - - base, _, _ := testrig.Base(nil) - connStr, clearDB := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) - if err != nil { - t.Fatal(err) - } - - return db, clearDB -} - -type mockKeyserverRoomserverAPI struct { - leftUsers []string -} - -func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { - res.LeftUsers = m.leftUsers - return nil -} - -func TestDeviceListUpdater_CleanUp(t *testing.T) { - processCtx := process.NewProcessContext() - - alice := test.NewUser(t) - bob := test.NewUser(t) - - // Bob is not joined to any of our rooms - rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}} - - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clearDB := mustCreateKeyserverDB(t, dbType) - defer clearDB() - - // This should not get deleted - if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil { - t.Error(err) - } - - // this one should get deleted - if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil { - t.Error(err) - } - - updater := NewDeviceListUpdater(processCtx, db, nil, - nil, nil, - 0, rsAPI, "test") - if err := updater.CleanUp(); err != nil { - t.Error(err) - } - - // check that we still have Alice in our stale list - staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) - if err != nil { - t.Error(err) - } - - // There should only be Alice - wantCount := 1 - if count := len(staleUsers); count != wantCount { - t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count) - } - - if staleUsers[0] != alice.ID { - t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID) - } - }) -} diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go deleted file mode 100644 index 9a08a0bb..00000000 --- a/keyserver/internal/internal.go +++ /dev/null @@ -1,816 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "sync" - "time" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - - fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/producers" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/config" - userapi "github.com/matrix-org/dendrite/userapi/api" -) - -type KeyInternalAPI struct { - DB storage.Database - Cfg *config.KeyServer - FedClient fedsenderapi.KeyserverFederationAPI - UserAPI userapi.KeyserverUserAPI - Producer *producers.KeyChange - Updater *DeviceListUpdater -} - -func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { - a.UserAPI = i -} - -func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error { - userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset) - if err != nil { - res.Error = &api.KeyError{ - Err: err.Error(), - } - return nil - } - res.Offset = latest - res.UserIDs = userIDs - return nil -} - -func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error { - res.KeyErrors = make(map[string]map[string]*api.KeyError) - if len(req.DeviceKeys) > 0 { - a.uploadLocalDeviceKeys(ctx, req, res) - } - if len(req.OneTimeKeys) > 0 { - a.uploadOneTimeKeys(ctx, req, res) - } - otks, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) - if err != nil { - return err - } - res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks} - return nil -} - -func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { - res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) - res.Failures = make(map[string]interface{}) - // wrap request map in a top-level by-domain map - domainToDeviceKeys := make(map[string]map[string]map[string]string) - for userID, val := range req.OneTimeKeys { - _, serverName, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - continue // ignore invalid users - } - nested, ok := domainToDeviceKeys[string(serverName)] - if !ok { - nested = make(map[string]map[string]string) - } - nested[userID] = val - domainToDeviceKeys[string(serverName)] = nested - } - for domain, local := range domainToDeviceKeys { - if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { - continue - } - // claim local keys - keys, err := a.DB.ClaimKeys(ctx, local) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err), - } - } - util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys") - for _, key := range keys { - _, ok := res.OneTimeKeys[key.UserID] - if !ok { - res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage) - } - _, ok = res.OneTimeKeys[key.UserID][key.DeviceID] - if !ok { - res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage) - } - for keyID, keyJSON := range key.KeyJSON { - res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON - } - } - delete(domainToDeviceKeys, domain) - } - if len(domainToDeviceKeys) > 0 { - a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) - } - return nil -} - -func (a *KeyInternalAPI) claimRemoteKeys( - ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, -) { - var wg sync.WaitGroup // Wait for fan-out goroutines to finish - var mu sync.Mutex // Protects the response struct - var claimed int // Number of keys claimed in total - var failures int // Number of servers we failed to ask - - util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys)) - wg.Add(len(domainToDeviceKeys)) - - for d, k := range domainToDeviceKeys { - go func(domain string, keysToClaim map[string]map[string]string) { - fedCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - defer wg.Done() - - claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) - - mu.Lock() - defer mu.Unlock() - - if err != nil { - util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed") - res.Failures[domain] = map[string]interface{}{ - "message": err.Error(), - } - failures++ - return - } - - for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys { - res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage) - for deviceID, keys := range deviceIDToKeys { - res.OneTimeKeys[userID][deviceID] = keys - claimed += len(keys) - } - } - }(d, k) - } - - wg.Wait() - util.GetLogger(ctx).WithFields(logrus.Fields{ - "num_keys": claimed, - "num_failures": failures, - }).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys)) -} - -func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { - if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("Failed to delete device keys: %s", err), - } - } - return nil -} - -func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error { - count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("Failed to query OTK counts: %s", err), - } - return nil - } - res.Count = *count - return nil -} - -func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { - msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to query DB for device keys: %s", err), - } - return nil - } - maxStreamID := int64(0) - // remove deleted devices - var result []api.DeviceMessage - for _, m := range msgs { - if m.StreamID > maxStreamID { - maxStreamID = m.StreamID - } - if m.KeyJSON == nil || len(m.KeyJSON) == 0 { - continue - } - result = append(result, m) - } - res.Devices = result - res.StreamID = maxStreamID - return nil -} - -// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present -// in our database. -func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error { - knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{}, true) - if err != nil { - return err - } - if len(knownDevices) == 0 { - return nil // fmt.Errorf("unknown user %s", req.UserID) - } - - for i := range knownDevices { - if knownDevices[i].DeviceID == req.DeviceID { - return nil // we already know about this device - } - } - - return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID) -} - -// nolint:gocyclo -func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { - var respMu sync.Mutex - res.DeviceKeys = make(map[string]map[string]json.RawMessage) - res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) - res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) - res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) - res.Failures = make(map[string]interface{}) - - // make a map from domain to device keys - domainToDeviceKeys := make(map[string]map[string][]string) - domainToCrossSigningKeys := make(map[string]map[string]struct{}) - for userID, deviceIDs := range req.UserToDevices { - _, serverName, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - continue // ignore invalid users - } - domain := string(serverName) - // query local devices - if a.Cfg.Matrix.IsLocalServerName(serverName) { - deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to query local device keys: %s", err), - } - return nil - } - - // pull out display names after we have the keys so we handle wildcards correctly - var dids []string - for _, dk := range deviceKeys { - dids = append(dids, dk.DeviceID) - } - var queryRes userapi.QueryDeviceInfosResponse - err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{ - DeviceIDs: dids, - }, &queryRes) - if err != nil { - util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") - } - - if res.DeviceKeys[userID] == nil { - res.DeviceKeys[userID] = make(map[string]json.RawMessage) - } - for _, dk := range deviceKeys { - if len(dk.KeyJSON) == 0 { - continue // don't include blank keys - } - // inject display name if known (either locally or remotely) - displayName := dk.DisplayName - if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" { - displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName - } - dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { - DisplayName string `json:"device_display_name,omitempty"` - }{displayName}) - res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON - } - } else { - domainToDeviceKeys[domain] = make(map[string][]string) - domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) - } - // work out if our cross-signing request for this user was - // satisfied, if not add them to the list of things to fetch - if _, ok := res.MasterKeys[userID]; !ok { - if _, ok := domainToCrossSigningKeys[domain]; !ok { - domainToCrossSigningKeys[domain] = make(map[string]struct{}) - } - domainToCrossSigningKeys[domain][userID] = struct{}{} - } - if _, ok := res.SelfSigningKeys[userID]; !ok { - if _, ok := domainToCrossSigningKeys[domain]; !ok { - domainToCrossSigningKeys[domain] = make(map[string]struct{}) - } - domainToCrossSigningKeys[domain][userID] = struct{}{} - } - } - - // attempt to satisfy key queries from the local database first as we should get device updates pushed to us - domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys) - if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 { - // perform key queries for remote devices - a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) - } - - // Now that we've done the potentially expensive work of asking the federation, - // try filling the cross-signing keys from the database that we know about. - a.crossSigningKeysFromDatabase(ctx, req, res) - - // Finally, append signatures that we know about - // TODO: This is horrible because we need to round-trip the signature from - // JSON, add the signatures and marshal it again, for some reason? - - for targetUserID, masterKey := range res.MasterKeys { - if masterKey.Signatures == nil { - masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - for targetKeyID := range masterKey.Keys { - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) - if err != nil { - // Stop executing the function if the context was canceled/the deadline was exceeded, - // as we can't continue without a valid context. - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return nil - } - logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") - continue - } - if len(sigMap) == 0 { - continue - } - for sourceUserID, forSourceUser := range sigMap { - for sourceKeyID, sourceSig := range forSourceUser { - if _, ok := masterKey.Signatures[sourceUserID]; !ok { - masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig - } - } - } - } - - for targetUserID, forUserID := range res.DeviceKeys { - for targetKeyID, key := range forUserID { - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID)) - if err != nil { - // Stop executing the function if the context was canceled/the deadline was exceeded, - // as we can't continue without a valid context. - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return nil - } - logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") - continue - } - if len(sigMap) == 0 { - continue - } - var deviceKey gomatrixserverlib.DeviceKeys - if err = json.Unmarshal(key, &deviceKey); err != nil { - continue - } - if deviceKey.Signatures == nil { - deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - for sourceUserID, forSourceUser := range sigMap { - for sourceKeyID, sourceSig := range forSourceUser { - if _, ok := deviceKey.Signatures[sourceUserID]; !ok { - deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig - } - } - if js, err := json.Marshal(deviceKey); err == nil { - res.DeviceKeys[targetUserID][targetKeyID] = js - } - } - } - return nil -} - -func (a *KeyInternalAPI) remoteKeysFromDatabase( - ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string, -) map[string]map[string][]string { - fetchRemote := make(map[string]map[string][]string) - for domain, userToDeviceMap := range domainToDeviceKeys { - for userID, deviceIDs := range userToDeviceMap { - // we can't safely return keys from the db when all devices are requested as we don't - // know if one has just been added. - if len(deviceIDs) > 0 { - err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs) - if err == nil { - continue - } - util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase") - } - // fetch device lists from remote - if _, ok := fetchRemote[domain]; !ok { - fetchRemote[domain] = make(map[string][]string) - } - fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...) - - } - } - return fetchRemote -} - -func (a *KeyInternalAPI) queryRemoteKeys( - ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, - domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{}, -) { - resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys)) - // allows us to wait until all federation servers have been poked - var wg sync.WaitGroup - // mutex for writing directly to res (e.g failures) - var respMu sync.Mutex - - domains := map[string]struct{}{} - for domain := range domainToDeviceKeys { - if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { - continue - } - domains[domain] = struct{}{} - } - for domain := range domainToCrossSigningKeys { - if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { - continue - } - domains[domain] = struct{}{} - } - wg.Add(len(domains)) - - // fan out - for domain := range domains { - go a.queryRemoteKeysOnServer( - ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain], - &wg, &respMu, timeout, resultCh, res, - ) - } - - // Close the result channel when the goroutines have quit so the for .. range exits - go func() { - wg.Wait() - close(resultCh) - }() - - processResult := func(result *gomatrixserverlib.RespQueryKeys) { - respMu.Lock() - defer respMu.Unlock() - for userID, nest := range result.DeviceKeys { - res.DeviceKeys[userID] = make(map[string]json.RawMessage) - for deviceID, deviceKey := range nest { - keyJSON, err := json.Marshal(deviceKey) - if err != nil { - continue - } - res.DeviceKeys[userID][deviceID] = keyJSON - } - } - - for userID, body := range result.MasterKeys { - res.MasterKeys[userID] = body - } - - for userID, body := range result.SelfSigningKeys { - res.SelfSigningKeys[userID] = body - } - - // TODO: do we want to persist these somewhere now - // that we have fetched them? - } - - for result := range resultCh { - processResult(result) - } -} - -func (a *KeyInternalAPI) queryRemoteKeysOnServer( - ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{}, - wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys, - res *api.QueryKeysResponse, -) { - defer wg.Done() - fedCtx := ctx - if timeout > 0 { - var cancel context.CancelFunc - fedCtx, cancel = context.WithTimeout(ctx, timeout) - defer cancel() - } - // for users who we do not have any knowledge about, try to start doing device list updates for them - // by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but - // lack a stream ID. - userIDsForAllDevices := map[string]struct{}{} - for userID, deviceIDs := range devKeys { - if len(deviceIDs) == 0 { - userIDsForAllDevices[userID] = struct{}{} - } - } - // for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing - // a device list update, so we'll populate those back into the /keys/query list if not - for userID := range crossSigningKeys { - if devKeys == nil { - devKeys = map[string][]string{} - } - if _, ok := userIDsForAllDevices[userID]; !ok { - devKeys[userID] = []string{} - } - } - for userID := range userIDsForAllDevices { - err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID) - if err != nil { - logrus.WithFields(logrus.Fields{ - logrus.ErrorKey: err, - "user_id": userID, - "server": serverName, - }).Error("Failed to manually update device lists for user") - // try to do it via /keys/query - devKeys[userID] = []string{} - continue - } - // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this - // user so the fact that we're populating all devices here isn't a problem so long as we have devices. - err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil) - if err != nil { - logrus.WithFields(logrus.Fields{ - logrus.ErrorKey: err, - "user_id": userID, - "server": serverName, - }).Error("Failed to manually update device lists for user") - // try to do it via /keys/query - devKeys[userID] = []string{} - continue - } - } - if len(devKeys) == 0 { - return - } - queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) - if err == nil { - resultCh <- &queryKeysResp - return - } - respMu.Lock() - res.Failures[serverName] = map[string]interface{}{ - "message": err.Error(), - } - respMu.Unlock() - - // last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server - // is down, better to return something than nothing at all. Clients can know about the failure by - // inspecting the failures map though so they can know it's a cached response. - for userID, dkeys := range devKeys { - // drop the error as it's already a failure at this point - _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys) - } - - // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache - respMu.Lock() - if len(res.DeviceKeys) > 0 { - delete(res.Failures, serverName) - } - respMu.Unlock() -} - -func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( - ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string, -) error { - keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) - // if we can't query the db or there are fewer keys than requested, fetch from remote. - if err != nil { - return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) - } - if len(keys) < len(deviceIDs) { - return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID) - } - if len(deviceIDs) == 0 && len(keys) == 0 { - return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID) - } - respMu.Lock() - if res.DeviceKeys[userID] == nil { - res.DeviceKeys[userID] = make(map[string]json.RawMessage) - } - respMu.Unlock() - - for _, key := range keys { - if len(key.KeyJSON) == 0 { - continue // ignore deleted keys - } - // inject the display name - key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { - DisplayName string `json:"device_display_name,omitempty"` - }{key.DisplayName}) - respMu.Lock() - res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON - respMu.Unlock() - } - return nil -} - -func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { - // get a list of devices from the user API that actually exist, as - // we won't store keys for devices that don't exist - uapidevices := &userapi.QueryDevicesResponse{} - if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { - res.Error = &api.KeyError{ - Err: err.Error(), - } - return - } - if !uapidevices.UserExists { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("user %q does not exist", req.UserID), - } - return - } - existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices)) - for _, key := range uapidevices.Devices { - existingDeviceMap[key.ID] = struct{}{} - } - - // Get all of the user existing device keys so we can check for changes. - existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), - } - return - } - - // Work out whether we have device keys in the keyserver for devices that - // no longer exist in the user API. This is mostly an exercise to ensure - // that we keep some integrity between the two. - var toClean []gomatrixserverlib.KeyID - for _, k := range existingKeys { - if _, ok := existingDeviceMap[k.DeviceID]; !ok { - toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID)) - } - } - - if len(toClean) > 0 { - if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { - logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean)) - } else { - logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean)) - } - } - - var keysToStore []api.DeviceMessage - - if req.OnlyDisplayNameUpdates { - for _, existingKey := range existingKeys { - for _, newKey := range req.DeviceKeys { - switch { - case existingKey.UserID != newKey.UserID: - continue - case existingKey.DeviceID != newKey.DeviceID: - continue - case existingKey.DisplayName != newKey.DisplayName: - existingKey.DisplayName = newKey.DisplayName - } - } - keysToStore = append(keysToStore, existingKey) - } - } else { - // assert that the user ID / device ID are not lying for each key - for _, key := range req.DeviceKeys { - var serverName gomatrixserverlib.ServerName - _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID) - if err != nil { - continue // ignore invalid users - } - if !a.Cfg.Matrix.IsLocalServerName(serverName) { - continue // ignore remote users - } - if len(key.KeyJSON) == 0 { - keysToStore = append(keysToStore, key.WithStreamID(0)) - continue // deleted keys don't need sanity checking - } - // check that the device in question actually exists in the user - // API before we try and store a key for it - if _, ok := existingDeviceMap[key.DeviceID]; !ok { - continue - } - gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str - gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str - if gotUserID == key.UserID && gotDeviceID == key.DeviceID { - keysToStore = append(keysToStore, key.WithStreamID(0)) - continue - } - - res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ - Err: fmt.Sprintf( - "user_id or device_id mismatch: users: %s - %s, devices: %s - %s", - gotUserID, key.UserID, gotDeviceID, key.DeviceID, - ), - }) - } - } - - // store the device keys and emit changes - err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), - } - return - } - err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates) - if err != nil { - util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) - } -} - -func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { - if req.UserID == "" { - res.Error = &api.KeyError{ - Err: "user ID missing", - } - } - if req.DeviceID != "" && len(req.OneTimeKeys) == 0 { - counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err), - } - } - if counts != nil { - res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts) - } - return - } - for _, key := range req.OneTimeKeys { - // grab existing keys based on (user/device/algorithm/key ID) - keyIDsWithAlgorithms := make([]string, len(key.KeyJSON)) - i := 0 - for keyIDWithAlgo := range key.KeyJSON { - keyIDsWithAlgorithms[i] = keyIDWithAlgo - i++ - } - existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms) - if err != nil { - res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ - Err: "failed to query existing one-time keys: " + err.Error(), - }) - continue - } - for keyIDWithAlgo := range existingKeys { - // if keys exist and the JSON doesn't match, error out as the key already exists - if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) { - res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ - Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo), - }) - continue - } - } - // store one-time keys - counts, err := a.DB.StoreOneTimeKeys(ctx, key) - if err != nil { - res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ - Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()), - }) - continue - } - // collect counts - res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts) - } - -} - -func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error { - // if we only want to update the display names, we can skip the checks below - if onlyUpdateDisplayName { - return producer.ProduceKeyChanges(new) - } - // find keys in new that are not in existing - var keysAdded []api.DeviceMessage - for _, newKey := range new { - exists := false - for _, existingKey := range existing { - // Do not treat the absence of keys as equal, or else we will not emit key changes - // when users delete devices which never had a key to begin with as both KeyJSONs are nil. - if existingKey.DeviceKeysEqual(&newKey) { - exists = true - break - } - } - if !exists { - keysAdded = append(keysAdded, newKey) - } - } - return producer.ProduceKeyChanges(keysAdded) -} diff --git a/keyserver/internal/internal_test.go b/keyserver/internal/internal_test.go deleted file mode 100644 index 8a2c9c5d..00000000 --- a/keyserver/internal/internal_test.go +++ /dev/null @@ -1,156 +0,0 @@ -package internal_test - -import ( - "context" - "reflect" - "testing" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/test" -) - -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { - t.Helper() - connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewDatabase(nil, &config.DatabaseOptions{ - ConnectionString: config.DataSource(connStr), - }) - if err != nil { - t.Fatalf("failed to create new user db: %v", err) - } - return db, close -} - -func Test_QueryDeviceMessages(t *testing.T) { - alice := test.NewUser(t) - type args struct { - req *api.QueryDeviceMessagesRequest - res *api.QueryDeviceMessagesResponse - } - tests := []struct { - name string - args args - wantErr bool - want *api.QueryDeviceMessagesResponse - }{ - { - name: "no existing keys", - args: args{ - req: &api.QueryDeviceMessagesRequest{ - UserID: "@doesNotExist:localhost", - }, - res: &api.QueryDeviceMessagesResponse{}, - }, - want: &api.QueryDeviceMessagesResponse{}, - }, - { - name: "existing user returns devices", - args: args{ - req: &api.QueryDeviceMessagesRequest{ - UserID: alice.ID, - }, - res: &api.QueryDeviceMessagesResponse{}, - }, - want: &api.QueryDeviceMessagesResponse{ - StreamID: 6, - Devices: []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, StreamID: 5, DeviceKeys: &api.DeviceKeys{ - DeviceID: "myDevice", - DisplayName: "first device", - UserID: alice.ID, - KeyJSON: []byte("ghi"), - }, - }, - { - Type: api.TypeDeviceKeyUpdate, StreamID: 6, DeviceKeys: &api.DeviceKeys{ - DeviceID: "mySecondDevice", - DisplayName: "second device", - UserID: alice.ID, - KeyJSON: []byte("jkl"), - }, // streamID 6 - }, - }, - }, - }, - } - - deviceMessages := []api.DeviceMessage{ - { // not the user we're looking for - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - UserID: "@doesNotExist:localhost", - }, - // streamID 1 for this user - }, - { // empty keyJSON will be ignored - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - DeviceID: "myDevice", - UserID: alice.ID, - }, // streamID 1 - }, - { - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - DeviceID: "myDevice", - UserID: alice.ID, - KeyJSON: []byte("abc"), - }, // streamID 2 - }, - { - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - DeviceID: "myDevice", - UserID: alice.ID, - KeyJSON: []byte("def"), - }, // streamID 3 - }, - { - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - DeviceID: "myDevice", - UserID: alice.ID, - KeyJSON: []byte(""), - }, // streamID 4 - }, - { - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - DeviceID: "myDevice", - DisplayName: "first device", - UserID: alice.ID, - KeyJSON: []byte("ghi"), - }, // streamID 5 - }, - { - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - DeviceID: "mySecondDevice", - UserID: alice.ID, - KeyJSON: []byte("jkl"), - DisplayName: "second device", - }, // streamID 6 - }, - } - ctx := context.Background() - - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, closeDB := mustCreateDatabase(t, dbType) - defer closeDB() - if err := db.StoreLocalDeviceKeys(ctx, deviceMessages); err != nil { - t.Fatalf("failed to store local devicesKeys") - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a := &internal.KeyInternalAPI{ - DB: db, - } - if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr { - t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr) - } - got := tt.args.res - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("QueryDeviceMessages(): got:\n%+v, want:\n%+v", got, tt.want) - } - }) - } - }) -} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go deleted file mode 100644 index 2d143682..00000000 --- a/keyserver/keyserver.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package keyserver - -import ( - "github.com/sirupsen/logrus" - - rsapi "github.com/matrix-org/dendrite/roomserver/api" - - fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/consumers" - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/keyserver/producers" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" -) - -// NewInternalAPI returns a concerete implementation of the internal API. Callers -// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. -func NewInternalAPI( - base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI, - rsAPI rsapi.KeyserverRoomserverAPI, -) api.KeyInternalAPI { - js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) - - db, err := storage.NewDatabase(base, &cfg.Database) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to key server database") - } - - keyChangeProducer := &producers.KeyChange{ - Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)), - JetStream: js, - DB: db, - } - ap := &internal.KeyInternalAPI{ - DB: db, - Cfg: cfg, - FedClient: fedClient, - Producer: keyChangeProducer, - } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable - ap.Updater = updater - - // Remove users which we don't share a room with anymore - if err := updater.CleanUp(); err != nil { - logrus.WithError(err).Error("failed to cleanup stale device lists") - } - - go func() { - if err := updater.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start device list updater") - } - }() - - dlConsumer := consumers.NewDeviceListUpdateConsumer( - base.ProcessContext, cfg, js, updater, - ) - if err := dlConsumer.Start(); err != nil { - logrus.WithError(err).Panic("failed to start device list consumer") - } - - sigConsumer := consumers.NewSigningKeyUpdateConsumer( - base.ProcessContext, cfg, js, ap, - ) - if err := sigConsumer.Start(); err != nil { - logrus.WithError(err).Panic("failed to start signing key consumer") - } - - return ap -} diff --git a/keyserver/keyserver_test.go b/keyserver/keyserver_test.go deleted file mode 100644 index 159b280f..00000000 --- a/keyserver/keyserver_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package keyserver - -import ( - "context" - "testing" - - roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" -) - -type mockKeyserverRoomserverAPI struct { - leftUsers []string -} - -func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { - res.LeftUsers = m.leftUsers - return nil -} - -// Merely tests that we can create an internal keyserver API -func Test_NewInternalAPI(t *testing.T) { - rsAPI := &mockKeyserverRoomserverAPI{} - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, closeBase := testrig.CreateBaseDendrite(t, dbType) - defer closeBase() - _ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - }) -} diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go deleted file mode 100644 index f86c3417..00000000 --- a/keyserver/producers/keychange.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package producers - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" -) - -// KeyChange produces key change events for the sync API and federation sender to consume -type KeyChange struct { - Topic string - JetStream nats.JetStreamContext - DB storage.Database -} - -// ProduceKeyChanges creates new change events for each key -func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { - userToDeviceCount := make(map[string]int) - for _, key := range keys { - id, err := p.DB.StoreKeyChange(context.Background(), key.UserID) - if err != nil { - return err - } - key.DeviceChangeID = id - value, err := json.Marshal(key) - if err != nil { - return err - } - - m := &nats.Msg{ - Subject: p.Topic, - Header: nats.Header{}, - } - m.Header.Set(jetstream.UserID, key.UserID) - m.Data = value - - _, err = p.JetStream.PublishMsg(m) - if err != nil { - return err - } - - userToDeviceCount[key.UserID]++ - } - for userID, count := range userToDeviceCount { - logrus.WithFields(logrus.Fields{ - "user_id": userID, - "num_key_changes": count, - }).Tracef("Produced to key change topic '%s'", p.Topic) - } - return nil -} - -func (p *KeyChange) ProduceSigningKeyUpdate(key api.CrossSigningKeyUpdate) error { - output := &api.DeviceMessage{ - Type: api.TypeCrossSigningUpdate, - OutputCrossSigningKeyUpdate: &api.OutputCrossSigningKeyUpdate{ - CrossSigningKeyUpdate: key, - }, - } - - id, err := p.DB.StoreKeyChange(context.Background(), key.UserID) - if err != nil { - return err - } - output.DeviceChangeID = id - - value, err := json.Marshal(output) - if err != nil { - return err - } - - m := &nats.Msg{ - Subject: p.Topic, - Header: nats.Header{}, - } - m.Header.Set(jetstream.UserID, key.UserID) - m.Data = value - - _, err = p.JetStream.PublishMsg(m) - if err != nil { - return err - } - - logrus.WithFields(logrus.Fields{ - "user_id": key.UserID, - }).Tracef("Produced to cross-signing update topic '%s'", p.Topic) - return nil -} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go deleted file mode 100644 index c6a8f44c..00000000 --- a/keyserver/storage/interface.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type Database interface { - // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination - // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. - ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) - - // StoreOneTimeKeys persists the given one-time keys. - StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) - - // OneTimeKeysCount returns a count of all OTKs for this device. - OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - - // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. - DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - - // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). - // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. - // Returns an error if there was a problem storing the keys. - StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error - - // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior - // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. - StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error - - // PrevIDsExists returns true if all prev IDs exist for this user. - PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) - - // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. - // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. - DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) - - // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying - // cross-signing signatures relating to that device. - DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error - - // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key - // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. - ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) - - // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. - // `userID` is the the user who has changed their keys in some way. - StoreKeyChange(ctx context.Context, userID string) (int64, error) - - // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). - // A to offset of types.OffsetNewest means no upper limit. - // Returns the offset of the latest key change. - KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) - - // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. - // If no domains are given, all user IDs with stale device lists are returned. - StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) - - // MarkDeviceListStale sets the stale bit for this user to isStale. - MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error - - CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) - CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) - CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) - - StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error - StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error - - DeleteStaleDeviceLists( - ctx context.Context, - userIDs []string, - ) error -} diff --git a/keyserver/storage/postgres/cross_signing_keys_table.go b/keyserver/storage/postgres/cross_signing_keys_table.go deleted file mode 100644 index 1022157e..00000000 --- a/keyserver/storage/postgres/cross_signing_keys_table.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "fmt" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -var crossSigningKeysSchema = ` -CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( - user_id TEXT NOT NULL, - key_type SMALLINT NOT NULL, - key_data TEXT NOT NULL, - PRIMARY KEY (user_id, key_type) -); -` - -const selectCrossSigningKeysForUserSQL = "" + - "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + - " WHERE user_id = $1" - -const upsertCrossSigningKeysForUserSQL = "" + - "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + - " VALUES($1, $2, $3)" + - " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3" - -type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt -} - -func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { - s := &crossSigningKeysStatements{ - db: db, - } - _, err := db.Exec(crossSigningKeysSchema) - if err != nil { - return nil, err - } - return s, sqlutil.StatementList{ - {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, - {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, - }.Prepare(db) -} - -func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, -) (r types.CrossSigningKeyMap, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") - r = types.CrossSigningKeyMap{} - for rows.Next() { - var keyTypeInt int16 - var keyData gomatrixserverlib.Base64Bytes - if err := rows.Scan(&keyTypeInt, &keyData); err != nil { - return nil, err - } - keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] - if !ok { - return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) - } - r[keyType] = keyData - } - return -} - -func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, -) error { - keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] - if !ok { - return fmt.Errorf("unknown key purpose %q", keyType) - } - if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { - return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) - } - return nil -} diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/keyserver/storage/postgres/cross_signing_sigs_table.go deleted file mode 100644 index 4536b7d8..00000000 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "fmt" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -var crossSigningSigsSchema = ` -CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( - origin_user_id TEXT NOT NULL, - origin_key_id TEXT NOT NULL, - target_user_id TEXT NOT NULL, - target_key_id TEXT NOT NULL, - signature TEXT NOT NULL, - PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) -); - -CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); -` - -const selectCrossSigningSigsForTargetSQL = "" + - "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + - " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3" - -const upsertCrossSigningSigsForTargetSQL = "" + - "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + - " VALUES($1, $2, $3, $4, $5)" + - " ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5" - -const deleteCrossSigningSigsForTargetSQL = "" + - "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" - -type crossSigningSigsStatements struct { - db *sql.DB - selectCrossSigningSigsForTargetStmt *sql.Stmt - upsertCrossSigningSigsForTargetStmt *sql.Stmt - deleteCrossSigningSigsForTargetStmt *sql.Stmt -} - -func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { - s := &crossSigningSigsStatements{ - db: db, - } - _, err := db.Exec(crossSigningSigsSchema) - if err != nil { - return nil, err - } - - m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "keyserver: cross signing signature indexes", - Up: deltas.UpFixCrossSigningSignatureIndexes, - }) - if err = m.Up(context.Background()); err != nil { - return nil, err - } - - return s, sqlutil.StatementList{ - {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, - {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, - {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL}, - }.Prepare(db) -} - -func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( - ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, -) (r types.CrossSigningSigMap, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed") - r = types.CrossSigningSigMap{} - for rows.Next() { - var userID string - var keyID gomatrixserverlib.KeyID - var signature gomatrixserverlib.Base64Bytes - if err := rows.Scan(&userID, &keyID, &signature); err != nil { - return nil, err - } - if _, ok := r[userID]; !ok { - r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - r[userID][keyID] = signature - } - return -} - -func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( - ctx context.Context, txn *sql.Tx, - originUserID string, originKeyID gomatrixserverlib.KeyID, - targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, -) error { - if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { - return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) - } - return nil -} - -func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( - ctx context.Context, txn *sql.Tx, - targetUserID string, targetKeyID gomatrixserverlib.KeyID, -) error { - if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { - return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) - } - return nil -} diff --git a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go b/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go deleted file mode 100644 index 0cfe9e79..00000000 --- a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package deltas - -import ( - "context" - "database/sql" - "fmt" -) - -func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { - // start counting from the last max offset, else 0. We need to do a count(*) first to see if there - // even are entries in this table to know if we can query for log_offset. Without the count then - // the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't - // exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/ - var count int - _ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count) - if count > 0 { - var maxOffset int64 - _ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset) - if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil { - return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err) - } - } - - _, err := tx.ExecContext(ctx, ` - -- make the new table - DROP TABLE IF EXISTS keyserver_key_changes; - CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), - user_id TEXT NOT NULL, - CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) - ); - `) - if err != nil { - return fmt.Errorf("failed to execute upgrade: %w", err) - } - return nil -} - -func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, ` - -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers - DROP SEQUENCE IF EXISTS keyserver_key_changes_seq; - DROP TABLE IF EXISTS keyserver_key_changes; - CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - partition BIGINT NOT NULL, - log_offset BIGINT NOT NULL, - user_id TEXT NOT NULL, - CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset) - ); - `) - if err != nil { - return fmt.Errorf("failed to execute downgrade: %w", err) - } - return nil -} diff --git a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go deleted file mode 100644 index 1a3d4fee..00000000 --- a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package deltas - -import ( - "context" - "database/sql" - "fmt" -) - -func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, ` - ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; - ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id); - - CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); - `) - if err != nil { - return fmt.Errorf("failed to execute upgrade: %w", err) - } - return nil -} - -func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, ` - ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; - ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id); - - DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx; - `) - if err != nil { - return fmt.Errorf("failed to execute downgrade: %w", err) - } - return nil -} diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go deleted file mode 100644 index 2aa11c52..00000000 --- a/keyserver/storage/postgres/device_keys_table.go +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "time" - - "github.com/lib/pq" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" -) - -var deviceKeysSchema = ` --- Stores device keys for users -CREATE TABLE IF NOT EXISTS keyserver_device_keys ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - ts_added_secs BIGINT NOT NULL, - key_json TEXT NOT NULL, - -- the stream ID of this key, scoped per-user. This gets updated when the device key changes. - -- This means we do not store an unbounded append-only log of device keys, which is not actually - -- required in the spec because in the event of a missed update the server fetches the entire - -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. - stream_id BIGINT NOT NULL, - display_name TEXT, - -- Clobber based on tuple of user/device. - CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) -); -` - -const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + - " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + - " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" - -const selectDeviceKeysSQL = "" + - "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" - -const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" - -const selectBatchDeviceKeysWithEmptiesSQL = "" + - "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" - -const selectMaxStreamForUserSQL = "" + - "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" - -const countStreamIDsForUserSQL = "" + - "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)" - -const deleteDeviceKeysSQL = "" + - "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" - -const deleteAllDeviceKeysSQL = "" + - "DELETE FROM keyserver_device_keys WHERE user_id=$1" - -type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt - selectMaxStreamForUserStmt *sql.Stmt - countStreamIDsForUserStmt *sql.Stmt - deleteDeviceKeysStmt *sql.Stmt - deleteAllDeviceKeysStmt *sql.Stmt -} - -func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { - s := &deviceKeysStatements{ - db: db, - } - _, err := db.Exec(deviceKeysSchema) - if err != nil { - return nil, err - } - if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { - return nil, err - } - if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { - return nil, err - } - if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil { - return nil, err - } - if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil { - return nil, err - } - if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { - return nil, err - } - return s, nil -} - -func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { - for i, key := range keys { - var keyJSONStr string - var streamID int64 - var displayName sql.NullString - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) - if err != nil && err != sql.ErrNoRows { - return err - } - // this will be '' when there is no device - keys[i].Type = api.TypeDeviceKeyUpdate - keys[i].KeyJSON = []byte(keyJSONStr) - keys[i].StreamID = streamID - if displayName.Valid { - keys[i].DisplayName = displayName.String - } - } - return nil -} - -func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { - // nullable if there are no results - var nullStream sql.NullInt64 - err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) - if err == sql.ErrNoRows { - err = nil - } - if nullStream.Valid { - streamID = nullStream.Int64 - } - return -} - -func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { - // nullable if there are no results - var count sql.NullInt32 - err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count) - if err != nil { - return 0, err - } - if count.Valid { - return int(count.Int32), nil - } - return 0, nil -} - -func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { - for _, key := range keys { - now := time.Now().Unix() - _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, - ) - if err != nil { - return err - } - } - return nil -} - -func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) - return err -} - -func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) - return err -} - -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { - var stmt *sql.Stmt - if includeEmpty { - stmt = s.selectBatchDeviceKeysWithEmptiesStmt - } else { - stmt = s.selectBatchDeviceKeysStmt - } - rows, err := stmt.QueryContext(ctx, userID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") - deviceIDMap := make(map[string]bool) - for _, d := range deviceIDs { - deviceIDMap[d] = true - } - var result []api.DeviceMessage - var displayName sql.NullString - for rows.Next() { - dk := api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - UserID: userID, - }, - } - if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil { - return nil, err - } - if displayName.Valid { - dk.DisplayName = displayName.String - } - // include the key if we want all keys (no device) or it was asked - if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { - result = append(result, dk) - } - } - return result, rows.Err() -} diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go deleted file mode 100644 index c0e3429c..00000000 --- a/keyserver/storage/postgres/key_changes_table.go +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "errors" - "fmt" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" -) - -var keyChangesSchema = ` --- Stores key change information about users. Used to determine when to send updated device lists to clients. -CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq; -CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), - user_id TEXT NOT NULL, - CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) -); -` - -// Replace based on user ID. We don't care how many times the user's keys have changed, only that they -// have changed, hence we can just keep bumping the change ID for this user. -const upsertKeyChangeSQL = "" + - "INSERT INTO keyserver_key_changes (user_id)" + - " VALUES ($1)" + - " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" + - " DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" + - " RETURNING change_id" - -const selectKeyChangesSQL = "" + - "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" - -type keyChangesStatements struct { - db *sql.DB - upsertKeyChangeStmt *sql.Stmt - selectKeyChangesStmt *sql.Stmt -} - -func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { - s := &keyChangesStatements{ - db: db, - } - _, err := db.Exec(keyChangesSchema) - if err != nil { - return s, err - } - - if err = executeMigration(context.Background(), db); err != nil { - return nil, err - } - return s, nil -} - -func executeMigration(ctx context.Context, db *sql.DB) error { - // TODO: Remove when we are sure we are not having goose artefacts in the db - // This forces an error, which indicates the migration is already applied, since the - // column partition was removed from the table - migrationName := "keyserver: refactor key changes" - - var cName string - err := db.QueryRowContext(ctx, "select column_name from information_schema.columns where table_name = 'keyserver_key_changes' AND column_name = 'partition'").Scan(&cName) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed - if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { - return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) - } - return nil - } - return err - } - m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: migrationName, - Up: deltas.UpRefactorKeyChanges, - }) - - return m.Up(ctx) -} - -func (s *keyChangesStatements) Prepare() (err error) { - if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { - return err - } - if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { - return err - } - return nil -} - -func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { - err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) - return -} - -func (s *keyChangesStatements) SelectKeyChanges( - ctx context.Context, fromOffset, toOffset int64, -) (userIDs []string, latestOffset int64, err error) { - latestOffset = fromOffset - rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) - if err != nil { - return nil, 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") - for rows.Next() { - var userID string - var offset int64 - if err := rows.Scan(&userID, &offset); err != nil { - return nil, 0, err - } - if offset > latestOffset { - latestOffset = offset - } - userIDs = append(userIDs, userID) - } - return -} diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go deleted file mode 100644 index 2117efca..00000000 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "encoding/json" - "time" - - "github.com/lib/pq" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" -) - -var oneTimeKeysSchema = ` --- Stores one-time public keys for users -CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - key_id TEXT NOT NULL, - algorithm TEXT NOT NULL, - ts_added_secs BIGINT NOT NULL, - key_json TEXT NOT NULL, - -- Clobber based on 4-uple of user/device/key/algorithm. - CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm) -); - -CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id); -` - -const upsertKeysSQL = "" + - "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + - " ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" + - " DO UPDATE SET key_json = $6" - -const selectKeysSQL = "" + - "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);" - -const selectKeysCountSQL = "" + - "SELECT algorithm, COUNT(key_id) FROM " + - " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" + - " x GROUP BY algorithm" - -const deleteOneTimeKeySQL = "" + - "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" - -const selectKeyByAlgorithmSQL = "" + - "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" - -const deleteOneTimeKeysSQL = "" + - "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" - -type oneTimeKeysStatements struct { - db *sql.DB - upsertKeysStmt *sql.Stmt - selectKeysStmt *sql.Stmt - selectKeysCountStmt *sql.Stmt - selectKeyByAlgorithmStmt *sql.Stmt - deleteOneTimeKeyStmt *sql.Stmt - deleteOneTimeKeysStmt *sql.Stmt -} - -func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { - s := &oneTimeKeysStatements{ - db: db, - } - _, err := db.Exec(oneTimeKeysSchema) - if err != nil { - return nil, err - } - if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil { - return nil, err - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return nil, err - } - if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { - return nil, err - } - if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { - return nil, err - } - return s, nil -} - -func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms)) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") - - result := make(map[string]json.RawMessage) - var ( - algorithmWithID string - keyJSONStr string - ) - for rows.Next() { - if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil { - return nil, err - } - result[algorithmWithID] = json.RawMessage(keyJSONStr) - } - return result, rows.Err() -} - -func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { - counts := &api.OneTimeKeysCount{ - DeviceID: deviceID, - UserID: userID, - KeyCount: make(map[string]int), - } - rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") - for rows.Next() { - var algorithm string - var count int - if err = rows.Scan(&algorithm, &count); err != nil { - return nil, err - } - counts.KeyCount[algorithm] = count - } - return counts, nil -} - -func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { - now := time.Now().Unix() - counts := &api.OneTimeKeysCount{ - DeviceID: keys.DeviceID, - UserID: keys.UserID, - KeyCount: make(map[string]int), - } - for keyIDWithAlgo, keyJSON := range keys.KeyJSON { - algo, keyID := keys.Split(keyIDWithAlgo) - _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( - ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), - ) - if err != nil { - return nil, err - } - } - rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") - for rows.Next() { - var algorithm string - var count int - if err = rows.Scan(&algorithm, &count); err != nil { - return nil, err - } - counts.KeyCount[algorithm] = count - } - - return counts, rows.Err() -} - -func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( - ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, -) (map[string]json.RawMessage, error) { - var keyID string - var keyJSON string - err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, err - } - _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) - return map[string]json.RawMessage{ - algorithm + ":" + keyID: json.RawMessage(keyJSON), - }, err -} - -func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) - return err -} diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go deleted file mode 100644 index 248ddfb4..00000000 --- a/keyserver/storage/postgres/stale_device_lists.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "time" - - "github.com/lib/pq" - - "github.com/matrix-org/dendrite/internal/sqlutil" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/gomatrixserverlib" -) - -var staleDeviceListsSchema = ` --- Stores whether a user's device lists are stale or not. -CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( - user_id TEXT PRIMARY KEY NOT NULL, - domain TEXT NOT NULL, - is_stale BOOLEAN NOT NULL, - ts_added_secs BIGINT NOT NULL -); - -CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); -` - -const upsertStaleDeviceListSQL = "" + - "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + - " VALUES ($1, $2, $3, $4)" + - " ON CONFLICT (user_id)" + - " DO UPDATE SET is_stale = $3, ts_added_secs = $4" - -const selectStaleDeviceListsWithDomainsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" - -const selectStaleDeviceListsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" - -const deleteStaleDevicesSQL = "" + - "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)" - -type staleDeviceListsStatements struct { - upsertStaleDeviceListStmt *sql.Stmt - selectStaleDeviceListsWithDomainsStmt *sql.Stmt - selectStaleDeviceListsStmt *sql.Stmt - deleteStaleDeviceListsStmt *sql.Stmt -} - -func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { - s := &staleDeviceListsStatements{} - _, err := db.Exec(staleDeviceListsSchema) - if err != nil { - return nil, err - } - return s, sqlutil.StatementList{ - {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, - {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, - {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, - {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, - }.Prepare(db) -} - -func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { - _, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return err - } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) - return err -} - -func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - // we only query for 1 domain or all domains so optimise for those use cases - if len(domains) == 0 { - rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) - if err != nil { - return nil, err - } - return rowsToUserIDs(ctx, rows) - } - var result []string - for _, domain := range domains { - rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) - if err != nil { - return nil, err - } - userIDs, err := rowsToUserIDs(ctx, rows) - if err != nil { - return nil, err - } - result = append(result, userIDs...) - } - return result, nil -} - -// DeleteStaleDeviceLists removes users from stale device lists -func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( - ctx context.Context, txn *sql.Tx, userIDs []string, -) error { - stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt) - _, err := stmt.ExecContext(ctx, pq.Array(userIDs)) - return err -} - -func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { - defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") - for rows.Next() { - var userID string - if err := rows.Scan(&userID); err != nil { - return nil, err - } - result = append(result, userID) - } - return result, rows.Err() -} diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go deleted file mode 100644 index 35e63055..00000000 --- a/keyserver/storage/postgres/storage.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/shared" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -// NewDatabase creates a new sync server database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - var err error - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) - if err != nil { - return nil, err - } - otk, err := NewPostgresOneTimeKeysTable(db) - if err != nil { - return nil, err - } - dk, err := NewPostgresDeviceKeysTable(db) - if err != nil { - return nil, err - } - kc, err := NewPostgresKeyChangesTable(db) - if err != nil { - return nil, err - } - sdl, err := NewPostgresStaleDeviceListsTable(db) - if err != nil { - return nil, err - } - csk, err := NewPostgresCrossSigningKeysTable(db) - if err != nil { - return nil, err - } - css, err := NewPostgresCrossSigningSigsTable(db) - if err != nil { - return nil, err - } - if err = kc.Prepare(); err != nil { - return nil, err - } - d := &shared.Database{ - DB: db, - Writer: writer, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, - StaleDeviceListsTable: sdl, - CrossSigningKeysTable: csk, - CrossSigningSigsTable: css, - } - return d, nil -} diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go deleted file mode 100644 index 54dd6ddc..00000000 --- a/keyserver/storage/shared/storage.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package shared - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type Database struct { - DB *sql.DB - Writer sqlutil.Writer - OneTimeKeysTable tables.OneTimeKeys - DeviceKeysTable tables.DeviceKeys - KeyChangesTable tables.KeyChanges - StaleDeviceListsTable tables.StaleDeviceLists - CrossSigningKeysTable tables.CrossSigningKeys - CrossSigningSigsTable tables.CrossSigningSigs -} - -func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) -} - -func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { - _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) - return err - }) - return -} - -func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { - return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) -} - -func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { - return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) -} - -func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { - count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) - if err != nil { - return false, err - } - return count == len(prevIDs), nil -} - -func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for _, userID := range clearUserIDs { - err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) - if err != nil { - return err - } - } - return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) - }) -} - -func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { - // work out the latest stream IDs for each user - userIDToStreamID := make(map[string]int64) - for _, k := range keys { - userIDToStreamID[k.UserID] = 0 - } - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for userID := range userIDToStreamID { - streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) - if err != nil { - return err - } - userIDToStreamID[userID] = streamID - } - // set the stream IDs for each key - for i := range keys { - k := keys[i] - userIDToStreamID[k.UserID]++ // start stream from 1 - k.StreamID = userIDToStreamID[k.UserID] - keys[i] = k - } - return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) - }) -} - -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { - return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) -} - -func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { - var result []api.OneTimeKeys - err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for userID, deviceToAlgo := range userToDeviceToAlgorithm { - for deviceID, algo := range deviceToAlgo { - keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo) - if err != nil { - return err - } - if keyJSON != nil { - result = append(result, api.OneTimeKeys{ - UserID: userID, - DeviceID: deviceID, - KeyJSON: keyJSON, - }) - } - } - } - return nil - }) - return result, err -} - -func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { - err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) - return err - }) - return -} - -func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { - return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) -} - -// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. -// If no domains are given, all user IDs with stale device lists are returned. -func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) -} - -// MarkDeviceListStale sets the stale bit for this user to isStale. -func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { - return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) - }) -} - -// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying -// cross-signing signatures relating to that device. -func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for _, deviceID := range deviceIDs { - if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) - } - if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) - } - if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err) - } - } - return nil - }) -} - -// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { - keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) - if err != nil { - return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) - } - results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} - for purpose, key := range keyMap { - keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) - result := gomatrixserverlib.CrossSigningKey{ - UserID: userID, - Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose}, - Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ - keyID: key, - }, - } - sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID) - if err != nil { - continue - } - for sigUserID, forSigUserID := range sigMap { - if userID != sigUserID { - continue - } - if result.Signatures == nil { - result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - if _, ok := result.Signatures[sigUserID]; !ok { - result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - for sigKeyID, sigBytes := range forSigUserID { - result.Signatures[sigUserID][sigKeyID] = sigBytes - } - } - results[purpose] = result - } - return results, nil -} - -// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { - return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) -} - -// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. -func (d *Database) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { - return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) -} - -// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. -func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for keyType, keyData := range keyMap { - if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { - return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) - } - } - return nil - }) -} - -// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. -func (d *Database) StoreCrossSigningSigsForTarget( - ctx context.Context, - originUserID string, originKeyID gomatrixserverlib.KeyID, - targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { - return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err) - } - return nil - }) -} - -// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. -func (d *Database) DeleteStaleDeviceLists( - ctx context.Context, - userIDs []string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) - }) -} diff --git a/keyserver/storage/sqlite3/cross_signing_keys_table.go b/keyserver/storage/sqlite3/cross_signing_keys_table.go deleted file mode 100644 index e103d988..00000000 --- a/keyserver/storage/sqlite3/cross_signing_keys_table.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "fmt" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -var crossSigningKeysSchema = ` -CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( - user_id TEXT NOT NULL, - key_type INTEGER NOT NULL, - key_data TEXT NOT NULL, - PRIMARY KEY (user_id, key_type) -); -` - -const selectCrossSigningKeysForUserSQL = "" + - "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + - " WHERE user_id = $1" - -const upsertCrossSigningKeysForUserSQL = "" + - "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + - " VALUES($1, $2, $3)" - -type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt -} - -func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { - s := &crossSigningKeysStatements{ - db: db, - } - _, err := db.Exec(crossSigningKeysSchema) - if err != nil { - return nil, err - } - return s, sqlutil.StatementList{ - {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, - {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, - }.Prepare(db) -} - -func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, -) (r types.CrossSigningKeyMap, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") - r = types.CrossSigningKeyMap{} - for rows.Next() { - var keyTypeInt int16 - var keyData gomatrixserverlib.Base64Bytes - if err := rows.Scan(&keyTypeInt, &keyData); err != nil { - return nil, err - } - keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] - if !ok { - return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) - } - r[keyType] = keyData - } - return -} - -func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, -) error { - keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] - if !ok { - return fmt.Errorf("unknown key purpose %q", keyType) - } - if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { - return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) - } - return nil -} diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/keyserver/storage/sqlite3/cross_signing_sigs_table.go deleted file mode 100644 index 7a153e8f..00000000 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "fmt" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -var crossSigningSigsSchema = ` -CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( - origin_user_id TEXT NOT NULL, - origin_key_id TEXT NOT NULL, - target_user_id TEXT NOT NULL, - target_key_id TEXT NOT NULL, - signature TEXT NOT NULL, - PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) -); - -CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); -` - -const selectCrossSigningSigsForTargetSQL = "" + - "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + - " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4" - -const upsertCrossSigningSigsForTargetSQL = "" + - "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + - " VALUES($1, $2, $3, $4, $5)" - -const deleteCrossSigningSigsForTargetSQL = "" + - "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" - -type crossSigningSigsStatements struct { - db *sql.DB - selectCrossSigningSigsForTargetStmt *sql.Stmt - upsertCrossSigningSigsForTargetStmt *sql.Stmt - deleteCrossSigningSigsForTargetStmt *sql.Stmt -} - -func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { - s := &crossSigningSigsStatements{ - db: db, - } - _, err := db.Exec(crossSigningSigsSchema) - if err != nil { - return nil, err - } - m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "keyserver: cross signing signature indexes", - Up: deltas.UpFixCrossSigningSignatureIndexes, - }) - if err = m.Up(context.Background()); err != nil { - return nil, err - } - - return s, sqlutil.StatementList{ - {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, - {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, - {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL}, - }.Prepare(db) -} - -func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( - ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, -) (r types.CrossSigningSigMap, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForOriginTargetStmt: rows.close() failed") - r = types.CrossSigningSigMap{} - for rows.Next() { - var userID string - var keyID gomatrixserverlib.KeyID - var signature gomatrixserverlib.Base64Bytes - if err := rows.Scan(&userID, &keyID, &signature); err != nil { - return nil, err - } - if _, ok := r[userID]; !ok { - r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - r[userID][keyID] = signature - } - return -} - -func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( - ctx context.Context, txn *sql.Tx, - originUserID string, originKeyID gomatrixserverlib.KeyID, - targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, -) error { - if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { - return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) - } - return nil -} - -func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( - ctx context.Context, txn *sql.Tx, - targetUserID string, targetKeyID gomatrixserverlib.KeyID, -) error { - if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { - return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) - } - return nil -} diff --git a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go b/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go deleted file mode 100644 index cd0f19df..00000000 --- a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package deltas - -import ( - "context" - "database/sql" - "fmt" -) - -func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { - // start counting from the last max offset, else 0. - var maxOffset int64 - var userID string - _ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset) - - _, err := tx.ExecContext(ctx, ` - -- make the new table - DROP TABLE IF EXISTS keyserver_key_changes; - CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - change_id INTEGER PRIMARY KEY AUTOINCREMENT, - -- The key owner - user_id TEXT NOT NULL, - UNIQUE (user_id) - ); - `) - if err != nil { - return fmt.Errorf("failed to execute upgrade: %w", err) - } - // to start counting from maxOffset, insert a row with that value - if userID != "" { - _, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID) - return err - } - return nil -} - -func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, ` - -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers - DROP TABLE IF EXISTS keyserver_key_changes; - CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - partition BIGINT NOT NULL, - offset BIGINT NOT NULL, - -- The key owner - user_id TEXT NOT NULL, - UNIQUE (partition, offset) - ); - `) - if err != nil { - return fmt.Errorf("failed to execute downgrade: %w", err) - } - return nil -} diff --git a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go deleted file mode 100644 index d4e38dea..00000000 --- a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package deltas - -import ( - "context" - "database/sql" - "fmt" -) - -func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, ` - CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( - origin_user_id TEXT NOT NULL, - origin_key_id TEXT NOT NULL, - target_user_id TEXT NOT NULL, - target_key_id TEXT NOT NULL, - signature TEXT NOT NULL, - PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) - ); - - INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature) - SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs; - - DROP TABLE keyserver_cross_signing_sigs; - ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs; - - CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); - `) - if err != nil { - return fmt.Errorf("failed to execute upgrade: %w", err) - } - return nil -} - -func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, ` - CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( - origin_user_id TEXT NOT NULL, - origin_key_id TEXT NOT NULL, - target_user_id TEXT NOT NULL, - target_key_id TEXT NOT NULL, - signature TEXT NOT NULL, - PRIMARY KEY (origin_user_id, target_user_id, target_key_id) - ); - - INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature) - SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs; - - DROP TABLE keyserver_cross_signing_sigs; - ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs; - - DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx; - `) - if err != nil { - return fmt.Errorf("failed to execute downgrade: %w", err) - } - return nil -} diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go deleted file mode 100644 index 73768da5..00000000 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "strings" - "time" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" -) - -var deviceKeysSchema = ` --- Stores device keys for users -CREATE TABLE IF NOT EXISTS keyserver_device_keys ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - ts_added_secs BIGINT NOT NULL, - key_json TEXT NOT NULL, - stream_id BIGINT NOT NULL, - display_name TEXT, - -- Clobber based on tuple of user/device. - UNIQUE (user_id, device_id) -); -` - -const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + - " ON CONFLICT (user_id, device_id)" + - " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" - -const selectDeviceKeysSQL = "" + - "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" - -const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" - -const selectBatchDeviceKeysWithEmptiesSQL = "" + - "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" - -const selectMaxStreamForUserSQL = "" + - "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" - -const countStreamIDsForUserSQL = "" + - "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" - -const deleteDeviceKeysSQL = "" + - "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" - -const deleteAllDeviceKeysSQL = "" + - "DELETE FROM keyserver_device_keys WHERE user_id=$1" - -type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt - selectMaxStreamForUserStmt *sql.Stmt - deleteDeviceKeysStmt *sql.Stmt - deleteAllDeviceKeysStmt *sql.Stmt -} - -func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { - s := &deviceKeysStatements{ - db: db, - } - _, err := db.Exec(deviceKeysSchema) - if err != nil { - return nil, err - } - if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { - return nil, err - } - if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { - return nil, err - } - if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil { - return nil, err - } - if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { - return nil, err - } - return s, nil -} - -func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) - return err -} - -func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) - return err -} - -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { - deviceIDMap := make(map[string]bool) - for _, d := range deviceIDs { - deviceIDMap[d] = true - } - var stmt *sql.Stmt - if includeEmpty { - stmt = s.selectBatchDeviceKeysWithEmptiesStmt - } else { - stmt = s.selectBatchDeviceKeysStmt - } - rows, err := stmt.QueryContext(ctx, userID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") - var result []api.DeviceMessage - var displayName sql.NullString - for rows.Next() { - dk := api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - UserID: userID, - }, - } - if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil { - return nil, err - } - if displayName.Valid { - dk.DisplayName = displayName.String - } - // include the key if we want all keys (no device) or it was asked - if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { - result = append(result, dk) - } - } - return result, rows.Err() -} - -func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { - for i, key := range keys { - var keyJSONStr string - var streamID int64 - var displayName sql.NullString - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) - if err != nil && err != sql.ErrNoRows { - return err - } - // this will be '' when there is no device - keys[i].Type = api.TypeDeviceKeyUpdate - keys[i].KeyJSON = []byte(keyJSONStr) - keys[i].StreamID = streamID - if displayName.Valid { - keys[i].DisplayName = displayName.String - } - } - return nil -} - -func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { - // nullable if there are no results - var nullStream sql.NullInt64 - err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) - if err == sql.ErrNoRows { - err = nil - } - if nullStream.Valid { - streamID = nullStream.Int64 - } - return -} - -func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { - iStreamIDs := make([]interface{}, len(streamIDs)+1) - iStreamIDs[0] = userID - for i := range streamIDs { - iStreamIDs[i+1] = streamIDs[i] - } - query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) - // nullable if there are no results - var count sql.NullInt64 - err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) - if err != nil { - return 0, err - } - if count.Valid { - return int(count.Int64), nil - } - return 0, nil -} - -func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { - for _, key := range keys { - now := time.Now().Unix() - _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, - ) - if err != nil { - return err - } - } - return nil -} diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go deleted file mode 100644 index 0c844d67..00000000 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "errors" - "fmt" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" -) - -var keyChangesSchema = ` --- Stores key change information about users. Used to determine when to send updated device lists to clients. -CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - change_id INTEGER PRIMARY KEY AUTOINCREMENT, - -- The key owner - user_id TEXT NOT NULL, - UNIQUE (user_id) -); -` - -// Replace based on user ID. We don't care how many times the user's keys have changed, only that they -// have changed, hence we can just keep bumping the change ID for this user. -const upsertKeyChangeSQL = "" + - "INSERT OR REPLACE INTO keyserver_key_changes (user_id)" + - " VALUES ($1)" + - " RETURNING change_id" - -const selectKeyChangesSQL = "" + - "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" - -type keyChangesStatements struct { - db *sql.DB - upsertKeyChangeStmt *sql.Stmt - selectKeyChangesStmt *sql.Stmt -} - -func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { - s := &keyChangesStatements{ - db: db, - } - _, err := db.Exec(keyChangesSchema) - if err != nil { - return s, err - } - - if err = executeMigration(context.Background(), db); err != nil { - return nil, err - } - - return s, nil -} - -func executeMigration(ctx context.Context, db *sql.DB) error { - // TODO: Remove when we are sure we are not having goose artefacts in the db - // This forces an error, which indicates the migration is already applied, since the - // column partition was removed from the table - migrationName := "keyserver: refactor key changes" - - var cName string - err := db.QueryRowContext(ctx, `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'keyserver_key_changes' AND p.name = 'partition'`).Scan(&cName) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed - if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { - return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) - } - return nil - } - return err - } - m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: migrationName, - Up: deltas.UpRefactorKeyChanges, - }) - return m.Up(ctx) -} - -func (s *keyChangesStatements) Prepare() (err error) { - if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { - return err - } - if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { - return err - } - return nil -} - -func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { - err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) - return -} - -func (s *keyChangesStatements) SelectKeyChanges( - ctx context.Context, fromOffset, toOffset int64, -) (userIDs []string, latestOffset int64, err error) { - latestOffset = fromOffset - rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) - if err != nil { - return nil, 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") - for rows.Next() { - var userID string - var offset int64 - if err := rows.Scan(&userID, &offset); err != nil { - return nil, 0, err - } - if offset > latestOffset { - latestOffset = offset - } - userIDs = append(userIDs, userID) - } - return -} diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go deleted file mode 100644 index 7a923d0e..00000000 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "encoding/json" - "time" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" -) - -var oneTimeKeysSchema = ` --- Stores one-time public keys for users -CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - key_id TEXT NOT NULL, - algorithm TEXT NOT NULL, - ts_added_secs BIGINT NOT NULL, - key_json TEXT NOT NULL, - -- Clobber based on 4-uple of user/device/key/algorithm. - UNIQUE (user_id, device_id, key_id, algorithm) -); - -CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id); -` - -const upsertKeysSQL = "" + - "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + - " ON CONFLICT (user_id, device_id, key_id, algorithm)" + - " DO UPDATE SET key_json = $6" - -const selectKeysSQL = "" + - "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" - -const selectKeysCountSQL = "" + - "SELECT algorithm, COUNT(key_id) FROM " + - " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" + - " x GROUP BY algorithm" - -const deleteOneTimeKeySQL = "" + - "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" - -const selectKeyByAlgorithmSQL = "" + - "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" - -const deleteOneTimeKeysSQL = "" + - "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" - -type oneTimeKeysStatements struct { - db *sql.DB - upsertKeysStmt *sql.Stmt - selectKeysStmt *sql.Stmt - selectKeysCountStmt *sql.Stmt - selectKeyByAlgorithmStmt *sql.Stmt - deleteOneTimeKeyStmt *sql.Stmt - deleteOneTimeKeysStmt *sql.Stmt -} - -func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { - s := &oneTimeKeysStatements{ - db: db, - } - _, err := db.Exec(oneTimeKeysSchema) - if err != nil { - return nil, err - } - if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil { - return nil, err - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return nil, err - } - if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { - return nil, err - } - if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { - return nil, err - } - return s, nil -} - -func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") - - wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) - for _, ka := range keyIDsWithAlgorithms { - wantSet[ka] = true - } - - result := make(map[string]json.RawMessage) - for rows.Next() { - var keyID string - var algorithm string - var keyJSONStr string - if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil { - return nil, err - } - keyIDWithAlgo := algorithm + ":" + keyID - if wantSet[keyIDWithAlgo] { - result[keyIDWithAlgo] = json.RawMessage(keyJSONStr) - } - } - return result, rows.Err() -} - -func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { - counts := &api.OneTimeKeysCount{ - DeviceID: deviceID, - UserID: userID, - KeyCount: make(map[string]int), - } - rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") - for rows.Next() { - var algorithm string - var count int - if err = rows.Scan(&algorithm, &count); err != nil { - return nil, err - } - counts.KeyCount[algorithm] = count - } - return counts, nil -} - -func (s *oneTimeKeysStatements) InsertOneTimeKeys( - ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys, -) (*api.OneTimeKeysCount, error) { - now := time.Now().Unix() - counts := &api.OneTimeKeysCount{ - DeviceID: keys.DeviceID, - UserID: keys.UserID, - KeyCount: make(map[string]int), - } - for keyIDWithAlgo, keyJSON := range keys.KeyJSON { - algo, keyID := keys.Split(keyIDWithAlgo) - _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( - ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), - ) - if err != nil { - return nil, err - } - } - rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") - for rows.Next() { - var algorithm string - var count int - if err = rows.Scan(&algorithm, &count); err != nil { - return nil, err - } - counts.KeyCount[algorithm] = count - } - - return counts, rows.Err() -} - -func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( - ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, -) (map[string]json.RawMessage, error) { - var keyID string - var keyJSON string - err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, err - } - _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) - if err != nil { - return nil, err - } - if keyJSON == "" { - return nil, nil - } - return map[string]json.RawMessage{ - algorithm + ":" + keyID: json.RawMessage(keyJSON), - }, err -} - -func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) - return err -} diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go deleted file mode 100644 index fd76a6e3..00000000 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "strings" - "time" - - "github.com/matrix-org/dendrite/internal/sqlutil" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/gomatrixserverlib" -) - -var staleDeviceListsSchema = ` --- Stores whether a user's device lists are stale or not. -CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( - user_id TEXT PRIMARY KEY NOT NULL, - domain TEXT NOT NULL, - is_stale BOOLEAN NOT NULL, - ts_added_secs BIGINT NOT NULL -); - -CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); -` - -const upsertStaleDeviceListSQL = "" + - "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + - " VALUES ($1, $2, $3, $4)" + - " ON CONFLICT (user_id)" + - " DO UPDATE SET is_stale = $3, ts_added_secs = $4" - -const selectStaleDeviceListsWithDomainsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" - -const selectStaleDeviceListsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" - -const deleteStaleDevicesSQL = "" + - "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)" - -type staleDeviceListsStatements struct { - db *sql.DB - upsertStaleDeviceListStmt *sql.Stmt - selectStaleDeviceListsWithDomainsStmt *sql.Stmt - selectStaleDeviceListsStmt *sql.Stmt - // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime -} - -func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { - s := &staleDeviceListsStatements{ - db: db, - } - _, err := db.Exec(staleDeviceListsSchema) - if err != nil { - return nil, err - } - return s, sqlutil.StatementList{ - {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, - {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, - {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, - // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime - }.Prepare(db) -} - -func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { - _, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return err - } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) - return err -} - -func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - // we only query for 1 domain or all domains so optimise for those use cases - if len(domains) == 0 { - rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) - if err != nil { - return nil, err - } - return rowsToUserIDs(ctx, rows) - } - var result []string - for _, domain := range domains { - rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) - if err != nil { - return nil, err - } - userIDs, err := rowsToUserIDs(ctx, rows) - if err != nil { - return nil, err - } - result = append(result, userIDs...) - } - return result, nil -} - -// DeleteStaleDeviceLists removes users from stale device lists -func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( - ctx context.Context, txn *sql.Tx, userIDs []string, -) error { - qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) - stmt, err := s.db.Prepare(qry) - if err != nil { - return err - } - defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed") - stmt = sqlutil.TxStmt(txn, stmt) - - params := make([]any, len(userIDs)) - for i := range userIDs { - params[i] = userIDs[i] - } - - _, err = stmt.ExecContext(ctx, params...) - return err -} - -func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { - defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") - for rows.Next() { - var userID string - if err := rows.Scan(&userID); err != nil { - return nil, err - } - result = append(result, userID) - } - return result, rows.Err() -} diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go deleted file mode 100644 index 873fe3e2..00000000 --- a/keyserver/storage/sqlite3/storage.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/shared" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) - if err != nil { - return nil, err - } - otk, err := NewSqliteOneTimeKeysTable(db) - if err != nil { - return nil, err - } - dk, err := NewSqliteDeviceKeysTable(db) - if err != nil { - return nil, err - } - kc, err := NewSqliteKeyChangesTable(db) - if err != nil { - return nil, err - } - sdl, err := NewSqliteStaleDeviceListsTable(db) - if err != nil { - return nil, err - } - csk, err := NewSqliteCrossSigningKeysTable(db) - if err != nil { - return nil, err - } - css, err := NewSqliteCrossSigningSigsTable(db) - if err != nil { - return nil, err - } - - if err = kc.Prepare(); err != nil { - return nil, err - } - d := &shared.Database{ - DB: db, - Writer: writer, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, - StaleDeviceListsTable: sdl, - CrossSigningKeysTable: csk, - CrossSigningSigsTable: css, - } - return d, nil -} diff --git a/keyserver/storage/storage.go b/keyserver/storage/storage.go deleted file mode 100644 index ab6a3540..00000000 --- a/keyserver/storage/storage.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !wasm -// +build !wasm - -package storage - -import ( - "fmt" - - "github.com/matrix-org/dendrite/keyserver/storage/postgres" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets postgres connection parameters -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) - case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties) - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go deleted file mode 100644 index e7a2af7c..00000000 --- a/keyserver/storage/storage_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package storage_test - -import ( - "context" - "reflect" - "sync" - "testing" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" -) - -var ctx = context.Background() - -func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { - base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database) - if err != nil { - t.Fatalf("failed to create new database: %v", err) - } - return db, close -} - -func MustNotError(t *testing.T, err error) { - t.Helper() - if err == nil { - return - } - t.Fatalf("operation failed: %s", err) -} - -func TestKeyChanges(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - _, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") - MustNotError(t, err) - deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeIDC { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) - } - if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -func TestKeyChangesNoDupes(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - if deviceChangeIDA == deviceChangeIDB { - t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) - } - deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeID { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) - } - if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -func TestKeyChangesUpperLimit(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") - MustNotError(t, err) - _, err = db.StoreKeyChange(ctx, "@charlie:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeIDB { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) - } - if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -var dbLock sync.Mutex -var deviceArray = []string{"AAA", "another_device"} - -// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, -// and that they are returned correctly when querying for device keys. -func TestDeviceKeysStreamIDGeneration(t *testing.T) { - var err error - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - alice := "@alice:TestDeviceKeysStreamIDGeneration" - bob := "@bob:TestDeviceKeysStreamIDGeneration" - msgs := []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 1 - }, - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: bob, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 1 as this is a different user - }, - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "another_device", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 2 as this is a 2nd device key - }, - } - MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) - if msgs[0].StreamID != 1 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) - } - if msgs[1].StreamID != 1 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) - } - if msgs[2].StreamID != 2 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) - } - - // updating a device sets the next stream ID for that user - msgs = []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v2"}`), - }, - // StreamID: 3 - }, - } - MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) - if msgs[0].StreamID != 3 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) - } - - dbLock.Lock() - defer dbLock.Unlock() - // Querying for device keys returns the latest stream IDs - msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false) - - if err != nil { - t.Fatalf("DeviceKeysForUser returned error: %s", err) - } - wantStreamIDs := map[string]int64{ - "AAA": 3, - "another_device": 2, - } - if len(msgs) != len(wantStreamIDs) { - t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) - } - for _, m := range msgs { - if m.StreamID != wantStreamIDs[m.DeviceID] { - t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) - } - } - }) -} diff --git a/keyserver/storage/storage_wasm.go b/keyserver/storage/storage_wasm.go deleted file mode 100644 index 75c9053e..00000000 --- a/keyserver/storage/storage_wasm.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "fmt" - - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) - case dbProperties.ConnectionString.IsPostgres(): - return nil, fmt.Errorf("can't use Postgres implementation") - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go deleted file mode 100644 index 24da1125..00000000 --- a/keyserver/storage/tables/interface.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tables - -import ( - "context" - "database/sql" - "encoding/json" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type OneTimeKeys interface { - SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) - CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) - // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. - // Returns an empty map if the key does not exist. - SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) - DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error -} - -type DeviceKeys interface { - SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error - SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) - CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) - SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) - DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error - DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error -} - -type KeyChanges interface { - InsertKeyChange(ctx context.Context, userID string) (int64, error) - // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. - // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. - SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) - - Prepare() error -} - -type StaleDeviceLists interface { - InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error - SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) - DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error -} - -type CrossSigningKeys interface { - SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) - UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error -} - -type CrossSigningSigs interface { - SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) - UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error - DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error -} diff --git a/keyserver/storage/tables/stale_device_lists_test.go b/keyserver/storage/tables/stale_device_lists_test.go deleted file mode 100644 index 76d3badd..00000000 --- a/keyserver/storage/tables/stale_device_lists_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package tables_test - -import ( - "context" - "testing" - - "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/config" - - "github.com/matrix-org/dendrite/keyserver/storage/postgres" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/test" -) - -func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) { - connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := sqlutil.Open(&config.DatabaseOptions{ - ConnectionString: config.DataSource(connStr), - }, nil) - if err != nil { - t.Fatalf("failed to open database: %s", err) - } - switch dbType { - case test.DBTypePostgres: - tab, err = postgres.NewPostgresStaleDeviceListsTable(db) - case test.DBTypeSQLite: - tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db) - } - if err != nil { - t.Fatalf("failed to create new table: %s", err) - } - return tab, close -} - -func TestStaleDeviceLists(t *testing.T) { - alice := test.NewUser(t) - bob := test.NewUser(t) - charlie := "@charlie:localhost" - ctx := context.Background() - - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - tab, closeDB := mustCreateTable(t, dbType) - defer closeDB() - - if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil { - t.Fatalf("failed to insert stale device: %s", err) - } - if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil { - t.Fatalf("failed to insert stale device: %s", err) - } - if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil { - t.Fatalf("failed to insert stale device: %s", err) - } - - // Query one server - wantStaleUsers := []string{alice.ID, bob.ID} - gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) - if err != nil { - t.Fatalf("failed to query stale device lists: %s", err) - } - if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { - t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) - } - - // Query all servers - wantStaleUsers = []string{alice.ID, bob.ID, charlie} - gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{}) - if err != nil { - t.Fatalf("failed to query stale device lists: %s", err) - } - if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { - t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) - } - - // Delete stale devices - deleteUsers := []string{alice.ID, bob.ID} - if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil { - t.Fatalf("failed to delete stale device lists: %s", err) - } - - // Verify we don't get anything back after deleting - gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) - if err != nil { - t.Fatalf("failed to query stale device lists: %s", err) - } - - if gotCount := len(gotStaleUsers); gotCount > 0 { - t.Fatalf("expected no stale users, got %d", gotCount) - } - }) -} diff --git a/keyserver/types/storage.go b/keyserver/types/storage.go deleted file mode 100644 index 7fb90454..00000000 --- a/keyserver/types/storage.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package types - -import ( - "math" - - "github.com/matrix-org/gomatrixserverlib" -) - -const ( - // OffsetNewest tells e.g. the database to get the most current data - OffsetNewest int64 = math.MaxInt64 - // OffsetOldest tells e.g. the database to get the oldest data - OffsetOldest int64 = 0 -) - -// KeyTypePurposeToInt maps a purpose to an integer, which is used in the -// database to reduce the amount of space taken up by this column. -var KeyTypePurposeToInt = map[gomatrixserverlib.CrossSigningKeyPurpose]int16{ - gomatrixserverlib.CrossSigningKeyPurposeMaster: 1, - gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: 2, - gomatrixserverlib.CrossSigningKeyPurposeUserSigning: 3, -} - -// KeyTypeIntToPurpose maps an integer to a purpose, which is used in the -// database to reduce the amount of space taken up by this column. -var KeyTypeIntToPurpose = map[int16]gomatrixserverlib.CrossSigningKeyPurpose{ - 1: gomatrixserverlib.CrossSigningKeyPurposeMaster, - 2: gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, - 3: gomatrixserverlib.CrossSigningKeyPurposeUserSigning, -} - -// Map of purpose -> public key -type CrossSigningKeyMap map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.Base64Bytes - -// Map of user ID -> key ID -> signature -type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes diff --git a/roomserver/api/api.go b/roomserver/api/api.go index a8228ae8..73732ae3 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -17,7 +17,6 @@ type RoomserverInternalAPI interface { ClientRoomserverAPI UserRoomserverAPI FederationRoomserverAPI - KeyserverRoomserverAPI // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -167,6 +166,7 @@ type ClientRoomserverAPI interface { type UserRoomserverAPI interface { QueryLatestEventsAndStateAPI + KeyserverRoomserverAPI QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 7ba01e50..304311c4 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -12,7 +12,6 @@ import ( userAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/gomatrixserverlib" @@ -47,7 +46,7 @@ func TestUsers(t *testing.T) { }) t.Run("kick users", func(t *testing.T) { - usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + usrAPI := userapi.NewInternalAPI(base, rsAPI, nil) rsAPI.SetUserAPI(usrAPI) testKickUsers(t, rsAPI, usrAPI) }) @@ -228,11 +227,10 @@ func TestPurgeRoom(t *testing.T) { fedClient := base.CreateFederationClient() rsAPI := roomserver.NewInternalAPI(base) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI) - userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + userAPI := userapi.NewInternalAPI(base, rsAPI, nil) // this starts the JetStream consumers - syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI) + syncapi.AddPublicRoutes(base, userAPI, rsAPI) federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) rsAPI.SetFederationAPI(nil, nil) diff --git a/setup/monolith.go b/setup/monolith.go index 5bbe4019..d8c65223 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/federationapi" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/transactions" - keyAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/mediaapi" "github.com/matrix-org/dendrite/relayapi" relayAPI "github.com/matrix-org/dendrite/relayapi/api" @@ -45,7 +44,6 @@ type Monolith struct { FederationAPI federationAPI.FederationInternalAPI RoomserverAPI roomserverAPI.RoomserverInternalAPI UserAPI userapi.UserInternalAPI - KeyAPI keyAPI.KeyInternalAPI RelayAPI relayAPI.RelayInternalAPI // Optional @@ -61,19 +59,14 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) { } clientapi.AddPublicRoutes( base, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(), - m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI, + m.FederationAPI, m.UserAPI, userDirectoryProvider, m.ExtPublicRoomsProvider, ) federationapi.AddPublicRoutes( - base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, - m.KeyAPI, nil, - ) - mediaapi.AddPublicRoutes( - base, m.UserAPI, m.Client, - ) - syncapi.AddPublicRoutes( - base, m.UserAPI, m.RoomserverAPI, m.KeyAPI, + base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, nil, ) + mediaapi.AddPublicRoutes(base, m.UserAPI, m.Client) + syncapi.AddPublicRoutes(base, m.UserAPI, m.RoomserverAPI) if m.RelayAPI != nil { relayapi.AddPublicRoutes(base, m.KeyRing, m.RelayAPI) diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 92f08150..5faaefb8 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -19,7 +19,6 @@ import ( "encoding/json" "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" @@ -28,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index 356e8326..32208c58 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -25,7 +25,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -33,6 +32,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" ) // OutputSendToDeviceEventConsumer consumes events that originated in the EDU server. @@ -42,7 +42,7 @@ type OutputSendToDeviceEventConsumer struct { durable string topic string db storage.Database - keyAPI keyapi.SyncKeyAPI + userAPI api.SyncKeyAPI isLocalServerName func(gomatrixserverlib.ServerName) bool stream streams.StreamProvider notifier *notifier.Notifier @@ -55,7 +55,7 @@ func NewOutputSendToDeviceEventConsumer( cfg *config.SyncAPI, js nats.JetStreamContext, store storage.Database, - keyAPI keyapi.SyncKeyAPI, + userAPI api.SyncKeyAPI, notifier *notifier.Notifier, stream streams.StreamProvider, ) *OutputSendToDeviceEventConsumer { @@ -65,7 +65,7 @@ func NewOutputSendToDeviceEventConsumer( topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"), db: store, - keyAPI: keyAPI, + userAPI: userAPI, isLocalServerName: cfg.Matrix.IsLocalServerName, notifier: notifier, stream: stream, @@ -116,7 +116,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs [] _, senderDomain, _ := gomatrixserverlib.SplitID('@', output.Sender) if requestingDeviceID != "" && !s.isLocalServerName(senderDomain) { // Mark the requesting device as stale, if we don't know about it. - if err = s.keyAPI.PerformMarkAsStaleIfNeeded(ctx, &keyapi.PerformMarkAsStaleRequest{ + if err = s.userAPI.PerformMarkAsStaleIfNeeded(ctx, &api.PerformMarkAsStaleRequest{ UserID: output.Sender, Domain: senderDomain, DeviceID: requestingDeviceID, }, &struct{}{}); err != nil { logger.WithError(err).Errorf("failed to mark as stale if needed") diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 4867f7d9..e7f677c8 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -18,22 +18,22 @@ import ( "context" "strings" + keytypes "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - keyapi "github.com/matrix-org/dendrite/keyserver/api" - keytypes "github.com/matrix-org/dendrite/keyserver/types" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" ) // DeviceOTKCounts adds one-time key counts to the /sync response -func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error { - var queryRes keyapi.QueryOneTimeKeysResponse - _ = keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{ +func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceID string, res *types.Response) error { + var queryRes api.QueryOneTimeKeysResponse + _ = keyAPI.QueryOneTimeKeys(ctx, &api.QueryOneTimeKeysRequest{ UserID: userID, DeviceID: deviceID, }, &queryRes) @@ -48,7 +48,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, devi // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // be already filled in with join/leave information. func DeviceListCatchup( - ctx context.Context, db storage.SharedUsers, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI, + ctx context.Context, db storage.SharedUsers, userAPI api.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI, userID string, res *types.Response, from, to types.StreamPosition, ) (newPos types.StreamPosition, hasNew bool, err error) { @@ -74,8 +74,8 @@ func DeviceListCatchup( if from > 0 { offset = int64(from) } - var queryRes keyapi.QueryKeyChangesResponse - _ = keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{ + var queryRes api.QueryKeyChangesResponse + _ = userAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{ Offset: offset, ToOffset: toOffset, }, &queryRes) diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index d64bea11..4bb85166 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -9,7 +9,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -22,44 +21,44 @@ var ( type mockKeyAPI struct{} -func (k *mockKeyAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *keyapi.PerformMarkAsStaleRequest, res *struct{}) error { +func (k *mockKeyAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *userapi.PerformMarkAsStaleRequest, res *struct{}) error { return nil } -func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) error { +func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *userapi.PerformUploadKeysRequest, res *userapi.PerformUploadKeysResponse) error { return nil } func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {} // PerformClaimKeys claims one-time keys for use in pre-key messages -func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) error { +func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *userapi.PerformClaimKeysRequest, res *userapi.PerformClaimKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) error { +func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *userapi.PerformDeleteKeysRequest, res *userapi.PerformDeleteKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) error { +func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *userapi.PerformUploadDeviceKeysRequest, res *userapi.PerformUploadDeviceKeysResponse) error { return nil } -func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) error { +func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *userapi.PerformUploadDeviceSignaturesRequest, res *userapi.PerformUploadDeviceSignaturesResponse) error { return nil } -func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) error { +func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *userapi.QueryKeysRequest, res *userapi.QueryKeysResponse) error { return nil } -func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error { +func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error { return nil } -func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error { +func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error { return nil } -func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) error { +func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error { return nil } -func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) error { +func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *userapi.QuerySignaturesRequest, res *userapi.QuerySignaturesResponse) error { return nil } diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go index 7996c203..e8189c35 100644 --- a/syncapi/streams/stream_devicelist.go +++ b/syncapi/streams/stream_devicelist.go @@ -3,17 +3,17 @@ package streams import ( "context" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type DeviceListStreamProvider struct { DefaultStreamProvider - rsAPI api.SyncRoomserverAPI - keyAPI keyapi.SyncKeyAPI + rsAPI api.SyncRoomserverAPI + userAPI userapi.SyncKeyAPI } func (p *DeviceListStreamProvider) CompleteSync( @@ -31,12 +31,12 @@ func (p *DeviceListStreamProvider) IncrementalSync( from, to types.StreamPosition, ) types.StreamPosition { var err error - to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) + to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.userAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) if err != nil { req.Log.WithError(err).Error("internal.DeviceListCatchup failed") return from } - err = internal.DeviceOTKCounts(req.Context, p.keyAPI, req.Device.UserID, req.Device.ID, req.Response) + err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response) if err != nil { req.Log.WithError(err).Error("internal.DeviceListCatchup failed") return from diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index dc854762..a35491ac 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -5,7 +5,6 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" @@ -27,7 +26,7 @@ type Streams struct { func NewSyncStreamProviders( d storage.Database, userAPI userapi.SyncUserAPI, - rsAPI rsapi.SyncRoomserverAPI, keyAPI keyapi.SyncKeyAPI, + rsAPI rsapi.SyncRoomserverAPI, eduCache *caching.EDUCache, lazyLoadCache caching.LazyLoadCache, notifier *notifier.Notifier, ) *Streams { streams := &Streams{ @@ -60,7 +59,7 @@ func NewSyncStreamProviders( DeviceListStreamProvider: &DeviceListStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, rsAPI: rsAPI, - keyAPI: keyAPI, + userAPI: userAPI, }, PresenceStreamProvider: &PresenceStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index b086567b..68f91387 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -32,7 +32,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/sqlutil" - keyapi "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/internal" @@ -48,7 +47,6 @@ type RequestPool struct { db storage.Database cfg *config.SyncAPI userAPI userapi.SyncUserAPI - keyAPI keyapi.SyncKeyAPI rsAPI roomserverAPI.SyncRoomserverAPI lastseen *sync.Map presence *sync.Map @@ -69,7 +67,7 @@ type PresenceConsumer interface { // NewRequestPool makes a new RequestPool func NewRequestPool( db storage.Database, cfg *config.SyncAPI, - userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI, + userAPI userapi.SyncUserAPI, rsAPI roomserverAPI.SyncRoomserverAPI, streams *streams.Streams, notifier *notifier.Notifier, producer PresencePublisher, consumer PresenceConsumer, enableMetrics bool, @@ -83,7 +81,6 @@ func NewRequestPool( db: db, cfg: cfg, userAPI: userAPI, - keyAPI: keyAPI, rsAPI: rsAPI, lastseen: &sync.Map{}, presence: &sync.Map{}, @@ -280,7 +277,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. // https://github.com/matrix-org/synapse/blob/29f06704b8871a44926f7c99e73cf4a978fb8e81/synapse/rest/client/sync.py#L276-L281 // Only try to get OTKs if the context isn't already done. if syncReq.Context.Err() == nil { - err = internal.DeviceOTKCounts(syncReq.Context, rp.keyAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response) + err = internal.DeviceOTKCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response) if err != nil && err != context.Canceled { syncReq.Log.WithError(err).Warn("failed to get OTK counts") } @@ -551,7 +548,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), snapshot, syncReq, fromToken.PDUPosition, toToken.PDUPosition) _, _, err = internal.DeviceListCatchup( - req.Context(), snapshot, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID, + req.Context(), snapshot, rp.userAPI, rp.rsAPI, syncReq.Device.UserID, syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition, ) if err != nil { diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index be19310f..153f7af5 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal/caching" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" @@ -42,7 +41,6 @@ func AddPublicRoutes( base *base.BaseDendrite, userAPI userapi.SyncUserAPI, rsAPI api.SyncRoomserverAPI, - keyAPI keyapi.SyncKeyAPI, ) { cfg := &base.Cfg.SyncAPI @@ -55,7 +53,7 @@ func AddPublicRoutes( eduCache := caching.NewTypingCache() notifier := notifier.NewNotifier() - streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, base.Caches, notifier) + streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, base.Caches, notifier) notifier.SetCurrentPosition(streams.Latest(context.Background())) if err = notifier.Load(context.Background(), syncDB); err != nil { logrus.WithError(err).Panicf("failed to load notifier ") @@ -71,7 +69,7 @@ func AddPublicRoutes( userAPI, ) - requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics) + requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics) if err = presenceConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start presence consumer") @@ -117,7 +115,7 @@ func AddPublicRoutes( } sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer( - base.ProcessContext, cfg, js, syncDB, keyAPI, notifier, streams.SendToDeviceStreamProvider, + base.ProcessContext, cfg, js, syncDB, userAPI, notifier, streams.SendToDeviceStreamProvider, ) if err = sendToDeviceConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start send-to-device consumer") diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 643c3026..e748660f 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -16,7 +16,6 @@ import ( "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/clientapi/producers" - keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" @@ -83,18 +82,15 @@ func (s *syncUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAc return nil } -func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error { +func (s *syncUserAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error { return nil } -type syncKeyAPI struct { - keyapi.SyncKeyAPI -} - -func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error { +func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error { return nil } -func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error { + +func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error { return nil } @@ -121,7 +117,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) msgs := toNATSMsgs(t, base, room.Events()...) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}) testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { @@ -220,7 +216,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { // m.room.history_visibility msgs := toNATSMsgs(t, base, room.Events()...) sinceTokens := make([]string, len(msgs)) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}) for i, msg := range msgs { testrig.MustPublishMsgs(t, jsctx, msg) time.Sleep(100 * time.Millisecond) @@ -304,7 +300,7 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}) w := httptest.NewRecorder() base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ "access_token": alice.AccessToken, @@ -422,7 +418,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { rsAPI := roomserver.NewInternalAPI(base) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI) for _, tc := range testCases { testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType) @@ -722,7 +718,7 @@ func TestGetMembership(t *testing.T) { rsAPI := roomserver.NewInternalAPI(base) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -789,7 +785,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}) producer := producers.SyncAPIProducer{ TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), @@ -1008,7 +1004,7 @@ func testContext(t *testing.T, dbType test.DBType) { rsAPI := roomserver.NewInternalAPI(base) rsAPI.SetFederationAPI(nil, nil) - AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, &syncKeyAPI{}) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI) room := test.NewRoom(t, user) diff --git a/userapi/api/api.go b/userapi/api/api.go index 4ea2e91c..fa297f77 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -15,9 +15,13 @@ package api import ( + "bytes" "context" "encoding/json" + "strings" + "time" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -26,15 +30,12 @@ import ( // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { - AppserviceUserAPI SyncUserAPI ClientUserAPI - MediaUserAPI FederationUserAPI - RoomserverUserAPI - KeyserverUserAPI QuerySearchProfilesAPI // used by p2p demos + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) } // api functions required by the appservice api @@ -43,11 +44,6 @@ type AppserviceUserAPI interface { PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error } -type KeyserverUserAPI interface { - QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error - QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error -} - type RoomserverUserAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) @@ -60,13 +56,20 @@ type MediaUserAPI interface { // api functions required by the federation api type FederationUserAPI interface { + UploadDeviceKeysAPI QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error + QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error } // api functions required by the sync api type SyncUserAPI interface { QueryAcccessTokenAPI + SyncKeyAPI QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error @@ -79,6 +82,7 @@ type ClientUserAPI interface { QueryAcccessTokenAPI LoginTokenInternalAPI UserLoginAPI + ClientKeyAPI QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error @@ -681,3 +685,310 @@ type QueryAccountByLocalpartRequest struct { type QueryAccountByLocalpartResponse struct { Account *Account } + +// API functions required by the clientapi +type ClientKeyAPI interface { + UploadDeviceKeysAPI + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error + + PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error + // PerformClaimKeys claims one-time keys for use in pre-key messages + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error + PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error +} + +type UploadDeviceKeysAPI interface { + PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error +} + +// API functions required by the syncapi +type SyncKeyAPI interface { + QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error + QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error + PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error +} + +type FederationKeyAPI interface { + UploadDeviceKeysAPI + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error +} + +// KeyError is returned if there was a problem performing/querying the server +type KeyError struct { + Err string `json:"error"` + IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE + IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM + IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM +} + +func (k *KeyError) Error() string { + return k.Err +} + +type DeviceMessageType int + +const ( + TypeDeviceKeyUpdate DeviceMessageType = iota + TypeCrossSigningUpdate +) + +// DeviceMessage represents the message produced into Kafka by the key server. +type DeviceMessage struct { + Type DeviceMessageType `json:"Type,omitempty"` + *DeviceKeys `json:"DeviceKeys,omitempty"` + *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` + // A monotonically increasing number which represents device changes for this user. + StreamID int64 + DeviceChangeID int64 +} + +// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log +type OutputCrossSigningKeyUpdate struct { + CrossSigningKeyUpdate `json:"signing_keys"` +} + +type CrossSigningKeyUpdate struct { + MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"` + SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"` + UserID string `json:"user_id"` +} + +// DeviceKeysEqual returns true if the device keys updates contain the +// same display name and key JSON. This will return false if either of +// the updates is not a device keys update, or if the user ID/device ID +// differ between the two. +func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool { + if m1.DeviceKeys == nil || m2.DeviceKeys == nil { + return false + } + if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID { + return false + } + if m1.DisplayName != m2.DisplayName { + return false // different display names + } + if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 { + return false // either is empty + } + return bytes.Equal(m1.KeyJSON, m2.KeyJSON) +} + +// DeviceKeys represents a set of device keys for a single device +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload +type DeviceKeys struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // The device display name + DisplayName string + // The raw device key JSON + KeyJSON []byte +} + +// WithStreamID returns a copy of this device message with the given stream ID +func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage { + return DeviceMessage{ + DeviceKeys: k, + StreamID: streamID, + } +} + +// OneTimeKeys represents a set of one-time keys for a single device +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload +type OneTimeKeys struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // A map of algorithm:key_id => key JSON + KeyJSON map[string]json.RawMessage +} + +// Split a key in KeyJSON into algorithm and key ID +func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { + segments := strings.Split(keyIDWithAlgo, ":") + return segments[0], segments[1] +} + +// OneTimeKeysCount represents the counts of one-time keys for a single device +type OneTimeKeysCount struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // algorithm to count e.g: + // { + // "curve25519": 10, + // "signed_curve25519": 20 + // } + KeyCount map[string]int +} + +// PerformUploadKeysRequest is the request to PerformUploadKeys +type PerformUploadKeysRequest struct { + UserID string // Required - User performing the request + DeviceID string // Optional - Device performing the request, for fetching OTK count + DeviceKeys []DeviceKeys + OneTimeKeys []OneTimeKeys + // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update + // the display name for their respective device, and NOT to modify the keys. The key + // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. + // Without this flag, requests to modify device display names would delete device keys. + OnlyDisplayNameUpdates bool +} + +// PerformUploadKeysResponse is the response to PerformUploadKeys +type PerformUploadKeysResponse struct { + // A fatal error when processing e.g database failures + Error *KeyError + // A map of user_id -> device_id -> Error for tracking failures. + KeyErrors map[string]map[string]*KeyError + OneTimeKeyCounts []OneTimeKeysCount +} + +// PerformDeleteKeysRequest asks the keyserver to forget about certain +// keys, and signatures related to those keys. +type PerformDeleteKeysRequest struct { + UserID string + KeyIDs []gomatrixserverlib.KeyID +} + +// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest. +type PerformDeleteKeysResponse struct { + Error *KeyError +} + +// KeyError sets a key error field on KeyErrors +func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) { + if r.KeyErrors[userID] == nil { + r.KeyErrors[userID] = make(map[string]*KeyError) + } + r.KeyErrors[userID][deviceID] = err +} + +type PerformClaimKeysRequest struct { + // Map of user_id to device_id to algorithm name + OneTimeKeys map[string]map[string]string + Timeout time.Duration +} + +type PerformClaimKeysResponse struct { + // Map of user_id to device_id to algorithm:key_id to key JSON + OneTimeKeys map[string]map[string]map[string]json.RawMessage + // Map of remote server domain to error JSON + Failures map[string]interface{} + // Set if there was a fatal error processing this action + Error *KeyError +} + +type PerformUploadDeviceKeysRequest struct { + gomatrixserverlib.CrossSigningKeys + // The user that uploaded the key, should be populated by the clientapi. + UserID string +} + +type PerformUploadDeviceKeysResponse struct { + Error *KeyError +} + +type PerformUploadDeviceSignaturesRequest struct { + Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice + // The user that uploaded the sig, should be populated by the clientapi. + UserID string +} + +type PerformUploadDeviceSignaturesResponse struct { + Error *KeyError +} + +type QueryKeysRequest struct { + // The user ID asking for the keys, e.g. if from a client API request. + // Will not be populated if the key request came from federation. + UserID string + // Maps user IDs to a list of devices + UserToDevices map[string][]string + Timeout time.Duration +} + +type QueryKeysResponse struct { + // Map of remote server domain to error JSON + Failures map[string]interface{} + // Map of user_id to device_id to device_key + DeviceKeys map[string]map[string]json.RawMessage + // Maps of user_id to cross signing key + MasterKeys map[string]gomatrixserverlib.CrossSigningKey + SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey + UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // Set if there was a fatal error processing this query + Error *KeyError +} + +type QueryKeyChangesRequest struct { + // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning + Offset int64 + // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. + // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). + ToOffset int64 +} + +type QueryKeyChangesResponse struct { + // The set of users who have had their keys change. + UserIDs []string + // The latest offset represented in this response. + Offset int64 + // Set if there was a problem handling the request. + Error *KeyError +} + +type QueryOneTimeKeysRequest struct { + // The local user to query OTK counts for + UserID string + // The device to query OTK counts for + DeviceID string +} + +type QueryOneTimeKeysResponse struct { + // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 + Count OneTimeKeysCount + Error *KeyError +} + +type QueryDeviceMessagesRequest struct { + UserID string +} + +type QueryDeviceMessagesResponse struct { + // The latest stream ID + StreamID int64 + Devices []DeviceMessage + Error *KeyError +} + +type QuerySignaturesRequest struct { + // A map of target user ID -> target key/device IDs to retrieve signatures for + TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"` +} + +type QuerySignaturesResponse struct { + // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures + Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap + // A map of target user ID -> cross-signing master key + MasterKeys map[string]gomatrixserverlib.CrossSigningKey + // A map of target user ID -> cross-signing self-signing key + SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // A map of target user ID -> cross-signing user-signing key + UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // The request error, if any + Error *KeyError +} + +type PerformMarkAsStaleRequest struct { + UserID string + Domain gomatrixserverlib.ServerName + DeviceID string +} diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index 42ae72e7..51bd2753 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -37,7 +37,7 @@ type OutputReceiptEventConsumer struct { jetstream nats.JetStreamContext durable string topic string - db storage.Database + db storage.UserDatabase serverName gomatrixserverlib.ServerName syncProducer *producers.SyncAPI pgClient pushgateway.Client @@ -49,7 +49,7 @@ func NewOutputReceiptEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, - store storage.Database, + store storage.UserDatabase, syncProducer *producers.SyncAPI, pgClient pushgateway.Client, ) *OutputReceiptEventConsumer { diff --git a/userapi/consumers/devicelistupdate.go b/userapi/consumers/devicelistupdate.go new file mode 100644 index 00000000..a65889fc --- /dev/null +++ b/userapi/consumers/devicelistupdate.go @@ -0,0 +1,95 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/userapi/internal" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" +) + +// DeviceListUpdateConsumer consumes device list updates that came in over federation. +type DeviceListUpdateConsumer struct { + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + updater *internal.DeviceListUpdater + isLocalServerName func(gomatrixserverlib.ServerName) bool +} + +// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. +func NewDeviceListUpdateConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + updater *internal.DeviceListUpdater, +) *DeviceListUpdateConsumer { + return &DeviceListUpdateConsumer{ + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + updater: updater, + isLocalServerName: cfg.Matrix.IsLocalServerName, + } +} + +// Start consuming from key servers +func (t *DeviceListUpdateConsumer) Start() error { + return jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.topic, t.durable, 1, + t.onMessage, nats.DeliverAll(), nats.ManualAck(), + ) +} + +// onMessage is called in response to a message received on the +// key change events topic from the key server. +func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + var m gomatrixserverlib.DeviceListUpdateEvent + if err := json.Unmarshal(msg.Data, &m); err != nil { + logrus.WithError(err).Errorf("Failed to read from device list update input topic") + return true + } + origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) + if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { + return true + } else if t.isLocalServerName(serverName) { + return true + } else if serverName != origin { + return true + } + + err := t.updater.Update(ctx, m) + if err != nil { + logrus.WithFields(logrus.Fields{ + "user_id": m.UserID, + "device_id": m.DeviceID, + "stream_id": m.StreamID, + "prev_id": m.PrevID, + }).WithError(err).Errorf("Failed to update device list") + return false + } + return true +} diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 3ce5af62..47d33095 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct { rsAPI rsapi.UserRoomserverAPI jetstream nats.JetStreamContext durable string - db storage.Database + db storage.UserDatabase topic string pgClient pushgateway.Client syncProducer *producers.SyncAPI @@ -53,7 +53,7 @@ func NewOutputRoomEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, - store storage.Database, + store storage.UserDatabase, pgClient pushgateway.Client, rsAPI rsapi.UserRoomserverAPI, syncProducer *producers.SyncAPI, diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 39f4aab4..bc5ae652 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -18,11 +18,11 @@ import ( userAPITypes "github.com/matrix-org/dendrite/userapi/types" ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "", 4, 0, 0, "") if err != nil { diff --git a/userapi/consumers/signingkeyupdate.go b/userapi/consumers/signingkeyupdate.go new file mode 100644 index 00000000..f4ff017d --- /dev/null +++ b/userapi/consumers/signingkeyupdate.go @@ -0,0 +1,111 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" +) + +// SigningKeyUpdateConsumer consumes signing key updates that came in over federation. +type SigningKeyUpdateConsumer struct { + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + userAPI api.UploadDeviceKeysAPI + cfg *config.UserAPI + isLocalServerName func(gomatrixserverlib.ServerName) bool +} + +// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. +func NewSigningKeyUpdateConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + userAPI api.UploadDeviceKeysAPI, +) *SigningKeyUpdateConsumer { + return &SigningKeyUpdateConsumer{ + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + userAPI: userAPI, + cfg: cfg, + isLocalServerName: cfg.Matrix.IsLocalServerName, + } +} + +// Start consuming from key servers +func (t *SigningKeyUpdateConsumer) Start() error { + return jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.topic, t.durable, 1, + t.onMessage, nats.DeliverAll(), nats.ManualAck(), + ) +} + +// onMessage is called in response to a message received on the +// signing key update events topic from the key server. +func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + var updatePayload api.CrossSigningKeyUpdate + if err := json.Unmarshal(msg.Data, &updatePayload); err != nil { + logrus.WithError(err).Errorf("Failed to read from signing key update input topic") + return true + } + origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) + if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { + logrus.WithError(err).Error("failed to split user id") + return true + } else if t.isLocalServerName(serverName) { + logrus.Warn("dropping device key update from ourself") + return true + } else if serverName != origin { + logrus.Warnf("dropping device key update, %s != %s", serverName, origin) + return true + } + + keys := gomatrixserverlib.CrossSigningKeys{} + if updatePayload.MasterKey != nil { + keys.MasterKey = *updatePayload.MasterKey + } + if updatePayload.SelfSigningKey != nil { + keys.SelfSigningKey = *updatePayload.SelfSigningKey + } + uploadReq := &api.PerformUploadDeviceKeysRequest{ + CrossSigningKeys: keys, + UserID: updatePayload.UserID, + } + uploadRes := &api.PerformUploadDeviceKeysResponse{} + if err := t.userAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { + logrus.WithError(err).Error("failed to upload device keys") + return false + } + if uploadRes.Error != nil { + logrus.WithError(uploadRes.Error).Error("failed to upload device keys") + return true + } + + return true +} diff --git a/userapi/internal/api.go b/userapi/internal/api.go deleted file mode 100644 index 0bb480da..00000000 --- a/userapi/internal/api.go +++ /dev/null @@ -1,968 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "context" - "database/sql" - "encoding/json" - "errors" - "fmt" - "strconv" - "time" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" - "golang.org/x/crypto/bcrypt" - - "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/internal/eventutil" - "github.com/matrix-org/dendrite/internal/pushgateway" - "github.com/matrix-org/dendrite/internal/sqlutil" - keyapi "github.com/matrix-org/dendrite/keyserver/api" - rsapi "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/config" - synctypes "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/producers" - "github.com/matrix-org/dendrite/userapi/storage" - "github.com/matrix-org/dendrite/userapi/storage/tables" - userapiUtil "github.com/matrix-org/dendrite/userapi/util" -) - -type UserInternalAPI struct { - DB storage.Database - SyncProducer *producers.SyncAPI - Config *config.UserAPI - - DisableTLSValidation bool - // AppServices is the list of all registered AS - AppServices []config.ApplicationService - KeyAPI keyapi.UserKeyAPI - RSAPI rsapi.UserRoomserverAPI - PgClient pushgateway.Client - Cfg *config.UserAPI -} - -func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { - local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - return err - } - if !a.Config.Matrix.IsLocalServerName(domain) { - return fmt.Errorf("cannot update account data of remote users (server name %s)", domain) - } - if req.DataType == "" { - return fmt.Errorf("data type must not be empty") - } - if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil { - util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed") - return fmt.Errorf("failed to save account data: %w", err) - } - var ignoredUsers *synctypes.IgnoredUsers - if req.DataType == "m.ignored_user_list" { - ignoredUsers = &synctypes.IgnoredUsers{} - _ = json.Unmarshal(req.AccountData, ignoredUsers) - } - if req.DataType == "m.fully_read" { - if err := a.setFullyRead(ctx, req); err != nil { - return err - } - } - if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{ - RoomID: req.RoomID, - Type: req.DataType, - IgnoredUsers: ignoredUsers, - }); err != nil { - util.GetLogger(ctx).WithError(err).Error("a.SyncProducer.SendAccountData failed") - return fmt.Errorf("failed to send account data to output: %w", err) - } - return nil -} - -func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccountDataRequest) error { - var output eventutil.ReadMarkerJSON - - if err := json.Unmarshal(req.AccountData, &output); err != nil { - return err - } - localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure") - return nil - } - if !a.Config.Matrix.IsLocalServerName(domain) { - return nil - } - - deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) - if err != nil { - logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") - return err - } - - if err = a.SyncProducer.GetAndSendNotificationData(ctx, req.UserID, req.RoomID); err != nil { - logrus.WithError(err).Error("UserInternalAPI.setFullyRead: GetAndSendNotificationData failed") - return err - } - - // nothing changed, no need to notify the push gateway - if !deleted { - return nil - } - - if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil { - logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed") - return err - } - return nil -} - -func postRegisterJoinRooms(cfg *config.UserAPI, acc *api.Account, rsAPI rsapi.UserRoomserverAPI) { - // POST register behaviour: check if the user is a normal user. - // If the user is a normal user, add user to room specified in the configuration "auto_join_rooms". - if acc.AccountType != api.AccountTypeAppService && acc.AppServiceID == "" { - for room := range cfg.AutoJoinRooms { - userID := userutil.MakeUserID(acc.Localpart, cfg.Matrix.ServerName) - err := addUserToRoom(context.Background(), rsAPI, cfg.AutoJoinRooms[room], acc.Localpart, userID) - if err != nil { - logrus.WithFields(logrus.Fields{ - "user_id": userID, - "room": cfg.AutoJoinRooms[room], - }).WithError(err).Errorf("user failed to auto-join room") - } - } - } -} - -// Add user to a room. This function currently working for auto_join_rooms config, -// which can add a newly registered user to a specified room. -func addUserToRoom( - ctx context.Context, - rsAPI rsapi.UserRoomserverAPI, - roomID string, - username string, - userID string, -) error { - addGroupContent := make(map[string]interface{}) - // This make sure the user's username can be displayed correctly. - // Because the newly-registered user doesn't have an avatar, the avatar_url is not needed. - addGroupContent["displayname"] = username - joinReq := rsapi.PerformJoinRequest{ - RoomIDOrAlias: roomID, - UserID: userID, - Content: addGroupContent, - } - joinRes := rsapi.PerformJoinResponse{} - return rsAPI.PerformJoin(ctx, &joinReq, &joinRes) -} - -func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { - serverName := req.ServerName - if serverName == "" { - serverName = a.Config.Matrix.ServerName - } - if !a.Config.Matrix.IsLocalServerName(serverName) { - return fmt.Errorf("server name %s is not local", serverName) - } - acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType) - if err != nil { - if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists - switch req.OnConflict { - case api.ConflictUpdate: - break - case api.ConflictAbort: - return &api.ErrorConflict{ - Message: err.Error(), - } - } - } - // account already exists - res.AccountCreated = false - res.Account = &api.Account{ - AppServiceID: req.AppServiceID, - Localpart: req.Localpart, - ServerName: serverName, - UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), - AccountType: req.AccountType, - } - return nil - } - - // Inform the SyncAPI about the newly created push_rules - if err = a.SyncProducer.SendAccountData(acc.UserID, eventutil.AccountData{ - Type: "m.push_rules", - }); err != nil { - util.GetLogger(ctx).WithFields(logrus.Fields{ - "user_id": acc.UserID, - }).WithError(err).Warn("failed to send account data to the SyncAPI") - } - - if req.AccountType == api.AccountTypeGuest { - res.AccountCreated = true - res.Account = acc - return nil - } - - if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil { - return fmt.Errorf("a.DB.SetDisplayName: %w", err) - } - - postRegisterJoinRooms(a.Cfg, acc, a.RSAPI) - - res.AccountCreated = true - res.Account = acc - return nil -} - -func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { - if !a.Config.Matrix.IsLocalServerName(req.ServerName) { - return fmt.Errorf("server name %s is not local", req.ServerName) - } - if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil { - return err - } - if req.LogoutDevices { - if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil { - return err - } - } - res.PasswordUpdated = true - return nil -} - -func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { - serverName := req.ServerName - if serverName == "" { - serverName = a.Config.Matrix.ServerName - } - if !a.Config.Matrix.IsLocalServerName(serverName) { - return fmt.Errorf("server name %s is not local", serverName) - } - util.GetLogger(ctx).WithFields(logrus.Fields{ - "localpart": req.Localpart, - "device_id": req.DeviceID, - "display_name": req.DeviceDisplayName, - }).Info("PerformDeviceCreation") - dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) - if err != nil { - return err - } - res.DeviceCreated = true - res.Device = dev - if req.NoDeviceListUpdate { - return nil - } - // create empty device keys and upload them to trigger device list changes - return a.deviceListUpdate(dev.UserID, []string{dev.ID}) -} - -func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.PerformDeviceDeletionRequest, res *api.PerformDeviceDeletionResponse) error { - util.GetLogger(ctx).WithField("user_id", req.UserID).WithField("devices", req.DeviceIDs).Info("PerformDeviceDeletion") - local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - return err - } - if !a.Config.Matrix.IsLocalServerName(domain) { - return fmt.Errorf("cannot PerformDeviceDeletion of remote users (server name %s)", domain) - } - deletedDeviceIDs := req.DeviceIDs - if len(req.DeviceIDs) == 0 { - var devices []api.Device - devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID) - for _, d := range devices { - deletedDeviceIDs = append(deletedDeviceIDs, d.ID) - } - } else { - err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs) - } - if err != nil { - return err - } - // Ask the keyserver to delete device keys and signatures for those devices - deleteReq := &keyapi.PerformDeleteKeysRequest{ - UserID: req.UserID, - } - for _, keyID := range req.DeviceIDs { - deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID)) - } - deleteRes := &keyapi.PerformDeleteKeysResponse{} - if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil { - return err - } - if err := deleteRes.Error; err != nil { - return fmt.Errorf("a.KeyAPI.PerformDeleteKeys: %w", err) - } - // create empty device keys and upload them to delete what was once there and trigger device list changes - return a.deviceListUpdate(req.UserID, deletedDeviceIDs) -} - -func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { - deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs)) - for i, did := range deviceIDs { - deviceKeys[i] = keyapi.DeviceKeys{ - UserID: userID, - DeviceID: did, - KeyJSON: nil, - } - } - - var uploadRes keyapi.PerformUploadKeysResponse - if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ - UserID: userID, - DeviceKeys: deviceKeys, - }, &uploadRes); err != nil { - return err - } - if uploadRes.Error != nil { - return fmt.Errorf("failed to delete device keys: %v", uploadRes.Error) - } - if len(uploadRes.KeyErrors) > 0 { - return fmt.Errorf("failed to delete device keys, key errors: %+v", uploadRes.KeyErrors) - } - return nil -} - -func (a *UserInternalAPI) PerformLastSeenUpdate( - ctx context.Context, - req *api.PerformLastSeenUpdateRequest, - res *api.PerformLastSeenUpdateResponse, -) error { - localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) - } - if !a.Config.Matrix.IsLocalServerName(domain) { - return fmt.Errorf("server name %s is not local", domain) - } - if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil { - return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) - } - return nil -} - -func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { - localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") - return err - } - if !a.Config.Matrix.IsLocalServerName(domain) { - return fmt.Errorf("server name %s is not local", domain) - } - dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID) - if err == sql.ErrNoRows { - res.DeviceExists = false - return nil - } else if err != nil { - util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed") - return err - } - res.DeviceExists = true - - if dev.UserID != req.RequestingUserID { - res.Forbidden = true - return nil - } - - err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") - return err - } - if req.DisplayName != nil && dev.DisplayName != *req.DisplayName { - // display name has changed: update the device key - var uploadRes keyapi.PerformUploadKeysResponse - if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ - UserID: req.RequestingUserID, - DeviceKeys: []keyapi.DeviceKeys{ - { - DeviceID: dev.ID, - DisplayName: *req.DisplayName, - KeyJSON: nil, - UserID: dev.UserID, - }, - }, - OnlyDisplayNameUpdates: true, - }, &uploadRes); err != nil { - return err - } - if uploadRes.Error != nil { - return fmt.Errorf("failed to update device key display name: %v", uploadRes.Error) - } - if len(uploadRes.KeyErrors) > 0 { - return fmt.Errorf("failed to update device key display name, key errors: %+v", uploadRes.KeyErrors) - } - } - return nil -} - -func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error { - local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - return err - } - if !a.Config.Matrix.IsLocalServerName(domain) { - return fmt.Errorf("cannot query profile of remote users (server name %s)", domain) - } - prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain) - if err != nil { - if err == sql.ErrNoRows { - return nil - } - return err - } - res.UserExists = true - res.AvatarURL = prof.AvatarURL - res.DisplayName = prof.DisplayName - return nil -} - -func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { - profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit) - if err != nil { - return err - } - res.Profiles = profiles - return nil -} - -func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { - devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs) - if err != nil { - return err - } - res.DeviceInfo = make(map[string]struct { - DisplayName string - UserID string - }) - for _, d := range devices { - res.DeviceInfo[d.ID] = struct { - DisplayName string - UserID string - }{ - DisplayName: d.DisplayName, - UserID: d.UserID, - } - } - return nil -} - -func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { - local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - return err - } - if !a.Config.Matrix.IsLocalServerName(domain) { - return fmt.Errorf("cannot query devices of remote users (server name %s)", domain) - } - devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain) - if err != nil { - return err - } - res.UserExists = true - res.Devices = devs - return nil -} - -func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error { - local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - return err - } - if !a.Config.Matrix.IsLocalServerName(domain) { - return fmt.Errorf("cannot query account data of remote users (server name %s)", domain) - } - if req.DataType != "" { - var data json.RawMessage - data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType) - if err != nil { - return err - } - res.RoomAccountData = make(map[string]map[string]json.RawMessage) - res.GlobalAccountData = make(map[string]json.RawMessage) - if data != nil { - if req.RoomID != "" { - if _, ok := res.RoomAccountData[req.RoomID]; !ok { - res.RoomAccountData[req.RoomID] = make(map[string]json.RawMessage) - } - res.RoomAccountData[req.RoomID][req.DataType] = data - } else { - res.GlobalAccountData[req.DataType] = data - } - } - return nil - } - global, rooms, err := a.DB.GetAccountData(ctx, local, domain) - if err != nil { - return err - } - res.RoomAccountData = rooms - res.GlobalAccountData = global - return nil -} - -func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAccessTokenRequest, res *api.QueryAccessTokenResponse) error { - if req.AppServiceUserID != "" { - appServiceDevice, err := a.queryAppServiceToken(ctx, req.AccessToken, req.AppServiceUserID) - if err != nil { - res.Err = err.Error() - } - res.Device = appServiceDevice - - return nil - } - device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken) - if err != nil { - if err == sql.ErrNoRows { - return nil - } - return err - } - localPart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) - if err != nil { - return err - } - if !a.Config.Matrix.IsLocalServerName(domain) { - return nil - } - acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain) - if err != nil { - return err - } - device.AccountType = acc.AccountType - res.Device = device - return nil -} - -func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) { - res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName) - return -} - -// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem -// creating a 'device'. -func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { - // Search for app service with given access_token - var appService *config.ApplicationService - for _, as := range a.AppServices { - if as.ASToken == token { - appService = &as - break - } - } - if appService == nil { - return nil, nil - } - - // Create a dummy device for AS user - dev := api.Device{ - // Use AS dummy device ID - ID: "AS_Device", - // AS dummy device has AS's token. - AccessToken: token, - AppserviceID: appService.ID, - AccountType: api.AccountTypeAppService, - } - - localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) - if err != nil { - return nil, err - } - - if localpart != "" { // AS is masquerading as another user - // Verify that the user is registered - account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain) - // Verify that the account exists and either appServiceID matches or - // it belongs to the appservice user namespaces - if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { - // Set the userID of dummy device - dev.UserID = appServiceUserID - return &dev, nil - } - return nil, &api.ErrorForbidden{Message: "appservice has not registered this user"} - } - - // AS is not masquerading as any user, so use AS's sender_localpart - dev.UserID = appService.SenderLocalpart - return &dev, nil -} - -// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. -func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { - serverName := req.ServerName - if serverName == "" { - serverName = a.Config.Matrix.ServerName - } - if !a.Config.Matrix.IsLocalServerName(serverName) { - return fmt.Errorf("server name %q not locally configured", serverName) - } - - evacuateReq := &rsapi.PerformAdminEvacuateUserRequest{ - UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), - } - evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{} - if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil { - return err - } - if err := evacuateRes.Error; err != nil { - logrus.WithError(err).Errorf("Failed to evacuate user after account deactivation") - } - - deviceReq := &api.PerformDeviceDeletionRequest{ - UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), - } - deviceRes := &api.PerformDeviceDeletionResponse{} - if err := a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil { - return err - } - - pusherReq := &api.PerformPusherDeletionRequest{ - Localpart: req.Localpart, - } - if err := a.PerformPusherDeletion(ctx, pusherReq, &struct{}{}); err != nil { - return err - } - - err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName) - res.AccountDeactivated = err == nil - return err -} - -// PerformOpenIDTokenCreation creates a new token that a relying party uses to authenticate a user -func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { - token := util.RandomString(24) - - exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID) - - res.Token = api.OpenIDToken{ - Token: token, - UserID: req.UserID, - ExpiresAtMS: exp, - } - - return err -} - -// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation -func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { - openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token) - if err != nil { - return err - } - - res.Sub = openIDTokenAttrs.UserID - res.ExpiresAtMS = openIDTokenAttrs.ExpiresAtMS - - return nil -} - -func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error { - // Delete metadata - if req.DeleteBackup { - if req.Version == "" { - res.BadInput = true - res.Error = "must specify a version to delete" - return nil - } - exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version) - if err != nil { - res.Error = fmt.Sprintf("failed to delete backup: %s", err) - } - res.Exists = exists - res.Version = req.Version - return nil - } - // Create metadata - if req.Version == "" { - version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) - if err != nil { - res.Error = fmt.Sprintf("failed to create backup: %s", err) - } - res.Exists = err == nil - res.Version = version - return nil - } - // Update metadata - if len(req.Keys.Rooms) == 0 { - err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) - if err != nil { - res.Error = fmt.Sprintf("failed to update backup: %s", err) - } - res.Exists = err == nil - res.Version = req.Version - return nil - } - // Upload Keys for a specific version metadata - a.uploadBackupKeys(ctx, req, res) - return nil -} - -func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { - // you can only upload keys for the CURRENT version - version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "") - if err != nil { - res.Error = fmt.Sprintf("failed to query version: %s", err) - return - } - if deleted { - res.Error = "backup was deleted" - return - } - if version != req.Version { - res.BadInput = true - res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version) - return - } - res.Exists = true - res.Version = version - - // map keys to a form we can upload more easily - the map ensures we have no duplicates. - var uploads []api.InternalKeyBackupSession - for roomID, data := range req.Keys.Rooms { - for sessionID, sessionData := range data.Sessions { - uploads = append(uploads, api.InternalKeyBackupSession{ - RoomID: roomID, - SessionID: sessionID, - KeyBackupSession: sessionData, - }) - } - } - count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads) - if err != nil { - res.Error = fmt.Sprintf("failed to upsert keys: %s", err) - return - } - res.KeyCount = count - res.KeyETag = etag -} - -func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error { - version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version) - res.Version = version - if err != nil { - if err == sql.ErrNoRows { - res.Exists = false - return nil - } - res.Error = fmt.Sprintf("failed to query key backup: %s", err) - return nil - } - res.Algorithm = algorithm - res.AuthData = authData - res.ETag = etag - res.Exists = !deleted - - if !req.ReturnKeys { - res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID) - if err != nil { - res.Error = fmt.Sprintf("failed to count keys: %s", err) - } - return nil - } - - result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) - if err != nil { - res.Error = fmt.Sprintf("failed to query keys: %s", err) - return nil - } - res.Keys = result - return nil -} - -func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { - if req.Limit == 0 || req.Limit > 1000 { - req.Limit = 1000 - } - - var fromID int64 - var err error - if req.From != "" { - fromID, err = strconv.ParseInt(req.From, 10, 64) - if err != nil { - return fmt.Errorf("QueryNotifications: parsing 'from': %w", err) - } - } - var filter tables.NotificationFilter = tables.AllNotifications - if req.Only == "highlight" { - filter = tables.HighlightNotifications - } - notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter) - if err != nil { - return err - } - if notifs == nil { - // This ensures empty is JSON-encoded as [] instead of null. - notifs = []*api.Notification{} - } - res.Notifications = notifs - if lastID >= 0 { - res.NextToken = strconv.FormatInt(lastID+1, 10) - } - return nil -} - -func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.PerformPusherSetRequest, res *struct{}) error { - util.GetLogger(ctx).WithFields(logrus.Fields{ - "localpart": req.Localpart, - "pushkey": req.Pusher.PushKey, - "display_name": req.Pusher.AppDisplayName, - }).Info("PerformPusherCreation") - if !req.Append { - err := a.DB.RemovePushers(ctx, req.Pusher.AppID, req.Pusher.PushKey) - if err != nil { - return err - } - } - if req.Pusher.Kind == "" { - return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName) - } - if req.Pusher.PushKeyTS == 0 { - req.Pusher.PushKeyTS = int64(time.Now().Unix()) - } - return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName) -} - -func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { - pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName) - if err != nil { - return err - } - for i := range pushers { - logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID) - if pushers[i].SessionID != req.SessionID { - err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName) - if err != nil { - return err - } - } - } - return nil -} - -func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { - var err error - res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName) - return err -} - -func (a *UserInternalAPI) PerformPushRulesPut( - ctx context.Context, - req *api.PerformPushRulesPutRequest, - _ *struct{}, -) error { - bs, err := json.Marshal(&req.RuleSets) - if err != nil { - return err - } - userReq := api.InputAccountDataRequest{ - UserID: req.UserID, - DataType: pushRulesAccountDataType, - AccountData: json.RawMessage(bs), - } - var userRes api.InputAccountDataResponse // empty - if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil { - return err - } - return nil -} - -func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { - localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) - } - pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain) - if err != nil { - return fmt.Errorf("failed to query push rules: %w", err) - } - res.RuleSets = pushRules - return nil -} - -func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL) - res.Profile = profile - res.Changed = changed - return err -} - -func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error { - id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName) - if err != nil { - return err - } - res.ID = id - return nil -} - -func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error { - var err error - res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName) - return err -} - -func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error { - acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword) - switch err { - case sql.ErrNoRows: // user does not exist - return nil - case bcrypt.ErrMismatchedHashAndPassword: // user exists, but password doesn't match - return nil - case bcrypt.ErrHashTooShort: // user exists, but probably a passwordless account - return nil - default: - res.Exists = true - res.Account = acc - return nil - } -} - -func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { - profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName) - res.Profile = profile - res.Changed = changed - return err -} - -func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { - localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium) - if err != nil { - return err - } - res.Localpart = localpart - res.ServerName = domain - return nil -} - -func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error { - r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName) - if err != nil { - return err - } - res.ThreePIDs = r - return nil -} - -func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.PerformForgetThreePIDRequest, res *struct{}) error { - return a.DB.RemoveThreePIDAssociation(ctx, req.ThreePID, req.Medium) -} - -func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error { - return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium) -} - -const pushRulesAccountDataType = "m.push_rules" diff --git a/userapi/internal/cross_signing.go b/userapi/internal/cross_signing.go new file mode 100644 index 00000000..8b9704d1 --- /dev/null +++ b/userapi/internal/cross_signing.go @@ -0,0 +1,587 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "bytes" + "context" + "crypto/ed25519" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "golang.org/x/crypto/curve25519" +) + +func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpose gomatrixserverlib.CrossSigningKeyPurpose) error { + // Is there exactly one key? + if len(key.Keys) != 1 { + return fmt.Errorf("should contain exactly one key") + } + + // Does the key ID match the key value? Iterates exactly once + for keyID, keyData := range key.Keys { + b64 := keyData.Encode() + tokens := strings.Split(string(keyID), ":") + if len(tokens) != 2 { + return fmt.Errorf("key ID is incorrectly formatted") + } + if tokens[1] != b64 { + return fmt.Errorf("key ID isn't correct") + } + switch tokens[0] { + case "ed25519": + if len(keyData) != ed25519.PublicKeySize { + return fmt.Errorf("ed25519 key is not the correct length") + } + case "curve25519": + if len(keyData) != curve25519.PointSize { + return fmt.Errorf("curve25519 key is not the correct length") + } + default: + // We can't enforce the key length to be correct for an + // algorithm that we don't recognise, so instead we'll + // just make sure that it isn't incredibly excessive. + if l := len(keyData); l > 4096 { + return fmt.Errorf("unknown key type is too long (%d bytes)", l) + } + } + } + + // Check to see if the signatures make sense + for _, forOriginUser := range key.Signatures { + for originKeyID, originSignature := range forOriginUser { + switch strings.SplitN(string(originKeyID), ":", 1)[0] { + case "ed25519": + if len(originSignature) != ed25519.SignatureSize { + return fmt.Errorf("ed25519 signature is not the correct length") + } + case "curve25519": + return fmt.Errorf("curve25519 signatures are impossible") + default: + if l := len(originSignature); l > 4096 { + return fmt.Errorf("unknown signature type is too long (%d bytes)", l) + } + } + } + } + + // Does the key claim to be from the right user? + if userID != key.UserID { + return fmt.Errorf("key has a user ID mismatch") + } + + // Does the key contain the correct purpose? + useful := false + for _, usage := range key.Usage { + if usage == purpose { + useful = true + break + } + } + if !useful { + return fmt.Errorf("key does not contain correct usage purpose") + } + + return nil +} + +// nolint:gocyclo +func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { + // Find the keys to store. + byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} + toStore := types.CrossSigningKeyMap{} + hasMasterKey := false + + if len(req.MasterKey.Keys) > 0 { + if err := sanityCheckKey(req.MasterKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err != nil { + res.Error = &api.KeyError{ + Err: "Master key sanity check failed: " + err.Error(), + IsInvalidParam: true, + } + return nil + } + + byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey + for _, key := range req.MasterKey.Keys { // iterates once, see sanityCheckKey + toStore[gomatrixserverlib.CrossSigningKeyPurposeMaster] = key + } + hasMasterKey = true + } + + if len(req.SelfSigningKey.Keys) > 0 { + if err := sanityCheckKey(req.SelfSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err != nil { + res.Error = &api.KeyError{ + Err: "Self-signing key sanity check failed: " + err.Error(), + IsInvalidParam: true, + } + return nil + } + + byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey + for _, key := range req.SelfSigningKey.Keys { // iterates once, see sanityCheckKey + toStore[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = key + } + } + + if len(req.UserSigningKey.Keys) > 0 { + if err := sanityCheckKey(req.UserSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeUserSigning); err != nil { + res.Error = &api.KeyError{ + Err: "User-signing key sanity check failed: " + err.Error(), + IsInvalidParam: true, + } + return nil + } + + byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey + for _, key := range req.UserSigningKey.Keys { // iterates once, see sanityCheckKey + toStore[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = key + } + } + + // If there's nothing to do then stop here. + if len(toStore) == 0 { + res.Error = &api.KeyError{ + Err: "No keys were supplied in the request", + IsMissingParam: true, + } + return nil + } + + // We can't have a self-signing or user-signing key without a master + // key, so make sure we have one of those. We will also only actually do + // something if any of the specified keys in the request are different + // to what we've got in the database, to avoid generating key change + // notifications unnecessarily. + existingKeys, err := a.KeyDatabase.CrossSigningKeysDataForUser(ctx, req.UserID) + if err != nil { + res.Error = &api.KeyError{ + Err: "Retrieving cross-signing keys from database failed: " + err.Error(), + } + return nil + } + + // If we still can't find a master key for the user then stop the upload. + // This satisfies the "Fails to upload self-signing key without master key" test. + if !hasMasterKey { + if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey { + res.Error = &api.KeyError{ + Err: "No master key was found", + IsMissingParam: true, + } + return nil + } + } + + // Check if anything actually changed compared to what we have in the database. + changed := false + for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{ + gomatrixserverlib.CrossSigningKeyPurposeMaster, + gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, + gomatrixserverlib.CrossSigningKeyPurposeUserSigning, + } { + old, gotOld := existingKeys[purpose] + new, gotNew := toStore[purpose] + if gotOld != gotNew { + // A new key purpose has been specified that we didn't know before, + // or one has been removed. + changed = true + break + } + if !bytes.Equal(old, new) { + // One of the existing keys for a purpose we already knew about has + // changed. + changed = true + break + } + } + if !changed { + return nil + } + + // Store the keys. + if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), + } + return nil + } + + // Now upload any signatures that were included with the keys. + for _, key := range byPurpose { + var targetKeyID gomatrixserverlib.KeyID + for targetKey := range key.Keys { // iterates once, see sanityCheckKey + targetKeyID = targetKey + } + for sigUserID, forSigUserID := range key.Signatures { + if sigUserID != req.UserID { + continue + } + for sigKeyID, sigBytes := range forSigUserID { + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err), + } + return nil + } + } + } + } + + // Finally, generate a notification that we updated the keys. + update := api.CrossSigningKeyUpdate{ + UserID: req.UserID, + } + if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok { + update.MasterKey = &mk + } + if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok { + update.SelfSigningKey = &ssk + } + if update.MasterKey == nil && update.SelfSigningKey == nil { + return nil + } + if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), + } + return nil + } + return nil +} + +func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { + // Before we do anything, we need the master and self-signing keys for this user. + // Then we can verify the signatures make sense. + queryReq := &api.QueryKeysRequest{ + UserID: req.UserID, + UserToDevices: map[string][]string{}, + } + queryRes := &api.QueryKeysResponse{} + for userID := range req.Signatures { + queryReq.UserToDevices[userID] = []string{} + } + _ = a.QueryKeys(ctx, queryReq, queryRes) + + selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + + // Sort signatures into two groups: one where people have signed their own + // keys and one where people have signed someone elses + for userID, forUserID := range req.Signatures { + for keyID, keyOrDevice := range forUserID { + switch key := keyOrDevice.CrossSigningBody.(type) { + case *gomatrixserverlib.CrossSigningKey: + if key.UserID == req.UserID { + if _, ok := selfSignatures[userID]; !ok { + selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + } + selfSignatures[userID][keyID] = keyOrDevice + } else { + if _, ok := otherSignatures[userID]; !ok { + otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + } + otherSignatures[userID][keyID] = keyOrDevice + } + + case *gomatrixserverlib.DeviceKeys: + if key.UserID == req.UserID { + if _, ok := selfSignatures[userID]; !ok { + selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + } + selfSignatures[userID][keyID] = keyOrDevice + } else { + if _, ok := otherSignatures[userID]; !ok { + otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + } + otherSignatures[userID][keyID] = keyOrDevice + } + + default: + continue + } + } + } + + if err := a.processSelfSignatures(ctx, selfSignatures); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.processSelfSignatures: %s", err), + } + return nil + } + + if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.processOtherSignatures: %s", err), + } + return nil + } + + // Finally, generate a notification that we updated the signatures. + for userID := range req.Signatures { + masterKey := queryRes.MasterKeys[userID] + selfSigningKey := queryRes.SelfSigningKeys[userID] + update := api.CrossSigningKeyUpdate{ + UserID: userID, + MasterKey: &masterKey, + SelfSigningKey: &selfSigningKey, + } + if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), + } + return nil + } + } + return nil +} + +func (a *UserInternalAPI) processSelfSignatures( + ctx context.Context, + signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, +) error { + // Here we will process: + // * The user signing their own devices using their self-signing key + // * The user signing their master key using one of their devices + + for targetUserID, forTargetUserID := range signatures { + for targetKeyID, signature := range forTargetUserID { + switch sig := signature.CrossSigningBody.(type) { + case *gomatrixserverlib.CrossSigningKey: + for keyID := range sig.Keys { + split := strings.SplitN(string(keyID), ":", 2) + if len(split) > 1 && gomatrixserverlib.KeyID(split[1]) == targetKeyID { + targetKeyID = keyID // contains the ed25519: or other scheme + break + } + } + for originUserID, forOriginUserID := range sig.Signatures { + for originKeyID, originSig := range forOriginUserID { + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( + ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, + ); err != nil { + return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) + } + } + } + + case *gomatrixserverlib.DeviceKeys: + for originUserID, forOriginUserID := range sig.Signatures { + for originKeyID, originSig := range forOriginUserID { + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( + ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, + ); err != nil { + return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) + } + } + } + + default: + return fmt.Errorf("unexpected type assertion") + } + } + } + + return nil +} + +func (a *UserInternalAPI) processOtherSignatures( + ctx context.Context, userID string, queryRes *api.QueryKeysResponse, + signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, +) error { + // Here we will process: + // * A user signing someone else's master keys using their user-signing keys + + for targetUserID, forTargetUserID := range signatures { + for _, signature := range forTargetUserID { + switch sig := signature.CrossSigningBody.(type) { + case *gomatrixserverlib.CrossSigningKey: + // Find the local copy of the master key. We'll use this to be + // sure that the supplied stanza matches the key that we think it + // should be. + masterKey, ok := queryRes.MasterKeys[targetUserID] + if !ok { + return fmt.Errorf("failed to find master key for user %q", targetUserID) + } + + // For each key ID, write the signatures. Maybe there'll be more + // than one algorithm in the future so it's best not to focus on + // everything being ed25519:. + for targetKeyID, suppliedKeyData := range sig.Keys { + // The master key will be supplied in the request, but we should + // make sure that it matches what we think the master key should + // actually be. + localKeyData, lok := masterKey.Keys[targetKeyID] + if !lok { + return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID) + } else if !bytes.Equal(suppliedKeyData, localKeyData) { + return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID) + } + + // We only care about the signatures from the uploading user, so + // we will ignore anything that didn't originate from them. + userSigs, ok := sig.Signatures[userID] + if !ok { + return fmt.Errorf("there are no signatures on master key %q from uploading user %q", targetKeyID, userID) + } + + for originKeyID, originSig := range userSigs { + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( + ctx, userID, originKeyID, targetUserID, targetKeyID, originSig, + ); err != nil { + return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) + } + } + } + + default: + // Users should only be signing another person's master key, + // so if we're here, it's probably because it's actually a + // gomatrixserverlib.DeviceKeys, which doesn't make sense. + } + } + } + + return nil +} + +func (a *UserInternalAPI) crossSigningKeysFromDatabase( + ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse, +) { + for targetUserID := range req.UserToDevices { + keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) + if err != nil { + logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID) + continue + } + + for keyType, key := range keys { + var keyID gomatrixserverlib.KeyID + for id := range key.Keys { + keyID = id + break + } + + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) + if err != nil && err != sql.ErrNoRows { + logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID) + continue + } + + appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) { + if key.Signatures == nil { + key.Signatures = types.CrossSigningSigMap{} + } + if _, ok := key.Signatures[originUserID]; !ok { + key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes) + } + key.Signatures[originUserID][originKeyID] = signature + } + + for originUserID, forOrigin := range sigMap { + for originKeyID, signature := range forOrigin { + switch { + case req.UserID != "" && originUserID == req.UserID: + // Include signatures that we created + appendSignature(originUserID, originKeyID, signature) + case originUserID == targetUserID: + // Include signatures that were created by the person whose key + // we are processing + appendSignature(originUserID, originKeyID, signature) + } + } + } + + switch keyType { + case gomatrixserverlib.CrossSigningKeyPurposeMaster: + res.MasterKeys[targetUserID] = key + + case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: + res.SelfSigningKeys[targetUserID] = key + + case gomatrixserverlib.CrossSigningKeyPurposeUserSigning: + res.UserSigningKeys[targetUserID] = key + } + } + } +} + +func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { + for targetUserID, forTargetUser := range req.TargetIDs { + keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) + if err != nil && err != sql.ErrNoRows { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err), + } + continue + } + + for targetPurpose, targetKey := range keyMap { + switch targetPurpose { + case gomatrixserverlib.CrossSigningKeyPurposeMaster: + if res.MasterKeys == nil { + res.MasterKeys = map[string]gomatrixserverlib.CrossSigningKey{} + } + res.MasterKeys[targetUserID] = targetKey + + case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: + if res.SelfSigningKeys == nil { + res.SelfSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{} + } + res.SelfSigningKeys[targetUserID] = targetKey + + case gomatrixserverlib.CrossSigningKeyPurposeUserSigning: + if res.UserSigningKeys == nil { + res.UserSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{} + } + res.UserSigningKeys[targetUserID] = targetKey + } + } + + for _, targetKeyID := range forTargetUser { + // Get own signatures only. + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID) + if err != nil && err != sql.ErrNoRows { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), + } + return nil + } + + for sourceUserID, forSourceUser := range sigMap { + for sourceKeyID, sourceSig := range forSourceUser { + if res.Signatures == nil { + res.Signatures = map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{} + } + if _, ok := res.Signatures[targetUserID]; !ok { + res.Signatures[targetUserID] = map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{} + } + if _, ok := res.Signatures[targetUserID][targetKeyID]; !ok { + res.Signatures[targetUserID][targetKeyID] = types.CrossSigningSigMap{} + } + if _, ok := res.Signatures[targetUserID][targetKeyID][sourceUserID]; !ok { + res.Signatures[targetUserID][targetKeyID][sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + res.Signatures[targetUserID][targetKeyID][sourceUserID][sourceKeyID] = sourceSig + } + } + } + } + return nil +} diff --git a/userapi/internal/device_list_update.go b/userapi/internal/device_list_update.go new file mode 100644 index 00000000..3b4dcf98 --- /dev/null +++ b/userapi/internal/device_list_update.go @@ -0,0 +1,579 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "hash/fnv" + "net" + "sync" + "time" + + rsapi "github.com/matrix-org/dendrite/roomserver/api" + + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" +) + +var ( + deviceListUpdateCount = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "keyserver", + Name: "device_list_update", + Help: "Number of times we have attempted to update device lists from this server", + }, + []string{"server"}, + ) +) + +const requestTimeout = time.Second * 30 + +func init() { + prometheus.MustRegister( + deviceListUpdateCount, + ) +} + +// DeviceListUpdater handles device list updates from remote servers. +// +// In the case where we have the prev_id for an update, the updater just stores the update (after acquiring a per-user lock). +// In the case where we do not have the prev_id for an update, the updater marks the user_id as stale and notifies +// a worker to get the latest device list for this user. Note: stream IDs are scoped per user so missing a prev_id +// for a (user, device) does not mean that DEVICE is outdated as the previous ID could be for a different device: +// we have to invalidate all devices for that user. Once the list has been fetched, the per-user lock is acquired and the +// updater stores the latest list along with the latest stream ID. +// +// On startup, the updater spins up N workers which are responsible for querying device keys from remote servers. +// Workers are scoped by homeserver domain, with one worker responsible for many domains, determined by hashing +// mod N the server name. Work is sent via a channel which just serves to "poke" the worker as the data is retrieved +// from the database (which allows us to batch requests to the same server). This has a number of desirable properties: +// - We guarantee only 1 in-flight /keys/query request per server at any time as there is exactly 1 worker responsible +// for that domain. +// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where +// we have many many servers) +// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers. +// +// The downsides are that: +// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free +// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts) +// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests +// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse +// than being stuck behind foo.bar +// +// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is +// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried. +type DeviceListUpdater struct { + process *process.ProcessContext + // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1 + // request to the remote server and race. + // TODO: Put in an LRU cache to bound growth + userIDToMutex map[string]*sync.Mutex + mu *sync.Mutex // protects UserIDToMutex + + db DeviceListUpdaterDatabase + api DeviceListUpdaterAPI + producer KeyChangeProducer + fedClient fedsenderapi.KeyserverFederationAPI + workerChans []chan gomatrixserverlib.ServerName + thisServer gomatrixserverlib.ServerName + + // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will + // block on or timeout via a select. + userIDToChan map[string]chan bool + userIDToChanMu *sync.Mutex + rsAPI rsapi.KeyserverRoomserverAPI +} + +// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. +// Useful for testing. +type DeviceListUpdaterDatabase interface { + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. + // If no domains are given, all user IDs with stale device lists are returned. + StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + + // MarkDeviceListStale sets the stale bit for this user to isStale. + MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error + + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior + // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error + + // PrevIDsExists returns true if all prev IDs exist for this user. + PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) + + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error +} + +type DeviceListUpdaterAPI interface { + PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error +} + +// KeyChangeProducer is the interface for producers.KeyChange useful for testing. +type KeyChangeProducer interface { + ProduceKeyChanges(keys []api.DeviceMessage) error +} + +// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. +func NewDeviceListUpdater( + process *process.ProcessContext, db DeviceListUpdaterDatabase, + api DeviceListUpdaterAPI, producer KeyChangeProducer, + fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, + rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName, +) *DeviceListUpdater { + return &DeviceListUpdater{ + process: process, + userIDToMutex: make(map[string]*sync.Mutex), + mu: &sync.Mutex{}, + db: db, + api: api, + producer: producer, + fedClient: fedClient, + thisServer: thisServer, + workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), + userIDToChan: make(map[string]chan bool), + userIDToChanMu: &sync.Mutex{}, + rsAPI: rsAPI, + } +} + +// Start the device list updater, which will try to refresh any stale device lists. +func (u *DeviceListUpdater) Start() error { + for i := 0; i < len(u.workerChans); i++ { + // Allocate a small buffer per channel. + // If the buffer limit is reached, backpressure will cause the processing of EDUs + // to stop (in this transaction) until key requests can be made. + ch := make(chan gomatrixserverlib.ServerName, 10) + u.workerChans[i] = ch + go u.worker(ch) + } + + staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + if err != nil { + return err + } + offset, step := time.Second*10, time.Second + if max := len(staleLists); max > 120 { + step = (time.Second * 120) / time.Duration(max) + } + for _, userID := range staleLists { + userID := userID // otherwise we are only sending the last entry + time.AfterFunc(offset, func() { + u.notifyWorkers(userID) + }) + offset += step + } + return nil +} + +// CleanUp removes stale device entries for users we don't share a room with anymore +func (u *DeviceListUpdater) CleanUp() error { + staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + if err != nil { + return err + } + + res := rsapi.QueryLeftUsersResponse{} + if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil { + return err + } + + if len(res.LeftUsers) == 0 { + return nil + } + logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers)) + return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers) +} + +func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { + u.mu.Lock() + defer u.mu.Unlock() + if u.userIDToMutex[userID] == nil { + u.userIDToMutex[userID] = &sync.Mutex{} + } + return u.userIDToMutex[userID] +} + +// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it. +// Blocks until the device list is synced or the timeout is reached. +func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error { + mu := u.mutex(userID) + mu.Lock() + err := u.db.MarkDeviceListStale(ctx, userID, true) + mu.Unlock() + if err != nil { + return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err) + } + u.notifyWorkers(userID) + return nil +} + +// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest, +// which assumes when /send 200 OKs that the device lists have been updated. +func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error { + isDeviceListStale, err := u.update(ctx, event) + if err != nil { + return err + } + if isDeviceListStale { + // poke workers to handle stale device lists + u.notifyWorkers(event.UserID) + } + return nil +} + +func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) (bool, error) { + mu := u.mutex(event.UserID) + mu.Lock() + defer mu.Unlock() + // check if we have the prev IDs + exists, err := u.db.PrevIDsExists(ctx, event.UserID, event.PrevID) + if err != nil { + return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err) + } + // if this is the first time we're hearing about this user, sync the device list manually. + if len(event.PrevID) == 0 { + exists = false + } + util.GetLogger(ctx).WithFields(logrus.Fields{ + "prev_ids_exist": exists, + "user_id": event.UserID, + "device_id": event.DeviceID, + "stream_id": event.StreamID, + "prev_ids": event.PrevID, + "display_name": event.DeviceDisplayName, + "deleted": event.Deleted, + }).Trace("DeviceListUpdater.Update") + + // if we haven't missed anything update the database and notify users + if exists || event.Deleted { + k := event.Keys + if event.Deleted { + k = nil + } + keys := []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: event.DeviceID, + DisplayName: event.DeviceDisplayName, + KeyJSON: k, + UserID: event.UserID, + }, + StreamID: event.StreamID, + }, + } + + // DeviceKeysJSON will side-effect modify this, so it needs + // to be a copy, not sharing any pointers with the above. + deviceKeysCopy := *keys[0].DeviceKeys + deviceKeysCopy.KeyJSON = nil + existingKeys := []api.DeviceMessage{ + { + Type: keys[0].Type, + DeviceKeys: &deviceKeysCopy, + StreamID: keys[0].StreamID, + }, + } + + // fetch what keys we had already and only emit changes + if err = u.db.DeviceKeysJSON(ctx, existingKeys); err != nil { + // non-fatal, log and continue + util.GetLogger(ctx).WithError(err).WithField("user_id", event.UserID).Errorf( + "failed to query device keys json for calculating diffs", + ) + } + + err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil) + if err != nil { + return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err) + } + + if err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false); err != nil { + return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err) + } + return false, nil + } + + err = u.db.MarkDeviceListStale(ctx, event.UserID, true) + if err != nil { + return false, fmt.Errorf("failed to mark device list for %s as stale: %w", event.UserID, err) + } + + return true, nil +} + +func (u *DeviceListUpdater) notifyWorkers(userID string) { + _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return + } + hash := fnv.New32a() + _, _ = hash.Write([]byte(remoteServer)) + index := int(int64(hash.Sum32()) % int64(len(u.workerChans))) + + ch := u.assignChannel(userID) + u.workerChans[index] <- remoteServer + select { + case <-ch: + case <-time.After(10 * time.Second): + // we don't return an error in this case as it's not a failure condition. + // we mainly block for the benefit of sytest anyway + } +} + +func (u *DeviceListUpdater) assignChannel(userID string) chan bool { + u.userIDToChanMu.Lock() + defer u.userIDToChanMu.Unlock() + if ch, ok := u.userIDToChan[userID]; ok { + return ch + } + ch := make(chan bool) + u.userIDToChan[userID] = ch + return ch +} + +func (u *DeviceListUpdater) clearChannel(userID string) { + u.userIDToChanMu.Lock() + defer u.userIDToChanMu.Unlock() + if ch, ok := u.userIDToChan[userID]; ok { + close(ch) + delete(u.userIDToChan, userID) + } +} + +func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { + retries := make(map[gomatrixserverlib.ServerName]time.Time) + retriesMu := &sync.Mutex{} + // restarter goroutine which will inject failed servers into ch when it is time + go func() { + var serversToRetry []gomatrixserverlib.ServerName + for { + serversToRetry = serversToRetry[:0] // reuse memory + time.Sleep(time.Second) + retriesMu.Lock() + now := time.Now() + for srv, retryAt := range retries { + if now.After(retryAt) { + serversToRetry = append(serversToRetry, srv) + } + } + for _, srv := range serversToRetry { + delete(retries, srv) + } + retriesMu.Unlock() + for _, srv := range serversToRetry { + ch <- srv + } + } + }() + for serverName := range ch { + retriesMu.Lock() + _, exists := retries[serverName] + retriesMu.Unlock() + if exists { + // Don't retry a server that we're already waiting for. + continue + } + waitTime, shouldRetry := u.processServer(serverName) + if shouldRetry { + retriesMu.Lock() + if _, exists = retries[serverName]; !exists { + retries[serverName] = time.Now().Add(waitTime) + } + retriesMu.Unlock() + } + } +} + +func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) { + ctx := u.process.Context() + logger := util.GetLogger(ctx).WithField("server_name", serverName) + deviceListUpdateCount.WithLabelValues(string(serverName)).Inc() + + waitTime := defaultWaitTime // How long should we wait to try again? + successCount := 0 // How many user requests failed? + + userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) + if err != nil { + logger.WithError(err).Error("Failed to load stale device lists") + return waitTime, true + } + + defer func() { + for _, userID := range userIDs { + // always clear the channel to unblock Update calls regardless of success/failure + u.clearChannel(userID) + } + }() + + for _, userID := range userIDs { + userWait, err := u.processServerUser(ctx, serverName, userID) + if err != nil { + if userWait > waitTime { + waitTime = userWait + } + break + } + successCount++ + } + + allUsersSucceeded := successCount == len(userIDs) + if !allUsersSucceeded { + logger.WithFields(logrus.Fields{ + "total": len(userIDs), + "succeeded": successCount, + "failed": len(userIDs) - successCount, + "wait_time": waitTime, + }).Debug("Failed to query device keys for some users") + } + return waitTime, !allUsersSucceeded +} + +func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) { + ctx, cancel := context.WithTimeout(ctx, requestTimeout) + defer cancel() + logger := util.GetLogger(ctx).WithFields(logrus.Fields{ + "server_name": serverName, + "user_id": userID, + }) + res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return time.Minute * 10, err + } + switch e := err.(type) { + case *json.UnmarshalTypeError, *json.SyntaxError: + logger.WithError(err).Debugf("Device list update for %q contained invalid JSON", userID) + return defaultWaitTime, nil + case *fedsenderapi.FederationClientError: + if e.RetryAfter > 0 { + return e.RetryAfter, err + } else if e.Blacklisted { + return time.Hour * 8, err + } + case net.Error: + // Use the default waitTime, if it's a timeout. + // It probably doesn't make sense to try further users. + if !e.Timeout() { + logger.WithError(e).Debug("GetUserDevices returned net.Error") + return time.Minute * 10, err + } + case gomatrix.HTTPError: + // The remote server returned an error, give it some time to recover. + // This is to avoid spamming remote servers, which may not be Matrix servers anymore. + if e.Code >= 300 { + logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError") + return hourWaitTime, err + } + default: + // Something else failed + logger.WithError(err).Debugf("GetUserDevices returned unknown error type: %T", err) + return time.Minute * 10, err + } + } + if res.UserID != userID { + logger.WithError(err).Debugf("User ID %q in device list update response doesn't match expected %q", res.UserID, userID) + return defaultWaitTime, nil + } + if res.MasterKey != nil || res.SelfSigningKey != nil { + uploadReq := &api.PerformUploadDeviceKeysRequest{ + UserID: userID, + } + uploadRes := &api.PerformUploadDeviceKeysResponse{} + if res.MasterKey != nil { + if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil { + uploadReq.MasterKey = *res.MasterKey + } + } + if res.SelfSigningKey != nil { + if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil { + uploadReq.SelfSigningKey = *res.SelfSigningKey + } + } + _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) + } + err = u.updateDeviceList(&res) + if err != nil { + logger.WithError(err).Error("Fetched device list but failed to store/emit it") + return defaultWaitTime, err + } + return defaultWaitTime, nil +} + +func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error { + ctx := context.Background() // we've got the keys, don't time out when persisting them to the database. + keys := make([]api.DeviceMessage, len(res.Devices)) + existingKeys := make([]api.DeviceMessage, len(res.Devices)) + for i, device := range res.Devices { + keyJSON, err := json.Marshal(device.Keys) + if err != nil { + util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device") + continue + } + keys[i] = api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + StreamID: res.StreamID, + DeviceKeys: &api.DeviceKeys{ + DeviceID: device.DeviceID, + DisplayName: device.DisplayName, + UserID: res.UserID, + KeyJSON: keyJSON, + }, + } + existingKeys[i] = api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + UserID: res.UserID, + DeviceID: device.DeviceID, + }, + } + } + // fetch what keys we had already and only emit changes + if err := u.db.DeviceKeysJSON(ctx, existingKeys); err != nil { + // non-fatal, log and continue + util.GetLogger(ctx).WithError(err).WithField("user_id", res.UserID).Errorf( + "failed to query device keys json for calculating diffs", + ) + } + + err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID}) + if err != nil { + return fmt.Errorf("failed to store remote device keys: %w", err) + } + err = u.db.MarkDeviceListStale(ctx, res.UserID, false) + if err != nil { + return fmt.Errorf("failed to mark device list as fresh: %w", err) + } + err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false) + if err != nil { + return fmt.Errorf("failed to emit key changes for fresh device list: %w", err) + } + return nil +} diff --git a/userapi/internal/device_list_update_default.go b/userapi/internal/device_list_update_default.go new file mode 100644 index 00000000..7d357c95 --- /dev/null +++ b/userapi/internal/device_list_update_default.go @@ -0,0 +1,22 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !vw + +package internal + +import "time" + +const defaultWaitTime = time.Minute +const hourWaitTime = time.Hour diff --git a/userapi/internal/device_list_update_sytest.go b/userapi/internal/device_list_update_sytest.go new file mode 100644 index 00000000..1c60d2eb --- /dev/null +++ b/userapi/internal/device_list_update_sytest.go @@ -0,0 +1,25 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vw + +package internal + +import "time" + +// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite +// results in a one-hour wait time from a previous device so the test times out. This is fine for +// production, but makes an otherwise passing test fail. +const defaultWaitTime = time.Second +const hourWaitTime = time.Second diff --git a/userapi/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go new file mode 100644 index 00000000..868fc9be --- /dev/null +++ b/userapi/internal/device_list_update_test.go @@ -0,0 +1,431 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "crypto/ed25519" + "fmt" + "io" + "net/http" + "net/url" + "reflect" + "strings" + "sync" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" +) + +var ( + ctx = context.Background() +) + +type mockKeyChangeProducer struct { + events []api.DeviceMessage +} + +func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error { + p.events = append(p.events, keys...) + return nil +} + +type mockDeviceListUpdaterDatabase struct { + staleUsers map[string]bool + prevIDsExist func(string, []int64) bool + storedKeys []api.DeviceMessage + mu sync.Mutex // protect staleUsers +} + +func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error { + return nil +} + +// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. +// If no domains are given, all user IDs with stale device lists are returned. +func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + d.mu.Lock() + defer d.mu.Unlock() + var result []string + for userID, isStale := range d.staleUsers { + if !isStale { + continue + } + _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return nil, err + } + if len(domains) == 0 { + result = append(result, userID) + continue + } + for _, d := range domains { + if remoteServer == d { + result = append(result, userID) + break + } + } + } + return result, nil +} + +// MarkDeviceListStale sets the stale bit for this user to isStale. +func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + d.mu.Lock() + defer d.mu.Unlock() + d.staleUsers[userID] = isStale + return nil +} + +func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool { + d.mu.Lock() + defer d.mu.Unlock() + return d.staleUsers[userID] +} + +// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key +// for this (user, device). Does not modify the stream ID for keys. +func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error { + d.storedKeys = append(d.storedKeys, keys...) + return nil +} + +// PrevIDsExists returns true if all prev IDs exist for this user. +func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { + return d.prevIDsExist(userID, prevIDs), nil +} + +func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + return nil +} + +type mockDeviceListUpdaterAPI struct { +} + +func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { + return nil +} + +type roundTripper struct { + fn func(*http.Request) (*http.Response, error) +} + +func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t.fn(req) +} + +func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { + _, pkey, _ := ed25519.GenerateKey(nil) + fedClient := gomatrixserverlib.NewFederationClient( + []*gomatrixserverlib.SigningIdentity{ + { + ServerName: gomatrixserverlib.ServerName("example.test"), + KeyID: gomatrixserverlib.KeyID("ed25519:test"), + PrivateKey: pkey, + }, + }, + ) + fedClient.Client = *gomatrixserverlib.NewClient( + gomatrixserverlib.WithTransport(&roundTripper{tripper}), + ) + return fedClient +} + +// Test that the device keys get persisted and emitted if we have the previous IDs. +func TestUpdateHavePrevID(t *testing.T) { + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int64) bool { + return true + }, + } + ap := &mockDeviceListUpdaterAPI{} + producer := &mockKeyChangeProducer{} + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost") + event := gomatrixserverlib.DeviceListUpdateEvent{ + DeviceDisplayName: "Foo Bar", + Deleted: false, + DeviceID: "FOO", + Keys: []byte(`{"key":"value"}`), + PrevID: []int64{0}, + StreamID: 1, + UserID: "@alice:localhost", + } + err := updater.Update(ctx, event) + if err != nil { + t.Fatalf("Update returned an error: %s", err) + } + want := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + StreamID: event.StreamID, + DeviceKeys: &api.DeviceKeys{ + DeviceID: event.DeviceID, + DisplayName: event.DeviceDisplayName, + KeyJSON: event.Keys, + UserID: event.UserID, + }, + } + if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { + t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) + } + if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { + t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) + } + if db.isStale(event.UserID) { + t.Errorf("%s incorrectly marked as stale", event.UserID) + } +} + +// Test that device keys are fetched from the remote server if we are missing prev IDs +// and that the user's devices are marked as stale until it succeeds. +func TestUpdateNoPrevID(t *testing.T) { + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int64) bool { + return false + }, + } + ap := &mockDeviceListUpdaterAPI{} + producer := &mockKeyChangeProducer{} + remoteUserID := "@alice:example.somewhere" + var wg sync.WaitGroup + wg.Add(1) + keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` + fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { + defer wg.Done() + if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) { + return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) + } + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(` + { + "user_id": "` + remoteUserID + `", + "stream_id": 5, + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": ` + keyJSON + `, + "device_display_name": "Mobile Phone" + } + ] + } + `)), + }, nil + }) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test") + if err := updater.Start(); err != nil { + t.Fatalf("failed to start updater: %s", err) + } + event := gomatrixserverlib.DeviceListUpdateEvent{ + DeviceDisplayName: "Mobile Phone", + Deleted: false, + DeviceID: "another_device_id", + Keys: []byte(`{"key":"value"}`), + PrevID: []int64{3}, + StreamID: 4, + UserID: remoteUserID, + } + err := updater.Update(ctx, event) + + if err != nil { + t.Fatalf("Update returned an error: %s", err) + } + t.Log("waiting for /users/devices to be called...") + wg.Wait() + // wait a bit for db to be updated... + time.Sleep(100 * time.Millisecond) + want := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + StreamID: 5, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "JLAFKJWSCS", + DisplayName: "Mobile Phone", + UserID: remoteUserID, + KeyJSON: []byte(keyJSON), + }, + } + // Now we should have a fresh list and the keys and emitted something + if db.isStale(event.UserID) { + t.Errorf("%s still marked as stale", event.UserID) + } + if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { + t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON)) + t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) + } + if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { + t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) + } + +} + +// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the +// update is still ongoing. +func TestDebounce(t *testing.T) { + t.Skipf("panic on closed channel on GHA") + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int64) bool { + return true + }, + } + ap := &mockDeviceListUpdaterAPI{} + producer := &mockKeyChangeProducer{} + fedCh := make(chan *http.Response, 1) + srv := gomatrixserverlib.ServerName("example.com") + userID := "@alice:example.com" + keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` + incomingFedReq := make(chan struct{}) + fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { + if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) { + return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) + } + close(incomingFedReq) + return <-fedCh, nil + }) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost") + if err := updater.Start(); err != nil { + t.Fatalf("failed to start updater: %s", err) + } + + // hit this 5 times + var wg sync.WaitGroup + wg.Add(5) + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil { + t.Errorf("ManualUpdate: %s", err) + } + }() + } + + // wait until the updater hits federation + select { + case <-incomingFedReq: + case <-time.After(time.Second): + t.Fatalf("timed out waiting for updater to hit federation") + } + + // user should be marked as stale + if !db.isStale(userID) { + t.Errorf("user %s not marked as stale", userID) + } + // now send the response over federation + fedCh <- &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(` + { + "user_id": "` + userID + `", + "stream_id": 5, + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": ` + keyJSON + `, + "device_display_name": "Mobile Phone" + } + ] + } + `)), + } + close(fedCh) + // wait until all 5 ManualUpdates return. If we hit federation again we won't send a response + // and should panic with read on a closed channel + wg.Wait() + + // user is no longer stale now + if db.isStale(userID) { + t.Errorf("user %s is marked as stale", userID) + } +} + +func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { + t.Helper() + + base, _, _ := testrig.Base(nil) + connStr, clearDB := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) + if err != nil { + t.Fatal(err) + } + + return db, clearDB +} + +type mockKeyserverRoomserverAPI struct { + leftUsers []string +} + +func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { + res.LeftUsers = m.leftUsers + return nil +} + +func TestDeviceListUpdater_CleanUp(t *testing.T) { + processCtx := process.NewProcessContext() + + alice := test.NewUser(t) + bob := test.NewUser(t) + + // Bob is not joined to any of our rooms + rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}} + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clearDB := mustCreateKeyserverDB(t, dbType) + defer clearDB() + + // This should not get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil { + t.Error(err) + } + + // this one should get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil { + t.Error(err) + } + + updater := NewDeviceListUpdater(processCtx, db, nil, + nil, nil, + 0, rsAPI, "test") + if err := updater.CleanUp(); err != nil { + t.Error(err) + } + + // check that we still have Alice in our stale list + staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Error(err) + } + + // There should only be Alice + wantCount := 1 + if count := len(staleUsers); count != wantCount { + t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count) + } + + if staleUsers[0] != alice.ID { + t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID) + } + }) +} diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go new file mode 100644 index 00000000..be816fe5 --- /dev/null +++ b/userapi/internal/key_api.go @@ -0,0 +1,798 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/matrix-org/dendrite/userapi/api" +) + +func (a *UserInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error { + userIDs, latest, err := a.KeyDatabase.KeyChanges(ctx, req.Offset, req.ToOffset) + if err != nil { + res.Error = &api.KeyError{ + Err: err.Error(), + } + return nil + } + res.Offset = latest + res.UserIDs = userIDs + return nil +} + +func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error { + res.KeyErrors = make(map[string]map[string]*api.KeyError) + if len(req.DeviceKeys) > 0 { + a.uploadLocalDeviceKeys(ctx, req, res) + } + if len(req.OneTimeKeys) > 0 { + a.uploadOneTimeKeys(ctx, req, res) + } + otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + if err != nil { + return err + } + res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks} + return nil +} + +func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { + res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) + res.Failures = make(map[string]interface{}) + // wrap request map in a top-level by-domain map + domainToDeviceKeys := make(map[string]map[string]map[string]string) + for userID, val := range req.OneTimeKeys { + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + continue // ignore invalid users + } + nested, ok := domainToDeviceKeys[string(serverName)] + if !ok { + nested = make(map[string]map[string]string) + } + nested[userID] = val + domainToDeviceKeys[string(serverName)] = nested + } + for domain, local := range domainToDeviceKeys { + if !a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + continue + } + // claim local keys + keys, err := a.KeyDatabase.ClaimKeys(ctx, local) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err), + } + } + util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys") + for _, key := range keys { + _, ok := res.OneTimeKeys[key.UserID] + if !ok { + res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage) + } + _, ok = res.OneTimeKeys[key.UserID][key.DeviceID] + if !ok { + res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage) + } + for keyID, keyJSON := range key.KeyJSON { + res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON + } + } + delete(domainToDeviceKeys, domain) + } + if len(domainToDeviceKeys) > 0 { + a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) + } + return nil +} + +func (a *UserInternalAPI) claimRemoteKeys( + ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, +) { + var wg sync.WaitGroup // Wait for fan-out goroutines to finish + var mu sync.Mutex // Protects the response struct + var claimed int // Number of keys claimed in total + var failures int // Number of servers we failed to ask + + util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys)) + wg.Add(len(domainToDeviceKeys)) + + for d, k := range domainToDeviceKeys { + go func(domain string, keysToClaim map[string]map[string]string) { + fedCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + defer wg.Done() + + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) + + mu.Lock() + defer mu.Unlock() + + if err != nil { + util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed") + res.Failures[domain] = map[string]interface{}{ + "message": err.Error(), + } + failures++ + return + } + + for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys { + res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage) + for deviceID, keys := range deviceIDToKeys { + res.OneTimeKeys[userID][deviceID] = keys + claimed += len(keys) + } + } + }(d, k) + } + + wg.Wait() + util.GetLogger(ctx).WithFields(logrus.Fields{ + "num_keys": claimed, + "num_failures": failures, + }).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys)) +} + +func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { + if err := a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("Failed to delete device keys: %s", err), + } + } + return nil +} + +func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error { + count, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("Failed to query OTK counts: %s", err), + } + return nil + } + res.Count = *count + return nil +} + +func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { + msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query DB for device keys: %s", err), + } + return nil + } + maxStreamID := int64(0) + // remove deleted devices + var result []api.DeviceMessage + for _, m := range msgs { + if m.StreamID > maxStreamID { + maxStreamID = m.StreamID + } + if m.KeyJSON == nil || len(m.KeyJSON) == 0 { + continue + } + result = append(result, m) + } + res.Devices = result + res.StreamID = maxStreamID + return nil +} + +// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present +// in our database. +func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error { + knownDevices, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, []string{}, true) + if err != nil { + return err + } + if len(knownDevices) == 0 { + return nil // fmt.Errorf("unknown user %s", req.UserID) + } + + for i := range knownDevices { + if knownDevices[i].DeviceID == req.DeviceID { + return nil // we already know about this device + } + } + + return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID) +} + +// nolint:gocyclo +func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { + var respMu sync.Mutex + res.DeviceKeys = make(map[string]map[string]json.RawMessage) + res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) + res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) + res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) + res.Failures = make(map[string]interface{}) + + // make a map from domain to device keys + domainToDeviceKeys := make(map[string]map[string][]string) + domainToCrossSigningKeys := make(map[string]map[string]struct{}) + for userID, deviceIDs := range req.UserToDevices { + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + continue // ignore invalid users + } + domain := string(serverName) + // query local devices + if a.Config.Matrix.IsLocalServerName(serverName) { + deviceKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query local device keys: %s", err), + } + return nil + } + + // pull out display names after we have the keys so we handle wildcards correctly + var dids []string + for _, dk := range deviceKeys { + dids = append(dids, dk.DeviceID) + } + var queryRes api.QueryDeviceInfosResponse + err = a.QueryDeviceInfos(ctx, &api.QueryDeviceInfosRequest{ + DeviceIDs: dids, + }, &queryRes) + if err != nil { + util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") + } + + if res.DeviceKeys[userID] == nil { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + } + for _, dk := range deviceKeys { + if len(dk.KeyJSON) == 0 { + continue // don't include blank keys + } + // inject display name if known (either locally or remotely) + displayName := dk.DisplayName + if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" { + displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName + } + dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { + DisplayName string `json:"device_display_name,omitempty"` + }{displayName}) + res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON + } + } else { + domainToDeviceKeys[domain] = make(map[string][]string) + domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) + } + // work out if our cross-signing request for this user was + // satisfied, if not add them to the list of things to fetch + if _, ok := res.MasterKeys[userID]; !ok { + if _, ok := domainToCrossSigningKeys[domain]; !ok { + domainToCrossSigningKeys[domain] = make(map[string]struct{}) + } + domainToCrossSigningKeys[domain][userID] = struct{}{} + } + if _, ok := res.SelfSigningKeys[userID]; !ok { + if _, ok := domainToCrossSigningKeys[domain]; !ok { + domainToCrossSigningKeys[domain] = make(map[string]struct{}) + } + domainToCrossSigningKeys[domain][userID] = struct{}{} + } + } + + // attempt to satisfy key queries from the local database first as we should get device updates pushed to us + domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys) + if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 { + // perform key queries for remote devices + a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) + } + + // Now that we've done the potentially expensive work of asking the federation, + // try filling the cross-signing keys from the database that we know about. + a.crossSigningKeysFromDatabase(ctx, req, res) + + // Finally, append signatures that we know about + // TODO: This is horrible because we need to round-trip the signature from + // JSON, add the signatures and marshal it again, for some reason? + + for targetUserID, masterKey := range res.MasterKeys { + if masterKey.Signatures == nil { + masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + for targetKeyID := range masterKey.Keys { + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) + if err != nil { + // Stop executing the function if the context was canceled/the deadline was exceeded, + // as we can't continue without a valid context. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil + } + logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") + continue + } + if len(sigMap) == 0 { + continue + } + for sourceUserID, forSourceUser := range sigMap { + for sourceKeyID, sourceSig := range forSourceUser { + if _, ok := masterKey.Signatures[sourceUserID]; !ok { + masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig + } + } + } + } + + for targetUserID, forUserID := range res.DeviceKeys { + for targetKeyID, key := range forUserID { + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID)) + if err != nil { + // Stop executing the function if the context was canceled/the deadline was exceeded, + // as we can't continue without a valid context. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil + } + logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") + continue + } + if len(sigMap) == 0 { + continue + } + var deviceKey gomatrixserverlib.DeviceKeys + if err = json.Unmarshal(key, &deviceKey); err != nil { + continue + } + if deviceKey.Signatures == nil { + deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + for sourceUserID, forSourceUser := range sigMap { + for sourceKeyID, sourceSig := range forSourceUser { + if _, ok := deviceKey.Signatures[sourceUserID]; !ok { + deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig + } + } + if js, err := json.Marshal(deviceKey); err == nil { + res.DeviceKeys[targetUserID][targetKeyID] = js + } + } + } + return nil +} + +func (a *UserInternalAPI) remoteKeysFromDatabase( + ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string, +) map[string]map[string][]string { + fetchRemote := make(map[string]map[string][]string) + for domain, userToDeviceMap := range domainToDeviceKeys { + for userID, deviceIDs := range userToDeviceMap { + // we can't safely return keys from the db when all devices are requested as we don't + // know if one has just been added. + if len(deviceIDs) > 0 { + err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs) + if err == nil { + continue + } + util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase") + } + // fetch device lists from remote + if _, ok := fetchRemote[domain]; !ok { + fetchRemote[domain] = make(map[string][]string) + } + fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...) + + } + } + return fetchRemote +} + +func (a *UserInternalAPI) queryRemoteKeys( + ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, + domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{}, +) { + resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys)) + // allows us to wait until all federation servers have been poked + var wg sync.WaitGroup + // mutex for writing directly to res (e.g failures) + var respMu sync.Mutex + + domains := map[string]struct{}{} + for domain := range domainToDeviceKeys { + if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + continue + } + domains[domain] = struct{}{} + } + for domain := range domainToCrossSigningKeys { + if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + continue + } + domains[domain] = struct{}{} + } + wg.Add(len(domains)) + + // fan out + for domain := range domains { + go a.queryRemoteKeysOnServer( + ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain], + &wg, &respMu, timeout, resultCh, res, + ) + } + + // Close the result channel when the goroutines have quit so the for .. range exits + go func() { + wg.Wait() + close(resultCh) + }() + + processResult := func(result *gomatrixserverlib.RespQueryKeys) { + respMu.Lock() + defer respMu.Unlock() + for userID, nest := range result.DeviceKeys { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + for deviceID, deviceKey := range nest { + keyJSON, err := json.Marshal(deviceKey) + if err != nil { + continue + } + res.DeviceKeys[userID][deviceID] = keyJSON + } + } + + for userID, body := range result.MasterKeys { + res.MasterKeys[userID] = body + } + + for userID, body := range result.SelfSigningKeys { + res.SelfSigningKeys[userID] = body + } + + // TODO: do we want to persist these somewhere now + // that we have fetched them? + } + + for result := range resultCh { + processResult(result) + } +} + +func (a *UserInternalAPI) queryRemoteKeysOnServer( + ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{}, + wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys, + res *api.QueryKeysResponse, +) { + defer wg.Done() + fedCtx := ctx + if timeout > 0 { + var cancel context.CancelFunc + fedCtx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + // for users who we do not have any knowledge about, try to start doing device list updates for them + // by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but + // lack a stream ID. + userIDsForAllDevices := map[string]struct{}{} + for userID, deviceIDs := range devKeys { + if len(deviceIDs) == 0 { + userIDsForAllDevices[userID] = struct{}{} + } + } + // for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing + // a device list update, so we'll populate those back into the /keys/query list if not + for userID := range crossSigningKeys { + if devKeys == nil { + devKeys = map[string][]string{} + } + if _, ok := userIDsForAllDevices[userID]; !ok { + devKeys[userID] = []string{} + } + } + for userID := range userIDsForAllDevices { + err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID) + if err != nil { + logrus.WithFields(logrus.Fields{ + logrus.ErrorKey: err, + "user_id": userID, + "server": serverName, + }).Error("Failed to manually update device lists for user") + // try to do it via /keys/query + devKeys[userID] = []string{} + continue + } + // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this + // user so the fact that we're populating all devices here isn't a problem so long as we have devices. + err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil) + if err != nil { + logrus.WithFields(logrus.Fields{ + logrus.ErrorKey: err, + "user_id": userID, + "server": serverName, + }).Error("Failed to manually update device lists for user") + // try to do it via /keys/query + devKeys[userID] = []string{} + continue + } + } + if len(devKeys) == 0 { + return + } + queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) + if err == nil { + resultCh <- &queryKeysResp + return + } + respMu.Lock() + res.Failures[serverName] = map[string]interface{}{ + "message": err.Error(), + } + respMu.Unlock() + + // last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server + // is down, better to return something than nothing at all. Clients can know about the failure by + // inspecting the failures map though so they can know it's a cached response. + for userID, dkeys := range devKeys { + // drop the error as it's already a failure at this point + _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys) + } + + // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache + respMu.Lock() + if len(res.DeviceKeys) > 0 { + delete(res.Failures, serverName) + } + respMu.Unlock() +} + +func (a *UserInternalAPI) populateResponseWithDeviceKeysFromDatabase( + ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string, +) error { + keys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false) + // if we can't query the db or there are fewer keys than requested, fetch from remote. + if err != nil { + return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) + } + if len(keys) < len(deviceIDs) { + return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID) + } + if len(deviceIDs) == 0 && len(keys) == 0 { + return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID) + } + respMu.Lock() + if res.DeviceKeys[userID] == nil { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + } + respMu.Unlock() + + for _, key := range keys { + if len(key.KeyJSON) == 0 { + continue // ignore deleted keys + } + // inject the display name + key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { + DisplayName string `json:"device_display_name,omitempty"` + }{key.DisplayName}) + respMu.Lock() + res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON + respMu.Unlock() + } + return nil +} + +func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + // get a list of devices from the user API that actually exist, as + // we won't store keys for devices that don't exist + uapidevices := &api.QueryDevicesResponse{} + if err := a.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { + res.Error = &api.KeyError{ + Err: err.Error(), + } + return + } + if !uapidevices.UserExists { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("user %q does not exist", req.UserID), + } + return + } + existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices)) + for _, key := range uapidevices.Devices { + existingDeviceMap[key.ID] = struct{}{} + } + + // Get all of the user existing device keys so we can check for changes. + existingKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, true) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), + } + return + } + + // Work out whether we have device keys in the keyserver for devices that + // no longer exist in the user API. This is mostly an exercise to ensure + // that we keep some integrity between the two. + var toClean []gomatrixserverlib.KeyID + for _, k := range existingKeys { + if _, ok := existingDeviceMap[k.DeviceID]; !ok { + toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID)) + } + } + + if len(toClean) > 0 { + if err = a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { + logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean)) + } else { + logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean)) + } + } + + var keysToStore []api.DeviceMessage + + if req.OnlyDisplayNameUpdates { + for _, existingKey := range existingKeys { + for _, newKey := range req.DeviceKeys { + switch { + case existingKey.UserID != newKey.UserID: + continue + case existingKey.DeviceID != newKey.DeviceID: + continue + case existingKey.DisplayName != newKey.DisplayName: + existingKey.DisplayName = newKey.DisplayName + } + } + keysToStore = append(keysToStore, existingKey) + } + } else { + // assert that the user ID / device ID are not lying for each key + for _, key := range req.DeviceKeys { + var serverName gomatrixserverlib.ServerName + _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID) + if err != nil { + continue // ignore invalid users + } + if !a.Config.Matrix.IsLocalServerName(serverName) { + continue // ignore remote users + } + if len(key.KeyJSON) == 0 { + keysToStore = append(keysToStore, key.WithStreamID(0)) + continue // deleted keys don't need sanity checking + } + // check that the device in question actually exists in the user + // API before we try and store a key for it + if _, ok := existingDeviceMap[key.DeviceID]; !ok { + continue + } + gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str + gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str + if gotUserID == key.UserID && gotDeviceID == key.DeviceID { + keysToStore = append(keysToStore, key.WithStreamID(0)) + continue + } + + res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ + Err: fmt.Sprintf( + "user_id or device_id mismatch: users: %s - %s, devices: %s - %s", + gotUserID, key.UserID, gotDeviceID, key.DeviceID, + ), + }) + } + } + + // store the device keys and emit changes + err = a.KeyDatabase.StoreLocalDeviceKeys(ctx, keysToStore) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), + } + return + } + err = emitDeviceKeyChanges(a.KeyChangeProducer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates) + if err != nil { + util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) + } +} + +func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + if req.UserID == "" { + res.Error = &api.KeyError{ + Err: "user ID missing", + } + } + if req.DeviceID != "" && len(req.OneTimeKeys) == 0 { + counts, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.KeyDatabase.OneTimeKeysCount: %s", err), + } + } + if counts != nil { + res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts) + } + return + } + for _, key := range req.OneTimeKeys { + // grab existing keys based on (user/device/algorithm/key ID) + keyIDsWithAlgorithms := make([]string, len(key.KeyJSON)) + i := 0 + for keyIDWithAlgo := range key.KeyJSON { + keyIDsWithAlgorithms[i] = keyIDWithAlgo + i++ + } + existingKeys, err := a.KeyDatabase.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms) + if err != nil { + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ + Err: "failed to query existing one-time keys: " + err.Error(), + }) + continue + } + for keyIDWithAlgo := range existingKeys { + // if keys exist and the JSON doesn't match, error out as the key already exists + if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) { + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ + Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo), + }) + continue + } + } + // store one-time keys + counts, err := a.KeyDatabase.StoreOneTimeKeys(ctx, key) + if err != nil { + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ + Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()), + }) + continue + } + // collect counts + res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts) + } + +} + +func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error { + // if we only want to update the display names, we can skip the checks below + if onlyUpdateDisplayName { + return producer.ProduceKeyChanges(new) + } + // find keys in new that are not in existing + var keysAdded []api.DeviceMessage + for _, newKey := range new { + exists := false + for _, existingKey := range existing { + // Do not treat the absence of keys as equal, or else we will not emit key changes + // when users delete devices which never had a key to begin with as both KeyJSONs are nil. + if existingKey.DeviceKeysEqual(&newKey) { + exists = true + break + } + } + if !exists { + keysAdded = append(keysAdded, newKey) + } + } + return producer.ProduceKeyChanges(keysAdded) +} diff --git a/userapi/internal/key_api_test.go b/userapi/internal/key_api_test.go new file mode 100644 index 00000000..fc7e7e0d --- /dev/null +++ b/userapi/internal/key_api_test.go @@ -0,0 +1,161 @@ +package internal_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/internal" + "github.com/matrix-org/dendrite/userapi/storage" +) + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + base, _, _ := testrig.Base(nil) + db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to create new user db: %v", err) + } + return db, func() { + base.Close() + close() + } +} + +func Test_QueryDeviceMessages(t *testing.T) { + alice := test.NewUser(t) + type args struct { + req *api.QueryDeviceMessagesRequest + res *api.QueryDeviceMessagesResponse + } + tests := []struct { + name string + args args + wantErr bool + want *api.QueryDeviceMessagesResponse + }{ + { + name: "no existing keys", + args: args{ + req: &api.QueryDeviceMessagesRequest{ + UserID: "@doesNotExist:localhost", + }, + res: &api.QueryDeviceMessagesResponse{}, + }, + want: &api.QueryDeviceMessagesResponse{}, + }, + { + name: "existing user returns devices", + args: args{ + req: &api.QueryDeviceMessagesRequest{ + UserID: alice.ID, + }, + res: &api.QueryDeviceMessagesResponse{}, + }, + want: &api.QueryDeviceMessagesResponse{ + StreamID: 6, + Devices: []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, StreamID: 5, DeviceKeys: &api.DeviceKeys{ + DeviceID: "myDevice", + DisplayName: "first device", + UserID: alice.ID, + KeyJSON: []byte("ghi"), + }, + }, + { + Type: api.TypeDeviceKeyUpdate, StreamID: 6, DeviceKeys: &api.DeviceKeys{ + DeviceID: "mySecondDevice", + DisplayName: "second device", + UserID: alice.ID, + KeyJSON: []byte("jkl"), + }, // streamID 6 + }, + }, + }, + }, + } + + deviceMessages := []api.DeviceMessage{ + { // not the user we're looking for + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + UserID: "@doesNotExist:localhost", + }, + // streamID 1 for this user + }, + { // empty keyJSON will be ignored + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + DeviceID: "myDevice", + UserID: alice.ID, + }, // streamID 1 + }, + { + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + DeviceID: "myDevice", + UserID: alice.ID, + KeyJSON: []byte("abc"), + }, // streamID 2 + }, + { + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + DeviceID: "myDevice", + UserID: alice.ID, + KeyJSON: []byte("def"), + }, // streamID 3 + }, + { + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + DeviceID: "myDevice", + UserID: alice.ID, + KeyJSON: []byte(""), + }, // streamID 4 + }, + { + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + DeviceID: "myDevice", + DisplayName: "first device", + UserID: alice.ID, + KeyJSON: []byte("ghi"), + }, // streamID 5 + }, + { + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + DeviceID: "mySecondDevice", + UserID: alice.ID, + KeyJSON: []byte("jkl"), + DisplayName: "second device", + }, // streamID 6 + }, + } + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateDatabase(t, dbType) + defer closeDB() + if err := db.StoreLocalDeviceKeys(ctx, deviceMessages); err != nil { + t.Fatalf("failed to store local devicesKeys") + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &internal.UserInternalAPI{ + KeyDatabase: db, + } + if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr { + t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr) + } + got := tt.args.res + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("QueryDeviceMessages(): got:\n%+v, want:\n%+v", got, tt.want) + } + }) + } + }) +} diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go new file mode 100644 index 00000000..1cbd9719 --- /dev/null +++ b/userapi/internal/user_api.go @@ -0,0 +1,970 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strconv" + "time" + + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + "golang.org/x/crypto/bcrypt" + + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/internal/sqlutil" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + synctypes "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" + userapiUtil "github.com/matrix-org/dendrite/userapi/util" +) + +type UserInternalAPI struct { + DB storage.UserDatabase + KeyDatabase storage.KeyDatabase + SyncProducer *producers.SyncAPI + KeyChangeProducer *producers.KeyChange + Config *config.UserAPI + + DisableTLSValidation bool + // AppServices is the list of all registered AS + AppServices []config.ApplicationService + RSAPI rsapi.UserRoomserverAPI + PgClient pushgateway.Client + FedClient fedsenderapi.KeyserverFederationAPI + Updater *DeviceListUpdater +} + +func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot update account data of remote users (server name %s)", domain) + } + if req.DataType == "" { + return fmt.Errorf("data type must not be empty") + } + if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil { + util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed") + return fmt.Errorf("failed to save account data: %w", err) + } + var ignoredUsers *synctypes.IgnoredUsers + if req.DataType == "m.ignored_user_list" { + ignoredUsers = &synctypes.IgnoredUsers{} + _ = json.Unmarshal(req.AccountData, ignoredUsers) + } + if req.DataType == "m.fully_read" { + if err := a.setFullyRead(ctx, req); err != nil { + return err + } + } + if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{ + RoomID: req.RoomID, + Type: req.DataType, + IgnoredUsers: ignoredUsers, + }); err != nil { + util.GetLogger(ctx).WithError(err).Error("a.SyncProducer.SendAccountData failed") + return fmt.Errorf("failed to send account data to output: %w", err) + } + return nil +} + +func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccountDataRequest) error { + var output eventutil.ReadMarkerJSON + + if err := json.Unmarshal(req.AccountData, &output); err != nil { + return err + } + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure") + return nil + } + if !a.Config.Matrix.IsLocalServerName(domain) { + return nil + } + + deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) + if err != nil { + logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") + return err + } + + if err = a.SyncProducer.GetAndSendNotificationData(ctx, req.UserID, req.RoomID); err != nil { + logrus.WithError(err).Error("UserInternalAPI.setFullyRead: GetAndSendNotificationData failed") + return err + } + + // nothing changed, no need to notify the push gateway + if !deleted { + return nil + } + + if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil { + logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed") + return err + } + return nil +} + +func postRegisterJoinRooms(cfg *config.UserAPI, acc *api.Account, rsAPI rsapi.UserRoomserverAPI) { + // POST register behaviour: check if the user is a normal user. + // If the user is a normal user, add user to room specified in the configuration "auto_join_rooms". + if acc.AccountType != api.AccountTypeAppService && acc.AppServiceID == "" { + for room := range cfg.AutoJoinRooms { + userID := userutil.MakeUserID(acc.Localpart, cfg.Matrix.ServerName) + err := addUserToRoom(context.Background(), rsAPI, cfg.AutoJoinRooms[room], acc.Localpart, userID) + if err != nil { + logrus.WithFields(logrus.Fields{ + "user_id": userID, + "room": cfg.AutoJoinRooms[room], + }).WithError(err).Errorf("user failed to auto-join room") + } + } + } +} + +// Add user to a room. This function currently working for auto_join_rooms config, +// which can add a newly registered user to a specified room. +func addUserToRoom( + ctx context.Context, + rsAPI rsapi.UserRoomserverAPI, + roomID string, + username string, + userID string, +) error { + addGroupContent := make(map[string]interface{}) + // This make sure the user's username can be displayed correctly. + // Because the newly-registered user doesn't have an avatar, the avatar_url is not needed. + addGroupContent["displayname"] = username + joinReq := rsapi.PerformJoinRequest{ + RoomIDOrAlias: roomID, + UserID: userID, + Content: addGroupContent, + } + joinRes := rsapi.PerformJoinResponse{} + return rsAPI.PerformJoin(ctx, &joinReq, &joinRes) +} + +func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %s is not local", serverName) + } + acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType) + if err != nil { + if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists + switch req.OnConflict { + case api.ConflictUpdate: + break + case api.ConflictAbort: + return &api.ErrorConflict{ + Message: err.Error(), + } + } + } + // account already exists + res.AccountCreated = false + res.Account = &api.Account{ + AppServiceID: req.AppServiceID, + Localpart: req.Localpart, + ServerName: serverName, + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), + AccountType: req.AccountType, + } + return nil + } + + // Inform the SyncAPI about the newly created push_rules + if err = a.SyncProducer.SendAccountData(acc.UserID, eventutil.AccountData{ + Type: "m.push_rules", + }); err != nil { + util.GetLogger(ctx).WithFields(logrus.Fields{ + "user_id": acc.UserID, + }).WithError(err).Warn("failed to send account data to the SyncAPI") + } + + if req.AccountType == api.AccountTypeGuest { + res.AccountCreated = true + res.Account = acc + return nil + } + + if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil { + return fmt.Errorf("a.DB.SetDisplayName: %w", err) + } + + postRegisterJoinRooms(a.Config, acc, a.RSAPI) + + res.AccountCreated = true + res.Account = acc + return nil +} + +func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { + if !a.Config.Matrix.IsLocalServerName(req.ServerName) { + return fmt.Errorf("server name %s is not local", req.ServerName) + } + if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil { + return err + } + if req.LogoutDevices { + if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil { + return err + } + } + res.PasswordUpdated = true + return nil +} + +func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %s is not local", serverName) + } + util.GetLogger(ctx).WithFields(logrus.Fields{ + "localpart": req.Localpart, + "device_id": req.DeviceID, + "display_name": req.DeviceDisplayName, + }).Info("PerformDeviceCreation") + dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) + if err != nil { + return err + } + res.DeviceCreated = true + res.Device = dev + if req.NoDeviceListUpdate { + return nil + } + // create empty device keys and upload them to trigger device list changes + return a.deviceListUpdate(dev.UserID, []string{dev.ID}) +} + +func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.PerformDeviceDeletionRequest, res *api.PerformDeviceDeletionResponse) error { + util.GetLogger(ctx).WithField("user_id", req.UserID).WithField("devices", req.DeviceIDs).Info("PerformDeviceDeletion") + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot PerformDeviceDeletion of remote users (server name %s)", domain) + } + deletedDeviceIDs := req.DeviceIDs + if len(req.DeviceIDs) == 0 { + var devices []api.Device + devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID) + for _, d := range devices { + deletedDeviceIDs = append(deletedDeviceIDs, d.ID) + } + } else { + err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs) + } + if err != nil { + return err + } + // Ask the keyserver to delete device keys and signatures for those devices + deleteReq := &api.PerformDeleteKeysRequest{ + UserID: req.UserID, + } + for _, keyID := range req.DeviceIDs { + deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID)) + } + deleteRes := &api.PerformDeleteKeysResponse{} + if err := a.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil { + return err + } + if err := deleteRes.Error; err != nil { + return fmt.Errorf("a.KeyAPI.PerformDeleteKeys: %w", err) + } + // create empty device keys and upload them to delete what was once there and trigger device list changes + return a.deviceListUpdate(req.UserID, deletedDeviceIDs) +} + +func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { + deviceKeys := make([]api.DeviceKeys, len(deviceIDs)) + for i, did := range deviceIDs { + deviceKeys[i] = api.DeviceKeys{ + UserID: userID, + DeviceID: did, + KeyJSON: nil, + } + } + + var uploadRes api.PerformUploadKeysResponse + if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{ + UserID: userID, + DeviceKeys: deviceKeys, + }, &uploadRes); err != nil { + return err + } + if uploadRes.Error != nil { + return fmt.Errorf("failed to delete device keys: %v", uploadRes.Error) + } + if len(uploadRes.KeyErrors) > 0 { + return fmt.Errorf("failed to delete device keys, key errors: %+v", uploadRes.KeyErrors) + } + return nil +} + +func (a *UserInternalAPI) PerformLastSeenUpdate( + ctx context.Context, + req *api.PerformLastSeenUpdateRequest, + res *api.PerformLastSeenUpdateResponse, +) error { + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + } + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("server name %s is not local", domain) + } + if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil { + return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) + } + return nil +} + +func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { + localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") + return err + } + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("server name %s is not local", domain) + } + dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID) + if err == sql.ErrNoRows { + res.DeviceExists = false + return nil + } else if err != nil { + util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed") + return err + } + res.DeviceExists = true + + if dev.UserID != req.RequestingUserID { + res.Forbidden = true + return nil + } + + err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") + return err + } + if req.DisplayName != nil && dev.DisplayName != *req.DisplayName { + // display name has changed: update the device key + var uploadRes api.PerformUploadKeysResponse + if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{ + UserID: req.RequestingUserID, + DeviceKeys: []api.DeviceKeys{ + { + DeviceID: dev.ID, + DisplayName: *req.DisplayName, + KeyJSON: nil, + UserID: dev.UserID, + }, + }, + OnlyDisplayNameUpdates: true, + }, &uploadRes); err != nil { + return err + } + if uploadRes.Error != nil { + return fmt.Errorf("failed to update device key display name: %v", uploadRes.Error) + } + if len(uploadRes.KeyErrors) > 0 { + return fmt.Errorf("failed to update device key display name, key errors: %+v", uploadRes.KeyErrors) + } + } + return nil +} + +func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query profile of remote users (server name %s)", domain) + } + prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + return err + } + res.UserExists = true + res.AvatarURL = prof.AvatarURL + res.DisplayName = prof.DisplayName + return nil +} + +func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { + profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit) + if err != nil { + return err + } + res.Profiles = profiles + return nil +} + +func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { + devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs) + if err != nil { + return err + } + res.DeviceInfo = make(map[string]struct { + DisplayName string + UserID string + }) + for _, d := range devices { + res.DeviceInfo[d.ID] = struct { + DisplayName string + UserID string + }{ + DisplayName: d.DisplayName, + UserID: d.UserID, + } + } + return nil +} + +func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query devices of remote users (server name %s)", domain) + } + devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain) + if err != nil { + return err + } + res.UserExists = true + res.Devices = devs + return nil +} + +func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query account data of remote users (server name %s)", domain) + } + if req.DataType != "" { + var data json.RawMessage + data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType) + if err != nil { + return err + } + res.RoomAccountData = make(map[string]map[string]json.RawMessage) + res.GlobalAccountData = make(map[string]json.RawMessage) + if data != nil { + if req.RoomID != "" { + if _, ok := res.RoomAccountData[req.RoomID]; !ok { + res.RoomAccountData[req.RoomID] = make(map[string]json.RawMessage) + } + res.RoomAccountData[req.RoomID][req.DataType] = data + } else { + res.GlobalAccountData[req.DataType] = data + } + } + return nil + } + global, rooms, err := a.DB.GetAccountData(ctx, local, domain) + if err != nil { + return err + } + res.RoomAccountData = rooms + res.GlobalAccountData = global + return nil +} + +func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAccessTokenRequest, res *api.QueryAccessTokenResponse) error { + if req.AppServiceUserID != "" { + appServiceDevice, err := a.queryAppServiceToken(ctx, req.AccessToken, req.AppServiceUserID) + if err != nil { + res.Err = err.Error() + } + res.Device = appServiceDevice + + return nil + } + device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + return err + } + localPart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return err + } + if !a.Config.Matrix.IsLocalServerName(domain) { + return nil + } + acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain) + if err != nil { + return err + } + device.AccountType = acc.AccountType + res.Device = device + return nil +} + +func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) { + res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName) + return +} + +// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem +// creating a 'device'. +func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { + // Search for app service with given access_token + var appService *config.ApplicationService + for _, as := range a.AppServices { + if as.ASToken == token { + appService = &as + break + } + } + if appService == nil { + return nil, nil + } + + // Create a dummy device for AS user + dev := api.Device{ + // Use AS dummy device ID + ID: "AS_Device", + // AS dummy device has AS's token. + AccessToken: token, + AppserviceID: appService.ID, + AccountType: api.AccountTypeAppService, + } + + localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) + if err != nil { + return nil, err + } + + if localpart != "" { // AS is masquerading as another user + // Verify that the user is registered + account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain) + // Verify that the account exists and either appServiceID matches or + // it belongs to the appservice user namespaces + if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { + // Set the userID of dummy device + dev.UserID = appServiceUserID + return &dev, nil + } + return nil, &api.ErrorForbidden{Message: "appservice has not registered this user"} + } + + // AS is not masquerading as any user, so use AS's sender_localpart + dev.UserID = appService.SenderLocalpart + return &dev, nil +} + +// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. +func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %q not locally configured", serverName) + } + + evacuateReq := &rsapi.PerformAdminEvacuateUserRequest{ + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), + } + evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{} + if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil { + return err + } + if err := evacuateRes.Error; err != nil { + logrus.WithError(err).Errorf("Failed to evacuate user after account deactivation") + } + + deviceReq := &api.PerformDeviceDeletionRequest{ + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), + } + deviceRes := &api.PerformDeviceDeletionResponse{} + if err := a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil { + return err + } + + pusherReq := &api.PerformPusherDeletionRequest{ + Localpart: req.Localpart, + } + if err := a.PerformPusherDeletion(ctx, pusherReq, &struct{}{}); err != nil { + return err + } + + err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName) + res.AccountDeactivated = err == nil + return err +} + +// PerformOpenIDTokenCreation creates a new token that a relying party uses to authenticate a user +func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { + token := util.RandomString(24) + + exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID) + + res.Token = api.OpenIDToken{ + Token: token, + UserID: req.UserID, + ExpiresAtMS: exp, + } + + return err +} + +// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation +func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { + openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token) + if err != nil { + return err + } + + res.Sub = openIDTokenAttrs.UserID + res.ExpiresAtMS = openIDTokenAttrs.ExpiresAtMS + + return nil +} + +func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error { + // Delete metadata + if req.DeleteBackup { + if req.Version == "" { + res.BadInput = true + res.Error = "must specify a version to delete" + return nil + } + exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version) + if err != nil { + res.Error = fmt.Sprintf("failed to delete backup: %s", err) + } + res.Exists = exists + res.Version = req.Version + return nil + } + // Create metadata + if req.Version == "" { + version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) + if err != nil { + res.Error = fmt.Sprintf("failed to create backup: %s", err) + } + res.Exists = err == nil + res.Version = version + return nil + } + // Update metadata + if len(req.Keys.Rooms) == 0 { + err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) + if err != nil { + res.Error = fmt.Sprintf("failed to update backup: %s", err) + } + res.Exists = err == nil + res.Version = req.Version + return nil + } + // Upload Keys for a specific version metadata + a.uploadBackupKeys(ctx, req, res) + return nil +} + +func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { + // you can only upload keys for the CURRENT version + version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "") + if err != nil { + res.Error = fmt.Sprintf("failed to query version: %s", err) + return + } + if deleted { + res.Error = "backup was deleted" + return + } + if version != req.Version { + res.BadInput = true + res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version) + return + } + res.Exists = true + res.Version = version + + // map keys to a form we can upload more easily - the map ensures we have no duplicates. + var uploads []api.InternalKeyBackupSession + for roomID, data := range req.Keys.Rooms { + for sessionID, sessionData := range data.Sessions { + uploads = append(uploads, api.InternalKeyBackupSession{ + RoomID: roomID, + SessionID: sessionID, + KeyBackupSession: sessionData, + }) + } + } + count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads) + if err != nil { + res.Error = fmt.Sprintf("failed to upsert keys: %s", err) + return + } + res.KeyCount = count + res.KeyETag = etag +} + +func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error { + version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version) + res.Version = version + if err != nil { + if err == sql.ErrNoRows { + res.Exists = false + return nil + } + res.Error = fmt.Sprintf("failed to query key backup: %s", err) + return nil + } + res.Algorithm = algorithm + res.AuthData = authData + res.ETag = etag + res.Exists = !deleted + + if !req.ReturnKeys { + res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID) + if err != nil { + res.Error = fmt.Sprintf("failed to count keys: %s", err) + } + return nil + } + + result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) + if err != nil { + res.Error = fmt.Sprintf("failed to query keys: %s", err) + return nil + } + res.Keys = result + return nil +} + +func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { + if req.Limit == 0 || req.Limit > 1000 { + req.Limit = 1000 + } + + var fromID int64 + var err error + if req.From != "" { + fromID, err = strconv.ParseInt(req.From, 10, 64) + if err != nil { + return fmt.Errorf("QueryNotifications: parsing 'from': %w", err) + } + } + var filter tables.NotificationFilter = tables.AllNotifications + if req.Only == "highlight" { + filter = tables.HighlightNotifications + } + notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter) + if err != nil { + return err + } + if notifs == nil { + // This ensures empty is JSON-encoded as [] instead of null. + notifs = []*api.Notification{} + } + res.Notifications = notifs + if lastID >= 0 { + res.NextToken = strconv.FormatInt(lastID+1, 10) + } + return nil +} + +func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.PerformPusherSetRequest, res *struct{}) error { + util.GetLogger(ctx).WithFields(logrus.Fields{ + "localpart": req.Localpart, + "pushkey": req.Pusher.PushKey, + "display_name": req.Pusher.AppDisplayName, + }).Info("PerformPusherCreation") + if !req.Append { + err := a.DB.RemovePushers(ctx, req.Pusher.AppID, req.Pusher.PushKey) + if err != nil { + return err + } + } + if req.Pusher.Kind == "" { + return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName) + } + if req.Pusher.PushKeyTS == 0 { + req.Pusher.PushKeyTS = int64(time.Now().Unix()) + } + return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName) +} + +func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { + pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName) + if err != nil { + return err + } + for i := range pushers { + logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID) + if pushers[i].SessionID != req.SessionID { + err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName) + if err != nil { + return err + } + } + } + return nil +} + +func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { + var err error + res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName) + return err +} + +func (a *UserInternalAPI) PerformPushRulesPut( + ctx context.Context, + req *api.PerformPushRulesPutRequest, + _ *struct{}, +) error { + bs, err := json.Marshal(&req.RuleSets) + if err != nil { + return err + } + userReq := api.InputAccountDataRequest{ + UserID: req.UserID, + DataType: pushRulesAccountDataType, + AccountData: json.RawMessage(bs), + } + var userRes api.InputAccountDataResponse // empty + if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil { + return err + } + return nil +} + +func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) + } + pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain) + if err != nil { + return fmt.Errorf("failed to query push rules: %w", err) + } + res.RuleSets = pushRules + return nil +} + +func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { + profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL) + res.Profile = profile + res.Changed = changed + return err +} + +func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error { + id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName) + if err != nil { + return err + } + res.ID = id + return nil +} + +func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error { + var err error + res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName) + return err +} + +func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error { + acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword) + switch err { + case sql.ErrNoRows: // user does not exist + return nil + case bcrypt.ErrMismatchedHashAndPassword: // user exists, but password doesn't match + return nil + case bcrypt.ErrHashTooShort: // user exists, but probably a passwordless account + return nil + default: + res.Exists = true + res.Account = acc + return nil + } +} + +func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { + profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName) + res.Profile = profile + res.Changed = changed + return err +} + +func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { + localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium) + if err != nil { + return err + } + res.Localpart = localpart + res.ServerName = domain + return nil +} + +func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error { + r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName) + if err != nil { + return err + } + res.ThreePIDs = r + return nil +} + +func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.PerformForgetThreePIDRequest, res *struct{}) error { + return a.DB.RemoveThreePIDAssociation(ctx, req.ThreePID, req.Medium) +} + +func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error { + return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium) +} + +const pushRulesAccountDataType = "m.push_rules" diff --git a/userapi/producers/keychange.go b/userapi/producers/keychange.go new file mode 100644 index 00000000..da6cea31 --- /dev/null +++ b/userapi/producers/keychange.go @@ -0,0 +1,107 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package producers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" +) + +// KeyChange produces key change events for the sync API and federation sender to consume +type KeyChange struct { + Topic string + JetStream JetStreamPublisher + DB storage.KeyChangeDatabase +} + +// ProduceKeyChanges creates new change events for each key +func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { + userToDeviceCount := make(map[string]int) + for _, key := range keys { + id, err := p.DB.StoreKeyChange(context.Background(), key.UserID) + if err != nil { + return err + } + key.DeviceChangeID = id + value, err := json.Marshal(key) + if err != nil { + return err + } + + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, key.UserID) + m.Data = value + + _, err = p.JetStream.PublishMsg(m) + if err != nil { + return err + } + + userToDeviceCount[key.UserID]++ + } + for userID, count := range userToDeviceCount { + logrus.WithFields(logrus.Fields{ + "user_id": userID, + "num_key_changes": count, + }).Tracef("Produced to key change topic '%s'", p.Topic) + } + return nil +} + +func (p *KeyChange) ProduceSigningKeyUpdate(key api.CrossSigningKeyUpdate) error { + output := &api.DeviceMessage{ + Type: api.TypeCrossSigningUpdate, + OutputCrossSigningKeyUpdate: &api.OutputCrossSigningKeyUpdate{ + CrossSigningKeyUpdate: key, + }, + } + + id, err := p.DB.StoreKeyChange(context.Background(), key.UserID) + if err != nil { + return err + } + output.DeviceChangeID = id + + value, err := json.Marshal(output) + if err != nil { + return err + } + + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, key.UserID) + m.Data = value + + _, err = p.JetStream.PublishMsg(m) + if err != nil { + return err + } + + logrus.WithFields(logrus.Fields{ + "user_id": key.UserID, + }).Tracef("Produced to cross-signing update topic '%s'", p.Topic) + return nil +} diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go index 51eaa985..165de899 100644 --- a/userapi/producers/syncapi.go +++ b/userapi/producers/syncapi.go @@ -19,13 +19,13 @@ type JetStreamPublisher interface { // SyncAPI produces messages for the Sync API server to consume. type SyncAPI struct { - db storage.Database + db storage.Notification producer JetStreamPublisher clientDataTopic string notificationDataTopic string } -func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { +func NewSyncAPI(db storage.UserDatabase, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { return &SyncAPI{ db: db, producer: js, diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index c22b7658..27837886 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -90,7 +90,7 @@ type KeyBackup interface { type LoginToken interface { // CreateLoginToken generates a token, stores and returns it. The lifetime is - // determined by the loginTokenLifetime given to the Database constructor. + // determined by the loginTokenLifetime given to the UserDatabase constructor. CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) // RemoveLoginToken removes the named token (and may clean up other expired tokens). @@ -130,7 +130,7 @@ type Notification interface { DeleteOldNotifications(ctx context.Context) error } -type Database interface { +type UserDatabase interface { Account AccountData Device @@ -144,6 +144,78 @@ type Database interface { ThreePID } +type KeyChangeDatabase interface { + // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. + // `userID` is the the user who has changed their keys in some way. + StoreKeyChange(ctx context.Context, userID string) (int64, error) +} + +type KeyDatabase interface { + KeyChangeDatabase + // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination + // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. + ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + + // StoreOneTimeKeys persists the given one-time keys. + StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + + // OneTimeKeysCount returns a count of all OTKs for this device. + OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). + // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. + // Returns an error if there was a problem storing the keys. + StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior + // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error + + // PrevIDsExists returns true if all prev IDs exist for this user. + PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) + + // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. + // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) + + // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying + // cross-signing signatures relating to that device. + DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error + + // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key + // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. + ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) + + // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). + // A to offset of types.OffsetNewest means no upper limit. + // Returns the offset of the latest key change. + KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) + + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. + // If no domains are given, all user IDs with stale device lists are returned. + StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + + // MarkDeviceListStale sets the stale bit for this user to isStale. + MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error + + CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) + CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) + CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) + + StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error + StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + + DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, + ) error +} + type Statistics interface { UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 2a4777d7..05716037 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData( roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) - _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content) + // Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage + // when passing the data to trigger "NOT NULL" constraint + var data *json.RawMessage + if len(content) > 0 { + data = &content + } + _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data) return } diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go new file mode 100644 index 00000000..c0ecbd30 --- /dev/null +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -0,0 +1,102 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var crossSigningKeysSchema = ` +CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( + user_id TEXT NOT NULL, + key_type SMALLINT NOT NULL, + key_data TEXT NOT NULL, + PRIMARY KEY (user_id, key_type) +); +` + +const selectCrossSigningKeysForUserSQL = "" + + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + + " WHERE user_id = $1" + +const upsertCrossSigningKeysForUserSQL = "" + + "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + + " VALUES($1, $2, $3)" + + " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3" + +type crossSigningKeysStatements struct { + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt +} + +func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { + s := &crossSigningKeysStatements{ + db: db, + } + _, err := db.Exec(crossSigningKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, + {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, + }.Prepare(db) +} + +func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, +) (r types.CrossSigningKeyMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") + r = types.CrossSigningKeyMap{} + for rows.Next() { + var keyTypeInt int16 + var keyData gomatrixserverlib.Base64Bytes + if err := rows.Scan(&keyTypeInt, &keyData); err != nil { + return nil, err + } + keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] + if !ok { + return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) + } + r[keyType] = keyData + } + return +} + +func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, +) error { + keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] + if !ok { + return fmt.Errorf("unknown key purpose %q", keyType) + } + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { + return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/cross_signing_sigs_table.go b/userapi/storage/postgres/cross_signing_sigs_table.go new file mode 100644 index 00000000..b0117145 --- /dev/null +++ b/userapi/storage/postgres/cross_signing_sigs_table.go @@ -0,0 +1,131 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var crossSigningSigsSchema = ` +CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) +); + +CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); +` + +const selectCrossSigningSigsForTargetSQL = "" + + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + + " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3" + +const upsertCrossSigningSigsForTargetSQL = "" + + "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + + " VALUES($1, $2, $3, $4, $5)" + + " ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5" + +const deleteCrossSigningSigsForTargetSQL = "" + + "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" + +type crossSigningSigsStatements struct { + db *sql.DB + selectCrossSigningSigsForTargetStmt *sql.Stmt + upsertCrossSigningSigsForTargetStmt *sql.Stmt + deleteCrossSigningSigsForTargetStmt *sql.Stmt +} + +func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { + s := &crossSigningSigsStatements{ + db: db, + } + _, err := db.Exec(crossSigningSigsSchema) + if err != nil { + return nil, err + } + + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "keyserver: cross signing signature indexes", + Up: deltas.UpFixCrossSigningSignatureIndexes, + }) + if err = m.Up(context.Background()); err != nil { + return nil, err + } + + return s, sqlutil.StatementList{ + {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, + {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, + {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL}, + }.Prepare(db) +} + +func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) (r types.CrossSigningSigMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed") + r = types.CrossSigningSigMap{} + for rows.Next() { + var userID string + var keyID gomatrixserverlib.KeyID + var signature gomatrixserverlib.Base64Bytes + if err := rows.Scan(&userID, &keyID, &signature); err != nil { + return nil, err + } + if _, ok := r[userID]; !ok { + r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + r[userID][keyID] = signature + } + return +} + +func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} + +func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) error { + if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { + return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/deltas/2022012016470000_key_changes.go b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go new file mode 100644 index 00000000..0cfe9e79 --- /dev/null +++ b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go @@ -0,0 +1,69 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { + // start counting from the last max offset, else 0. We need to do a count(*) first to see if there + // even are entries in this table to know if we can query for log_offset. Without the count then + // the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't + // exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/ + var count int + _ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count) + if count > 0 { + var maxOffset int64 + _ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset) + if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil { + return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err) + } + } + + _, err := tx.ExecContext(ctx, ` + -- make the new table + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers + DROP SEQUENCE IF EXISTS keyserver_key_changes_seq; + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + log_offset BIGINT NOT NULL, + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go new file mode 100644 index 00000000..1a3d4fee --- /dev/null +++ b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go @@ -0,0 +1,47 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; + ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id); + + CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; + ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id); + + DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx; + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/device_keys_table.go b/userapi/storage/postgres/device_keys_table.go new file mode 100644 index 00000000..a9203857 --- /dev/null +++ b/userapi/storage/postgres/device_keys_table.go @@ -0,0 +1,213 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/lib/pq" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var deviceKeysSchema = ` +-- Stores device keys for users +CREATE TABLE IF NOT EXISTS keyserver_device_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- the stream ID of this key, scoped per-user. This gets updated when the device key changes. + -- This means we do not store an unbounded append-only log of device keys, which is not actually + -- required in the spec because in the event of a missed update the server fetches the entire + -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. + stream_id BIGINT NOT NULL, + display_name TEXT, + -- Clobber based on tuple of user/device. + CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) +); +` + +const upsertDeviceKeysSQL = "" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" + +const selectDeviceKeysSQL = "" + + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +const selectBatchDeviceKeysSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" + +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" + +const countStreamIDsForUserSQL = "" + + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)" + +const deleteDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +const deleteAllDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1" + +type deviceKeysStatements struct { + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + countStreamIDsForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt +} + +func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { + s := &deviceKeysStatements{ + db: db, + } + _, err := db.Exec(deviceKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL}, + {&s.selectDeviceKeysStmt, selectDeviceKeysSQL}, + {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL}, + {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL}, + {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL}, + {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, + {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL}, + {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL}, + }.Prepare(db) +} + +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + for i, key := range keys { + var keyJSONStr string + var streamID int64 + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) + if err != nil && err != sql.ErrNoRows { + return err + } + // this will be '' when there is no device + keys[i].Type = api.TypeDeviceKeyUpdate + keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } + } + return nil +} + +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { + // nullable if there are no results + var nullStream sql.NullInt64 + err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int64 + } + return +} + +func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { + // nullable if there are no results + var count sql.NullInt32 + err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count) + if err != nil { + return 0, err + } + if count.Valid { + return int(count.Int32), nil + } + return 0, nil +} + +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, + ) + if err != nil { + return err + } + } + return nil +} + +func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} + +func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + return err +} + +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") + deviceIDMap := make(map[string]bool) + for _, d := range deviceIDs { + deviceIDMap[d] = true + } + var result []api.DeviceMessage + var displayName sql.NullString + for rows.Next() { + dk := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + UserID: userID, + }, + } + if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dk.DisplayName = displayName.String + } + // include the key if we want all keys (no device) or it was asked + if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { + result = append(result, dk) + } + } + return result, rows.Err() +} diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 7481ac5b..88f8839c 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice( if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { return nil, fmt.Errorf("insertDeviceStmt: %w", err) } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -168,7 +168,11 @@ func (s *devicesStatements) InsertDevice( LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go index 7b58f7ba..91a34c35 100644 --- a/userapi/storage/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -52,7 +52,7 @@ const updateBackupKeySQL = "" + const countKeysSQL = "" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" -const selectKeysSQL = "" + +const selectBackupKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" @@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, - {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysStmt, selectBackupKeysSQL}, {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, }.Prepare(db) diff --git a/userapi/storage/postgres/key_changes_table.go b/userapi/storage/postgres/key_changes_table.go new file mode 100644 index 00000000..a0049414 --- /dev/null +++ b/userapi/storage/postgres/key_changes_table.go @@ -0,0 +1,127 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var keyChangesSchema = ` +-- Stores key change information about users. Used to determine when to send updated device lists to clients. +CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq; +CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) +); +` + +// Replace based on user ID. We don't care how many times the user's keys have changed, only that they +// have changed, hence we can just keep bumping the change ID for this user. +const upsertKeyChangeSQL = "" + + "INSERT INTO keyserver_key_changes (user_id)" + + " VALUES ($1)" + + " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" + + " DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" + + " RETURNING change_id" + +const selectKeyChangesSQL = "" + + "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" + +type keyChangesStatements struct { + db *sql.DB + upsertKeyChangeStmt *sql.Stmt + selectKeyChangesStmt *sql.Stmt +} + +func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { + s := &keyChangesStatements{ + db: db, + } + _, err := db.Exec(keyChangesSchema) + if err != nil { + return s, err + } + + if err = executeMigration(context.Background(), db); err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertKeyChangeStmt, upsertKeyChangeSQL}, + {&s.selectKeyChangesStmt, selectKeyChangesSQL}, + }.Prepare(db) +} + +func executeMigration(ctx context.Context, db *sql.DB) error { + // TODO: Remove when we are sure we are not having goose artefacts in the db + // This forces an error, which indicates the migration is already applied, since the + // column partition was removed from the table + migrationName := "keyserver: refactor key changes" + + var cName string + err := db.QueryRowContext(ctx, "select column_name from information_schema.columns where table_name = 'keyserver_key_changes' AND column_name = 'partition'").Scan(&cName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed + if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { + return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) + } + return nil + } + return err + } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: migrationName, + Up: deltas.UpRefactorKeyChanges, + }) + + return m.Up(ctx) +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { + err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) + return +} + +func (s *keyChangesStatements) SelectKeyChanges( + ctx context.Context, fromOffset, toOffset int64, +) (userIDs []string, latestOffset int64, err error) { + latestOffset = fromOffset + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") + for rows.Next() { + var userID string + var offset int64 + if err := rows.Scan(&userID, &offset); err != nil { + return nil, 0, err + } + if offset > latestOffset { + latestOffset = offset + } + userIDs = append(userIDs, userID) + } + return +} diff --git a/userapi/storage/postgres/one_time_keys_table.go b/userapi/storage/postgres/one_time_keys_table.go new file mode 100644 index 00000000..972a5914 --- /dev/null +++ b/userapi/storage/postgres/one_time_keys_table.go @@ -0,0 +1,194 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var oneTimeKeysSchema = ` +-- Stores one-time public keys for users +CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + key_id TEXT NOT NULL, + algorithm TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- Clobber based on 4-uple of user/device/key/algorithm. + CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id); +` + +const upsertKeysSQL = "" + + "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" + + " DO UPDATE SET key_json = $6" + +const selectOneTimeKeysSQL = "" + + "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);" + +const selectKeysCountSQL = "" + + "SELECT algorithm, COUNT(key_id) FROM " + + " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" + + " x GROUP BY algorithm" + +const deleteOneTimeKeySQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" + +const selectKeyByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" + +const deleteOneTimeKeysSQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" + +type oneTimeKeysStatements struct { + db *sql.DB + upsertKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysCountStmt *sql.Stmt + selectKeyByAlgorithmStmt *sql.Stmt + deleteOneTimeKeyStmt *sql.Stmt + deleteOneTimeKeysStmt *sql.Stmt +} + +func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { + s := &oneTimeKeysStatements{ + db: db, + } + _, err := db.Exec(oneTimeKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertKeysSQL}, + {&s.selectKeysStmt, selectOneTimeKeysSQL}, + {&s.selectKeysCountStmt, selectKeysCountSQL}, + {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL}, + {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL}, + {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL}, + }.Prepare(db) +} + +func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") + + result := make(map[string]json.RawMessage) + var ( + algorithmWithID string + keyJSONStr string + ) + for rows.Next() { + if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil { + return nil, err + } + result[algorithmWithID] = json.RawMessage(keyJSONStr) + } + return result, rows.Err() +} + +func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + counts := &api.OneTimeKeysCount{ + DeviceID: deviceID, + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + +func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { + now := time.Now().Unix() + counts := &api.OneTimeKeysCount{ + DeviceID: keys.DeviceID, + UserID: keys.UserID, + KeyCount: make(map[string]int), + } + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( + ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), + ) + if err != nil { + return nil, err + } + } + rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + + return counts, rows.Err() +} + +func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( + ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} + +func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} diff --git a/userapi/storage/postgres/stale_device_lists.go b/userapi/storage/postgres/stale_device_lists.go new file mode 100644 index 00000000..c823b58c --- /dev/null +++ b/userapi/storage/postgres/stale_device_lists.go @@ -0,0 +1,131 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/lib/pq" + + "github.com/matrix-org/dendrite/internal/sqlutil" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" + +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)" + +type staleDeviceListsStatements struct { + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt + deleteStaleDeviceListsStmt *sql.Stmt +} + +func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{} + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, + }.Prepare(db) +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt) + _, err := stmt.ExecContext(ctx, pq.Array(userIDs)) + return err +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 92dc4808..673d123b 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -136,3 +136,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) + if err != nil { + return nil, err + } + otk, err := NewPostgresOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewPostgresDeviceKeysTable(db) + if err != nil { + return nil, err + } + kc, err := NewPostgresKeyChangesTable(db) + if err != nil { + return nil, err + } + sdl, err := NewPostgresStaleDeviceListsTable(db) + if err != nil { + return nil, err + } + csk, err := NewPostgresCrossSigningKeysTable(db) + if err != nil { + return nil, err + } + css, err := NewPostgresCrossSigningSigsTable(db) + if err != nil { + return nil, err + } + + return &shared.KeyDatabase{ + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, + CrossSigningKeysTable: csk, + CrossSigningSigsTable: css, + Writer: writer, + }, nil +} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index bf94f14d..d3272a03 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -59,6 +59,17 @@ type Database struct { OpenIDTokenLifetimeMS int64 } +type KeyDatabase struct { + OneTimeKeysTable tables.OneTimeKeys + DeviceKeysTable tables.DeviceKeys + KeyChangesTable tables.KeyChanges + StaleDeviceListsTable tables.StaleDeviceLists + CrossSigningKeysTable tables.CrossSigningKeys + CrossSigningSigsTable tables.CrossSigningSigs + DB *sql.DB + Writer sqlutil.Writer +} + const ( // The length of generated device IDs deviceIDByteLength = 6 @@ -875,3 +886,227 @@ func (d *Database) DailyRoomsMessages( ) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { return d.Stats.DailyRoomsMessages(ctx, nil, serverName) } + +// + +func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) +} + +func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) + return err + }) + return +} + +func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) +} + +func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) +} + +func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { + count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) + if err != nil { + return false, err + } + return count == len(prevIDs), nil +} + +func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for _, userID := range clearUserIDs { + err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) + if err != nil { + return err + } + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + // work out the latest stream IDs for each user + userIDToStreamID := make(map[string]int64) + for _, k := range keys { + userIDToStreamID[k.UserID] = 0 + } + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for userID := range userIDToStreamID { + streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) + if err != nil { + return err + } + userIDToStreamID[userID] = streamID + } + // set the stream IDs for each key + for i := range keys { + k := keys[i] + userIDToStreamID[k.UserID]++ // start stream from 1 + k.StreamID = userIDToStreamID[k.UserID] + keys[i] = k + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) +} + +func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { + var result []api.OneTimeKeys + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for userID, deviceToAlgo := range userToDeviceToAlgorithm { + for deviceID, algo := range deviceToAlgo { + keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo) + if err != nil { + return err + } + if keyJSON != nil { + result = append(result, api.OneTimeKeys{ + UserID: userID, + DeviceID: deviceID, + KeyJSON: keyJSON, + }) + } + } + } + return nil + }) + return result, err +} + +func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { + err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) + return err + }) + return +} + +func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { + return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) +} + +// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. +// If no domains are given, all user IDs with stale device lists are returned. +func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) +} + +// MarkDeviceListStale sets the stale bit for this user to isStale. +func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) + }) +} + +// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying +// cross-signing signatures relating to that device. +func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for _, deviceID := range deviceIDs { + if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) + } + if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) + } + if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err) + } + } + return nil + }) +} + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { + keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) + if err != nil { + return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) + } + results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} + for purpose, key := range keyMap { + keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) + result := gomatrixserverlib.CrossSigningKey{ + UserID: userID, + Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose}, + Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + keyID: key, + }, + } + sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID) + if err != nil { + continue + } + for sigUserID, forSigUserID := range sigMap { + if userID != sigUserID { + continue + } + if result.Signatures == nil { + result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + if _, ok := result.Signatures[sigUserID]; !ok { + result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + for sigKeyID, sigBytes := range forSigUserID { + result.Signatures[sigUserID][sigKeyID] = sigBytes + } + } + results[purpose] = result + } + return results, nil +} + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { + return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) +} + +// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. +func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { + return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) +} + +// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. +func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for keyType, keyData := range keyMap { + if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { + return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) + } + } + return nil + }) +} + +// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. +func (d *KeyDatabase) StoreCrossSigningSigsForTarget( + ctx context.Context, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err) + } + return nil + }) +} + +// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. +func (d *KeyDatabase) DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) + }) +} diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go new file mode 100644 index 00000000..10721fcc --- /dev/null +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -0,0 +1,101 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var crossSigningKeysSchema = ` +CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( + user_id TEXT NOT NULL, + key_type INTEGER NOT NULL, + key_data TEXT NOT NULL, + PRIMARY KEY (user_id, key_type) +); +` + +const selectCrossSigningKeysForUserSQL = "" + + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + + " WHERE user_id = $1" + +const upsertCrossSigningKeysForUserSQL = "" + + "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + + " VALUES($1, $2, $3)" + +type crossSigningKeysStatements struct { + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt +} + +func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { + s := &crossSigningKeysStatements{ + db: db, + } + _, err := db.Exec(crossSigningKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, + {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, + }.Prepare(db) +} + +func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, +) (r types.CrossSigningKeyMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") + r = types.CrossSigningKeyMap{} + for rows.Next() { + var keyTypeInt int16 + var keyData gomatrixserverlib.Base64Bytes + if err := rows.Scan(&keyTypeInt, &keyData); err != nil { + return nil, err + } + keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] + if !ok { + return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) + } + r[keyType] = keyData + } + return +} + +func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, +) error { + keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] + if !ok { + return fmt.Errorf("unknown key purpose %q", keyType) + } + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { + return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/cross_signing_sigs_table.go b/userapi/storage/sqlite3/cross_signing_sigs_table.go new file mode 100644 index 00000000..2be00c9c --- /dev/null +++ b/userapi/storage/sqlite3/cross_signing_sigs_table.go @@ -0,0 +1,129 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var crossSigningSigsSchema = ` +CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) +); + +CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); +` + +const selectCrossSigningSigsForTargetSQL = "" + + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + + " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4" + +const upsertCrossSigningSigsForTargetSQL = "" + + "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + + " VALUES($1, $2, $3, $4, $5)" + +const deleteCrossSigningSigsForTargetSQL = "" + + "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" + +type crossSigningSigsStatements struct { + db *sql.DB + selectCrossSigningSigsForTargetStmt *sql.Stmt + upsertCrossSigningSigsForTargetStmt *sql.Stmt + deleteCrossSigningSigsForTargetStmt *sql.Stmt +} + +func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { + s := &crossSigningSigsStatements{ + db: db, + } + _, err := db.Exec(crossSigningSigsSchema) + if err != nil { + return nil, err + } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "keyserver: cross signing signature indexes", + Up: deltas.UpFixCrossSigningSignatureIndexes, + }) + if err = m.Up(context.Background()); err != nil { + return nil, err + } + + return s, sqlutil.StatementList{ + {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, + {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, + {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL}, + }.Prepare(db) +} + +func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) (r types.CrossSigningSigMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForOriginTargetStmt: rows.close() failed") + r = types.CrossSigningSigMap{} + for rows.Next() { + var userID string + var keyID gomatrixserverlib.KeyID + var signature gomatrixserverlib.Base64Bytes + if err := rows.Scan(&userID, &keyID, &signature); err != nil { + return nil, err + } + if _, ok := r[userID]; !ok { + r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + r[userID][keyID] = signature + } + return +} + +func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} + +func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) error { + if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { + return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go new file mode 100644 index 00000000..cd0f19df --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go @@ -0,0 +1,66 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { + // start counting from the last max offset, else 0. + var maxOffset int64 + var userID string + _ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset) + + _, err := tx.ExecContext(ctx, ` + -- make the new table + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (user_id) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + // to start counting from maxOffset, insert a row with that value + if userID != "" { + _, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID) + return err + } + return nil +} + +func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + offset BIGINT NOT NULL, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (partition, offset) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go new file mode 100644 index 00000000..d4e38dea --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go @@ -0,0 +1,71 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) + ); + + INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature) + SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs; + + DROP TABLE keyserver_cross_signing_sigs; + ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs; + + CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, target_user_id, target_key_id) + ); + + INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature) + SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs; + + DROP TABLE keyserver_cross_signing_sigs; + ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs; + + DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx; + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/device_keys_table.go b/userapi/storage/sqlite3/device_keys_table.go new file mode 100644 index 00000000..15e69cc4 --- /dev/null +++ b/userapi/storage/sqlite3/device_keys_table.go @@ -0,0 +1,213 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var deviceKeysSchema = ` +-- Stores device keys for users +CREATE TABLE IF NOT EXISTS keyserver_device_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + stream_id BIGINT NOT NULL, + display_name TEXT, + -- Clobber based on tuple of user/device. + UNIQUE (user_id, device_id) +); +` + +const upsertDeviceKeysSQL = "" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT (user_id, device_id)" + + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" + +const selectDeviceKeysSQL = "" + + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +const selectBatchDeviceKeysSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" + +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" + +const countStreamIDsForUserSQL = "" + + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" + +const deleteDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +const deleteAllDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1" + +type deviceKeysStatements struct { + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt +} + +func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { + s := &deviceKeysStatements{ + db: db, + } + _, err := db.Exec(deviceKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL}, + {&s.selectDeviceKeysStmt, selectDeviceKeysSQL}, + {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL}, + {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL}, + {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL}, + // {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, // prepared at runtime + {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL}, + {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL}, + }.Prepare(db) +} + +func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} + +func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + return err +} + +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + deviceIDMap := make(map[string]bool) + for _, d := range deviceIDs { + deviceIDMap[d] = true + } + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") + var result []api.DeviceMessage + var displayName sql.NullString + for rows.Next() { + dk := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + UserID: userID, + }, + } + if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dk.DisplayName = displayName.String + } + // include the key if we want all keys (no device) or it was asked + if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { + result = append(result, dk) + } + } + return result, rows.Err() +} + +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + for i, key := range keys { + var keyJSONStr string + var streamID int64 + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) + if err != nil && err != sql.ErrNoRows { + return err + } + // this will be '' when there is no device + keys[i].Type = api.TypeDeviceKeyUpdate + keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } + } + return nil +} + +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { + // nullable if there are no results + var nullStream sql.NullInt64 + err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int64 + } + return +} + +func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { + iStreamIDs := make([]interface{}, len(streamIDs)+1) + iStreamIDs[0] = userID + for i := range streamIDs { + iStreamIDs[i+1] = streamIDs[i] + } + query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) + // nullable if there are no results + var count sql.NullInt64 + err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) + if err != nil { + return 0, err + } + if count.Valid { + return int(count.Int64), nil + } + return 0, nil +} + +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, + ) + if err != nil { + return err + } + } + return nil +} diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 449e4549..65e17527 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -151,7 +151,7 @@ func (s *devicesStatements) InsertDevice( if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -159,7 +159,11 @@ func (s *devicesStatements) InsertDevice( LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, @@ -172,7 +176,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn * if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -180,7 +184,11 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn * LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) DeleteDevice( @@ -202,6 +210,7 @@ func (s *devicesStatements) DeleteDevices( if err != nil { return err } + defer internal.CloseAndLogIfError(ctx, prep, "DeleteDevices.StmtClose() failed") stmt := sqlutil.TxStmt(txn, prep) params := make([]interface{}, len(devices)+2) params[0] = localpart diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go index 7883ffb1..ed274631 100644 --- a/userapi/storage/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -52,7 +52,7 @@ const updateBackupKeySQL = "" + const countKeysSQL = "" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" -const selectKeysSQL = "" + +const selectBackupKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" @@ -83,7 +83,7 @@ func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, - {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysStmt, selectBackupKeysSQL}, {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, }.Prepare(db) diff --git a/userapi/storage/sqlite3/key_changes_table.go b/userapi/storage/sqlite3/key_changes_table.go new file mode 100644 index 00000000..923bb57e --- /dev/null +++ b/userapi/storage/sqlite3/key_changes_table.go @@ -0,0 +1,125 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var keyChangesSchema = ` +-- Stores key change information about users. Used to determine when to send updated device lists to clients. +CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (user_id) +); +` + +// Replace based on user ID. We don't care how many times the user's keys have changed, only that they +// have changed, hence we can just keep bumping the change ID for this user. +const upsertKeyChangeSQL = "" + + "INSERT OR REPLACE INTO keyserver_key_changes (user_id)" + + " VALUES ($1)" + + " RETURNING change_id" + +const selectKeyChangesSQL = "" + + "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" + +type keyChangesStatements struct { + db *sql.DB + upsertKeyChangeStmt *sql.Stmt + selectKeyChangesStmt *sql.Stmt +} + +func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { + s := &keyChangesStatements{ + db: db, + } + _, err := db.Exec(keyChangesSchema) + if err != nil { + return s, err + } + + if err = executeMigration(context.Background(), db); err != nil { + return nil, err + } + + return s, sqlutil.StatementList{ + {&s.upsertKeyChangeStmt, upsertKeyChangeSQL}, + {&s.selectKeyChangesStmt, selectKeyChangesSQL}, + }.Prepare(db) +} + +func executeMigration(ctx context.Context, db *sql.DB) error { + // TODO: Remove when we are sure we are not having goose artefacts in the db + // This forces an error, which indicates the migration is already applied, since the + // column partition was removed from the table + migrationName := "keyserver: refactor key changes" + + var cName string + err := db.QueryRowContext(ctx, `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'keyserver_key_changes' AND p.name = 'partition'`).Scan(&cName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed + if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { + return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) + } + return nil + } + return err + } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: migrationName, + Up: deltas.UpRefactorKeyChanges, + }) + return m.Up(ctx) +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { + err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) + return +} + +func (s *keyChangesStatements) SelectKeyChanges( + ctx context.Context, fromOffset, toOffset int64, +) (userIDs []string, latestOffset int64, err error) { + latestOffset = fromOffset + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") + for rows.Next() { + var userID string + var offset int64 + if err := rows.Scan(&userID, &offset); err != nil { + return nil, 0, err + } + if offset > latestOffset { + latestOffset = offset + } + userIDs = append(userIDs, userID) + } + return +} diff --git a/userapi/storage/sqlite3/one_time_keys_table.go b/userapi/storage/sqlite3/one_time_keys_table.go new file mode 100644 index 00000000..a992d399 --- /dev/null +++ b/userapi/storage/sqlite3/one_time_keys_table.go @@ -0,0 +1,208 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var oneTimeKeysSchema = ` +-- Stores one-time public keys for users +CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + key_id TEXT NOT NULL, + algorithm TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- Clobber based on 4-uple of user/device/key/algorithm. + UNIQUE (user_id, device_id, key_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id); +` + +const upsertKeysSQL = "" + + "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT (user_id, device_id, key_id, algorithm)" + + " DO UPDATE SET key_json = $6" + +const selectOneTimeKeysSQL = "" + + "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" + +const selectKeysCountSQL = "" + + "SELECT algorithm, COUNT(key_id) FROM " + + " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" + + " x GROUP BY algorithm" + +const deleteOneTimeKeySQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" + +const selectKeyByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" + +const deleteOneTimeKeysSQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" + +type oneTimeKeysStatements struct { + db *sql.DB + upsertKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysCountStmt *sql.Stmt + selectKeyByAlgorithmStmt *sql.Stmt + deleteOneTimeKeyStmt *sql.Stmt + deleteOneTimeKeysStmt *sql.Stmt +} + +func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { + s := &oneTimeKeysStatements{ + db: db, + } + _, err := db.Exec(oneTimeKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertKeysSQL}, + {&s.selectKeysStmt, selectOneTimeKeysSQL}, + {&s.selectKeysCountStmt, selectKeysCountSQL}, + {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL}, + {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL}, + {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL}, + }.Prepare(db) +} + +func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") + + wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) + for _, ka := range keyIDsWithAlgorithms { + wantSet[ka] = true + } + + result := make(map[string]json.RawMessage) + for rows.Next() { + var keyID string + var algorithm string + var keyJSONStr string + if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil { + return nil, err + } + keyIDWithAlgo := algorithm + ":" + keyID + if wantSet[keyIDWithAlgo] { + result[keyIDWithAlgo] = json.RawMessage(keyJSONStr) + } + } + return result, rows.Err() +} + +func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + counts := &api.OneTimeKeysCount{ + DeviceID: deviceID, + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + +func (s *oneTimeKeysStatements) InsertOneTimeKeys( + ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys, +) (*api.OneTimeKeysCount, error) { + now := time.Now().Unix() + counts := &api.OneTimeKeysCount{ + DeviceID: keys.DeviceID, + UserID: keys.UserID, + KeyCount: make(map[string]int), + } + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( + ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), + ) + if err != nil { + return nil, err + } + } + rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + + return counts, rows.Err() +} + +func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( + ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + if err != nil { + return nil, err + } + if keyJSON == "" { + return nil, nil + } + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} + +func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} diff --git a/userapi/storage/sqlite3/stale_device_lists.go b/userapi/storage/sqlite3/stale_device_lists.go new file mode 100644 index 00000000..f078fc99 --- /dev/null +++ b/userapi/storage/sqlite3/stale_device_lists.go @@ -0,0 +1,145 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" + +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)" + +type staleDeviceListsStatements struct { + db *sql.DB + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt + // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime +} + +func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{ + db: db, + } + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime + }.Prepare(db) +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) + stmt, err := s.db.Prepare(qry) + if err != nil { + return err + } + defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed") + stmt = sqlutil.TxStmt(txn, stmt) + + params := make([]any, len(userIDs)) + for i := range userIDs { + params[i] = userIDs[i] + } + + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go index a1365c94..72b3ba49 100644 --- a/userapi/storage/sqlite3/stats_table.go +++ b/userapi/storage/sqlite3/stats_table.go @@ -256,6 +256,7 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int if err != nil { return 0, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "allUsers.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) err = stmt.QueryRowContext(ctx, 1, 2, 3, 4, @@ -269,6 +270,7 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res if err != nil { return 0, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "nonBridgedUsers.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) err = stmt.QueryRowContext(ctx, 1, 2, 3, @@ -286,6 +288,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "registeredUserByType.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) registeredAfter := time.Now().AddDate(0, 0, -30) diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 85a1f706..0f3eeed1 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -30,8 +30,8 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) -// NewDatabase creates a new accounts and profiles database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { +// NewUserDatabase creates a new accounts and profiles database +func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) if err != nil { return nil, err @@ -134,3 +134,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) + if err != nil { + return nil, err + } + otk, err := NewSqliteOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewSqliteDeviceKeysTable(db) + if err != nil { + return nil, err + } + kc, err := NewSqliteKeyChangesTable(db) + if err != nil { + return nil, err + } + sdl, err := NewSqliteStaleDeviceListsTable(db) + if err != nil { + return nil, err + } + csk, err := NewSqliteCrossSigningKeysTable(db) + if err != nil { + return nil, err + } + css, err := NewSqliteCrossSigningSigsTable(db) + if err != nil { + return nil, err + } + + return &shared.KeyDatabase{ + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, + CrossSigningKeysTable: csk, + CrossSigningSigsTable: css, + Writer: writer, + }, nil +} diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go index 42221e75..0329fb46 100644 --- a/userapi/storage/storage.go +++ b/userapi/storage/storage.go @@ -29,15 +29,36 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3" ) -// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) +// NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) { +func NewUserDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + serverName gomatrixserverlib.ServerName, + bcryptCost int, + openIDTokenLifetimeMS int64, + loginTokenLifetime time.Duration, + serverNoticesLocalpart string, +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) default: return nil, fmt.Errorf("unexpected database type") } } + +// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme) +// and sets postgres connection parameters. +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewKeyDatabase(base, dbProperties) + case dbProperties.ConnectionString.IsPostgres(): + return postgres.NewKeyDatabase(base, dbProperties) + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 23aafff0..f52e7e17 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -4,9 +4,12 @@ import ( "context" "encoding/json" "fmt" + "reflect" + "sync" "testing" "time" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" @@ -29,14 +32,14 @@ var ( ctx = context.Background() ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { - t.Fatalf("NewUserAPIDatabase returned %s", err) + t.Fatalf("NewUserDatabase returned %s", err) } return db, func() { close() @@ -47,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun // Tests storing and getting account data func Test_AccountData(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -78,7 +81,7 @@ func Test_AccountData(t *testing.T) { // Tests the creation of accounts func Test_Accounts(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -158,7 +161,7 @@ func Test_Devices(t *testing.T) { accessToken := util.RandomString(16) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "") @@ -238,7 +241,7 @@ func Test_KeyBackup(t *testing.T) { room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() wantAuthData := json.RawMessage("my auth data") @@ -315,7 +318,7 @@ func Test_KeyBackup(t *testing.T) { func Test_LoginToken(t *testing.T) { alice := test.NewUser(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create a new token @@ -347,7 +350,7 @@ func Test_OpenID(t *testing.T) { token := util.RandomString(24) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS @@ -368,7 +371,7 @@ func Test_Profile(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create account, which also creates a profile @@ -417,7 +420,7 @@ func Test_Pusher(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() appID := util.RandomString(8) @@ -468,7 +471,7 @@ func Test_ThreePID(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() threePID := util.RandomString(8) medium := util.RandomString(8) @@ -507,7 +510,7 @@ func Test_Notification(t *testing.T) { room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // generate some dummy notifications for i := 0; i < 10; i++ { @@ -571,3 +574,184 @@ func Test_Notification(t *testing.T) { assert.Equal(t, int64(0), total) }) } + +func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { + base, close := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + if err != nil { + t.Fatalf("failed to create new database: %v", err) + } + return db, close +} + +func MustNotError(t *testing.T, err error) { + t.Helper() + if err == nil { + return + } + t.Fatalf("operation failed: %s", err) +} + +func TestKeyChanges(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + _, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDC { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) + } + if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +func TestKeyChangesNoDupes(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + if deviceChangeIDA == deviceChangeIDB { + t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) + } + deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeID { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) + } + if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +func TestKeyChangesUpperLimit(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + _, err = db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDB { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) + } + if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +var dbLock sync.Mutex +var deviceArray = []string{"AAA", "another_device"} + +// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, +// and that they are returned correctly when querying for device keys. +func TestDeviceKeysStreamIDGeneration(t *testing.T) { + var err error + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + alice := "@alice:TestDeviceKeysStreamIDGeneration" + bob := "@bob:TestDeviceKeysStreamIDGeneration" + msgs := []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 + }, + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: bob, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 as this is a different user + }, + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "another_device", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 2 as this is a 2nd device key + }, + } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + } + if msgs[1].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + } + if msgs[2].StreamID != 2 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + } + + // updating a device sets the next stream ID for that user + msgs = []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v2"}`), + }, + // StreamID: 3 + }, + } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 3 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + } + + dbLock.Lock() + defer dbLock.Unlock() + // Querying for device keys returns the latest stream IDs + msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false) + + if err != nil { + t.Fatalf("DeviceKeysForUser returned error: %s", err) + } + wantStreamIDs := map[string]int64{ + "AAA": 3, + "another_device": 2, + } + if len(msgs) != len(wantStreamIDs) { + t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) + } + for _, m := range msgs { + if m.StreamID != wantStreamIDs[m.DeviceID] { + t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) + } + } + }) +} diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go index 5d5d292e..163e3e17 100644 --- a/userapi/storage/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -32,10 +32,10 @@ func NewUserAPIDatabase( openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string, -) (Database, error) { +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 9221e571..693e7303 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -20,10 +20,10 @@ import ( "encoding/json" "time" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/types" ) @@ -145,3 +145,47 @@ const ( // uint32. AllNotifications NotificationFilter = (1 << 31) - 1 ) + +type OneTimeKeys interface { + SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. + // Returns an empty map if the key does not exist. + SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) + DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error +} + +type DeviceKeys interface { + SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error + SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) + CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) + DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error + DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error +} + +type KeyChanges interface { + InsertKeyChange(ctx context.Context, userID string) (int64, error) + // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. + // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. + SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) +} + +type StaleDeviceLists interface { + InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error + SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error +} + +type CrossSigningKeys interface { + SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) + UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error +} + +type CrossSigningSigs interface { + SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) + UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error +} diff --git a/userapi/storage/tables/stale_device_lists_test.go b/userapi/storage/tables/stale_device_lists_test.go new file mode 100644 index 00000000..b9bdafda --- /dev/null +++ b/userapi/storage/tables/stale_device_lists_test.go @@ -0,0 +1,94 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/userapi/storage/postgres" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, nil) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresStaleDeviceListsTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db) + } + if err != nil { + t.Fatalf("failed to create new table: %s", err) + } + return tab, close +} + +func TestStaleDeviceLists(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := "@charlie:localhost" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateTable(t, dbType) + defer closeDB() + + if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + + // Query one server + wantStaleUsers := []string{alice.ID, bob.ID} + gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Query all servers + wantStaleUsers = []string{alice.ID, bob.ID, charlie} + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Delete stale devices + deleteUsers := []string{alice.ID, bob.ID} + if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil { + t.Fatalf("failed to delete stale device lists: %s", err) + } + + // Verify we don't get anything back after deleting + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + + if gotCount := len(gotStaleUsers); gotCount > 0 { + t.Fatalf("expected no stale users, got %d", gotCount) + } + }) +} diff --git a/userapi/types/storage.go b/userapi/types/storage.go new file mode 100644 index 00000000..7fb90454 --- /dev/null +++ b/userapi/types/storage.go @@ -0,0 +1,50 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "math" + + "github.com/matrix-org/gomatrixserverlib" +) + +const ( + // OffsetNewest tells e.g. the database to get the most current data + OffsetNewest int64 = math.MaxInt64 + // OffsetOldest tells e.g. the database to get the oldest data + OffsetOldest int64 = 0 +) + +// KeyTypePurposeToInt maps a purpose to an integer, which is used in the +// database to reduce the amount of space taken up by this column. +var KeyTypePurposeToInt = map[gomatrixserverlib.CrossSigningKeyPurpose]int16{ + gomatrixserverlib.CrossSigningKeyPurposeMaster: 1, + gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: 2, + gomatrixserverlib.CrossSigningKeyPurposeUserSigning: 3, +} + +// KeyTypeIntToPurpose maps an integer to a purpose, which is used in the +// database to reduce the amount of space taken up by this column. +var KeyTypeIntToPurpose = map[int16]gomatrixserverlib.CrossSigningKeyPurpose{ + 1: gomatrixserverlib.CrossSigningKeyPurposeMaster, + 2: gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, + 3: gomatrixserverlib.CrossSigningKeyPurposeUserSigning, +} + +// Map of purpose -> public key +type CrossSigningKeyMap map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.Base64Bytes + +// Map of user ID -> key ID -> signature +type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes diff --git a/userapi/userapi.go b/userapi/userapi.go index 2dd81d75..826bd721 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -17,13 +17,11 @@ package userapi import ( "time" + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/internal/pushgateway" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/consumers" @@ -33,16 +31,20 @@ import ( "github.com/matrix-org/dendrite/userapi/util" ) -// NewInternalAPI returns a concerete implementation of the internal API. Callers +// NewInternalAPI returns a concrete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *base.BaseDendrite, cfg *config.UserAPI, - appServices []config.ApplicationService, keyAPI keyapi.UserKeyAPI, - rsAPI rsapi.UserRoomserverAPI, pgClient pushgateway.Client, -) api.UserInternalAPI { + base *base.BaseDendrite, + rsAPI rsapi.UserRoomserverAPI, + fedClient fedsenderapi.KeyserverFederationAPI, +) *internal.UserInternalAPI { + cfg := &base.Cfg.UserAPI js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + appServices := base.Cfg.Derived.ApplicationServices - db, err := storage.NewUserAPIDatabase( + pgClient := base.PushGatewayHTTPClient() + + db, err := storage.NewUserDatabase( base, &cfg.AccountDatabase, cfg.Matrix.ServerName, @@ -55,6 +57,11 @@ func NewInternalAPI( logrus.WithError(err).Panicf("failed to connect to accounts db") } + keyDB, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to key db") + } + syncProducer := producers.NewSyncAPI( db, js, // TODO: user API should handle syncs for account data. Right now, @@ -64,17 +71,50 @@ func NewInternalAPI( cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData), cfg.Matrix.JetStream.Prefixed(jetstream.OutputNotificationData), ) + keyChangeProducer := &producers.KeyChange{ + Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), + JetStream: js, + DB: keyDB, + } userAPI := &internal.UserInternalAPI{ DB: db, + KeyDatabase: keyDB, SyncProducer: syncProducer, + KeyChangeProducer: keyChangeProducer, Config: cfg, AppServices: appServices, - KeyAPI: keyAPI, RSAPI: rsAPI, DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, PgClient: pgClient, - Cfg: cfg, + FedClient: fedClient, + } + + updater := internal.NewDeviceListUpdater(base.ProcessContext, keyDB, userAPI, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable + userAPI.Updater = updater + // Remove users which we don't share a room with anymore + if err := updater.CleanUp(); err != nil { + logrus.WithError(err).Error("failed to cleanup stale device lists") + } + + go func() { + if err := updater.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start device list updater") + } + }() + + dlConsumer := consumers.NewDeviceListUpdateConsumer( + base.ProcessContext, cfg, js, updater, + ) + if err := dlConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start device list consumer") + } + + sigConsumer := consumers.NewSigningKeyUpdateConsumer( + base.ProcessContext, cfg, js, userAPI, + ) + if err := sigConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start signing key consumer") } receiptConsumer := consumers.NewOutputReceiptEventConsumer( diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 68d08c2f..08b1336b 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -21,7 +21,10 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/nats-io/nats.go" "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/setup/config" @@ -38,32 +41,55 @@ const ( type apiTestOpts struct { loginTokenLifetime time.Duration + serverName string } -func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) { +type dummyProducer struct{} + +func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) { + return &nats.PubAck{}, nil +} + +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.UserDatabase, func()) { if opts.loginTokenLifetime == 0 { opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + sName := serverName + if opts.serverName != "" { + sName = gomatrixserverlib.ServerName(opts.serverName) + } + accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") + }, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { t.Fatalf("failed to create account DB: %s", err) } + keyDB, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to create key DB: %s", err) + } + cfg := &config.UserAPI{ Matrix: &config.Global{ SigningIdentity: gomatrixserverlib.SigningIdentity{ - ServerName: serverName, + ServerName: sName, }, }, } + syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "") + keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: &dummyProducer{}} return &internal.UserInternalAPI{ - DB: accountDB, - Config: cfg, + DB: accountDB, + KeyDatabase: keyDB, + Config: cfg, + SyncProducer: syncProducer, + KeyChangeProducer: keyChangeProducer, }, accountDB, func() { close() baseclose() @@ -332,3 +358,292 @@ func TestQueryAccountByLocalpart(t *testing.T) { testCases(t, intAPI) }) } + +func TestAccountData(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + + testCases := []struct { + name string + inputData *api.InputAccountDataRequest + wantErr bool + }{ + { + name: "not a local user", + inputData: &api.InputAccountDataRequest{UserID: "@notlocal:example.com"}, + wantErr: true, + }, + { + name: "local user missing datatype", + inputData: &api.InputAccountDataRequest{UserID: alice.ID}, + wantErr: true, + }, + { + name: "missing json", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: nil}, + wantErr: true, + }, + { + name: "with json", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}")}, + }, + { + name: "room data", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}"), RoomID: "!dummy:test"}, + }, + { + name: "ignored users", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.ignored_user_list", AccountData: []byte("{}")}, + }, + { + name: "m.fully_read", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.fully_read", AccountData: []byte("{}")}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + defer close() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res := api.InputAccountDataResponse{} + err := intAPI.InputAccountData(ctx, tc.inputData, &res) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + + // query the data again and compare + queryRes := api.QueryAccountDataResponse{} + queryReq := api.QueryAccountDataRequest{ + UserID: tc.inputData.UserID, + DataType: tc.inputData.DataType, + RoomID: tc.inputData.RoomID, + } + err = intAPI.QueryAccountData(ctx, &queryReq, &queryRes) + if err != nil && !tc.wantErr { + t.Fatal(err) + } + // verify global data + if tc.inputData.RoomID == "" { + if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.GlobalAccountData[tc.inputData.DataType]) { + t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.GlobalAccountData[tc.inputData.DataType])) + } + } else { + // verify room data + if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]) { + t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType])) + } + } + }) + } + }) +} + +func TestDevices(t *testing.T) { + ctx := context.Background() + + dupeAccessToken := util.RandomString(8) + + displayName := "testing" + + creationTests := []struct { + name string + inputData *api.PerformDeviceCreationRequest + wantErr bool + wantNewDevID bool + }{ + { + name: "not a local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal"}, + wantErr: true, + }, + { + name: "implicit local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, DeviceDisplayName: &displayName}, + }, + { + name: "explicit local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + { + name: "dupe token - ok", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, + }, + { + name: "dupe token - not ok", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, + wantErr: true, + }, + { + name: "test3 second device", // used to test deletion later + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + { + name: "test3 third device", // used to test deletion later + wantNewDevID: true, + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + } + + deletionTests := []struct { + name string + inputData *api.PerformDeviceDeletionRequest + wantErr bool + wantDevices int + }{ + { + name: "deletion - not a local user", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test:notlocalhost"}, + wantErr: true, + }, + { + name: "deleting not existing devices should not error", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test", DeviceIDs: []string{"iDontExist"}}, + wantDevices: 1, + }, + { + name: "delete all devices", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test"}, + wantDevices: 0, + }, + { + name: "delete all devices", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test3:test"}, + wantDevices: 0, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + defer close() + + for _, tc := range creationTests { + t.Run(tc.name, func(t *testing.T) { + res := api.PerformDeviceCreationResponse{} + deviceID := util.RandomString(8) + tc.inputData.DeviceID = &deviceID + if tc.wantNewDevID { + tc.inputData.DeviceID = nil + } + err := intAPI.PerformDeviceCreation(ctx, tc.inputData, &res) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + if !res.DeviceCreated { + return + } + + queryDevicesRes := api.QueryDevicesResponse{} + queryDevicesReq := api.QueryDevicesRequest{UserID: res.Device.UserID} + if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil { + t.Fatal(err) + } + // We only want to verify one device + if len(queryDevicesRes.Devices) > 1 { + return + } + res.Device.AccessToken = "" + + // At this point, there should only be one device + if !reflect.DeepEqual(*res.Device, queryDevicesRes.Devices[0]) { + t.Fatalf("expected device to be\n%#v, got \n%#v", *res.Device, queryDevicesRes.Devices[0]) + } + + newDisplayName := "new name" + if tc.inputData.DeviceDisplayName == nil { + updateRes := api.PerformDeviceUpdateResponse{} + updateReq := api.PerformDeviceUpdateRequest{ + RequestingUserID: fmt.Sprintf("@%s:%s", tc.inputData.Localpart, "test"), + DeviceID: deviceID, + DisplayName: &newDisplayName, + } + + if err = intAPI.PerformDeviceUpdate(ctx, &updateReq, &updateRes); err != nil { + t.Fatal(err) + } + } + + queryDeviceInfosRes := api.QueryDeviceInfosResponse{} + queryDeviceInfosReq := api.QueryDeviceInfosRequest{DeviceIDs: []string{*tc.inputData.DeviceID}} + if err = intAPI.QueryDeviceInfos(ctx, &queryDeviceInfosReq, &queryDeviceInfosRes); err != nil { + t.Fatal(err) + } + gotDisplayName := queryDeviceInfosRes.DeviceInfo[*tc.inputData.DeviceID].DisplayName + if tc.inputData.DeviceDisplayName != nil { + wantDisplayName := *tc.inputData.DeviceDisplayName + if wantDisplayName != gotDisplayName { + t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName) + } + } else { + wantDisplayName := newDisplayName + if wantDisplayName != gotDisplayName { + t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName) + } + } + }) + } + + for _, tc := range deletionTests { + t.Run(tc.name, func(t *testing.T) { + delRes := api.PerformDeviceDeletionResponse{} + err := intAPI.PerformDeviceDeletion(ctx, tc.inputData, &delRes) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + if tc.wantErr { + return + } + + queryDevicesRes := api.QueryDevicesResponse{} + queryDevicesReq := api.QueryDevicesRequest{UserID: tc.inputData.UserID} + if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil { + t.Fatal(err) + } + + if len(queryDevicesRes.Devices) != tc.wantDevices { + t.Fatalf("expected %d devices, got %d", tc.wantDevices, len(queryDevicesRes.Devices)) + } + + }) + } + }) +} + +// Tests that the session ID of a device is not reused when reusing the same device ID. +func TestDeviceIDReuse(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + defer close() + + res := api.PerformDeviceCreationResponse{} + // create a first device + deviceID := util.RandomString(8) + req := api.PerformDeviceCreationRequest{Localpart: "alice", ServerName: "test", DeviceID: &deviceID, NoDeviceListUpdate: true} + err := intAPI.PerformDeviceCreation(ctx, &req, &res) + if err != nil { + t.Fatal(err) + } + + // Do the same request again, we expect a different sessionID + res2 := api.PerformDeviceCreationResponse{} + err = intAPI.PerformDeviceCreation(ctx, &req, &res2) + if err != nil { + t.Fatalf("expected no error, but got: %v", err) + } + + if res2.Device.SessionID == res.Device.SessionID { + t.Fatalf("expected a different session ID, but they are the same") + } + }) +} diff --git a/userapi/util/devices.go b/userapi/util/devices.go index c55fc799..31617d8c 100644 --- a/userapi/util/devices.go +++ b/userapi/util/devices.go @@ -19,7 +19,7 @@ type PusherDevice struct { } // GetPushDevices pushes to the configured devices of a local user. -func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { +func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) { pushers, err := db.GetPushers(ctx, localpart, serverName) if err != nil { return nil, fmt.Errorf("db.GetPushers: %w", err) diff --git a/userapi/util/notify.go b/userapi/util/notify.go index fc0ab39b..08d1371d 100644 --- a/userapi/util/notify.go +++ b/userapi/util/notify.go @@ -13,11 +13,11 @@ import ( ) // NotifyUserCountsAsync sends notifications to a local user's -// notification destinations. Database lookups run synchronously, but +// notification destinations. UserDatabase lookups run synchronously, but // a single goroutine is started when talking to the Push // gateways. There is no way to know when the background goroutine has // finished. -func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error { +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.UserDatabase) error { pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db) if err != nil { return err diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index f1d20259..421852d3 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -79,7 +79,7 @@ func TestNotifyUserCountsAsync(t *testing.T) { defer close() base, _, _ := testrig.Base(nil) defer base.Close() - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "test", bcrypt.MinCost, 0, 0, "") if err != nil { diff --git a/userapi/util/phonehomestats_test.go b/userapi/util/phonehomestats_test.go index 6e62210e..5f626b5b 100644 --- a/userapi/util/phonehomestats_test.go +++ b/userapi/util/phonehomestats_test.go @@ -21,7 +21,7 @@ func TestCollect(t *testing.T) { b, _, _ := testrig.Base(nil) connStr, closeDB := test.PrepareDBConnectionString(t, dbType) defer closeDB() - db, err := storage.NewUserAPIDatabase(b, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(b, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, 1000, 1000, "") if err != nil { -- cgit v1.2.3