aboutsummaryrefslogtreecommitdiff
path: root/syncapi
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2022-07-12 08:23:58 +0200
committerGitHub <noreply@github.com>2022-07-12 08:23:58 +0200
commit09f0ff14c88fec2ae5b3db79054378a8ced255a9 (patch)
treef433d00d3320c32ffb91ec695748a9afefbe4604 /syncapi
parent3ea21273bcc151b36eec412d0ec550642fe9b04f (diff)
Minor SendToDevice fix (#2565)
* Avoid unnecessary marshalling if sending to the local server * Fix ordering of ToDevice messages * Revive SendToDevice test
Diffstat (limited to 'syncapi')
-rw-r--r--syncapi/storage/postgres/send_to_device_table.go10
-rw-r--r--syncapi/storage/sqlite3/send_to_device_table.go8
-rw-r--r--syncapi/storage/storage_test.go185
3 files changed, 121 insertions, 82 deletions
diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go
index 47c1cdae..96d6844f 100644
--- a/syncapi/storage/postgres/send_to_device_table.go
+++ b/syncapi/storage/postgres/send_to_device_table.go
@@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/sirupsen/logrus"
)
const sendToDeviceSchema = `
@@ -51,7 +52,7 @@ const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
- ORDER BY id DESC
+ ORDER BY id ASC
`
const deleteSendToDeviceMessagesSQL = `
@@ -112,17 +113,18 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
return
}
- if id > lastPos {
- lastPos = id
- }
event := types.SendToDeviceEvent{
ID: id,
UserID: userID,
DeviceID: deviceID,
}
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
+ logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
continue
}
+ if id > lastPos {
+ lastPos = id
+ }
events = append(events, event)
}
if lastPos == 0 {
diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go
index 0b1d5bbf..5285acbe 100644
--- a/syncapi/storage/sqlite3/send_to_device_table.go
+++ b/syncapi/storage/sqlite3/send_to_device_table.go
@@ -49,7 +49,7 @@ const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
- ORDER BY id DESC
+ ORDER BY id ASC
`
const deleteSendToDeviceMessagesSQL = `
@@ -120,9 +120,6 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
logrus.WithError(err).Errorf("Failed to retrieve send-to-device message")
return
}
- if id > lastPos {
- lastPos = id
- }
event := types.SendToDeviceEvent{
ID: id,
UserID: userID,
@@ -132,6 +129,9 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
continue
}
+ if id > lastPos {
+ lastPos = id
+ }
events = append(events, event)
}
if lastPos == 0 {
diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go
index 563c92e3..c7415170 100644
--- a/syncapi/storage/storage_test.go
+++ b/syncapi/storage/storage_test.go
@@ -1,7 +1,9 @@
package storage_test
import (
+ "bytes"
"context"
+ "encoding/json"
"fmt"
"reflect"
"testing"
@@ -394,90 +396,125 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) {
from = topologyTokenBefore(t, db, paginatedEvents[len(paginatedEvents)-1].EventID())
}
}
+*/
func TestSendToDeviceBehaviour(t *testing.T) {
- //t.Parallel()
- db := MustCreateDatabase(t)
+ t.Parallel()
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+ deviceID := "one"
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := MustCreateDatabase(t, dbType)
+ defer close()
+ // At this point there should be no messages. We haven't sent anything
+ // yet.
+ _, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 0 {
+ t.Fatal("first call should have no updates")
+ }
- // At this point there should be no messages. We haven't sent anything
- // yet.
- _, events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{})
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
- t.Fatal("first call should have no updates")
- }
- err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{})
- if err != nil {
- return
- }
+ err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, 100)
+ if err != nil {
+ return
+ }
- // Try sending a message.
- streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{
- Sender: "bob",
- Type: "m.type",
- Content: json.RawMessage("{}"),
- })
- if err != nil {
- t.Fatal(err)
- }
+ // Try sending a message.
+ streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
+ Sender: bob.ID,
+ Type: "m.type",
+ Content: json.RawMessage("{}"),
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
- // At this point we should get exactly one message. We're sending the sync position
- // that we were given from the update and the send-to-device update will be updated
- // in the database to reflect that this was the sync position we sent the message at.
- _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
- t.Fatal("second call should have one update")
- }
- err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos})
- if err != nil {
- return
- }
+ // At this point we should get exactly one message. We're sending the sync position
+ // that we were given from the update and the send-to-device update will be updated
+ // in the database to reflect that this was the sync position we sent the message at.
+ streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if count := len(events); count != 1 {
+ t.Fatalf("second call should have one update, got %d", count)
+ }
+ err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos)
+ if err != nil {
+ return
+ }
- // At this point we should still have one message because we haven't progressed the
- // sync position yet. This is equivalent to the client failing to /sync and retrying
- // with the same position.
- _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
- t.Fatal("third call should have one update still")
- }
- err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos})
- if err != nil {
- return
- }
+ // At this point we should still have one message because we haven't progressed the
+ // sync position yet. This is equivalent to the client failing to /sync and retrying
+ // with the same position.
+ streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 1 {
+ t.Fatal("third call should have one update still")
+ }
+ err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos+1)
+ if err != nil {
+ return
+ }
- // At this point we should now have no updates, because we've progressed the sync
- // position. Therefore the update from before will not be sent again.
- _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 1})
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
- t.Fatal("fourth call should have no updates")
- }
- err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos + 1})
- if err != nil {
- return
- }
+ // At this point we should now have no updates, because we've progressed the sync
+ // position. Therefore the update from before will not be sent again.
+ _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos+1, streamPos+2)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 0 {
+ t.Fatal("fourth call should have no updates")
+ }
+ err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos+1)
+ if err != nil {
+ return
+ }
- // At this point we should still have no updates, because no new updates have been
- // sent.
- _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 2})
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
- t.Fatal("fifth call should have no updates")
- }
+ // At this point we should still have no updates, because no new updates have been
+ // sent.
+ _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+2)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 0 {
+ t.Fatal("fifth call should have no updates")
+ }
+
+ // Send some more messages and verify the ordering is correct ("in order of arrival")
+ var lastPos types.StreamPosition = 0
+ for i := 0; i < 10; i++ {
+ streamPos, err = db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
+ Sender: bob.ID,
+ Type: "m.type",
+ Content: json.RawMessage(fmt.Sprintf(`{ "count": %d }`, i)),
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ lastPos = streamPos
+ }
+
+ _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
+ if err != nil {
+ t.Fatalf("unable to get events: %v", err)
+ }
+
+ for i := 0; i < 10; i++ {
+ want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
+ got := events[i].Content
+ if !bytes.Equal(got, want) {
+ t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
+ }
+ }
+ })
}
+/*
func TestInviteBehaviour(t *testing.T) {
db := MustCreateDatabase(t)
inviteRoom1 := "!inviteRoom1:somewhere"