diff options
author | Kegsay <kegan@matrix.org> | 2020-07-22 17:04:57 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-07-22 17:04:57 +0100 |
commit | 541a23f712a1cff2e8ba0ada41ceff90e86ee42d (patch) | |
tree | 1fdacfd2b7479dd6a075bb5680b372f9b7939b59 /userapi | |
parent | 1e71fd645ed9bbac87627434b303659a195512c7 (diff) |
Handle inbound federation E2E key queries/claims (#1215)
* Handle inbound /keys/claim and /keys/query requests
* Add display names to device key responses
* Linting
Diffstat (limited to 'userapi')
-rw-r--r-- | userapi/api/api.go | 14 | ||||
-rw-r--r-- | userapi/internal/api.go | 21 | ||||
-rw-r--r-- | userapi/inthttp/client.go | 13 | ||||
-rw-r--r-- | userapi/inthttp/server.go | 14 | ||||
-rw-r--r-- | userapi/storage/devices/interface.go | 1 | ||||
-rw-r--r-- | userapi/storage/devices/postgres/devices_table.go | 36 | ||||
-rw-r--r-- | userapi/storage/devices/postgres/storage.go | 4 | ||||
-rw-r--r-- | userapi/storage/devices/sqlite3/devices_table.go | 43 | ||||
-rw-r--r-- | userapi/storage/devices/sqlite3/storage.go | 4 |
9 files changed, 148 insertions, 2 deletions
diff --git a/userapi/api/api.go b/userapi/api/api.go index cf0f0563..bd0773f8 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -30,6 +30,7 @@ type UserInternalAPI interface { QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error + QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error } // InputAccountDataRequest is the request for InputAccountData @@ -44,6 +45,19 @@ type InputAccountDataRequest struct { type InputAccountDataResponse struct { } +// QueryDeviceInfosRequest is the request to QueryDeviceInfos +type QueryDeviceInfosRequest struct { + DeviceIDs []string +} + +// QueryDeviceInfosResponse is the response to QueryDeviceInfos +type QueryDeviceInfosResponse struct { + DeviceInfo map[string]struct { + DisplayName string + UserID string + } +} + // QueryAccessTokenRequest is the request for QueryAccessToken type QueryAccessTokenRequest struct { AccessToken string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 1d10d1d8..2de8f960 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -125,6 +125,27 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil return nil } +func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { + devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs) + if err != nil { + return err + } + res.DeviceInfo = make(map[string]struct { + DisplayName string + UserID string + }) + for _, d := range devices { + res.DeviceInfo[d.ID] = struct { + DisplayName string + UserID string + }{ + DisplayName: d.DisplayName, + UserID: d.UserID, + } + } + return nil +} + func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 4ab0d690..b2b42823 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -35,6 +35,7 @@ const ( QueryAccessTokenPath = "/userapi/queryAccessToken" QueryDevicesPath = "/userapi/queryDevices" QueryAccountDataPath = "/userapi/queryAccountData" + QueryDeviceInfosPath = "/userapi/queryDeviceInfos" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -101,6 +102,18 @@ func (h *httpUserInternalAPI) QueryProfile( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) QueryDeviceInfos( + ctx context.Context, + request *api.QueryDeviceInfosRequest, + response *api.QueryDeviceInfosResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos") + defer span.Finish() + + apiURL := h.apiURL + QueryDeviceInfosPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + func (h *httpUserInternalAPI) QueryAccessToken( ctx context.Context, request *api.QueryAccessTokenRequest, diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 8f3be773..d8e151ad 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/util" ) +// nolint: gocyclo func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { internalAPIMux.Handle(PerformAccountCreationPath, httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { @@ -103,4 +104,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryDeviceInfosPath, + httputil.MakeInternalAPI("queryDeviceInfos", func(req *http.Request) util.JSONResponse { + request := api.QueryDeviceInfosRequest{} + response := api.QueryDeviceInfosResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryDeviceInfos(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 4bdb5785..3c9ec934 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -24,6 +24,7 @@ type Database interface { GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked // and replaced with the given accessToken. If the given accessToken is already in use for another device, diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index 1d036d1b..03bf7c72 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -84,11 +84,15 @@ const deleteDevicesByLocalpartSQL = "" + const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" +const selectDevicesByIDSQL = "" + + "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)" + type devicesStatements struct { insertDeviceStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt + selectDevicesByIDStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt deleteDeviceStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt @@ -125,6 +129,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { return } + if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { + return + } s.serverName = server return } @@ -207,15 +214,42 @@ func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device + var displayName sql.NullString stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) + if displayName.Valid { + dev.DisplayName = displayName.String + } } return &dev, err } +func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") + var devices []api.Device + for rows.Next() { + var dev api.Device + var localpart string + var displayName sql.NullString + if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dev.DisplayName = displayName.String + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + devices = append(devices, dev) + } + return devices, rows.Err() +} + func (s *devicesStatements) selectDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 801657bd..6ac802bb 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -71,6 +71,10 @@ func (d *Database) GetDevicesByLocalpart( return d.devices.selectDevicesByLocalpart(ctx, localpart) } +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.devices.selectDevicesByID(ctx, deviceIDs) +} + // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked // and replaced with the given accessToken. If the given accessToken is already in use for another device, diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index ec52c64b..efe6f927 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -20,6 +20,7 @@ import ( "strings" "time" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" @@ -72,6 +73,9 @@ const deleteDevicesByLocalpartSQL = "" + const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" +const selectDevicesByIDSQL = "" + + "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" + type devicesStatements struct { db *sql.DB writer *sqlutil.TransactionWriter @@ -79,6 +83,7 @@ type devicesStatements struct { selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt + selectDevicesByIDStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt deleteDeviceStmt *sql.Stmt @@ -117,6 +122,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { return } + if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { + return + } s.serverName = server return } @@ -224,11 +232,15 @@ func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device + var displayName sql.NullString stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) + if displayName.Valid { + dev.DisplayName = displayName.String + } } return &dev, err } @@ -263,3 +275,32 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, nil } + +func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1) + iDeviceIDs := make([]interface{}, len(deviceIDs)) + for i := range deviceIDs { + iDeviceIDs[i] = deviceIDs[i] + } + + rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") + var devices []api.Device + for rows.Next() { + var dev api.Device + var localpart string + var displayName sql.NullString + if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dev.DisplayName = displayName.String + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + devices = append(devices, dev) + } + return devices, rows.Err() +} diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index f248abda..b9f08ca1 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -77,6 +77,10 @@ func (d *Database) GetDevicesByLocalpart( return d.devices.selectDevicesByLocalpart(ctx, localpart) } +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.devices.selectDevicesByID(ctx, deviceIDs) +} + // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked // and replaced with the given accessToken. If the given accessToken is already in use for another device, |