aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-10-19 12:03:12 +0100
committerNeil Alexander <neilalexander@users.noreply.github.com>2022-10-19 12:03:12 +0100
commitc1463db6c9183aa67ef41e7ea85ed36dc5817d18 (patch)
tree00ba61dca01a01ed19f7dfe03977c86583b769e8
parentf3dae0e749ca35b1527fbfcb0371e89d0e9833ab (diff)
Fix concurrent map write in key server
-rw-r--r--keyserver/internal/internal.go17
1 files changed, 11 insertions, 6 deletions
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index 06fc4987..d2ea2093 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -250,6 +250,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap
// nolint:gocyclo
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
+ var respMu sync.Mutex
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
@@ -329,7 +330,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
- domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
+ domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
// perform key queries for remote devices
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
@@ -407,7 +408,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
func (a *KeyInternalAPI) remoteKeysFromDatabase(
- ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
+ ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
) map[string]map[string][]string {
fetchRemote := make(map[string]map[string][]string)
for domain, userToDeviceMap := range domainToDeviceKeys {
@@ -415,7 +416,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
// we can't safely return keys from the db when all devices are requested as we don't
// know if one has just been added.
if len(deviceIDs) > 0 {
- err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs)
+ err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
if err == nil {
continue
}
@@ -542,7 +543,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
respMu.Lock()
- err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
+ err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
respMu.Unlock()
if err != nil {
logrus.WithFields(logrus.Fields{
@@ -573,7 +574,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
// inspecting the failures map though so they can know it's a cached response.
for userID, dkeys := range devKeys {
// drop the error as it's already a failure at this point
- _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys)
+ _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
}
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
@@ -585,7 +586,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
}
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
- ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
+ ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
) error {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
@@ -598,9 +599,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
if len(deviceIDs) == 0 && len(keys) == 0 {
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
}
+ respMu.Lock()
if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
}
+ respMu.Unlock()
for _, key := range keys {
if len(key.KeyJSON) == 0 {
@@ -610,7 +613,9 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
}{key.DisplayName})
+ respMu.Lock()
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
+ respMu.Unlock()
}
return nil
}