aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2022-10-07 10:54:42 +0200
committerGitHub <noreply@github.com>2022-10-07 10:54:42 +0200
commitb9d0e9f7ed7ce1f4be72a25c6f5185a6e809f019 (patch)
treea628e1c81752e0f5a296180357c6bac183a11135
parent453b50e1d3ce4b4906e050fa8ebe7bcc4c881600 (diff)
Add test for `QueryDeviceMessages` (#2773)
Adds tests for `QueryDeviceMessages` and also includes some optimizations to reduce allocations in the DB layer.
-rw-r--r--keyserver/internal/internal.go8
-rw-r--r--keyserver/internal/internal_test.go156
-rw-r--r--keyserver/storage/postgres/device_keys_table.go16
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go16
4 files changed, 172 insertions, 24 deletions
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index a0280dff..06fc4987 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -212,15 +212,13 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
return nil
}
maxStreamID := int64(0)
+ // remove deleted devices
+ var result []api.DeviceMessage
for _, m := range msgs {
if m.StreamID > maxStreamID {
maxStreamID = m.StreamID
}
- }
- // remove deleted devices
- var result []api.DeviceMessage
- for _, m := range msgs {
- if m.KeyJSON == nil {
+ if m.KeyJSON == nil || len(m.KeyJSON) == 0 {
continue
}
result = append(result, m)
diff --git a/keyserver/internal/internal_test.go b/keyserver/internal/internal_test.go
new file mode 100644
index 00000000..8a2c9c5d
--- /dev/null
+++ b/keyserver/internal/internal_test.go
@@ -0,0 +1,156 @@
+package internal_test
+
+import (
+ "context"
+ "reflect"
+ "testing"
+
+ "github.com/matrix-org/dendrite/keyserver/api"
+ "github.com/matrix-org/dendrite/keyserver/internal"
+ "github.com/matrix-org/dendrite/keyserver/storage"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/test"
+)
+
+func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
+ t.Helper()
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := storage.NewDatabase(nil, &config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ })
+ if err != nil {
+ t.Fatalf("failed to create new user db: %v", err)
+ }
+ return db, close
+}
+
+func Test_QueryDeviceMessages(t *testing.T) {
+ alice := test.NewUser(t)
+ type args struct {
+ req *api.QueryDeviceMessagesRequest
+ res *api.QueryDeviceMessagesResponse
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr bool
+ want *api.QueryDeviceMessagesResponse
+ }{
+ {
+ name: "no existing keys",
+ args: args{
+ req: &api.QueryDeviceMessagesRequest{
+ UserID: "@doesNotExist:localhost",
+ },
+ res: &api.QueryDeviceMessagesResponse{},
+ },
+ want: &api.QueryDeviceMessagesResponse{},
+ },
+ {
+ name: "existing user returns devices",
+ args: args{
+ req: &api.QueryDeviceMessagesRequest{
+ UserID: alice.ID,
+ },
+ res: &api.QueryDeviceMessagesResponse{},
+ },
+ want: &api.QueryDeviceMessagesResponse{
+ StreamID: 6,
+ Devices: []api.DeviceMessage{
+ {
+ Type: api.TypeDeviceKeyUpdate, StreamID: 5, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ DisplayName: "first device",
+ UserID: alice.ID,
+ KeyJSON: []byte("ghi"),
+ },
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, StreamID: 6, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "mySecondDevice",
+ DisplayName: "second device",
+ UserID: alice.ID,
+ KeyJSON: []byte("jkl"),
+ }, // streamID 6
+ },
+ },
+ },
+ },
+ }
+
+ deviceMessages := []api.DeviceMessage{
+ { // not the user we're looking for
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ UserID: "@doesNotExist:localhost",
+ },
+ // streamID 1 for this user
+ },
+ { // empty keyJSON will be ignored
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ UserID: alice.ID,
+ }, // streamID 1
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ UserID: alice.ID,
+ KeyJSON: []byte("abc"),
+ }, // streamID 2
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ UserID: alice.ID,
+ KeyJSON: []byte("def"),
+ }, // streamID 3
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ UserID: alice.ID,
+ KeyJSON: []byte(""),
+ }, // streamID 4
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ DisplayName: "first device",
+ UserID: alice.ID,
+ KeyJSON: []byte("ghi"),
+ }, // streamID 5
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "mySecondDevice",
+ UserID: alice.ID,
+ KeyJSON: []byte("jkl"),
+ DisplayName: "second device",
+ }, // streamID 6
+ },
+ }
+ ctx := context.Background()
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, closeDB := mustCreateDatabase(t, dbType)
+ defer closeDB()
+ if err := db.StoreLocalDeviceKeys(ctx, deviceMessages); err != nil {
+ t.Fatalf("failed to store local devicesKeys")
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := &internal.KeyInternalAPI{
+ DB: db,
+ }
+ if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr {
+ t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ got := tt.args.res
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("QueryDeviceMessages(): got:\n%+v, want:\n%+v", got, tt.want)
+ }
+ })
+ }
+ })
+}
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go
index ccd20cbd..2aa11c52 100644
--- a/keyserver/storage/postgres/device_keys_table.go
+++ b/keyserver/storage/postgres/device_keys_table.go
@@ -20,6 +20,7 @@ import (
"time"
"github.com/lib/pq"
+
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
@@ -204,20 +205,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
deviceIDMap[d] = true
}
var result []api.DeviceMessage
+ var displayName sql.NullString
for rows.Next() {
dk := api.DeviceMessage{
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{},
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ UserID: userID,
+ },
}
- dk.UserID = userID
- var keyJSON string
- var streamID int64
- var displayName sql.NullString
- if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
+ if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
return nil, err
}
- dk.KeyJSON = []byte(keyJSON)
- dk.StreamID = streamID
if displayName.Valid {
dk.DisplayName = displayName.String
}
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
index e77b49b3..73768da5 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/keyserver/storage/sqlite3/device_keys_table.go
@@ -137,21 +137,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
}
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
var result []api.DeviceMessage
+ var displayName sql.NullString
for rows.Next() {
dk := api.DeviceMessage{
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{},
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ UserID: userID,
+ },
}
- dk.Type = api.TypeDeviceKeyUpdate
- dk.UserID = userID
- var keyJSON string
- var streamID int64
- var displayName sql.NullString
- if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
+ if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
return nil, err
}
- dk.KeyJSON = []byte(keyJSON)
- dk.StreamID = streamID
if displayName.Valid {
dk.DisplayName = displayName.String
}