aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-03-10 13:17:28 +0000
committerGitHub <noreply@github.com>2022-03-10 13:17:28 +0000
commite485f9c2bd15bca397229444399fa7e168eca43d (patch)
tree4efe11f99423ba450371521a8f50d5b4234ba7fe
parente1881627d18c541bfad0f5dc210b0786b8e55f35 (diff)
64-bit stream IDs for device list updates (#2267)
-rw-r--r--federationapi/consumers/keychange.go4
-rw-r--r--keyserver/api/api.go6
-rw-r--r--keyserver/internal/device_list_update.go2
-rw-r--r--keyserver/internal/device_list_update_test.go14
-rw-r--r--keyserver/internal/internal.go2
-rw-r--r--keyserver/storage/interface.go2
-rw-r--r--keyserver/storage/postgres/device_keys_table.go10
-rw-r--r--keyserver/storage/shared/storage.go12
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go14
-rw-r--r--keyserver/storage/storage_test.go2
-rw-r--r--keyserver/storage/tables/interface.go2
11 files changed, 33 insertions, 37 deletions
diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go
index 22dbc32d..33d716d2 100644
--- a/federationapi/consumers/keychange.go
+++ b/federationapi/consumers/keychange.go
@@ -203,9 +203,9 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
return err == nil
}
-func prevID(streamID int) []int {
+func prevID(streamID int64) []int64 {
if streamID <= 1 {
return nil
}
- return []int{streamID - 1}
+ return []int64{streamID - 1}
}
diff --git a/keyserver/api/api.go b/keyserver/api/api.go
index 54eb04f8..d361c622 100644
--- a/keyserver/api/api.go
+++ b/keyserver/api/api.go
@@ -70,7 +70,7 @@ type DeviceMessage struct {
*DeviceKeys `json:"DeviceKeys,omitempty"`
*eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
// A monotonically increasing number which represents device changes for this user.
- StreamID int
+ StreamID int64
DeviceChangeID int64
}
@@ -108,7 +108,7 @@ type DeviceKeys struct {
}
// WithStreamID returns a copy of this device message with the given stream ID
-func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage {
+func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage {
return DeviceMessage{
DeviceKeys: k,
StreamID: streamID,
@@ -281,7 +281,7 @@ type QueryDeviceMessagesRequest struct {
type QueryDeviceMessagesResponse struct {
// The latest stream ID
- StreamID int
+ StreamID int64
Devices []DeviceMessage
Error *KeyError
}
diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go
index 974d0196..4b2b8c18 100644
--- a/keyserver/internal/device_list_update.go
+++ b/keyserver/internal/device_list_update.go
@@ -109,7 +109,7 @@ type DeviceListUpdaterDatabase interface {
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user.
- PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
+ PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go
index ff939355..0033a508 100644
--- a/keyserver/internal/device_list_update_test.go
+++ b/keyserver/internal/device_list_update_test.go
@@ -46,7 +46,7 @@ func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) erro
type mockDeviceListUpdaterDatabase struct {
staleUsers map[string]bool
- prevIDsExist func(string, []int) bool
+ prevIDsExist func(string, []int64) bool
storedKeys []api.DeviceMessage
mu sync.Mutex // protect staleUsers
}
@@ -101,7 +101,7 @@ func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Contex
}
// PrevIDsExists returns true if all prev IDs exist for this user.
-func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) {
+func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
return d.prevIDsExist(userID, prevIDs), nil
}
@@ -139,7 +139,7 @@ func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrix
func TestUpdateHavePrevID(t *testing.T) {
db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool),
- prevIDsExist: func(string, []int) bool {
+ prevIDsExist: func(string, []int64) bool {
return true
},
}
@@ -151,7 +151,7 @@ func TestUpdateHavePrevID(t *testing.T) {
Deleted: false,
DeviceID: "FOO",
Keys: []byte(`{"key":"value"}`),
- PrevID: []int{0},
+ PrevID: []int64{0},
StreamID: 1,
UserID: "@alice:localhost",
}
@@ -185,7 +185,7 @@ func TestUpdateHavePrevID(t *testing.T) {
func TestUpdateNoPrevID(t *testing.T) {
db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool),
- prevIDsExist: func(string, []int) bool {
+ prevIDsExist: func(string, []int64) bool {
return false
},
}
@@ -226,7 +226,7 @@ func TestUpdateNoPrevID(t *testing.T) {
Deleted: false,
DeviceID: "another_device_id",
Keys: []byte(`{"key":"value"}`),
- PrevID: []int{3},
+ PrevID: []int64{3},
StreamID: 4,
UserID: remoteUserID,
}
@@ -268,7 +268,7 @@ func TestDebounce(t *testing.T) {
t.Skipf("panic on closed channel on GHA")
db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool),
- prevIDsExist: func(string, []int) bool {
+ prevIDsExist: func(string, []int64) bool {
return true
},
}
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index 0a8bef95..cc9d3a61 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -205,7 +205,7 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
}
return
}
- maxStreamID := 0
+ maxStreamID := int64(0)
for _, m := range msgs {
if m.StreamID > maxStreamID {
maxStreamID = m.StreamID
diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go
index 4dffe695..16e03477 100644
--- a/keyserver/storage/interface.go
+++ b/keyserver/storage/interface.go
@@ -49,7 +49,7 @@ type Database interface {
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user.
- PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
+ PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go
index 628301cf..ccd20cbd 100644
--- a/keyserver/storage/postgres/device_keys_table.go
+++ b/keyserver/storage/postgres/device_keys_table.go
@@ -121,7 +121,7 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
- var streamID int
+ var streamID int64
var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows {
@@ -138,15 +138,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
return nil
}
-func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
+func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
// nullable if there are no results
- var nullStream sql.NullInt32
+ var nullStream sql.NullInt64
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows {
err = nil
}
if nullStream.Valid {
- streamID = nullStream.Int32
+ streamID = nullStream.Int64
}
return
}
@@ -211,7 +211,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
}
dk.UserID = userID
var keyJSON string
- var streamID int
+ var streamID int64
var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err
diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go
index f2790c8d..03215b93 100644
--- a/keyserver/storage/shared/storage.go
+++ b/keyserver/storage/shared/storage.go
@@ -59,12 +59,8 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage)
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}
-func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) {
- sids := make([]int64, len(prevIDs))
- for i := range prevIDs {
- sids[i] = int64(prevIDs[i])
- }
- count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, sids)
+func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
+ count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
if err != nil {
return false, err
}
@@ -85,7 +81,7 @@ func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceM
func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
// work out the latest stream IDs for each user
- userIDToStreamID := make(map[string]int)
+ userIDToStreamID := make(map[string]int64)
for _, k := range keys {
userIDToStreamID[k.UserID] = 0
}
@@ -95,7 +91,7 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
if err != nil {
return err
}
- userIDToStreamID[userID] = int(streamID)
+ userIDToStreamID[userID] = streamID
}
// set the stream IDs for each key
for i := range keys {
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
index b461424c..e77b49b3 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/keyserver/storage/sqlite3/device_keys_table.go
@@ -145,7 +145,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
dk.Type = api.TypeDeviceKeyUpdate
dk.UserID = userID
var keyJSON string
- var streamID int
+ var streamID int64
var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err
@@ -166,7 +166,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
- var streamID int
+ var streamID int64
var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows {
@@ -183,15 +183,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
return nil
}
-func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
+func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
// nullable if there are no results
- var nullStream sql.NullInt32
+ var nullStream sql.NullInt64
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows {
err = nil
}
if nullStream.Valid {
- streamID = nullStream.Int32
+ streamID = nullStream.Int64
}
return
}
@@ -204,13 +204,13 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
}
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
// nullable if there are no results
- var count sql.NullInt32
+ var count sql.NullInt64
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
if err != nil {
return 0, err
}
if count.Valid {
- return int(count.Int32), nil
+ return int(count.Int64), nil
}
return 0, nil
}
diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go
index 4d513724..84d2098a 100644
--- a/keyserver/storage/storage_test.go
+++ b/keyserver/storage/storage_test.go
@@ -177,7 +177,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
- wantStreamIDs := map[string]int{
+ wantStreamIDs := map[string]int64{
"AAA": 3,
"another_device": 2,
}
diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go
index cd171959..f840cd1f 100644
--- a/keyserver/storage/tables/interface.go
+++ b/keyserver/storage/tables/interface.go
@@ -37,7 +37,7 @@ type OneTimeKeys interface {
type DeviceKeys interface {
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
- SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
+ SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error