aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-03-01 17:06:47 +0100
committerGitHub <noreply@github.com>2023-03-01 17:06:47 +0100
commit6c20f8f742a7e03710fae81df6ef98bac31da2b1 (patch)
tree202e962951dc41c949a71c7f5c1deb6d1da78843
parent1aa70b0f56825a4a5f92c38cabb1fe841cec6e18 (diff)
Refactor `StoreEvent`, add `MaybeRedactEvent`, create an `EventDatabase` (#2989)
This PR changes the following: - `StoreEvent` now only stores an event (and possibly prev event), instead of also doing redactions - Adds a `MaybeRedactEvent` (pulled out from `StoreEvent`), which should be called after storing events - a few other things
-rw-r--r--appservice/consumers/roomserver.go1
-rw-r--r--clientapi/routing/redaction.go2
-rw-r--r--cmd/resolve-state/main.go11
-rw-r--r--federationapi/consumers/roomserver.go3
-rw-r--r--federationapi/routing/eventauth.go2
-rw-r--r--federationapi/routing/events.go25
-rw-r--r--federationapi/routing/state.go2
-rw-r--r--internal/hooks/hooks.go4
-rw-r--r--roomserver/api/api.go13
-rw-r--r--roomserver/api/query.go3
-rw-r--r--roomserver/api/wrapper.go3
-rw-r--r--roomserver/internal/helpers/auth.go10
-rw-r--r--roomserver/internal/helpers/helpers.go20
-rw-r--r--roomserver/internal/helpers/helpers_test.go9
-rw-r--r--roomserver/internal/input/input_events.go50
-rw-r--r--roomserver/internal/input/input_membership.go2
-rw-r--r--roomserver/internal/input/input_missing.go8
-rw-r--r--roomserver/internal/perform/perform_admin.go2
-rw-r--r--roomserver/internal/perform/perform_backfill.go32
-rw-r--r--roomserver/internal/perform/perform_inbound_peek.go6
-rw-r--r--roomserver/internal/perform/perform_invite.go4
-rw-r--r--roomserver/internal/query/query.go130
-rw-r--r--roomserver/internal/query/query_test.go6
-rw-r--r--roomserver/roomserver_test.go14
-rw-r--r--roomserver/state/state.go12
-rw-r--r--roomserver/storage/interface.go48
-rw-r--r--roomserver/storage/postgres/storage.go39
-rw-r--r--roomserver/storage/shared/room_updater.go8
-rw-r--r--roomserver/storage/shared/storage.go372
-rw-r--r--roomserver/storage/shared/storage_test.go12
-rw-r--r--roomserver/storage/sqlite3/storage.go41
-rw-r--r--setup/mscs/msc2836/msc2836.go7
-rw-r--r--syncapi/consumers/roomserver.go1
-rw-r--r--syncapi/routing/memberships.go2
34 files changed, 486 insertions, 418 deletions
diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go
index ac68f4bd..528de63e 100644
--- a/appservice/consumers/roomserver.go
+++ b/appservice/consumers/roomserver.go
@@ -122,6 +122,7 @@ func (s *OutputRoomEventConsumer) onMessage(
if len(output.NewRoomEvent.AddsStateEventIDs) > 0 {
newEventID := output.NewRoomEvent.Event.EventID()
eventsReq := &api.QueryEventsByIDRequest{
+ RoomID: output.NewRoomEvent.Event.RoomID(),
EventIDs: make([]string, 0, len(output.NewRoomEvent.AddsStateEventIDs)),
}
eventsRes := &api.QueryEventsByIDResponse{}
diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go
index 7841b3b0..f86bbc8f 100644
--- a/clientapi/routing/redaction.go
+++ b/clientapi/routing/redaction.go
@@ -57,7 +57,7 @@ func SendRedaction(
}
}
- ev := roomserverAPI.GetEvent(req.Context(), rsAPI, eventID)
+ ev := roomserverAPI.GetEvent(req.Context(), rsAPI, roomID, eventID)
if ev == nil {
return util.JSONResponse{
Code: 400,
diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go
index e3840bbc..a9cc80cb 100644
--- a/cmd/resolve-state/main.go
+++ b/cmd/resolve-state/main.go
@@ -62,9 +62,10 @@ func main() {
panic(err)
}
- stateres := state.NewStateResolution(roomserverDB, &types.RoomInfo{
+ roomInfo := &types.RoomInfo{
RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion),
- })
+ }
+ stateres := state.NewStateResolution(roomserverDB, roomInfo)
if *difference {
if len(snapshotNIDs) != 2 {
@@ -87,7 +88,7 @@ func main() {
}
var eventEntries []types.Event
- eventEntries, err = roomserverDB.Events(ctx, 0, eventNIDs)
+ eventEntries, err = roomserverDB.Events(ctx, roomInfo, eventNIDs)
if err != nil {
panic(err)
}
@@ -145,7 +146,7 @@ func main() {
}
fmt.Println("Fetching", len(eventNIDMap), "state events")
- eventEntries, err := roomserverDB.Events(ctx, 0, eventNIDs)
+ eventEntries, err := roomserverDB.Events(ctx, roomInfo, eventNIDs)
if err != nil {
panic(err)
}
@@ -165,7 +166,7 @@ func main() {
}
fmt.Println("Fetching", len(authEventIDs), "auth events")
- authEventEntries, err := roomserverDB.EventsFromIDs(ctx, 0, authEventIDs)
+ authEventEntries, err := roomserverDB.EventsFromIDs(ctx, roomInfo, authEventIDs)
if err != nil {
panic(err)
}
diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go
index 82a4db3f..378b96ba 100644
--- a/federationapi/consumers/roomserver.go
+++ b/federationapi/consumers/roomserver.go
@@ -173,6 +173,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
// Finally, work out if there are any more events missing.
if len(missingEventIDs) > 0 {
eventsReq := &api.QueryEventsByIDRequest{
+ RoomID: ore.Event.RoomID(),
EventIDs: missingEventIDs,
}
eventsRes := &api.QueryEventsByIDResponse{}
@@ -483,7 +484,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents(
// At this point the missing events are neither the event itself nor are
// they present in our local database. Our only option is to fetch them
// from the roomserver using the query API.
- eventReq := api.QueryEventsByIDRequest{EventIDs: missing}
+ eventReq := api.QueryEventsByIDRequest{EventIDs: missing, RoomID: event.RoomID()}
var eventResp api.QueryEventsByIDResponse
if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil {
return nil, err
diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go
index 868785a9..2f1f3baf 100644
--- a/federationapi/routing/eventauth.go
+++ b/federationapi/routing/eventauth.go
@@ -36,7 +36,7 @@ func GetEventAuth(
return *err
}
- event, resErr := fetchEvent(ctx, rsAPI, eventID)
+ event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID)
if resErr != nil {
return *resErr
}
diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go
index 6168912b..b4129241 100644
--- a/federationapi/routing/events.go
+++ b/federationapi/routing/events.go
@@ -20,10 +20,11 @@ import (
"net/http"
"time"
- "github.com/matrix-org/dendrite/clientapi/jsonerror"
- "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
+
+ "github.com/matrix-org/dendrite/clientapi/jsonerror"
+ "github.com/matrix-org/dendrite/roomserver/api"
)
// GetEvent returns the requested event
@@ -38,7 +39,9 @@ func GetEvent(
if err != nil {
return *err
}
- event, err := fetchEvent(ctx, rsAPI, eventID)
+ // /_matrix/federation/v1/event/{eventId} doesn't have a roomID, we use an empty string,
+ // which results in `QueryEventsByID` to first get the event and use that to determine the roomID.
+ event, err := fetchEvent(ctx, rsAPI, "", eventID)
if err != nil {
return *err
}
@@ -60,21 +63,13 @@ func allowedToSeeEvent(
rsAPI api.FederationRoomserverAPI,
eventID string,
) *util.JSONResponse {
- var authResponse api.QueryServerAllowedToSeeEventResponse
- err := rsAPI.QueryServerAllowedToSeeEvent(
- ctx,
- &api.QueryServerAllowedToSeeEventRequest{
- EventID: eventID,
- ServerName: origin,
- },
- &authResponse,
- )
+ allowed, err := rsAPI.QueryServerAllowedToSeeEvent(ctx, origin, eventID)
if err != nil {
resErr := util.ErrorResponse(err)
return &resErr
}
- if !authResponse.AllowedToSeeEvent {
+ if !allowed {
resErr := util.MessageResponse(http.StatusForbidden, "server not allowed to see event")
return &resErr
}
@@ -83,11 +78,11 @@ func allowedToSeeEvent(
}
// fetchEvent fetches the event without auth checks. Returns an error if the event cannot be found.
-func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) {
+func fetchEvent(ctx context.Context, rsAPI api.FederationRoomserverAPI, roomID, eventID string) (*gomatrixserverlib.Event, *util.JSONResponse) {
var eventsResponse api.QueryEventsByIDResponse
err := rsAPI.QueryEventsByID(
ctx,
- &api.QueryEventsByIDRequest{EventIDs: []string{eventID}},
+ &api.QueryEventsByIDRequest{EventIDs: []string{eventID}, RoomID: roomID},
&eventsResponse,
)
if err != nil {
diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go
index 1d08d0a8..1120cf26 100644
--- a/federationapi/routing/state.go
+++ b/federationapi/routing/state.go
@@ -107,7 +107,7 @@ func getState(
return nil, nil, err
}
- event, resErr := fetchEvent(ctx, rsAPI, eventID)
+ event, resErr := fetchEvent(ctx, rsAPI, roomID, eventID)
if resErr != nil {
return nil, nil, resErr
}
diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go
index 223282a2..d6c79e98 100644
--- a/internal/hooks/hooks.go
+++ b/internal/hooks/hooks.go
@@ -16,7 +16,9 @@
// Hooks can only be run in monolith mode.
package hooks
-import "sync"
+import (
+ "sync"
+)
const (
// KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent
diff --git a/roomserver/api/api.go b/roomserver/api/api.go
index 73732ae3..f6d003a4 100644
--- a/roomserver/api/api.go
+++ b/roomserver/api/api.go
@@ -54,7 +54,8 @@ type QueryBulkStateContentAPI interface {
}
type QueryEventsAPI interface {
- // Query a list of events by event ID.
+ // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
+ // which room to use by querying the first events roomID.
QueryEventsByID(
ctx context.Context,
req *QueryEventsByIDRequest,
@@ -71,7 +72,8 @@ type SyncRoomserverAPI interface {
QueryBulkStateContentAPI
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
- // Query a list of events by event ID.
+ // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
+ // which room to use by querying the first events roomID.
QueryEventsByID(
ctx context.Context,
req *QueryEventsByIDRequest,
@@ -108,7 +110,8 @@ type SyncRoomserverAPI interface {
}
type AppserviceRoomserverAPI interface {
- // Query a list of events by event ID.
+ // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
+ // which room to use by querying the first events roomID.
QueryEventsByID(
ctx context.Context,
req *QueryEventsByIDRequest,
@@ -182,6 +185,8 @@ type FederationRoomserverAPI interface {
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
+ // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
+ // which room to use by querying the first events roomID.
QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error
// Query to get state and auth chain for a (potentially hypothetical) event.
// Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate
@@ -193,7 +198,7 @@ type FederationRoomserverAPI interface {
// Query missing events for a room from roomserver
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
// Query whether a server is allowed to see an event
- QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error
+ QueryServerAllowedToSeeEvent(ctx context.Context, serverName gomatrixserverlib.ServerName, eventID string) (allowed bool, err error)
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
QueryRestrictedJoinAllowed(ctx context.Context, req *QueryRestrictedJoinAllowedRequest, res *QueryRestrictedJoinAllowedResponse) error
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
diff --git a/roomserver/api/query.go b/roomserver/api/query.go
index 4ef548e1..24722db0 100644
--- a/roomserver/api/query.go
+++ b/roomserver/api/query.go
@@ -86,6 +86,9 @@ type QueryStateAfterEventsResponse struct {
// QueryEventsByIDRequest is a request to QueryEventsByID
type QueryEventsByIDRequest struct {
+ // The roomID to query events for. If this is empty, we first try to fetch the roomID from the database
+ // as this is needed for further processing/parsing events.
+ RoomID string `json:"room_id"`
// The event IDs to look up.
EventIDs []string `json:"event_ids"`
}
diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go
index 252be557..f220560e 100644
--- a/roomserver/api/wrapper.go
+++ b/roomserver/api/wrapper.go
@@ -108,9 +108,10 @@ func SendInputRoomEvents(
}
// GetEvent returns the event or nil, even on errors.
-func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, eventID string) *gomatrixserverlib.HeaderedEvent {
+func GetEvent(ctx context.Context, rsAPI QueryEventsAPI, roomID, eventID string) *gomatrixserverlib.HeaderedEvent {
var res QueryEventsByIDResponse
err := rsAPI.QueryEventsByID(ctx, &QueryEventsByIDRequest{
+ RoomID: roomID,
EventIDs: []string{eventID},
}, &res)
if err != nil {
diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go
index 27c8dd8f..9defe794 100644
--- a/roomserver/internal/helpers/auth.go
+++ b/roomserver/internal/helpers/auth.go
@@ -67,7 +67,7 @@ func CheckForSoftFail(
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database.
- authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomNID, stateNeeded, authStateEntries)
+ authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
if err != nil {
return true, fmt.Errorf("loadAuthEvents: %w", err)
}
@@ -85,7 +85,7 @@ func CheckForSoftFail(
func CheckAuthEvents(
ctx context.Context,
db storage.RoomDatabase,
- roomNID types.RoomNID,
+ roomInfo *types.RoomInfo,
event *gomatrixserverlib.HeaderedEvent,
authEventIDs []string,
) ([]types.EventNID, error) {
@@ -100,7 +100,7 @@ func CheckAuthEvents(
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database.
- authEvents, err := loadAuthEvents(ctx, db, roomNID, stateNeeded, authStateEntries)
+ authEvents, err := loadAuthEvents(ctx, db, roomInfo, stateNeeded, authStateEntries)
if err != nil {
return nil, fmt.Errorf("loadAuthEvents: %w", err)
}
@@ -193,7 +193,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
func loadAuthEvents(
ctx context.Context,
db state.StateResolutionStorage,
- roomNID types.RoomNID,
+ roomInfo *types.RoomInfo,
needed gomatrixserverlib.StateNeeded,
state []types.StateEntry,
) (result authEvents, err error) {
@@ -216,7 +216,7 @@ func loadAuthEvents(
eventNIDs = append(eventNIDs, eventNID)
}
}
- if result.events, err = db.Events(ctx, roomNID, eventNIDs); err != nil {
+ if result.events, err = db.Events(ctx, roomInfo, eventNIDs); err != nil {
return
}
roomID := ""
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go
index ee1610cf..9a70bcc9 100644
--- a/roomserver/internal/helpers/helpers.go
+++ b/roomserver/internal/helpers/helpers.go
@@ -85,7 +85,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
return false, err
}
- events, err := db.Events(ctx, info.RoomNID, eventNIDs)
+ events, err := db.Events(ctx, info, eventNIDs)
if err != nil {
return false, err
}
@@ -157,7 +157,7 @@ func IsInvitePending(
// only keep the "m.room.member" events with a "join" membership. These events are returned.
// Returns an error if there was an issue fetching the events.
func GetMembershipsAtState(
- ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, joinedOnly bool,
+ ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, joinedOnly bool,
) ([]types.Event, error) {
var eventNIDs types.EventNIDs
@@ -177,7 +177,7 @@ func GetMembershipsAtState(
util.Unique(eventNIDs)
// Get all of the events in this state
- stateEvents, err := db.Events(ctx, roomNID, eventNIDs)
+ stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
if err != nil {
return nil, err
}
@@ -227,9 +227,9 @@ func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types
}
func LoadEvents(
- ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, eventNIDs []types.EventNID,
+ ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID,
) ([]*gomatrixserverlib.Event, error) {
- stateEvents, err := db.Events(ctx, roomNID, eventNIDs)
+ stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
if err != nil {
return nil, err
}
@@ -242,13 +242,13 @@ func LoadEvents(
}
func LoadStateEvents(
- ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry,
+ ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
) ([]*gomatrixserverlib.Event, error) {
eventNIDs := make([]types.EventNID, len(stateEntries))
for i := range stateEntries {
eventNIDs[i] = stateEntries[i].EventNID
}
- return LoadEvents(ctx, db, roomNID, eventNIDs)
+ return LoadEvents(ctx, db, roomInfo, eventNIDs)
}
func CheckServerAllowedToSeeEvent(
@@ -326,7 +326,7 @@ func slowGetHistoryVisibilityState(
return nil, nil
}
- return LoadStateEvents(ctx, db, info.RoomNID, filteredEntries)
+ return LoadStateEvents(ctx, db, info, filteredEntries)
}
// TODO: Remove this when we have tests to assert correctness of this function
@@ -366,7 +366,7 @@ BFSLoop:
next = make([]string, 0)
}
// Retrieve the events to process from the database.
- events, err = db.EventsFromIDs(ctx, info.RoomNID, front)
+ events, err = db.EventsFromIDs(ctx, info, front)
if err != nil {
return resultNIDs, redactEventIDs, err
}
@@ -467,7 +467,7 @@ func QueryLatestEventsAndState(
return err
}
- stateEvents, err := LoadStateEvents(ctx, db, roomInfo.RoomNID, stateEntries)
+ stateEvents, err := LoadStateEvents(ctx, db, roomInfo, stateEntries)
if err != nil {
return err
}
diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go
index 62730df1..c056e704 100644
--- a/roomserver/internal/helpers/helpers_test.go
+++ b/roomserver/internal/helpers/helpers_test.go
@@ -4,9 +4,10 @@ import (
"context"
"testing"
- "github.com/matrix-org/dendrite/roomserver/types"
"github.com/stretchr/testify/assert"
+ "github.com/matrix-org/dendrite/roomserver/types"
+
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/test"
@@ -38,9 +39,9 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
var authNIDs []types.EventNID
for _, x := range room.Events() {
- roomNID, err := db.GetOrCreateRoomNID(context.Background(), x.Unwrap())
+ roomInfo, err := db.GetOrCreateRoomInfo(context.Background(), x.Unwrap())
assert.NoError(t, err)
- assert.Greater(t, roomNID, types.RoomNID(0))
+ assert.NotNil(t, roomInfo)
eventTypeNID, err := db.GetOrCreateEventTypeNID(context.Background(), x.Type())
assert.NoError(t, err)
@@ -49,7 +50,7 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey())
assert.NoError(t, err)
- evNID, _, _, _, err := db.StoreEvent(context.Background(), x.Event, roomNID, eventTypeNID, eventStateKeyNID, authNIDs, false)
+ evNID, _, err := db.StoreEvent(context.Background(), x.Event, roomInfo, eventTypeNID, eventStateKeyNID, authNIDs, false)
assert.NoError(t, err)
authNIDs = append(authNIDs, evNID)
}
diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go
index fe35efb2..ede345a9 100644
--- a/roomserver/internal/input/input_events.go
+++ b/roomserver/internal/input/input_events.go
@@ -24,9 +24,10 @@ import (
"fmt"
"time"
- "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/tidwall/gjson"
+ "github.com/matrix-org/dendrite/roomserver/internal/helpers"
+
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
@@ -274,8 +275,10 @@ func (r *Inputer) processRoomEvent(
// Check if the event is allowed by its auth events. If it isn't then
// we consider the event to be "rejected" — it will still be persisted.
+ redactAllowed := true
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
isRejected = true
+ redactAllowed = false
rejectionErr = err
logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID())
}
@@ -323,7 +326,7 @@ func (r *Inputer) processRoomEvent(
// burning CPU time.
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent {
- historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo.RoomNID, input, missingPrev)
+ historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo, input, missingPrev)
if err != nil {
return fmt.Errorf("r.processStateBefore: %w", err)
}
@@ -332,9 +335,11 @@ func (r *Inputer) processRoomEvent(
}
}
- roomNID, err := r.DB.GetOrCreateRoomNID(ctx, event)
- if err != nil {
- return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err)
+ if roomInfo == nil {
+ roomInfo, err = r.DB.GetOrCreateRoomInfo(ctx, event)
+ if err != nil {
+ return fmt.Errorf("r.DB.GetOrCreateRoomInfo: %w", err)
+ }
}
eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type())
@@ -348,15 +353,24 @@ func (r *Inputer) processRoomEvent(
}
// Store the event.
- _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
+ eventNID, stateAtEvent, err := r.DB.StoreEvent(ctx, event, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
if err != nil {
return fmt.Errorf("updater.StoreEvent: %w", err)
}
// if storing this event results in it being redacted then do so.
- if !isRejected && redactedEventID == event.EventID() {
- if err = eventutil.RedactEvent(redactionEvent, event); err != nil {
- return fmt.Errorf("eventutil.RedactEvent: %w", rerr)
+ var (
+ redactedEventID string
+ redactionEvent *gomatrixserverlib.Event
+ redactedEvent *gomatrixserverlib.Event
+ )
+ if !isRejected && !isCreateEvent {
+ redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, redactAllowed)
+ if err != nil {
+ return err
+ }
+ if redactedEvent != nil {
+ redactedEventID = redactedEvent.EventID()
}
}
@@ -489,7 +503,7 @@ func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixse
// nolint:nakedret
func (r *Inputer) processStateBefore(
ctx context.Context,
- roomNID types.RoomNID,
+ roomInfo *types.RoomInfo,
input *api.InputRoomEvent,
missingPrev bool,
) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) {
@@ -505,7 +519,7 @@ func (r *Inputer) processStateBefore(
case input.HasState:
// If we're overriding the state then we need to go and retrieve
// them from the database. It's a hard error if they are missing.
- stateEvents, err := r.DB.EventsFromIDs(ctx, roomNID, input.StateEventIDs)
+ stateEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, input.StateEventIDs)
if err != nil {
return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err)
}
@@ -604,7 +618,7 @@ func (r *Inputer) fetchAuthEvents(
}
for _, authEventID := range authEventIDs {
- authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo.RoomNID, []string{authEventID})
+ authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, []string{authEventID})
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
unknown[authEventID] = struct{}{}
continue
@@ -690,9 +704,11 @@ nextAuthEvent:
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
}
- roomNID, err := r.DB.GetOrCreateRoomNID(ctx, authEvent)
- if err != nil {
- return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err)
+ if roomInfo == nil {
+ roomInfo, err = r.DB.GetOrCreateRoomInfo(ctx, authEvent)
+ if err != nil {
+ return fmt.Errorf("r.DB.GetOrCreateRoomInfo: %w", err)
+ }
}
eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type())
@@ -706,7 +722,7 @@ nextAuthEvent:
}
// Finally, store the event in the database.
- eventNID, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
+ eventNID, _, err := r.DB.StoreEvent(ctx, authEvent, roomInfo, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
if err != nil {
return fmt.Errorf("updater.StoreEvent: %w", err)
}
@@ -782,7 +798,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event
return err
}
- memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, membershipNIDs)
+ memberEvents, err := r.DB.Events(ctx, roomInfo, membershipNIDs)
if err != nil {
return err
}
diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go
index 99a01255..e1dfa6cf 100644
--- a/roomserver/internal/input/input_membership.go
+++ b/roomserver/internal/input/input_membership.go
@@ -53,7 +53,7 @@ func (r *Inputer) updateMemberships(
// Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON?
- events, err := updater.Events(ctx, 0, eventNIDs)
+ events, err := updater.Events(ctx, nil, eventNIDs)
if err != nil {
return nil, err
}
diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go
index c8b7d31d..9627f15a 100644
--- a/roomserver/internal/input/input_missing.go
+++ b/roomserver/internal/input/input_missing.go
@@ -395,7 +395,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
for _, entry := range stateEntries {
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
- stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomNID, stateEventNIDs)
+ stateEvents, err := t.db.Events(ctx, t.roomInfo, stateEventNIDs)
if err != nil {
t.log.WithError(err).Warnf("failed to load state events locally")
return nil
@@ -432,7 +432,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
missingEventList = append(missingEventList, evID)
}
t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events")
- events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList)
+ events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList)
if err != nil {
return nil
}
@@ -702,7 +702,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
}
t.haveEventsMutex.Unlock()
- events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList)
+ events, err := t.db.EventsFromIDs(ctx, t.roomInfo, missingEventList)
if err != nil {
return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err)
}
@@ -844,7 +844,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
if localFirst {
// fetch from the roomserver
- events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, []string{missingEventID})
+ events, err := t.db.EventsFromIDs(ctx, t.roomInfo, []string{missingEventID})
if err != nil {
t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err)
} else if len(events) == 1 {
diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go
index 2efe2255..45089bdd 100644
--- a/roomserver/internal/perform/perform_admin.go
+++ b/roomserver/internal/perform/perform_admin.go
@@ -70,7 +70,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
return nil
}
- memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, memberNIDs)
+ memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go
index 3a3a049d..411f4202 100644
--- a/roomserver/internal/perform/perform_backfill.go
+++ b/roomserver/internal/perform/perform_backfill.go
@@ -23,7 +23,6 @@ import (
"github.com/sirupsen/logrus"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
- "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/auth"
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
@@ -86,7 +85,7 @@ func (r *Backfiller) PerformBackfill(
// Retrieve events from the list that was filled previously. If we fail to get
// events from the database then attempt once to get them from federation instead.
var loadedEvents []*gomatrixserverlib.Event
- loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs)
+ loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info, resultNIDs)
if err != nil {
if _, ok := err.(types.MissingEventError); ok {
return r.backfillViaFederation(ctx, request, response)
@@ -473,7 +472,7 @@ FindSuccessor:
// Retrieve all "m.room.member" state events of "join" membership, which
// contains the list of users in the room before the event, therefore all
// the servers in it at that moment.
- memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info.RoomNID, stateEntries, true)
+ memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info, stateEntries, true)
if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
return nil
@@ -532,7 +531,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
roomNID = nid.RoomNID
}
}
- eventsWithNids, err := b.db.Events(ctx, roomNID, eventNIDs)
+ eventsWithNids, err := b.db.Events(ctx, &b.roomInfo, eventNIDs)
if err != nil {
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
return nil, err
@@ -562,7 +561,7 @@ func joinEventsFromHistoryVisibility(
}
// Get all of the events in this state
- stateEvents, err := db.Events(ctx, roomInfo.RoomNID, eventNIDs)
+ stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
if err != nil {
// even though the default should be shared, restricting the visibility to joined
// feels more secure here.
@@ -585,7 +584,7 @@ func joinEventsFromHistoryVisibility(
if err != nil {
return nil, visibility, err
}
- evs, err := db.Events(ctx, roomInfo.RoomNID, joinEventNIDs)
+ evs, err := db.Events(ctx, roomInfo, joinEventNIDs)
return evs, visibility, err
}
@@ -606,7 +605,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
i++
}
- roomNID, err = db.GetOrCreateRoomNID(ctx, ev.Unwrap())
+ roomInfo, err := db.GetOrCreateRoomInfo(ctx, ev.Unwrap())
if err != nil {
logrus.WithError(err).Error("failed to get or create roomNID")
continue
@@ -624,23 +623,22 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
continue
}
- var redactedEventID string
- var redactionEvent *gomatrixserverlib.Event
- eventNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), roomNID, eventTypeNID, eventStateKeyNID, authNids, false)
+ eventNID, _, err = db.StoreEvent(ctx, ev.Unwrap(), roomInfo, eventTypeNID, eventStateKeyNID, authNids, false)
if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
continue
}
+
+ _, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev.Unwrap(), true)
+ if err != nil {
+ logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
+ continue
+ }
// If storing this event results in it being redacted, then do so.
// It's also possible for this event to be a redaction which results in another event being
// redacted, which we don't care about since we aren't returning it in this backfill.
- if redactedEventID == ev.EventID() {
- eventToRedact := ev.Unwrap()
- if err := eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil {
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
- continue
- }
- ev = eventToRedact.Headered(ev.RoomVersion)
+ if redactedEvent != nil && redactedEvent.EventID() == ev.EventID() {
+ ev = redactedEvent.Headered(ev.RoomVersion)
events[j] = ev
}
backfilledEventMap[ev.EventID()] = types.Event{
diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go
index 9ac9edc4..1fb6eb43 100644
--- a/roomserver/internal/perform/perform_inbound_peek.go
+++ b/roomserver/internal/perform/perform_inbound_peek.go
@@ -64,7 +64,7 @@ func (r *InboundPeeker) PerformInboundPeek(
if err != nil {
return err
}
- latestEvents, err := r.DB.EventsFromIDs(ctx, info.RoomNID, []string{latestEventRefs[0].EventID})
+ latestEvents, err := r.DB.EventsFromIDs(ctx, info, []string{latestEventRefs[0].EventID})
if err != nil {
return err
}
@@ -88,7 +88,7 @@ func (r *InboundPeeker) PerformInboundPeek(
if err != nil {
return err
}
- stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries)
+ stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info, stateEntries)
if err != nil {
return err
}
@@ -100,7 +100,7 @@ func (r *InboundPeeker) PerformInboundPeek(
}
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
- authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
+ authEvents, err := query.GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
if err != nil {
return err
}
diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go
index 118e1b87..13d13f7b 100644
--- a/roomserver/internal/perform/perform_invite.go
+++ b/roomserver/internal/perform/perform_invite.go
@@ -194,7 +194,7 @@ func (r *Inviter) PerformInvite(
// try and see if the user is allowed to make this invite. We can't do
// this for invites coming in over federation - we have to take those on
// trust.
- _, err = helpers.CheckAuthEvents(ctx, r.DB, info.RoomNID, event, event.AuthEventIDs())
+ _, err = helpers.CheckAuthEvents(ctx, r.DB, info, event, event.AuthEventIDs())
if err != nil {
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
"processInviteEvent.checkAuthEvents failed for event",
@@ -291,7 +291,7 @@ func buildInviteStrippedState(
for _, stateNID := range stateEntries {
stateNIDs = append(stateNIDs, stateNID.EventNID)
}
- stateEvents, err := db.Events(ctx, info.RoomNID, stateNIDs)
+ stateEvents, err := db.Events(ctx, info, stateNIDs)
if err != nil {
return nil, err
}
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index ac34e0ff..c5b74422 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -21,11 +21,12 @@ import (
"errors"
"fmt"
- "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
+
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/acls"
@@ -102,7 +103,7 @@ func (r *Queryer) QueryStateAfterEvents(
return err
}
- stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries)
+ stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info, stateEntries)
if err != nil {
return err
}
@@ -114,7 +115,7 @@ func (r *Queryer) QueryStateAfterEvents(
}
authEventIDs = util.UniqueStrings(authEventIDs)
- authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
+ authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
if err != nil {
return fmt.Errorf("getAuthChain: %w", err)
}
@@ -132,24 +133,46 @@ func (r *Queryer) QueryStateAfterEvents(
return nil
}
-// QueryEventsByID implements api.RoomserverInternalAPI
+// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
+// which room to use by querying the first events roomID.
func (r *Queryer) QueryEventsByID(
ctx context.Context,
request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse,
) error {
- events, err := r.DB.EventsFromIDs(ctx, 0, request.EventIDs)
+ if len(request.EventIDs) == 0 {
+ return nil
+ }
+ var err error
+ // We didn't receive a room ID, we need to fetch it first before we can continue.
+ // This happens for e.g. ` /_matrix/federation/v1/event/{eventId}`
+ var roomInfo *types.RoomInfo
+ if request.RoomID == "" {
+ var eventNIDs map[string]types.EventMetadata
+ eventNIDs, err = r.DB.EventNIDs(ctx, []string{request.EventIDs[0]})
+ if err != nil {
+ return err
+ }
+ if len(eventNIDs) == 0 {
+ return nil
+ }
+ roomInfo, err = r.DB.RoomInfoByNID(ctx, eventNIDs[request.EventIDs[0]].RoomNID)
+ } else {
+ roomInfo, err = r.DB.RoomInfo(ctx, request.RoomID)
+ }
+ if err != nil {
+ return err
+ }
+ if roomInfo == nil {
+ return nil
+ }
+ events, err := r.DB.EventsFromIDs(ctx, roomInfo, request.EventIDs)
if err != nil {
return err
}
for _, event := range events {
- roomVersion, verr := r.roomVersion(event.RoomID())
- if verr != nil {
- return verr
- }
-
- response.Events = append(response.Events, event.Headered(roomVersion))
+ response.Events = append(response.Events, event.Headered(roomInfo.RoomVersion))
}
return nil
@@ -186,7 +209,7 @@ func (r *Queryer) QueryMembershipForUser(
response.IsInRoom = stillInRoom
response.HasBeenInRoom = true
- evs, err := r.DB.Events(ctx, info.RoomNID, []types.EventNID{membershipEventNID})
+ evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID})
if err != nil {
return err
}
@@ -268,10 +291,10 @@ func (r *Queryer) QueryMembershipAtEvent(
// once. If we have more than one membership event, we need to get the state for each state entry.
if canShortCircuit {
if len(memberships) == 0 {
- memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
+ memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
}
} else {
- memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
+ memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
}
if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err)
@@ -318,7 +341,7 @@ func (r *Queryer) QueryMembershipsForRoom(
}
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
}
- events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
+ events, err = r.DB.Events(ctx, info, eventNIDs)
if err != nil {
return fmt.Errorf("r.DB.Events: %w", err)
}
@@ -357,14 +380,14 @@ func (r *Queryer) QueryMembershipsForRoom(
return err
}
- events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
+ events, err = r.DB.Events(ctx, info, eventNIDs)
} else {
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err
}
- events, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntries, request.JoinedOnly)
+ events, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntries, request.JoinedOnly)
}
if err != nil {
@@ -412,39 +435,39 @@ func (r *Queryer) QueryServerJoinedToRoom(
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
func (r *Queryer) QueryServerAllowedToSeeEvent(
ctx context.Context,
- request *api.QueryServerAllowedToSeeEventRequest,
- response *api.QueryServerAllowedToSeeEventResponse,
-) (err error) {
- events, err := r.DB.EventsFromIDs(ctx, 0, []string{request.EventID})
+ serverName gomatrixserverlib.ServerName,
+ eventID string,
+) (allowed bool, err error) {
+ events, err := r.DB.EventNIDs(ctx, []string{eventID})
if err != nil {
return
}
if len(events) == 0 {
- response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see
- return
- }
- roomID := events[0].RoomID()
-
- inRoomReq := &api.QueryServerJoinedToRoomRequest{
- RoomID: roomID,
- ServerName: request.ServerName,
+ return allowed, nil
}
- inRoomRes := &api.QueryServerJoinedToRoomResponse{}
- if err = r.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil {
- return fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err)
- }
-
- info, err := r.DB.RoomInfo(ctx, roomID)
+ info, err := r.DB.RoomInfoByNID(ctx, events[eventID].RoomNID)
if err != nil {
- return err
+ return allowed, err
}
if info == nil || info.IsStub() {
- return nil
+ return allowed, nil
}
- response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
- ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom,
+ var isInRoom bool
+ if r.IsLocalServerName(serverName) || serverName == "" {
+ isInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID)
+ if err != nil {
+ return allowed, fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err)
+ }
+ } else {
+ isInRoom, err = r.DB.GetServerInRoom(ctx, info.RoomNID, serverName)
+ if err != nil {
+ return allowed, fmt.Errorf("r.DB.GetServerInRoom: %w", err)
+ }
+ }
+
+ return helpers.CheckServerAllowedToSeeEvent(
+ ctx, r.DB, info, eventID, serverName, isInRoom,
)
- return
}
// QueryMissingEvents implements api.RoomserverInternalAPI
@@ -466,19 +489,22 @@ func (r *Queryer) QueryMissingEvents(
eventsToFilter[id] = true
}
}
- events, err := r.DB.EventsFromIDs(ctx, 0, front)
+ if len(front) == 0 {
+ return nil // no events to query, give up.
+ }
+ events, err := r.DB.EventNIDs(ctx, []string{front[0]})
if err != nil {
return err
}
if len(events) == 0 {
return nil // we are missing the events being asked to search from, give up.
}
- info, err := r.DB.RoomInfo(ctx, events[0].RoomID())
+ info, err := r.DB.RoomInfoByNID(ctx, events[front[0]].RoomNID)
if err != nil {
return err
}
if info == nil || info.IsStub() {
- return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
+ return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID)
}
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
@@ -486,7 +512,7 @@ func (r *Queryer) QueryMissingEvents(
return err
}
- loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs)
+ loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info, resultNIDs)
if err != nil {
return err
}
@@ -529,7 +555,7 @@ func (r *Queryer) QueryStateAndAuthChain(
// TODO: this probably means it should be a different query operation...
if request.OnlyFetchAuthChain {
var authEvents []*gomatrixserverlib.Event
- authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, request.AuthEventIDs)
+ authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs)
if err != nil {
return err
}
@@ -556,7 +582,7 @@ func (r *Queryer) QueryStateAndAuthChain(
}
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
- authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
+ authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
if err != nil {
return err
}
@@ -611,18 +637,18 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI
return nil, rejected, false, err
}
- events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo.RoomNID, stateEntries)
+ events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo, stateEntries)
return events, rejected, false, err
}
-type eventsFromIDs func(context.Context, types.RoomNID, []string) ([]types.Event, error)
+type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Event, error)
// GetAuthChain fetches the auth chain for the given auth events. An auth chain
// is the list of all events that are referenced in the auth_events section, and
// all their auth_events, recursively. The returned set of events contain the
// given events. Will *not* error if we don't have all auth events.
func GetAuthChain(
- ctx context.Context, fn eventsFromIDs, authEventIDs []string,
+ ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string,
) ([]*gomatrixserverlib.Event, error) {
// List of event IDs to fetch. On each pass, these events will be requested
// from the database and the `eventsToFetch` will be updated with any new
@@ -633,7 +659,7 @@ func GetAuthChain(
for len(eventsToFetch) > 0 {
// Try to retrieve the events from the database.
- events, err := fn(ctx, 0, eventsToFetch)
+ events, err := fn(ctx, roomInfo, eventsToFetch)
if err != nil {
return nil, err
}
@@ -852,7 +878,7 @@ func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryS
}
func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error {
- chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs)
+ chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, nil, req.EventIDs)
if err != nil {
return err
}
@@ -971,7 +997,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
// For each of the joined users, let's see if we can get a valid
// membership event.
for _, joinNID := range joinNIDs {
- events, err := r.DB.Events(ctx, roomInfo.RoomNID, []types.EventNID{joinNID})
+ events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID})
if err != nil || len(events) != 1 {
continue
}
diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go
index 16761157..265f326d 100644
--- a/roomserver/internal/query/query_test.go
+++ b/roomserver/internal/query/query_test.go
@@ -80,7 +80,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error {
}
// EventsFromIDs implements RoomserverInternalAPIEventDB
-func (db *getEventDB) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) (res []types.Event, err error) {
+func (db *getEventDB) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) (res []types.Event, err error) {
for _, evID := range eventIDs {
res = append(res, types.Event{
EventNID: 0,
@@ -106,7 +106,7 @@ func TestGetAuthChainSingle(t *testing.T) {
t.Fatalf("Failed to add events to db: %v", err)
}
- result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e"})
+ result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e"})
if err != nil {
t.Fatalf("getAuthChain failed: %v", err)
}
@@ -139,7 +139,7 @@ func TestGetAuthChainMultiple(t *testing.T) {
t.Fatalf("Failed to add events to db: %v", err)
}
- result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e", "f"})
+ result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e", "f"})
if err != nil {
t.Fatalf("getAuthChain failed: %v", err)
}
diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go
index 304311c4..cfa27e54 100644
--- a/roomserver/roomserver_test.go
+++ b/roomserver/roomserver_test.go
@@ -278,6 +278,16 @@ func TestPurgeRoom(t *testing.T) {
if roomInfo == nil {
t.Fatalf("room does not exist")
}
+
+ //
+ roomInfo2, err := db.RoomInfoByNID(ctx, roomInfo.RoomNID)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(roomInfo, roomInfo2) {
+ t.Fatalf("expected roomInfos to be the same, but they aren't")
+ }
+
// remember the roomInfo before purging
existingRoomInfo := roomInfo
@@ -333,6 +343,10 @@ func TestPurgeRoom(t *testing.T) {
if roomInfo != nil {
t.Fatalf("room should not exist after purging: %+v", roomInfo)
}
+ roomInfo2, err = db.RoomInfoByNID(ctx, existingRoomInfo.RoomNID)
+ if err == nil {
+ t.Fatalf("expected room to not exist, but it does: %#v", roomInfo2)
+ }
// validation below
diff --git a/roomserver/state/state.go b/roomserver/state/state.go
index cec542d7..9af2f857 100644
--- a/roomserver/state/state.go
+++ b/roomserver/state/state.go
@@ -41,8 +41,8 @@ type StateResolutionStorage interface {
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
- Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error)
- EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error)
+ Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
+ EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
}
type StateResolution struct {
@@ -975,7 +975,7 @@ func (v *StateResolution) resolveConflictsV2(
// Store the newly found auth events in the auth set for this event.
var authEventMap map[string]types.StateEntry
- authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo.RoomNID, conflictedEvent, knownAuthEvents)
+ authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, v.roomInfo, conflictedEvent, knownAuthEvents)
if err != nil {
return err
}
@@ -1091,7 +1091,7 @@ func (v *StateResolution) loadStateEvents(
eventNIDs = append(eventNIDs, entry.EventNID)
}
}
- events, err := v.db.Events(ctx, v.roomInfo.RoomNID, eventNIDs)
+ events, err := v.db.Events(ctx, v.roomInfo, eventNIDs)
if err != nil {
return nil, nil, err
}
@@ -1120,7 +1120,7 @@ type authEventLoader struct {
// loadAuthEvents loads all of the auth events for a given event recursively,
// along with a map that contains state entries for all of the auth events.
func (l *authEventLoader) loadAuthEvents(
- ctx context.Context, roomNID types.RoomNID, event *gomatrixserverlib.Event, eventMap map[string]types.Event,
+ ctx context.Context, roomInfo *types.RoomInfo, event *gomatrixserverlib.Event, eventMap map[string]types.Event,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
l.Lock()
defer l.Unlock()
@@ -1155,7 +1155,7 @@ func (l *authEventLoader) loadAuthEvents(
// If we need to get events from the database, go and fetch
// those now.
if len(l.lookupFromDB) > 0 {
- eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomNID, l.lookupFromDB)
+ eventsFromDB, err := l.v.db.EventsFromIDs(ctx, roomInfo, l.lookupFromDB)
if err != nil {
return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
}
diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go
index 88ec5667..a41a8a9b 100644
--- a/roomserver/storage/interface.go
+++ b/roomserver/storage/interface.go
@@ -29,6 +29,7 @@ type Database interface {
SupportsConcurrentRoomInputs() bool
// RoomInfo returns room information for the given room ID, or nil if there is no room.
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
+ RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error)
// Store the room state at an event in the database
AddState(
ctx context.Context,
@@ -69,12 +70,12 @@ type Database interface {
) ([]types.StateEntryList, error)
// Look up the Events for a list of numeric event IDs.
// Returns a sorted list of events.
- Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error)
+ Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
- // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
- StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
+ // Stores a matrix room event in the database. Returns the room NID, the state snapshot or an error.
+ StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
// Look up the state entries for a list of string event IDs
// Returns an error if the there is an error talking to the database
// Returns a types.MissingEventError if the event IDs aren't in the database.
@@ -135,7 +136,7 @@ type Database interface {
// EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was
// not found.
// Returns an error if the retrieval went wrong.
- EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error)
+ EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
// Publish or unpublish a room from the room directory.
PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error
// Returns a list of room IDs for rooms which are published.
@@ -179,36 +180,53 @@ type Database interface {
GetMembershipForHistoryVisibility(
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error)
- GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error)
+ GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
+ MaybeRedactEvent(
+ ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool,
+ ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error)
}
type RoomDatabase interface {
+ EventDatabase
// RoomInfo returns room information for the given room ID, or nil if there is no room.
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
+ RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error)
// IsEventRejected returns true if the event is known and rejected.
IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error)
MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error)
- // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
- StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error)
- StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error)
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
- SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
- StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
- Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error)
- EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error)
LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
- EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
- EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
- GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error)
+ GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (*types.RoomInfo, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
+ GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
+}
+
+type EventDatabase interface {
+ EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
+ EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
+ EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
+ StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
+ EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventMetadata, error)
+ SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
+ StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
+ SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
+ EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
+ EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
+ Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
+ // MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error
+ // (nil if there was nothing to do)
+ MaybeRedactEvent(
+ ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool,
+ ) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error)
+ StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, error)
}
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go
index 87208438..d98a5cf9 100644
--- a/roomserver/storage/postgres/storage.go
+++ b/roomserver/storage/postgres/storage.go
@@ -194,23 +194,28 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
return err
}
d.Database = shared.Database{
- DB: db,
- Cache: cache,
- Writer: writer,
- EventTypesTable: eventTypes,
- EventStateKeysTable: eventStateKeys,
- EventJSONTable: eventJSON,
- EventsTable: events,
- RoomsTable: rooms,
- StateBlockTable: stateBlock,
- StateSnapshotTable: stateSnapshot,
- PrevEventsTable: prevEvents,
- RoomAliasesTable: roomAliases,
- InvitesTable: invites,
- MembershipTable: membership,
- PublishedTable: published,
- RedactionsTable: redactions,
- Purge: purge,
+ DB: db,
+ EventDatabase: shared.EventDatabase{
+ DB: db,
+ Cache: cache,
+ Writer: writer,
+ EventsTable: events,
+ EventJSONTable: eventJSON,
+ EventTypesTable: eventTypes,
+ EventStateKeysTable: eventStateKeys,
+ PrevEventsTable: prevEvents,
+ RedactionsTable: redactions,
+ },
+ Cache: cache,
+ Writer: writer,
+ RoomsTable: rooms,
+ StateBlockTable: stateBlock,
+ StateSnapshotTable: stateSnapshot,
+ RoomAliasesTable: roomAliases,
+ InvitesTable: invites,
+ MembershipTable: membership,
+ PublishedTable: published,
+ Purge: purge,
}
return nil
}
diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go
index 5006c3c5..dc1db082 100644
--- a/roomserver/storage/shared/room_updater.go
+++ b/roomserver/storage/shared/room_updater.go
@@ -116,8 +116,8 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent
})
}
-func (u *RoomUpdater) Events(ctx context.Context, _ types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) {
- return u.d.events(ctx, u.txn, u.roomInfo.RoomNID, eventNIDs)
+func (u *RoomUpdater) Events(ctx context.Context, _ *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) {
+ return u.d.events(ctx, u.txn, u.roomInfo, eventNIDs)
}
func (u *RoomUpdater) SnapshotNIDFromEventID(
@@ -195,8 +195,8 @@ func (u *RoomUpdater) StateAtEventIDs(
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
}
-func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) {
- return u.d.eventsFromIDs(ctx, u.txn, roomNID, eventIDs, NoFilter)
+func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) {
+ return u.d.eventsFromIDs(ctx, u.txn, u.roomInfo, eventIDs, NoFilter)
}
// IsReferenced implements types.RoomRecentEventsUpdater
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index aac5bc36..be3f228d 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -9,7 +9,6 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
- "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/caching"
@@ -28,6 +27,23 @@ import (
const redactionsArePermanent = true
type Database struct {
+ DB *sql.DB
+ EventDatabase
+ Cache caching.RoomServerCaches
+ Writer sqlutil.Writer
+ RoomsTable tables.Rooms
+ StateSnapshotTable tables.StateSnapshot
+ StateBlockTable tables.StateBlock
+ RoomAliasesTable tables.RoomAliases
+ InvitesTable tables.Invites
+ MembershipTable tables.Membership
+ PublishedTable tables.Published
+ Purge tables.Purge
+ GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
+}
+
+// EventDatabase contains all tables needed to work with events
+type EventDatabase struct {
DB *sql.DB
Cache caching.RoomServerCaches
Writer sqlutil.Writer
@@ -35,17 +51,8 @@ type Database struct {
EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes
EventStateKeysTable tables.EventStateKeys
- RoomsTable tables.Rooms
- StateSnapshotTable tables.StateSnapshot
- StateBlockTable tables.StateBlock
- RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents
- InvitesTable tables.Invites
- MembershipTable tables.Membership
- PublishedTable tables.Published
RedactionsTable tables.Redactions
- Purge tables.Purge
- GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
}
func (d *Database) SupportsConcurrentRoomInputs() bool {
@@ -58,13 +65,13 @@ func (d *Database) GetMembershipForHistoryVisibility(
return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...)
}
-func (d *Database) EventTypeNIDs(
+func (d *EventDatabase) EventTypeNIDs(
ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
return d.eventTypeNIDs(ctx, nil, eventTypes)
}
-func (d *Database) eventTypeNIDs(
+func (d *EventDatabase) eventTypeNIDs(
ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
result := make(map[string]types.EventTypeNID)
@@ -91,7 +98,7 @@ func (d *Database) eventTypeNIDs(
return result, nil
}
-func (d *Database) EventStateKeys(
+func (d *EventDatabase) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
@@ -116,13 +123,13 @@ func (d *Database) EventStateKeys(
return result, nil
}
-func (d *Database) EventStateKeyNIDs(
+func (d *EventDatabase) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
return d.eventStateKeyNIDs(ctx, nil, eventStateKeys)
}
-func (d *Database) eventStateKeyNIDs(
+func (d *EventDatabase) eventStateKeyNIDs(
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID)
@@ -174,7 +181,7 @@ func (d *Database) eventStateKeyNIDs(
return result, nil
}
-func (d *Database) StateEntriesForEventIDs(
+func (d *EventDatabase) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string, excludeRejected bool,
) ([]types.StateEntry, error) {
return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs, excludeRejected)
@@ -213,6 +220,17 @@ func (d *Database) stateEntriesForTuples(
return lists, nil
}
+func (d *Database) RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error) {
+ roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{roomNID})
+ if err != nil {
+ return nil, err
+ }
+ if len(roomIDs) == 0 {
+ return nil, fmt.Errorf("room does not exist")
+ }
+ return d.roomInfo(ctx, nil, roomIDs[0])
+}
+
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return d.roomInfo(ctx, nil, roomID)
}
@@ -292,7 +310,7 @@ func (d *Database) addState(
return
}
-func (d *Database) EventNIDs(
+func (d *EventDatabase) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventMetadata, error) {
return d.eventNIDs(ctx, nil, eventIDs, NoFilter)
@@ -305,7 +323,7 @@ const (
FilterUnsentOnly UnsentFilter = true
)
-func (d *Database) eventNIDs(
+func (d *EventDatabase) eventNIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter,
) (map[string]types.EventMetadata, error) {
switch filter {
@@ -318,7 +336,7 @@ func (d *Database) eventNIDs(
}
}
-func (d *Database) SetState(
+func (d *EventDatabase) SetState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@@ -326,19 +344,19 @@ func (d *Database) SetState(
})
}
-func (d *Database) StateAtEventIDs(
+func (d *EventDatabase) StateAtEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
}
-func (d *Database) SnapshotNIDFromEventID(
+func (d *EventDatabase) SnapshotNIDFromEventID(
ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) {
return d.snapshotNIDFromEventID(ctx, nil, eventID)
}
-func (d *Database) snapshotNIDFromEventID(
+func (d *EventDatabase) snapshotNIDFromEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) (types.StateSnapshotNID, error) {
_, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID)
@@ -351,17 +369,17 @@ func (d *Database) snapshotNIDFromEventID(
return stateNID, err
}
-func (d *Database) EventIDs(
+func (d *EventDatabase) EventIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) {
return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
}
-func (d *Database) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) {
- return d.eventsFromIDs(ctx, nil, roomNID, eventIDs, NoFilter)
+func (d *EventDatabase) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) {
+ return d.eventsFromIDs(ctx, nil, roomInfo, eventIDs, NoFilter)
}
-func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventIDs []string, filter UnsentFilter) ([]types.Event, error) {
+func (d *EventDatabase) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventIDs []string, filter UnsentFilter) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter)
if err != nil {
return nil, err
@@ -370,15 +388,9 @@ func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types
var nids []types.EventNID
for _, nid := range nidMap {
nids = append(nids, nid.EventNID)
- if roomNID != 0 && roomNID != nid.RoomNID {
- logrus.Errorf("expected events from room %d, but also found %d", roomNID, nid.RoomNID)
- }
- if roomNID == 0 {
- roomNID = nid.RoomNID
- }
}
- return d.events(ctx, txn, roomNID, nids)
+ return d.events(ctx, txn, roomInfo, nids)
}
func (d *Database) LatestEventIDs(
@@ -517,19 +529,17 @@ func (d *Database) GetInvitesForUser(
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
}
-func (d *Database) Events(
- ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID,
-) ([]types.Event, error) {
- return d.events(ctx, nil, roomNID, eventNIDs)
+func (d *EventDatabase) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) {
+ return d.events(ctx, nil, roomInfo, eventNIDs)
}
-func (d *Database) events(
- ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, inputEventNIDs types.EventNIDs,
+func (d *EventDatabase) events(
+ ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, inputEventNIDs types.EventNIDs,
) ([]types.Event, error) {
- if roomNID == 0 {
- // No need to go further, as we won't find any events for this room.
- return nil, nil
+ if roomInfo == nil { // this should never happen
+ return nil, fmt.Errorf("unable to parse events without roomInfo")
}
+
sort.Sort(inputEventNIDs)
events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs))
eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs))
@@ -566,31 +576,9 @@ func (d *Database) events(
eventIDs = map[types.EventNID]string{}
}
- var roomVersion gomatrixserverlib.RoomVersion
- var fetchRoomVersion bool
- var ok bool
- var roomID string
- if roomID, ok = d.Cache.GetRoomServerRoomID(roomNID); ok {
- roomVersion, ok = d.Cache.GetRoomVersion(roomID)
- if !ok {
- fetchRoomVersion = true
- }
- }
-
- if roomVersion == "" || fetchRoomVersion {
- var dbRoomVersions map[types.RoomNID]gomatrixserverlib.RoomVersion
- dbRoomVersions, err = d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, []types.RoomNID{roomNID})
- if err != nil {
- return nil, err
- }
- if roomVersion, ok = dbRoomVersions[roomNID]; !ok {
- return nil, fmt.Errorf("unable to find roomversion for room %d", roomNID)
- }
- }
-
for _, eventJSON := range eventJSONs {
events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID(
- eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion,
+ eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomInfo.RoomVersion,
)
if err != nil {
return nil, err
@@ -660,8 +648,8 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e
return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID)
}
-// GetOrCreateRoomNID gets or creates a new roomNID for the given event
-func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (roomNID types.RoomNID, err error) {
+// GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID.
+func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserverlib.Event) (roomInfo *types.RoomInfo, err error) {
// Get the default room version. If the client doesn't supply a room_version
// then we will use our configured default to create the room.
// https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom
@@ -670,8 +658,9 @@ func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserver
// room.
var roomVersion gomatrixserverlib.RoomVersion
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
- return 0, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
+ return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
}
+ var roomNID types.RoomNID
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion)
if err != nil {
@@ -679,7 +668,10 @@ func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserver
}
return nil
})
- return roomNID, err
+ return &types.RoomInfo{
+ RoomVersion: roomVersion,
+ RoomNID: roomNID,
+ }, err
}
func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) {
@@ -710,25 +702,22 @@ func (d *Database) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKe
return eventStateKeyNID, nil
}
-func (d *Database) StoreEvent(
+func (d *EventDatabase) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event,
- roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
+ roomInfo *types.RoomInfo, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
authEventNIDs []types.EventNID, isRejected bool,
-) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
+) (types.EventNID, types.StateAtEvent, error) {
var (
- eventNID types.EventNID
- stateNID types.StateSnapshotNID
- redactionEvent *gomatrixserverlib.Event
- redactedEventID string
- err error
+ eventNID types.EventNID
+ stateNID types.StateSnapshotNID
+ err error
)
- // Second writer is using the database-provided transaction, probably from the
- // room updater, for easy roll-back if required.
+
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if eventNID, stateNID, err = d.EventsTable.InsertEvent(
ctx,
txn,
- roomNID,
+ roomInfo.RoomNID,
eventTypeNID,
eventStateKeyNID,
event.EventID(),
@@ -751,16 +740,26 @@ func (d *Database) StoreEvent(
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
}
- if !isRejected { // ignore rejected redaction events
- redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, roomNID, eventNID, event)
- if err != nil {
- return fmt.Errorf("d.handleRedactions: %w", err)
+
+ if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
+ // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
+ // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
+ // function only does SELECTs though so the created txn (at this point) is just a read txn like
+ // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
+ // to do writes however then this will need to go inside `Writer.Do`.
+
+ // The following is a copy of RoomUpdater.StorePreviousEvents
+ for _, ref := range prevEvents {
+ if err = d.PrevEventsTable.InsertPreviousEvent(ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
+ return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
+ }
}
}
+
return nil
})
if err != nil {
- return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err)
+ return 0, types.StateAtEvent{}, fmt.Errorf("d.Writer.Do: %w", err)
}
// We should attempt to update the previous events table with any
@@ -768,33 +767,6 @@ func (d *Database) StoreEvent(
// events updater because it somewhat works as a mutex, ensuring
// that there's a row-level lock on the latest room events (well,
// on Postgres at least).
- if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
- // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
- // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
- // function only does SELECTs though so the created txn (at this point) is just a read txn like
- // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
- // to do writes however then this will need to go inside `Writer.Do`.
- succeeded := false
- var roomInfo *types.RoomInfo
- roomInfo, err = d.roomInfo(ctx, nil, event.RoomID())
- if err != nil {
- return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
- }
- if roomInfo == nil && len(prevEvents) > 0 {
- return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
- }
- var updater *RoomUpdater
- updater, err = d.GetRoomUpdater(ctx, roomInfo)
- if err != nil {
- return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err)
- }
- defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
-
- if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
- return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)
- }
- succeeded = true
- }
return eventNID, types.StateAtEvent{
BeforeStateSnapshotNID: stateNID,
@@ -805,7 +777,7 @@ func (d *Database) StoreEvent(
},
EventNID: eventNID,
},
- }, redactionEvent, redactedEventID, err
+ }, err
}
func (d *Database) PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error {
@@ -893,7 +865,7 @@ func (d *Database) assignEventTypeNID(
return eventTypeNID, nil
}
-func (d *Database) assignStateKeyNID(
+func (d *EventDatabase) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey)
@@ -937,7 +909,7 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) (
return roomVersion, err
}
-// handleRedactions manages the redacted status of events. There's two cases to consider in order to comply with the spec:
+// MaybeRedactEvent manages the redacted status of events. There's two cases to consider in order to comply with the spec:
// "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid."
// https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
// These cases are:
@@ -952,95 +924,95 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) (
// when loading events to determine whether to apply redactions. This keeps the hot-path of reading events quick as we don't need
// to cross-reference with other tables when loading.
//
-// Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction.
-func (d *Database) handleRedactions(
- ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event,
-) (*gomatrixserverlib.Event, string, error) {
- var err error
- isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil
- if isRedactionEvent {
- // an event which redacts itself should be ignored
- if event.EventID() == event.Redacts() {
- return nil, "", nil
- }
+// Returns the redaction event and the redacted event if this call resulted in a redaction.
+func (d *EventDatabase) MaybeRedactEvent(
+ ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event, redactAllowed bool,
+) (*gomatrixserverlib.Event, *gomatrixserverlib.Event, error) {
+ var (
+ redactionEvent, redactedEvent *types.Event
+ err error
+ validated bool
+ ignoreRedaction bool
+ )
- err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{
- Validated: false,
- RedactionEventID: event.EventID(),
- RedactsEventID: event.Redacts(),
- })
- if err != nil {
- return nil, "", fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err)
+ wErr := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil
+ if isRedactionEvent {
+ // an event which redacts itself should be ignored
+ if event.EventID() == event.Redacts() {
+ return nil
+ }
+
+ err = d.RedactionsTable.InsertRedaction(ctx, txn, tables.RedactionInfo{
+ Validated: false,
+ RedactionEventID: event.EventID(),
+ RedactsEventID: event.Redacts(),
+ })
+ if err != nil {
+ return fmt.Errorf("d.RedactionsTable.InsertRedaction: %w", err)
+ }
}
- }
- redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, roomNID, eventNID, event)
- if err != nil {
- return nil, "", fmt.Errorf("d.loadRedactionPair: %w", err)
- }
- if validated || redactedEvent == nil || redactionEvent == nil {
- // we've seen this redaction before or there is nothing to redact
- return nil, "", nil
- }
- if redactedEvent.RoomID() != redactionEvent.RoomID() {
- // redactions across rooms aren't allowed
- return nil, "", nil
- }
+ redactionEvent, redactedEvent, validated, err = d.loadRedactionPair(ctx, txn, roomInfo, eventNID, event)
+ switch {
+ case err != nil:
+ return fmt.Errorf("d.loadRedactionPair: %w", err)
+ case validated || redactedEvent == nil || redactionEvent == nil:
+ // we've seen this redaction before or there is nothing to redact
+ return nil
+ case redactedEvent.RoomID() != redactionEvent.RoomID():
+ // redactions across rooms aren't allowed
+ ignoreRedaction = true
+ return nil
+ }
- // Get the power level from the database, so we can verify the user is allowed to redact the event
- powerLevels, err := d.GetStateEvent(ctx, event.RoomID(), gomatrixserverlib.MRoomPowerLevels, "")
- if err != nil {
- return nil, "", fmt.Errorf("d.GetStateEvent: %w", err)
- }
- if powerLevels == nil {
- return nil, "", fmt.Errorf("unable to fetch m.room.power_levels event from database for room %s", event.RoomID())
- }
- pl, err := powerLevels.PowerLevels()
- if err != nil {
- return nil, "", fmt.Errorf("unable to get powerlevels for room: %w", err)
- }
+ // 1. The power level of the redaction event’s sender is greater than or equal to the redact level. (redactAllowed)
+ // 2. The domain of the redaction event’s sender matches that of the original event’s sender.
+ _, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender())
+ _, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender())
+ if !redactAllowed || sender1 != sender2 {
+ ignoreRedaction = true
+ return nil
+ }
- redactUser := pl.UserLevel(redactionEvent.Sender())
- switch {
- case redactUser >= pl.Redact:
- // The power level of the redaction event’s sender is greater than or equal to the redact level.
- case redactedEvent.Sender() == redactionEvent.Sender():
- // The domain of the redaction event’s sender matches that of the original event’s sender.
- default:
- return nil, "", nil
- }
+ // mark the event as redacted
+ if redactionsArePermanent {
+ redactedEvent.Redact()
+ }
- // mark the event as redacted
- if redactionsArePermanent {
- redactedEvent.Redact()
- }
+ err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent)
+ if err != nil {
+ return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
+ }
+ // NOTSPEC: sytest relies on this unspecced field existing :(
+ err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID())
+ if err != nil {
+ return fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
+ }
+ // overwrite the eventJSON table
+ err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON())
+ if err != nil {
+ return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
+ }
- err = redactedEvent.SetUnsignedField("redacted_because", redactionEvent)
- if err != nil {
- return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
- }
- // NOTSPEC: sytest relies on this unspecced field existing :(
- err = redactedEvent.SetUnsignedField("redacted_by", redactionEvent.EventID())
- if err != nil {
- return nil, "", fmt.Errorf("redactedEvent.SetUnsignedField: %w", err)
- }
- // overwrite the eventJSON table
- err = d.EventJSONTable.InsertEventJSON(ctx, txn, redactedEvent.EventNID, redactedEvent.JSON())
- if err != nil {
- return nil, "", fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
+ err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true)
+ if err != nil {
+ return fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err)
+ }
+ return nil
+ })
+ if wErr != nil {
+ return nil, nil, err
}
-
- err = d.RedactionsTable.MarkRedactionValidated(ctx, txn, redactionEvent.EventID(), true)
- if err != nil {
- err = fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err)
+ if ignoreRedaction || redactionEvent == nil || redactedEvent == nil {
+ return nil, nil, nil
}
-
- return redactionEvent.Event, redactedEvent.EventID(), err
+ return redactionEvent.Event, redactedEvent.Event, nil
}
// loadRedactionPair returns both the redaction event and the redacted event, else nil.
-func (d *Database) loadRedactionPair(
- ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event,
+func (d *EventDatabase) loadRedactionPair(
+ ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, eventNID types.EventNID, event *gomatrixserverlib.Event,
) (*types.Event, *types.Event, bool, error) {
var redactionEvent, redactedEvent *types.Event
var info *tables.RedactionInfo
@@ -1072,16 +1044,16 @@ func (d *Database) loadRedactionPair(
}
if isRedactionEvent {
- redactedEvent = d.loadEvent(ctx, roomNID, info.RedactsEventID)
+ redactedEvent = d.loadEvent(ctx, roomInfo, info.RedactsEventID)
} else {
- redactionEvent = d.loadEvent(ctx, roomNID, info.RedactionEventID)
+ redactionEvent = d.loadEvent(ctx, roomInfo, info.RedactionEventID)
}
return redactionEvent, redactedEvent, info.Validated, nil
}
// applyRedactions will redact events that have an `unsigned.redacted_because` field.
-func (d *Database) applyRedactions(events []types.Event) {
+func (d *EventDatabase) applyRedactions(events []types.Event) {
for i := range events {
if result := gjson.GetBytes(events[i].Unsigned(), "redacted_because"); result.Exists() {
events[i].Redact()
@@ -1090,7 +1062,7 @@ func (d *Database) applyRedactions(events []types.Event) {
}
// loadEvent loads a single event or returns nil on any problems/missing event
-func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID string) *types.Event {
+func (d *EventDatabase) loadEvent(ctx context.Context, roomInfo *types.RoomInfo, eventID string) *types.Event {
nids, err := d.EventNIDs(ctx, []string{eventID})
if err != nil {
return nil
@@ -1098,7 +1070,7 @@ func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID
if len(nids) == 0 {
return nil
}
- evs, err := d.Events(ctx, roomNID, []types.EventNID{nids[eventID].EventNID})
+ evs, err := d.Events(ctx, roomInfo, []types.EventNID{nids[eventID].EventNID})
if err != nil {
return nil
}
@@ -1144,7 +1116,7 @@ func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *type
// If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
- roomInfo, err := d.RoomInfo(ctx, roomID)
+ roomInfo, err := d.roomInfo(ctx, nil, roomID)
if err != nil {
return nil, err
}
@@ -1209,7 +1181,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
// Same as GetStateEvent but returns all matching state events with this event type. Returns no error
// if there are no events with this event type.
func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*gomatrixserverlib.HeaderedEvent, error) {
- roomInfo, err := d.RoomInfo(ctx, roomID)
+ roomInfo, err := d.roomInfo(ctx, nil, roomID)
if err != nil {
return nil, err
}
@@ -1340,7 +1312,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion)
// TODO: This feels like this is going to be really slow...
for _, roomID := range roomIDs {
- roomInfo, err2 := d.RoomInfo(ctx, roomID)
+ roomInfo, err2 := d.roomInfo(ctx, nil, roomID)
if err2 != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to load room info for room %s : %w", roomID, err2)
}
diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go
index 3acb55a3..684e80b8 100644
--- a/roomserver/storage/shared/storage_test.go
+++ b/roomserver/storage/shared/storage_test.go
@@ -52,12 +52,14 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat
cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false)
+ evDb := shared.EventDatabase{EventStateKeysTable: stateKeyTable, Cache: cache}
+
return &shared.Database{
- DB: db,
- EventStateKeysTable: stateKeyTable,
- MembershipTable: membershipTable,
- Writer: sqlutil.NewExclusiveWriter(),
- Cache: cache,
+ DB: db,
+ EventDatabase: evDb,
+ MembershipTable: membershipTable,
+ Writer: sqlutil.NewExclusiveWriter(),
+ Cache: cache,
}, func() {
err := base.Close()
assert.NoError(t, err)
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go
index 392edd28..2adedd2d 100644
--- a/roomserver/storage/sqlite3/storage.go
+++ b/roomserver/storage/sqlite3/storage.go
@@ -203,24 +203,29 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
}
d.Database = shared.Database{
- DB: db,
- Cache: cache,
- Writer: writer,
- EventsTable: events,
- EventTypesTable: eventTypes,
- EventStateKeysTable: eventStateKeys,
- EventJSONTable: eventJSON,
- RoomsTable: rooms,
- StateBlockTable: stateBlock,
- StateSnapshotTable: stateSnapshot,
- PrevEventsTable: prevEvents,
- RoomAliasesTable: roomAliases,
- InvitesTable: invites,
- MembershipTable: membership,
- PublishedTable: published,
- RedactionsTable: redactions,
- GetRoomUpdaterFn: d.GetRoomUpdater,
- Purge: purge,
+ DB: db,
+ EventDatabase: shared.EventDatabase{
+ DB: db,
+ Cache: cache,
+ Writer: writer,
+ EventsTable: events,
+ EventTypesTable: eventTypes,
+ EventStateKeysTable: eventStateKeys,
+ EventJSONTable: eventJSON,
+ PrevEventsTable: prevEvents,
+ RedactionsTable: redactions,
+ },
+ Cache: cache,
+ Writer: writer,
+ RoomsTable: rooms,
+ StateBlockTable: stateBlock,
+ StateSnapshotTable: stateSnapshot,
+ RoomAliasesTable: roomAliases,
+ InvitesTable: invites,
+ MembershipTable: membership,
+ PublishedTable: published,
+ GetRoomUpdaterFn: d.GetRoomUpdater,
+ Purge: purge,
}
return nil
}
diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go
index bc369c16..4bb6a5ee 100644
--- a/setup/mscs/msc2836/msc2836.go
+++ b/setup/mscs/msc2836/msc2836.go
@@ -253,7 +253,7 @@ func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONRespo
var res MSC2836EventRelationshipsResponse
var returnEvents []*gomatrixserverlib.HeaderedEvent
// Can the user see (according to history visibility) event_id? If no, reject the request, else continue.
- event := rc.getLocalEvent(rc.req.EventID)
+ event := rc.getLocalEvent(rc.req.RoomID, rc.req.EventID)
if event == nil {
event = rc.fetchUnknownEvent(rc.req.EventID, rc.req.RoomID)
}
@@ -592,7 +592,7 @@ func (rc *reqCtx) remoteEventRelationships(eventID string) *MSC2836EventRelation
// lookForEvent returns the event for the event ID given, by trying to query remote servers
// if the event ID is unknown via /event_relationships.
func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent {
- event := rc.getLocalEvent(eventID)
+ event := rc.getLocalEvent(rc.req.RoomID, eventID)
if event == nil {
queryRes := rc.remoteEventRelationships(eventID)
if queryRes != nil {
@@ -622,9 +622,10 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent
return nil
}
-func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent {
+func (rc *reqCtx) getLocalEvent(roomID, eventID string) *gomatrixserverlib.HeaderedEvent {
var queryEventsRes roomserver.QueryEventsByIDResponse
err := rc.rsAPI.QueryEventsByID(rc.ctx, &roomserver.QueryEventsByIDRequest{
+ RoomID: roomID,
EventIDs: []string{eventID},
}, &queryEventsRes)
if err != nil {
diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go
index 21838039..a8d4d2b2 100644
--- a/syncapi/consumers/roomserver.go
+++ b/syncapi/consumers/roomserver.go
@@ -212,6 +212,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
// Finally, work out if there are any more events missing.
if len(missingEventIDs) > 0 {
eventsReq := &api.QueryEventsByIDRequest{
+ RoomID: ev.RoomID(),
EventIDs: missingEventIDs,
}
eventsRes := &api.QueryEventsByIDResponse{}
diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go
index 9ffdf513..8efd77ce 100644
--- a/syncapi/routing/memberships.go
+++ b/syncapi/routing/memberships.go
@@ -109,7 +109,7 @@ func GetMemberships(
}
qryRes := &api.QueryEventsByIDResponse{}
- if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs}, qryRes); err != nil {
+ if err := rsAPI.QueryEventsByID(req.Context(), &api.QueryEventsByIDRequest{EventIDs: eventIDs, RoomID: roomID}, qryRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed")
return jsonerror.InternalServerError()
}