aboutsummaryrefslogtreecommitdiff
path: root/roomserver
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-09-03 17:20:54 +0100
committerGitHub <noreply@github.com>2020-09-03 17:20:54 +0100
commitb20386123e0cbdc53016231f0087d0047b5667e9 (patch)
treef037957006b0295709be9890c22fdb4563a1d2be /roomserver
parent6150de6cb3611ffc61ce10ed6714f65e51e38e78 (diff)
Move currentstateserver API to roomserver (#1387)
* Move currentstateserver API to roomserver Stub out DB functions for now, nothing uses the roomserver version yet. * Allow it to startup * Implement some current-state-server storage interface functions * Add missing package
Diffstat (limited to 'roomserver')
-rw-r--r--roomserver/acls/acls.go164
-rw-r--r--roomserver/acls/acls_test.go105
-rw-r--r--roomserver/api/api.go14
-rw-r--r--roomserver/api/api_trace.go41
-rw-r--r--roomserver/api/query.go104
-rw-r--r--roomserver/api/wrapper.go99
-rw-r--r--roomserver/internal/api.go6
-rw-r--r--roomserver/internal/query/query.go102
-rw-r--r--roomserver/inthttp/client.go72
-rw-r--r--roomserver/inthttp/server.go78
-rw-r--r--roomserver/storage/interface.go19
-rw-r--r--roomserver/storage/postgres/membership_table.go24
-rw-r--r--roomserver/storage/postgres/rooms_table.go48
-rw-r--r--roomserver/storage/shared/storage.go80
-rw-r--r--roomserver/storage/sqlite3/membership_table.go24
-rw-r--r--roomserver/storage/sqlite3/rooms_table.go49
-rw-r--r--roomserver/storage/tables/interface.go3
17 files changed, 1028 insertions, 4 deletions
diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go
new file mode 100644
index 00000000..775b6c73
--- /dev/null
+++ b/roomserver/acls/acls.go
@@ -0,0 +1,164 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package acls
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net"
+ "regexp"
+ "strings"
+ "sync"
+
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+)
+
+type ServerACLDatabase interface {
+ // GetKnownRooms returns a list of all rooms we know about.
+ GetKnownRooms(ctx context.Context) ([]string, error)
+ // GetStateEvent returns the state event of a given type for a given room with a given state key
+ // If no event could be found, returns nil
+ // If there was an issue during the retrieval, returns an error
+ GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
+}
+
+type ServerACLs struct {
+ acls map[string]*serverACL // room ID -> ACL
+ aclsMutex sync.RWMutex // protects the above
+}
+
+func NewServerACLs(db ServerACLDatabase) *ServerACLs {
+ ctx := context.TODO()
+ acls := &ServerACLs{
+ acls: make(map[string]*serverACL),
+ }
+ // Look up all of the rooms that the current state server knows about.
+ rooms, err := db.GetKnownRooms(ctx)
+ if err != nil {
+ logrus.WithError(err).Fatalf("Failed to get known rooms")
+ }
+ // For each room, let's see if we have a server ACL state event. If we
+ // do then we'll process it into memory so that we have the regexes to
+ // hand.
+ for _, room := range rooms {
+ state, err := db.GetStateEvent(ctx, room, "m.room.server_acl", "")
+ if err != nil {
+ logrus.WithError(err).Errorf("Failed to get server ACLs for room %q", room)
+ continue
+ }
+ if state != nil {
+ acls.OnServerACLUpdate(&state.Event)
+ }
+ }
+ return acls
+}
+
+type ServerACL struct {
+ Allowed []string `json:"allow"`
+ Denied []string `json:"deny"`
+ AllowIPLiterals bool `json:"allow_ip_literals"`
+}
+
+type serverACL struct {
+ ServerACL
+ allowedRegexes []*regexp.Regexp
+ deniedRegexes []*regexp.Regexp
+}
+
+func compileACLRegex(orig string) (*regexp.Regexp, error) {
+ escaped := regexp.QuoteMeta(orig)
+ escaped = strings.Replace(escaped, "\\?", ".", -1)
+ escaped = strings.Replace(escaped, "\\*", ".*", -1)
+ return regexp.Compile(escaped)
+}
+
+func (s *ServerACLs) OnServerACLUpdate(state *gomatrixserverlib.Event) {
+ acls := &serverACL{}
+ if err := json.Unmarshal(state.Content(), &acls.ServerACL); err != nil {
+ logrus.WithError(err).Errorf("Failed to unmarshal state content for server ACLs")
+ return
+ }
+ // The spec calls only for * (zero or more chars) and ? (exactly one char)
+ // to be supported as wildcard components, so we will escape all of the regex
+ // special characters and then replace * and ? with their regex counterparts.
+ // https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl
+ for _, orig := range acls.Allowed {
+ if expr, err := compileACLRegex(orig); err != nil {
+ logrus.WithError(err).Errorf("Failed to compile allowed regex")
+ } else {
+ acls.allowedRegexes = append(acls.allowedRegexes, expr)
+ }
+ }
+ for _, orig := range acls.Denied {
+ if expr, err := compileACLRegex(orig); err != nil {
+ logrus.WithError(err).Errorf("Failed to compile denied regex")
+ } else {
+ acls.deniedRegexes = append(acls.deniedRegexes, expr)
+ }
+ }
+ logrus.WithFields(logrus.Fields{
+ "allow_ip_literals": acls.AllowIPLiterals,
+ "num_allowed": len(acls.allowedRegexes),
+ "num_denied": len(acls.deniedRegexes),
+ }).Debugf("Updating server ACLs for %q", state.RoomID())
+ s.aclsMutex.Lock()
+ defer s.aclsMutex.Unlock()
+ s.acls[state.RoomID()] = acls
+}
+
+func (s *ServerACLs) IsServerBannedFromRoom(serverName gomatrixserverlib.ServerName, roomID string) bool {
+ s.aclsMutex.RLock()
+ // First of all check if we have an ACL for this room. If we don't then
+ // no servers are banned from the room.
+ acls, ok := s.acls[roomID]
+ if !ok {
+ s.aclsMutex.RUnlock()
+ return false
+ }
+ s.aclsMutex.RUnlock()
+ // Split the host and port apart. This is because the spec calls on us to
+ // validate the hostname only in cases where the port is also present.
+ if serverNameOnly, _, err := net.SplitHostPort(string(serverName)); err == nil {
+ serverName = gomatrixserverlib.ServerName(serverNameOnly)
+ }
+ // Check if the hostname is an IPv4 or IPv6 literal. We cheat here by adding
+ // a /0 prefix length just to trick ParseCIDR into working. If we find that
+ // the server is an IP literal and we don't allow those then stop straight
+ // away.
+ if _, _, err := net.ParseCIDR(fmt.Sprintf("%s/0", serverName)); err == nil {
+ if !acls.AllowIPLiterals {
+ return true
+ }
+ }
+ // Check if the hostname matches one of the denied regexes. If it does then
+ // the server is banned from the room.
+ for _, expr := range acls.deniedRegexes {
+ if expr.MatchString(string(serverName)) {
+ return true
+ }
+ }
+ // Check if the hostname matches one of the allowed regexes. If it does then
+ // the server is NOT banned from the room.
+ for _, expr := range acls.allowedRegexes {
+ if expr.MatchString(string(serverName)) {
+ return false
+ }
+ }
+ // If we've got to this point then we haven't matched any regexes or an IP
+ // hostname if disallowed. The spec calls for default-deny here.
+ return true
+}
diff --git a/roomserver/acls/acls_test.go b/roomserver/acls/acls_test.go
new file mode 100644
index 00000000..9fb6a558
--- /dev/null
+++ b/roomserver/acls/acls_test.go
@@ -0,0 +1,105 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package acls
+
+import (
+ "regexp"
+ "testing"
+)
+
+func TestOpenACLsWithBlacklist(t *testing.T) {
+ roomID := "!test:test.com"
+ allowRegex, err := compileACLRegex("*")
+ if err != nil {
+ t.Fatalf(err.Error())
+ }
+ denyRegex, err := compileACLRegex("foo.com")
+ if err != nil {
+ t.Fatalf(err.Error())
+ }
+
+ acls := ServerACLs{
+ acls: make(map[string]*serverACL),
+ }
+
+ acls.acls[roomID] = &serverACL{
+ ServerACL: ServerACL{
+ AllowIPLiterals: true,
+ },
+ allowedRegexes: []*regexp.Regexp{allowRegex},
+ deniedRegexes: []*regexp.Regexp{denyRegex},
+ }
+
+ if acls.IsServerBannedFromRoom("1.2.3.4", roomID) {
+ t.Fatal("Expected 1.2.3.4 to be allowed but wasn't")
+ }
+ if acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) {
+ t.Fatal("Expected 1.2.3.4:2345 to be allowed but wasn't")
+ }
+ if !acls.IsServerBannedFromRoom("foo.com", roomID) {
+ t.Fatal("Expected foo.com to be banned but wasn't")
+ }
+ if !acls.IsServerBannedFromRoom("foo.com:3456", roomID) {
+ t.Fatal("Expected foo.com:3456 to be banned but wasn't")
+ }
+ if acls.IsServerBannedFromRoom("bar.com", roomID) {
+ t.Fatal("Expected bar.com to be allowed but wasn't")
+ }
+ if acls.IsServerBannedFromRoom("bar.com:4567", roomID) {
+ t.Fatal("Expected bar.com:4567 to be allowed but wasn't")
+ }
+}
+
+func TestDefaultACLsWithWhitelist(t *testing.T) {
+ roomID := "!test:test.com"
+ allowRegex, err := compileACLRegex("foo.com")
+ if err != nil {
+ t.Fatalf(err.Error())
+ }
+
+ acls := ServerACLs{
+ acls: make(map[string]*serverACL),
+ }
+
+ acls.acls[roomID] = &serverACL{
+ ServerACL: ServerACL{
+ AllowIPLiterals: false,
+ },
+ allowedRegexes: []*regexp.Regexp{allowRegex},
+ deniedRegexes: []*regexp.Regexp{},
+ }
+
+ if !acls.IsServerBannedFromRoom("1.2.3.4", roomID) {
+ t.Fatal("Expected 1.2.3.4 to be banned but wasn't")
+ }
+ if !acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) {
+ t.Fatal("Expected 1.2.3.4:2345 to be banned but wasn't")
+ }
+ if acls.IsServerBannedFromRoom("foo.com", roomID) {
+ t.Fatal("Expected foo.com to be allowed but wasn't")
+ }
+ if acls.IsServerBannedFromRoom("foo.com:3456", roomID) {
+ t.Fatal("Expected foo.com:3456 to be allowed but wasn't")
+ }
+ if !acls.IsServerBannedFromRoom("bar.com", roomID) {
+ t.Fatal("Expected bar.com to be allowed but wasn't")
+ }
+ if !acls.IsServerBannedFromRoom("baz.com", roomID) {
+ t.Fatal("Expected baz.com to be allowed but wasn't")
+ }
+ if !acls.IsServerBannedFromRoom("qux.com:4567", roomID) {
+ t.Fatal("Expected qux.com:4567 to be allowed but wasn't")
+ }
+}
diff --git a/roomserver/api/api.go b/roomserver/api/api.go
index 0fe30b8b..96bdc767 100644
--- a/roomserver/api/api.go
+++ b/roomserver/api/api.go
@@ -106,6 +106,20 @@ type RoomserverInternalAPI interface {
response *QueryStateAndAuthChainResponse,
) error
+ // QueryCurrentState retrieves the requested state events. If state events are not found, they will be missing from
+ // the response.
+ QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
+ // QueryRoomsForUser retrieves a list of room IDs matching the given query.
+ QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
+ // QueryBulkStateContent does a bulk query for state event content in the given rooms.
+ QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error
+ // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
+ QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
+ // QueryKnownUsers returns a list of users that we know about from our joined rooms.
+ QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error
+ // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
+ QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
+
// Query a given amount (or less) of events prior to a given set of events.
PerformBackfill(
ctx context.Context,
diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go
index 9b53aa88..25da2e8e 100644
--- a/roomserver/api/api_trace.go
+++ b/roomserver/api/api_trace.go
@@ -236,6 +236,47 @@ func (t *RoomserverInternalAPITrace) RemoveRoomAlias(
return err
}
+func (t *RoomserverInternalAPITrace) QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error {
+ err := t.Impl.QueryCurrentState(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("QueryCurrentState req=%+v res=%+v", js(req), js(res))
+ return err
+}
+
+// QueryRoomsForUser retrieves a list of room IDs matching the given query.
+func (t *RoomserverInternalAPITrace) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error {
+ err := t.Impl.QueryRoomsForUser(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("QueryRoomsForUser req=%+v res=%+v", js(req), js(res))
+ return err
+}
+
+// QueryBulkStateContent does a bulk query for state event content in the given rooms.
+func (t *RoomserverInternalAPITrace) QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error {
+ err := t.Impl.QueryBulkStateContent(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("QueryBulkStateContent req=%+v res=%+v", js(req), js(res))
+ return err
+}
+
+// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
+func (t *RoomserverInternalAPITrace) QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error {
+ err := t.Impl.QuerySharedUsers(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("QuerySharedUsers req=%+v res=%+v", js(req), js(res))
+ return err
+}
+
+// QueryKnownUsers returns a list of users that we know about from our joined rooms.
+func (t *RoomserverInternalAPITrace) QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error {
+ err := t.Impl.QueryKnownUsers(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("QueryKnownUsers req=%+v res=%+v", js(req), js(res))
+ return err
+}
+
+// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
+func (t *RoomserverInternalAPITrace) QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error {
+ err := t.Impl.QueryServerBannedFromRoom(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("QueryServerBannedFromRoom req=%+v res=%+v", js(req), js(res))
+ return err
+}
+
func js(thing interface{}) string {
b, err := json.Marshal(thing)
if err != nil {
diff --git a/roomserver/api/query.go b/roomserver/api/query.go
index 4e1d09c3..d0d0474d 100644
--- a/roomserver/api/query.go
+++ b/roomserver/api/query.go
@@ -17,6 +17,11 @@
package api
import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -225,3 +230,102 @@ type QueryPublishedRoomsResponse struct {
// The list of published rooms.
RoomIDs []string
}
+
+type QuerySharedUsersRequest struct {
+ UserID string
+ ExcludeRoomIDs []string
+ IncludeRoomIDs []string
+}
+
+type QuerySharedUsersResponse struct {
+ UserIDsToCount map[string]int
+}
+
+type QueryRoomsForUserRequest struct {
+ UserID string
+ // The desired membership of the user. If this is the empty string then no rooms are returned.
+ WantMembership string
+}
+
+type QueryRoomsForUserResponse struct {
+ RoomIDs []string
+}
+
+type QueryBulkStateContentRequest struct {
+ // Returns state events in these rooms
+ RoomIDs []string
+ // If true, treats the '*' StateKey as "all state events of this type" rather than a literal value of '*'
+ AllowWildcards bool
+ // The state events to return. Only a small subset of tuples are allowed in this request as only certain events
+ // have their content fields extracted. Specifically, the tuple Type must be one of:
+ // m.room.avatar
+ // m.room.create
+ // m.room.canonical_alias
+ // m.room.guest_access
+ // m.room.history_visibility
+ // m.room.join_rules
+ // m.room.member
+ // m.room.name
+ // m.room.topic
+ // Any other tuple type will result in the query failing.
+ StateTuples []gomatrixserverlib.StateKeyTuple
+}
+type QueryBulkStateContentResponse struct {
+ // map of room ID -> tuple -> content_value
+ Rooms map[string]map[gomatrixserverlib.StateKeyTuple]string
+}
+
+type QueryCurrentStateRequest struct {
+ RoomID string
+ StateTuples []gomatrixserverlib.StateKeyTuple
+}
+
+type QueryCurrentStateResponse struct {
+ StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent
+}
+
+type QueryKnownUsersRequest struct {
+ UserID string `json:"user_id"`
+ SearchString string `json:"search_string"`
+ Limit int `json:"limit"`
+}
+
+type QueryKnownUsersResponse struct {
+ Users []authtypes.FullyQualifiedProfile `json:"profiles"`
+}
+
+type QueryServerBannedFromRoomRequest struct {
+ ServerName gomatrixserverlib.ServerName `json:"server_name"`
+ RoomID string `json:"room_id"`
+}
+
+type QueryServerBannedFromRoomResponse struct {
+ Banned bool `json:"banned"`
+}
+
+// MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode.
+func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) {
+ se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents))
+ for k, v := range r.StateEvents {
+ // use 0x1F (unit separator) as the delimiter between type/state key,
+ se[fmt.Sprintf("%s\x1F%s", k.EventType, k.StateKey)] = v
+ }
+ return json.Marshal(se)
+}
+
+func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
+ res := make(map[string]*gomatrixserverlib.HeaderedEvent)
+ err := json.Unmarshal(data, &res)
+ if err != nil {
+ return err
+ }
+ r.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent, len(res))
+ for k, v := range res {
+ fields := strings.Split(k, "\x1F")
+ r.StateEvents[gomatrixserverlib.StateKeyTuple{
+ EventType: fields[0],
+ StateKey: fields[1],
+ }] = v
+ }
+ return nil
+}
diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go
index 16f5e8e1..82a4a571 100644
--- a/roomserver/api/wrapper.go
+++ b/roomserver/api/wrapper.go
@@ -133,3 +133,102 @@ func GetEvent(ctx context.Context, rsAPI RoomserverInternalAPI, eventID string)
}
return &res.Events[0]
}
+
+// GetStateEvent returns the current state event in the room or nil.
+func GetStateEvent(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.HeaderedEvent {
+ var res QueryCurrentStateResponse
+ err := rsAPI.QueryCurrentState(ctx, &QueryCurrentStateRequest{
+ RoomID: roomID,
+ StateTuples: []gomatrixserverlib.StateKeyTuple{tuple},
+ }, &res)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("Failed to QueryCurrentState")
+ return nil
+ }
+ ev, ok := res.StateEvents[tuple]
+ if ok {
+ return ev
+ }
+ return nil
+}
+
+// IsServerBannedFromRoom returns whether the server is banned from a room by server ACLs.
+func IsServerBannedFromRoom(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, serverName gomatrixserverlib.ServerName) bool {
+ req := &QueryServerBannedFromRoomRequest{
+ ServerName: serverName,
+ RoomID: roomID,
+ }
+ res := &QueryServerBannedFromRoomResponse{}
+ if err := rsAPI.QueryServerBannedFromRoom(ctx, req, res); err != nil {
+ util.GetLogger(ctx).WithError(err).Error("Failed to QueryServerBannedFromRoom")
+ return true
+ }
+ return res.Banned
+}
+
+// PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the
+// published room directory.
+// due to lots of switches
+// nolint:gocyclo
+func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) {
+ avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""}
+ nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""}
+ canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""}
+ topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""}
+ guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""}
+ visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""}
+ joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""}
+
+ var stateRes QueryBulkStateContentResponse
+ err := rsAPI.QueryBulkStateContent(ctx, &QueryBulkStateContentRequest{
+ RoomIDs: roomIDs,
+ AllowWildcards: true,
+ StateTuples: []gomatrixserverlib.StateKeyTuple{
+ nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, avatarTuple,
+ {EventType: gomatrixserverlib.MRoomMember, StateKey: "*"},
+ },
+ }, &stateRes)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("QueryBulkStateContent failed")
+ return nil, err
+ }
+ chunk := make([]gomatrixserverlib.PublicRoom, len(roomIDs))
+ i := 0
+ for roomID, data := range stateRes.Rooms {
+ pub := gomatrixserverlib.PublicRoom{
+ RoomID: roomID,
+ }
+ joinCount := 0
+ var joinRule, guestAccess string
+ for tuple, contentVal := range data {
+ if tuple.EventType == gomatrixserverlib.MRoomMember && contentVal == "join" {
+ joinCount++
+ continue
+ }
+ switch tuple {
+ case avatarTuple:
+ pub.AvatarURL = contentVal
+ case nameTuple:
+ pub.Name = contentVal
+ case topicTuple:
+ pub.Topic = contentVal
+ case canonicalTuple:
+ pub.CanonicalAlias = contentVal
+ case visibilityTuple:
+ pub.WorldReadable = contentVal == "world_readable"
+ // need both of these to determine whether guests can join
+ case joinRuleTuple:
+ joinRule = contentVal
+ case guestTuple:
+ guestAccess = contentVal
+ }
+ }
+ if joinRule == gomatrixserverlib.Public && guestAccess == "can_join" {
+ pub.GuestCanJoin = true
+ }
+ pub.JoinedMembersCount = joinCount
+ chunk[i] = pub
+ i++
+ }
+ return chunk, nil
+}
diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go
index 93c0be77..bdea650e 100644
--- a/roomserver/internal/api.go
+++ b/roomserver/internal/api.go
@@ -7,6 +7,7 @@ import (
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/config"
+ "github.com/matrix-org/dendrite/roomserver/acls"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/internal/perform"
@@ -46,8 +47,9 @@ func NewRoomserverAPI(
ServerName: cfg.Matrix.ServerName,
KeyRing: keyRing,
Queryer: &query.Queryer{
- DB: roomserverDB,
- Cache: caches,
+ DB: roomserverDB,
+ Cache: caches,
+ ServerACLs: acls.NewServerACLs(roomserverDB),
},
Inputer: &input.Inputer{
DB: roomserverDB,
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index b2799aef..f76c9316 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -16,9 +16,12 @@ package query
import (
"context"
+ "errors"
"fmt"
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/caching"
+ "github.com/matrix-org/dendrite/roomserver/acls"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/state"
@@ -31,8 +34,9 @@ import (
)
type Queryer struct {
- DB storage.Database
- Cache caching.RoomServerCaches
+ DB storage.Database
+ Cache caching.RoomServerCaches
+ ServerACLs *acls.ServerACLs
}
// QueryLatestEventsAndState implements api.RoomserverInternalAPI
@@ -502,3 +506,97 @@ func (r *Queryer) QueryPublishedRooms(
res.RoomIDs = rooms
return nil
}
+
+func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
+ res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
+ for _, tuple := range req.StateTuples {
+ ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey)
+ if err != nil {
+ return err
+ }
+ if ev != nil {
+ res.StateEvents[tuple] = ev
+ }
+ }
+ return nil
+}
+
+func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
+ roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership)
+ if err != nil {
+ return err
+ }
+ res.RoomIDs = roomIDs
+ return nil
+}
+
+func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
+ users, err := r.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit)
+ if err != nil {
+ return err
+ }
+ for _, user := range users {
+ res.Users = append(res.Users, authtypes.FullyQualifiedProfile{
+ UserID: user,
+ })
+ }
+ return nil
+}
+
+func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error {
+ events, err := r.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards)
+ if err != nil {
+ return err
+ }
+ res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string)
+ for _, ev := range events {
+ if res.Rooms[ev.RoomID] == nil {
+ res.Rooms[ev.RoomID] = make(map[gomatrixserverlib.StateKeyTuple]string)
+ }
+ room := res.Rooms[ev.RoomID]
+ room[gomatrixserverlib.StateKeyTuple{
+ EventType: ev.EventType,
+ StateKey: ev.StateKey,
+ }] = ev.ContentValue
+ res.Rooms[ev.RoomID] = room
+ }
+ return nil
+}
+
+func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
+ roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
+ if err != nil {
+ return err
+ }
+ roomIDs = append(roomIDs, req.IncludeRoomIDs...)
+ excludeMap := make(map[string]bool)
+ for _, roomID := range req.ExcludeRoomIDs {
+ excludeMap[roomID] = true
+ }
+ // filter out excluded rooms
+ j := 0
+ for i := range roomIDs {
+ // move elements to include to the beginning of the slice
+ // then trim elements on the right
+ if !excludeMap[roomIDs[i]] {
+ roomIDs[j] = roomIDs[i]
+ j++
+ }
+ }
+ roomIDs = roomIDs[:j]
+
+ users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs)
+ if err != nil {
+ return err
+ }
+ res.UserIDsToCount = users
+ return nil
+}
+
+func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error {
+ if r.ServerACLs == nil {
+ return errors.New("no server ACL tracking")
+ }
+ res.Banned = r.ServerACLs.IsServerBannedFromRoom(req.ServerName, req.RoomID)
+ return nil
+}
diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go
index 1657bcde..b414b0d8 100644
--- a/roomserver/inthttp/client.go
+++ b/roomserver/inthttp/client.go
@@ -43,6 +43,12 @@ const (
RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities"
RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom"
RoomserverQueryPublishedRoomsPath = "/roomserver/queryPublishedRooms"
+ RoomserverQueryCurrentStatePath = "/roomserver/queryCurrentState"
+ RoomserverQueryRoomsForUserPath = "/roomserver/queryRoomsForUser"
+ RoomserverQueryBulkStateContentPath = "/roomserver/queryBulkStateContent"
+ RoomserverQuerySharedUsersPath = "/roomserver/querySharedUsers"
+ RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers"
+ RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
)
type httpRoomserverInternalAPI struct {
@@ -371,3 +377,69 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom(
}
return err
}
+
+func (h *httpRoomserverInternalAPI) QueryCurrentState(
+ ctx context.Context,
+ request *api.QueryCurrentStateRequest,
+ response *api.QueryCurrentStateResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryCurrentStatePath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+func (h *httpRoomserverInternalAPI) QueryRoomsForUser(
+ ctx context.Context,
+ request *api.QueryRoomsForUserRequest,
+ response *api.QueryRoomsForUserResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryRoomsForUserPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+func (h *httpRoomserverInternalAPI) QueryBulkStateContent(
+ ctx context.Context,
+ request *api.QueryBulkStateContentRequest,
+ response *api.QueryBulkStateContentResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryBulkStateContentPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+func (h *httpRoomserverInternalAPI) QuerySharedUsers(
+ ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQuerySharedUsersPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+}
+
+func (h *httpRoomserverInternalAPI) QueryKnownUsers(
+ ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryKnownUsersPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+}
+
+func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom(
+ ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+}
diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go
index 0ac36a2a..ebfb296d 100644
--- a/roomserver/inthttp/server.go
+++ b/roomserver/inthttp/server.go
@@ -312,4 +312,82 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
+ internalAPIMux.Handle(RoomserverQueryCurrentStatePath,
+ httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse {
+ request := api.QueryCurrentStateRequest{}
+ response := api.QueryCurrentStateResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := r.QueryCurrentState(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ internalAPIMux.Handle(RoomserverQueryRoomsForUserPath,
+ httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse {
+ request := api.QueryRoomsForUserRequest{}
+ response := api.QueryRoomsForUserResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := r.QueryRoomsForUser(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ internalAPIMux.Handle(RoomserverQueryBulkStateContentPath,
+ httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse {
+ request := api.QueryBulkStateContentRequest{}
+ response := api.QueryBulkStateContentResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := r.QueryBulkStateContent(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ internalAPIMux.Handle(RoomserverQuerySharedUsersPath,
+ httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse {
+ request := api.QuerySharedUsersRequest{}
+ response := api.QuerySharedUsersResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := r.QuerySharedUsers(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ internalAPIMux.Handle(RoomserverQuerySharedUsersPath,
+ httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse {
+ request := api.QueryKnownUsersRequest{}
+ response := api.QueryKnownUsersResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := r.QueryKnownUsers(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath,
+ httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse {
+ request := api.QueryServerBannedFromRoomRequest{}
+ response := api.QueryServerBannedFromRoomResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := r.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
}
diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go
index ef7a9f09..c4119f7e 100644
--- a/roomserver/storage/interface.go
+++ b/roomserver/storage/interface.go
@@ -17,6 +17,7 @@ package storage
import (
"context"
+ "github.com/matrix-org/dendrite/currentstateserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types"
@@ -138,4 +139,22 @@ type Database interface {
PublishRoom(ctx context.Context, roomID string, publish bool) error
// Returns a list of room IDs for rooms which are published.
GetPublishedRooms(ctx context.Context) ([]string, error)
+
+ // TODO: factor out - from currentstateserver
+
+ // GetStateEvent returns the state event of a given type for a given room with a given state key
+ // If no event could be found, returns nil
+ // If there was an issue during the retrieval, returns an error
+ GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
+ // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
+ GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error)
+ // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
+ // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
+ GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
+ // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
+ JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
+ // GetKnownUsers searches all users that userID knows about.
+ GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
+ // GetKnownRooms returns a list of all rooms we know about.
+ GetKnownRooms(ctx context.Context) ([]string, error)
}
diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go
index 13cef638..0799647e 100644
--- a/roomserver/storage/postgres/membership_table.go
+++ b/roomserver/storage/postgres/membership_table.go
@@ -99,6 +99,9 @@ const updateMembershipSQL = "" +
"UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" +
" WHERE room_nid = $1 AND target_nid = $2"
+const selectRoomsWithMembershipSQL = "" +
+ "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
+
type membershipStatements struct {
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
@@ -108,6 +111,7 @@ type membershipStatements struct {
selectMembershipsFromRoomStmt *sql.Stmt
selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
+ selectRoomsWithMembershipStmt *sql.Stmt
}
func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
@@ -126,6 +130,7 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
+ {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
}.Prepare(db)
}
@@ -222,3 +227,22 @@ func (s *membershipStatements) UpdateMembership(
)
return err
}
+
+func (s *membershipStatements) SelectRoomsWithMembership(
+ ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
+) ([]types.RoomNID, error) {
+ rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed")
+ var roomNIDs []types.RoomNID
+ for rows.Next() {
+ var roomNID types.RoomNID
+ if err := rows.Scan(&roomNID); err != nil {
+ return nil, err
+ }
+ roomNIDs = append(roomNIDs, roomNID)
+ }
+ return roomNIDs, nil
+}
diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go
index 13c8e703..9d359146 100644
--- a/roomserver/storage/postgres/rooms_table.go
+++ b/roomserver/storage/postgres/rooms_table.go
@@ -21,6 +21,7 @@ import (
"errors"
"github.com/lib/pq"
+ "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
@@ -74,6 +75,12 @@ const selectRoomVersionForRoomNIDSQL = "" +
const selectRoomInfoSQL = "" +
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
+const selectRoomIDsSQL = "" +
+ "SELECT room_id FROM roomserver_rooms"
+
+const bulkSelectRoomIDsSQL = "" +
+ "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
+
type roomStatements struct {
insertRoomNIDStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt
@@ -82,6 +89,8 @@ type roomStatements struct {
updateLatestEventNIDsStmt *sql.Stmt
selectRoomVersionForRoomNIDStmt *sql.Stmt
selectRoomInfoStmt *sql.Stmt
+ selectRoomIDsStmt *sql.Stmt
+ bulkSelectRoomIDsStmt *sql.Stmt
}
func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
@@ -98,9 +107,27 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
+ {&s.selectRoomIDsStmt, selectRoomIDsSQL},
+ {&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL},
}.Prepare(db)
}
+func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
+ rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
+ var roomIDs []string
+ for rows.Next() {
+ var roomID string
+ if err = rows.Scan(&roomID); err != nil {
+ return nil, err
+ }
+ roomIDs = append(roomIDs, roomID)
+ }
+ return roomIDs, nil
+}
func (s *roomStatements) InsertRoomNID(
ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion,
@@ -197,3 +224,24 @@ func (s *roomStatements) SelectRoomVersionForRoomNID(
}
return roomVersion, err
}
+
+func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
+ var array pq.Int64Array
+ for _, nid := range roomNIDs {
+ array = append(array, int64(nid))
+ }
+ rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
+ var roomIDs []string
+ for rows.Next() {
+ var roomID string
+ if err = rows.Scan(&roomID); err != nil {
+ return nil, err
+ }
+ roomIDs = append(roomIDs, roomID)
+ }
+ return roomIDs, nil
+}
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 6e0ebd2c..5c447d66 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
+ csstables "github.com/matrix-org/dendrite/currentstateserver/storage/tables"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
@@ -711,3 +712,82 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
}
return &evs[0]
}
+
+// GetStateEvent returns the current state event of a given type for a given room with a given state key
+// If no event could be found, returns nil
+// If there was an issue during the retrieval, returns an error
+func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
+ /*
+ roomInfo, err := d.RoomInfo(ctx, roomID)
+ if err != nil {
+ return nil, err
+ }
+ eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
+ if err != nil {
+ return nil, err
+ }
+ stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey)
+ if err != nil {
+ return nil, err
+ }
+ blockNIDs, err := d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, []types.StateSnapshotNID{roomInfo.StateSnapshotNID})
+ if err != nil {
+ return nil, err
+ }
+ */
+ return nil, nil
+}
+
+// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
+func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
+ var membershipState tables.MembershipState
+ switch membership {
+ case "join":
+ membershipState = tables.MembershipStateJoin
+ case "invite":
+ membershipState = tables.MembershipStateInvite
+ case "leave":
+ membershipState = tables.MembershipStateLeaveOrBan
+ case "ban":
+ membershipState = tables.MembershipStateLeaveOrBan
+ default:
+ return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
+ }
+ stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
+ if err != nil {
+ return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
+ }
+ roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState)
+ if err != nil {
+ return nil, err
+ }
+ roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs)
+ if err != nil {
+ return nil, err
+ }
+ if len(roomIDs) != len(roomNIDs) {
+ return nil, fmt.Errorf("GetRoomsByMembership: missing room IDs, got %d want %d", len(roomIDs), len(roomNIDs))
+ }
+ return roomIDs, nil
+}
+
+// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
+// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
+func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]csstables.StrippedEvent, error) {
+ return nil, fmt.Errorf("not implemented yet")
+}
+
+// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
+func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
+ return nil, fmt.Errorf("not implemented yet")
+}
+
+// GetKnownUsers searches all users that userID knows about.
+func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) {
+ return nil, fmt.Errorf("not implemented yet")
+}
+
+// GetKnownRooms returns a list of all rooms we know about.
+func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
+ return d.RoomsTable.SelectRoomIDs(ctx)
+}
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
index b3ee69c0..e850c80b 100644
--- a/roomserver/storage/sqlite3/membership_table.go
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -75,6 +75,9 @@ const updateMembershipSQL = "" +
"UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" +
" WHERE room_nid = $4 AND target_nid = $5"
+const selectRoomsWithMembershipSQL = "" +
+ "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
+
type membershipStatements struct {
db *sql.DB
insertMembershipStmt *sql.Stmt
@@ -84,6 +87,7 @@ type membershipStatements struct {
selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectMembershipsFromRoomStmt *sql.Stmt
selectLocalMembershipsFromRoomStmt *sql.Stmt
+ selectRoomsWithMembershipStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
}
@@ -105,6 +109,7 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
+ {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
}.Prepare(db)
}
@@ -203,3 +208,22 @@ func (s *membershipStatements) UpdateMembership(
)
return err
}
+
+func (s *membershipStatements) SelectRoomsWithMembership(
+ ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
+) ([]types.RoomNID, error) {
+ rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed")
+ var roomNIDs []types.RoomNID
+ for rows.Next() {
+ var roomNID types.RoomNID
+ if err := rows.Scan(&roomNID); err != nil {
+ return nil, err
+ }
+ roomNIDs = append(roomNIDs, roomNID)
+ }
+ return roomNIDs, nil
+}
diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go
index 4c1699d0..daacf86f 100644
--- a/roomserver/storage/sqlite3/rooms_table.go
+++ b/roomserver/storage/sqlite3/rooms_table.go
@@ -21,7 +21,9 @@ import (
"encoding/json"
"errors"
"fmt"
+ "strings"
+ "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
@@ -64,6 +66,12 @@ const selectRoomVersionForRoomNIDSQL = "" +
const selectRoomInfoSQL = "" +
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
+const selectRoomIDsSQL = "" +
+ "SELECT room_id FROM roomserver_rooms"
+
+const bulkSelectRoomIDsSQL = "" +
+ "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
+
type roomStatements struct {
db *sql.DB
insertRoomNIDStmt *sql.Stmt
@@ -73,6 +81,7 @@ type roomStatements struct {
updateLatestEventNIDsStmt *sql.Stmt
selectRoomVersionForRoomNIDStmt *sql.Stmt
selectRoomInfoStmt *sql.Stmt
+ selectRoomIDsStmt *sql.Stmt
}
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
@@ -91,9 +100,27 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
+ {&s.selectRoomIDsStmt, selectRoomIDsSQL},
}.Prepare(db)
}
+func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
+ rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
+ var roomIDs []string
+ for rows.Next() {
+ var roomID string
+ if err = rows.Scan(&roomID); err != nil {
+ return nil, err
+ }
+ roomIDs = append(roomIDs, roomID)
+ }
+ return roomIDs, nil
+}
+
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo
var latestNIDsJSON string
@@ -203,3 +230,25 @@ func (s *roomStatements) SelectRoomVersionForRoomNID(
}
return roomVersion, err
}
+
+func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
+ iRoomNIDs := make([]interface{}, len(roomNIDs))
+ for i, v := range roomNIDs {
+ iRoomNIDs[i] = v
+ }
+ sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
+ rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
+ var roomIDs []string
+ for rows.Next() {
+ var roomID string
+ if err = rows.Scan(&roomID); err != nil {
+ return nil, err
+ }
+ roomIDs = append(roomIDs, roomID)
+ }
+ return roomIDs, nil
+}
diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go
index c599dd3f..126c27b5 100644
--- a/roomserver/storage/tables/interface.go
+++ b/roomserver/storage/tables/interface.go
@@ -65,6 +65,8 @@ type Rooms interface {
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
+ SelectRoomIDs(ctx context.Context) ([]string, error)
+ BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)
}
type Transactions interface {
@@ -120,6 +122,7 @@ type Membership interface {
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error
+ SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
}
type Published interface {