aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/postgres/storage.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/storage/postgres/storage.go')
-rw-r--r--roomserver/storage/postgres/storage.go97
1 files changed, 30 insertions, 67 deletions
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go
index 03cfb7f0..53a58076 100644
--- a/roomserver/storage/postgres/storage.go
+++ b/roomserver/storage/postgres/storage.go
@@ -33,7 +33,6 @@ import (
// A Database is used to store room events and stream offsets.
type Database struct {
shared.Database
- statements statements
events tables.Events
eventTypes tables.EventTypes
eventStateKeys tables.EventStateKeys
@@ -41,6 +40,8 @@ type Database struct {
rooms tables.Rooms
transactions tables.Transactions
prevEvents tables.PreviousEvents
+ invites tables.Invites
+ membership tables.Membership
db *sql.DB
}
@@ -52,9 +53,6 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
if d.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil {
return nil, err
}
- if err = d.statements.prepare(d.db); err != nil {
- return nil, err
- }
d.eventStateKeys, err = NewPostgresEventStateKeysTable(d.db)
if err != nil {
return nil, err
@@ -95,6 +93,14 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
if err != nil {
return nil, err
}
+ d.invites, err = NewPostgresInvitesTable(d.db)
+ if err != nil {
+ return nil, err
+ }
+ d.membership, err = NewPostgresMembershipTable(d.db)
+ if err != nil {
+ return nil, err
+ }
d.Database = shared.Database{
DB: d.db,
EventTypesTable: d.eventTypes,
@@ -107,6 +113,8 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
StateSnapshotTable: stateSnapshot,
PrevEventsTable: d.prevEvents,
RoomAliasesTable: roomAliases,
+ InvitesTable: d.invites,
+ MembershipTable: d.membership,
}
return &d, nil
}
@@ -254,15 +262,6 @@ func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventSta
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal)
}
-// GetInvitesForUser implements query.RoomserverQueryAPIDatabase
-func (d *Database) GetInvitesForUser(
- ctx context.Context,
- roomNID types.RoomNID,
- targetUserNID types.EventStateKeyNID,
-) (senderUserIDs []types.EventStateKeyNID, err error) {
- return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
-}
-
// MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string,
@@ -303,7 +302,7 @@ type membershipUpdater struct {
d *Database
roomNID types.RoomNID
targetUserNID types.EventStateKeyNID
- membership membershipState
+ membership tables.MembershipState
}
func (d *Database) membershipUpdaterTxn(
@@ -314,11 +313,11 @@ func (d *Database) membershipUpdaterTxn(
targetLocal bool,
) (types.MembershipUpdater, error) {
- if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
+ if err := d.membership.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
return nil, err
}
- membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
+ membership, err := d.membership.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
if err != nil {
return nil, err
}
@@ -330,17 +329,17 @@ func (d *Database) membershipUpdaterTxn(
// IsInvite implements types.MembershipUpdater
func (u *membershipUpdater) IsInvite() bool {
- return u.membership == membershipStateInvite
+ return u.membership == tables.MembershipStateInvite
}
// IsJoin implements types.MembershipUpdater
func (u *membershipUpdater) IsJoin() bool {
- return u.membership == membershipStateJoin
+ return u.membership == tables.MembershipStateJoin
}
// IsLeave implements types.MembershipUpdater
func (u *membershipUpdater) IsLeave() bool {
- return u.membership == membershipStateLeaveOrBan
+ return u.membership == tables.MembershipStateLeaveOrBan
}
// SetToInvite implements types.MembershipUpdater
@@ -349,15 +348,15 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er
if err != nil {
return false, err
}
- inserted, err := u.d.statements.insertInviteEvent(
+ inserted, err := u.d.invites.InsertInviteEvent(
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
)
if err != nil {
return false, err
}
- if u.membership != membershipStateInvite {
- if err = u.d.statements.updateMembership(
- u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
+ if u.membership != tables.MembershipStateInvite {
+ if err = u.d.membership.UpdateMembership(
+ u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
); err != nil {
return false, err
}
@@ -376,7 +375,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
// If this is a join event update, there is no invite to update
if !isUpdate {
- inviteEventIDs, err = u.d.statements.updateInviteRetired(
+ inviteEventIDs, err = u.d.invites.UpdateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
@@ -390,10 +389,10 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
return nil, err
}
- if u.membership != membershipStateJoin || isUpdate {
- if err = u.d.statements.updateMembership(
+ if u.membership != tables.MembershipStateJoin || isUpdate {
+ if err = u.d.membership.UpdateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
- membershipStateJoin, nIDs[eventID],
+ tables.MembershipStateJoin, nIDs[eventID],
); err != nil {
return nil, err
}
@@ -408,7 +407,7 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
if err != nil {
return nil, err
}
- inviteEventIDs, err := u.d.statements.updateInviteRetired(
+ inviteEventIDs, err := u.d.invites.UpdateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
@@ -421,10 +420,10 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
return nil, err
}
- if u.membership != membershipStateLeaveOrBan {
- if err = u.d.statements.updateMembership(
+ if u.membership != tables.MembershipStateLeaveOrBan {
+ if err = u.d.membership.UpdateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
- membershipStateLeaveOrBan, nIDs[eventID],
+ tables.MembershipStateLeaveOrBan, nIDs[eventID],
); err != nil {
return nil, err
}
@@ -432,42 +431,6 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
return inviteEventIDs, nil
}
-// GetMembership implements query.RoomserverQueryAPIDB
-func (d *Database) GetMembership(
- ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
-) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
- requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID)
- if err != nil {
- return
- }
-
- senderMembershipEventNID, senderMembership, err :=
- d.statements.selectMembershipFromRoomAndTarget(
- ctx, roomNID, requestSenderUserNID,
- )
- if err == sql.ErrNoRows {
- // The user has never been a member of that room
- return 0, false, nil
- } else if err != nil {
- return
- }
-
- return senderMembershipEventNID, senderMembership == membershipStateJoin, nil
-}
-
-// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
-func (d *Database) GetMembershipEventNIDsForRoom(
- ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
-) ([]types.EventNID, error) {
- if joinOnly {
- return d.statements.selectMembershipsFromRoomAndMembership(
- ctx, roomNID, membershipStateJoin, localOnly,
- )
- }
-
- return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly)
-}
-
type transaction struct {
ctx context.Context
txn *sql.Tx