diff options
author | devonh <devon.dmytro@gmail.com> | 2023-09-15 14:39:06 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-15 14:39:06 +0000 |
commit | 8245b24100b0afaa046bb3fe52f0994f906c8ab1 (patch) | |
tree | f8d65fd4c27e208e5829f5ca0f66d79c639fabdb /roomserver | |
parent | 058081e68e4e23400645c6206cedddba8a31507e (diff) |
Update gmsl to use new validated RoomID on PDUs (#3200)
GMSL returns a `spec.RoomID` when calling `PDU.RoomID()`
Diffstat (limited to 'roomserver')
-rw-r--r-- | roomserver/acls/acls.go | 2 | ||||
-rw-r--r-- | roomserver/api/wrapper.go | 2 | ||||
-rw-r--r-- | roomserver/auth/auth.go | 6 | ||||
-rw-r--r-- | roomserver/internal/alias.go | 2 | ||||
-rw-r--r-- | roomserver/internal/api.go | 2 | ||||
-rw-r--r-- | roomserver/internal/helpers/auth.go | 4 | ||||
-rw-r--r-- | roomserver/internal/helpers/helpers.go | 6 | ||||
-rw-r--r-- | roomserver/internal/input/input.go | 2 | ||||
-rw-r--r-- | roomserver/internal/input/input_events.go | 54 | ||||
-rw-r--r-- | roomserver/internal/input/input_latest_events.go | 4 | ||||
-rw-r--r-- | roomserver/internal/input/input_membership.go | 10 | ||||
-rw-r--r-- | roomserver/internal/input/input_missing.go | 14 | ||||
-rw-r--r-- | roomserver/internal/perform/perform_backfill.go | 8 | ||||
-rw-r--r-- | roomserver/internal/perform/perform_invite.go | 8 | ||||
-rw-r--r-- | roomserver/internal/query/query_room_hierarchy.go | 4 | ||||
-rw-r--r-- | roomserver/internal/query/query_test.go | 1 | ||||
-rw-r--r-- | roomserver/storage/shared/storage.go | 21 |
17 files changed, 58 insertions, 92 deletions
diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index b04828b6..601ce906 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -119,7 +119,7 @@ func (s *ServerACLs) OnServerACLUpdate(state gomatrixserverlib.PDU) { }).Debugf("Updating server ACLs for %q", state.RoomID()) s.aclsMutex.Lock() defer s.aclsMutex.Unlock() - s.acls[state.RoomID()] = acls + s.acls[state.RoomID().String()] = acls } func (s *ServerACLs) IsServerBannedFromRoom(serverName spec.ServerName, roomID string) bool { diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 2505a993..0ad5d201 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -75,7 +75,7 @@ func SendEventWithState( } logrus.WithContext(ctx).WithFields(logrus.Fields{ - "room_id": event.RoomID(), + "room_id": event.RoomID().String(), "event_id": event.EventID(), "outliers": len(ires), "state_ids": len(stateEventIDs), diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index df95851e..d5172dab 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -85,11 +85,7 @@ func IsAnyUserOnServerWithMembership(ctx context.Context, querier api.QuerySende continue } - validRoomID, err := spec.NewRoomID(ev.RoomID()) - if err != nil { - continue - } - userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey)) + userID, err := querier.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey)) if err != nil { continue } diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index a7f0aab9..5ceda7e0 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -189,7 +189,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(ctx context.Context, senderID sp proto := &gomatrixserverlib.ProtoEvent{ SenderID: string(canonicalSenderID), - RoomID: ev.RoomID(), + RoomID: ev.RoomID().String(), Type: ev.Type(), StateKey: ev.StateKey(), Content: res, diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 530147da..1e08f6a3 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -239,7 +239,7 @@ func (r *RoomserverInternalAPI) HandleInvite( if err != nil { return err } - return r.OutputProducer.ProduceRoomEvents(inviteEvent.RoomID(), outputEvents) + return r.OutputProducer.ProduceRoomEvents(inviteEvent.RoomID().String(), outputEvents) } func (r *RoomserverInternalAPI) PerformCreateRoom( diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 89fae244..9da751b1 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -218,9 +218,9 @@ func loadAuthEvents( roomID := "" for _, ev := range result.events { if roomID == "" { - roomID = ev.RoomID() + roomID = ev.RoomID().String() } - if ev.RoomID() != roomID { + if ev.RoomID().String() != roomID { result.valid = false break } diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index febabf41..b2e21bf5 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -54,7 +54,7 @@ func UpdateToInviteMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, - RoomID: add.RoomID(), + RoomID: add.RoomID().String(), Membership: spec.Join, RetiredByEventID: add.EventID(), TargetSenderID: spec.SenderID(*add.StateKey()), @@ -396,7 +396,7 @@ BFSLoop: // It's nasty that we have to extract the room ID from an event, but many federation requests // only talk in event IDs, no room IDs at all (!!!) ev := events[0] - isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID()) + isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID().String()) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") } @@ -419,7 +419,7 @@ BFSLoop: // hasn't been seen before. if !visited[pre] { visited[pre] = true - allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom, querier) + allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID().String(), pre, serverName, isServerInRoom, querier) if err != nil { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( "Error checking if allowed to see event", diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 99056359..40475153 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -358,7 +358,7 @@ func (r *Inputer) queueInputRoomEvents( // For each event, marshal the input room event and then // send it into the input queue. for _, e := range request.InputRoomEvents { - roomID := e.Event.RoomID() + roomID := e.Event.RoomID().String() subj := r.Cfg.Matrix.JetStream.Prefixed(jetstream.InputRoomEventSubj(roomID)) msg := &nats.Msg{ Subject: subj, diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index bf321662..77b50d0e 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -87,7 +87,7 @@ func (r *Inputer) processRoomEvent( } trace, ctx := internal.StartRegion(ctx, "processRoomEvent") - trace.SetTag("room_id", input.Event.RoomID()) + trace.SetTag("room_id", input.Event.RoomID().String()) trace.SetTag("event_id", input.Event.EventID()) defer trace.EndRegion() @@ -96,7 +96,7 @@ func (r *Inputer) processRoomEvent( defer func() { timetaken := time.Since(started) processRoomEventDuration.With(prometheus.Labels{ - "room_id": input.Event.RoomID(), + "room_id": input.Event.RoomID().String(), }).Observe(float64(timetaken.Milliseconds())) }() @@ -105,7 +105,7 @@ func (r *Inputer) processRoomEvent( event := headered.PDU logger := util.GetLogger(ctx).WithFields(logrus.Fields{ "event_id": event.EventID(), - "room_id": event.RoomID(), + "room_id": event.RoomID().String(), "kind": input.Kind, "origin": input.Origin, "type": event.Type(), @@ -120,19 +120,15 @@ func (r *Inputer) processRoomEvent( // Don't waste time processing the event if the room doesn't exist. // A room entry locally will only be created in response to a create // event. - roomInfo, rerr := r.DB.RoomInfo(ctx, event.RoomID()) + roomInfo, rerr := r.DB.RoomInfo(ctx, event.RoomID().String()) if rerr != nil { return fmt.Errorf("r.DB.RoomInfo: %w", rerr) } isCreateEvent := event.Type() == spec.MRoomCreate && event.StateKeyEquals("") if roomInfo == nil && !isCreateEvent { - return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) + return fmt.Errorf("room %s does not exist for event %s", event.RoomID().String(), event.EventID()) } - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return err - } - sender, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()) + sender, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err) } @@ -179,7 +175,7 @@ func (r *Inputer) processRoomEvent( // If we have missing events (auth or prev), we build a list of servers to ask if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), ExcludeSelf: true, ExcludeBlacklisted: true, } @@ -395,12 +391,12 @@ func (r *Inputer) processRoomEvent( // Request the room info again — it's possible that the room has been // created by now if it didn't exist already. - roomInfo, err = r.DB.RoomInfo(ctx, event.RoomID()) + roomInfo, err = r.DB.RoomInfo(ctx, event.RoomID().String()) if err != nil { return fmt.Errorf("updater.RoomInfo: %w", err) } if roomInfo == nil { - return fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) + return fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID().String()) } if input.HasState || (!missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0) { @@ -459,7 +455,7 @@ func (r *Inputer) processRoomEvent( if userErr != nil { return userErr } - err = r.RSAPI.StoreUserRoomPublicKey(ctx, mapping.MXIDMapping.UserRoomKey, *storeUserID, *validRoomID) + err = r.RSAPI.StoreUserRoomPublicKey(ctx, mapping.MXIDMapping.UserRoomKey, *storeUserID, event.RoomID()) if err != nil { return fmt.Errorf("failed storing user room public key: %w", err) } @@ -481,7 +477,7 @@ func (r *Inputer) processRoomEvent( return fmt.Errorf("r.updateLatestEvents: %w", err) } case api.KindOld: - err = r.OutputProducer.ProduceRoomEvents(event.RoomID(), []api.OutputEvent{ + err = r.OutputProducer.ProduceRoomEvents(event.RoomID().String(), []api.OutputEvent{ { Type: api.OutputTypeOldRoomEvent, OldRoomEvent: &api.OutputOldRoomEvent{ @@ -507,7 +503,7 @@ func (r *Inputer) processRoomEvent( // so notify downstream components to redact this event - they should have it if they've // been tracking our output log. if redactedEventID != "" { - err = r.OutputProducer.ProduceRoomEvents(event.RoomID(), []api.OutputEvent{ + err = r.OutputProducer.ProduceRoomEvents(event.RoomID().String(), []api.OutputEvent{ { Type: api.OutputTypeRedactedEvent, RedactedEvent: &api.OutputRedactedEvent{ @@ -536,7 +532,7 @@ func (r *Inputer) processRoomEvent( // handleRemoteRoomUpgrade updates published rooms and room aliases func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error { - oldRoomID := event.RoomID() + oldRoomID := event.RoomID().String() newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, string(event.SenderID())) } @@ -596,7 +592,7 @@ func (r *Inputer) processStateBefore( StateKey: "", }) stateBeforeReq := &api.QueryStateAfterEventsRequest{ - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), PrevEventIDs: event.PrevEventIDs(), StateToFetch: tuplesNeeded, } @@ -606,7 +602,7 @@ func (r *Inputer) processStateBefore( } switch { case !stateBeforeRes.RoomExists: - rejectionErr = fmt.Errorf("room %q does not exist", event.RoomID()) + rejectionErr = fmt.Errorf("room %q does not exist", event.RoomID().String()) return case !stateBeforeRes.PrevEventsExist: rejectionErr = fmt.Errorf("prev events of %q are not known", event.EventID()) @@ -707,7 +703,7 @@ func (r *Inputer) fetchAuthEvents( // Request the entire auth chain for the event in question. This should // contain all of the auth events — including ones that we already know — // so we'll need to filter through those in the next section. - res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.Version(), event.RoomID(), event.EventID()) + res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.Version(), event.RoomID().String(), event.EventID()) if err != nil { logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) continue @@ -866,25 +862,20 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) latestReq := &api.QueryLatestEventsAndStateRequest{ - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), } latestRes := &api.QueryLatestEventsAndStateResponse{} if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { return err } - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return err - } - prevEvents := latestRes.LatestEvents for _, memberEvent := range memberEvents { if memberEvent.StateKey() == nil { continue } - memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*memberEvent.StateKey())) + memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*memberEvent.StateKey())) if err != nil { continue } @@ -912,7 +903,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r stateKey := *memberEvent.StateKey() fledglingEvent := &gomatrixserverlib.ProtoEvent{ - RoomID: event.RoomID(), + RoomID: event.RoomID().String(), Type: spec.MRoomMember, StateKey: &stateKey, SenderID: stateKey, @@ -928,12 +919,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r return err } - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return err - } - - signingIdentity, err := r.SigningIdentity(ctx, *validRoomID, *memberUserID) + signingIdentity, err := r.SigningIdentity(ctx, event.RoomID(), *memberUserID) if err != nil { return err } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 940783e0..ec03d6f1 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -197,7 +197,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // send the event asynchronously but we would need to ensure that 1) the events are written to the log in // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the // necessary bookkeeping we'll keep the event sending synchronous for now. - if err = u.api.OutputProducer.ProduceRoomEvents(u.event.RoomID(), updates); err != nil { + if err = u.api.OutputProducer.ProduceRoomEvents(u.event.RoomID().String(), updates); err != nil { return fmt.Errorf("u.api.WriteOutputEvents: %w", err) } @@ -290,7 +290,7 @@ func (u *latestEventsUpdater) latestState() error { if removed := len(u.removed) - len(u.added); !u.rewritesState && removed > 0 { logrus.WithFields(logrus.Fields{ "event_id": u.event.EventID(), - "room_id": u.event.RoomID(), + "room_id": u.event.RoomID().String(), "old_state_nid": u.oldStateNID, "new_state_nid": u.newStateNID, "old_latest": u.oldLatest.EventIDs(), diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index c46f8dba..4cfc2cda 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -139,11 +139,7 @@ func (r *Inputer) updateMembership( func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool { isTargetLocalUser := false if statekey := event.StateKey(); statekey != nil { - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - return isTargetLocalUser - } - userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*statekey)) + userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey)) if err != nil || userID == nil { return isTargetLocalUser } @@ -168,7 +164,7 @@ func updateToJoinMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, - RoomID: add.RoomID(), + RoomID: add.RoomID().String(), Membership: spec.Join, RetiredByEventID: add.EventID(), TargetSenderID: spec.SenderID(*add.StateKey()), @@ -195,7 +191,7 @@ func updateToLeaveMembership( Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ EventID: eventID, - RoomID: add.RoomID(), + RoomID: add.RoomID().String(), Membership: newMembership, RetiredByEventID: add.EventID(), TargetSenderID: spec.SenderID(*add.StateKey()), diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 5b4c0727..d9ab291e 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -84,7 +84,7 @@ func (t *missingStateReq) processEventWithMissingState( // need to fallback to /state. t.log = util.GetLogger(ctx).WithFields(map[string]interface{}{ "txn_event": e.EventID(), - "room_id": e.RoomID(), + "room_id": e.RoomID().String(), "txn_prev_events": e.PrevEventIDs(), }) @@ -264,7 +264,7 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e // Look up what the state is after the backward extremity. This will either // come from the roomserver, if we know all the required events, or it will // come from a remote server via /state_ids if not. - prevState, trustworthy, err := t.lookupStateAfterEvent(ctx, roomVersion, e.RoomID(), prevEventID) + prevState, trustworthy, err := t.lookupStateAfterEvent(ctx, roomVersion, e.RoomID().String(), prevEventID) switch err2 := err.(type) { case gomatrixserverlib.EventValidationError: if !err2.Persistable { @@ -316,9 +316,9 @@ func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e } // There's more than one previous state - run them all through state res var err error - t.roomsMu.Lock(e.RoomID()) + t.roomsMu.Lock(e.RoomID().String()) resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, e) - t.roomsMu.Unlock(e.RoomID()) + t.roomsMu.Unlock(e.RoomID().String()) switch err2 := err.(type) { case gomatrixserverlib.EventValidationError: if !err2.Persistable { @@ -510,7 +510,7 @@ retryAllowedState: }); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: - h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) + h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID().String(), missing.AuthEventID, true) switch e := err2.(type) { case gomatrixserverlib.EventValidationError: if !e.Persistable { @@ -546,7 +546,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver trace, ctx := internal.StartRegion(ctx, "getMissingEvents") defer trace.EndRegion() - logger := t.log.WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) + logger := t.log.WithField("event_id", e.EventID()).WithField("room_id", e.RoomID().String()) latest, _, _, err := t.db.LatestEventIDs(ctx, t.roomInfo.RoomNID) if err != nil { return nil, false, false, fmt.Errorf("t.DB.LatestEventIDs: %w", err) @@ -560,7 +560,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver var missingResp *fclient.RespMissingEvents for _, server := range t.servers { var m fclient.RespMissingEvents - if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID(), fclient.MissingEvents{ + if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID().String(), fclient.MissingEvents{ Limit: 20, // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. EarliestEvents: latestEvents, diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 33200e81..dafa5873 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -301,7 +301,7 @@ func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent return ids, nil } if len(targetEvent.PrevEventIDs()) == 0 && targetEvent.Type() == "m.room.create" && targetEvent.StateKeyEquals("") { - util.GetLogger(ctx).WithField("room_id", targetEvent.RoomID()).Info("Backfilled to the beginning of the room") + util.GetLogger(ctx).WithField("room_id", targetEvent.RoomID().String()).Info("Backfilled to the beginning of the room") b.eventIDToBeforeStateIDs[targetEvent.EventID()] = []string{} return nil, nil } @@ -494,11 +494,7 @@ FindSuccessor: // Store the server names in a temporary map to avoid duplicates. serverSet := make(map[spec.ServerName]bool) for _, event := range memberEvents { - validRoomID, err := spec.NewRoomID(event.RoomID()) - if err != nil { - continue - } - if sender, err := b.querier.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()); err == nil { + if sender, err := b.querier.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil { serverSet[sender.Domain()] = true } } diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index e07780d6..3abb69cb 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -100,16 +100,12 @@ func (r *Inviter) ProcessInviteMembership( var outputUpdates []api.OutputEvent var updater *shared.MembershipUpdater - validRoomID, err := spec.NewRoomID(inviteEvent.RoomID()) - if err != nil { - return nil, err - } - userID, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey())) + userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) if err != nil { return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} } isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain()) - if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { + if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID().String(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) } outputUpdates, err = helpers.UpdateToInviteMembership(updater, &types.Event{ diff --git a/roomserver/internal/query/query_room_hierarchy.go b/roomserver/internal/query/query_room_hierarchy.go index 7274be52..76eba12b 100644 --- a/roomserver/internal/query/query_room_hierarchy.go +++ b/roomserver/internal/query/query_room_hierarchy.go @@ -513,14 +513,14 @@ func restrictedJoinRuleAllowedRooms(ctx context.Context, joinRuleEv *types.Heade } var jrContent gomatrixserverlib.JoinRuleContent if err := json.Unmarshal(joinRuleEv.Content(), &jrContent); err != nil { - util.GetLogger(ctx).Warnf("failed to check join_rule on room %s: %s", joinRuleEv.RoomID(), err) + util.GetLogger(ctx).Warnf("failed to check join_rule on room %s: %s", joinRuleEv.RoomID().String(), err) return nil } for _, allow := range jrContent.Allow { if allow.Type == spec.MRoomMembership { allowedRoomID, err := spec.NewRoomID(allow.RoomID) if err != nil { - util.GetLogger(ctx).Warnf("invalid room ID '%s' found in join_rule on room %s: %s", allow.RoomID, joinRuleEv.RoomID(), err) + util.GetLogger(ctx).Warnf("invalid room ID '%s' found in join_rule on room %s: %s", allow.RoomID, joinRuleEv.RoomID().String(), err) } else { allows = append(allows, *allowedRoomID) } diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index 619d9303..296960b2 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -49,6 +49,7 @@ func (db *getEventDB) addFakeEvent(eventID string, authIDs []string) error { } builder := map[string]interface{}{ "event_id": eventID, + "room_id": "!room:a", "auth_events": authEvents, } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index b09c5afb..3331c602 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -696,8 +696,8 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserver return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) } - roomNID, nidOK := d.Cache.GetRoomServerRoomNID(event.RoomID()) - cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(event.RoomID()) + roomNID, nidOK := d.Cache.GetRoomServerRoomNID(event.RoomID().String()) + cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(event.RoomID().String()) // if we found both, the roomNID and version in our cache, no need to query the database if nidOK && versionOK { return &types.RoomInfo{ @@ -707,14 +707,14 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserver } err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion) + roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID().String(), roomVersion) if err != nil { return err } return nil }) if roomVersion != "" { - d.Cache.StoreRoomVersion(event.RoomID(), roomVersion) + d.Cache.StoreRoomVersion(event.RoomID().String(), roomVersion) } return &types.RoomInfo{ RoomVersion: roomVersion, @@ -1026,24 +1026,19 @@ func (d *EventDatabase) MaybeRedactEvent( case validated || redactedEvent == nil || redactionEvent == nil: // we've seen this redaction before or there is nothing to redact return nil - case redactedEvent.RoomID() != redactionEvent.RoomID(): + case redactedEvent.RoomID().String() != redactionEvent.RoomID().String(): // redactions across rooms aren't allowed ignoreRedaction = true return nil } - var validRoomID *spec.RoomID - validRoomID, err = spec.NewRoomID(redactedEvent.RoomID()) - if err != nil { - return err - } sender1Domain := "" - sender1, err1 := querier.QueryUserIDForSender(ctx, *validRoomID, redactedEvent.SenderID()) + sender1, err1 := querier.QueryUserIDForSender(ctx, redactedEvent.RoomID(), redactedEvent.SenderID()) if err1 == nil { sender1Domain = string(sender1.Domain()) } sender2Domain := "" - sender2, err2 := querier.QueryUserIDForSender(ctx, *validRoomID, redactionEvent.SenderID()) + sender2, err2 := querier.QueryUserIDForSender(ctx, redactedEvent.RoomID(), redactionEvent.SenderID()) if err2 == nil { sender2Domain = string(sender2.Domain()) } @@ -1522,7 +1517,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } result[i] = tables.StrippedEvent{ EventType: ev.Type(), - RoomID: ev.RoomID(), + RoomID: ev.RoomID().String(), StateKey: *ev.StateKey(), ContentValue: tables.ExtractContentValue(&types.HeaderedEvent{PDU: ev}), } |