aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--roomserver/acls/acls.go47
-rw-r--r--roomserver/acls/acls_test.go56
2 files changed, 91 insertions, 12 deletions
diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go
index 660f4f3b..017682e0 100644
--- a/roomserver/acls/acls.go
+++ b/roomserver/acls/acls.go
@@ -41,15 +41,21 @@ type ServerACLDatabase interface {
}
type ServerACLs struct {
- acls map[string]*serverACL // room ID -> ACL
- aclsMutex sync.RWMutex // protects the above
+ acls map[string]*serverACL // room ID -> ACL
+ aclsMutex sync.RWMutex // protects the above
+ aclRegexCache map[string]**regexp.Regexp // Cache from "serverName" -> pointer to a regex
+ aclRegexCacheMutex sync.RWMutex // protects the above
}
func NewServerACLs(db ServerACLDatabase) *ServerACLs {
ctx := context.TODO()
acls := &ServerACLs{
acls: make(map[string]*serverACL),
+ // Be generous when creating the cache, as in reality
+ // there are hundreds of servers in an ACL.
+ aclRegexCache: make(map[string]**regexp.Regexp, 100),
}
+
// Look up all of the rooms that the current state server knows about.
rooms, err := db.GetKnownRooms(ctx)
if err != nil {
@@ -67,6 +73,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs {
for _, event := range events {
acls.OnServerACLUpdate(event)
}
+
return acls
}
@@ -78,8 +85,8 @@ type ServerACL struct {
type serverACL struct {
ServerACL
- allowedRegexes []*regexp.Regexp
- deniedRegexes []*regexp.Regexp
+ allowedRegexes []**regexp.Regexp
+ deniedRegexes []**regexp.Regexp
}
func compileACLRegex(orig string) (*regexp.Regexp, error) {
@@ -89,6 +96,25 @@ func compileACLRegex(orig string) (*regexp.Regexp, error) {
return regexp.Compile(escaped)
}
+// cachedCompileACLRegex is a wrapper around compileACLRegex with added caching
+func (s *ServerACLs) cachedCompileACLRegex(orig string) (**regexp.Regexp, error) {
+ s.aclRegexCacheMutex.RLock()
+ re, ok := s.aclRegexCache[orig]
+ if ok {
+ s.aclRegexCacheMutex.RUnlock()
+ return re, nil
+ }
+ s.aclRegexCacheMutex.RUnlock()
+ compiled, err := compileACLRegex(orig)
+ if err != nil {
+ return nil, err
+ }
+ s.aclRegexCacheMutex.Lock()
+ defer s.aclRegexCacheMutex.Unlock()
+ s.aclRegexCache[orig] = &compiled
+ return &compiled, nil
+}
+
func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) {
acls := &serverACL{}
if err := json.Unmarshal([]byte(strippedEvent.ContentValue), &acls.ServerACL); err != nil {
@@ -100,14 +126,14 @@ func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) {
// special characters and then replace * and ? with their regex counterparts.
// https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl
for _, orig := range acls.Allowed {
- if expr, err := compileACLRegex(orig); err != nil {
+ if expr, err := s.cachedCompileACLRegex(orig); err != nil {
logrus.WithError(err).Errorf("Failed to compile allowed regex")
} else {
acls.allowedRegexes = append(acls.allowedRegexes, expr)
}
}
for _, orig := range acls.Denied {
- if expr, err := compileACLRegex(orig); err != nil {
+ if expr, err := s.cachedCompileACLRegex(orig); err != nil {
logrus.WithError(err).Errorf("Failed to compile denied regex")
} else {
acls.deniedRegexes = append(acls.deniedRegexes, expr)
@@ -118,6 +144,11 @@ func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) {
"num_allowed": len(acls.allowedRegexes),
"num_denied": len(acls.deniedRegexes),
}).Debugf("Updating server ACLs for %q", strippedEvent.RoomID)
+
+ // Clear out Denied and Allowed, now that we have the compiled regexes.
+ // They are not needed anymore from this point on.
+ acls.Denied = nil
+ acls.Allowed = nil
s.aclsMutex.Lock()
defer s.aclsMutex.Unlock()
s.acls[strippedEvent.RoomID] = acls
@@ -150,14 +181,14 @@ func (s *ServerACLs) IsServerBannedFromRoom(serverName spec.ServerName, roomID s
// Check if the hostname matches one of the denied regexes. If it does then
// the server is banned from the room.
for _, expr := range acls.deniedRegexes {
- if expr.MatchString(string(serverName)) {
+ if (*expr).MatchString(string(serverName)) {
return true
}
}
// Check if the hostname matches one of the allowed regexes. If it does then
// the server is NOT banned from the room.
for _, expr := range acls.allowedRegexes {
- if expr.MatchString(string(serverName)) {
+ if (*expr).MatchString(string(serverName)) {
return false
}
}
diff --git a/roomserver/acls/acls_test.go b/roomserver/acls/acls_test.go
index 9fb6a558..efe1d209 100644
--- a/roomserver/acls/acls_test.go
+++ b/roomserver/acls/acls_test.go
@@ -15,8 +15,14 @@
package acls
import (
+ "context"
"regexp"
"testing"
+
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/gomatrixserverlib/spec"
+ "github.com/stretchr/testify/assert"
)
func TestOpenACLsWithBlacklist(t *testing.T) {
@@ -38,8 +44,8 @@ func TestOpenACLsWithBlacklist(t *testing.T) {
ServerACL: ServerACL{
AllowIPLiterals: true,
},
- allowedRegexes: []*regexp.Regexp{allowRegex},
- deniedRegexes: []*regexp.Regexp{denyRegex},
+ allowedRegexes: []**regexp.Regexp{&allowRegex},
+ deniedRegexes: []**regexp.Regexp{&denyRegex},
}
if acls.IsServerBannedFromRoom("1.2.3.4", roomID) {
@@ -77,8 +83,8 @@ func TestDefaultACLsWithWhitelist(t *testing.T) {
ServerACL: ServerACL{
AllowIPLiterals: false,
},
- allowedRegexes: []*regexp.Regexp{allowRegex},
- deniedRegexes: []*regexp.Regexp{},
+ allowedRegexes: []**regexp.Regexp{&allowRegex},
+ deniedRegexes: []**regexp.Regexp{},
}
if !acls.IsServerBannedFromRoom("1.2.3.4", roomID) {
@@ -103,3 +109,45 @@ func TestDefaultACLsWithWhitelist(t *testing.T) {
t.Fatal("Expected qux.com:4567 to be allowed but wasn't")
}
}
+
+var (
+ content1 = `{"allow":["*"],"allow_ip_literals":false,"deny":["hello.world", "*.hello.world"]}`
+)
+
+type dummyACLDB struct{}
+
+func (d dummyACLDB) GetKnownRooms(ctx context.Context) ([]string, error) {
+ return []string{"1", "2"}, nil
+}
+
+func (d dummyACLDB) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) {
+ return []tables.StrippedEvent{
+ {
+ RoomID: "1",
+ ContentValue: content1,
+ },
+ {
+ RoomID: "2",
+ ContentValue: content1,
+ },
+ }, nil
+}
+
+func TestCachedRegex(t *testing.T) {
+ db := dummyACLDB{}
+ wantBannedServer := spec.ServerName("hello.world")
+
+ acls := NewServerACLs(db)
+
+ // Check that hello.world is banned in room 1
+ banned := acls.IsServerBannedFromRoom(wantBannedServer, "1")
+ assert.True(t, banned)
+
+ // Check that hello.world is banned in room 2
+ banned = acls.IsServerBannedFromRoom(wantBannedServer, "2")
+ assert.True(t, banned)
+
+ // Check that matrix.hello.world is banned in room 2
+ banned = acls.IsServerBannedFromRoom("matrix."+wantBannedServer, "2")
+ assert.True(t, banned)
+}