diff options
Diffstat (limited to 'roomserver/storage/postgres/storage.go')
-rw-r--r-- | roomserver/storage/postgres/storage.go | 97 |
1 files changed, 30 insertions, 67 deletions
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 03cfb7f0..53a58076 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -33,7 +33,6 @@ import ( // A Database is used to store room events and stream offsets. type Database struct { shared.Database - statements statements events tables.Events eventTypes tables.EventTypes eventStateKeys tables.EventStateKeys @@ -41,6 +40,8 @@ type Database struct { rooms tables.Rooms transactions tables.Transactions prevEvents tables.PreviousEvents + invites tables.Invites + membership tables.Membership db *sql.DB } @@ -52,9 +53,6 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if d.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } - if err = d.statements.prepare(d.db); err != nil { - return nil, err - } d.eventStateKeys, err = NewPostgresEventStateKeysTable(d.db) if err != nil { return nil, err @@ -95,6 +93,14 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if err != nil { return nil, err } + d.invites, err = NewPostgresInvitesTable(d.db) + if err != nil { + return nil, err + } + d.membership, err = NewPostgresMembershipTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, EventTypesTable: d.eventTypes, @@ -107,6 +113,8 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, StateSnapshotTable: stateSnapshot, PrevEventsTable: d.prevEvents, RoomAliasesTable: roomAliases, + InvitesTable: d.invites, + MembershipTable: d.membership, } return &d, nil } @@ -254,15 +262,6 @@ func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventSta return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) } -// GetInvitesForUser implements query.RoomserverQueryAPIDatabase -func (d *Database) GetInvitesForUser( - ctx context.Context, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, -) (senderUserIDs []types.EventStateKeyNID, err error) { - return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) -} - // MembershipUpdater implements input.RoomEventDatabase func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, @@ -303,7 +302,7 @@ type membershipUpdater struct { d *Database roomNID types.RoomNID targetUserNID types.EventStateKeyNID - membership membershipState + membership tables.MembershipState } func (d *Database) membershipUpdaterTxn( @@ -314,11 +313,11 @@ func (d *Database) membershipUpdaterTxn( targetLocal bool, ) (types.MembershipUpdater, error) { - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { + if err := d.membership.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { return nil, err } - membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) + membership, err := d.membership.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) if err != nil { return nil, err } @@ -330,17 +329,17 @@ func (d *Database) membershipUpdaterTxn( // IsInvite implements types.MembershipUpdater func (u *membershipUpdater) IsInvite() bool { - return u.membership == membershipStateInvite + return u.membership == tables.MembershipStateInvite } // IsJoin implements types.MembershipUpdater func (u *membershipUpdater) IsJoin() bool { - return u.membership == membershipStateJoin + return u.membership == tables.MembershipStateJoin } // IsLeave implements types.MembershipUpdater func (u *membershipUpdater) IsLeave() bool { - return u.membership == membershipStateLeaveOrBan + return u.membership == tables.MembershipStateLeaveOrBan } // SetToInvite implements types.MembershipUpdater @@ -349,15 +348,15 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er if err != nil { return false, err } - inserted, err := u.d.statements.insertInviteEvent( + inserted, err := u.d.invites.InsertInviteEvent( u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), ) if err != nil { return false, err } - if u.membership != membershipStateInvite { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, + if u.membership != tables.MembershipStateInvite { + if err = u.d.membership.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, ); err != nil { return false, err } @@ -376,7 +375,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd // If this is a join event update, there is no invite to update if !isUpdate { - inviteEventIDs, err = u.d.statements.updateInviteRetired( + inviteEventIDs, err = u.d.invites.UpdateInviteRetired( u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { @@ -390,10 +389,10 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd return nil, err } - if u.membership != membershipStateJoin || isUpdate { - if err = u.d.statements.updateMembership( + if u.membership != tables.MembershipStateJoin || isUpdate { + if err = u.d.membership.UpdateMembership( u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateJoin, nIDs[eventID], + tables.MembershipStateJoin, nIDs[eventID], ); err != nil { return nil, err } @@ -408,7 +407,7 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s if err != nil { return nil, err } - inviteEventIDs, err := u.d.statements.updateInviteRetired( + inviteEventIDs, err := u.d.invites.UpdateInviteRetired( u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { @@ -421,10 +420,10 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s return nil, err } - if u.membership != membershipStateLeaveOrBan { - if err = u.d.statements.updateMembership( + if u.membership != tables.MembershipStateLeaveOrBan { + if err = u.d.membership.UpdateMembership( u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateLeaveOrBan, nIDs[eventID], + tables.MembershipStateLeaveOrBan, nIDs[eventID], ); err != nil { return nil, err } @@ -432,42 +431,6 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s return inviteEventIDs, nil } -// GetMembership implements query.RoomserverQueryAPIDB -func (d *Database) GetMembership( - ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, -) (membershipEventNID types.EventNID, stillInRoom bool, err error) { - requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) - if err != nil { - return - } - - senderMembershipEventNID, senderMembership, err := - d.statements.selectMembershipFromRoomAndTarget( - ctx, roomNID, requestSenderUserNID, - ) - if err == sql.ErrNoRows { - // The user has never been a member of that room - return 0, false, nil - } else if err != nil { - return - } - - return senderMembershipEventNID, senderMembership == membershipStateJoin, nil -} - -// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB -func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, -) ([]types.EventNID, error) { - if joinOnly { - return d.statements.selectMembershipsFromRoomAndMembership( - ctx, roomNID, membershipStateJoin, localOnly, - ) - } - - return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly) -} - type transaction struct { ctx context.Context txn *sql.Tx |