aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--clientapi/producers/eduserver.go26
-rw-r--r--clientapi/routing/routing.go25
-rw-r--r--clientapi/routing/sendtodevice.go70
-rw-r--r--cmd/dendrite-client-api-server/main.go2
-rw-r--r--cmd/dendrite-demo-libp2p/main.go2
-rw-r--r--cmd/dendrite-edu-server/main.go3
-rw-r--r--cmd/dendrite-federation-api-server/main.go2
-rw-r--r--cmd/dendrite-monolith-server/main.go2
-rw-r--r--cmd/dendritejs/main.go3
-rw-r--r--dendrite-config.yaml7
-rw-r--r--eduserver/api/input.go42
-rw-r--r--eduserver/api/output.go19
-rw-r--r--eduserver/cache/cache.go17
-rw-r--r--eduserver/cache/cache_test.go4
-rw-r--r--eduserver/eduserver.go15
-rw-r--r--eduserver/input/input.go98
-rw-r--r--federationapi/routing/send.go19
-rw-r--r--federationapi/routing/send_test.go8
-rw-r--r--go.mod2
-rw-r--r--go.sum4
-rw-r--r--internal/config/config.go2
-rw-r--r--internal/sql.go62
-rw-r--r--syncapi/consumers/eduserver_sendtodevice.go113
-rw-r--r--syncapi/consumers/eduserver_typing.go (renamed from syncapi/consumers/eduserver.go)0
-rw-r--r--syncapi/storage/interface.go32
-rw-r--r--syncapi/storage/postgres/send_to_device_table.go171
-rw-r--r--syncapi/storage/postgres/syncserver.go6
-rw-r--r--syncapi/storage/shared/syncserver.go148
-rw-r--r--syncapi/storage/sqlite3/send_to_device_table.go172
-rw-r--r--syncapi/storage/sqlite3/syncserver.go6
-rw-r--r--syncapi/storage/storage_test.go99
-rw-r--r--syncapi/storage/tables/interface.go39
-rw-r--r--syncapi/sync/notifier.go16
-rw-r--r--syncapi/sync/notifier_test.go2
-rw-r--r--syncapi/sync/requestpool.go71
-rw-r--r--syncapi/syncapi.go9
-rw-r--r--syncapi/types/types.go23
-rw-r--r--sytest-blacklist5
-rw-r--r--sytest-whitelist12
39 files changed, 1300 insertions, 58 deletions
diff --git a/clientapi/producers/eduserver.go b/clientapi/producers/eduserver.go
index 30c40fb7..102c1fad 100644
--- a/clientapi/producers/eduserver.go
+++ b/clientapi/producers/eduserver.go
@@ -14,6 +14,7 @@ package producers
import (
"context"
+ "encoding/json"
"time"
"github.com/matrix-org/dendrite/eduserver/api"
@@ -52,3 +53,28 @@ func (p *EDUServerProducer) SendTyping(
return err
}
+
+// SendToDevice sends a typing event to EDU server
+func (p *EDUServerProducer) SendToDevice(
+ ctx context.Context, sender, userID, deviceID, eventType string,
+ message interface{},
+) error {
+ js, err := json.Marshal(message)
+ if err != nil {
+ return err
+ }
+ requestData := api.InputSendToDeviceEvent{
+ UserID: userID,
+ DeviceID: deviceID,
+ SendToDeviceEvent: gomatrixserverlib.SendToDeviceEvent{
+ Sender: sender,
+ Type: eventType,
+ Content: js,
+ },
+ }
+ request := api.InputSendToDeviceEventRequest{
+ InputSendToDeviceEvent: requestData,
+ }
+ response := api.InputSendToDeviceEventResponse{}
+ return p.InputAPI.InputSendToDeviceEvent(ctx, &request, &response)
+}
diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go
index 934d9f06..83e399ac 100644
--- a/clientapi/routing/routing.go
+++ b/clientapi/routing/routing.go
@@ -274,6 +274,31 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
+ r0mux.Handle("/sendToDevice/{eventType}/{txnID}",
+ internal.MakeAuthAPI("send_to_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
+ vars, err := internal.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ txnID := vars["txnID"]
+ return SendToDevice(req, device, eduProducer, transactionsCache, vars["eventType"], &txnID)
+ }),
+ ).Methods(http.MethodPut, http.MethodOptions)
+
+ // This is only here because sytest refers to /unstable for this endpoint
+ // rather than r0. It's an exact duplicate of the above handler.
+ // TODO: Remove this if/when sytest is fixed!
+ unstableMux.Handle("/sendToDevice/{eventType}/{txnID}",
+ internal.MakeAuthAPI("send_to_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
+ vars, err := internal.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ txnID := vars["txnID"]
+ return SendToDevice(req, device, eduProducer, transactionsCache, vars["eventType"], &txnID)
+ }),
+ ).Methods(http.MethodPut, http.MethodOptions)
+
r0mux.Handle("/account/whoami",
internal.MakeAuthAPI("whoami", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return Whoami(req, device)
diff --git a/clientapi/routing/sendtodevice.go b/clientapi/routing/sendtodevice.go
new file mode 100644
index 00000000..5d3060d7
--- /dev/null
+++ b/clientapi/routing/sendtodevice.go
@@ -0,0 +1,70 @@
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package routing
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/clientapi/httputil"
+ "github.com/matrix-org/dendrite/clientapi/jsonerror"
+ "github.com/matrix-org/dendrite/clientapi/producers"
+ "github.com/matrix-org/dendrite/internal/transactions"
+ "github.com/matrix-org/util"
+)
+
+// SendToDevice handles PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}
+// sends the device events to the EDU Server
+func SendToDevice(
+ req *http.Request, device *authtypes.Device,
+ eduProducer *producers.EDUServerProducer,
+ txnCache *transactions.Cache,
+ eventType string, txnID *string,
+) util.JSONResponse {
+ if txnID != nil {
+ if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok {
+ return *res
+ }
+ }
+
+ var httpReq struct {
+ Messages map[string]map[string]json.RawMessage `json:"messages"`
+ }
+ resErr := httputil.UnmarshalJSONRequest(req, &httpReq)
+ if resErr != nil {
+ return *resErr
+ }
+
+ for userID, byUser := range httpReq.Messages {
+ for deviceID, message := range byUser {
+ if err := eduProducer.SendToDevice(
+ req.Context(), device.UserID, userID, deviceID, eventType, message,
+ ); err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("eduProducer.SendToDevice failed")
+ return jsonerror.InternalServerError()
+ }
+ }
+ }
+
+ res := util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: struct{}{},
+ }
+
+ if txnID != nil {
+ txnCache.AddTransaction(device.AccessToken, *txnID, &res)
+ }
+
+ return res
+}
diff --git a/cmd/dendrite-client-api-server/main.go b/cmd/dendrite-client-api-server/main.go
index e06adf8f..f919243d 100644
--- a/cmd/dendrite-client-api-server/main.go
+++ b/cmd/dendrite-client-api-server/main.go
@@ -39,7 +39,7 @@ func main() {
rsAPI := base.CreateHTTPRoomserverAPIs()
fsAPI := base.CreateHTTPFederationSenderAPIs()
rsAPI.SetFederationSenderAPI(fsAPI)
- eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New())
+ eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB)
clientapi.SetupClientAPIComponent(
base, deviceDB, accountDB, federation, keyRing,
diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go
index fc56b9bb..e9d01fd9 100644
--- a/cmd/dendrite-demo-libp2p/main.go
+++ b/cmd/dendrite-demo-libp2p/main.go
@@ -148,7 +148,7 @@ func main() {
&base.Base, keyRing, federation,
)
eduInputAPI := eduserver.SetupEDUServerComponent(
- &base.Base, cache.New(),
+ &base.Base, cache.New(), deviceDB,
)
asAPI := appservice.SetupAppServiceAPIComponent(
&base.Base, accountDB, deviceDB, federation, rsAPI, transactions.New(),
diff --git a/cmd/dendrite-edu-server/main.go b/cmd/dendrite-edu-server/main.go
index 66e17e57..ca0460f8 100644
--- a/cmd/dendrite-edu-server/main.go
+++ b/cmd/dendrite-edu-server/main.go
@@ -29,8 +29,9 @@ func main() {
logrus.WithError(err).Warn("BaseDendrite close failed")
}
}()
+ deviceDB := base.CreateDeviceDB()
- eduserver.SetupEDUServerComponent(base, cache.New())
+ eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB)
base.SetupAndServeHTTP(string(base.Cfg.Bind.EDUServer), string(base.Cfg.Listen.EDUServer))
diff --git a/cmd/dendrite-federation-api-server/main.go b/cmd/dendrite-federation-api-server/main.go
index 5425d117..af63b549 100644
--- a/cmd/dendrite-federation-api-server/main.go
+++ b/cmd/dendrite-federation-api-server/main.go
@@ -39,7 +39,7 @@ func main() {
rsAPI := base.CreateHTTPRoomserverAPIs()
asAPI := base.CreateHTTPAppServiceAPIs()
rsAPI.SetFederationSenderAPI(fsAPI)
- eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New())
+ eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB)
eduProducer := producers.NewEDUServerProducer(eduInputAPI)
federationapi.SetupFederationAPIComponent(
diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go
index 8367cd9d..ef114ccd 100644
--- a/cmd/dendrite-monolith-server/main.go
+++ b/cmd/dendrite-monolith-server/main.go
@@ -87,7 +87,7 @@ func main() {
}
eduInputAPI := eduserver.SetupEDUServerComponent(
- base, cache.New(),
+ base, cache.New(), deviceDB,
)
if base.EnableHTTPAPIs {
eduInputAPI = base.CreateHTTPEDUServerAPIs()
diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go
index 45f23d9a..9a02e71e 100644
--- a/cmd/dendritejs/main.go
+++ b/cmd/dendritejs/main.go
@@ -175,6 +175,7 @@ func main() {
cfg.Database.SyncAPI = "file:/idb/dendritejs_syncapi.db"
cfg.Kafka.Topics.UserUpdates = "user_updates"
cfg.Kafka.Topics.OutputTypingEvent = "output_typing_event"
+ cfg.Kafka.Topics.OutputSendToDeviceEvent = "output_send_to_device_event"
cfg.Kafka.Topics.OutputClientData = "output_client_data"
cfg.Kafka.Topics.OutputRoomEvent = "output_room_event"
cfg.Matrix.TrustedIDServers = []string{
@@ -206,7 +207,7 @@ func main() {
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node)
rsAPI := roomserver.SetupRoomServerComponent(base, keyRing, federation)
- eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New())
+ eduInputAPI := eduserver.SetupEDUServerComponent(base, cache.New(), deviceDB)
asQuery := appservice.SetupAppServiceAPIComponent(
base, accountDB, deviceDB, federation, rsAPI, transactions.New(),
)
diff --git a/dendrite-config.yaml b/dendrite-config.yaml
index 1802b8b7..a5b29597 100644
--- a/dendrite-config.yaml
+++ b/dendrite-config.yaml
@@ -104,7 +104,8 @@ kafka:
topics:
output_room_event: roomserverOutput
output_client_data: clientapiOutput
- output_typing_event: eduServerOutput
+ output_typing_event: eduServerTypingOutput
+ output_send_to_device_event: eduServerSendToDeviceOutput
user_updates: userUpdates
# The postgres connection configs for connecting to the databases e.g a postgres:// URI
@@ -137,8 +138,8 @@ listen:
federation_sender: "localhost:7776"
appservice_api: "localhost:7777"
edu_server: "localhost:7778"
- key_server: "localhost:7779"
- server_key_api: "localhost:7780"
+ key_server: "localhost:7779"
+ server_key_api: "localhost:7780"
# The configuration for tracing the dendrite components.
tracing:
diff --git a/eduserver/api/input.go b/eduserver/api/input.go
index 8b5b6d76..fa7f30cb 100644
--- a/eduserver/api/input.go
+++ b/eduserver/api/input.go
@@ -1,3 +1,7 @@
+// Copyright 2017 Vector Creations Ltd
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@@ -37,6 +41,12 @@ type InputTypingEvent struct {
OriginServerTS gomatrixserverlib.Timestamp `json:"origin_server_ts"`
}
+type InputSendToDeviceEvent struct {
+ UserID string `json:"user_id"`
+ DeviceID string `json:"device_id"`
+ gomatrixserverlib.SendToDeviceEvent
+}
+
// InputTypingEventRequest is a request to EDUServerInputAPI
type InputTypingEventRequest struct {
InputTypingEvent InputTypingEvent `json:"input_typing_event"`
@@ -45,6 +55,14 @@ type InputTypingEventRequest struct {
// InputTypingEventResponse is a response to InputTypingEvents
type InputTypingEventResponse struct{}
+// InputSendToDeviceEventRequest is a request to EDUServerInputAPI
+type InputSendToDeviceEventRequest struct {
+ InputSendToDeviceEvent InputSendToDeviceEvent `json:"input_send_to_device_event"`
+}
+
+// InputSendToDeviceEventResponse is a response to InputSendToDeviceEventRequest
+type InputSendToDeviceEventResponse struct{}
+
// EDUServerInputAPI is used to write events to the typing server.
type EDUServerInputAPI interface {
InputTypingEvent(
@@ -52,11 +70,20 @@ type EDUServerInputAPI interface {
request *InputTypingEventRequest,
response *InputTypingEventResponse,
) error
+
+ InputSendToDeviceEvent(
+ ctx context.Context,
+ request *InputSendToDeviceEventRequest,
+ response *InputSendToDeviceEventResponse,
+ ) error
}
// EDUServerInputTypingEventPath is the HTTP path for the InputTypingEvent API.
const EDUServerInputTypingEventPath = "/eduserver/input"
+// EDUServerInputSendToDeviceEventPath is the HTTP path for the InputSendToDeviceEvent API.
+const EDUServerInputSendToDeviceEventPath = "/eduserver/sendToDevice"
+
// NewEDUServerInputAPIHTTP creates a EDUServerInputAPI implemented by talking to a HTTP POST API.
func NewEDUServerInputAPIHTTP(eduServerURL string, httpClient *http.Client) (EDUServerInputAPI, error) {
if httpClient == nil {
@@ -70,7 +97,7 @@ type httpEDUServerInputAPI struct {
httpClient *http.Client
}
-// InputRoomEvents implements EDUServerInputAPI
+// InputTypingEvent implements EDUServerInputAPI
func (h *httpEDUServerInputAPI) InputTypingEvent(
ctx context.Context,
request *InputTypingEventRequest,
@@ -82,3 +109,16 @@ func (h *httpEDUServerInputAPI) InputTypingEvent(
apiURL := h.eduServerURL + EDUServerInputTypingEventPath
return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
+
+// InputSendToDeviceEvent implements EDUServerInputAPI
+func (h *httpEDUServerInputAPI) InputSendToDeviceEvent(
+ ctx context.Context,
+ request *InputSendToDeviceEventRequest,
+ response *InputSendToDeviceEventResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "InputSendToDeviceEvent")
+ defer span.Finish()
+
+ apiURL := h.eduServerURL + EDUServerInputSendToDeviceEventPath
+ return internalHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
diff --git a/eduserver/api/output.go b/eduserver/api/output.go
index 8696acf4..e6ded841 100644
--- a/eduserver/api/output.go
+++ b/eduserver/api/output.go
@@ -1,3 +1,7 @@
+// Copyright 2017 Vector Creations Ltd
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@@ -12,7 +16,11 @@
package api
-import "time"
+import (
+ "time"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
// OutputTypingEvent is an entry in typing server output kafka log.
// This contains the event with extra fields used to create 'm.typing' event
@@ -32,3 +40,12 @@ type TypingEvent struct {
UserID string `json:"user_id"`
Typing bool `json:"typing"`
}
+
+// OutputSendToDeviceEvent is an entry in the send-to-device output kafka log.
+// This contains the full event content, along with the user ID and device ID
+// to which it is destined.
+type OutputSendToDeviceEvent struct {
+ UserID string `json:"user_id"`
+ DeviceID string `json:"device_id"`
+ gomatrixserverlib.SendToDeviceEvent
+}
diff --git a/eduserver/cache/cache.go b/eduserver/cache/cache.go
index 46f7a2b1..dd535a6d 100644
--- a/eduserver/cache/cache.go
+++ b/eduserver/cache/cache.go
@@ -1,3 +1,7 @@
+// Copyright 2017 Vector Creations Ltd
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@@ -109,6 +113,19 @@ func (t *EDUCache) AddTypingUser(
return t.GetLatestSyncPosition()
}
+// AddSendToDeviceMessage increases the sync position for
+// send-to-device updates.
+// Returns the sync position before update, as the caller
+// will use this to record the current stream position
+// at the time that the send-to-device message was sent.
+func (t *EDUCache) AddSendToDeviceMessage() int64 {
+ t.Lock()
+ defer t.Unlock()
+ latestSyncPosition := t.latestSyncPosition
+ t.latestSyncPosition++
+ return latestSyncPosition
+}
+
// addUser with mutex lock & replace the previous timer.
// Returns the latest typing sync position after update.
func (t *EDUCache) addUser(
diff --git a/eduserver/cache/cache_test.go b/eduserver/cache/cache_test.go
index d1b2f8bd..c7d01879 100644
--- a/eduserver/cache/cache_test.go
+++ b/eduserver/cache/cache_test.go
@@ -1,3 +1,7 @@
+// Copyright 2017 Vector Creations Ltd
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
diff --git a/eduserver/eduserver.go b/eduserver/eduserver.go
index 14fbd332..6f664eb6 100644
--- a/eduserver/eduserver.go
+++ b/eduserver/eduserver.go
@@ -1,3 +1,7 @@
+// Copyright 2017 Vector Creations Ltd
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@@ -13,6 +17,7 @@
package eduserver
import (
+ "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/eduserver/input"
@@ -26,11 +31,15 @@ import (
func SetupEDUServerComponent(
base *basecomponent.BaseDendrite,
eduCache *cache.EDUCache,
+ deviceDB devices.Database,
) api.EDUServerInputAPI {
inputAPI := &input.EDUServerInputAPI{
- Cache: eduCache,
- Producer: base.KafkaProducer,
- OutputTypingEventTopic: string(base.Cfg.Kafka.Topics.OutputTypingEvent),
+ Cache: eduCache,
+ DeviceDB: deviceDB,
+ Producer: base.KafkaProducer,
+ OutputTypingEventTopic: string(base.Cfg.Kafka.Topics.OutputTypingEvent),
+ OutputSendToDeviceEventTopic: string(base.Cfg.Kafka.Topics.OutputSendToDeviceEvent),
+ ServerName: base.Cfg.Matrix.ServerName,
}
inputAPI.SetupHTTP(base.InternalAPIMux)
diff --git a/eduserver/input/input.go b/eduserver/input/input.go
index 73777e32..4e305195 100644
--- a/eduserver/input/input.go
+++ b/eduserver/input/input.go
@@ -1,3 +1,7 @@
+// Copyright 2017 Vector Creations Ltd
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@@ -20,11 +24,13 @@ import (
"github.com/Shopify/sarama"
"github.com/gorilla/mux"
+ "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/eduserver/cache"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
)
// EDUServerInputAPI implements api.EDUServerInputAPI
@@ -33,8 +39,14 @@ type EDUServerInputAPI struct {
Cache *cache.EDUCache
// The kafka topic to output new typing events to.
OutputTypingEventTopic string
+ // The kafka topic to output new send to device events to.
+ OutputSendToDeviceEventTopic string
// kafka producer
Producer sarama.SyncProducer
+ // device database
+ DeviceDB devices.Database
+ // our server name
+ ServerName gomatrixserverlib.ServerName
}
// InputTypingEvent implements api.EDUServerInputAPI
@@ -54,10 +66,20 @@ func (t *EDUServerInputAPI) InputTypingEvent(
t.Cache.RemoveUser(ite.UserID, ite.RoomID)
}
- return t.sendEvent(ite)
+ return t.sendTypingEvent(ite)
+}
+
+// InputTypingEvent implements api.EDUServerInputAPI
+func (t *EDUServerInputAPI) InputSendToDeviceEvent(
+ ctx context.Context,
+ request *api.InputSendToDeviceEventRequest,
+ response *api.InputSendToDeviceEventResponse,
+) error {
+ ise := &request.InputSendToDeviceEvent
+ return t.sendToDeviceEvent(ise)
}
-func (t *EDUServerInputAPI) sendEvent(ite *api.InputTypingEvent) error {
+func (t *EDUServerInputAPI) sendTypingEvent(ite *api.InputTypingEvent) error {
ev := &api.TypingEvent{
Type: gomatrixserverlib.MTyping,
RoomID: ite.RoomID,
@@ -90,6 +112,65 @@ func (t *EDUServerInputAPI) sendEvent(ite *api.InputTypingEvent) error {
return err
}
+func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) error {
+ devices := []string{}
+ localpart, domain, err := gomatrixserverlib.SplitID('@', ise.UserID)
+ if err != nil {
+ return err
+ }
+
+ // If the event is targeted locally then we want to expand the wildcard
+ // out into individual device IDs so that we can send them to each respective
+ // device. If the event isn't targeted locally then we can't expand the
+ // wildcard as we don't know about the remote devices, so instead we leave it
+ // as-is, so that the federation sender can send it on with the wildcard intact.
+ if domain == t.ServerName && ise.DeviceID == "*" {
+ devs, err := t.DeviceDB.GetDevicesByLocalpart(context.TODO(), localpart)
+ if err != nil {
+ return err
+ }
+ for _, dev := range devs {
+ devices = append(devices, dev.ID)
+ }
+ } else {
+ devices = append(devices, ise.DeviceID)
+ }
+
+ for _, device := range devices {
+ ote := &api.OutputSendToDeviceEvent{
+ UserID: ise.UserID,
+ DeviceID: device,
+ SendToDeviceEvent: ise.SendToDeviceEvent,
+ }
+
+ logrus.WithFields(logrus.Fields{
+ "user_id": ise.UserID,
+ "device_id": ise.DeviceID,
+ "event_type": ise.Type,
+ }).Info("handling send-to-device message")
+
+ eventJSON, err := json.Marshal(ote)
+ if err != nil {
+ logrus.WithError(err).Error("sendToDevice failed json.Marshal")
+ return err
+ }
+
+ m := &sarama.ProducerMessage{
+ Topic: string(t.OutputSendToDeviceEventTopic),
+ Key: sarama.StringEncoder(ote.UserID),
+ Value: sarama.ByteEncoder(eventJSON),
+ }
+
+ _, _, err = t.Producer.SendMessage(m)
+ if err != nil {
+ logrus.WithError(err).Error("sendToDevice failed t.Producer.SendMessage")
+ return err
+ }
+ }
+
+ return nil
+}
+
// SetupHTTP adds the EDUServerInputAPI handlers to the http.ServeMux.
func (t *EDUServerInputAPI) SetupHTTP(internalAPIMux *mux.Router) {
internalAPIMux.Handle(api.EDUServerInputTypingEventPath,
@@ -105,4 +186,17 @@ func (t *EDUServerInputAPI) SetupHTTP(internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
+ internalAPIMux.Handle(api.EDUServerInputSendToDeviceEventPath,
+ internal.MakeInternalAPI("inputSendToDeviceEvents", func(req *http.Request) util.JSONResponse {
+ var request api.InputSendToDeviceEventRequest
+ var response api.InputSendToDeviceEventResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := t.InputSendToDeviceEvent(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
}
diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go
index b514af0a..74b4c014 100644
--- a/federationapi/routing/send.go
+++ b/federationapi/routing/send.go
@@ -265,6 +265,25 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) {
if err := t.eduProducer.SendTyping(t.context, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil {
util.GetLogger(t.context).WithError(err).Error("Failed to send typing event to edu server")
}
+ case gomatrixserverlib.MDirectToDevice:
+ // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema
+ var directPayload gomatrixserverlib.ToDeviceMessage
+ if err := json.Unmarshal(e.Content, &directPayload); err != nil {
+ util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal send-to-device events")
+ continue
+ }
+ for userID, byUser := range directPayload.Messages {
+ for deviceID, message := range byUser {
+ // TODO: check that the user and the device actually exist here
+ if err := t.eduProducer.SendToDevice(t.context, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil {
+ util.GetLogger(t.context).WithError(err).WithFields(logrus.Fields{
+ "sender": directPayload.Sender,
+ "user_id": userID,
+ "device_id": deviceID,
+ }).Error("Failed to send send-to-device event to edu server")
+ }
+ }
+ }
default:
util.GetLogger(t.context).WithField("type", e.Type).Warn("unhandled edu")
}
diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go
index cb8aec6f..3e28a347 100644
--- a/federationapi/routing/send_test.go
+++ b/federationapi/routing/send_test.go
@@ -77,6 +77,14 @@ func (p *testEDUProducer) InputTypingEvent(
return nil
}
+func (p *testEDUProducer) InputSendToDeviceEvent(
+ ctx context.Context,
+ request *eduAPI.InputSendToDeviceEventRequest,
+ response *eduAPI.InputSendToDeviceEventResponse,
+) error {
+ return nil
+}
+
type testRoomserverAPI struct {
inputRoomEvents []api.InputRoomEvent
queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse
diff --git a/go.mod b/go.mod
index 4365ea50..cc60e1a2 100644
--- a/go.mod
+++ b/go.mod
@@ -18,7 +18,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
- github.com/matrix-org/gomatrixserverlib v0.0.0-20200528122156-fbb320a2ee61
+ github.com/matrix-org/gomatrixserverlib v0.0.0-20200601162724-79e93fe989cf
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7
github.com/mattn/go-sqlite3 v2.0.2+incompatible
diff --git a/go.sum b/go.sum
index c08cfa5d..6d9c2725 100644
--- a/go.sum
+++ b/go.sum
@@ -356,8 +356,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 h1:Yb+Wlf
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4=
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
-github.com/matrix-org/gomatrixserverlib v0.0.0-20200528122156-fbb320a2ee61 h1:3rgoGvj/skUWg+u9E6ycEFs2ZGenEjr28ZtAhAhmZeM=
-github.com/matrix-org/gomatrixserverlib v0.0.0-20200528122156-fbb320a2ee61/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
+github.com/matrix-org/gomatrixserverlib v0.0.0-20200601162724-79e93fe989cf h1:iT2dfJ6JmYNRZBQTXeCNwsZIvfkBbFggzclM8iKnbR0=
+github.com/matrix-org/gomatrixserverlib v0.0.0-20200601162724-79e93fe989cf/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y=
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
diff --git a/internal/config/config.go b/internal/config/config.go
index 2a95069a..a20cc0ea 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -152,6 +152,8 @@ type Dendrite struct {
OutputClientData Topic `yaml:"output_client_data"`
// Topic for eduserver/api.OutputTypingEvent events.
OutputTypingEvent Topic `yaml:"output_typing_event"`
+ // Topic for eduserver/api.OutputSendToDeviceEvent events.
+ OutputSendToDeviceEvent Topic `yaml:"output_send_to_device_event"`
// Topic for user updates (profile, presence)
UserUpdates Topic `yaml:"user_updates"`
}
diff --git a/internal/sql.go b/internal/sql.go
index d6a5a308..546954bd 100644
--- a/internal/sql.go
+++ b/internal/sql.go
@@ -1,4 +1,6 @@
// Copyright 2017 Vector Creations Ltd
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,9 +18,12 @@ package internal
import (
"database/sql"
+ "errors"
"fmt"
"runtime"
"time"
+
+ "go.uber.org/atomic"
)
// A Transaction is something that can be committed or rolledback.
@@ -107,3 +112,60 @@ type DbProperties interface {
MaxOpenConns() int
ConnMaxLifetime() time.Duration
}
+
+// TransactionWriter allows queuing database writes so that you don't
+// contend on database locks in, e.g. SQLite. Only one task will run
+// at a time on a given TransactionWriter.
+type TransactionWriter struct {
+ running atomic.Bool
+ todo chan transactionWriterTask
+}
+
+func NewTransactionWriter() *TransactionWriter {
+ return &TransactionWriter{
+ todo: make(chan transactionWriterTask),
+ }
+}
+
+// transactionWriterTask represents a specific task.
+type transactionWriterTask struct {
+ db *sql.DB
+ f func(txn *sql.Tx) error
+ wait chan error
+}
+
+// Do queues a task to be run by a TransactionWriter. The function
+// provided will be ran within a transaction as supplied by the
+// database parameter. This will block until the task is finished.
+func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error {
+ if w.todo == nil {
+ return errors.New("not initialised")
+ }
+ if !w.running.Load() {
+ go w.run()
+ }
+ task := transactionWriterTask{
+ db: db,
+ f: f,
+ wait: make(chan error, 1),
+ }
+ w.todo <- task
+ return <-task.wait
+}
+
+// run processes the tasks for a given transaction writer. Only one
+// of these goroutines will run at a time. A transaction will be
+// opened using the database object from the task and then this will
+// be passed as a parameter to the task function.
+func (w *TransactionWriter) run() {
+ if !w.running.CAS(false, true) {
+ return
+ }
+ defer w.running.Store(false)
+ for task := range w.todo {
+ task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
+ return task.f(txn)
+ })
+ close(task.wait)
+ }
+}
diff --git a/syncapi/consumers/eduserver_sendtodevice.go b/syncapi/consumers/eduserver_sendtodevice.go
new file mode 100644
index 00000000..48701803
--- /dev/null
+++ b/syncapi/consumers/eduserver_sendtodevice.go
@@ -0,0 +1,113 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package consumers
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/Shopify/sarama"
+ "github.com/matrix-org/dendrite/eduserver/api"
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/config"
+ "github.com/matrix-org/dendrite/syncapi/storage"
+ "github.com/matrix-org/dendrite/syncapi/sync"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ log "github.com/sirupsen/logrus"
+)
+
+// OutputSendToDeviceEventConsumer consumes events that originated in the EDU server.
+type OutputSendToDeviceEventConsumer struct {
+ sendToDeviceConsumer *internal.ContinualConsumer
+ db storage.Database
+ serverName gomatrixserverlib.ServerName // our server name
+ notifier *sync.Notifier
+}
+
+// NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer.
+// Call Start() to begin consuming from the EDU server.
+func NewOutputSendToDeviceEventConsumer(
+ cfg *config.Dendrite,
+ kafkaConsumer sarama.Consumer,
+ n *sync.Notifier,
+ store storage.Database,
+) *OutputSendToDeviceEventConsumer {
+
+ consumer := internal.ContinualConsumer{
+ Topic: string(cfg.Kafka.Topics.OutputSendToDeviceEvent),
+ Consumer: kafkaConsumer,
+ PartitionStore: store,
+ }
+
+ s := &OutputSendToDeviceEventConsumer{
+ sendToDeviceConsumer: &consumer,
+ db: store,
+ serverName: cfg.Matrix.ServerName,
+ notifier: n,
+ }
+
+ consumer.ProcessMessage = s.onMessage
+
+ return s
+}
+
+// Start consuming from EDU api
+func (s *OutputSendToDeviceEventConsumer) Start() error {
+ return s.sendToDeviceConsumer.Start()
+}
+
+func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
+ var output api.OutputSendToDeviceEvent
+ if err := json.Unmarshal(msg.Value, &output); err != nil {
+ // If the message was invalid, log it and move on to the next message in the stream
+ log.WithError(err).Errorf("EDU server output log: message parse failure")
+ return err
+ }
+
+ _, domain, err := gomatrixserverlib.SplitID('@', output.UserID)
+ if err != nil {
+ return err
+ }
+ if domain != s.serverName {
+ return nil
+ }
+
+ util.GetLogger(context.TODO()).WithFields(log.Fields{
+ "sender": output.Sender,
+ "user_id": output.UserID,
+ "device_id": output.DeviceID,
+ "event_type": output.Type,
+ }).Info("sync API received send-to-device event from EDU server")
+
+ streamPos := s.db.AddSendToDevice()
+
+ _, err = s.db.StoreNewSendForDeviceMessage(
+ context.TODO(), streamPos, output.UserID, output.DeviceID, output.SendToDeviceEvent,
+ )
+ if err != nil {
+ log.WithError(err).Errorf("failed to store send-to-device message")
+ return err
+ }
+
+ s.notifier.OnNewSendToDevice(
+ output.UserID,
+ []string{output.DeviceID},
+ types.NewStreamToken(0, streamPos),
+ )
+
+ return nil
+}
diff --git a/syncapi/consumers/eduserver.go b/syncapi/consumers/eduserver_typing.go
index 12b1efbc..12b1efbc 100644
--- a/syncapi/consumers/eduserver.go
+++ b/syncapi/consumers/eduserver_typing.go
diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go
index 7e1a40fd..566e5d58 100644
--- a/syncapi/storage/interface.go
+++ b/syncapi/storage/interface.go
@@ -55,10 +55,12 @@ type Database interface {
// sync response for the given user. Events returned will include any client
// transaction IDs associated with the given device. These transaction IDs come
// from when the device sent the event via an API that included a transaction
- // ID.
- IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
- // CompleteSync returns a complete /sync API response for the given user.
- CompleteSync(ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int) (*types.Response, error)
+ // ID. A response object must be provided for IncrementaSync to populate - it
+ // will not create one.
+ IncrementalSync(ctx context.Context, res *types.Response, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
+ // CompleteSync returns a complete /sync API response for the given user. A response object
+ // must be provided for CompleteSync to populate - it will not create one.
+ CompleteSync(ctx context.Context, res *types.Response, device authtypes.Device, numRecentEventsPerRoom int) (*types.Response, error)
// GetAccountDataInRange returns all account data for a given user inserted or
// updated between two given positions
// Returns a map following the format data[roomID] = []dataTypes
@@ -104,4 +106,26 @@ type Database interface {
StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent
// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.
SyncStreamPosition(ctx context.Context) (types.StreamPosition, error)
+ // AddSendToDevice increases the EDU position in the cache and returns the stream position.
+ AddSendToDevice() types.StreamPosition
+ // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns three lists:
+ // - "events": a list of send-to-device events that should be included in the sync
+ // - "changes": a list of send-to-device events that should be updated in the database by
+ // CleanSendToDeviceUpdates
+ // - "deletions": a list of send-to-device events which have been confirmed as sent and
+ // can be deleted altogether by CleanSendToDeviceUpdates
+ // The token supplied should be the current requested sync token, e.g. from the "since"
+ // parameter.
+ SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
+ // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
+ StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
+ // CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the
+ // result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows
+ // SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after
+ // starting to wait for an incremental sync with timeout).
+ // The token supplied should be the current requested sync token, e.g. from the "since"
+ // parameter.
+ CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
+ // SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent.
+ SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error)
}
diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go
new file mode 100644
index 00000000..335a05ef
--- /dev/null
+++ b/syncapi/storage/postgres/send_to_device_table.go
@@ -0,0 +1,171 @@
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/syncapi/storage/tables"
+ "github.com/matrix-org/dendrite/syncapi/types"
+)
+
+const sendToDeviceSchema = `
+CREATE SEQUENCE IF NOT EXISTS syncapi_send_to_device_id;
+
+-- Stores send-to-device messages.
+CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
+ -- The ID that uniquely identifies this message.
+ id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_send_to_device_id'),
+ -- The user ID to send the message to.
+ user_id TEXT NOT NULL,
+ -- The device ID to send the message to.
+ device_id TEXT NOT NULL,
+ -- The event content JSON.
+ content TEXT NOT NULL,
+ -- The token that was supplied to the /sync at the time that this
+ -- message was included in a sync response, or NULL if we haven't
+ -- included it in a /sync response yet.
+ sent_by_token TEXT
+);
+`
+
+const insertSendToDeviceMessageSQL = `
+ INSERT INTO syncapi_send_to_device (user_id, device_id, content)
+ VALUES ($1, $2, $3)
+`
+
+const countSendToDeviceMessagesSQL = `
+ SELECT COUNT(*)
+ FROM syncapi_send_to_device
+ WHERE user_id = $1 AND device_id = $2
+`
+
+const selectSendToDeviceMessagesSQL = `
+ SELECT id, user_id, device_id, content, sent_by_token
+ FROM syncapi_send_to_device
+ WHERE user_id = $1 AND device_id = $2
+ ORDER BY id DESC
+`
+
+const updateSentSendToDeviceMessagesSQL = `
+ UPDATE syncapi_send_to_device SET sent_by_token = $1
+ WHERE id = ANY($2)
+`
+
+const deleteSendToDeviceMessagesSQL = `
+ DELETE FROM syncapi_send_to_device WHERE id = ANY($1)
+`
+
+type sendToDeviceStatements struct {
+ insertSendToDeviceMessageStmt *sql.Stmt
+ countSendToDeviceMessagesStmt *sql.Stmt
+ selectSendToDeviceMessagesStmt *sql.Stmt
+ updateSentSendToDeviceMessagesStmt *sql.Stmt
+ deleteSendToDeviceMessagesStmt *sql.Stmt
+}
+
+func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
+ s := &sendToDeviceStatements{}
+ _, err := db.Exec(sendToDeviceSchema)
+ if err != nil {
+ return nil, err
+ }
+ if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
+ return nil, err
+ }
+ if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
+ return nil, err
+ }
+ if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
+ return nil, err
+ }
+ if s.updateSentSendToDeviceMessagesStmt, err = db.Prepare(updateSentSendToDeviceMessagesSQL); err != nil {
+ return nil, err
+ }
+ if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
+ return nil, err
+ }
+ return s, nil
+}
+
+func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
+ ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
+) (err error) {
+ _, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
+ return
+}
+
+func (s *sendToDeviceStatements) CountSendToDeviceMessages(
+ ctx context.Context, txn *sql.Tx, userID, deviceID string,
+) (count int, err error) {
+ row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
+ if err = row.Scan(&count); err != nil {
+ return
+ }
+ return count, nil
+}
+
+func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
+ ctx context.Context, txn *sql.Tx, userID, deviceID string,
+) (events []types.SendToDeviceEvent, err error) {
+ rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
+ if err != nil {
+ return
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
+
+ for rows.Next() {
+ var id types.SendToDeviceNID
+ var userID, deviceID, content string
+ var sentByToken *string
+ if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil {
+ return
+ }
+ event := types.SendToDeviceEvent{
+ ID: id,
+ UserID: userID,
+ DeviceID: deviceID,
+ }
+ if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
+ return
+ }
+ if sentByToken != nil {
+ if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
+ event.SentByToken = &token
+ }
+ }
+ events = append(events, event)
+ }
+
+ return events, rows.Err()
+}
+
+func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
+ ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
+) (err error) {
+ _, err = txn.Stmt(s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids))
+ return
+}
+
+func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
+ ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID,
+) (err error) {
+ _, err = txn.Stmt(s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids))
+ return
+}
diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go
index dc73350a..8a8f964a 100644
--- a/syncapi/storage/postgres/syncserver.go
+++ b/syncapi/storage/postgres/syncserver.go
@@ -69,6 +69,10 @@ func NewDatabase(dbDataSourceName string, dbProperties internal.DbProperties) (*
if err != nil {
return nil, err
}
+ sendToDevice, err := NewPostgresSendToDeviceTable(d.db)
+ if err != nil {
+ return nil, err
+ }
d.Database = shared.Database{
DB: d.db,
Invites: invites,
@@ -77,6 +81,8 @@ func NewDatabase(dbDataSourceName string, dbProperties internal.DbProperties) (*
Topology: topology,
CurrentRoomState: currState,
BackwardExtremities: backwardExtremities,
+ SendToDevice: sendToDevice,
+ SendToDeviceWriter: internal.NewTransactionWriter(),
EDUCache: cache.New(),
}
return &d, nil
diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go
index 888f85e0..497c043a 100644
--- a/syncapi/storage/shared/syncserver.go
+++ b/syncapi/storage/shared/syncserver.go
@@ -1,3 +1,17 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package shared
import (
@@ -27,6 +41,8 @@ type Database struct {
Topology tables.Topology
CurrentRoomState tables.CurrentRoomState
BackwardExtremities tables.BackwardsExtremities
+ SendToDevice tables.SendToDevice
+ SendToDeviceWriter *internal.TransactionWriter
EDUCache *cache.EDUCache
}
@@ -89,6 +105,10 @@ func (d *Database) RemoveTypingUser(
return types.StreamPosition(d.EDUCache.RemoveUser(userID, roomID))
}
+func (d *Database) AddSendToDevice() types.StreamPosition {
+ return types.StreamPosition(d.EDUCache.AddSendToDeviceMessage())
+}
+
func (d *Database) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
d.EDUCache.SetTimeoutCallback(fn)
}
@@ -528,14 +548,14 @@ func (d *Database) addEDUDeltaToResponse(
}
func (d *Database) IncrementalSync(
- ctx context.Context,
+ ctx context.Context, res *types.Response,
device authtypes.Device,
fromPos, toPos types.StreamingToken,
numRecentEventsPerRoom int,
wantFullState bool,
) (*types.Response, error) {
nextBatchPos := fromPos.WithUpdates(toPos)
- res := types.NewResponse(nextBatchPos)
+ res.NextBatch = nextBatchPos.String()
var joinedRoomIDs []string
var err error
@@ -568,12 +588,12 @@ func (d *Database) IncrementalSync(
// getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed
// to it. It returns toPos and joinedRoomIDs for use of adding EDUs.
+// nolint:nakedret
func (d *Database) getResponseWithPDUsForCompleteSync(
- ctx context.Context,
+ ctx context.Context, res *types.Response,
userID string,
numRecentEventsPerRoom int,
) (
- res *types.Response,
toPos types.StreamingToken,
joinedRoomIDs []string,
err error,
@@ -604,7 +624,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
To: toPos.PDUPosition(),
}
- res = types.NewResponse(toPos)
+ res.NextBatch = toPos.String()
// Extract room state and recent events for all rooms the user is joined to.
joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
@@ -662,14 +682,15 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
}
succeeded = true
- return res, toPos, joinedRoomIDs, err
+ return //res, toPos, joinedRoomIDs, err
}
func (d *Database) CompleteSync(
- ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int,
+ ctx context.Context, res *types.Response,
+ device authtypes.Device, numRecentEventsPerRoom int,
) (*types.Response, error) {
- res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync(
- ctx, device.UserID, numRecentEventsPerRoom,
+ toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync(
+ ctx, res, device.UserID, numRecentEventsPerRoom,
)
if err != nil {
return nil, err
@@ -1028,6 +1049,115 @@ func (d *Database) currentStateStreamEventsForRoom(
return s, nil
}
+func (d *Database) SendToDeviceUpdatesWaiting(
+ ctx context.Context, userID, deviceID string,
+) (bool, error) {
+ count, err := d.SendToDevice.CountSendToDeviceMessages(ctx, nil, userID, deviceID)
+ if err != nil {
+ return false, err
+ }
+ return count > 0, nil
+}
+
+func (d *Database) AddSendToDeviceEvent(
+ ctx context.Context, txn *sql.Tx,
+ userID, deviceID, content string,
+) error {
+ return d.SendToDevice.InsertSendToDeviceMessage(
+ ctx, txn, userID, deviceID, content,
+ )
+}
+
+func (d *Database) StoreNewSendForDeviceMessage(
+ ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
+) (types.StreamPosition, error) {
+ j, err := json.Marshal(event)
+ if err != nil {
+ return streamPos, err
+ }
+ // Delegate the database write task to the SendToDeviceWriter. It'll guarantee
+ // that we don't lock the table for writes in more than one place.
+ err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error {
+ return d.AddSendToDeviceEvent(
+ ctx, txn, userID, deviceID, string(j),
+ )
+ })
+ if err != nil {
+ return streamPos, err
+ }
+ return streamPos, nil
+}
+
+func (d *Database) SendToDeviceUpdatesForSync(
+ ctx context.Context,
+ userID, deviceID string,
+ token types.StreamingToken,
+) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) {
+ // First of all, get our send-to-device updates for this user.
+ events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
+ }
+
+ // If there's nothing to do then stop here.
+ if len(events) == 0 {
+ return nil, nil, nil, nil
+ }
+
+ // Work out whether we need to update any of the database entries.
+ toReturn := []types.SendToDeviceEvent{}
+ toUpdate := []types.SendToDeviceNID{}
+ toDelete := []types.SendToDeviceNID{}
+ for _, event := range events {
+ if event.SentByToken == nil {
+ // If the event has no sent-by token yet then we haven't attempted to send
+ // it. Record the current requested sync token in the database.
+ toUpdate = append(toUpdate, event.ID)
+ toReturn = append(toReturn, event)
+ event.SentByToken = &token
+ } else if token.IsAfter(*event.SentByToken) {
+ // The event had a sync token, therefore we've sent it before. The current
+ // sync token is now after the stored one so we can assume that the client
+ // successfully completed the previous sync (it would re-request it otherwise)
+ // so we can remove the entry from the database.
+ toDelete = append(toDelete, event.ID)
+ } else {
+ // It looks like the sync is being re-requested, maybe it timed out or
+ // failed. Re-send any that should have been acknowledged by now.
+ toReturn = append(toReturn, event)
+ }
+ }
+
+ return toReturn, toUpdate, toDelete, nil
+}
+
+func (d *Database) CleanSendToDeviceUpdates(
+ ctx context.Context,
+ toUpdate, toDelete []types.SendToDeviceNID,
+ token types.StreamingToken,
+) (err error) {
+ if len(toUpdate) == 0 && len(toDelete) == 0 {
+ return nil
+ }
+ // If we need to write to the database then we'll ask the SendToDeviceWriter to
+ // do that for us. It'll guarantee that we don't lock the table for writes in
+ // more than one place.
+ err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error {
+ // Delete any send-to-device messages marked for deletion.
+ if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
+ return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
+ }
+
+ // Now update any outstanding send-to-device messages with the new sync token.
+ if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil {
+ return fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err)
+ }
+
+ return nil
+ })
+ return
+}
+
// There may be some overlap where events in stateEvents are already in recentEvents, so filter
// them out so we don't include them twice in the /sync response. They should be in recentEvents
// only, so clients get to the correct state once they have rolled forward.
diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go
new file mode 100644
index 00000000..0d03f23e
--- /dev/null
+++ b/syncapi/storage/sqlite3/send_to_device_table.go
@@ -0,0 +1,172 @@
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "strings"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/syncapi/storage/tables"
+ "github.com/matrix-org/dendrite/syncapi/types"
+)
+
+const sendToDeviceSchema = `
+-- Stores send-to-device messages.
+CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
+ -- The ID that uniquely identifies this message.
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ -- The user ID to send the message to.
+ user_id TEXT NOT NULL,
+ -- The device ID to send the message to.
+ device_id TEXT NOT NULL,
+ -- The event content JSON.
+ content TEXT NOT NULL,
+ -- The token that was supplied to the /sync at the time that this
+ -- message was included in a sync response, or NULL if we haven't
+ -- included it in a /sync response yet.
+ sent_by_token TEXT
+);
+`
+
+const insertSendToDeviceMessageSQL = `
+ INSERT INTO syncapi_send_to_device (user_id, device_id, content)
+ VALUES ($1, $2, $3)
+`
+
+const countSendToDeviceMessagesSQL = `
+ SELECT COUNT(*)
+ FROM syncapi_send_to_device
+ WHERE user_id = $1 AND device_id = $2
+`
+
+const selectSendToDeviceMessagesSQL = `
+ SELECT id, user_id, device_id, content, sent_by_token
+ FROM syncapi_send_to_device
+ WHERE user_id = $1 AND device_id = $2
+ ORDER BY id DESC
+`
+
+const updateSentSendToDeviceMessagesSQL = `
+ UPDATE syncapi_send_to_device SET sent_by_token = $1
+ WHERE id IN ($2)
+`
+
+const deleteSendToDeviceMessagesSQL = `
+ DELETE FROM syncapi_send_to_device WHERE id IN ($1)
+`
+
+type sendToDeviceStatements struct {
+ insertSendToDeviceMessageStmt *sql.Stmt
+ selectSendToDeviceMessagesStmt *sql.Stmt
+ countSendToDeviceMessagesStmt *sql.Stmt
+}
+
+func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
+ s := &sendToDeviceStatements{}
+ _, err := db.Exec(sendToDeviceSchema)
+ if err != nil {
+ return nil, err
+ }
+ if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
+ return nil, err
+ }
+ if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
+ return nil, err
+ }
+ if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
+ return nil, err
+ }
+ return s, nil
+}
+
+func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
+ ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
+) (err error) {
+ _, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
+ return
+}
+
+func (s *sendToDeviceStatements) CountSendToDeviceMessages(
+ ctx context.Context, txn *sql.Tx, userID, deviceID string,
+) (count int, err error) {
+ row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
+ if err = row.Scan(&count); err != nil {
+ return
+ }
+ return count, nil
+}
+
+func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
+ ctx context.Context, txn *sql.Tx, userID, deviceID string,
+) (events []types.SendToDeviceEvent, err error) {
+ rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
+ if err != nil {
+ return
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
+
+ for rows.Next() {
+ var id types.SendToDeviceNID
+ var userID, deviceID, content string
+ var sentByToken *string
+ if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil {
+ return
+ }
+ event := types.SendToDeviceEvent{
+ ID: id,
+ UserID: userID,
+ DeviceID: deviceID,
+ }
+ if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
+ return
+ }
+ if sentByToken != nil {
+ if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
+ event.SentByToken = &token
+ }
+ }
+ events = append(events, event)
+ }
+
+ return events, rows.Err()
+}
+
+func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
+ ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
+) (err error) {
+ query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", internal.QueryVariadic(1+len(nids)), 1)
+ params := make([]interface{}, 1+len(nids))
+ params[0] = token
+ for k, v := range nids {
+ params[k+1] = v
+ }
+ _, err = txn.ExecContext(ctx, query, params...)
+ return
+}
+
+func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
+ ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID,
+) (err error) {
+ query := strings.Replace(deleteSendToDeviceMessagesSQL, "($1)", internal.QueryVariadic(len(nids)), 1)
+ params := make([]interface{}, 1+len(nids))
+ for k, v := range nids {
+ params[k] = v
+ }
+ _, err = txn.ExecContext(ctx, query, params...)
+ return
+}
diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go
index 8ab1d404..5ba07617 100644
--- a/syncapi/storage/sqlite3/syncserver.go
+++ b/syncapi/storage/sqlite3/syncserver.go
@@ -95,6 +95,10 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil {
return err
}
+ sendToDevice, err := NewSqliteSendToDeviceTable(d.db)
+ if err != nil {
+ return err
+ }
d.Database = shared.Database{
DB: d.db,
Invites: invites,
@@ -103,6 +107,8 @@ func (d *SyncServerDatasource) prepare() (err error) {
BackwardExtremities: bwExtrem,
CurrentRoomState: roomState,
Topology: topology,
+ SendToDevice: sendToDevice,
+ SendToDeviceWriter: internal.NewTransactionWriter(),
EDUCache: cache.New(),
}
return nil
diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go
index bb8554f4..4661ede4 100644
--- a/syncapi/storage/storage_test.go
+++ b/syncapi/storage/storage_test.go
@@ -3,6 +3,7 @@ package storage_test
import (
"context"
"crypto/ed25519"
+ "encoding/json"
"fmt"
"testing"
"time"
@@ -157,7 +158,8 @@ func TestSyncResponse(t *testing.T) {
from := types.NewStreamToken( // pretend we are at the penultimate event
positions[len(positions)-2], types.StreamPosition(0),
)
- return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false)
+ res := types.NewResponse()
+ return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
},
WantTimeline: events[len(events)-1:],
},
@@ -169,8 +171,9 @@ func TestSyncResponse(t *testing.T) {
from := types.NewStreamToken( // pretend we are 10 events behind
positions[len(positions)-11], types.StreamPosition(0),
)
+ res := types.NewResponse()
// limit is set to 5
- return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false)
+ return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
},
// want the last 5 events, NOT the last 10.
WantTimeline: events[len(events)-5:],
@@ -180,8 +183,9 @@ func TestSyncResponse(t *testing.T) {
{
Name: "CompleteSync limited",
DoSync: func() (*types.Response, error) {
+ res := types.NewResponse()
// limit set to 5
- return db.CompleteSync(ctx, testUserDeviceA, 5)
+ return db.CompleteSync(ctx, res, testUserDeviceA, 5)
},
// want the last 5 events
WantTimeline: events[len(events)-5:],
@@ -193,7 +197,8 @@ func TestSyncResponse(t *testing.T) {
{
Name: "CompleteSync",
DoSync: func() (*types.Response, error) {
- return db.CompleteSync(ctx, testUserDeviceA, len(events)+1)
+ res := types.NewResponse()
+ return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
},
WantTimeline: events,
// We want no state at all as that field in /sync is the delta between the token (beginning of time)
@@ -234,7 +239,8 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
positions[len(positions)-2], types.StreamPosition(0),
)
- res, err := db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false)
+ res := types.NewResponse()
+ res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
if err != nil {
t.Fatalf("failed to IncrementalSync with latest token")
}
@@ -512,6 +518,89 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) {
}
}
+func TestSendToDeviceBehaviour(t *testing.T) {
+ //t.Parallel()
+ db := MustCreateDatabase(t)
+
+ // At this point there should be no messages. We haven't sent anything
+ // yet.
+ events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
+ t.Fatal("first call should have no updates")
+ }
+ err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0))
+ if err != nil {
+ return
+ }
+
+ // Try sending a message.
+ streamPos, err := db.StoreNewSendForDeviceMessage(ctx, types.StreamPosition(0), "alice", "one", gomatrixserverlib.SendToDeviceEvent{
+ Sender: "bob",
+ Type: "m.type",
+ Content: json.RawMessage("{}"),
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // At this point we should get exactly one message. We're sending the sync position
+ // that we were given from the update and the send-to-device update will be updated
+ // in the database to reflect that this was the sync position we sent the message at.
+ events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
+ t.Fatal("second call should have one update")
+ }
+ err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
+ if err != nil {
+ return
+ }
+
+ // At this point we should still have one message because we haven't progressed the
+ // sync position yet. This is equivalent to the client failing to /sync and retrying
+ // with the same position.
+ events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
+ t.Fatal("third call should have one update still")
+ }
+ err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
+ if err != nil {
+ return
+ }
+
+ // At this point we should now have no updates, because we've progressed the sync
+ // position. Therefore the update from before will not be sent again.
+ events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
+ t.Fatal("fourth call should have no updates")
+ }
+ err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1))
+ if err != nil {
+ return
+ }
+
+ // At this point we should still have no updates, because no new updates have been
+ // sent.
+ events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
+ t.Fatal("fifth call should have no updates")
+ }
+}
+
func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) {
if len(gots) != len(wants) {
t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants))
diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go
index bc3b6941..0b7d1595 100644
--- a/syncapi/storage/tables/interface.go
+++ b/syncapi/storage/tables/interface.go
@@ -1,3 +1,17 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package tables
import (
@@ -94,3 +108,28 @@ type BackwardsExtremities interface {
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
}
+
+// SendToDevice tracks send-to-device messages which are sent to individual
+// clients. Each message gets inserted into this table at the point that we
+// receive it from the EDU server.
+//
+// We're supposed to try and do our best to deliver send-to-device messages
+// once, but the only way that we can really guarantee that they have been
+// delivered is if the client successfully requests the next sync as given
+// in the next_batch. Each time the device syncs, we will request all of the
+// updates that either haven't been sent yet, along with all updates that we
+// *have* sent but we haven't confirmed to have been received yet. If it's the
+// first time we're sending a given update then we update the table to say
+// what the "since" parameter was when we tried to send it.
+//
+// When the client syncs again, if their "since" parameter is *later* than
+// the recorded one, we drop the entry from the DB as it's "sent". If the
+// sync parameter isn't later then we will keep including the updates in the
+// sync response, as the client is seemingly trying to repeat the same /sync.
+type SendToDevice interface {
+ InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (err error)
+ SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error)
+ UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error)
+ DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
+ CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
+}
diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go
index 9b410a0c..325e7535 100644
--- a/syncapi/sync/notifier.go
+++ b/syncapi/sync/notifier.go
@@ -120,6 +120,18 @@ func (n *Notifier) OnNewEvent(
}
}
+func (n *Notifier) OnNewSendToDevice(
+ userID string, deviceIDs []string,
+ posUpdate types.StreamingToken,
+) {
+ n.streamLock.Lock()
+ defer n.streamLock.Unlock()
+ latestPos := n.currPos.WithUpdates(posUpdate)
+ n.currPos = latestPos
+
+ n.wakeupUserDevice(userID, deviceIDs, latestPos)
+}
+
// GetListener returns a UserStreamListener that can be used to wait for
// updates for a user. Must be closed.
// notify for anything before sincePos
@@ -189,8 +201,8 @@ func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) {
// wakeupUserDevice will wake up the sync stream for a specific user device. Other
// device streams will be left alone.
// nolint:unused
-func (n *Notifier) wakeupUserDevice(userDevices map[string]string, newPos types.StreamingToken) {
- for userID, deviceID := range userDevices {
+func (n *Notifier) wakeupUserDevice(userID string, deviceIDs []string, newPos types.StreamingToken) {
+ for _, deviceID := range deviceIDs {
if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil {
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
}
diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go
index 14ddef20..13231557 100644
--- a/syncapi/sync/notifier_test.go
+++ b/syncapi/sync/notifier_test.go
@@ -172,7 +172,7 @@ func TestCorrectStreamWakeup(t *testing.T) {
time.Sleep(1 * time.Second)
wake := "two"
- n.wakeupUserDevice(map[string]string{alice: wake}, syncPositionAfter)
+ n.wakeupUserDevice(alice, []string{wake}, syncPositionAfter)
if result := <-awoken; result != wake {
t.Fatalf("expected to wake %q, got %q", wake, result)
diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go
index bd29b333..8b93cad4 100644
--- a/syncapi/sync/requestpool.go
+++ b/syncapi/sync/requestpool.go
@@ -1,4 +1,6 @@
// Copyright 2017 Vector Creations Ltd
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,6 +17,7 @@
package sync
import (
+ "context"
"net/http"
"time"
@@ -54,17 +57,18 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
JSON: jsonerror.Unknown(err.Error()),
}
}
+
logger := util.GetLogger(req.Context()).WithFields(log.Fields{
- "userID": device.UserID,
- "deviceID": device.ID,
- "since": syncReq.since,
- "timeout": syncReq.timeout,
- "limit": syncReq.limit,
+ "user_id": device.UserID,
+ "device_id": device.ID,
+ "since": syncReq.since,
+ "timeout": syncReq.timeout,
+ "limit": syncReq.limit,
})
currPos := rp.notifier.CurrentPosition()
- if shouldReturnImmediately(syncReq) {
+ if rp.shouldReturnImmediately(syncReq) {
syncData, err = rp.currentSyncForUser(*syncReq, currPos)
if err != nil {
logger.WithError(err).Error("rp.currentSyncForUser failed")
@@ -116,7 +120,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
// response. This ensures that we don't waste the hard work
// of calculating the sync only to get timed out before we
// can respond
-
syncData, err = rp.currentSyncForUser(*syncReq, currPos)
if err != nil {
logger.WithError(err).Error("rp.currentSyncForUser failed")
@@ -134,19 +137,59 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
}
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
+ res = types.NewResponse()
+
+ since := types.NewStreamToken(0, 0)
+ if req.since != nil {
+ since = *req.since
+ }
+
+ // See if we have any new tasks to do for the send-to-device messaging.
+ events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, since)
+ if err != nil {
+ return nil, err
+ }
+
// TODO: handle ignored users
if req.since == nil {
- res, err = rp.db.CompleteSync(req.ctx, req.device, req.limit)
+ res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit)
} else {
- res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState)
+ res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState)
}
-
if err != nil {
return
}
accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter)
+ if err != nil {
+ return
+ }
+
+ // Before we return the sync response, make sure that we take action on
+ // any send-to-device database updates or deletions that we need to do.
+ // Then add the updates into the sync response.
+ if len(updates) > 0 || len(deletions) > 0 {
+ // Handle the updates and deletions in the database.
+ err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since)
+ if err != nil {
+ return
+ }
+ }
+ if len(events) > 0 {
+ // Add the updates into the sync response.
+ for _, event := range events {
+ res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent)
+ }
+
+ // Get the next_batch from the sync response and increase the
+ // EDU counter.
+ if pos, perr := types.NewStreamTokenFromString(res.NextBatch); perr == nil {
+ pos.Positions[1]++
+ res.NextBatch = pos.String()
+ }
+ }
+
return
}
@@ -238,6 +281,10 @@ func (rp *RequestPool) appendAccountData(
// shouldReturnImmediately returns whether the /sync request is an initial sync,
// or timeout=0, or full_state=true, in any of the cases the request should
// return immediately.
-func shouldReturnImmediately(syncReq *syncRequest) bool {
- return syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState
+func (rp *RequestPool) shouldReturnImmediately(syncReq *syncRequest) bool {
+ if syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState {
+ return true
+ }
+ waiting, werr := rp.db.SendToDeviceUpdatesWaiting(context.TODO(), syncReq.device.UserID, syncReq.device.ID)
+ return werr == nil && waiting
}
diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go
index 9251f618..762f4e9d 100644
--- a/syncapi/syncapi.go
+++ b/syncapi/syncapi.go
@@ -78,7 +78,14 @@ func SetupSyncAPIComponent(
base.Cfg, base.KafkaConsumer, notifier, syncDB,
)
if err = typingConsumer.Start(); err != nil {
- logrus.WithError(err).Panicf("failed to start typing server consumer")
+ logrus.WithError(err).Panicf("failed to start typing consumer")
+ }
+
+ sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer(
+ base.Cfg, base.KafkaConsumer, notifier, syncDB,
+ )
+ if err = sendToDeviceConsumer.Start(); err != nil {
+ logrus.WithError(err).Panicf("failed to start send-to-device consumer")
}
routing.Setup(base.PublicAPIMux, requestPool, syncDB, deviceDB, federation, rsAPI, cfg)
diff --git a/syncapi/types/types.go b/syncapi/types/types.go
index caa1b3ad..c1f09fba 100644
--- a/syncapi/types/types.go
+++ b/syncapi/types/types.go
@@ -296,13 +296,14 @@ type Response struct {
Invite map[string]InviteResponse `json:"invite"`
Leave map[string]LeaveResponse `json:"leave"`
} `json:"rooms"`
+ ToDevice struct {
+ Events []gomatrixserverlib.SendToDeviceEvent `json:"events"`
+ } `json:"to_device"`
}
// NewResponse creates an empty response with initialised maps.
-func NewResponse(token StreamingToken) *Response {
- res := Response{
- NextBatch: token.String(),
- }
+func NewResponse() *Response {
+ res := Response{}
// Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section,
// so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors.
res.Rooms.Join = make(map[string]JoinResponse)
@@ -315,6 +316,7 @@ func NewResponse(token StreamingToken) *Response {
// This also applies to NewJoinResponse, NewInviteResponse and NewLeaveResponse.
res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0)
res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0)
+ res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0)
return &res
}
@@ -326,7 +328,8 @@ func (r *Response) IsEmpty() bool {
len(r.Rooms.Invite) == 0 &&
len(r.Rooms.Leave) == 0 &&
len(r.AccountData.Events) == 0 &&
- len(r.Presence.Events) == 0
+ len(r.Presence.Events) == 0 &&
+ len(r.ToDevice.Events) == 0
}
// JoinResponse represents a /sync response for a room which is under the 'join' key.
@@ -393,3 +396,13 @@ func NewLeaveResponse() *LeaveResponse {
res.Timeline.Events = make([]gomatrixserverlib.ClientEvent, 0)
return &res
}
+
+type SendToDeviceNID int
+
+type SendToDeviceEvent struct {
+ gomatrixserverlib.SendToDeviceEvent
+ ID SendToDeviceNID
+ UserID string
+ DeviceID string
+ SentByToken *StreamingToken
+}
diff --git a/sytest-blacklist b/sytest-blacklist
index caad2545..1efc207f 100644
--- a/sytest-blacklist
+++ b/sytest-blacklist
@@ -39,3 +39,8 @@ Ignore invite in incremental sync
# Blacklisted because this test calls /r0/events which we don't implement
New room members see their own join event
Existing members see new members' join events
+
+# Blacklisted because the federation work for these hasn't been finished yet.
+Can recv device messages over federation
+Device messages over federation wake up /sync
+Wildcard device messages over federation wake up /sync
diff --git a/sytest-whitelist b/sytest-whitelist
index d4e6be9a..6236b28e 100644
--- a/sytest-whitelist
+++ b/sytest-whitelist
@@ -289,3 +289,15 @@ Existing members see new members' join events
Inbound federation can receive events
Inbound federation can receive redacted events
Can logout current device
+Can send a message directly to a device using PUT /sendToDevice
+Can recv a device message using /sync
+Can recv device messages until they are acknowledged
+Device messages with the same txn_id are deduplicated
+Device messages wake up /sync
+# TODO: separate PR for: Can recv device messages over federation
+# TODO: separate PR for: Device messages over federation wake up /sync
+Can send messages with a wildcard device id
+Can send messages with a wildcard device id to two devices
+Wildcard device messages wake up /sync
+# TODO: separate PR for: Wildcard device messages over federation wake up /sync
+Can send a to-device message to two users which both receive it using /sync