aboutsummaryrefslogtreecommitdiff
path: root/roomserver
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-03-17 17:05:21 +0000
committerNeil Alexander <neilalexander@users.noreply.github.com>2022-03-17 17:05:21 +0000
commit4e64c270dbe5d438325903e4404ed4b9ec43c039 (patch)
tree703043599b6a3ed316df980de493b7ba03156d4e /roomserver
parent0fb94fc781a71219d5e537788e976bec1d84382c (diff)
Various bug fixes and tweaks around invites and membership
Diffstat (limited to 'roomserver')
-rw-r--r--roomserver/internal/helpers/helpers.go2
-rw-r--r--roomserver/internal/perform/perform_invite.go7
-rw-r--r--roomserver/internal/perform/perform_leave.go4
-rw-r--r--roomserver/storage/postgres/membership_table.go10
-rw-r--r--roomserver/storage/shared/membership_updater.go10
-rw-r--r--roomserver/storage/sqlite3/membership_table.go10
-rw-r--r--roomserver/storage/tables/interface.go2
7 files changed, 28 insertions, 17 deletions
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go
index 78a875c7..e67bbfca 100644
--- a/roomserver/internal/helpers/helpers.go
+++ b/roomserver/internal/helpers/helpers.go
@@ -28,7 +28,7 @@ func UpdateToInviteMembership(
// reprocessing this event, or because the we received this invite from a
// remote server via the federation invite API. In those cases we don't need
// to send the event.
- needsSending, err := mu.SetToInvite(*add)
+ needsSending, err := mu.SetToInvite(add)
if err != nil {
return nil, err
}
diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go
index 6559cd08..6111372d 100644
--- a/roomserver/internal/perform/perform_invite.go
+++ b/roomserver/internal/perform/perform_invite.go
@@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
+ "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
@@ -139,13 +140,15 @@ func (r *Inviter) PerformInvite(
// will never pass auth checks due to lacking room state, but we
// still need to tell the client about the invite so we can accept
// it, hence we return an output event to send to the sync api.
- updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion)
+ var updater *shared.MembershipUpdater
+ updater, err = r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion)
if err != nil {
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
}
unwrapped := event.Unwrap()
- outputUpdates, err := helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion)
+ var outputUpdates []api.OutputEvent
+ outputUpdates, err = helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion)
if err != nil {
return nil, fmt.Errorf("updateToInviteMembership: %w", err)
}
diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go
index 49ddd481..1e5fb9f1 100644
--- a/roomserver/internal/perform/perform_leave.go
+++ b/roomserver/internal/perform/perform_leave.go
@@ -91,12 +91,12 @@ func (r *Leaver) performLeaveRoomByID(
}
// check that this is not a "server notice room"
accData := &userapi.QueryAccountDataResponse{}
- if err := r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{
+ if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{
UserID: req.UserID,
RoomID: req.RoomID,
DataType: "m.tag",
}, accData); err != nil {
- return nil, fmt.Errorf("unable to query account data")
+ return nil, fmt.Errorf("unable to query account data: %w", err)
}
if roomData, ok := accData.RoomAccountData[req.RoomID]; ok {
diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go
index 12717874..6ed5293e 100644
--- a/roomserver/storage/postgres/membership_table.go
+++ b/roomserver/storage/postgres/membership_table.go
@@ -276,11 +276,15 @@ func (s *membershipStatements) UpdateMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID, forgotten bool,
-) error {
- _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext(
+) (bool, error) {
+ res, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext(
ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, forgotten,
)
- return err
+ if err != nil {
+ return false, err
+ }
+ rows, err := res.RowsAffected()
+ return rows > 0, err
}
func (s *membershipStatements) SelectRoomsWithMembership(
diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go
index 8f3f3d63..b7db9f81 100644
--- a/roomserver/storage/shared/membership_updater.go
+++ b/roomserver/storage/shared/membership_updater.go
@@ -92,7 +92,7 @@ func (u *MembershipUpdater) IsKnock() bool {
}
// SetToInvite implements types.MembershipUpdater
-func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
+func (u *MembershipUpdater) SetToInvite(event *gomatrixserverlib.Event) (bool, error) {
var inserted bool
err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
@@ -106,7 +106,7 @@ func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er
return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err)
}
if u.membership != tables.MembershipStateInvite {
- if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false); err != nil {
+ if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
@@ -142,7 +142,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
}
if u.membership != tables.MembershipStateJoin || isUpdate {
- if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil {
+ if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
@@ -176,7 +176,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
}
if u.membership != tables.MembershipStateLeaveOrBan {
- if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil {
+ if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
@@ -201,7 +201,7 @@ func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, er
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
- if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, nIDs[event.EventID()], false); err != nil {
+ if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, nIDs[event.EventID()], false); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
index 43567a94..7ed86b61 100644
--- a/roomserver/storage/sqlite3/membership_table.go
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -253,12 +253,16 @@ func (s *membershipStatements) UpdateMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID, forgotten bool,
-) error {
+) (bool, error) {
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
- _, err := stmt.ExecContext(
+ res, err := stmt.ExecContext(
ctx, senderUserNID, membership, eventNID, forgotten, roomNID, targetUserNID,
)
- return err
+ if err != nil {
+ return false, err
+ }
+ rows, err := res.RowsAffected()
+ return rows > 0, err
}
func (s *membershipStatements) SelectRoomsWithMembership(
diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go
index 04e3c96c..97e4afcf 100644
--- a/roomserver/storage/tables/interface.go
+++ b/roomserver/storage/tables/interface.go
@@ -125,7 +125,7 @@ type Membership interface {
SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error)
SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
- UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
+ UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error)