aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/alias.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/internal/alias.go')
-rw-r--r--roomserver/internal/alias.go102
1 files changed, 45 insertions, 57 deletions
diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go
index b04a56fe..a7f0aab9 100644
--- a/roomserver/internal/alias.go
+++ b/roomserver/internal/alias.go
@@ -35,27 +35,27 @@ import (
// SetRoomAlias implements alias.RoomserverInternalAPI
func (r *RoomserverInternalAPI) SetRoomAlias(
ctx context.Context,
- request *api.SetRoomAliasRequest,
- response *api.SetRoomAliasResponse,
-) error {
+ senderID spec.SenderID,
+ roomID spec.RoomID,
+ alias string,
+) (aliasAlreadyUsed bool, err error) {
// Check if the alias isn't already referring to a room
- roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
+ existingRoomID, err := r.DB.GetRoomIDForAlias(ctx, alias)
if err != nil {
- return err
+ return false, err
}
- if len(roomID) > 0 {
+
+ if len(existingRoomID) > 0 {
// If the alias already exists, stop the process
- response.AliasExists = true
- return nil
+ return true, nil
}
- response.AliasExists = false
// Save the new alias
- if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID, request.UserID); err != nil {
- return err
+ if err := r.DB.SetRoomAlias(ctx, alias, roomID.String(), string(senderID)); err != nil {
+ return false, err
}
- return nil
+ return false, nil
}
// GetRoomIDForAlias implements alias.RoomserverInternalAPI
@@ -116,90 +116,79 @@ func (r *RoomserverInternalAPI) GetAliasesForRoomID(
// nolint:gocyclo
// RemoveRoomAlias implements alias.RoomserverInternalAPI
// nolint: gocyclo
-func (r *RoomserverInternalAPI) RemoveRoomAlias(
- ctx context.Context,
- request *api.RemoveRoomAliasRequest,
- response *api.RemoveRoomAliasResponse,
-) error {
- roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
+func (r *RoomserverInternalAPI) RemoveRoomAlias(ctx context.Context, senderID spec.SenderID, alias string) (aliasFound bool, aliasRemoved bool, err error) {
+ roomID, err := r.DB.GetRoomIDForAlias(ctx, alias)
if err != nil {
- return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err)
+ return false, false, fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err)
}
if roomID == "" {
- response.Found = false
- response.Removed = false
- return nil
+ return false, false, nil
}
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
- return err
+ return true, false, err
}
- sender, err := r.QueryUserIDForSender(ctx, *validRoomID, request.SenderID)
+ sender, err := r.QueryUserIDForSender(ctx, *validRoomID, senderID)
if err != nil || sender == nil {
- return fmt.Errorf("r.QueryUserIDForSender: %w", err)
+ return true, false, fmt.Errorf("r.QueryUserIDForSender: %w", err)
}
virtualHost := sender.Domain()
- response.Found = true
- creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias)
+ creatorID, err := r.DB.GetCreatorIDForAlias(ctx, alias)
if err != nil {
- return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err)
+ return true, false, fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err)
}
- if spec.SenderID(creatorID) != request.SenderID {
+ if spec.SenderID(creatorID) != senderID {
var plEvent *types.HeaderedEvent
var pls *gomatrixserverlib.PowerLevelContent
plEvent, err = r.DB.GetStateEvent(ctx, roomID, spec.MRoomPowerLevels, "")
if err != nil {
- return fmt.Errorf("r.DB.GetStateEvent: %w", err)
+ return true, false, fmt.Errorf("r.DB.GetStateEvent: %w", err)
}
pls, err = plEvent.PowerLevels()
if err != nil {
- return fmt.Errorf("plEvent.PowerLevels: %w", err)
+ return true, false, fmt.Errorf("plEvent.PowerLevels: %w", err)
}
- if pls.UserLevel(request.SenderID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) {
- response.Removed = false
- return nil
+ if pls.UserLevel(senderID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) {
+ return true, false, nil
}
}
ev, err := r.DB.GetStateEvent(ctx, roomID, spec.MRoomCanonicalAlias, "")
if err != nil && err != sql.ErrNoRows {
- return err
+ return true, false, err
} else if ev != nil {
stateAlias := gjson.GetBytes(ev.Content(), "alias").Str
// the alias to remove is currently set as the canonical alias, remove it
- if stateAlias == request.Alias {
+ if stateAlias == alias {
res, err := sjson.DeleteBytes(ev.Content(), "alias")
if err != nil {
- return err
+ return true, false, err
}
- senderID := request.SenderID
- if request.SenderID != ev.SenderID() {
- senderID = ev.SenderID()
- }
- sender, err := r.QueryUserIDForSender(ctx, *validRoomID, senderID)
- if err != nil || sender == nil {
- return err
+ canonicalSenderID := ev.SenderID()
+ canonicalSender, err := r.QueryUserIDForSender(ctx, *validRoomID, canonicalSenderID)
+ if err != nil || canonicalSender == nil {
+ return true, false, err
}
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
- return err
+ return true, false, err
}
- identity, err := r.SigningIdentityFor(ctx, *validRoomID, *sender)
+ identity, err := r.SigningIdentityFor(ctx, *validRoomID, *canonicalSender)
if err != nil {
- return err
+ return true, false, err
}
proto := &gomatrixserverlib.ProtoEvent{
- SenderID: string(senderID),
+ SenderID: string(canonicalSenderID),
RoomID: ev.RoomID(),
Type: ev.Type(),
StateKey: ev.StateKey(),
@@ -208,34 +197,33 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
eventsNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(proto)
if err != nil {
- return fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err)
+ return true, false, fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err)
}
if len(eventsNeeded.Tuples()) == 0 {
- return errors.New("expecting state tuples for event builder, got none")
+ return true, false, errors.New("expecting state tuples for event builder, got none")
}
stateRes := &api.QueryLatestEventsAndStateResponse{}
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil {
- return err
+ return true, false, err
}
newEvent, err := eventutil.BuildEvent(ctx, proto, &identity, time.Now(), &eventsNeeded, stateRes)
if err != nil {
- return err
+ return true, false, err
}
err = api.SendEvents(ctx, r, api.KindNew, []*types.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false)
if err != nil {
- return err
+ return true, false, err
}
}
}
// Remove the alias from the database
- if err := r.DB.RemoveRoomAlias(ctx, request.Alias); err != nil {
- return err
+ if err := r.DB.RemoveRoomAlias(ctx, alias); err != nil {
+ return true, false, err
}
- response.Removed = true
- return nil
+ return true, true, nil
}