aboutsummaryrefslogtreecommitdiff
path: root/cmd/resolve-state/main.go
blob: 3a9f3ce4f2212c98464832b2221c929012e530f1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package main

import (
	"context"
	"flag"
	"fmt"
	"strconv"

	"github.com/matrix-org/dendrite/internal/caching"
	"github.com/matrix-org/dendrite/roomserver/storage"
	"github.com/matrix-org/dendrite/roomserver/types"
	"github.com/matrix-org/dendrite/setup"
	"github.com/matrix-org/dendrite/setup/base"
	"github.com/matrix-org/dendrite/setup/config"
	"github.com/matrix-org/gomatrixserverlib"
)

// This is a utility for inspecting state snapshots and running state resolution
// against real snapshots in an actual database.
// It takes one or more state snapshot NIDs as arguments, along with a room version
// to use for unmarshalling events, and will produce resolved output.
//
// Usage: ./resolve-state --roomversion=version snapshot [snapshot ...]
//   e.g. ./resolve-state --roomversion=5 1254 1235 1282

var roomVersion = flag.String("roomversion", "5", "the room version to parse events as")
var filterType = flag.String("filtertype", "", "the event types to filter on")

func main() {
	ctx := context.Background()
	cfg := setup.ParseFlags(true)
	cfg.Logging = append(cfg.Logging[:0], config.LogrusHook{
		Type:  "std",
		Level: "error",
	})
	base := base.NewBaseDendrite(cfg, "ResolveState", base.DisableMetrics)
	args := flag.Args()

	fmt.Println("Room version", *roomVersion)

	snapshotNIDs := []types.StateSnapshotNID{}
	for _, arg := range args {
		if i, err := strconv.Atoi(arg); err == nil {
			snapshotNIDs = append(snapshotNIDs, types.StateSnapshotNID(i))
		}
	}

	fmt.Println("Fetching", len(snapshotNIDs), "snapshot NIDs")

	cache, err := caching.NewInMemoryLRUCache(true)
	if err != nil {
		panic(err)
	}

	roomserverDB, err := storage.Open(base, &cfg.RoomServer.Database, cache)
	if err != nil {
		panic(err)
	}

	blockNIDs, err := roomserverDB.StateBlockNIDs(ctx, snapshotNIDs)
	if err != nil {
		panic(err)
	}

	var stateEntries []types.StateEntryList
	for _, list := range blockNIDs {
		entries, err2 := roomserverDB.StateEntries(ctx, list.StateBlockNIDs)
		if err2 != nil {
			panic(err2)
		}
		stateEntries = append(stateEntries, entries...)
	}

	var eventNIDs []types.EventNID
	for _, entry := range stateEntries {
		for _, e := range entry.StateEntries {
			eventNIDs = append(eventNIDs, e.EventNID)
		}
	}

	fmt.Println("Fetching", len(eventNIDs), "state events")
	eventEntries, err := roomserverDB.Events(ctx, eventNIDs)
	if err != nil {
		panic(err)
	}

	authEventIDMap := make(map[string]struct{})
	events := make([]*gomatrixserverlib.Event, len(eventEntries))
	for i := range eventEntries {
		events[i] = eventEntries[i].Event
		for _, authEventID := range eventEntries[i].AuthEventIDs() {
			authEventIDMap[authEventID] = struct{}{}
		}
	}

	authEventIDs := make([]string, 0, len(authEventIDMap))
	for authEventID := range authEventIDMap {
		authEventIDs = append(authEventIDs, authEventID)
	}

	fmt.Println("Fetching", len(authEventIDs), "auth events")
	authEventEntries, err := roomserverDB.EventsFromIDs(ctx, authEventIDs)
	if err != nil {
		panic(err)
	}

	authEvents := make([]*gomatrixserverlib.Event, len(authEventEntries))
	for i := range authEventEntries {
		authEvents[i] = authEventEntries[i].Event
	}

	fmt.Println("Resolving state")
	resolved, err := gomatrixserverlib.ResolveConflicts(
		gomatrixserverlib.RoomVersion(*roomVersion),
		events,
		authEvents,
	)
	if err != nil {
		panic(err)
	}

	fmt.Println("Resolved state contains", len(resolved), "events")
	filteringEventType := *filterType
	count := 0
	for _, event := range resolved {
		if filteringEventType != "" && event.Type() != filteringEventType {
			continue
		}
		count++
		fmt.Println()
		fmt.Printf("* %s %s %q\n", event.EventID(), event.Type(), *event.StateKey())
		fmt.Printf("  %s\n", string(event.Content()))
	}

	fmt.Println()
	fmt.Println("Returned", count, "state events after filtering")
}