diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-03-17 17:05:21 +0000 |
---|---|---|
committer | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-03-17 17:05:21 +0000 |
commit | 4e64c270dbe5d438325903e4404ed4b9ec43c039 (patch) | |
tree | 703043599b6a3ed316df980de493b7ba03156d4e /roomserver | |
parent | 0fb94fc781a71219d5e537788e976bec1d84382c (diff) |
Various bug fixes and tweaks around invites and membership
Diffstat (limited to 'roomserver')
-rw-r--r-- | roomserver/internal/helpers/helpers.go | 2 | ||||
-rw-r--r-- | roomserver/internal/perform/perform_invite.go | 7 | ||||
-rw-r--r-- | roomserver/internal/perform/perform_leave.go | 4 | ||||
-rw-r--r-- | roomserver/storage/postgres/membership_table.go | 10 | ||||
-rw-r--r-- | roomserver/storage/shared/membership_updater.go | 10 | ||||
-rw-r--r-- | roomserver/storage/sqlite3/membership_table.go | 10 | ||||
-rw-r--r-- | roomserver/storage/tables/interface.go | 2 |
7 files changed, 28 insertions, 17 deletions
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 78a875c7..e67bbfca 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -28,7 +28,7 @@ func UpdateToInviteMembership( // reprocessing this event, or because the we received this invite from a // remote server via the federation invite API. In those cases we don't need // to send the event. - needsSending, err := mu.SetToInvite(*add) + needsSending, err := mu.SetToInvite(add) if err != nil { return nil, err } diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 6559cd08..6111372d 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" @@ -139,13 +140,15 @@ func (r *Inviter) PerformInvite( // will never pass auth checks due to lacking room state, but we // still need to tell the client about the invite so we can accept // it, hence we return an output event to send to the sync api. - updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion) + var updater *shared.MembershipUpdater + updater, err = r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion) if err != nil { return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) } unwrapped := event.Unwrap() - outputUpdates, err := helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion) + var outputUpdates []api.OutputEvent + outputUpdates, err = helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion) if err != nil { return nil, fmt.Errorf("updateToInviteMembership: %w", err) } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 49ddd481..1e5fb9f1 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -91,12 +91,12 @@ func (r *Leaver) performLeaveRoomByID( } // check that this is not a "server notice room" accData := &userapi.QueryAccountDataResponse{} - if err := r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ + if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ UserID: req.UserID, RoomID: req.RoomID, DataType: "m.tag", }, accData); err != nil { - return nil, fmt.Errorf("unable to query account data") + return nil, fmt.Errorf("unable to query account data: %w", err) } if roomData, ok := accData.RoomAccountData[req.RoomID]; ok { diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 12717874..6ed5293e 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -276,11 +276,15 @@ func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, forgotten bool, -) error { - _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( +) (bool, error) { + res, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, forgotten, ) - return err + if err != nil { + return false, err + } + rows, err := res.RowsAffected() + return rows > 0, err } func (s *membershipStatements) SelectRoomsWithMembership( diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 8f3f3d63..b7db9f81 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -92,7 +92,7 @@ func (u *MembershipUpdater) IsKnock() bool { } // SetToInvite implements types.MembershipUpdater -func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { +func (u *MembershipUpdater) SetToInvite(event *gomatrixserverlib.Event) (bool, error) { var inserted bool err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) @@ -106,7 +106,7 @@ func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) } if u.membership != tables.MembershipStateInvite { - if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false); err != nil { + if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -142,7 +142,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd } if u.membership != tables.MembershipStateJoin || isUpdate { - if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil { + if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -176,7 +176,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s } if u.membership != tables.MembershipStateLeaveOrBan { - if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil { + if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -201,7 +201,7 @@ func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, er return fmt.Errorf("u.d.EventNIDs: %w", err) } - if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, nIDs[event.EventID()], false); err != nil { + if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, nIDs[event.EventID()], false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 43567a94..7ed86b61 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -253,12 +253,16 @@ func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, forgotten bool, -) error { +) (bool, error) { stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) - _, err := stmt.ExecContext( + res, err := stmt.ExecContext( ctx, senderUserNID, membership, eventNID, forgotten, roomNID, targetUserNID, ) - return err + if err != nil { + return false, err + } + rows, err := res.RowsAffected() + return rows > 0, err } func (s *membershipStatements) SelectRoomsWithMembership( diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 04e3c96c..97e4afcf 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -125,7 +125,7 @@ type Membership interface { SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) - UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error + UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error) SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms. SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) |