aboutsummaryrefslogtreecommitdiff
path: root/syncapi/streams/stream_accountdata.go
blob: aa7f0937d11d111e5cd96bfa41f2d590063f5fd6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package streams

import (
	"context"

	"github.com/matrix-org/dendrite/syncapi/types"
	userapi "github.com/matrix-org/dendrite/userapi/api"
	"github.com/matrix-org/gomatrixserverlib"
)

type AccountDataStreamProvider struct {
	StreamProvider
	userAPI userapi.UserInternalAPI
}

func (p *AccountDataStreamProvider) Setup() {
	p.StreamProvider.Setup()

	p.latestMutex.Lock()
	defer p.latestMutex.Unlock()

	id, err := p.DB.MaxStreamPositionForAccountData(context.Background())
	if err != nil {
		panic(err)
	}
	p.latest = id
}

func (p *AccountDataStreamProvider) CompleteSync(
	ctx context.Context,
	req *types.SyncRequest,
) types.StreamPosition {
	dataReq := &userapi.QueryAccountDataRequest{
		UserID: req.Device.UserID,
	}
	dataRes := &userapi.QueryAccountDataResponse{}
	if err := p.userAPI.QueryAccountData(ctx, dataReq, dataRes); err != nil {
		req.Log.WithError(err).Error("p.userAPI.QueryAccountData failed")
		return p.LatestPosition(ctx)
	}
	for datatype, databody := range dataRes.GlobalAccountData {
		req.Response.AccountData.Events = append(
			req.Response.AccountData.Events,
			gomatrixserverlib.ClientEvent{
				Type:    datatype,
				Content: gomatrixserverlib.RawJSON(databody),
			},
		)
	}
	for r, j := range req.Response.Rooms.Join {
		for datatype, databody := range dataRes.RoomAccountData[r] {
			j.AccountData.Events = append(
				j.AccountData.Events,
				gomatrixserverlib.ClientEvent{
					Type:    datatype,
					Content: gomatrixserverlib.RawJSON(databody),
				},
			)
			req.Response.Rooms.Join[r] = j
		}
	}

	return p.LatestPosition(ctx)
}

func (p *AccountDataStreamProvider) IncrementalSync(
	ctx context.Context,
	req *types.SyncRequest,
	from, to types.StreamPosition,
) types.StreamPosition {
	r := types.Range{
		From: from,
		To:   to,
	}
	accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead

	dataTypes, err := p.DB.GetAccountDataInRange(
		ctx, req.Device.UserID, r, &accountDataFilter,
	)
	if err != nil {
		req.Log.WithError(err).Error("p.DB.GetAccountDataInRange failed")
		return from
	}

	if len(dataTypes) == 0 {
		// TODO: this fixes the sytest but is it the right thing to do?
		dataTypes[""] = []string{"m.push_rules"}
	}

	// Iterate over the rooms
	for roomID, dataTypes := range dataTypes {
		// Request the missing data from the database
		for _, dataType := range dataTypes {
			dataReq := userapi.QueryAccountDataRequest{
				UserID:   req.Device.UserID,
				RoomID:   roomID,
				DataType: dataType,
			}
			dataRes := userapi.QueryAccountDataResponse{}
			err = p.userAPI.QueryAccountData(ctx, &dataReq, &dataRes)
			if err != nil {
				req.Log.WithError(err).Error("p.userAPI.QueryAccountData failed")
				continue
			}
			if roomID == "" {
				if globalData, ok := dataRes.GlobalAccountData[dataType]; ok {
					req.Response.AccountData.Events = append(
						req.Response.AccountData.Events,
						gomatrixserverlib.ClientEvent{
							Type:    dataType,
							Content: gomatrixserverlib.RawJSON(globalData),
						},
					)
				}
			} else {
				if roomData, ok := dataRes.RoomAccountData[roomID][dataType]; ok {
					joinData := req.Response.Rooms.Join[roomID]
					joinData.AccountData.Events = append(
						joinData.AccountData.Events,
						gomatrixserverlib.ClientEvent{
							Type:    dataType,
							Content: gomatrixserverlib.RawJSON(roomData),
						},
					)
					req.Response.Rooms.Join[roomID] = joinData
				}
			}
		}
	}

	return to
}