aboutsummaryrefslogtreecommitdiff
path: root/keyserver
diff options
context:
space:
mode:
Diffstat (limited to 'keyserver')
-rw-r--r--keyserver/api/api.go14
-rw-r--r--keyserver/internal/internal.go11
-rw-r--r--keyserver/inthttp/client.go18
-rw-r--r--keyserver/inthttp/server.go11
-rw-r--r--keyserver/storage/interface.go3
-rw-r--r--keyserver/storage/postgres/one_time_keys_table.go22
-rw-r--r--keyserver/storage/shared/storage.go4
-rw-r--r--keyserver/storage/sqlite3/one_time_keys_table.go22
-rw-r--r--keyserver/storage/tables/interface.go1
9 files changed, 106 insertions, 0 deletions
diff --git a/keyserver/api/api.go b/keyserver/api/api.go
index 6795498f..eb2f9e24 100644
--- a/keyserver/api/api.go
+++ b/keyserver/api/api.go
@@ -31,6 +31,7 @@ type KeyInternalAPI interface {
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse)
+ QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse)
}
// KeyError is returned if there was a problem performing/querying the server
@@ -157,3 +158,16 @@ type QueryKeyChangesResponse struct {
// Set if there was a problem handling the request.
Error *KeyError
}
+
+type QueryOneTimeKeysRequest struct {
+ // The local user to query OTK counts for
+ UserID string
+ // The device to query OTK counts for
+ DeviceID string
+}
+
+type QueryOneTimeKeysResponse struct {
+ // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
+ Count OneTimeKeysCount
+ Error *KeyError
+}
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index bb828663..3c8dff84 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -168,6 +168,17 @@ func (a *KeyInternalAPI) claimRemoteKeys(
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
}
+func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) {
+ count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
+ }
+ return
+ }
+ res.Count = *count
+}
+
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{})
diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go
index 3f9690b5..b65cbdaf 100644
--- a/keyserver/inthttp/client.go
+++ b/keyserver/inthttp/client.go
@@ -31,6 +31,7 @@ const (
PerformClaimKeysPath = "/keyserver/performClaimKeys"
QueryKeysPath = "/keyserver/queryKeys"
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
+ QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys"
)
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
@@ -108,6 +109,23 @@ func (h *httpKeyInternalAPI) QueryKeys(
}
}
+func (h *httpKeyInternalAPI) QueryOneTimeKeys(
+ ctx context.Context,
+ request *api.QueryOneTimeKeysRequest,
+ response *api.QueryOneTimeKeysResponse,
+) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOneTimeKeys")
+ defer span.Finish()
+
+ apiURL := h.apiURL + QueryOneTimeKeysPath
+ err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ if err != nil {
+ response.Error = &api.KeyError{
+ Err: err.Error(),
+ }
+ }
+}
+
func (h *httpKeyInternalAPI) QueryKeyChanges(
ctx context.Context,
request *api.QueryKeyChangesRequest,
diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go
index f3d2882c..615b6f80 100644
--- a/keyserver/inthttp/server.go
+++ b/keyserver/inthttp/server.go
@@ -58,6 +58,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
+ internalAPIMux.Handle(QueryOneTimeKeysPath,
+ httputil.MakeInternalAPI("queryOneTimeKeys", func(req *http.Request) util.JSONResponse {
+ request := api.QueryOneTimeKeysRequest{}
+ response := api.QueryOneTimeKeysResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ s.QueryOneTimeKeys(req.Context(), &request, &response)
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
internalAPIMux.Handle(QueryKeyChangesPath,
httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse {
request := api.QueryKeyChangesRequest{}
diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go
index fade7522..0e0158e5 100644
--- a/keyserver/storage/interface.go
+++ b/keyserver/storage/interface.go
@@ -29,6 +29,9 @@ type Database interface {
// StoreOneTimeKeys persists the given one-time keys.
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
+ // OneTimeKeysCount returns a count of all OTKs for this device.
+ OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
+
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go
index a9d05548..df215d5a 100644
--- a/keyserver/storage/postgres/one_time_keys_table.go
+++ b/keyserver/storage/postgres/one_time_keys_table.go
@@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
return result, rows.Err()
}
+func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
+ counts := &api.OneTimeKeysCount{
+ DeviceID: deviceID,
+ UserID: userID,
+ KeyCount: make(map[string]int),
+ }
+ rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
+ for rows.Next() {
+ var algorithm string
+ var count int
+ if err = rows.Scan(&algorithm, &count); err != nil {
+ return nil, err
+ }
+ counts.KeyCount[algorithm] = count
+ }
+ return counts, nil
+}
+
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
now := time.Now().Unix()
counts := &api.OneTimeKeysCount{
diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go
index 8c2534f5..44cb0cc2 100644
--- a/keyserver/storage/shared/storage.go
+++ b/keyserver/storage/shared/storage.go
@@ -39,6 +39,10 @@ func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (
return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys)
}
+func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
+ return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
+}
+
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}
diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go
index fecf533e..b35407cd 100644
--- a/keyserver/storage/sqlite3/one_time_keys_table.go
+++ b/keyserver/storage/sqlite3/one_time_keys_table.go
@@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
return result, rows.Err()
}
+func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
+ counts := &api.OneTimeKeysCount{
+ DeviceID: deviceID,
+ UserID: userID,
+ KeyCount: make(map[string]int),
+ }
+ rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
+ for rows.Next() {
+ var algorithm string
+ var count int
+ if err = rows.Scan(&algorithm, &count); err != nil {
+ return nil, err
+ }
+ counts.KeyCount[algorithm] = count
+ }
+ return counts, nil
+}
+
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
now := time.Now().Unix()
counts := &api.OneTimeKeysCount{
diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go
index 8b89283f..c6e43be4 100644
--- a/keyserver/storage/tables/interface.go
+++ b/keyserver/storage/tables/interface.go
@@ -24,6 +24,7 @@ import (
type OneTimeKeys interface {
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
+ CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
// Returns an empty map if the key does not exist.