aboutsummaryrefslogtreecommitdiff
path: root/roomserver
diff options
context:
space:
mode:
authorruben <code@rbn.im>2019-05-21 22:56:55 +0200
committerBrendan Abolivier <babolivier@matrix.org>2019-05-21 21:56:55 +0100
commit74827428bd3e11faab65f12204449c1b9469b0ae (patch)
tree0decafa542436a0667ed2d3e3cfd4df0f03de1e5 /roomserver
parent4d588f7008afe5600219ac0930c2eee2de5c447b (diff)
use go module for dependencies (#594)
Diffstat (limited to 'roomserver')
-rw-r--r--roomserver/README.md59
-rw-r--r--roomserver/alias/alias.go285
-rw-r--r--roomserver/api/alias.go183
-rw-r--r--roomserver/api/input.go139
-rw-r--r--roomserver/api/output.go138
-rw-r--r--roomserver/api/query.go480
-rw-r--r--roomserver/auth/auth.go47
-rw-r--r--roomserver/input/authevents.go243
-rw-r--r--roomserver/input/authevents_test.go136
-rw-r--r--roomserver/input/events.go235
-rw-r--r--roomserver/input/input.go95
-rw-r--r--roomserver/input/latest_events.go293
-rw-r--r--roomserver/input/membership.go310
-rw-r--r--roomserver/query/query.go802
-rw-r--r--roomserver/query/query_test.go155
-rw-r--r--roomserver/roomserver.go68
-rw-r--r--roomserver/state/state.go966
-rw-r--r--roomserver/state/state_test.go56
-rw-r--r--roomserver/storage/event_json_table.go105
-rw-r--r--roomserver/storage/event_state_keys_table.go153
-rw-r--r--roomserver/storage/event_types_table.go146
-rw-r--r--roomserver/storage/events_table.go410
-rw-r--r--roomserver/storage/invite_table.go154
-rw-r--r--roomserver/storage/membership_table.go193
-rw-r--r--roomserver/storage/prepare.go36
-rw-r--r--roomserver/storage/previous_events_table.go99
-rw-r--r--roomserver/storage/room_aliases_table.go109
-rw-r--r--roomserver/storage/rooms_table.go155
-rw-r--r--roomserver/storage/sql.go59
-rw-r--r--roomserver/storage/state_block_table.go280
-rw-r--r--roomserver/storage/state_block_table_test.go85
-rw-r--r--roomserver/storage/state_snapshot_table.go118
-rw-r--r--roomserver/storage/storage.go705
-rw-r--r--roomserver/storage/transactions_table.go86
-rw-r--r--roomserver/types/types.go203
35 files changed, 7786 insertions, 0 deletions
diff --git a/roomserver/README.md b/roomserver/README.md
new file mode 100644
index 00000000..5a275760
--- /dev/null
+++ b/roomserver/README.md
@@ -0,0 +1,59 @@
+# RoomServer
+
+
+## RoomServer Internals
+
+### Numeric IDs
+
+To save space matrix string identifiers are mapped to local numeric IDs.
+The numeric IDs are more efficient to manipulate and use less space to store.
+The numeric IDs are never exposed in the API the room server exposes.
+The numeric IDs are converted to string IDs before they leave the room server.
+The numeric ID for a string ID is never 0 to avoid being confused with go's
+default zero value.
+Zero is used to indicate that there was no corresponding string ID.
+Well-known event types and event state keys are preassigned numeric IDs.
+
+### State Snapshot Storage
+
+The room server stores the state of the matrix room at each event.
+For efficiency the state is stored as blocks of 3-tuples of numeric IDs for the
+event type, event state key and event ID. For further efficiency the state
+snapshots are stored as the combination of up to 64 these blocks. This allows
+blocks of the room state to be reused in multiple snapshots.
+
+The resulting database tables look something like this:
+
+ +-------------------------------------------------------------------+
+ | Events |
+ +---------+-------------------+------------------+------------------+
+ | EventNID| EventTypeNID | EventStateKeyNID | StateSnapshotNID |
+ +---------+-------------------+------------------+------------------+
+ | 1 | m.room.create 1 | "" 1 | <nil> 0 |
+ | 2 | m.room.member 2 | "@user:foo" 2 | <nil> 0 |
+ | 3 | m.room.member 2 | "@user:bar" 3 | {1,2} 1 |
+ | 4 | m.room.message 3 | <nil> 0 | {1,2,3} 2 |
+ | 5 | m.room.member 2 | "@user:foo" 2 | {1,2,3} 2 |
+ | 6 | m.room.message 3 | <nil> 0 | {1,3,6} 3 |
+ +---------+-------------------+------------------+------------------+
+
+ +----------------------------------------+
+ | State Snapshots |
+ +-----------------------+----------------+
+ | EventStateSnapshotNID | StateBlockNIDs |
+ +-----------------------+----------------|
+ | 1 | {1} |
+ | 2 | {1,2} |
+ | 3 | {1,2,3} |
+ +-----------------------+----------------+
+
+ +-----------------------------------------------------------------+
+ | State Blocks |
+ +---------------+-------------------+------------------+----------+
+ | StateBlockNID | EventTypeNID | EventStateKeyNID | EventNID |
+ +---------------+-------------------+------------------+----------+
+ | 1 | m.room.create 1 | "" 1 | 1 |
+ | 1 | m.room.member 2 | "@user:foo" 2 | 2 |
+ | 2 | m.room.member 2 | "@user:bar" 3 | 3 |
+ | 3 | m.room.member 2 | "@user:foo" 2 | 6 |
+ +---------------+-------------------+------------------+----------+
diff --git a/roomserver/alias/alias.go b/roomserver/alias/alias.go
new file mode 100644
index 00000000..27279aad
--- /dev/null
+++ b/roomserver/alias/alias.go
@@ -0,0 +1,285 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package alias
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "time"
+
+ appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/common/config"
+ roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+)
+
+// RoomserverAliasAPIDatabase has the storage APIs needed to implement the alias API.
+type RoomserverAliasAPIDatabase interface {
+ // Save a given room alias with the room ID it refers to.
+ // Returns an error if there was a problem talking to the database.
+ SetRoomAlias(ctx context.Context, alias string, roomID string) error
+ // Look up the room ID a given alias refers to.
+ // Returns an error if there was a problem talking to the database.
+ GetRoomIDForAlias(ctx context.Context, alias string) (string, error)
+ // Look up all aliases referring to a given room ID.
+ // Returns an error if there was a problem talking to the database.
+ GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error)
+ // Remove a given room alias.
+ // Returns an error if there was a problem talking to the database.
+ RemoveRoomAlias(ctx context.Context, alias string) error
+}
+
+// RoomserverAliasAPI is an implementation of alias.RoomserverAliasAPI
+type RoomserverAliasAPI struct {
+ DB RoomserverAliasAPIDatabase
+ Cfg *config.Dendrite
+ InputAPI roomserverAPI.RoomserverInputAPI
+ QueryAPI roomserverAPI.RoomserverQueryAPI
+ AppserviceAPI appserviceAPI.AppServiceQueryAPI
+}
+
+// SetRoomAlias implements alias.RoomserverAliasAPI
+func (r *RoomserverAliasAPI) SetRoomAlias(
+ ctx context.Context,
+ request *roomserverAPI.SetRoomAliasRequest,
+ response *roomserverAPI.SetRoomAliasResponse,
+) error {
+ // Check if the alias isn't already referring to a room
+ roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
+ if err != nil {
+ return err
+ }
+ if len(roomID) > 0 {
+ // If the alias already exists, stop the process
+ response.AliasExists = true
+ return nil
+ }
+ response.AliasExists = false
+
+ // Save the new alias
+ if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID); err != nil {
+ return err
+ }
+
+ // Send a m.room.aliases event with the updated list of aliases for this room
+ // At this point we've already committed the alias to the database so we
+ // shouldn't cancel this request.
+ // TODO: Ensure that we send unsent events when if server restarts.
+ return r.sendUpdatedAliasesEvent(context.TODO(), request.UserID, request.RoomID)
+}
+
+// GetRoomIDForAlias implements alias.RoomserverAliasAPI
+func (r *RoomserverAliasAPI) GetRoomIDForAlias(
+ ctx context.Context,
+ request *roomserverAPI.GetRoomIDForAliasRequest,
+ response *roomserverAPI.GetRoomIDForAliasResponse,
+) error {
+ // Look up the room ID in the database
+ roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
+ if err != nil {
+ return err
+ }
+
+ // No rooms found locally, try our application services by making a call to
+ // the appservice component
+ aliasReq := appserviceAPI.RoomAliasExistsRequest{Alias: request.Alias}
+ var aliasResp appserviceAPI.RoomAliasExistsResponse
+ if err = r.AppserviceAPI.RoomAliasExists(ctx, &aliasReq, &aliasResp); err != nil {
+ return err
+ }
+
+ response.RoomID = roomID
+ return nil
+}
+
+// GetAliasesForRoomID implements alias.RoomserverAliasAPI
+func (r *RoomserverAliasAPI) GetAliasesForRoomID(
+ ctx context.Context,
+ request *roomserverAPI.GetAliasesForRoomIDRequest,
+ response *roomserverAPI.GetAliasesForRoomIDResponse,
+) error {
+ // Look up the aliases in the database for the given RoomID
+ aliases, err := r.DB.GetAliasesForRoomID(ctx, request.RoomID)
+ if err != nil {
+ return err
+ }
+
+ response.Aliases = aliases
+ return nil
+}
+
+// RemoveRoomAlias implements alias.RoomserverAliasAPI
+func (r *RoomserverAliasAPI) RemoveRoomAlias(
+ ctx context.Context,
+ request *roomserverAPI.RemoveRoomAliasRequest,
+ response *roomserverAPI.RemoveRoomAliasResponse,
+) error {
+ // Look up the room ID in the database
+ roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
+ if err != nil {
+ return err
+ }
+
+ // Remove the dalias from the database
+ if err := r.DB.RemoveRoomAlias(ctx, request.Alias); err != nil {
+ return err
+ }
+
+ // Send an updated m.room.aliases event
+ // At this point we've already committed the alias to the database so we
+ // shouldn't cancel this request.
+ // TODO: Ensure that we send unsent events when if server restarts.
+ return r.sendUpdatedAliasesEvent(context.TODO(), request.UserID, roomID)
+}
+
+type roomAliasesContent struct {
+ Aliases []string `json:"aliases"`
+}
+
+// Build the updated m.room.aliases event to send to the room after addition or
+// removal of an alias
+func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent(
+ ctx context.Context, userID string, roomID string,
+) error {
+ serverName := string(r.Cfg.Matrix.ServerName)
+
+ builder := gomatrixserverlib.EventBuilder{
+ Sender: userID,
+ RoomID: roomID,
+ Type: "m.room.aliases",
+ StateKey: &serverName,
+ }
+
+ // Retrieve the updated list of aliases, marhal it and set it as the
+ // event's content
+ aliases, err := r.DB.GetAliasesForRoomID(ctx, roomID)
+ if err != nil {
+ return err
+ }
+ content := roomAliasesContent{Aliases: aliases}
+ rawContent, err := json.Marshal(content)
+ if err != nil {
+ return err
+ }
+ err = builder.SetContent(json.RawMessage(rawContent))
+ if err != nil {
+ return err
+ }
+
+ // Get needed state events and depth
+ eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(&builder)
+ if err != nil {
+ return err
+ }
+ req := roomserverAPI.QueryLatestEventsAndStateRequest{
+ RoomID: roomID,
+ StateToFetch: eventsNeeded.Tuples(),
+ }
+ var res roomserverAPI.QueryLatestEventsAndStateResponse
+ if err = r.QueryAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil {
+ return err
+ }
+ builder.Depth = res.Depth
+ builder.PrevEvents = res.LatestEvents
+
+ // Add auth events
+ authEvents := gomatrixserverlib.NewAuthEvents(nil)
+ for i := range res.StateEvents {
+ err = authEvents.AddEvent(&res.StateEvents[i])
+ if err != nil {
+ return err
+ }
+ }
+ refs, err := eventsNeeded.AuthEventReferences(&authEvents)
+ if err != nil {
+ return err
+ }
+ builder.AuthEvents = refs
+
+ // Build the event
+ eventID := fmt.Sprintf("$%s:%s", util.RandomString(16), r.Cfg.Matrix.ServerName)
+ now := time.Now()
+ event, err := builder.Build(
+ eventID, now, r.Cfg.Matrix.ServerName, r.Cfg.Matrix.KeyID, r.Cfg.Matrix.PrivateKey,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Create the request
+ ire := roomserverAPI.InputRoomEvent{
+ Kind: roomserverAPI.KindNew,
+ Event: event,
+ AuthEventIDs: event.AuthEventIDs(),
+ SendAsServer: serverName,
+ }
+ inputReq := roomserverAPI.InputRoomEventsRequest{
+ InputRoomEvents: []roomserverAPI.InputRoomEvent{ire},
+ }
+ var inputRes roomserverAPI.InputRoomEventsResponse
+
+ // Send the request
+ return r.InputAPI.InputRoomEvents(ctx, &inputReq, &inputRes)
+}
+
+// SetupHTTP adds the RoomserverAliasAPI handlers to the http.ServeMux.
+func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux) {
+ servMux.Handle(
+ roomserverAPI.RoomserverSetRoomAliasPath,
+ common.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse {
+ var request roomserverAPI.SetRoomAliasRequest
+ var response roomserverAPI.SetRoomAliasResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.ErrorResponse(err)
+ }
+ if err := r.SetRoomAlias(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ servMux.Handle(
+ roomserverAPI.RoomserverGetRoomIDForAliasPath,
+ common.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse {
+ var request roomserverAPI.GetRoomIDForAliasRequest
+ var response roomserverAPI.GetRoomIDForAliasResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.ErrorResponse(err)
+ }
+ if err := r.GetRoomIDForAlias(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ servMux.Handle(
+ roomserverAPI.RoomserverRemoveRoomAliasPath,
+ common.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse {
+ var request roomserverAPI.RemoveRoomAliasRequest
+ var response roomserverAPI.RemoveRoomAliasResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.ErrorResponse(err)
+ }
+ if err := r.RemoveRoomAlias(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+}
diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go
new file mode 100644
index 00000000..57671071
--- /dev/null
+++ b/roomserver/api/alias.go
@@ -0,0 +1,183 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package api
+
+import (
+ "context"
+ "net/http"
+
+ commonHTTP "github.com/matrix-org/dendrite/common/http"
+ opentracing "github.com/opentracing/opentracing-go"
+)
+
+// SetRoomAliasRequest is a request to SetRoomAlias
+type SetRoomAliasRequest struct {
+ // ID of the user setting the alias
+ UserID string `json:"user_id"`
+ // New alias for the room
+ Alias string `json:"alias"`
+ // The room ID the alias is referring to
+ RoomID string `json:"room_id"`
+}
+
+// SetRoomAliasResponse is a response to SetRoomAlias
+type SetRoomAliasResponse struct {
+ // Does the alias already refer to a room?
+ AliasExists bool `json:"alias_exists"`
+}
+
+// GetRoomIDForAliasRequest is a request to GetRoomIDForAlias
+type GetRoomIDForAliasRequest struct {
+ // Alias we want to lookup
+ Alias string `json:"alias"`
+}
+
+// GetRoomIDForAliasResponse is a response to GetRoomIDForAlias
+type GetRoomIDForAliasResponse struct {
+ // The room ID the alias refers to
+ RoomID string `json:"room_id"`
+}
+
+// GetAliasesForRoomIDRequest is a request to GetAliasesForRoomID
+type GetAliasesForRoomIDRequest struct {
+ // The room ID we want to find aliases for
+ RoomID string `json:"room_id"`
+}
+
+// GetAliasesForRoomIDResponse is a response to GetAliasesForRoomID
+type GetAliasesForRoomIDResponse struct {
+ // The aliases the alias refers to
+ Aliases []string `json:"aliases"`
+}
+
+// RemoveRoomAliasRequest is a request to RemoveRoomAlias
+type RemoveRoomAliasRequest struct {
+ // ID of the user removing the alias
+ UserID string `json:"user_id"`
+ // The room alias to remove
+ Alias string `json:"alias"`
+}
+
+// RemoveRoomAliasResponse is a response to RemoveRoomAlias
+type RemoveRoomAliasResponse struct{}
+
+// RoomserverAliasAPI is used to save, lookup or remove a room alias
+type RoomserverAliasAPI interface {
+ // Set a room alias
+ SetRoomAlias(
+ ctx context.Context,
+ req *SetRoomAliasRequest,
+ response *SetRoomAliasResponse,
+ ) error
+
+ // Get the room ID for an alias
+ GetRoomIDForAlias(
+ ctx context.Context,
+ req *GetRoomIDForAliasRequest,
+ response *GetRoomIDForAliasResponse,
+ ) error
+
+ // Get all known aliases for a room ID
+ GetAliasesForRoomID(
+ ctx context.Context,
+ req *GetAliasesForRoomIDRequest,
+ response *GetAliasesForRoomIDResponse,
+ ) error
+
+ // Remove a room alias
+ RemoveRoomAlias(
+ ctx context.Context,
+ req *RemoveRoomAliasRequest,
+ response *RemoveRoomAliasResponse,
+ ) error
+}
+
+// RoomserverSetRoomAliasPath is the HTTP path for the SetRoomAlias API.
+const RoomserverSetRoomAliasPath = "/api/roomserver/setRoomAlias"
+
+// RoomserverGetRoomIDForAliasPath is the HTTP path for the GetRoomIDForAlias API.
+const RoomserverGetRoomIDForAliasPath = "/api/roomserver/GetRoomIDForAlias"
+
+// RoomserverGetAliasesForRoomIDPath is the HTTP path for the GetAliasesForRoomID API.
+const RoomserverGetAliasesForRoomIDPath = "/api/roomserver/GetAliasesForRoomID"
+
+// RoomserverRemoveRoomAliasPath is the HTTP path for the RemoveRoomAlias API.
+const RoomserverRemoveRoomAliasPath = "/api/roomserver/removeRoomAlias"
+
+// NewRoomserverAliasAPIHTTP creates a RoomserverAliasAPI implemented by talking to a HTTP POST API.
+// If httpClient is nil then it uses the http.DefaultClient
+func NewRoomserverAliasAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverAliasAPI {
+ if httpClient == nil {
+ httpClient = http.DefaultClient
+ }
+ return &httpRoomserverAliasAPI{roomserverURL, httpClient}
+}
+
+type httpRoomserverAliasAPI struct {
+ roomserverURL string
+ httpClient *http.Client
+}
+
+// SetRoomAlias implements RoomserverAliasAPI
+func (h *httpRoomserverAliasAPI) SetRoomAlias(
+ ctx context.Context,
+ request *SetRoomAliasRequest,
+ response *SetRoomAliasResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "SetRoomAlias")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverSetRoomAliasPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+// GetRoomIDForAlias implements RoomserverAliasAPI
+func (h *httpRoomserverAliasAPI) GetRoomIDForAlias(
+ ctx context.Context,
+ request *GetRoomIDForAliasRequest,
+ response *GetRoomIDForAliasResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "GetRoomIDForAlias")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverGetRoomIDForAliasPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+// GetAliasesForRoomID implements RoomserverAliasAPI
+func (h *httpRoomserverAliasAPI) GetAliasesForRoomID(
+ ctx context.Context,
+ request *GetAliasesForRoomIDRequest,
+ response *GetAliasesForRoomIDResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "GetAliasesForRoomID")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverGetAliasesForRoomIDPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+// RemoveRoomAlias implements RoomserverAliasAPI
+func (h *httpRoomserverAliasAPI) RemoveRoomAlias(
+ ctx context.Context,
+ request *RemoveRoomAliasRequest,
+ response *RemoveRoomAliasResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "RemoveRoomAlias")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverRemoveRoomAliasPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
diff --git a/roomserver/api/input.go b/roomserver/api/input.go
new file mode 100644
index 00000000..2c2e27c6
--- /dev/null
+++ b/roomserver/api/input.go
@@ -0,0 +1,139 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package api provides the types that are used to communicate with the roomserver.
+package api
+
+import (
+ "context"
+ "net/http"
+
+ commonHTTP "github.com/matrix-org/dendrite/common/http"
+ "github.com/matrix-org/gomatrixserverlib"
+ opentracing "github.com/opentracing/opentracing-go"
+)
+
+const (
+ // KindOutlier event fall outside the contiguous event graph.
+ // We do not have the state for these events.
+ // These events are state events used to authenticate other events.
+ // They can become part of the contiguous event graph via backfill.
+ KindOutlier = 1
+ // KindNew event extend the contiguous graph going forwards.
+ // They usually don't need state, but may include state if the
+ // there was a new event that references an event that we don't
+ // have a copy of.
+ KindNew = 2
+ // KindBackfill event extend the contiguous graph going backwards.
+ // They always have state.
+ KindBackfill = 3
+)
+
+// DoNotSendToOtherServers tells us not to send the event to other matrix
+// servers.
+const DoNotSendToOtherServers = ""
+
+// InputRoomEvent is a matrix room event to add to the room server database.
+// TODO: Implement UnmarshalJSON/MarshalJSON in a way that does something sensible with the event JSON.
+type InputRoomEvent struct {
+ // Whether this event is new, backfilled or an outlier.
+ // This controls how the event is processed.
+ Kind int `json:"kind"`
+ // The event JSON for the event to add.
+ Event gomatrixserverlib.Event `json:"event"`
+ // List of state event IDs that authenticate this event.
+ // These are likely derived from the "auth_events" JSON key of the event.
+ // But can be different because the "auth_events" key can be incomplete or wrong.
+ // For example many matrix events forget to reference the m.room.create event even though it is needed for auth.
+ // (since synapse allows this to happen we have to allow it as well.)
+ AuthEventIDs []string `json:"auth_event_ids"`
+ // Whether the state is supplied as a list of event IDs or whether it
+ // should be derived from the state at the previous events.
+ HasState bool `json:"has_state"`
+ // Optional list of state event IDs forming the state before this event.
+ // These state events must have already been persisted.
+ // These are only used if HasState is true.
+ // The list can be empty, for example when storing the first event in a room.
+ StateEventIDs []string `json:"state_event_ids"`
+ // The server name to use to push this event to other servers.
+ // Or empty if this event shouldn't be pushed to other servers.
+ SendAsServer string `json:"send_as_server"`
+ // The transaction ID of the send request if sent by a local user and one
+ // was specified
+ TransactionID *TransactionID `json:"transaction_id"`
+}
+
+// TransactionID contains the transaction ID sent by a client when sending an
+// event, along with the ID of that device.
+type TransactionID struct {
+ DeviceID string `json:"device_id"`
+ TransactionID string `json:"id"`
+}
+
+// InputInviteEvent is a matrix invite event received over federation without
+// the usual context a matrix room event would have. We usually do not have
+// access to the events needed to check the event auth rules for the invite.
+type InputInviteEvent struct {
+ Event gomatrixserverlib.Event `json:"event"`
+}
+
+// InputRoomEventsRequest is a request to InputRoomEvents
+type InputRoomEventsRequest struct {
+ InputRoomEvents []InputRoomEvent `json:"input_room_events"`
+ InputInviteEvents []InputInviteEvent `json:"input_invite_events"`
+}
+
+// InputRoomEventsResponse is a response to InputRoomEvents
+type InputRoomEventsResponse struct {
+ EventID string `json:"event_id"`
+}
+
+// RoomserverInputAPI is used to write events to the room server.
+type RoomserverInputAPI interface {
+ InputRoomEvents(
+ ctx context.Context,
+ request *InputRoomEventsRequest,
+ response *InputRoomEventsResponse,
+ ) error
+}
+
+// RoomserverInputRoomEventsPath is the HTTP path for the InputRoomEvents API.
+const RoomserverInputRoomEventsPath = "/api/roomserver/inputRoomEvents"
+
+// NewRoomserverInputAPIHTTP creates a RoomserverInputAPI implemented by talking to a HTTP POST API.
+// If httpClient is nil then it uses the http.DefaultClient
+func NewRoomserverInputAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverInputAPI {
+ if httpClient == nil {
+ httpClient = http.DefaultClient
+ }
+ return &httpRoomserverInputAPI{roomserverURL, httpClient}
+}
+
+type httpRoomserverInputAPI struct {
+ roomserverURL string
+ httpClient *http.Client
+}
+
+// InputRoomEvents implements RoomserverInputAPI
+func (h *httpRoomserverInputAPI) InputRoomEvents(
+ ctx context.Context,
+ request *InputRoomEventsRequest,
+ response *InputRoomEventsResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverInputRoomEventsPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
diff --git a/roomserver/api/output.go b/roomserver/api/output.go
new file mode 100644
index 00000000..c09d5a1e
--- /dev/null
+++ b/roomserver/api/output.go
@@ -0,0 +1,138 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package api
+
+import (
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// An OutputType is a type of roomserver output.
+type OutputType string
+
+const (
+ // OutputTypeNewRoomEvent indicates that the event is an OutputNewRoomEvent
+ OutputTypeNewRoomEvent OutputType = "new_room_event"
+ // OutputTypeNewInviteEvent indicates that the event is an OutputNewInviteEvent
+ OutputTypeNewInviteEvent OutputType = "new_invite_event"
+ // OutputTypeRetireInviteEvent indicates that the event is an OutputRetireInviteEvent
+ OutputTypeRetireInviteEvent OutputType = "retire_invite_event"
+)
+
+// An OutputEvent is an entry in the roomserver output kafka log.
+// Consumers should check the type field when consuming this event.
+type OutputEvent struct {
+ // What sort of event this is.
+ Type OutputType `json:"type"`
+ // The content of event with type OutputTypeNewRoomEvent
+ NewRoomEvent *OutputNewRoomEvent `json:"new_room_event,omitempty"`
+ // The content of event with type OutputTypeNewInviteEvent
+ NewInviteEvent *OutputNewInviteEvent `json:"new_invite_event,omitempty"`
+ // The content of event with type OutputTypeRetireInviteEvent
+ RetireInviteEvent *OutputRetireInviteEvent `json:"retire_invite_event,omitempty"`
+}
+
+// An OutputNewRoomEvent is written when the roomserver receives a new event.
+// It contains the full matrix room event and enough information for a
+// consumer to construct the current state of the room and the state before the
+// event.
+//
+// When we talk about state in a matrix room we are talking about the state
+// after a list of events. The current state is the state after the latest
+// event IDs in the room. The state before an event is the state after its
+// prev_events.
+type OutputNewRoomEvent struct {
+ // The Event.
+ Event gomatrixserverlib.Event `json:"event"`
+ // The latest events in the room after this event.
+ // This can be used to set the prev events for new events in the room.
+ // This also can be used to get the full current state after this event.
+ LatestEventIDs []string `json:"latest_event_ids"`
+ // The state event IDs that were added to the state of the room by this event.
+ // Together with RemovesStateEventIDs this allows the receiver to keep an up to date
+ // view of the current state of the room.
+ AddsStateEventIDs []string `json:"adds_state_event_ids"`
+ // The state event IDs that were removed from the state of the room by this event.
+ RemovesStateEventIDs []string `json:"removes_state_event_ids"`
+ // The ID of the event that was output before this event.
+ // Or the empty string if this is the first event output for this room.
+ // This is used by consumers to check if they can safely update their
+ // current state using the delta supplied in AddsStateEventIDs and
+ // RemovesStateEventIDs.
+ //
+ // If the LastSentEventID doesn't match what they were expecting it to be
+ // they can use the LatestEventIDs to request the full current state.
+ LastSentEventID string `json:"last_sent_event_id"`
+ // The state event IDs that are part of the state at the event, but not
+ // part of the current state. Together with the StateBeforeRemovesEventIDs
+ // this can be used to construct the state before the event from the
+ // current state. The StateBeforeAddsEventIDs and StateBeforeRemovesEventIDs
+ // delta is applied after the AddsStateEventIDs and RemovesStateEventIDs.
+ //
+ // Consumers need to know the state at each event in order to determine
+ // which users and servers are allowed to see the event. This information
+ // is needed to apply the history visibility rules and to tell which
+ // servers we need to push events to over federation.
+ //
+ // The state is given as a delta against the current state because they are
+ // usually either the same state, or differ by just a couple of events.
+ StateBeforeAddsEventIDs []string `json:"state_before_adds_event_ids"`
+ // The state event IDs that are part of the current state, but not part
+ // of the state at the event.
+ StateBeforeRemovesEventIDs []string `json:"state_before_removes_event_ids"`
+ // The server name to use to push this event to other servers.
+ // Or empty if this event shouldn't be pushed to other servers.
+ //
+ // This is used by the federation sender component. We need to tell it what
+ // event it needs to send because it can't tell on its own. Normally if an
+ // event was created on this server then we are responsible for sending it.
+ // However there are a couple of exceptions. The first is that when the
+ // server joins a remote room through another matrix server, it is the job
+ // of the other matrix server to send the event over federation. The second
+ // is the reverse of the first, that is when a remote server joins a room
+ // that we are in over federation using our server it is our responsibility
+ // to send the join event to other matrix servers.
+ //
+ // We encode the server name that the event should be sent using here to
+ // future proof the API for virtual hosting.
+ SendAsServer string `json:"send_as_server"`
+ // The transaction ID of the send request if sent by a local user and one
+ // was specified
+ TransactionID *TransactionID `json:"transaction_id"`
+}
+
+// An OutputNewInviteEvent is written whenever an invite becomes active.
+// Invite events can be received outside of an existing room so have to be
+// tracked separately from the room events themselves.
+type OutputNewInviteEvent struct {
+ // The "m.room.member" invite event.
+ Event gomatrixserverlib.Event `json:"event"`
+}
+
+// An OutputRetireInviteEvent is written whenever an existing invite is no longer
+// active. An invite stops being active if the user joins the room or if the
+// invite is rejected by the user.
+type OutputRetireInviteEvent struct {
+ // The ID of the "m.room.member" invite event.
+ EventID string
+ // The target user ID of the "m.room.member" invite event that was retired.
+ TargetUserID string
+ // Optional event ID of the event that replaced the invite.
+ // This can be empty if the invite was rejected locally and we were unable
+ // to reach the server that originally sent the invite.
+ RetiredByEventID string
+ // The "membership" of the user after retiring the invite. One of "join"
+ // "leave" or "ban".
+ Membership string
+}
diff --git a/roomserver/api/query.go b/roomserver/api/query.go
new file mode 100644
index 00000000..a544f8aa
--- /dev/null
+++ b/roomserver/api/query.go
@@ -0,0 +1,480 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package api
+
+import (
+ "context"
+ "net/http"
+
+ commonHTTP "github.com/matrix-org/dendrite/common/http"
+ "github.com/matrix-org/gomatrixserverlib"
+ opentracing "github.com/opentracing/opentracing-go"
+)
+
+// QueryLatestEventsAndStateRequest is a request to QueryLatestEventsAndState
+type QueryLatestEventsAndStateRequest struct {
+ // The room ID to query the latest events for.
+ RoomID string `json:"room_id"`
+ // The state key tuples to fetch from the room current state.
+ // If this list is empty or nil then no state events are returned.
+ StateToFetch []gomatrixserverlib.StateKeyTuple `json:"state_to_fetch"`
+}
+
+// QueryLatestEventsAndStateResponse is a response to QueryLatestEventsAndState
+// This is used when sending events to set the prev_events, auth_events and depth.
+// It is also used to tell whether the event is allowed by the event auth rules.
+type QueryLatestEventsAndStateResponse struct {
+ // Copy of the request for debugging.
+ QueryLatestEventsAndStateRequest
+ // Does the room exist?
+ // If the room doesn't exist this will be false and LatestEvents will be empty.
+ RoomExists bool `json:"room_exists"`
+ // The latest events in the room.
+ // These are used to set the prev_events when sending an event.
+ LatestEvents []gomatrixserverlib.EventReference `json:"latest_events"`
+ // The state events requested.
+ // This list will be in an arbitrary order.
+ // These are used to set the auth_events when sending an event.
+ // These are used to check whether the event is allowed.
+ StateEvents []gomatrixserverlib.Event `json:"state_events"`
+ // The depth of the latest events.
+ // This is one greater than the maximum depth of the latest events.
+ // This is used to set the depth when sending an event.
+ Depth int64 `json:"depth"`
+}
+
+// QueryStateAfterEventsRequest is a request to QueryStateAfterEvents
+type QueryStateAfterEventsRequest struct {
+ // The room ID to query the state in.
+ RoomID string `json:"room_id"`
+ // The list of previous events to return the events after.
+ PrevEventIDs []string `json:"prev_event_ids"`
+ // The state key tuples to fetch from the state
+ StateToFetch []gomatrixserverlib.StateKeyTuple `json:"state_to_fetch"`
+}
+
+// QueryStateAfterEventsResponse is a response to QueryStateAfterEvents
+type QueryStateAfterEventsResponse struct {
+ // Copy of the request for debugging.
+ QueryStateAfterEventsRequest
+ // Does the room exist on this roomserver?
+ // If the room doesn't exist this will be false and StateEvents will be empty.
+ RoomExists bool `json:"room_exists"`
+ // Do all the previous events exist on this roomserver?
+ // If some of previous events do not exist this will be false and StateEvents will be empty.
+ PrevEventsExist bool `json:"prev_events_exist"`
+ // The state events requested.
+ // This list will be in an arbitrary order.
+ StateEvents []gomatrixserverlib.Event `json:"state_events"`
+}
+
+// QueryEventsByIDRequest is a request to QueryEventsByID
+type QueryEventsByIDRequest struct {
+ // The event IDs to look up.
+ EventIDs []string `json:"event_ids"`
+}
+
+// QueryEventsByIDResponse is a response to QueryEventsByID
+type QueryEventsByIDResponse struct {
+ // Copy of the request for debugging.
+ QueryEventsByIDRequest
+ // A list of events with the requested IDs.
+ // If the roomserver does not have a copy of a requested event
+ // then it will omit that event from the list.
+ // If the roomserver thinks it has a copy of the event, but
+ // fails to read it from the database then it will fail
+ // the entire request.
+ // This list will be in an arbitrary order.
+ Events []gomatrixserverlib.Event `json:"events"`
+}
+
+// QueryMembershipForUserRequest is a request to QueryMembership
+type QueryMembershipForUserRequest struct {
+ // ID of the room to fetch membership from
+ RoomID string `json:"room_id"`
+ // ID of the user for whom membership is requested
+ UserID string `json:"user_id"`
+}
+
+// QueryMembershipForUserResponse is a response to QueryMembership
+type QueryMembershipForUserResponse struct {
+ // The EventID of the latest "m.room.member" event for the sender,
+ // if HasBeenInRoom is true.
+ EventID string `json:"event_id"`
+ // True if the user has been in room before and has either stayed in it or left it.
+ HasBeenInRoom bool `json:"has_been_in_room"`
+ // True if the user is in room.
+ IsInRoom bool `json:"is_in_room"`
+}
+
+// QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom
+type QueryMembershipsForRoomRequest struct {
+ // If true, only returns the membership events of "join" membership
+ JoinedOnly bool `json:"joined_only"`
+ // ID of the room to fetch memberships from
+ RoomID string `json:"room_id"`
+ // ID of the user sending the request
+ Sender string `json:"sender"`
+}
+
+// QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom
+type QueryMembershipsForRoomResponse struct {
+ // The "m.room.member" events (of "join" membership) in the client format
+ JoinEvents []gomatrixserverlib.ClientEvent `json:"join_events"`
+ // True if the user has been in room before and has either stayed in it or
+ // left it.
+ HasBeenInRoom bool `json:"has_been_in_room"`
+}
+
+// QueryInvitesForUserRequest is a request to QueryInvitesForUser
+type QueryInvitesForUserRequest struct {
+ // The room ID to look up invites in.
+ RoomID string `json:"room_id"`
+ // The User ID to look up invites for.
+ TargetUserID string `json:"target_user_id"`
+}
+
+// QueryInvitesForUserResponse is a response to QueryInvitesForUser
+// This is used when accepting an invite or rejecting a invite to tell which
+// remote matrix servers to contact.
+type QueryInvitesForUserResponse struct {
+ // A list of matrix user IDs for each sender of an active invite targeting
+ // the requested user ID.
+ InviteSenderUserIDs []string `json:"invite_sender_user_ids"`
+}
+
+// QueryServerAllowedToSeeEventRequest is a request to QueryServerAllowedToSeeEvent
+type QueryServerAllowedToSeeEventRequest struct {
+ // The event ID to look up invites in.
+ EventID string `json:"event_id"`
+ // The server interested in the event
+ ServerName gomatrixserverlib.ServerName `json:"server_name"`
+}
+
+// QueryServerAllowedToSeeEventResponse is a response to QueryServerAllowedToSeeEvent
+type QueryServerAllowedToSeeEventResponse struct {
+ // Wether the server in question is allowed to see the event
+ AllowedToSeeEvent bool `json:"can_see_event"`
+}
+
+// QueryMissingEventsRequest is a request to QueryMissingEvents
+type QueryMissingEventsRequest struct {
+ // Events which are known previous to the gap in the timeline.
+ EarliestEvents []string `json:"earliest_events"`
+ // Latest known events.
+ LatestEvents []string `json:"latest_events"`
+ // Limit the number of events this query returns.
+ Limit int `json:"limit"`
+ // The server interested in the event
+ ServerName gomatrixserverlib.ServerName `json:"server_name"`
+}
+
+// QueryMissingEventsResponse is a response to QueryMissingEvents
+type QueryMissingEventsResponse struct {
+ // Missing events, arbritrary order.
+ Events []gomatrixserverlib.Event `json:"events"`
+}
+
+// QueryStateAndAuthChainRequest is a request to QueryStateAndAuthChain
+type QueryStateAndAuthChainRequest struct {
+ // The room ID to query the state in.
+ RoomID string `json:"room_id"`
+ // The list of prev events for the event. Used to calculate the state at
+ // the event
+ PrevEventIDs []string `json:"prev_event_ids"`
+ // The list of auth events for the event. Used to calculate the auth chain
+ AuthEventIDs []string `json:"auth_event_ids"`
+}
+
+// QueryStateAndAuthChainResponse is a response to QueryStateAndAuthChain
+type QueryStateAndAuthChainResponse struct {
+ // Copy of the request for debugging.
+ QueryStateAndAuthChainRequest
+ // Does the room exist on this roomserver?
+ // If the room doesn't exist this will be false and StateEvents will be empty.
+ RoomExists bool `json:"room_exists"`
+ // Do all the previous events exist on this roomserver?
+ // If some of previous events do not exist this will be false and StateEvents will be empty.
+ PrevEventsExist bool `json:"prev_events_exist"`
+ // The state and auth chain events that were requested.
+ // The lists will be in an arbitrary order.
+ StateEvents []gomatrixserverlib.Event `json:"state_events"`
+ AuthChainEvents []gomatrixserverlib.Event `json:"auth_chain_events"`
+}
+
+// QueryBackfillRequest is a request to QueryBackfill.
+type QueryBackfillRequest struct {
+ // Events to start paginating from.
+ EarliestEventsIDs []string `json:"earliest_event_ids"`
+ // The maximum number of events to retrieve.
+ Limit int `json:"limit"`
+ // The server interested in the events.
+ ServerName gomatrixserverlib.ServerName `json:"server_name"`
+}
+
+// QueryBackfillResponse is a response to QueryBackfill.
+type QueryBackfillResponse struct {
+ // Missing events, arbritrary order.
+ Events []gomatrixserverlib.Event `json:"events"`
+}
+
+// RoomserverQueryAPI is used to query information from the room server.
+type RoomserverQueryAPI interface {
+ // Query the latest events and state for a room from the room server.
+ QueryLatestEventsAndState(
+ ctx context.Context,
+ request *QueryLatestEventsAndStateRequest,
+ response *QueryLatestEventsAndStateResponse,
+ ) error
+
+ // Query the state after a list of events in a room from the room server.
+ QueryStateAfterEvents(
+ ctx context.Context,
+ request *QueryStateAfterEventsRequest,
+ response *QueryStateAfterEventsResponse,
+ ) error
+
+ // Query a list of events by event ID.
+ QueryEventsByID(
+ ctx context.Context,
+ request *QueryEventsByIDRequest,
+ response *QueryEventsByIDResponse,
+ ) error
+
+ // Query the membership event for an user for a room.
+ QueryMembershipForUser(
+ ctx context.Context,
+ request *QueryMembershipForUserRequest,
+ response *QueryMembershipForUserResponse,
+ ) error
+
+ // Query a list of membership events for a room
+ QueryMembershipsForRoom(
+ ctx context.Context,
+ request *QueryMembershipsForRoomRequest,
+ response *QueryMembershipsForRoomResponse,
+ ) error
+
+ // Query a list of invite event senders for a user in a room.
+ QueryInvitesForUser(
+ ctx context.Context,
+ request *QueryInvitesForUserRequest,
+ response *QueryInvitesForUserResponse,
+ ) error
+
+ // Query whether a server is allowed to see an event
+ QueryServerAllowedToSeeEvent(
+ ctx context.Context,
+ request *QueryServerAllowedToSeeEventRequest,
+ response *QueryServerAllowedToSeeEventResponse,
+ ) error
+
+ // Query missing events for a room from roomserver
+ QueryMissingEvents(
+ ctx context.Context,
+ request *QueryMissingEventsRequest,
+ response *QueryMissingEventsResponse,
+ ) error
+
+ // Query to get state and auth chain for a (potentially hypothetical) event.
+ // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate
+ // the state and auth chain to return.
+ QueryStateAndAuthChain(
+ ctx context.Context,
+ request *QueryStateAndAuthChainRequest,
+ response *QueryStateAndAuthChainResponse,
+ ) error
+
+ // Query a given amount (or less) of events prior to a given set of events.
+ QueryBackfill(
+ ctx context.Context,
+ request *QueryBackfillRequest,
+ response *QueryBackfillResponse,
+ ) error
+}
+
+// RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API.
+const RoomserverQueryLatestEventsAndStatePath = "/api/roomserver/queryLatestEventsAndState"
+
+// RoomserverQueryStateAfterEventsPath is the HTTP path for the QueryStateAfterEvents API.
+const RoomserverQueryStateAfterEventsPath = "/api/roomserver/queryStateAfterEvents"
+
+// RoomserverQueryEventsByIDPath is the HTTP path for the QueryEventsByID API.
+const RoomserverQueryEventsByIDPath = "/api/roomserver/queryEventsByID"
+
+// RoomserverQueryMembershipForUserPath is the HTTP path for the QueryMembershipForUser API.
+const RoomserverQueryMembershipForUserPath = "/api/roomserver/queryMembershipForUser"
+
+// RoomserverQueryMembershipsForRoomPath is the HTTP path for the QueryMembershipsForRoom API
+const RoomserverQueryMembershipsForRoomPath = "/api/roomserver/queryMembershipsForRoom"
+
+// RoomserverQueryInvitesForUserPath is the HTTP path for the QueryInvitesForUser API
+const RoomserverQueryInvitesForUserPath = "/api/roomserver/queryInvitesForUser"
+
+// RoomserverQueryServerAllowedToSeeEventPath is the HTTP path for the QueryServerAllowedToSeeEvent API
+const RoomserverQueryServerAllowedToSeeEventPath = "/api/roomserver/queryServerAllowedToSeeEvent"
+
+// RoomserverQueryMissingEventsPath is the HTTP path for the QueryMissingEvents API
+const RoomserverQueryMissingEventsPath = "/api/roomserver/queryMissingEvents"
+
+// RoomserverQueryStateAndAuthChainPath is the HTTP path for the QueryStateAndAuthChain API
+const RoomserverQueryStateAndAuthChainPath = "/api/roomserver/queryStateAndAuthChain"
+
+// RoomserverQueryBackfillPath is the HTTP path for the QueryMissingEvents API
+const RoomserverQueryBackfillPath = "/api/roomserver/QueryBackfill"
+
+// NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API.
+// If httpClient is nil then it uses the http.DefaultClient
+func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverQueryAPI {
+ if httpClient == nil {
+ httpClient = http.DefaultClient
+ }
+ return &httpRoomserverQueryAPI{roomserverURL, httpClient}
+}
+
+type httpRoomserverQueryAPI struct {
+ roomserverURL string
+ httpClient *http.Client
+}
+
+// QueryLatestEventsAndState implements RoomserverQueryAPI
+func (h *httpRoomserverQueryAPI) QueryLatestEventsAndState(
+ ctx context.Context,
+ request *QueryLatestEventsAndStateRequest,
+ response *QueryLatestEventsAndStateResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLatestEventsAndState")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+// QueryStateAfterEvents implements RoomserverQueryAPI
+func (h *httpRoomserverQueryAPI) QueryStateAfterEvents(
+ ctx context.Context,
+ request *QueryStateAfterEventsRequest,
+ response *QueryStateAfterEventsResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAfterEvents")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryStateAfterEventsPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+// QueryEventsByID implements RoomserverQueryAPI
+func (h *httpRoomserverQueryAPI) QueryEventsByID(
+ ctx context.Context,
+ request *QueryEventsByIDRequest,
+ response *QueryEventsByIDResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryEventsByID")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryEventsByIDPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+// QueryMembershipForUser implements RoomserverQueryAPI
+func (h *httpRoomserverQueryAPI) QueryMembershipForUser(
+ ctx context.Context,
+ request *QueryMembershipForUserRequest,
+ response *QueryMembershipForUserResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipForUser")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryMembershipForUserPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+// QueryMembershipsForRoom implements RoomserverQueryAPI
+func (h *httpRoomserverQueryAPI) QueryMembershipsForRoom(
+ ctx context.Context,
+ request *QueryMembershipsForRoomRequest,
+ response *QueryMembershipsForRoomResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipsForRoom")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryMembershipsForRoomPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+// QueryInvitesForUser implements RoomserverQueryAPI
+func (h *httpRoomserverQueryAPI) QueryInvitesForUser(
+ ctx context.Context,
+ request *QueryInvitesForUserRequest,
+ response *QueryInvitesForUserResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryInvitesForUser")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryInvitesForUserPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI
+func (h *httpRoomserverQueryAPI) QueryServerAllowedToSeeEvent(
+ ctx context.Context,
+ request *QueryServerAllowedToSeeEventRequest,
+ response *QueryServerAllowedToSeeEventResponse,
+) (err error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerAllowedToSeeEvent")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryServerAllowedToSeeEventPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+// QueryMissingEvents implements RoomServerQueryAPI
+func (h *httpRoomserverQueryAPI) QueryMissingEvents(
+ ctx context.Context,
+ request *QueryMissingEventsRequest,
+ response *QueryMissingEventsResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingEvents")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryMissingEventsPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+// QueryStateAndAuthChain implements RoomserverQueryAPI
+func (h *httpRoomserverQueryAPI) QueryStateAndAuthChain(
+ ctx context.Context,
+ request *QueryStateAndAuthChainRequest,
+ response *QueryStateAndAuthChainResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAndAuthChain")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryStateAndAuthChainPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+// QueryBackfill implements RoomServerQueryAPI
+func (h *httpRoomserverQueryAPI) QueryBackfill(
+ ctx context.Context,
+ request *QueryBackfillRequest,
+ response *QueryBackfillResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBackfill")
+ defer span.Finish()
+
+ apiURL := h.roomserverURL + RoomserverQueryMissingEventsPath
+ return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go
new file mode 100644
index 00000000..2dce6f6d
--- /dev/null
+++ b/roomserver/auth/auth.go
@@ -0,0 +1,47 @@
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package auth
+
+import "github.com/matrix-org/gomatrixserverlib"
+
+// IsServerAllowed returns true if there exists a event in authEvents
+// which allows server to view this event. That is true when a client on the server
+// can view the event. Otherwise returns false.
+func IsServerAllowed(
+ serverName gomatrixserverlib.ServerName,
+ authEvents []gomatrixserverlib.Event,
+) bool {
+ for _, ev := range authEvents {
+ membership, err := ev.Membership()
+ if err != nil || membership != "join" {
+ continue
+ }
+
+ stateKey := ev.StateKey()
+ if stateKey == nil {
+ continue
+ }
+
+ _, domain, err := gomatrixserverlib.SplitID('@', *stateKey)
+ if err != nil {
+ continue
+ }
+
+ if domain == serverName {
+ return true
+ }
+ }
+
+ // TODO: Check if history visibility is shared and if the server is currently in the room
+ return false
+}
diff --git a/roomserver/input/authevents.go b/roomserver/input/authevents.go
new file mode 100644
index 00000000..74be2ed3
--- /dev/null
+++ b/roomserver/input/authevents.go
@@ -0,0 +1,243 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package input
+
+import (
+ "context"
+ "sort"
+
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// checkAuthEvents checks that the event passes authentication checks
+// Returns the numeric IDs for the auth events.
+func checkAuthEvents(
+ ctx context.Context,
+ db RoomEventDatabase,
+ event gomatrixserverlib.Event,
+ authEventIDs []string,
+) ([]types.EventNID, error) {
+ // Grab the numeric IDs for the supplied auth state events from the database.
+ authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs)
+ if err != nil {
+ return nil, err
+ }
+ // TODO: check for duplicate state keys here.
+
+ // Work out which of the state events we actually need.
+ stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event})
+
+ // Load the actual auth events from the database.
+ authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
+ if err != nil {
+ return nil, err
+ }
+
+ // Check if the event is allowed.
+ if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
+ return nil, err
+ }
+
+ // Return the numeric IDs for the auth events.
+ result := make([]types.EventNID, len(authStateEntries))
+ for i := range authStateEntries {
+ result[i] = authStateEntries[i].EventNID
+ }
+ return result, nil
+}
+
+type authEvents struct {
+ stateKeyNIDMap map[string]types.EventStateKeyNID
+ state stateEntryMap
+ events eventMap
+}
+
+// Create implements gomatrixserverlib.AuthEventProvider
+func (ae *authEvents) Create() (*gomatrixserverlib.Event, error) {
+ return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil
+}
+
+// PowerLevels implements gomatrixserverlib.AuthEventProvider
+func (ae *authEvents) PowerLevels() (*gomatrixserverlib.Event, error) {
+ return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil
+}
+
+// JoinRules implements gomatrixserverlib.AuthEventProvider
+func (ae *authEvents) JoinRules() (*gomatrixserverlib.Event, error) {
+ return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil
+}
+
+// Memmber implements gomatrixserverlib.AuthEventProvider
+func (ae *authEvents) Member(stateKey string) (*gomatrixserverlib.Event, error) {
+ return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil
+}
+
+// ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider
+func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Event, error) {
+ return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil
+}
+
+func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event {
+ eventNID, ok := ae.state.lookup(types.StateKeyTuple{
+ EventTypeNID: typeNID,
+ EventStateKeyNID: types.EmptyStateKeyNID,
+ })
+ if !ok {
+ return nil
+ }
+ event, ok := ae.events.lookup(eventNID)
+ if !ok {
+ return nil
+ }
+ return &event.Event
+}
+
+func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event {
+ stateKeyNID, ok := ae.stateKeyNIDMap[stateKey]
+ if !ok {
+ return nil
+ }
+ eventNID, ok := ae.state.lookup(types.StateKeyTuple{
+ EventTypeNID: typeNID,
+ EventStateKeyNID: stateKeyNID,
+ })
+ if !ok {
+ return nil
+ }
+ event, ok := ae.events.lookup(eventNID)
+ if !ok {
+ return nil
+ }
+ return &event.Event
+}
+
+// loadAuthEvents loads the events needed for authentication from the supplied room state.
+func loadAuthEvents(
+ ctx context.Context,
+ db RoomEventDatabase,
+ needed gomatrixserverlib.StateNeeded,
+ state []types.StateEntry,
+) (result authEvents, err error) {
+ // Look up the numeric IDs for the state keys needed for auth.
+ var neededStateKeys []string
+ neededStateKeys = append(neededStateKeys, needed.Member...)
+ neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
+ if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(ctx, neededStateKeys); err != nil {
+ return
+ }
+
+ // Load the events we need.
+ result.state = state
+ var eventNIDs []types.EventNID
+ keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed)
+ for _, keyTuple := range keyTuplesNeeded {
+ eventNID, ok := result.state.lookup(keyTuple)
+ if ok {
+ eventNIDs = append(eventNIDs, eventNID)
+ }
+ }
+ if result.events, err = db.Events(ctx, eventNIDs); err != nil {
+ return
+ }
+ return
+}
+
+// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
+func stateKeyTuplesNeeded(
+ stateKeyNIDMap map[string]types.EventStateKeyNID,
+ stateNeeded gomatrixserverlib.StateNeeded,
+) []types.StateKeyTuple {
+ var keyTuples []types.StateKeyTuple
+ if stateNeeded.Create {
+ keyTuples = append(keyTuples, types.StateKeyTuple{
+ EventTypeNID: types.MRoomCreateNID,
+ EventStateKeyNID: types.EmptyStateKeyNID,
+ })
+ }
+ if stateNeeded.PowerLevels {
+ keyTuples = append(keyTuples, types.StateKeyTuple{
+ EventTypeNID: types.MRoomPowerLevelsNID,
+ EventStateKeyNID: types.EmptyStateKeyNID,
+ })
+ }
+ if stateNeeded.JoinRules {
+ keyTuples = append(keyTuples, types.StateKeyTuple{
+ EventTypeNID: types.MRoomJoinRulesNID,
+ EventStateKeyNID: types.EmptyStateKeyNID,
+ })
+ }
+ for _, member := range stateNeeded.Member {
+ stateKeyNID, ok := stateKeyNIDMap[member]
+ if ok {
+ keyTuples = append(keyTuples, types.StateKeyTuple{
+ EventTypeNID: types.MRoomMemberNID,
+ EventStateKeyNID: stateKeyNID,
+ })
+ }
+ }
+ for _, token := range stateNeeded.ThirdPartyInvite {
+ stateKeyNID, ok := stateKeyNIDMap[token]
+ if ok {
+ keyTuples = append(keyTuples, types.StateKeyTuple{
+ EventTypeNID: types.MRoomThirdPartyInviteNID,
+ EventStateKeyNID: stateKeyNID,
+ })
+ }
+ }
+ return keyTuples
+}
+
+// Map from event type, state key tuple to numeric event ID.
+// Implemented using binary search on a sorted array.
+type stateEntryMap []types.StateEntry
+
+// lookup an entry in the event map.
+func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) {
+ // Since the list is sorted we can implement this using binary search.
+ // This is faster than using a hash map.
+ // We don't have to worry about pathological cases because the keys are fixed
+ // size and are controlled by us.
+ list := []types.StateEntry(m)
+ i := sort.Search(len(list), func(i int) bool {
+ return !list[i].StateKeyTuple.LessThan(stateKey)
+ })
+ if i < len(list) && list[i].StateKeyTuple == stateKey {
+ ok = true
+ eventNID = list[i].EventNID
+ }
+ return
+}
+
+// Map from numeric event ID to event.
+// Implemented using binary search on a sorted array.
+type eventMap []types.Event
+
+// lookup an entry in the event map.
+func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
+ // Since the list is sorted we can implement this using binary search.
+ // This is faster than using a hash map.
+ // We don't have to worry about pathological cases because the keys are fixed
+ // size are controlled by us.
+ list := []types.Event(m)
+ i := sort.Search(len(list), func(i int) bool {
+ return list[i].EventNID >= eventNID
+ })
+ if i < len(list) && list[i].EventNID == eventNID {
+ ok = true
+ event = &list[i]
+ }
+ return
+}
diff --git a/roomserver/input/authevents_test.go b/roomserver/input/authevents_test.go
new file mode 100644
index 00000000..0621a084
--- /dev/null
+++ b/roomserver/input/authevents_test.go
@@ -0,0 +1,136 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package input
+
+import (
+ "testing"
+
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) {
+ var list []types.StateEntry
+ for i := int64(0); i < entries; i++ {
+ list = append(list, types.StateEntry{
+ StateKeyTuple: types.StateKeyTuple{
+ EventTypeNID: types.EventTypeNID(i),
+ EventStateKeyNID: types.EventStateKeyNID(i),
+ },
+ EventNID: types.EventNID(i),
+ })
+ }
+
+ for i := 0; i < b.N; i++ {
+ entryMap := stateEntryMap(list)
+ for j := int64(0); j < lookups; j++ {
+ entryMap.lookup(types.StateKeyTuple{
+ EventTypeNID: types.EventTypeNID(j),
+ EventStateKeyNID: types.EventStateKeyNID(j),
+ })
+ }
+ }
+}
+
+func BenchmarkStateEntryMap100Lookup10(b *testing.B) {
+ benchmarkStateEntryMapLookup(100, 10, b)
+}
+
+func BenchmarkStateEntryMap1000Lookup100(b *testing.B) {
+ benchmarkStateEntryMapLookup(1000, 100, b)
+}
+
+func BenchmarkStateEntryMap100Lookup100(b *testing.B) {
+ benchmarkStateEntryMapLookup(100, 100, b)
+}
+
+func BenchmarkStateEntryMap1000Lookup10000(b *testing.B) {
+ benchmarkStateEntryMapLookup(1000, 10000, b)
+}
+
+func TestStateEntryMap(t *testing.T) {
+ entryMap := stateEntryMap([]types.StateEntry{
+ {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1},
+ {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 3}, EventNID: 2},
+ {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 2, EventStateKeyNID: 1}, EventNID: 3},
+ })
+
+ testCases := []struct {
+ inputTypeNID types.EventTypeNID
+ inputStateKey types.EventStateKeyNID
+ wantOK bool
+ wantEventNID types.EventNID
+ }{
+ // Check that tuples that in the array are in the map.
+ {1, 1, true, 1},
+ {1, 3, true, 2},
+ {2, 1, true, 3},
+ // Check that tuples that aren't in the array aren't in the map.
+ {0, 0, false, 0},
+ {1, 2, false, 0},
+ {3, 1, false, 0},
+ }
+
+ for _, testCase := range testCases {
+ keyTuple := types.StateKeyTuple{EventTypeNID: testCase.inputTypeNID, EventStateKeyNID: testCase.inputStateKey}
+ gotEventNID, gotOK := entryMap.lookup(keyTuple)
+ if testCase.wantOK != gotOK {
+ t.Fatalf("stateEntryMap lookup(%v): want ok to be %v, got %v", keyTuple, testCase.wantOK, gotOK)
+ }
+ if testCase.wantEventNID != gotEventNID {
+ t.Fatalf("stateEntryMap lookup(%v): want eventNID to be %v, got %v", keyTuple, testCase.wantEventNID, gotEventNID)
+ }
+ }
+}
+
+func TestEventMap(t *testing.T) {
+ events := eventMap([]types.Event{
+ {EventNID: 1},
+ {EventNID: 2},
+ {EventNID: 3},
+ {EventNID: 5},
+ {EventNID: 8},
+ })
+
+ testCases := []struct {
+ inputEventNID types.EventNID
+ wantOK bool
+ wantEvent *types.Event
+ }{
+ // Check that the IDs that are in the array are in the map.
+ {1, true, &events[0]},
+ {2, true, &events[1]},
+ {3, true, &events[2]},
+ {5, true, &events[3]},
+ {8, true, &events[4]},
+ // Check that tuples that aren't in the array aren't in the map.
+ {0, false, nil},
+ {4, false, nil},
+ {6, false, nil},
+ {7, false, nil},
+ {9, false, nil},
+ }
+
+ for _, testCase := range testCases {
+ gotEvent, gotOK := events.lookup(testCase.inputEventNID)
+ if testCase.wantOK != gotOK {
+ t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK)
+ }
+
+ if testCase.wantEvent != gotEvent {
+ t.Fatalf("eventMap lookup(%v): want event to be %v, got %v", testCase.inputEventNID, testCase.wantEvent, gotEvent)
+ }
+ }
+
+}
diff --git a/roomserver/input/events.go b/roomserver/input/events.go
new file mode 100644
index 00000000..feb15b3e
--- /dev/null
+++ b/roomserver/input/events.go
@@ -0,0 +1,235 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package input
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/state"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// A RoomEventDatabase has the storage APIs needed to store a room event.
+type RoomEventDatabase interface {
+ state.RoomStateDatabase
+ // Stores a matrix room event in the database
+ StoreEvent(
+ ctx context.Context,
+ event gomatrixserverlib.Event,
+ txnAndDeviceID *api.TransactionID,
+ authEventNIDs []types.EventNID,
+ ) (types.RoomNID, types.StateAtEvent, error)
+ // Look up the state entries for a list of string event IDs
+ // Returns an error if the there is an error talking to the database
+ // Returns a types.MissingEventError if the event IDs aren't in the database.
+ StateEntriesForEventIDs(
+ ctx context.Context, eventIDs []string,
+ ) ([]types.StateEntry, error)
+ // Set the state at an event.
+ SetState(
+ ctx context.Context,
+ eventNID types.EventNID,
+ stateNID types.StateSnapshotNID,
+ ) error
+ // Look up the latest events in a room in preparation for an update.
+ // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
+ // Returns the latest events in the room and the last eventID sent to the log along with an updater.
+ // If this returns an error then no further action is required.
+ GetLatestEventsForUpdate(
+ ctx context.Context, roomNID types.RoomNID,
+ ) (updater types.RoomRecentEventsUpdater, err error)
+ // Look up the string event IDs for a list of numeric event IDs
+ EventIDs(
+ ctx context.Context, eventNIDs []types.EventNID,
+ ) (map[types.EventNID]string, error)
+ // Build a membership updater for the target user in a room.
+ MembershipUpdater(
+ ctx context.Context, roomID, targerUserID string,
+ ) (types.MembershipUpdater, error)
+ // Look up event ID by transaction's info.
+ // This is used to determine if the room event is processed/processing already.
+ // Returns an empty string if no such event exists.
+ GetTransactionEventID(
+ ctx context.Context, transactionID string,
+ deviceID string, userID string,
+ ) (string, error)
+}
+
+// OutputRoomEventWriter has the APIs needed to write an event to the output logs.
+type OutputRoomEventWriter interface {
+ // Write a list of events for a room
+ WriteOutputEvents(roomID string, updates []api.OutputEvent) error
+}
+
+// processRoomEvent can only be called once at a time
+//
+// TODO(#375): This should be rewritten to allow concurrent calls. The
+// difficulty is in ensuring that we correctly annotate events with the correct
+// state deltas when sending to kafka streams
+func processRoomEvent(
+ ctx context.Context,
+ db RoomEventDatabase,
+ ow OutputRoomEventWriter,
+ input api.InputRoomEvent,
+) (eventID string, err error) {
+ // Parse and validate the event JSON
+ event := input.Event
+
+ // Check that the event passes authentication checks and work out the numeric IDs for the auth events.
+ authEventNIDs, err := checkAuthEvents(ctx, db, event, input.AuthEventIDs)
+ if err != nil {
+ return
+ }
+
+ if input.TransactionID != nil {
+ tdID := input.TransactionID
+ eventID, err = db.GetTransactionEventID(
+ ctx, tdID.TransactionID, tdID.DeviceID, input.Event.Sender(),
+ )
+ // On error OR event with the transaction already processed/processesing
+ if err != nil || eventID != "" {
+ return
+ }
+ }
+
+ // Store the event
+ roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
+ if err != nil {
+ return
+ }
+
+ if input.Kind == api.KindOutlier {
+ // For outliers we can stop after we've stored the event itself as it
+ // doesn't have any associated state to store and we don't need to
+ // notify anyone about it.
+ return event.EventID(), nil
+ }
+
+ if stateAtEvent.BeforeStateSnapshotNID == 0 {
+ // We haven't calculated a state for this event yet.
+ // Lets calculate one.
+ err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event)
+ if err != nil {
+ return
+ }
+ }
+
+ if input.Kind == api.KindBackfill {
+ // Backfill is not implemented.
+ panic("Not implemented")
+ }
+
+ // Update the extremities of the event graph for the room
+ return event.EventID(), updateLatestEvents(
+ ctx, db, ow, roomNID, stateAtEvent, event, input.SendAsServer, input.TransactionID,
+ )
+}
+
+func calculateAndSetState(
+ ctx context.Context,
+ db RoomEventDatabase,
+ input api.InputRoomEvent,
+ roomNID types.RoomNID,
+ stateAtEvent *types.StateAtEvent,
+ event gomatrixserverlib.Event,
+) error {
+ var err error
+ if input.HasState {
+ // We've been told what the state at the event is so we don't need to calculate it.
+ // Check that those state events are in the database and store the state.
+ var entries []types.StateEntry
+ if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
+ return err
+ }
+
+ if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil {
+ return err
+ }
+ } else {
+ // We haven't been told what the state at the event is so we need to calculate it from the prev_events
+ if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil {
+ return err
+ }
+ }
+ return db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
+}
+
+func processInviteEvent(
+ ctx context.Context,
+ db RoomEventDatabase,
+ ow OutputRoomEventWriter,
+ input api.InputInviteEvent,
+) (err error) {
+ if input.Event.StateKey() == nil {
+ return fmt.Errorf("invite must be a state event")
+ }
+
+ roomID := input.Event.RoomID()
+ targetUserID := *input.Event.StateKey()
+
+ updater, err := db.MembershipUpdater(ctx, roomID, targetUserID)
+ if err != nil {
+ return err
+ }
+ succeeded := false
+ defer common.EndTransaction(updater, &succeeded)
+
+ if updater.IsJoin() {
+ // If the user is joined to the room then that takes precedence over this
+ // invite event. It makes little sense to move a user that is already
+ // joined to the room into the invite state.
+ // This could plausibly happen if an invite request raced with a join
+ // request for a user. For example if a user was invited to a public
+ // room and they joined the room at the same time as the invite was sent.
+ // The other way this could plausibly happen is if an invite raced with
+ // a kick. For example if a user was kicked from a room in error and in
+ // response someone else in the room re-invited them then it is possible
+ // for the invite request to race with the leave event so that the
+ // target receives invite before it learns that it has been kicked.
+ // There are a few ways this could be plausibly handled in the roomserver.
+ // 1) Store the invite, but mark it as retired. That will result in the
+ // permanent rejection of that invite event. So even if the target
+ // user leaves the room and the invite is retransmitted it will be
+ // ignored. However a new invite with a new event ID would still be
+ // accepted.
+ // 2) Silently discard the invite event. This means that if the event
+ // was retransmitted at a later date after the target user had left
+ // the room we would accept the invite. However since we hadn't told
+ // the sending server that the invite had been discarded it would
+ // have no reason to attempt to retry.
+ // 3) Signal the sending server that the user is already joined to the
+ // room.
+ // For now we will implement option 2. Since in the abesence of a retry
+ // mechanism it will be equivalent to option 1, and we don't have a
+ // signalling mechanism to implement option 3.
+ return nil
+ }
+
+ outputUpdates, err := updateToInviteMembership(updater, &input.Event, nil)
+ if err != nil {
+ return err
+ }
+
+ if err = ow.WriteOutputEvents(roomID, outputUpdates); err != nil {
+ return err
+ }
+
+ succeeded = true
+ return nil
+}
diff --git a/roomserver/input/input.go b/roomserver/input/input.go
new file mode 100644
index 00000000..bd029d8d
--- /dev/null
+++ b/roomserver/input/input.go
@@ -0,0 +1,95 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package input contains the code processes new room events
+package input
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "sync"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/util"
+ sarama "gopkg.in/Shopify/sarama.v1"
+)
+
+// RoomserverInputAPI implements api.RoomserverInputAPI
+type RoomserverInputAPI struct {
+ DB RoomEventDatabase
+ Producer sarama.SyncProducer
+ // The kafkaesque topic to output new room events to.
+ // This is the name used in kafka to identify the stream to write events to.
+ OutputRoomEventTopic string
+ // Protects calls to processRoomEvent
+ mutex sync.Mutex
+}
+
+// WriteOutputEvents implements OutputRoomEventWriter
+func (r *RoomserverInputAPI) WriteOutputEvents(roomID string, updates []api.OutputEvent) error {
+ messages := make([]*sarama.ProducerMessage, len(updates))
+ for i := range updates {
+ value, err := json.Marshal(updates[i])
+ if err != nil {
+ return err
+ }
+ messages[i] = &sarama.ProducerMessage{
+ Topic: r.OutputRoomEventTopic,
+ Key: sarama.StringEncoder(roomID),
+ Value: sarama.ByteEncoder(value),
+ }
+ }
+ return r.Producer.SendMessages(messages)
+}
+
+// InputRoomEvents implements api.RoomserverInputAPI
+func (r *RoomserverInputAPI) InputRoomEvents(
+ ctx context.Context,
+ request *api.InputRoomEventsRequest,
+ response *api.InputRoomEventsResponse,
+) (err error) {
+ // We lock as processRoomEvent can only be called once at a time
+ r.mutex.Lock()
+ defer r.mutex.Unlock()
+ for i := range request.InputRoomEvents {
+ if response.EventID, err = processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil {
+ return err
+ }
+ }
+ for i := range request.InputInviteEvents {
+ if err = processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// SetupHTTP adds the RoomserverInputAPI handlers to the http.ServeMux.
+func (r *RoomserverInputAPI) SetupHTTP(servMux *http.ServeMux) {
+ servMux.Handle(api.RoomserverInputRoomEventsPath,
+ common.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse {
+ var request api.InputRoomEventsRequest
+ var response api.InputRoomEventsResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := r.InputRoomEvents(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+}
diff --git a/roomserver/input/latest_events.go b/roomserver/input/latest_events.go
new file mode 100644
index 00000000..c2f06393
--- /dev/null
+++ b/roomserver/input/latest_events.go
@@ -0,0 +1,293 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package input
+
+import (
+ "bytes"
+ "context"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/state"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+)
+
+// updateLatestEvents updates the list of latest events for this room in the database and writes the
+// event to the output log.
+// The latest events are the events that aren't referenced by another event in the database:
+//
+// Time goes down the page. 1 is the m.room.create event (root).
+//
+// 1 After storing 1 the latest events are {1}
+// | After storing 2 the latest events are {2}
+// 2 After storing 3 the latest events are {3}
+// / \ After storing 4 the latest events are {3,4}
+// 3 4 After storing 5 the latest events are {5,4}
+// | | After storing 6 the latest events are {5,6}
+// 5 6 <--- latest After storing 7 the latest events are {6,7}
+// |
+// 7 <----- latest
+//
+// Can only be called once at a time
+func updateLatestEvents(
+ ctx context.Context,
+ db RoomEventDatabase,
+ ow OutputRoomEventWriter,
+ roomNID types.RoomNID,
+ stateAtEvent types.StateAtEvent,
+ event gomatrixserverlib.Event,
+ sendAsServer string,
+ transactionID *api.TransactionID,
+) (err error) {
+ updater, err := db.GetLatestEventsForUpdate(ctx, roomNID)
+ if err != nil {
+ return
+ }
+ succeeded := false
+ defer common.EndTransaction(updater, &succeeded)
+
+ u := latestEventsUpdater{
+ ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID,
+ stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer,
+ transactionID: transactionID,
+ }
+ if err = u.doUpdateLatestEvents(); err != nil {
+ return err
+ }
+
+ succeeded = true
+ return
+}
+
+// latestEventsUpdater tracks the state used to update the latest events in the
+// room. It mostly just ferries state between the various function calls.
+// The state could be passed using function arguments, but it becomes impractical
+// when there are so many variables to pass around.
+type latestEventsUpdater struct {
+ ctx context.Context
+ db RoomEventDatabase
+ updater types.RoomRecentEventsUpdater
+ ow OutputRoomEventWriter
+ roomNID types.RoomNID
+ stateAtEvent types.StateAtEvent
+ event gomatrixserverlib.Event
+ transactionID *api.TransactionID
+ // Which server to send this event as.
+ sendAsServer string
+ // The eventID of the event that was processed before this one.
+ lastEventIDSent string
+ // The latest events in the room after processing this event.
+ latest []types.StateAtEventAndReference
+ // The state entries removed from and added to the current state of the
+ // room as a result of processing this event. They are sorted lists.
+ removed []types.StateEntry
+ added []types.StateEntry
+ // The state entries that are removed and added to recover the state before
+ // the event being processed. They are sorted lists.
+ stateBeforeEventRemoves []types.StateEntry
+ stateBeforeEventAdds []types.StateEntry
+ // The snapshots of current state before and after processing this event
+ oldStateNID types.StateSnapshotNID
+ newStateNID types.StateSnapshotNID
+}
+
+func (u *latestEventsUpdater) doUpdateLatestEvents() error {
+ prevEvents := u.event.PrevEvents()
+ oldLatest := u.updater.LatestEvents()
+ u.lastEventIDSent = u.updater.LastEventIDSent()
+ u.oldStateNID = u.updater.CurrentStateSnapshotNID()
+
+ hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID)
+ if err != nil {
+ return err
+ } else if hasBeenSent {
+ // Already sent this event so we can stop processing
+ return nil
+ }
+
+ if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil {
+ return err
+ }
+
+ eventReference := u.event.EventReference()
+ // Check if this event is already referenced by another event in the room.
+ alreadyReferenced, err := u.updater.IsReferenced(eventReference)
+ if err != nil {
+ return err
+ }
+
+ u.latest = calculateLatest(oldLatest, alreadyReferenced, prevEvents, types.StateAtEventAndReference{
+ EventReference: eventReference,
+ StateAtEvent: u.stateAtEvent,
+ })
+
+ if err = u.latestState(); err != nil {
+ return err
+ }
+
+ updates, err := updateMemberships(u.ctx, u.db, u.updater, u.removed, u.added)
+ if err != nil {
+ return err
+ }
+
+ update, err := u.makeOutputNewRoomEvent()
+ if err != nil {
+ return err
+ }
+ updates = append(updates, *update)
+
+ // Send the event to the output logs.
+ // We do this inside the database transaction to ensure that we only mark an event as sent if we sent it.
+ // (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but
+ // the write to the output log succeeds)
+ // TODO: This assumes that writing the event to the output log is synchronous. It should be possible to
+ // send the event asynchronously but we would need to ensure that 1) the events are written to the log in
+ // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
+ // necessary bookkeeping we'll keep the event sending synchronous for now.
+ if err = u.ow.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
+ return err
+ }
+
+ if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
+ return err
+ }
+
+ return u.updater.MarkEventAsSent(u.stateAtEvent.EventNID)
+}
+
+func (u *latestEventsUpdater) latestState() error {
+ var err error
+
+ latestStateAtEvents := make([]types.StateAtEvent, len(u.latest))
+ for i := range u.latest {
+ latestStateAtEvents[i] = u.latest[i].StateAtEvent
+ }
+ u.newStateNID, err = state.CalculateAndStoreStateAfterEvents(
+ u.ctx, u.db, u.roomNID, latestStateAtEvents,
+ )
+ if err != nil {
+ return err
+ }
+
+ u.removed, u.added, err = state.DifferenceBetweeenStateSnapshots(
+ u.ctx, u.db, u.oldStateNID, u.newStateNID,
+ )
+ if err != nil {
+ return err
+ }
+
+ u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = state.DifferenceBetweeenStateSnapshots(
+ u.ctx, u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
+ )
+ return err
+}
+
+func calculateLatest(
+ oldLatest []types.StateAtEventAndReference,
+ alreadyReferenced bool,
+ prevEvents []gomatrixserverlib.EventReference,
+ newEvent types.StateAtEventAndReference,
+) []types.StateAtEventAndReference {
+ var alreadyInLatest bool
+ var newLatest []types.StateAtEventAndReference
+ for _, l := range oldLatest {
+ keep := true
+ for _, prevEvent := range prevEvents {
+ if l.EventID == prevEvent.EventID && bytes.Equal(l.EventSHA256, prevEvent.EventSHA256) {
+ // This event can be removed from the latest events cause we've found an event that references it.
+ // (If an event is referenced by another event then it can't be one of the latest events in the room
+ // because we have an event that comes after it)
+ keep = false
+ break
+ }
+ }
+ if l.EventNID == newEvent.EventNID {
+ alreadyInLatest = true
+ }
+ if keep {
+ // Keep the event in the latest events.
+ newLatest = append(newLatest, l)
+ }
+ }
+
+ if !alreadyReferenced && !alreadyInLatest {
+ // This event is not referenced by any of the events in the room
+ // and the event is not already in the latest events.
+ // Add it to the latest events
+ newLatest = append(newLatest, newEvent)
+ }
+
+ return newLatest
+}
+
+func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) {
+
+ latestEventIDs := make([]string, len(u.latest))
+ for i := range u.latest {
+ latestEventIDs[i] = u.latest[i].EventID
+ }
+
+ ore := api.OutputNewRoomEvent{
+ Event: u.event,
+ LastSentEventID: u.lastEventIDSent,
+ LatestEventIDs: latestEventIDs,
+ TransactionID: u.transactionID,
+ }
+
+ var stateEventNIDs []types.EventNID
+ for _, entry := range u.added {
+ stateEventNIDs = append(stateEventNIDs, entry.EventNID)
+ }
+ for _, entry := range u.removed {
+ stateEventNIDs = append(stateEventNIDs, entry.EventNID)
+ }
+ for _, entry := range u.stateBeforeEventRemoves {
+ stateEventNIDs = append(stateEventNIDs, entry.EventNID)
+ }
+ for _, entry := range u.stateBeforeEventAdds {
+ stateEventNIDs = append(stateEventNIDs, entry.EventNID)
+ }
+ stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
+ eventIDMap, err := u.db.EventIDs(u.ctx, stateEventNIDs)
+ if err != nil {
+ return nil, err
+ }
+ for _, entry := range u.added {
+ ore.AddsStateEventIDs = append(ore.AddsStateEventIDs, eventIDMap[entry.EventNID])
+ }
+ for _, entry := range u.removed {
+ ore.RemovesStateEventIDs = append(ore.RemovesStateEventIDs, eventIDMap[entry.EventNID])
+ }
+ for _, entry := range u.stateBeforeEventRemoves {
+ ore.StateBeforeRemovesEventIDs = append(ore.StateBeforeRemovesEventIDs, eventIDMap[entry.EventNID])
+ }
+ for _, entry := range u.stateBeforeEventAdds {
+ ore.StateBeforeAddsEventIDs = append(ore.StateBeforeAddsEventIDs, eventIDMap[entry.EventNID])
+ }
+ ore.SendAsServer = u.sendAsServer
+
+ return &api.OutputEvent{
+ Type: api.OutputTypeNewRoomEvent,
+ NewRoomEvent: &ore,
+ }, nil
+}
+
+type eventNIDSorter []types.EventNID
+
+func (s eventNIDSorter) Len() int { return len(s) }
+func (s eventNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
+func (s eventNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
diff --git a/roomserver/input/membership.go b/roomserver/input/membership.go
new file mode 100644
index 00000000..0c3fbb80
--- /dev/null
+++ b/roomserver/input/membership.go
@@ -0,0 +1,310 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package input
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// Membership values
+// TODO: Factor these out somewhere sensible?
+const join = "join"
+const leave = "leave"
+const invite = "invite"
+const ban = "ban"
+
+// updateMembership updates the current membership and the invites for each
+// user affected by a change in the current state of the room.
+// Returns a list of output events to write to the kafka log to inform the
+// consumers about the invites added or retired by the change in current state.
+func updateMemberships(
+ ctx context.Context,
+ db RoomEventDatabase,
+ updater types.RoomRecentEventsUpdater,
+ removed, added []types.StateEntry,
+) ([]api.OutputEvent, error) {
+ changes := membershipChanges(removed, added)
+ var eventNIDs []types.EventNID
+ for _, change := range changes {
+ if change.addedEventNID != 0 {
+ eventNIDs = append(eventNIDs, change.addedEventNID)
+ }
+ if change.removedEventNID != 0 {
+ eventNIDs = append(eventNIDs, change.removedEventNID)
+ }
+ }
+
+ // Load the event JSON so we can look up the "membership" key.
+ // TODO: Maybe add a membership key to the events table so we can load that
+ // key without having to load the entire event JSON?
+ events, err := db.Events(ctx, eventNIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ var updates []api.OutputEvent
+
+ for _, change := range changes {
+ var ae *gomatrixserverlib.Event
+ var re *gomatrixserverlib.Event
+ targetUserNID := change.EventStateKeyNID
+ if change.removedEventNID != 0 {
+ ev, _ := eventMap(events).lookup(change.removedEventNID)
+ if ev != nil {
+ re = &ev.Event
+ }
+ }
+ if change.addedEventNID != 0 {
+ ev, _ := eventMap(events).lookup(change.addedEventNID)
+ if ev != nil {
+ ae = &ev.Event
+ }
+ }
+ if updates, err = updateMembership(updater, targetUserNID, re, ae, updates); err != nil {
+ return nil, err
+ }
+ }
+ return updates, nil
+}
+
+func updateMembership(
+ updater types.RoomRecentEventsUpdater, targetUserNID types.EventStateKeyNID,
+ remove, add *gomatrixserverlib.Event,
+ updates []api.OutputEvent,
+) ([]api.OutputEvent, error) {
+ var err error
+ // Default the membership to Leave if no event was added or removed.
+ oldMembership := leave
+ newMembership := leave
+
+ if remove != nil {
+ oldMembership, err = remove.Membership()
+ if err != nil {
+ return nil, err
+ }
+ }
+ if add != nil {
+ newMembership, err = add.Membership()
+ if err != nil {
+ return nil, err
+ }
+ }
+ if oldMembership == newMembership && newMembership != join {
+ // If the membership is the same then nothing changed and we can return
+ // immediately, unless it's a Join update (e.g. profile update).
+ return updates, nil
+ }
+
+ mu, err := updater.MembershipUpdater(targetUserNID)
+ if err != nil {
+ return nil, err
+ }
+
+ switch newMembership {
+ case invite:
+ return updateToInviteMembership(mu, add, updates)
+ case join:
+ return updateToJoinMembership(mu, add, updates)
+ case leave, ban:
+ return updateToLeaveMembership(mu, add, newMembership, updates)
+ default:
+ panic(fmt.Errorf(
+ "input: membership %q is not one of the allowed values", newMembership,
+ ))
+ }
+}
+
+func updateToInviteMembership(
+ mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
+) ([]api.OutputEvent, error) {
+ // We may have already sent the invite to the user, either because we are
+ // reprocessing this event, or because the we received this invite from a
+ // remote server via the federation invite API. In those cases we don't need
+ // to send the event.
+ needsSending, err := mu.SetToInvite(*add)
+ if err != nil {
+ return nil, err
+ }
+ if needsSending {
+ // We notify the consumers using a special event even though we will
+ // notify them about the change in current state as part of the normal
+ // room event stream. This ensures that the consumers only have to
+ // consider a single stream of events when determining whether a user
+ // is invited, rather than having to combine multiple streams themselves.
+ onie := api.OutputNewInviteEvent{
+ Event: *add,
+ }
+ updates = append(updates, api.OutputEvent{
+ Type: api.OutputTypeNewInviteEvent,
+ NewInviteEvent: &onie,
+ })
+ }
+ return updates, nil
+}
+
+func updateToJoinMembership(
+ mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
+) ([]api.OutputEvent, error) {
+ // If the user is already marked as being joined, we call SetToJoin to update
+ // the event ID then we can return immediately. Retired is ignored as there
+ // is no invite event to retire.
+ if mu.IsJoin() {
+ _, err := mu.SetToJoin(add.Sender(), add.EventID(), true)
+ if err != nil {
+ return nil, err
+ }
+ return updates, nil
+ }
+ // When we mark a user as being joined we will invalidate any invites that
+ // are active for that user. We notify the consumers that the invites have
+ // been retired using a special event, even though they could infer this
+ // by studying the state changes in the room event stream.
+ retired, err := mu.SetToJoin(add.Sender(), add.EventID(), false)
+ if err != nil {
+ return nil, err
+ }
+ for _, eventID := range retired {
+ orie := api.OutputRetireInviteEvent{
+ EventID: eventID,
+ Membership: join,
+ RetiredByEventID: add.EventID(),
+ TargetUserID: *add.StateKey(),
+ }
+ updates = append(updates, api.OutputEvent{
+ Type: api.OutputTypeRetireInviteEvent,
+ RetireInviteEvent: &orie,
+ })
+ }
+ return updates, nil
+}
+
+func updateToLeaveMembership(
+ mu types.MembershipUpdater, add *gomatrixserverlib.Event,
+ newMembership string, updates []api.OutputEvent,
+) ([]api.OutputEvent, error) {
+ // If the user is already neither joined, nor invited to the room then we
+ // can return immediately.
+ if mu.IsLeave() {
+ return updates, nil
+ }
+ // When we mark a user as having left we will invalidate any invites that
+ // are active for that user. We notify the consumers that the invites have
+ // been retired using a special event, even though they could infer this
+ // by studying the state changes in the room event stream.
+ retired, err := mu.SetToLeave(add.Sender(), add.EventID())
+ if err != nil {
+ return nil, err
+ }
+ for _, eventID := range retired {
+ orie := api.OutputRetireInviteEvent{
+ EventID: eventID,
+ Membership: newMembership,
+ RetiredByEventID: add.EventID(),
+ TargetUserID: *add.StateKey(),
+ }
+ updates = append(updates, api.OutputEvent{
+ Type: api.OutputTypeRetireInviteEvent,
+ RetireInviteEvent: &orie,
+ })
+ }
+ return updates, nil
+}
+
+// membershipChanges pairs up the membership state changes from a sorted list
+// of state removed and a sorted list of state added.
+func membershipChanges(removed, added []types.StateEntry) []stateChange {
+ changes := pairUpChanges(removed, added)
+ var result []stateChange
+ for _, c := range changes {
+ if c.EventTypeNID == types.MRoomMemberNID {
+ result = append(result, c)
+ }
+ }
+ return result
+}
+
+type stateChange struct {
+ types.StateKeyTuple
+ removedEventNID types.EventNID
+ addedEventNID types.EventNID
+}
+
+// pairUpChanges pairs up the state events added and removed for each type,
+// state key tuple. Assumes that removed and added are sorted.
+func pairUpChanges(removed, added []types.StateEntry) []stateChange {
+ var ai int
+ var ri int
+ var result []stateChange
+ for {
+ switch {
+ case ai == len(added):
+ // We've reached the end of the added entries.
+ // The rest of the removed list are events that were removed without
+ // an event with the same state key being added.
+ for _, s := range removed[ri:] {
+ result = append(result, stateChange{
+ StateKeyTuple: s.StateKeyTuple,
+ removedEventNID: s.EventNID,
+ })
+ }
+ return result
+ case ri == len(removed):
+ // We've reached the end of the removed entries.
+ // The rest of the added list are events that were added without
+ // an event with the same state key being removed.
+ for _, s := range added[ai:] {
+ result = append(result, stateChange{
+ StateKeyTuple: s.StateKeyTuple,
+ addedEventNID: s.EventNID,
+ })
+ }
+ return result
+ case added[ai].StateKeyTuple == removed[ri].StateKeyTuple:
+ // The tuple is in both lists so an event with that key is being
+ // removed and another event with the same key is being added.
+ result = append(result, stateChange{
+ StateKeyTuple: added[ai].StateKeyTuple,
+ removedEventNID: removed[ri].EventNID,
+ addedEventNID: added[ai].EventNID,
+ })
+ ai++
+ ri++
+ case added[ai].StateKeyTuple.LessThan(removed[ri].StateKeyTuple):
+ // The lists are sorted so the added entry being less than the
+ // removed entry means that the added event was added without an
+ // event with the same key being removed.
+ result = append(result, stateChange{
+ StateKeyTuple: added[ai].StateKeyTuple,
+ addedEventNID: added[ai].EventNID,
+ })
+ ai++
+ default:
+ // Reaching the default case implies that the removed entry is less
+ // than the added entry. Since the lists are sorted this means that
+ // the removed event was removed without an event with the same
+ // key being added.
+ result = append(result, stateChange{
+ StateKeyTuple: removed[ai].StateKeyTuple,
+ removedEventNID: removed[ri].EventNID,
+ })
+ ri++
+ }
+ }
+}
diff --git a/roomserver/query/query.go b/roomserver/query/query.go
new file mode 100644
index 00000000..b97d50b1
--- /dev/null
+++ b/roomserver/query/query.go
@@ -0,0 +1,802 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package query
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/auth"
+ "github.com/matrix-org/dendrite/roomserver/state"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+)
+
+// RoomserverQueryAPIEventDB has a convenience API to fetch events directly by
+// EventIDs.
+type RoomserverQueryAPIEventDB interface {
+ // Look up the Events for a list of event IDs. Does not error if event was
+ // not found.
+ // Returns an error if the retrieval went wrong.
+ EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error)
+}
+
+// RoomserverQueryAPIDatabase has the storage APIs needed to implement the query API.
+type RoomserverQueryAPIDatabase interface {
+ state.RoomStateDatabase
+ RoomserverQueryAPIEventDB
+ // Look up the numeric ID for the room.
+ // Returns 0 if the room doesn't exists.
+ // Returns an error if there was a problem talking to the database.
+ RoomNID(ctx context.Context, roomID string) (types.RoomNID, error)
+ // Look up event references for the latest events in the room and the current state snapshot.
+ // Returns the latest events, the current state and the maximum depth of the latest events plus 1.
+ // Returns an error if there was a problem talking to the database.
+ LatestEventIDs(
+ ctx context.Context, roomNID types.RoomNID,
+ ) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
+ // Look up the numeric IDs for a list of events.
+ // Returns an error if there was a problem talking to the database.
+ EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
+ // Lookup the event IDs for a batch of event numeric IDs.
+ // Returns an error if the retrieval went wrong.
+ EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
+ // Lookup the membership of a given user in a given room.
+ // Returns the numeric ID of the latest membership event sent from this user
+ // in this room, along a boolean set to true if the user is still in this room,
+ // false if not.
+ // Returns an error if there was a problem talking to the database.
+ GetMembership(
+ ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
+ ) (membershipEventNID types.EventNID, stillInRoom bool, err error)
+ // Lookup the membership event numeric IDs for all user that are or have
+ // been members of a given room. Only lookup events of "join" membership if
+ // joinOnly is set to true.
+ // Returns an error if there was a problem talking to the database.
+ GetMembershipEventNIDsForRoom(
+ ctx context.Context, roomNID types.RoomNID, joinOnly bool,
+ ) ([]types.EventNID, error)
+ // Look up the active invites targeting a user in a room and return the
+ // numeric state key IDs for the user IDs who sent them.
+ // Returns an error if there was a problem talking to the database.
+ GetInvitesForUser(
+ ctx context.Context,
+ roomNID types.RoomNID,
+ targetUserNID types.EventStateKeyNID,
+ ) (senderUserNIDs []types.EventStateKeyNID, err error)
+ // Look up the string event state keys for a list of numeric event state keys
+ // Returns an error if there was a problem talking to the database.
+ EventStateKeys(
+ context.Context, []types.EventStateKeyNID,
+ ) (map[types.EventStateKeyNID]string, error)
+}
+
+// RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI
+type RoomserverQueryAPI struct {
+ DB RoomserverQueryAPIDatabase
+}
+
+// QueryLatestEventsAndState implements api.RoomserverQueryAPI
+func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
+ ctx context.Context,
+ request *api.QueryLatestEventsAndStateRequest,
+ response *api.QueryLatestEventsAndStateResponse,
+) error {
+ response.QueryLatestEventsAndStateRequest = *request
+ roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
+ if err != nil {
+ return err
+ }
+ if roomNID == 0 {
+ return nil
+ }
+ response.RoomExists = true
+ var currentStateSnapshotNID types.StateSnapshotNID
+ response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
+ r.DB.LatestEventIDs(ctx, roomNID)
+ if err != nil {
+ return err
+ }
+
+ // Look up the currrent state for the requested tuples.
+ stateEntries, err := state.LoadStateAtSnapshotForStringTuples(
+ ctx, r.DB, currentStateSnapshotNID, request.StateToFetch,
+ )
+ if err != nil {
+ return err
+ }
+
+ stateEvents, err := r.loadStateEvents(ctx, stateEntries)
+ if err != nil {
+ return err
+ }
+
+ response.StateEvents = stateEvents
+ return nil
+}
+
+// QueryStateAfterEvents implements api.RoomserverQueryAPI
+func (r *RoomserverQueryAPI) QueryStateAfterEvents(
+ ctx context.Context,
+ request *api.QueryStateAfterEventsRequest,
+ response *api.QueryStateAfterEventsResponse,
+) error {
+ response.QueryStateAfterEventsRequest = *request
+ roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
+ if err != nil {
+ return err
+ }
+ if roomNID == 0 {
+ return nil
+ }
+ response.RoomExists = true
+
+ prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
+ if err != nil {
+ switch err.(type) {
+ case types.MissingEventError:
+ return nil
+ default:
+ return err
+ }
+ }
+ response.PrevEventsExist = true
+
+ // Look up the currrent state for the requested tuples.
+ stateEntries, err := state.LoadStateAfterEventsForStringTuples(
+ ctx, r.DB, prevStates, request.StateToFetch,
+ )
+ if err != nil {
+ return err
+ }
+
+ stateEvents, err := r.loadStateEvents(ctx, stateEntries)
+ if err != nil {
+ return err
+ }
+
+ response.StateEvents = stateEvents
+ return nil
+}
+
+// QueryEventsByID implements api.RoomserverQueryAPI
+func (r *RoomserverQueryAPI) QueryEventsByID(
+ ctx context.Context,
+ request *api.QueryEventsByIDRequest,
+ response *api.QueryEventsByIDResponse,
+) error {
+ response.QueryEventsByIDRequest = *request
+
+ eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs)
+ if err != nil {
+ return err
+ }
+
+ var eventNIDs []types.EventNID
+ for _, nid := range eventNIDMap {
+ eventNIDs = append(eventNIDs, nid)
+ }
+
+ events, err := r.loadEvents(ctx, eventNIDs)
+ if err != nil {
+ return err
+ }
+
+ response.Events = events
+ return nil
+}
+
+func (r *RoomserverQueryAPI) loadStateEvents(
+ ctx context.Context, stateEntries []types.StateEntry,
+) ([]gomatrixserverlib.Event, error) {
+ eventNIDs := make([]types.EventNID, len(stateEntries))
+ for i := range stateEntries {
+ eventNIDs[i] = stateEntries[i].EventNID
+ }
+ return r.loadEvents(ctx, eventNIDs)
+}
+
+func (r *RoomserverQueryAPI) loadEvents(
+ ctx context.Context, eventNIDs []types.EventNID,
+) ([]gomatrixserverlib.Event, error) {
+ stateEvents, err := r.DB.Events(ctx, eventNIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ result := make([]gomatrixserverlib.Event, len(stateEvents))
+ for i := range stateEvents {
+ result[i] = stateEvents[i].Event
+ }
+ return result, nil
+}
+
+// QueryMembershipForUser implements api.RoomserverQueryAPI
+func (r *RoomserverQueryAPI) QueryMembershipForUser(
+ ctx context.Context,
+ request *api.QueryMembershipForUserRequest,
+ response *api.QueryMembershipForUserResponse,
+) error {
+ roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
+ if err != nil {
+ return err
+ }
+
+ membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, roomNID, request.UserID)
+ if err != nil {
+ return err
+ }
+
+ if membershipEventNID == 0 {
+ response.HasBeenInRoom = false
+ return nil
+ }
+
+ response.IsInRoom = stillInRoom
+ eventIDMap, err := r.DB.EventIDs(ctx, []types.EventNID{membershipEventNID})
+ if err != nil {
+ return err
+ }
+
+ response.EventID = eventIDMap[membershipEventNID]
+ return nil
+}
+
+// QueryMembershipsForRoom implements api.RoomserverQueryAPI
+func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
+ ctx context.Context,
+ request *api.QueryMembershipsForRoomRequest,
+ response *api.QueryMembershipsForRoomResponse,
+) error {
+ roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
+ if err != nil {
+ return err
+ }
+
+ membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, roomNID, request.Sender)
+ if err != nil {
+ return err
+ }
+
+ if membershipEventNID == 0 {
+ response.HasBeenInRoom = false
+ response.JoinEvents = nil
+ return nil
+ }
+
+ response.HasBeenInRoom = true
+ response.JoinEvents = []gomatrixserverlib.ClientEvent{}
+
+ var events []types.Event
+ if stillInRoom {
+ var eventNIDs []types.EventNID
+ eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly)
+ if err != nil {
+ return err
+ }
+
+ events, err = r.DB.Events(ctx, eventNIDs)
+ } else {
+ events, err = r.getMembershipsBeforeEventNID(ctx, membershipEventNID, request.JoinedOnly)
+ }
+
+ if err != nil {
+ return err
+ }
+
+ for _, event := range events {
+ clientEvent := gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll)
+ response.JoinEvents = append(response.JoinEvents, clientEvent)
+ }
+
+ return nil
+}
+
+// getMembershipsBeforeEventNID takes the numeric ID of an event and fetches the state
+// of the event's room as it was when this event was fired, then filters the state events to
+// only keep the "m.room.member" events with a "join" membership. These events are returned.
+// Returns an error if there was an issue fetching the events.
+func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(
+ ctx context.Context, eventNID types.EventNID, joinedOnly bool,
+) ([]types.Event, error) {
+ events := []types.Event{}
+ // Lookup the event NID
+ eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID})
+ if err != nil {
+ return nil, err
+ }
+ eventIDs := []string{eIDs[eventNID]}
+
+ prevState, err := r.DB.StateAtEventIDs(ctx, eventIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ // Fetch the state as it was when this event was fired
+ stateEntries, err := state.LoadCombinedStateAfterEvents(ctx, r.DB, prevState)
+ if err != nil {
+ return nil, err
+ }
+
+ var eventNIDs []types.EventNID
+ for _, entry := range stateEntries {
+ // Filter the events to retrieve to only keep the membership events
+ if entry.EventTypeNID == types.MRoomMemberNID {
+ eventNIDs = append(eventNIDs, entry.EventNID)
+ }
+ }
+
+ // Get all of the events in this state
+ stateEvents, err := r.DB.Events(ctx, eventNIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ if !joinedOnly {
+ return stateEvents, nil
+ }
+
+ // Filter the events to only keep the "join" membership events
+ for _, event := range stateEvents {
+ membership, err := event.Membership()
+ if err != nil {
+ return nil, err
+ }
+
+ if membership == "join" {
+ events = append(events, event)
+ }
+ }
+
+ return events, nil
+}
+
+// QueryInvitesForUser implements api.RoomserverQueryAPI
+func (r *RoomserverQueryAPI) QueryInvitesForUser(
+ ctx context.Context,
+ request *api.QueryInvitesForUserRequest,
+ response *api.QueryInvitesForUserResponse,
+) error {
+ roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
+ if err != nil {
+ return err
+ }
+
+ targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.TargetUserID})
+ if err != nil {
+ return err
+ }
+ targetUserNID := targetUserNIDs[request.TargetUserID]
+
+ senderUserNIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID)
+ if err != nil {
+ return err
+ }
+
+ senderUserIDs, err := r.DB.EventStateKeys(ctx, senderUserNIDs)
+ if err != nil {
+ return err
+ }
+
+ for _, senderUserID := range senderUserIDs {
+ response.InviteSenderUserIDs = append(response.InviteSenderUserIDs, senderUserID)
+ }
+
+ return nil
+}
+
+// QueryServerAllowedToSeeEvent implements api.RoomserverQueryAPI
+func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent(
+ ctx context.Context,
+ request *api.QueryServerAllowedToSeeEventRequest,
+ response *api.QueryServerAllowedToSeeEventResponse,
+) (err error) {
+ response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent(
+ ctx, request.EventID, request.ServerName,
+ )
+ return
+}
+
+func (r *RoomserverQueryAPI) checkServerAllowedToSeeEvent(
+ ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName,
+) (bool, error) {
+ stateEntries, err := state.LoadStateAtEvent(ctx, r.DB, eventID)
+ if err != nil {
+ return false, err
+ }
+
+ // TODO: We probably want to make it so that we don't have to pull
+ // out all the state if possible.
+ stateAtEvent, err := r.loadStateEvents(ctx, stateEntries)
+ if err != nil {
+ return false, err
+ }
+
+ return auth.IsServerAllowed(serverName, stateAtEvent), nil
+}
+
+// QueryMissingEvents implements api.RoomserverQueryAPI
+func (r *RoomserverQueryAPI) QueryMissingEvents(
+ ctx context.Context,
+ request *api.QueryMissingEventsRequest,
+ response *api.QueryMissingEventsResponse,
+) error {
+ var front []string
+ eventsToFilter := make(map[string]bool, len(request.LatestEvents))
+ visited := make(map[string]bool, request.Limit) // request.Limit acts as a hint to size.
+ for _, id := range request.EarliestEvents {
+ visited[id] = true
+ }
+
+ for _, id := range request.LatestEvents {
+ if !visited[id] {
+ front = append(front, id)
+ eventsToFilter[id] = true
+ }
+ }
+
+ resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName)
+ if err != nil {
+ return err
+ }
+
+ loadedEvents, err := r.loadEvents(ctx, resultNIDs)
+ if err != nil {
+ return err
+ }
+
+ response.Events = make([]gomatrixserverlib.Event, 0, len(loadedEvents)-len(eventsToFilter))
+ for _, event := range loadedEvents {
+ if !eventsToFilter[event.EventID()] {
+ response.Events = append(response.Events, event)
+ }
+ }
+
+ return err
+}
+
+// QueryBackfill implements api.RoomServerQueryAPI
+func (r *RoomserverQueryAPI) QueryBackfill(
+ ctx context.Context,
+ request *api.QueryBackfillRequest,
+ response *api.QueryBackfillResponse,
+) error {
+ var err error
+ var front []string
+
+ // The limit defines the maximum number of events to retrieve, so it also
+ // defines the highest number of elements in the map below.
+ visited := make(map[string]bool, request.Limit)
+
+ // The provided event IDs have already been seen by the request's emitter,
+ // and will be retrieved anyway, so there's no need to care about them if
+ // they appear in our exploration of the event tree.
+ for _, id := range request.EarliestEventsIDs {
+ visited[id] = true
+ }
+
+ front = request.EarliestEventsIDs
+
+ // Scan the event tree for events to send back.
+ resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName)
+ if err != nil {
+ return err
+ }
+
+ // Retrieve events from the list that was filled previously.
+ response.Events, err = r.loadEvents(ctx, resultNIDs)
+ return err
+}
+
+func (r *RoomserverQueryAPI) scanEventTree(
+ ctx context.Context, front []string, visited map[string]bool, limit int,
+ serverName gomatrixserverlib.ServerName,
+) (resultNIDs []types.EventNID, err error) {
+ var allowed bool
+ var events []types.Event
+ var next []string
+ var pre string
+
+ resultNIDs = make([]types.EventNID, 0, limit)
+
+ // Loop through the event IDs to retrieve the requested events and go
+ // through the whole tree (up to the provided limit) using the events'
+ // "prev_event" key.
+BFSLoop:
+ for len(front) > 0 {
+ // Prevent unnecessary allocations: reset the slice only when not empty.
+ if len(next) > 0 {
+ next = make([]string, 0)
+ }
+ // Retrieve the events to process from the database.
+ events, err = r.DB.EventsFromIDs(ctx, front)
+ if err != nil {
+ return
+ }
+
+ for _, ev := range events {
+ // Break out of the loop if the provided limit is reached.
+ if len(resultNIDs) == limit {
+ break BFSLoop
+ }
+ // Update the list of events to retrieve.
+ resultNIDs = append(resultNIDs, ev.EventNID)
+ // Loop through the event's parents.
+ for _, pre = range ev.PrevEventIDs() {
+ // Only add an event to the list of next events to process if it
+ // hasn't been seen before.
+ if !visited[pre] {
+ visited[pre] = true
+ allowed, err = r.checkServerAllowedToSeeEvent(ctx, pre, serverName)
+ if err != nil {
+ return
+ }
+
+ // If the event hasn't been seen before and the HS
+ // requesting to retrieve it is allowed to do so, add it to
+ // the list of events to retrieve.
+ if allowed {
+ next = append(next, pre)
+ }
+ }
+ }
+ }
+ // Repeat the same process with the parent events we just processed.
+ front = next
+ }
+
+ return
+}
+
+// QueryStateAndAuthChain implements api.RoomserverQueryAPI
+func (r *RoomserverQueryAPI) QueryStateAndAuthChain(
+ ctx context.Context,
+ request *api.QueryStateAndAuthChainRequest,
+ response *api.QueryStateAndAuthChainResponse,
+) error {
+ response.QueryStateAndAuthChainRequest = *request
+ roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
+ if err != nil {
+ return err
+ }
+ if roomNID == 0 {
+ return nil
+ }
+ response.RoomExists = true
+
+ prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
+ if err != nil {
+ switch err.(type) {
+ case types.MissingEventError:
+ return nil
+ default:
+ return err
+ }
+ }
+ response.PrevEventsExist = true
+
+ // Look up the currrent state for the requested tuples.
+ stateEntries, err := state.LoadCombinedStateAfterEvents(
+ ctx, r.DB, prevStates,
+ )
+ if err != nil {
+ return err
+ }
+
+ stateEvents, err := r.loadStateEvents(ctx, stateEntries)
+ if err != nil {
+ return err
+ }
+
+ response.StateEvents = stateEvents
+ response.AuthChainEvents, err = getAuthChain(ctx, r.DB, request.AuthEventIDs)
+ return err
+}
+
+// getAuthChain fetches the auth chain for the given auth events.
+// An auth chain is the list of all events that are referenced in the
+// auth_events section, and all their auth_events, recursively.
+// The returned set of events contain the given events.
+// Will *not* error if we don't have all auth events.
+func getAuthChain(
+ ctx context.Context, dB RoomserverQueryAPIEventDB, authEventIDs []string,
+) ([]gomatrixserverlib.Event, error) {
+ var authEvents []gomatrixserverlib.Event
+
+ // List of event ids to fetch. These will be added to the result and
+ // their auth events will be fetched (if they haven't been previously)
+ eventsToFetch := authEventIDs
+
+ // Set of events we've already fetched.
+ fetchedEventMap := make(map[string]bool)
+
+ // Check if there's anything left to do
+ for len(eventsToFetch) > 0 {
+ // Convert eventIDs to events. First need to fetch NIDs
+ events, err := dB.EventsFromIDs(ctx, eventsToFetch)
+ if err != nil {
+ return nil, err
+ }
+
+ // Work out a) which events we should add to the returned list of
+ // events and b) which of the auth events we haven't seen yet and
+ // add them to the list of events to fetch.
+ eventsToFetch = eventsToFetch[:0]
+ for _, event := range events {
+ fetchedEventMap[event.EventID()] = true
+ authEvents = append(authEvents, event.Event)
+
+ // Now we need to fetch any auth events that we haven't
+ // previously seen.
+ for _, authEventID := range event.AuthEventIDs() {
+ if !fetchedEventMap[authEventID] {
+ fetchedEventMap[authEventID] = true
+ eventsToFetch = append(eventsToFetch, authEventID)
+ }
+ }
+ }
+ }
+
+ return authEvents, nil
+}
+
+// SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux.
+// nolint: gocyclo
+func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
+ servMux.Handle(
+ api.RoomserverQueryLatestEventsAndStatePath,
+ common.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse {
+ var request api.QueryLatestEventsAndStateRequest
+ var response api.QueryLatestEventsAndStateResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.ErrorResponse(err)
+ }
+ if err := r.QueryLatestEventsAndState(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ servMux.Handle(
+ api.RoomserverQueryStateAfterEventsPath,
+ common.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse {
+ var request api.QueryStateAfterEventsRequest
+ var response api.QueryStateAfterEventsResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.ErrorResponse(err)
+ }
+ if err := r.QueryStateAfterEvents(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ servMux.Handle(
+ api.RoomserverQueryEventsByIDPath,
+ common.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse {
+ var request api.QueryEventsByIDRequest
+ var response api.QueryEventsByIDResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.ErrorResponse(err)
+ }
+ if err := r.QueryEventsByID(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ servMux.Handle(
+ api.RoomserverQueryMembershipForUserPath,
+ common.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse {
+ var request api.QueryMembershipForUserRequest
+ var response api.QueryMembershipForUserResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.ErrorResponse(err)
+ }
+ if err := r.QueryMembershipForUser(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ servMux.Handle(
+ api.RoomserverQueryMembershipsForRoomPath,
+ common.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse {
+ var request api.QueryMembershipsForRoomRequest
+ var response api.QueryMembershipsForRoomResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.ErrorResponse(err)
+ }
+ if err := r.QueryMembershipsForRoom(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ servMux.Handle(
+ api.RoomserverQueryInvitesForUserPath,
+ common.MakeInternalAPI("queryInvitesForUser", func(req *http.Request) util.JSONResponse {
+ var request api.QueryInvitesForUserRequest
+ var response api.QueryInvitesForUserResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.ErrorResponse(err)
+ }
+ if err := r.QueryInvitesForUser(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ servMux.Handle(
+ api.RoomserverQueryServerAllowedToSeeEventPath,
+ common.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse {
+ var request api.QueryServerAllowedToSeeEventRequest
+ var response api.QueryServerAllowedToSeeEventResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.ErrorResponse(err)
+ }
+ if err := r.QueryServerAllowedToSeeEvent(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ servMux.Handle(
+ api.RoomserverQueryMissingEventsPath,
+ common.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse {
+ var request api.QueryMissingEventsRequest
+ var response api.QueryMissingEventsResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.ErrorResponse(err)
+ }
+ if err := r.QueryMissingEvents(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ servMux.Handle(
+ api.RoomserverQueryStateAndAuthChainPath,
+ common.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse {
+ var request api.QueryStateAndAuthChainRequest
+ var response api.QueryStateAndAuthChainResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.ErrorResponse(err)
+ }
+ if err := r.QueryStateAndAuthChain(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ servMux.Handle(
+ api.RoomserverQueryBackfillPath,
+ common.MakeInternalAPI("QueryBackfill", func(req *http.Request) util.JSONResponse {
+ var request api.QueryBackfillRequest
+ var response api.QueryBackfillResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.ErrorResponse(err)
+ }
+ if err := r.QueryBackfill(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+}
diff --git a/roomserver/query/query_test.go b/roomserver/query/query_test.go
new file mode 100644
index 00000000..76c2e158
--- /dev/null
+++ b/roomserver/query/query_test.go
@@ -0,0 +1,155 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package query
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+
+ "github.com/matrix-org/dendrite/common/test"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// used to implement RoomserverQueryAPIEventDB to test getAuthChain
+type getEventDB struct {
+ eventMap map[string]gomatrixserverlib.Event
+}
+
+func createEventDB() *getEventDB {
+ return &getEventDB{
+ eventMap: make(map[string]gomatrixserverlib.Event),
+ }
+}
+
+// Adds a fake event to the storage with given auth events.
+func (db *getEventDB) addFakeEvent(eventID string, authIDs []string) error {
+ authEvents := []gomatrixserverlib.EventReference{}
+ for _, authID := range authIDs {
+ authEvents = append(authEvents, gomatrixserverlib.EventReference{
+ EventID: authID,
+ })
+ }
+
+ builder := map[string]interface{}{
+ "event_id": eventID,
+ "auth_events": authEvents,
+ }
+
+ eventJSON, err := json.Marshal(&builder)
+ if err != nil {
+ return err
+ }
+
+ event, err := gomatrixserverlib.NewEventFromTrustedJSON(eventJSON, false)
+ if err != nil {
+ return err
+ }
+
+ db.eventMap[eventID] = event
+
+ return nil
+}
+
+// Adds multiple events at once, each entry in the map is an eventID and set of
+// auth events that are converted to an event and added.
+func (db *getEventDB) addFakeEvents(graph map[string][]string) error {
+ for eventID, authIDs := range graph {
+ err := db.addFakeEvent(eventID, authIDs)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// EventsFromIDs implements RoomserverQueryAPIEventDB
+func (db *getEventDB) EventsFromIDs(ctx context.Context, eventIDs []string) (res []types.Event, err error) {
+ for _, evID := range eventIDs {
+ res = append(res, types.Event{
+ EventNID: 0,
+ Event: db.eventMap[evID],
+ })
+ }
+
+ return
+}
+
+func TestGetAuthChainSingle(t *testing.T) {
+ db := createEventDB()
+
+ err := db.addFakeEvents(map[string][]string{
+ "a": {},
+ "b": {"a"},
+ "c": {"a", "b"},
+ "d": {"b", "c"},
+ "e": {"a", "d"},
+ })
+
+ if err != nil {
+ t.Fatalf("Failed to add events to db: %v", err)
+ }
+
+ result, err := getAuthChain(context.TODO(), db, []string{"e"})
+ if err != nil {
+ t.Fatalf("getAuthChain failed: %v", err)
+ }
+
+ var returnedIDs []string
+ for _, event := range result {
+ returnedIDs = append(returnedIDs, event.EventID())
+ }
+
+ expectedIDs := []string{"a", "b", "c", "d", "e"}
+
+ if !test.UnsortedStringSliceEqual(expectedIDs, returnedIDs) {
+ t.Fatalf("returnedIDs got '%v', expected '%v'", returnedIDs, expectedIDs)
+ }
+}
+
+func TestGetAuthChainMultiple(t *testing.T) {
+ db := createEventDB()
+
+ err := db.addFakeEvents(map[string][]string{
+ "a": {},
+ "b": {"a"},
+ "c": {"a", "b"},
+ "d": {"b", "c"},
+ "e": {"a", "d"},
+ "f": {"a", "b", "c"},
+ })
+
+ if err != nil {
+ t.Fatalf("Failed to add events to db: %v", err)
+ }
+
+ result, err := getAuthChain(context.TODO(), db, []string{"e", "f"})
+ if err != nil {
+ t.Fatalf("getAuthChain failed: %v", err)
+ }
+
+ var returnedIDs []string
+ for _, event := range result {
+ returnedIDs = append(returnedIDs, event.EventID())
+ }
+
+ expectedIDs := []string{"a", "b", "c", "d", "e", "f"}
+
+ if !test.UnsortedStringSliceEqual(expectedIDs, returnedIDs) {
+ t.Fatalf("returnedIDs got '%v', expected '%v'", returnedIDs, expectedIDs)
+ }
+}
diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go
new file mode 100644
index 00000000..2ffbf67d
--- /dev/null
+++ b/roomserver/roomserver.go
@@ -0,0 +1,68 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package roomserver
+
+import (
+ "net/http"
+
+ "github.com/matrix-org/dendrite/roomserver/api"
+
+ asQuery "github.com/matrix-org/dendrite/appservice/query"
+ "github.com/matrix-org/dendrite/common/basecomponent"
+ "github.com/matrix-org/dendrite/roomserver/alias"
+ "github.com/matrix-org/dendrite/roomserver/input"
+ "github.com/matrix-org/dendrite/roomserver/query"
+ "github.com/matrix-org/dendrite/roomserver/storage"
+ "github.com/sirupsen/logrus"
+)
+
+// SetupRoomServerComponent sets up and registers HTTP handlers for the
+// RoomServer component. Returns instances of the various roomserver APIs,
+// allowing other components running in the same process to hit the query the
+// APIs directly instead of having to use HTTP.
+func SetupRoomServerComponent(
+ base *basecomponent.BaseDendrite,
+) (api.RoomserverAliasAPI, api.RoomserverInputAPI, api.RoomserverQueryAPI) {
+ roomserverDB, err := storage.Open(string(base.Cfg.Database.RoomServer))
+ if err != nil {
+ logrus.WithError(err).Panicf("failed to connect to room server db")
+ }
+
+ inputAPI := input.RoomserverInputAPI{
+ DB: roomserverDB,
+ Producer: base.KafkaProducer,
+ OutputRoomEventTopic: string(base.Cfg.Kafka.Topics.OutputRoomEvent),
+ }
+
+ inputAPI.SetupHTTP(http.DefaultServeMux)
+
+ queryAPI := query.RoomserverQueryAPI{DB: roomserverDB}
+
+ queryAPI.SetupHTTP(http.DefaultServeMux)
+
+ asAPI := asQuery.AppServiceQueryAPI{Cfg: base.Cfg}
+
+ aliasAPI := alias.RoomserverAliasAPI{
+ DB: roomserverDB,
+ Cfg: base.Cfg,
+ InputAPI: &inputAPI,
+ QueryAPI: &queryAPI,
+ AppserviceAPI: &asAPI,
+ }
+
+ aliasAPI.SetupHTTP(http.DefaultServeMux)
+
+ return &aliasAPI, &inputAPI, &queryAPI
+}
diff --git a/roomserver/state/state.go b/roomserver/state/state.go
new file mode 100644
index 00000000..2a0b7f57
--- /dev/null
+++ b/roomserver/state/state.go
@@ -0,0 +1,966 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package state provides functions for reading state from the database.
+// The functions for writing state to the database are the input package.
+package state
+
+import (
+ "context"
+ "fmt"
+ "sort"
+ "time"
+
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/prometheus/client_golang/prometheus"
+)
+
+// A RoomStateDatabase has the storage APIs needed to load state from the database
+type RoomStateDatabase interface {
+ // Store the room state at an event in the database
+ AddState(
+ ctx context.Context,
+ roomNID types.RoomNID,
+ stateBlockNIDs []types.StateBlockNID,
+ state []types.StateEntry,
+ ) (types.StateSnapshotNID, error)
+ // Look up the state of a room at each event for a list of string event IDs.
+ // Returns an error if there is an error talking to the database
+ // Returns a types.MissingEventError if the room state for the event IDs aren't in the database
+ StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
+ // Look up the numeric IDs for a list of string event types.
+ // Returns a map from string event type to numeric ID for the event type.
+ EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
+ // Look up the numeric IDs for a list of string event state keys.
+ // Returns a map from string state key to numeric ID for the state key.
+ EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
+ // Look up the numeric state data IDs for each numeric state snapshot ID
+ // The returned slice is sorted by numeric state snapshot ID.
+ StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
+ // Look up the state data for each numeric state data ID
+ // The returned slice is sorted by numeric state data ID.
+ StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
+ // Look up the state data for the state key tuples for each numeric state block ID
+ // This is used to fetch a subset of the room state at a snapshot.
+ // If a block doesn't contain any of the requested tuples then it can be discarded from the result.
+ // The returned slice is sorted by numeric state block ID.
+ StateEntriesForTuples(
+ ctx context.Context,
+ stateBlockNIDs []types.StateBlockNID,
+ stateKeyTuples []types.StateKeyTuple,
+ ) ([]types.StateEntryList, error)
+ // Look up the Events for a list of numeric event IDs.
+ // Returns a sorted list of events.
+ Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
+ // Look up snapshot NID for an event ID string
+ SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
+}
+
+// LoadStateAtSnapshot loads the full state of a room at a particular snapshot.
+// This is typically the state before an event or the current state of a room.
+// Returns a sorted list of state entries or an error if there was a problem talking to the database.
+func LoadStateAtSnapshot(
+ ctx context.Context, db RoomStateDatabase, stateNID types.StateSnapshotNID,
+) ([]types.StateEntry, error) {
+ stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
+ if err != nil {
+ return nil, err
+ }
+ // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
+ stateBlockNIDList := stateBlockNIDLists[0]
+
+ stateEntryLists, err := db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
+ if err != nil {
+ return nil, err
+ }
+ stateEntriesMap := stateEntryListMap(stateEntryLists)
+
+ // Combine all the state entries for this snapshot.
+ // The order of state block NIDs in the list tells us the order to combine them in.
+ var fullState []types.StateEntry
+ for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
+ entries, ok := stateEntriesMap.lookup(stateBlockNID)
+ if !ok {
+ // This should only get hit if the database is corrupt.
+ // It should be impossible for an event to reference a NID that doesn't exist
+ panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID))
+ }
+ fullState = append(fullState, entries...)
+ }
+
+ // Stable sort so that the most recent entry for each state key stays
+ // remains later in the list than the older entries for the same state key.
+ sort.Stable(stateEntryByStateKeySorter(fullState))
+ // Unique returns the last entry and hence the most recent entry for each state key.
+ fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
+ return fullState, nil
+}
+
+// LoadStateAtEvent loads the full state of a room at a particular event.
+func LoadStateAtEvent(
+ ctx context.Context, db RoomStateDatabase, eventID string,
+) ([]types.StateEntry, error) {
+ snapshotNID, err := db.SnapshotNIDFromEventID(ctx, eventID)
+ if err != nil {
+ return nil, err
+ }
+
+ stateEntries, err := LoadStateAtSnapshot(ctx, db, snapshotNID)
+ if err != nil {
+ return nil, err
+ }
+
+ return stateEntries, nil
+}
+
+// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
+// and combines those snapshots together into a single list.
+func LoadCombinedStateAfterEvents(
+ ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent,
+) ([]types.StateEntry, error) {
+ stateNIDs := make([]types.StateSnapshotNID, len(prevStates))
+ for i, state := range prevStates {
+ stateNIDs[i] = state.BeforeStateSnapshotNID
+ }
+ // Fetch the state snapshots for the state before the each prev event from the database.
+ // Deduplicate the IDs before passing them to the database.
+ // There could be duplicates because the events could be state events where
+ // the snapshot of the room state before them was the same.
+ stateBlockNIDLists, err := db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs))
+ if err != nil {
+ return nil, err
+ }
+
+ var stateBlockNIDs []types.StateBlockNID
+ for _, list := range stateBlockNIDLists {
+ stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...)
+ }
+ // Fetch the state entries that will be combined to create the snapshots.
+ // Deduplicate the IDs before passing them to the database.
+ // There could be duplicates because a block of state entries could be reused by
+ // multiple snapshots.
+ stateEntryLists, err := db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs))
+ if err != nil {
+ return nil, err
+ }
+ stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists)
+ stateEntriesMap := stateEntryListMap(stateEntryLists)
+
+ // Combine the entries from all the snapshots of state after each prev event into a single list.
+ var combined []types.StateEntry
+ for _, prevState := range prevStates {
+ // Grab the list of state data NIDs for this snapshot.
+ stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID)
+ if !ok {
+ // This should only get hit if the database is corrupt.
+ // It should be impossible for an event to reference a NID that doesn't exist
+ panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID))
+ }
+
+ // Combine all the state entries for this snapshot.
+ // The order of state block NIDs in the list tells us the order to combine them in.
+ var fullState []types.StateEntry
+ for _, stateBlockNID := range stateBlockNIDs {
+ entries, ok := stateEntriesMap.lookup(stateBlockNID)
+ if !ok {
+ // This should only get hit if the database is corrupt.
+ // It should be impossible for an event to reference a NID that doesn't exist
+ panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID))
+ }
+ fullState = append(fullState, entries...)
+ }
+ if prevState.IsStateEvent() {
+ // If the prev event was a state event then add an entry for the event itself
+ // so that we get the state after the event rather than the state before.
+ fullState = append(fullState, prevState.StateEntry)
+ }
+
+ // Stable sort so that the most recent entry for each state key stays
+ // remains later in the list than the older entries for the same state key.
+ sort.Stable(stateEntryByStateKeySorter(fullState))
+ // Unique returns the last entry and hence the most recent entry for each state key.
+ fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
+ // Add the full state for this StateSnapshotNID.
+ combined = append(combined, fullState...)
+ }
+ return combined, nil
+}
+
+// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots.
+func DifferenceBetweeenStateSnapshots(
+ ctx context.Context, db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID,
+) (removed, added []types.StateEntry, err error) {
+ if oldStateNID == newStateNID {
+ // If the snapshot NIDs are the same then nothing has changed
+ return nil, nil, nil
+ }
+
+ var oldEntries []types.StateEntry
+ var newEntries []types.StateEntry
+ if oldStateNID != 0 {
+ oldEntries, err = LoadStateAtSnapshot(ctx, db, oldStateNID)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ if newStateNID != 0 {
+ newEntries, err = LoadStateAtSnapshot(ctx, db, newStateNID)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ var oldI int
+ var newI int
+ for {
+ switch {
+ case oldI == len(oldEntries):
+ // We've reached the end of the old entries.
+ // The rest of the new list must have been newly added.
+ added = append(added, newEntries[newI:]...)
+ return
+ case newI == len(newEntries):
+ // We've reached the end of the new entries.
+ // The rest of the old list must be have been removed.
+ removed = append(removed, oldEntries[oldI:]...)
+ return
+ case oldEntries[oldI] == newEntries[newI]:
+ // The entry is in both lists so skip over it.
+ oldI++
+ newI++
+ case oldEntries[oldI].LessThan(newEntries[newI]):
+ // The lists are sorted so the old entry being less than the new entry means that it only appears in the old list.
+ removed = append(removed, oldEntries[oldI])
+ oldI++
+ default:
+ // Reaching the default case implies that the new entry is less than the old entry.
+ // Since the lists are sorted this means that it only appears in the new list.
+ added = append(added, newEntries[newI])
+ newI++
+ }
+ }
+}
+
+// LoadStateAtSnapshotForStringTuples loads the state for a list of event type and state key pairs at a snapshot.
+// This is used when we only want to load a subset of the room state at a snapshot.
+// If there is no entry for a given event type and state key pair then it will be discarded.
+// This is typically the state before an event or the current state of a room.
+// Returns a sorted list of state entries or an error if there was a problem talking to the database.
+func LoadStateAtSnapshotForStringTuples(
+ ctx context.Context,
+ db RoomStateDatabase,
+ stateNID types.StateSnapshotNID,
+ stateKeyTuples []gomatrixserverlib.StateKeyTuple,
+) ([]types.StateEntry, error) {
+ numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples)
+ if err != nil {
+ return nil, err
+ }
+ return loadStateAtSnapshotForNumericTuples(ctx, db, stateNID, numericTuples)
+}
+
+// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs
+// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded.
+// Returns an error if there was a problem talking to the database.
+func stringTuplesToNumericTuples(
+ ctx context.Context,
+ db RoomStateDatabase,
+ stringTuples []gomatrixserverlib.StateKeyTuple,
+) ([]types.StateKeyTuple, error) {
+ eventTypes := make([]string, len(stringTuples))
+ stateKeys := make([]string, len(stringTuples))
+ for i := range stringTuples {
+ eventTypes[i] = stringTuples[i].EventType
+ stateKeys[i] = stringTuples[i].StateKey
+ }
+ eventTypes = util.UniqueStrings(eventTypes)
+ eventTypeMap, err := db.EventTypeNIDs(ctx, eventTypes)
+ if err != nil {
+ return nil, err
+ }
+ stateKeys = util.UniqueStrings(stateKeys)
+ stateKeyMap, err := db.EventStateKeyNIDs(ctx, stateKeys)
+ if err != nil {
+ return nil, err
+ }
+
+ var result []types.StateKeyTuple
+ for _, stringTuple := range stringTuples {
+ var numericTuple types.StateKeyTuple
+ var ok1, ok2 bool
+ numericTuple.EventTypeNID, ok1 = eventTypeMap[stringTuple.EventType]
+ numericTuple.EventStateKeyNID, ok2 = stateKeyMap[stringTuple.StateKey]
+ // Discard the tuple if there wasn't a numeric ID for either the event type or the state key.
+ if ok1 && ok2 {
+ result = append(result, numericTuple)
+ }
+ }
+
+ return result, nil
+}
+
+// loadStateAtSnapshotForNumericTuples loads the state for a list of event type and state key pairs at a snapshot.
+// This is used when we only want to load a subset of the room state at a snapshot.
+// If there is no entry for a given event type and state key pair then it will be discarded.
+// This is typically the state before an event or the current state of a room.
+// Returns a sorted list of state entries or an error if there was a problem talking to the database.
+func loadStateAtSnapshotForNumericTuples(
+ ctx context.Context,
+ db RoomStateDatabase,
+ stateNID types.StateSnapshotNID,
+ stateKeyTuples []types.StateKeyTuple,
+) ([]types.StateEntry, error) {
+ stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
+ if err != nil {
+ return nil, err
+ }
+ // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
+ stateBlockNIDList := stateBlockNIDLists[0]
+
+ stateEntryLists, err := db.StateEntriesForTuples(
+ ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples,
+ )
+ if err != nil {
+ return nil, err
+ }
+ stateEntriesMap := stateEntryListMap(stateEntryLists)
+
+ // Combine all the state entries for this snapshot.
+ // The order of state block NIDs in the list tells us the order to combine them in.
+ var fullState []types.StateEntry
+ for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
+ entries, ok := stateEntriesMap.lookup(stateBlockNID)
+ if !ok {
+ // If the block is missing from the map it means that none of its entries matched a requested tuple.
+ // This can happen if the block doesn't contain an update for one of the requested tuples.
+ // If none of the requested tuples are in the block then it can be safely skipped.
+ continue
+ }
+ fullState = append(fullState, entries...)
+ }
+
+ // Stable sort so that the most recent entry for each state key stays
+ // remains later in the list than the older entries for the same state key.
+ sort.Stable(stateEntryByStateKeySorter(fullState))
+ // Unique returns the last entry and hence the most recent entry for each state key.
+ fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
+ return fullState, nil
+}
+
+// LoadStateAfterEventsForStringTuples loads the state for a list of event type
+// and state key pairs after list of events.
+// This is used when we only want to load a subset of the room state after a list of events.
+// If there is no entry for a given event type and state key pair then it will be discarded.
+// This is typically the state before an event.
+// Returns a sorted list of state entries or an error if there was a problem talking to the database.
+func LoadStateAfterEventsForStringTuples(
+ ctx context.Context,
+ db RoomStateDatabase,
+ prevStates []types.StateAtEvent,
+ stateKeyTuples []gomatrixserverlib.StateKeyTuple,
+) ([]types.StateEntry, error) {
+ numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples)
+ if err != nil {
+ return nil, err
+ }
+ return loadStateAfterEventsForNumericTuples(ctx, db, prevStates, numericTuples)
+}
+
+func loadStateAfterEventsForNumericTuples(
+ ctx context.Context,
+ db RoomStateDatabase,
+ prevStates []types.StateAtEvent,
+ stateKeyTuples []types.StateKeyTuple,
+) ([]types.StateEntry, error) {
+ if len(prevStates) == 1 {
+ // Fast path for a single event.
+ prevState := prevStates[0]
+ result, err := loadStateAtSnapshotForNumericTuples(
+ ctx, db, prevState.BeforeStateSnapshotNID, stateKeyTuples,
+ )
+ if err != nil {
+ return nil, err
+ }
+ if prevState.IsStateEvent() {
+ // The result is current the state before the requested event.
+ // We want the state after the requested event.
+ // If the requested event was a state event then we need to
+ // update that key in the result.
+ // If the requested event wasn't a state event then the state after
+ // it is the same as the state before it.
+ for i := range result {
+ if result[i].StateKeyTuple == prevState.StateKeyTuple {
+ result[i] = prevState.StateEntry
+ }
+ }
+ }
+ return result, nil
+ }
+
+ // Slow path for more that one event.
+ // Load the entire state so that we can do conflict resolution if we need to.
+ // TODO: The are some optimistations we could do here:
+ // 1) We only need to do conflict resolution if there is a conflict in the
+ // requested tuples so we might try loading just those tuples and then
+ // checking for conflicts.
+ // 2) When there is a conflict we still only need to load the state
+ // needed to do conflict resolution which would save us having to load
+ // the full state.
+
+ // TODO: Add metrics for this as it could take a long time for big rooms
+ // with large conflicts.
+ fullState, _, _, err := calculateStateAfterManyEvents(ctx, db, prevStates)
+ if err != nil {
+ return nil, err
+ }
+
+ // Sort the full state so we can use it as a map.
+ sort.Sort(stateEntrySorter(fullState))
+
+ // Filter the full state down to the required tuples.
+ var result []types.StateEntry
+ for _, tuple := range stateKeyTuples {
+ eventNID, ok := stateEntryMap(fullState).lookup(tuple)
+ if ok {
+ result = append(result, types.StateEntry{
+ StateKeyTuple: tuple,
+ EventNID: eventNID,
+ })
+ }
+ }
+ sort.Sort(stateEntrySorter(result))
+ return result, nil
+}
+
+var calculateStateDurations = prometheus.NewSummaryVec(
+ prometheus.SummaryOpts{
+ Namespace: "dendrite",
+ Subsystem: "roomserver",
+ Name: "calculate_state_duration_microseconds",
+ Help: "How long it takes to calculate the state after a list of events",
+ },
+ // Takes two labels:
+ // algorithm:
+ // The algorithm used to calculate the state or the step it failed on if it failed.
+ // Labels starting with "_" are used to indicate when the algorithm fails halfway.
+ // outcome:
+ // Whether the state was successfully calculated.
+ //
+ // The possible values for algorithm are:
+ // empty_state -> The list of events was empty so the state is empty.
+ // no_change -> The state hasn't changed.
+ // single_delta -> There was a single event added to the state in a way that can be encoded as a single delta
+ // full_state_no_conflicts -> We created a new copy of the full room state, but didn't enounter any conflicts
+ // while doing so.
+ // full_state_with_conflicts -> We created a new copy of the full room state and had to resolve conflicts to do so.
+ // _load_state_block_nids -> Failed loading the state block nids for a single previous state.
+ // _load_combined_state -> Failed to load the combined state.
+ // _resolve_conflicts -> Failed to resolve conflicts.
+ []string{"algorithm", "outcome"},
+)
+
+var calculateStatePrevEventLength = prometheus.NewSummaryVec(
+ prometheus.SummaryOpts{
+ Namespace: "dendrite",
+ Subsystem: "roomserver",
+ Name: "calculate_state_prev_event_length",
+ Help: "The length of the list of events to calculate the state after",
+ },
+ []string{"algorithm", "outcome"},
+)
+
+var calculateStateFullStateLength = prometheus.NewSummaryVec(
+ prometheus.SummaryOpts{
+ Namespace: "dendrite",
+ Subsystem: "roomserver",
+ Name: "calculate_state_full_state_length",
+ Help: "The length of the full room state.",
+ },
+ []string{"algorithm", "outcome"},
+)
+
+var calculateStateConflictLength = prometheus.NewSummaryVec(
+ prometheus.SummaryOpts{
+ Namespace: "dendrite",
+ Subsystem: "roomserver",
+ Name: "calculate_state_conflict_state_length",
+ Help: "The length of the conflicted room state.",
+ },
+ []string{"algorithm", "outcome"},
+)
+
+type calculateStateMetrics struct {
+ algorithm string
+ startTime time.Time
+ prevEventLength int
+ fullStateLength int
+ conflictLength int
+}
+
+func (c *calculateStateMetrics) stop(stateNID types.StateSnapshotNID, err error) (types.StateSnapshotNID, error) {
+ var outcome string
+ if err == nil {
+ outcome = "success"
+ } else {
+ outcome = "failure"
+ }
+ endTime := time.Now()
+ calculateStateDurations.WithLabelValues(c.algorithm, outcome).Observe(
+ float64(endTime.Sub(c.startTime).Nanoseconds()) / 1000.,
+ )
+ calculateStatePrevEventLength.WithLabelValues(c.algorithm, outcome).Observe(
+ float64(c.prevEventLength),
+ )
+ calculateStateFullStateLength.WithLabelValues(c.algorithm, outcome).Observe(
+ float64(c.fullStateLength),
+ )
+ calculateStateConflictLength.WithLabelValues(c.algorithm, outcome).Observe(
+ float64(c.conflictLength),
+ )
+ return stateNID, err
+}
+
+func init() {
+ prometheus.MustRegister(
+ calculateStateDurations, calculateStatePrevEventLength,
+ calculateStateFullStateLength, calculateStateConflictLength,
+ )
+}
+
+// CalculateAndStoreStateBeforeEvent calculates a snapshot of the state of a room before an event.
+// Stores the snapshot of the state in the database.
+// Returns a numeric ID for the snapshot of the state before the event.
+func CalculateAndStoreStateBeforeEvent(
+ ctx context.Context,
+ db RoomStateDatabase,
+ event gomatrixserverlib.Event,
+ roomNID types.RoomNID,
+) (types.StateSnapshotNID, error) {
+ // Load the state at the prev events.
+ prevEventRefs := event.PrevEvents()
+ prevEventIDs := make([]string, len(prevEventRefs))
+ for i := range prevEventRefs {
+ prevEventIDs[i] = prevEventRefs[i].EventID
+ }
+
+ prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs)
+ if err != nil {
+ return 0, err
+ }
+
+ // The state before this event will be the state after the events that came before it.
+ return CalculateAndStoreStateAfterEvents(ctx, db, roomNID, prevStates)
+}
+
+// CalculateAndStoreStateAfterEvents finds the room state after the given events.
+// Stores the resulting state in the database and returns a numeric ID for that snapshot.
+func CalculateAndStoreStateAfterEvents(
+ ctx context.Context,
+ db RoomStateDatabase,
+ roomNID types.RoomNID,
+ prevStates []types.StateAtEvent,
+) (types.StateSnapshotNID, error) {
+ metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
+
+ if len(prevStates) == 0 {
+ // 2) There weren't any prev_events for this event so the state is
+ // empty.
+ metrics.algorithm = "empty_state"
+ return metrics.stop(db.AddState(ctx, roomNID, nil, nil))
+ }
+
+ if len(prevStates) == 1 {
+ prevState := prevStates[0]
+ if prevState.EventStateKeyNID == 0 {
+ // 3) None of the previous events were state events and they all
+ // have the same state, so this event has exactly the same state
+ // as the previous events.
+ // This should be the common case.
+ metrics.algorithm = "no_change"
+ return metrics.stop(prevState.BeforeStateSnapshotNID, nil)
+ }
+ // The previous event was a state event so we need to store a copy
+ // of the previous state updated with that event.
+ stateBlockNIDLists, err := db.StateBlockNIDs(
+ ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID},
+ )
+ if err != nil {
+ metrics.algorithm = "_load_state_blocks"
+ return metrics.stop(0, err)
+ }
+ stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs
+ if len(stateBlockNIDs) < maxStateBlockNIDs {
+ // 4) The number of state data blocks is small enough that we can just
+ // add the state event as a block of size one to the end of the blocks.
+ metrics.algorithm = "single_delta"
+ return metrics.stop(db.AddState(
+ ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
+ ))
+ }
+ // If there are too many deltas then we need to calculate the full state
+ // So fall through to calculateAndStoreStateAfterManyEvents
+ }
+
+ return calculateAndStoreStateAfterManyEvents(ctx, db, roomNID, prevStates, metrics)
+}
+
+// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
+// Increasing this number means that we can encode more of the state changes as simple deltas which means that
+// we need fewer entries in the state data table. However making this number bigger will increase the size of
+// the rows in the state table itself and will require more index lookups when retrieving a snapshot.
+// TODO: Tune this to get the right balance between size and lookup performance.
+const maxStateBlockNIDs = 64
+
+// calculateAndStoreStateAfterManyEvents finds the room state after the given events.
+// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event.
+// Stores the resulting state and returns a numeric ID for the snapshot.
+func calculateAndStoreStateAfterManyEvents(
+ ctx context.Context,
+ db RoomStateDatabase,
+ roomNID types.RoomNID,
+ prevStates []types.StateAtEvent,
+ metrics calculateStateMetrics,
+) (types.StateSnapshotNID, error) {
+
+ state, algorithm, conflictLength, err :=
+ calculateStateAfterManyEvents(ctx, db, prevStates)
+ metrics.algorithm = algorithm
+ if err != nil {
+ return metrics.stop(0, err)
+ }
+
+ // TODO: Check if we can encode the new state as a delta against the
+ // previous state.
+ metrics.conflictLength = conflictLength
+ metrics.fullStateLength = len(state)
+ return metrics.stop(db.AddState(ctx, roomNID, nil, state))
+}
+
+func calculateStateAfterManyEvents(
+ ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent,
+) (state []types.StateEntry, algorithm string, conflictLength int, err error) {
+ var combined []types.StateEntry
+ // Conflict resolution.
+ // First stage: load the state after each of the prev events.
+ combined, err = LoadCombinedStateAfterEvents(ctx, db, prevStates)
+ if err != nil {
+ algorithm = "_load_combined_state"
+ return
+ }
+
+ // Collect all the entries with the same type and key together.
+ // We don't care about the order here because the conflict resolution
+ // algorithm doesn't depend on the order of the prev events.
+ // Remove duplicate entires.
+ combined = combined[:util.SortAndUnique(stateEntrySorter(combined))]
+
+ // Find the conflicts
+ conflicts := findDuplicateStateKeys(combined)
+
+ if len(conflicts) > 0 {
+ conflictLength = len(conflicts)
+
+ // 5) There are conflicting state events, for each conflict workout
+ // what the appropriate state event is.
+
+ // Work out which entries aren't conflicted.
+ var notConflicted []types.StateEntry
+ for _, entry := range combined {
+ if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok {
+ notConflicted = append(notConflicted, entry)
+ }
+ }
+
+ var resolved []types.StateEntry
+ resolved, err = resolveConflicts(ctx, db, notConflicted, conflicts)
+ if err != nil {
+ algorithm = "_resolve_conflicts"
+ return
+ }
+ algorithm = "full_state_with_conflicts"
+ state = resolved
+ } else {
+ algorithm = "full_state_no_conflicts"
+ // 6) There weren't any conflicts
+ state = combined
+ }
+ return
+}
+
+// resolveConflicts resolves a list of conflicted state entries. It takes two lists.
+// The first is a list of all state entries that are not conflicted.
+// The second is a list of all state entries that are conflicted
+// A state entry is conflicted when there is more than one numeric event ID for the same state key tuple.
+// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts.
+// The returned list is sorted by state key tuple.
+// Returns an error if there was a problem talking to the database.
+func resolveConflicts(
+ ctx context.Context,
+ db RoomStateDatabase,
+ notConflicted, conflicted []types.StateEntry,
+) ([]types.StateEntry, error) {
+
+ // Load the conflicted events
+ conflictedEvents, eventIDMap, err := loadStateEvents(ctx, db, conflicted)
+ if err != nil {
+ return nil, err
+ }
+
+ // Work out which auth events we need to load.
+ needed := gomatrixserverlib.StateNeededForAuth(conflictedEvents)
+
+ // Find the numeric IDs for the necessary state keys.
+ var neededStateKeys []string
+ neededStateKeys = append(neededStateKeys, needed.Member...)
+ neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
+ stateKeyNIDMap, err := db.EventStateKeyNIDs(ctx, neededStateKeys)
+ if err != nil {
+ return nil, err
+ }
+
+ // Load the necessary auth events.
+ tuplesNeeded := stateKeyTuplesNeeded(stateKeyNIDMap, needed)
+ var authEntries []types.StateEntry
+ for _, tuple := range tuplesNeeded {
+ if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok {
+ authEntries = append(authEntries, types.StateEntry{
+ StateKeyTuple: tuple,
+ EventNID: eventNID,
+ })
+ }
+ }
+ authEvents, _, err := loadStateEvents(ctx, db, authEntries)
+ if err != nil {
+ return nil, err
+ }
+
+ // Resolve the conflicts.
+ resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents)
+
+ // Map from the full events back to numeric state entries.
+ for _, resolvedEvent := range resolvedEvents {
+ entry, ok := eventIDMap[resolvedEvent.EventID()]
+ if !ok {
+ panic(fmt.Errorf("Missing state entry for event ID %q", resolvedEvent.EventID()))
+ }
+ notConflicted = append(notConflicted, entry)
+ }
+
+ // Sort the result so it can be searched.
+ sort.Sort(stateEntrySorter(notConflicted))
+ return notConflicted, nil
+}
+
+// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
+func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
+ var keyTuples []types.StateKeyTuple
+ if stateNeeded.Create {
+ keyTuples = append(keyTuples, types.StateKeyTuple{
+ EventTypeNID: types.MRoomCreateNID,
+ EventStateKeyNID: types.EmptyStateKeyNID,
+ })
+ }
+ if stateNeeded.PowerLevels {
+ keyTuples = append(keyTuples, types.StateKeyTuple{
+ EventTypeNID: types.MRoomPowerLevelsNID,
+ EventStateKeyNID: types.EmptyStateKeyNID,
+ })
+ }
+ if stateNeeded.JoinRules {
+ keyTuples = append(keyTuples, types.StateKeyTuple{
+ EventTypeNID: types.MRoomJoinRulesNID,
+ EventStateKeyNID: types.EmptyStateKeyNID,
+ })
+ }
+ for _, member := range stateNeeded.Member {
+ stateKeyNID, ok := stateKeyNIDMap[member]
+ if ok {
+ keyTuples = append(keyTuples, types.StateKeyTuple{
+ EventTypeNID: types.MRoomMemberNID,
+ EventStateKeyNID: stateKeyNID,
+ })
+ }
+ }
+ for _, token := range stateNeeded.ThirdPartyInvite {
+ stateKeyNID, ok := stateKeyNIDMap[token]
+ if ok {
+ keyTuples = append(keyTuples, types.StateKeyTuple{
+ EventTypeNID: types.MRoomThirdPartyInviteNID,
+ EventStateKeyNID: stateKeyNID,
+ })
+ }
+ }
+ return keyTuples
+}
+
+// loadStateEvents loads the matrix events for a list of state entries.
+// Returns a list of state events in no particular order and a map from string event ID back to state entry.
+// The map can be used to recover which numeric state entry a given event is for.
+// Returns an error if there was a problem talking to the database.
+func loadStateEvents(
+ ctx context.Context, db RoomStateDatabase, entries []types.StateEntry,
+) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) {
+ eventNIDs := make([]types.EventNID, len(entries))
+ for i := range entries {
+ eventNIDs[i] = entries[i].EventNID
+ }
+ events, err := db.Events(ctx, eventNIDs)
+ if err != nil {
+ return nil, nil, err
+ }
+ eventIDMap := map[string]types.StateEntry{}
+ result := make([]gomatrixserverlib.Event, len(entries))
+ for i := range entries {
+ event, ok := eventMap(events).lookup(entries[i].EventNID)
+ if !ok {
+ panic(fmt.Errorf("Corrupt DB: Missing event numeric ID %d", entries[i].EventNID))
+ }
+ result[i] = event.Event
+ eventIDMap[event.Event.EventID()] = entries[i]
+ }
+ return result, eventIDMap, nil
+}
+
+// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list.
+// Returns a sorted list of those state entries.
+func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry {
+ var result []types.StateEntry
+ // j is the starting index of a block of entries with the same state key tuple.
+ j := 0
+ for i := 1; i < len(a); i++ {
+ // Check if the state key tuple matches the start of the block
+ if a[j].StateKeyTuple != a[i].StateKeyTuple {
+ // If the state key tuple is different then we've reached the end of a block of duplicates.
+ // Check if the size of the block is bigger than one.
+ // If the size is one then there was only a single entry with that state key tuple so we don't add it to the result
+ if j+1 != i {
+ // Add the block to the result.
+ result = append(result, a[j:i]...)
+ }
+ // Start a new block for the next state key tuple.
+ j = i
+ }
+ }
+ // Check if the last block with the same state key tuple had more than one event in it.
+ if j+1 != len(a) {
+ result = append(result, a[j:]...)
+ }
+ return result
+}
+
+type stateEntrySorter []types.StateEntry
+
+func (s stateEntrySorter) Len() int { return len(s) }
+func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
+func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+type stateBlockNIDListMap []types.StateBlockNIDList
+
+func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) {
+ list := []types.StateBlockNIDList(m)
+ i := sort.Search(len(list), func(i int) bool {
+ return list[i].StateSnapshotNID >= stateNID
+ })
+ if i < len(list) && list[i].StateSnapshotNID == stateNID {
+ ok = true
+ stateBlockNIDs = list[i].StateBlockNIDs
+ }
+ return
+}
+
+type stateEntryListMap []types.StateEntryList
+
+func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) {
+ list := []types.StateEntryList(m)
+ i := sort.Search(len(list), func(i int) bool {
+ return list[i].StateBlockNID >= stateBlockNID
+ })
+ if i < len(list) && list[i].StateBlockNID == stateBlockNID {
+ ok = true
+ stateEntries = list[i].StateEntries
+ }
+ return
+}
+
+type stateEntryByStateKeySorter []types.StateEntry
+
+func (s stateEntryByStateKeySorter) Len() int { return len(s) }
+func (s stateEntryByStateKeySorter) Less(i, j int) bool {
+ return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple)
+}
+func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+type stateNIDSorter []types.StateSnapshotNID
+
+func (s stateNIDSorter) Len() int { return len(s) }
+func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
+func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID {
+ return nids[:util.SortAndUnique(stateNIDSorter(nids))]
+}
+
+type stateBlockNIDSorter []types.StateBlockNID
+
+func (s stateBlockNIDSorter) Len() int { return len(s) }
+func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
+func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID {
+ return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))]
+}
+
+// Map from event type, state key tuple to numeric event ID.
+// Implemented using binary search on a sorted array.
+type stateEntryMap []types.StateEntry
+
+// lookup an entry in the event map.
+func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) {
+ // Since the list is sorted we can implement this using binary search.
+ // This is faster than using a hash map.
+ // We don't have to worry about pathological cases because the keys are fixed
+ // size and are controlled by us.
+ list := []types.StateEntry(m)
+ i := sort.Search(len(list), func(i int) bool {
+ return !list[i].StateKeyTuple.LessThan(stateKey)
+ })
+ if i < len(list) && list[i].StateKeyTuple == stateKey {
+ ok = true
+ eventNID = list[i].EventNID
+ }
+ return
+}
+
+// Map from numeric event ID to event.
+// Implemented using binary search on a sorted array.
+type eventMap []types.Event
+
+// lookup an entry in the event map.
+func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
+ // Since the list is sorted we can implement this using binary search.
+ // This is faster than using a hash map.
+ // We don't have to worry about pathological cases because the keys are fixed
+ // size are controlled by us.
+ list := []types.Event(m)
+ i := sort.Search(len(list), func(i int) bool {
+ return list[i].EventNID >= eventNID
+ })
+ if i < len(list) && list[i].EventNID == eventNID {
+ ok = true
+ event = &list[i]
+ }
+ return
+}
diff --git a/roomserver/state/state_test.go b/roomserver/state/state_test.go
new file mode 100644
index 00000000..67af1867
--- /dev/null
+++ b/roomserver/state/state_test.go
@@ -0,0 +1,56 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "testing"
+
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+func TestFindDuplicateStateKeys(t *testing.T) {
+ testCases := []struct {
+ Input []types.StateEntry
+ Want []types.StateEntry
+ }{{
+ Input: []types.StateEntry{
+ {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1},
+ {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 2},
+ {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 2, EventStateKeyNID: 2}, EventNID: 3},
+ },
+ Want: []types.StateEntry{
+ {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1},
+ {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 2},
+ },
+ }, {
+ Input: []types.StateEntry{
+ {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1},
+ {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 2}, EventNID: 2},
+ },
+ Want: nil,
+ }}
+
+ for _, test := range testCases {
+ got := findDuplicateStateKeys(test.Input)
+ if len(got) != len(test.Want) {
+ t.Fatalf("Wanted %v, got %v", test.Want, got)
+ }
+ for i := range got {
+ if got[i] != test.Want[i] {
+ t.Fatalf("Wanted %v, got %v", test.Want, got)
+ }
+ }
+ }
+}
diff --git a/roomserver/storage/event_json_table.go b/roomserver/storage/event_json_table.go
new file mode 100644
index 00000000..b81667d9
--- /dev/null
+++ b/roomserver/storage/event_json_table.go
@@ -0,0 +1,105 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const eventJSONSchema = `
+-- Stores the JSON for each event. This kept separate from the main events
+-- table to keep the rows in the main events table small.
+CREATE TABLE IF NOT EXISTS roomserver_event_json (
+ -- Local numeric ID for the event.
+ event_nid BIGINT NOT NULL PRIMARY KEY,
+ -- The JSON for the event.
+ -- Stored as TEXT because this should be valid UTF-8.
+ -- Not stored as a JSONB because we always just pull the entire event
+ -- so there is no point in postgres parsing it.
+ -- Not stored as JSON because we already validate the JSON in the server
+ -- so there is no point in postgres validating it.
+ -- TODO: Should we be compressing the events with Snappy or DEFLATE?
+ event_json TEXT NOT NULL
+);
+`
+
+const insertEventJSONSQL = "" +
+ "INSERT INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)" +
+ " ON CONFLICT DO NOTHING"
+
+// Bulk event JSON lookup by numeric event ID.
+// Sort by the numeric event ID.
+// This means that we can use binary search to lookup by numeric event ID.
+const bulkSelectEventJSONSQL = "" +
+ "SELECT event_nid, event_json FROM roomserver_event_json" +
+ " WHERE event_nid = ANY($1)" +
+ " ORDER BY event_nid ASC"
+
+type eventJSONStatements struct {
+ insertEventJSONStmt *sql.Stmt
+ bulkSelectEventJSONStmt *sql.Stmt
+}
+
+func (s *eventJSONStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(eventJSONSchema)
+ if err != nil {
+ return
+ }
+ return statementList{
+ {&s.insertEventJSONStmt, insertEventJSONSQL},
+ {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL},
+ }.prepare(db)
+}
+
+func (s *eventJSONStatements) insertEventJSON(
+ ctx context.Context, eventNID types.EventNID, eventJSON []byte,
+) error {
+ _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON)
+ return err
+}
+
+type eventJSONPair struct {
+ EventNID types.EventNID
+ EventJSON []byte
+}
+
+func (s *eventJSONStatements) bulkSelectEventJSON(
+ ctx context.Context, eventNIDs []types.EventNID,
+) ([]eventJSONPair, error) {
+ rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ // We know that we will only get as many results as event NIDs
+ // because of the unique constraint on event NIDs.
+ // So we can allocate an array of the correct size now.
+ // We might get fewer results than NIDs so we adjust the length of the slice before returning it.
+ results := make([]eventJSONPair, len(eventNIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ result := &results[i]
+ var eventNID int64
+ if err := rows.Scan(&eventNID, &result.EventJSON); err != nil {
+ return nil, err
+ }
+ result.EventNID = types.EventNID(eventNID)
+ }
+ return results[:i], nil
+}
diff --git a/roomserver/storage/event_state_keys_table.go b/roomserver/storage/event_state_keys_table.go
new file mode 100644
index 00000000..1ef93370
--- /dev/null
+++ b/roomserver/storage/event_state_keys_table.go
@@ -0,0 +1,153 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const eventStateKeysSchema = `
+-- Numeric versions of the event "state_key"s. State keys tend to be reused so
+-- assigning each string a numeric ID should reduce the amount of data that
+-- needs to be stored and fetched from the database.
+-- It also means that many operations can work with int64 arrays rather than
+-- string arrays which may help reduce GC pressure.
+-- Well known state keys are pre-assigned numeric IDs:
+-- 1 -> "" (the empty string)
+-- Other state keys are automatically assigned numeric IDs starting from 2**16.
+-- This leaves room to add more pre-assigned numeric IDs and clearly separates
+-- the automatically assigned IDs from the pre-assigned IDs.
+CREATE SEQUENCE IF NOT EXISTS roomserver_event_state_key_nid_seq START 65536;
+CREATE TABLE IF NOT EXISTS roomserver_event_state_keys (
+ -- Local numeric ID for the state key.
+ event_state_key_nid BIGINT PRIMARY KEY DEFAULT nextval('roomserver_event_state_key_nid_seq'),
+ event_state_key TEXT NOT NULL CONSTRAINT roomserver_event_state_key_unique UNIQUE
+);
+INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key) VALUES
+ (1, '') ON CONFLICT DO NOTHING;
+`
+
+// Same as insertEventTypeNIDSQL
+const insertEventStateKeyNIDSQL = "" +
+ "INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1)" +
+ " ON CONFLICT ON CONSTRAINT roomserver_event_state_key_unique" +
+ " DO NOTHING RETURNING (event_state_key_nid)"
+
+const selectEventStateKeyNIDSQL = "" +
+ "SELECT event_state_key_nid FROM roomserver_event_state_keys" +
+ " WHERE event_state_key = $1"
+
+// Bulk lookup from string state key to numeric ID for that state key.
+// Takes an array of strings as the query parameter.
+const bulkSelectEventStateKeyNIDSQL = "" +
+ "SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys" +
+ " WHERE event_state_key = ANY($1)"
+
+// Bulk lookup from numeric ID to string state key for that state key.
+// Takes an array of strings as the query parameter.
+const bulkSelectEventStateKeySQL = "" +
+ "SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys" +
+ " WHERE event_state_key_nid = ANY($1)"
+
+type eventStateKeyStatements struct {
+ insertEventStateKeyNIDStmt *sql.Stmt
+ selectEventStateKeyNIDStmt *sql.Stmt
+ bulkSelectEventStateKeyNIDStmt *sql.Stmt
+ bulkSelectEventStateKeyStmt *sql.Stmt
+}
+
+func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(eventStateKeysSchema)
+ if err != nil {
+ return
+ }
+ return statementList{
+ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
+ {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
+ {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL},
+ {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL},
+ }.prepare(db)
+}
+
+func (s *eventStateKeyStatements) insertEventStateKeyNID(
+ ctx context.Context, txn *sql.Tx, eventStateKey string,
+) (types.EventStateKeyNID, error) {
+ var eventStateKeyNID int64
+ stmt := common.TxStmt(txn, s.insertEventStateKeyNIDStmt)
+ err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
+ return types.EventStateKeyNID(eventStateKeyNID), err
+}
+
+func (s *eventStateKeyStatements) selectEventStateKeyNID(
+ ctx context.Context, txn *sql.Tx, eventStateKey string,
+) (types.EventStateKeyNID, error) {
+ var eventStateKeyNID int64
+ stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt)
+ err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
+ return types.EventStateKeyNID(eventStateKeyNID), err
+}
+
+func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
+ ctx context.Context, eventStateKeys []string,
+) (map[string]types.EventStateKeyNID, error) {
+ rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext(
+ ctx, pq.StringArray(eventStateKeys),
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ result := make(map[string]types.EventStateKeyNID, len(eventStateKeys))
+ for rows.Next() {
+ var stateKey string
+ var stateKeyNID int64
+ if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
+ return nil, err
+ }
+ result[stateKey] = types.EventStateKeyNID(stateKeyNID)
+ }
+ return result, nil
+}
+
+func (s *eventStateKeyStatements) bulkSelectEventStateKey(
+ ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
+) (map[types.EventStateKeyNID]string, error) {
+ nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
+ for i := range eventStateKeyNIDs {
+ nIDs[i] = int64(eventStateKeyNIDs[i])
+ }
+ rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
+ for rows.Next() {
+ var stateKey string
+ var stateKeyNID int64
+ if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
+ return nil, err
+ }
+ result[types.EventStateKeyNID(stateKeyNID)] = stateKey
+ }
+ return result, nil
+}
diff --git a/roomserver/storage/event_types_table.go b/roomserver/storage/event_types_table.go
new file mode 100644
index 00000000..7b8d53a5
--- /dev/null
+++ b/roomserver/storage/event_types_table.go
@@ -0,0 +1,146 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const eventTypesSchema = `
+-- Numeric versions of the event "type"s. Event types tend to be taken from a
+-- small common pool. Assigning each a numeric ID should reduce the amount of
+-- data that needs to be stored and fetched from the database.
+-- It also means that many operations can work with int64 arrays rather than
+-- string arrays which may help reduce GC pressure.
+-- Well known event types are pre-assigned numeric IDs:
+-- 1 -> m.room.create
+-- 2 -> m.room.power_levels
+-- 3 -> m.room.join_rules
+-- 4 -> m.room.third_party_invite
+-- 5 -> m.room.member
+-- 6 -> m.room.redaction
+-- 7 -> m.room.history_visibility
+-- Picking well-known numeric IDs for the events types that require special
+-- attention during state conflict resolution means that we write that code
+-- using numeric constants.
+-- It also means that the numeric IDs for common event types should be
+-- consistent between different instances which might make ad-hoc debugging
+-- easier.
+-- Other event types are automatically assigned numeric IDs starting from 2**16.
+-- This leaves room to add more pre-assigned numeric IDs and clearly separates
+-- the automatically assigned IDs from the pre-assigned IDs.
+CREATE SEQUENCE IF NOT EXISTS roomserver_event_type_nid_seq START 65536;
+CREATE TABLE IF NOT EXISTS roomserver_event_types (
+ -- Local numeric ID for the event type.
+ event_type_nid BIGINT PRIMARY KEY DEFAULT nextval('roomserver_event_type_nid_seq'),
+ -- The string event_type.
+ event_type TEXT NOT NULL CONSTRAINT roomserver_event_type_unique UNIQUE
+);
+INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES
+ (1, 'm.room.create'),
+ (2, 'm.room.power_levels'),
+ (3, 'm.room.join_rules'),
+ (4, 'm.room.third_party_invite'),
+ (5, 'm.room.member'),
+ (6, 'm.room.redaction'),
+ (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
+`
+
+// Assign a new numeric event type ID.
+// The usual case is that the event type is not in the database.
+// In that case the ID will be assigned using the next value from the sequence.
+// We use `RETURNING` to tell postgres to return the assigned ID.
+// But it's possible that the type was added in a query that raced with us.
+// This will result in a conflict on the event_type_unique constraint, in this
+// case we do nothing. Postgresql won't return a row in that case so we rely on
+// the caller catching the sql.ErrNoRows error and running a select to get the row.
+// We could get postgresql to return the row on a conflict by updating the row
+// but it doesn't seem like a good idea to modify the rows just to make postgresql
+// return it. Modifying the rows will cause postgres to assign a new tuple for the
+// row even though the data doesn't change resulting in unncesssary modifications
+// to the indexes.
+const insertEventTypeNIDSQL = "" +
+ "INSERT INTO roomserver_event_types (event_type) VALUES ($1)" +
+ " ON CONFLICT ON CONSTRAINT roomserver_event_type_unique" +
+ " DO NOTHING RETURNING (event_type_nid)"
+
+const selectEventTypeNIDSQL = "" +
+ "SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1"
+
+// Bulk lookup from string event type to numeric ID for that event type.
+// Takes an array of strings as the query parameter.
+const bulkSelectEventTypeNIDSQL = "" +
+ "SELECT event_type, event_type_nid FROM roomserver_event_types" +
+ " WHERE event_type = ANY($1)"
+
+type eventTypeStatements struct {
+ insertEventTypeNIDStmt *sql.Stmt
+ selectEventTypeNIDStmt *sql.Stmt
+ bulkSelectEventTypeNIDStmt *sql.Stmt
+}
+
+func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(eventTypesSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
+ {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL},
+ {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL},
+ }.prepare(db)
+}
+
+func (s *eventTypeStatements) insertEventTypeNID(
+ ctx context.Context, eventType string,
+) (types.EventTypeNID, error) {
+ var eventTypeNID int64
+ err := s.insertEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
+ return types.EventTypeNID(eventTypeNID), err
+}
+
+func (s *eventTypeStatements) selectEventTypeNID(
+ ctx context.Context, eventType string,
+) (types.EventTypeNID, error) {
+ var eventTypeNID int64
+ err := s.selectEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
+ return types.EventTypeNID(eventTypeNID), err
+}
+
+func (s *eventTypeStatements) bulkSelectEventTypeNID(
+ ctx context.Context, eventTypes []string,
+) (map[string]types.EventTypeNID, error) {
+ rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes))
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ result := make(map[string]types.EventTypeNID, len(eventTypes))
+ for rows.Next() {
+ var eventType string
+ var eventTypeNID int64
+ if err := rows.Scan(&eventType, &eventTypeNID); err != nil {
+ return nil, err
+ }
+ result[eventType] = types.EventTypeNID(eventTypeNID)
+ }
+ return result, nil
+}
diff --git a/roomserver/storage/events_table.go b/roomserver/storage/events_table.go
new file mode 100644
index 00000000..5bad939f
--- /dev/null
+++ b/roomserver/storage/events_table.go
@@ -0,0 +1,410 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const eventsSchema = `
+-- The events table holds metadata for each event, the actual JSON is stored
+-- separately to keep the size of the rows small.
+CREATE SEQUENCE IF NOT EXISTS roomserver_event_nid_seq;
+CREATE TABLE IF NOT EXISTS roomserver_events (
+ -- Local numeric ID for the event.
+ event_nid BIGINT PRIMARY KEY DEFAULT nextval('roomserver_event_nid_seq'),
+ -- Local numeric ID for the room the event is in.
+ -- This is never 0.
+ room_nid BIGINT NOT NULL,
+ -- Local numeric ID for the type of the event.
+ -- This is never 0.
+ event_type_nid BIGINT NOT NULL,
+ -- Local numeric ID for the state_key of the event
+ -- This is 0 if the event is not a state event.
+ event_state_key_nid BIGINT NOT NULL,
+ -- Whether the event has been written to the output log.
+ sent_to_output BOOLEAN NOT NULL DEFAULT FALSE,
+ -- Local numeric ID for the state at the event.
+ -- This is 0 if we don't know the state at the event.
+ -- If the state is not 0 then this event is part of the contiguous
+ -- part of the event graph
+ -- Since many different events can have the same state we store the
+ -- state into a separate state table and refer to it by numeric ID.
+ state_snapshot_nid BIGINT NOT NULL DEFAULT 0,
+ -- Depth of the event in the event graph.
+ depth BIGINT NOT NULL,
+ -- The textual event id.
+ -- Used to lookup the numeric ID when processing requests.
+ -- Needed for state resolution.
+ -- An event may only appear in this table once.
+ event_id TEXT NOT NULL CONSTRAINT roomserver_event_id_unique UNIQUE,
+ -- The sha256 reference hash for the event.
+ -- Needed for setting reference hashes when sending new events.
+ reference_sha256 BYTEA NOT NULL,
+ -- A list of numeric IDs for events that can authenticate this event.
+ auth_event_nids BIGINT[] NOT NULL
+);
+`
+
+const insertEventSQL = "" +
+ "INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth)" +
+ " VALUES ($1, $2, $3, $4, $5, $6, $7)" +
+ " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique" +
+ " DO NOTHING" +
+ " RETURNING event_nid, state_snapshot_nid"
+
+const selectEventSQL = "" +
+ "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1"
+
+// Bulk lookup of events by string ID.
+// Sort by the numeric IDs for event type and state key.
+// This means we can use binary search to lookup entries by type and state key.
+const bulkSelectStateEventByIDSQL = "" +
+ "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
+ " WHERE event_id = ANY($1)" +
+ " ORDER BY event_type_nid, event_state_key_nid ASC"
+
+const bulkSelectStateAtEventByIDSQL = "" +
+ "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM roomserver_events" +
+ " WHERE event_id = ANY($1)"
+
+const updateEventStateSQL = "" +
+ "UPDATE roomserver_events SET state_snapshot_nid = $2 WHERE event_nid = $1"
+
+const selectEventSentToOutputSQL = "" +
+ "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1"
+
+const updateEventSentToOutputSQL = "" +
+ "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1"
+
+const selectEventIDSQL = "" +
+ "SELECT event_id FROM roomserver_events WHERE event_nid = $1"
+
+const bulkSelectStateAtEventAndReferenceSQL = "" +
+ "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" +
+ " FROM roomserver_events WHERE event_nid = ANY($1)"
+
+const bulkSelectEventReferenceSQL = "" +
+ "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid = ANY($1)"
+
+const bulkSelectEventIDSQL = "" +
+ "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid = ANY($1)"
+
+const bulkSelectEventNIDSQL = "" +
+ "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)"
+
+const selectMaxEventDepthSQL = "" +
+ "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)"
+
+type eventStatements struct {
+ insertEventStmt *sql.Stmt
+ selectEventStmt *sql.Stmt
+ bulkSelectStateEventByIDStmt *sql.Stmt
+ bulkSelectStateAtEventByIDStmt *sql.Stmt
+ updateEventStateStmt *sql.Stmt
+ selectEventSentToOutputStmt *sql.Stmt
+ updateEventSentToOutputStmt *sql.Stmt
+ selectEventIDStmt *sql.Stmt
+ bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
+ bulkSelectEventReferenceStmt *sql.Stmt
+ bulkSelectEventIDStmt *sql.Stmt
+ bulkSelectEventNIDStmt *sql.Stmt
+ selectMaxEventDepthStmt *sql.Stmt
+}
+
+func (s *eventStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(eventsSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertEventStmt, insertEventSQL},
+ {&s.selectEventStmt, selectEventSQL},
+ {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
+ {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
+ {&s.updateEventStateStmt, updateEventStateSQL},
+ {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
+ {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL},
+ {&s.selectEventIDStmt, selectEventIDSQL},
+ {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
+ {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
+ {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
+ {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
+ {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
+ }.prepare(db)
+}
+
+func (s *eventStatements) insertEvent(
+ ctx context.Context,
+ roomNID types.RoomNID,
+ eventTypeNID types.EventTypeNID,
+ eventStateKeyNID types.EventStateKeyNID,
+ eventID string,
+ referenceSHA256 []byte,
+ authEventNIDs []types.EventNID,
+ depth int64,
+) (types.EventNID, types.StateSnapshotNID, error) {
+ var eventNID int64
+ var stateNID int64
+ err := s.insertEventStmt.QueryRowContext(
+ ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
+ eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
+ ).Scan(&eventNID, &stateNID)
+ return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
+}
+
+func (s *eventStatements) selectEvent(
+ ctx context.Context, eventID string,
+) (types.EventNID, types.StateSnapshotNID, error) {
+ var eventNID int64
+ var stateNID int64
+ err := s.selectEventStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID)
+ return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
+}
+
+// bulkSelectStateEventByID lookups a list of state events by event ID.
+// If any of the requested events are missing from the database it returns a types.MissingEventError
+func (s *eventStatements) bulkSelectStateEventByID(
+ ctx context.Context, eventIDs []string,
+) ([]types.StateEntry, error) {
+ rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ // We know that we will only get as many results as event IDs
+ // because of the unique constraint on event IDs.
+ // So we can allocate an array of the correct size now.
+ // We might get fewer results than IDs so we adjust the length of the slice before returning it.
+ results := make([]types.StateEntry, len(eventIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ result := &results[i]
+ if err = rows.Scan(
+ &result.EventTypeNID,
+ &result.EventStateKeyNID,
+ &result.EventNID,
+ ); err != nil {
+ return nil, err
+ }
+ }
+ if i != len(eventIDs) {
+ // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
+ // We don't know which ones were missing because we don't return the string IDs in the query.
+ // However it should be possible debug this by replaying queries or entries from the input kafka logs.
+ // If this turns out to be impossible and we do need the debug information here, it would be better
+ // to do it as a separate query rather than slowing down/complicating the common case.
+ return nil, types.MissingEventError(
+ fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)),
+ )
+ }
+ return results, err
+}
+
+// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
+// If any of the requested events are missing from the database it returns a types.MissingEventError.
+// If we do not have the state for any of the requested events it returns a types.MissingEventError.
+func (s *eventStatements) bulkSelectStateAtEventByID(
+ ctx context.Context, eventIDs []string,
+) ([]types.StateAtEvent, error) {
+ rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := make([]types.StateAtEvent, len(eventIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ result := &results[i]
+ if err = rows.Scan(
+ &result.EventTypeNID,
+ &result.EventStateKeyNID,
+ &result.EventNID,
+ &result.BeforeStateSnapshotNID,
+ ); err != nil {
+ return nil, err
+ }
+ if result.BeforeStateSnapshotNID == 0 {
+ return nil, types.MissingEventError(
+ fmt.Sprintf("storage: missing state for event NID %d", result.EventNID),
+ )
+ }
+ }
+ if i != len(eventIDs) {
+ return nil, types.MissingEventError(
+ fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)),
+ )
+ }
+ return results, err
+}
+
+func (s *eventStatements) updateEventState(
+ ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
+) error {
+ _, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID))
+ return err
+}
+
+func (s *eventStatements) selectEventSentToOutput(
+ ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
+) (sentToOutput bool, err error) {
+ stmt := common.TxStmt(txn, s.selectEventSentToOutputStmt)
+ err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
+ return
+}
+
+func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
+ stmt := common.TxStmt(txn, s.updateEventSentToOutputStmt)
+ _, err := stmt.ExecContext(ctx, int64(eventNID))
+ return err
+}
+
+func (s *eventStatements) selectEventID(
+ ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
+) (eventID string, err error) {
+ stmt := common.TxStmt(txn, s.selectEventIDStmt)
+ err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID)
+ return
+}
+
+func (s *eventStatements) bulkSelectStateAtEventAndReference(
+ ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
+) ([]types.StateAtEventAndReference, error) {
+ stmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt)
+ rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := make([]types.StateAtEventAndReference, len(eventNIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ var (
+ eventTypeNID int64
+ eventStateKeyNID int64
+ eventNID int64
+ stateSnapshotNID int64
+ eventID string
+ eventSHA256 []byte
+ )
+ if err = rows.Scan(
+ &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256,
+ ); err != nil {
+ return nil, err
+ }
+ result := &results[i]
+ result.EventTypeNID = types.EventTypeNID(eventTypeNID)
+ result.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
+ result.EventNID = types.EventNID(eventNID)
+ result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID)
+ result.EventID = eventID
+ result.EventSHA256 = eventSHA256
+ }
+ if i != len(eventNIDs) {
+ return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
+ }
+ return results, nil
+}
+
+func (s *eventStatements) bulkSelectEventReference(
+ ctx context.Context, eventNIDs []types.EventNID,
+) ([]gomatrixserverlib.EventReference, error) {
+ rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := make([]gomatrixserverlib.EventReference, len(eventNIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ result := &results[i]
+ if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil {
+ return nil, err
+ }
+ }
+ if i != len(eventNIDs) {
+ return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
+ }
+ return results, nil
+}
+
+// bulkSelectEventID returns a map from numeric event ID to string event ID.
+func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
+ rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := make(map[types.EventNID]string, len(eventNIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ var eventNID int64
+ var eventID string
+ if err = rows.Scan(&eventNID, &eventID); err != nil {
+ return nil, err
+ }
+ results[types.EventNID(eventNID)] = eventID
+ }
+ if i != len(eventNIDs) {
+ return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
+ }
+ return results, nil
+}
+
+// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
+// If an event ID is not in the database then it is omitted from the map.
+func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
+ rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := make(map[string]types.EventNID, len(eventIDs))
+ for rows.Next() {
+ var eventID string
+ var eventNID int64
+ if err = rows.Scan(&eventID, &eventNID); err != nil {
+ return nil, err
+ }
+ results[eventID] = types.EventNID(eventNID)
+ }
+ return results, nil
+}
+
+func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) {
+ var result int64
+ stmt := s.selectMaxEventDepthStmt
+ err := stmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result)
+ if err != nil {
+ return 0, err
+ }
+ return result, nil
+}
+
+func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {
+ nids := make([]int64, len(eventNIDs))
+ for i := range eventNIDs {
+ nids[i] = int64(eventNIDs[i])
+ }
+ return nids
+}
diff --git a/roomserver/storage/invite_table.go b/roomserver/storage/invite_table.go
new file mode 100644
index 00000000..4f9cdfb4
--- /dev/null
+++ b/roomserver/storage/invite_table.go
@@ -0,0 +1,154 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const inviteSchema = `
+CREATE TABLE IF NOT EXISTS roomserver_invites (
+ -- The string ID of the invite event itself.
+ -- We can't use a numeric event ID here because we don't always have
+ -- enough information to store an invite in the event table.
+ -- In particular we don't always have a chain of auth_events for invites
+ -- received over federation.
+ invite_event_id TEXT PRIMARY KEY,
+ -- The numeric ID of the room the invite m.room.member event is in.
+ room_nid BIGINT NOT NULL,
+ -- The numeric ID for the state key of the invite m.room.member event.
+ -- This tells us who the invite is for.
+ -- This is used to query the active invites for a user.
+ target_nid BIGINT NOT NULL,
+ -- The numeric ID for the sender of the invite m.room.member event.
+ -- This tells us who sent the invite.
+ -- This is used to work out which matrix server we should talk to when
+ -- we try to join the room.
+ sender_nid BIGINT NOT NULL DEFAULT 0,
+ -- This is used to track whether the invite is still active.
+ -- This is set implicitly when processing new join and leave events and
+ -- explicitly when rejecting events over federation.
+ retired BOOLEAN NOT NULL DEFAULT FALSE,
+ -- The invite event JSON.
+ invite_event_json TEXT NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid)
+ WHERE NOT retired;
+`
+const insertInviteEventSQL = "" +
+ "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," +
+ " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
+ " ON CONFLICT DO NOTHING"
+
+const selectInviteActiveForUserInRoomSQL = "" +
+ "SELECT sender_nid FROM roomserver_invites" +
+ " WHERE target_nid = $1 AND room_nid = $2" +
+ " AND NOT retired"
+
+// Retire every active invite for a user in a room.
+// Ideally we'd know which invite events were retired by a given update so we
+// wouldn't need to remove every active invite.
+// However the matrix protocol doesn't give us a way to reliably identify the
+// invites that were retired, so we are forced to retire all of them.
+const updateInviteRetiredSQL = "" +
+ "UPDATE roomserver_invites SET retired = TRUE" +
+ " WHERE room_nid = $1 AND target_nid = $2 AND NOT retired" +
+ " RETURNING invite_event_id"
+
+type inviteStatements struct {
+ insertInviteEventStmt *sql.Stmt
+ selectInviteActiveForUserInRoomStmt *sql.Stmt
+ updateInviteRetiredStmt *sql.Stmt
+}
+
+func (s *inviteStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(inviteSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertInviteEventStmt, insertInviteEventSQL},
+ {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL},
+ {&s.updateInviteRetiredStmt, updateInviteRetiredSQL},
+ }.prepare(db)
+}
+
+func (s *inviteStatements) insertInviteEvent(
+ ctx context.Context,
+ txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
+ targetUserNID, senderUserNID types.EventStateKeyNID,
+ inviteEventJSON []byte,
+) (bool, error) {
+ result, err := common.TxStmt(txn, s.insertInviteEventStmt).ExecContext(
+ ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
+ )
+ if err != nil {
+ return false, err
+ }
+ count, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return count != 0, nil
+}
+
+func (s *inviteStatements) updateInviteRetired(
+ ctx context.Context,
+ txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+) (eventIDs []string, err error) {
+ stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
+ rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
+ if err != nil {
+ return nil, err
+ }
+ defer (func() { err = rows.Close() })()
+ for rows.Next() {
+ var inviteEventID string
+ if err := rows.Scan(&inviteEventID); err != nil {
+ return nil, err
+ }
+ eventIDs = append(eventIDs, inviteEventID)
+ }
+ return
+}
+
+// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
+func (s *inviteStatements) selectInviteActiveForUserInRoom(
+ ctx context.Context,
+ targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
+) ([]types.EventStateKeyNID, error) {
+ rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
+ ctx, targetUserNID, roomNID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ var result []types.EventStateKeyNID
+ for rows.Next() {
+ var senderUserNID int64
+ if err := rows.Scan(&senderUserNID); err != nil {
+ return nil, err
+ }
+ result = append(result, types.EventStateKeyNID(senderUserNID))
+ }
+ return result, nil
+}
diff --git a/roomserver/storage/membership_table.go b/roomserver/storage/membership_table.go
new file mode 100644
index 00000000..88a9ed72
--- /dev/null
+++ b/roomserver/storage/membership_table.go
@@ -0,0 +1,193 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+type membershipState int64
+
+const (
+ membershipStateLeaveOrBan membershipState = 1
+ membershipStateInvite membershipState = 2
+ membershipStateJoin membershipState = 3
+)
+
+const membershipSchema = `
+-- The membership table is used to coordinate updates between the invite table
+-- and the room state tables.
+-- This table is updated in one of 3 ways:
+-- 1) The membership of a user changes within the current state of the room.
+-- 2) An invite is received outside of a room over federation.
+-- 3) An invite is rejected outside of a room over federation.
+CREATE TABLE IF NOT EXISTS roomserver_membership (
+ room_nid BIGINT NOT NULL,
+ -- Numeric state key ID for the user ID this state is for.
+ target_nid BIGINT NOT NULL,
+ -- Numeric state key ID for the user ID who changed the state.
+ -- This may be 0 since it is not always possible to identify the user that
+ -- changed the state.
+ sender_nid BIGINT NOT NULL DEFAULT 0,
+ -- The state the user is in within this room.
+ -- Default value is "membershipStateLeaveOrBan"
+ membership_nid BIGINT NOT NULL DEFAULT 1,
+ -- The numeric ID of the membership event.
+ -- It refers to the join membership event if the membership_nid is join (3),
+ -- and to the leave/ban membership event if the membership_nid is leave or
+ -- ban (1).
+ -- If the membership_nid is invite (2) and the user has been in the room
+ -- before, it will refer to the previous leave/ban membership event, and will
+ -- be equals to 0 (its default) if the user never joined the room before.
+ -- This NID is updated if the join event gets updated (e.g. profile update),
+ -- or if the user leaves/joins the room.
+ event_nid BIGINT NOT NULL DEFAULT 0,
+ UNIQUE (room_nid, target_nid)
+);
+`
+
+// Insert a row in to membership table so that it can be locked by the
+// SELECT FOR UPDATE
+const insertMembershipSQL = "" +
+ "INSERT INTO roomserver_membership (room_nid, target_nid)" +
+ " VALUES ($1, $2)" +
+ " ON CONFLICT DO NOTHING"
+
+const selectMembershipFromRoomAndTargetSQL = "" +
+ "SELECT membership_nid, event_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1 AND target_nid = $2"
+
+const selectMembershipsFromRoomAndMembershipSQL = "" +
+ "SELECT event_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1 AND membership_nid = $2"
+
+const selectMembershipsFromRoomSQL = "" +
+ "SELECT event_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1"
+
+const selectMembershipForUpdateSQL = "" +
+ "SELECT membership_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE"
+
+const updateMembershipSQL = "" +
+ "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" +
+ " WHERE room_nid = $1 AND target_nid = $2"
+
+type membershipStatements struct {
+ insertMembershipStmt *sql.Stmt
+ selectMembershipForUpdateStmt *sql.Stmt
+ selectMembershipFromRoomAndTargetStmt *sql.Stmt
+ selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
+ selectMembershipsFromRoomStmt *sql.Stmt
+ updateMembershipStmt *sql.Stmt
+}
+
+func (s *membershipStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(membershipSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertMembershipStmt, insertMembershipSQL},
+ {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
+ {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
+ {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
+ {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
+ {&s.updateMembershipStmt, updateMembershipSQL},
+ }.prepare(db)
+}
+
+func (s *membershipStatements) insertMembership(
+ ctx context.Context,
+ txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+) error {
+ stmt := common.TxStmt(txn, s.insertMembershipStmt)
+ _, err := stmt.ExecContext(ctx, roomNID, targetUserNID)
+ return err
+}
+
+func (s *membershipStatements) selectMembershipForUpdate(
+ ctx context.Context,
+ txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+) (membership membershipState, err error) {
+ err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
+ ctx, roomNID, targetUserNID,
+ ).Scan(&membership)
+ return
+}
+
+func (s *membershipStatements) selectMembershipFromRoomAndTarget(
+ ctx context.Context,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+) (eventNID types.EventNID, membership membershipState, err error) {
+ err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
+ ctx, roomNID, targetUserNID,
+ ).Scan(&membership, &eventNID)
+ return
+}
+
+func (s *membershipStatements) selectMembershipsFromRoom(
+ ctx context.Context, roomNID types.RoomNID,
+) (eventNIDs []types.EventNID, err error) {
+ rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID)
+ if err != nil {
+ return
+ }
+
+ for rows.Next() {
+ var eNID types.EventNID
+ if err = rows.Scan(&eNID); err != nil {
+ return
+ }
+ eventNIDs = append(eventNIDs, eNID)
+ }
+ return
+}
+func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
+ ctx context.Context,
+ roomNID types.RoomNID, membership membershipState,
+) (eventNIDs []types.EventNID, err error) {
+ stmt := s.selectMembershipsFromRoomAndMembershipStmt
+ rows, err := stmt.QueryContext(ctx, roomNID, membership)
+ if err != nil {
+ return
+ }
+
+ for rows.Next() {
+ var eNID types.EventNID
+ if err = rows.Scan(&eNID); err != nil {
+ return
+ }
+ eventNIDs = append(eventNIDs, eNID)
+ }
+ return
+}
+
+func (s *membershipStatements) updateMembership(
+ ctx context.Context,
+ txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+ senderUserNID types.EventStateKeyNID, membership membershipState,
+ eventNID types.EventNID,
+) error {
+ _, err := common.TxStmt(txn, s.updateMembershipStmt).ExecContext(
+ ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID,
+ )
+ return err
+}
diff --git a/roomserver/storage/prepare.go b/roomserver/storage/prepare.go
new file mode 100644
index 00000000..61c49a3c
--- /dev/null
+++ b/roomserver/storage/prepare.go
@@ -0,0 +1,36 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "database/sql"
+)
+
+// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
+type statementList []struct {
+ statement **sql.Stmt
+ sql string
+}
+
+// prepare the SQL for each statement in the list and assign the result to the prepared statement.
+// nolint: safesql
+func (s statementList) prepare(db *sql.DB) (err error) {
+ for _, statement := range s {
+ if *statement.statement, err = db.Prepare(statement.sql); err != nil {
+ return
+ }
+ }
+ return
+}
diff --git a/roomserver/storage/previous_events_table.go b/roomserver/storage/previous_events_table.go
new file mode 100644
index 00000000..81d581a9
--- /dev/null
+++ b/roomserver/storage/previous_events_table.go
@@ -0,0 +1,99 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const previousEventSchema = `
+-- The previous events table stores the event_ids referenced by the events
+-- stored in the events table.
+-- This is used to tell if a new event is already referenced by an event in
+-- the database.
+CREATE TABLE IF NOT EXISTS roomserver_previous_events (
+ -- The string event ID taken from the prev_events key of an event.
+ previous_event_id TEXT NOT NULL,
+ -- The SHA256 reference hash taken from the prev_events key of an event.
+ previous_reference_sha256 BYTEA NOT NULL,
+ -- A list of numeric event IDs of events that reference this prev_event.
+ event_nids BIGINT[] NOT NULL,
+ CONSTRAINT roomserver_previous_event_id_unique UNIQUE (previous_event_id, previous_reference_sha256)
+);
+`
+
+// Insert an entry into the previous_events table.
+// If there is already an entry indicating that an event references that previous event then
+// add the event NID to the list to indicate that this event references that previous event as well.
+// This should only be modified while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
+// The lock is necessary to avoid data races when checking whether an event is already referenced by another event.
+const insertPreviousEventSQL = "" +
+ "INSERT INTO roomserver_previous_events" +
+ " (previous_event_id, previous_reference_sha256, event_nids)" +
+ " VALUES ($1, $2, array_append('{}'::bigint[], $3))" +
+ " ON CONFLICT ON CONSTRAINT roomserver_previous_event_id_unique" +
+ " DO UPDATE SET event_nids = array_append(roomserver_previous_events.event_nids, $3)" +
+ " WHERE $3 != ALL(roomserver_previous_events.event_nids)"
+
+// Check if the event is referenced by another event in the table.
+// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
+const selectPreviousEventExistsSQL = "" +
+ "SELECT 1 FROM roomserver_previous_events" +
+ " WHERE previous_event_id = $1 AND previous_reference_sha256 = $2"
+
+type previousEventStatements struct {
+ insertPreviousEventStmt *sql.Stmt
+ selectPreviousEventExistsStmt *sql.Stmt
+}
+
+func (s *previousEventStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(previousEventSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertPreviousEventStmt, insertPreviousEventSQL},
+ {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
+ }.prepare(db)
+}
+
+func (s *previousEventStatements) insertPreviousEvent(
+ ctx context.Context,
+ txn *sql.Tx,
+ previousEventID string,
+ previousEventReferenceSHA256 []byte,
+ eventNID types.EventNID,
+) error {
+ stmt := common.TxStmt(txn, s.insertPreviousEventStmt)
+ _, err := stmt.ExecContext(
+ ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
+ )
+ return err
+}
+
+// Check if the event reference exists
+// Returns sql.ErrNoRows if the event reference doesn't exist.
+func (s *previousEventStatements) selectPreviousEventExists(
+ ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
+) error {
+ var ok int64
+ stmt := common.TxStmt(txn, s.selectPreviousEventExistsStmt)
+ return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok)
+}
diff --git a/roomserver/storage/room_aliases_table.go b/roomserver/storage/room_aliases_table.go
new file mode 100644
index 00000000..f640c37f
--- /dev/null
+++ b/roomserver/storage/room_aliases_table.go
@@ -0,0 +1,109 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+ "database/sql"
+)
+
+const roomAliasesSchema = `
+-- Stores room aliases and room IDs they refer to
+CREATE TABLE IF NOT EXISTS roomserver_room_aliases (
+ -- Alias of the room
+ alias TEXT NOT NULL PRIMARY KEY,
+ -- Room ID the alias refers to
+ room_id TEXT NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id);
+`
+
+const insertRoomAliasSQL = "" +
+ "INSERT INTO roomserver_room_aliases (alias, room_id) VALUES ($1, $2)"
+
+const selectRoomIDFromAliasSQL = "" +
+ "SELECT room_id FROM roomserver_room_aliases WHERE alias = $1"
+
+const selectAliasesFromRoomIDSQL = "" +
+ "SELECT alias FROM roomserver_room_aliases WHERE room_id = $1"
+
+const deleteRoomAliasSQL = "" +
+ "DELETE FROM roomserver_room_aliases WHERE alias = $1"
+
+type roomAliasesStatements struct {
+ insertRoomAliasStmt *sql.Stmt
+ selectRoomIDFromAliasStmt *sql.Stmt
+ selectAliasesFromRoomIDStmt *sql.Stmt
+ deleteRoomAliasStmt *sql.Stmt
+}
+
+func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(roomAliasesSchema)
+ if err != nil {
+ return
+ }
+ return statementList{
+ {&s.insertRoomAliasStmt, insertRoomAliasSQL},
+ {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
+ {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL},
+ {&s.deleteRoomAliasStmt, deleteRoomAliasSQL},
+ }.prepare(db)
+}
+
+func (s *roomAliasesStatements) insertRoomAlias(
+ ctx context.Context, alias string, roomID string,
+) (err error) {
+ _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID)
+ return
+}
+
+func (s *roomAliasesStatements) selectRoomIDFromAlias(
+ ctx context.Context, alias string,
+) (roomID string, err error) {
+ err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
+ if err == sql.ErrNoRows {
+ return "", nil
+ }
+ return
+}
+
+func (s *roomAliasesStatements) selectAliasesFromRoomID(
+ ctx context.Context, roomID string,
+) (aliases []string, err error) {
+ aliases = []string{}
+ rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
+ if err != nil {
+ return
+ }
+
+ for rows.Next() {
+ var alias string
+ if err = rows.Scan(&alias); err != nil {
+ return
+ }
+
+ aliases = append(aliases, alias)
+ }
+
+ return
+}
+
+func (s *roomAliasesStatements) deleteRoomAlias(
+ ctx context.Context, alias string,
+) (err error) {
+ _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias)
+ return
+}
diff --git a/roomserver/storage/rooms_table.go b/roomserver/storage/rooms_table.go
new file mode 100644
index 00000000..64193ffe
--- /dev/null
+++ b/roomserver/storage/rooms_table.go
@@ -0,0 +1,155 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const roomsSchema = `
+CREATE SEQUENCE IF NOT EXISTS roomserver_room_nid_seq;
+CREATE TABLE IF NOT EXISTS roomserver_rooms (
+ -- Local numeric ID for the room.
+ room_nid BIGINT PRIMARY KEY DEFAULT nextval('roomserver_room_nid_seq'),
+ -- Textual ID for the room.
+ room_id TEXT NOT NULL CONSTRAINT roomserver_room_id_unique UNIQUE,
+ -- The most recent events in the room that aren't referenced by another event.
+ -- This list may empty if the server hasn't joined the room yet.
+ -- (The server will be in that state while it stores the events for the initial state of the room)
+ latest_event_nids BIGINT[] NOT NULL DEFAULT '{}'::BIGINT[],
+ -- The last event written to the output log for this room.
+ last_event_sent_nid BIGINT NOT NULL DEFAULT 0,
+ -- The state of the room after the current set of latest events.
+ -- This will be 0 if there are no latest events in the room.
+ state_snapshot_nid BIGINT NOT NULL DEFAULT 0
+);
+`
+
+// Same as insertEventTypeNIDSQL
+const insertRoomNIDSQL = "" +
+ "INSERT INTO roomserver_rooms (room_id) VALUES ($1)" +
+ " ON CONFLICT ON CONSTRAINT roomserver_room_id_unique" +
+ " DO NOTHING RETURNING (room_nid)"
+
+const selectRoomNIDSQL = "" +
+ "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
+
+const selectLatestEventNIDsSQL = "" +
+ "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
+
+const selectLatestEventNIDsForUpdateSQL = "" +
+ "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1 FOR UPDATE"
+
+const updateLatestEventNIDsSQL = "" +
+ "UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1"
+
+type roomStatements struct {
+ insertRoomNIDStmt *sql.Stmt
+ selectRoomNIDStmt *sql.Stmt
+ selectLatestEventNIDsStmt *sql.Stmt
+ selectLatestEventNIDsForUpdateStmt *sql.Stmt
+ updateLatestEventNIDsStmt *sql.Stmt
+}
+
+func (s *roomStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(roomsSchema)
+ if err != nil {
+ return
+ }
+ return statementList{
+ {&s.insertRoomNIDStmt, insertRoomNIDSQL},
+ {&s.selectRoomNIDStmt, selectRoomNIDSQL},
+ {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
+ {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
+ {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
+ }.prepare(db)
+}
+
+func (s *roomStatements) insertRoomNID(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) (types.RoomNID, error) {
+ var roomNID int64
+ stmt := common.TxStmt(txn, s.insertRoomNIDStmt)
+ err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
+ return types.RoomNID(roomNID), err
+}
+
+func (s *roomStatements) selectRoomNID(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) (types.RoomNID, error) {
+ var roomNID int64
+ stmt := common.TxStmt(txn, s.selectRoomNIDStmt)
+ err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
+ return types.RoomNID(roomNID), err
+}
+
+func (s *roomStatements) selectLatestEventNIDs(
+ ctx context.Context, roomNID types.RoomNID,
+) ([]types.EventNID, types.StateSnapshotNID, error) {
+ var nids pq.Int64Array
+ var stateSnapshotNID int64
+ stmt := s.selectLatestEventNIDsStmt
+ err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
+ if err != nil {
+ return nil, 0, err
+ }
+ eventNIDs := make([]types.EventNID, len(nids))
+ for i := range nids {
+ eventNIDs[i] = types.EventNID(nids[i])
+ }
+ return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
+}
+
+func (s *roomStatements) selectLatestEventsNIDsForUpdate(
+ ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
+) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) {
+ var nids pq.Int64Array
+ var lastEventSentNID int64
+ var stateSnapshotNID int64
+ stmt := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt)
+ err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
+ if err != nil {
+ return nil, 0, 0, err
+ }
+ eventNIDs := make([]types.EventNID, len(nids))
+ for i := range nids {
+ eventNIDs[i] = types.EventNID(nids[i])
+ }
+ return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil
+}
+
+func (s *roomStatements) updateLatestEventNIDs(
+ ctx context.Context,
+ txn *sql.Tx,
+ roomNID types.RoomNID,
+ eventNIDs []types.EventNID,
+ lastEventSentNID types.EventNID,
+ stateSnapshotNID types.StateSnapshotNID,
+) error {
+ stmt := common.TxStmt(txn, s.updateLatestEventNIDsStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ roomNID,
+ eventNIDsAsArray(eventNIDs),
+ int64(lastEventSentNID),
+ int64(stateSnapshotNID),
+ )
+ return err
+}
diff --git a/roomserver/storage/sql.go b/roomserver/storage/sql.go
new file mode 100644
index 00000000..05efa8dd
--- /dev/null
+++ b/roomserver/storage/sql.go
@@ -0,0 +1,59 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "database/sql"
+)
+
+type statements struct {
+ eventTypeStatements
+ eventStateKeyStatements
+ roomStatements
+ eventStatements
+ eventJSONStatements
+ stateSnapshotStatements
+ stateBlockStatements
+ previousEventStatements
+ roomAliasesStatements
+ inviteStatements
+ membershipStatements
+ transactionStatements
+}
+
+func (s *statements) prepare(db *sql.DB) error {
+ var err error
+
+ for _, prepare := range []func(db *sql.DB) error{
+ s.eventTypeStatements.prepare,
+ s.eventStateKeyStatements.prepare,
+ s.roomStatements.prepare,
+ s.eventStatements.prepare,
+ s.eventJSONStatements.prepare,
+ s.stateSnapshotStatements.prepare,
+ s.stateBlockStatements.prepare,
+ s.previousEventStatements.prepare,
+ s.roomAliasesStatements.prepare,
+ s.inviteStatements.prepare,
+ s.membershipStatements.prepare,
+ s.transactionStatements.prepare,
+ } {
+ if err = prepare(db); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/roomserver/storage/state_block_table.go b/roomserver/storage/state_block_table.go
new file mode 100644
index 00000000..b2e8ef8a
--- /dev/null
+++ b/roomserver/storage/state_block_table.go
@@ -0,0 +1,280 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "sort"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/util"
+)
+
+const stateDataSchema = `
+-- The state data map.
+-- Designed to give enough information to run the state resolution algorithm
+-- without hitting the database in the common case.
+-- TODO: Is it worth replacing the unique btree index with a covering index so
+-- that postgres could lookup the state using an index-only scan?
+-- The type and state_key are included in the index to make it easier to
+-- lookup a specific (type, state_key) pair for an event. It also makes it easy
+-- to read the state for a given state_block_nid ordered by (type, state_key)
+-- which in turn makes it easier to merge state data blocks.
+CREATE SEQUENCE IF NOT EXISTS roomserver_state_block_nid_seq;
+CREATE TABLE IF NOT EXISTS roomserver_state_block (
+ -- Local numeric ID for this state data.
+ state_block_nid bigint NOT NULL,
+ event_type_nid bigint NOT NULL,
+ event_state_key_nid bigint NOT NULL,
+ event_nid bigint NOT NULL,
+ UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
+);
+`
+
+const insertStateDataSQL = "" +
+ "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
+ " VALUES ($1, $2, $3, $4)"
+
+const selectNextStateBlockNIDSQL = "" +
+ "SELECT nextval('roomserver_state_block_nid_seq')"
+
+// Bulk state lookup by numeric state block ID.
+// Sort by the state_block_nid, event_type_nid, event_state_key_nid
+// This means that all the entries for a given state_block_nid will appear
+// together in the list and those entries will sorted by event_type_nid
+// and event_state_key_nid. This property makes it easier to merge two
+// state data blocks together.
+const bulkSelectStateBlockEntriesSQL = "" +
+ "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
+ " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" +
+ " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
+
+// Bulk state lookup by numeric state block ID.
+// Filters the rows in each block to the requested types and state keys.
+// We would like to restrict to particular type state key pairs but we are
+// restricted by the query language to pull the cross product of a list
+// of types and a list state_keys. So we have to filter the result in the
+// application to restrict it to the list of event types and state keys we
+// actually wanted.
+const bulkSelectFilteredStateBlockEntriesSQL = "" +
+ "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
+ " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" +
+ " AND event_type_nid = ANY($2) AND event_state_key_nid = ANY($3)" +
+ " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
+
+type stateBlockStatements struct {
+ insertStateDataStmt *sql.Stmt
+ selectNextStateBlockNIDStmt *sql.Stmt
+ bulkSelectStateBlockEntriesStmt *sql.Stmt
+ bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
+}
+
+func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(stateDataSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertStateDataStmt, insertStateDataSQL},
+ {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
+ {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
+ {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
+ }.prepare(db)
+}
+
+func (s *stateBlockStatements) bulkInsertStateData(
+ ctx context.Context,
+ stateBlockNID types.StateBlockNID,
+ entries []types.StateEntry,
+) error {
+ for _, entry := range entries {
+ _, err := s.insertStateDataStmt.ExecContext(
+ ctx,
+ int64(stateBlockNID),
+ int64(entry.EventTypeNID),
+ int64(entry.EventStateKeyNID),
+ int64(entry.EventNID),
+ )
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (s *stateBlockStatements) selectNextStateBlockNID(
+ ctx context.Context,
+) (types.StateBlockNID, error) {
+ var stateBlockNID int64
+ err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID)
+ return types.StateBlockNID(stateBlockNID), err
+}
+
+func (s *stateBlockStatements) bulkSelectStateBlockEntries(
+ ctx context.Context, stateBlockNIDs []types.StateBlockNID,
+) ([]types.StateEntryList, error) {
+ nids := make([]int64, len(stateBlockNIDs))
+ for i := range stateBlockNIDs {
+ nids[i] = int64(stateBlockNIDs[i])
+ }
+ rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids))
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ results := make([]types.StateEntryList, len(stateBlockNIDs))
+ // current is a pointer to the StateEntryList to append the state entries to.
+ var current *types.StateEntryList
+ i := 0
+ for rows.Next() {
+ var (
+ stateBlockNID int64
+ eventTypeNID int64
+ eventStateKeyNID int64
+ eventNID int64
+ entry types.StateEntry
+ )
+ if err := rows.Scan(
+ &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
+ ); err != nil {
+ return nil, err
+ }
+ entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
+ entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
+ entry.EventNID = types.EventNID(eventNID)
+ if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
+ // The state entry row is for a different state data block to the current one.
+ // So we start appending to the next entry in the list.
+ current = &results[i]
+ current.StateBlockNID = types.StateBlockNID(stateBlockNID)
+ i++
+ }
+ current.StateEntries = append(current.StateEntries, entry)
+ }
+ if i != len(stateBlockNIDs) {
+ return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
+ }
+ return results, nil
+}
+
+func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
+ ctx context.Context,
+ stateBlockNIDs []types.StateBlockNID,
+ stateKeyTuples []types.StateKeyTuple,
+) ([]types.StateEntryList, error) {
+ tuples := stateKeyTupleSorter(stateKeyTuples)
+ // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
+ sort.Sort(tuples)
+
+ eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
+ rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext(
+ ctx,
+ stateBlockNIDsAsArray(stateBlockNIDs),
+ eventTypeNIDArray,
+ eventStateKeyNIDArray,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ var results []types.StateEntryList
+ var current types.StateEntryList
+ for rows.Next() {
+ var (
+ stateBlockNID int64
+ eventTypeNID int64
+ eventStateKeyNID int64
+ eventNID int64
+ entry types.StateEntry
+ )
+ if err := rows.Scan(
+ &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
+ ); err != nil {
+ return nil, err
+ }
+ entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
+ entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
+ entry.EventNID = types.EventNID(eventNID)
+
+ // We can use binary search here because we sorted the tuples earlier
+ if !tuples.contains(entry.StateKeyTuple) {
+ // The select will return the cross product of types and state keys.
+ // So we need to check if type of the entry is in the list.
+ continue
+ }
+
+ if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
+ // The state entry row is for a different state data block to the current one.
+ // So we append the current entry to the results and start adding to a new one.
+ // The first time through the loop current will be empty.
+ if current.StateEntries != nil {
+ results = append(results, current)
+ }
+ current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
+ }
+ current.StateEntries = append(current.StateEntries, entry)
+ }
+ // Add the last entry to the list if it is not empty.
+ if current.StateEntries != nil {
+ results = append(results, current)
+ }
+ return results, nil
+}
+
+func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
+ nids := make([]int64, len(stateBlockNIDs))
+ for i := range stateBlockNIDs {
+ nids[i] = int64(stateBlockNIDs[i])
+ }
+ return pq.Int64Array(nids)
+}
+
+type stateKeyTupleSorter []types.StateKeyTuple
+
+func (s stateKeyTupleSorter) Len() int { return len(s) }
+func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
+func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+// Check whether a tuple is in the list. Assumes that the list is sorted.
+func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
+ i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
+ return i < len(s) && s[i] == value
+}
+
+// List the unique eventTypeNIDs and eventStateKeyNIDs.
+// Assumes that the list is sorted.
+func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) {
+ eventTypeNIDs = make(pq.Int64Array, len(s))
+ eventStateKeyNIDs = make(pq.Int64Array, len(s))
+ for i := range s {
+ eventTypeNIDs[i] = int64(s[i].EventTypeNID)
+ eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
+ }
+ eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
+ eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
+ return
+}
+
+type int64Sorter []int64
+
+func (s int64Sorter) Len() int { return len(s) }
+func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
+func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
diff --git a/roomserver/storage/state_block_table_test.go b/roomserver/storage/state_block_table_test.go
new file mode 100644
index 00000000..f891b5bc
--- /dev/null
+++ b/roomserver/storage/state_block_table_test.go
@@ -0,0 +1,85 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "sort"
+ "testing"
+
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+func TestStateKeyTupleSorter(t *testing.T) {
+ input := stateKeyTupleSorter{
+ {EventTypeNID: 1, EventStateKeyNID: 2},
+ {EventTypeNID: 1, EventStateKeyNID: 4},
+ {EventTypeNID: 2, EventStateKeyNID: 2},
+ {EventTypeNID: 1, EventStateKeyNID: 1},
+ }
+ want := []types.StateKeyTuple{
+ {EventTypeNID: 1, EventStateKeyNID: 1},
+ {EventTypeNID: 1, EventStateKeyNID: 2},
+ {EventTypeNID: 1, EventStateKeyNID: 4},
+ {EventTypeNID: 2, EventStateKeyNID: 2},
+ }
+ doNotWant := []types.StateKeyTuple{
+ {EventTypeNID: 0, EventStateKeyNID: 0},
+ {EventTypeNID: 1, EventStateKeyNID: 3},
+ {EventTypeNID: 2, EventStateKeyNID: 1},
+ {EventTypeNID: 3, EventStateKeyNID: 1},
+ }
+ wantTypeNIDs := []int64{1, 2}
+ wantStateKeyNIDs := []int64{1, 2, 4}
+
+ // Sort the input and check it's in the right order.
+ sort.Sort(input)
+ gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
+
+ for i := range want {
+ if input[i] != want[i] {
+ t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
+ }
+
+ if !input.contains(want[i]) {
+ t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
+ }
+ }
+
+ for i := range doNotWant {
+ if input.contains(doNotWant[i]) {
+ t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
+ }
+ }
+
+ if len(wantTypeNIDs) != len(gotTypeNIDs) {
+ t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
+ }
+
+ for i := range wantTypeNIDs {
+ if wantTypeNIDs[i] != gotTypeNIDs[i] {
+ t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
+ }
+ }
+
+ if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
+ t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
+ }
+
+ for i := range wantStateKeyNIDs {
+ if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
+ t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
+ }
+ }
+}
diff --git a/roomserver/storage/state_snapshot_table.go b/roomserver/storage/state_snapshot_table.go
new file mode 100644
index 00000000..aa14daad
--- /dev/null
+++ b/roomserver/storage/state_snapshot_table.go
@@ -0,0 +1,118 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const stateSnapshotSchema = `
+-- The state of a room before an event.
+-- Stored as a list of state_block entries stored in a separate table.
+-- The actual state is constructed by combining all the state_block entries
+-- referenced by state_block_nids together. If the same state key tuple appears
+-- multiple times then the entry from the later state_block clobbers the earlier
+-- entries.
+-- This encoding format allows us to implement a delta encoding which is useful
+-- because room state tends to accumulate small changes over time. Although if
+-- the list of deltas becomes too long it becomes more efficient to encode
+-- the full state under single state_block_nid.
+CREATE SEQUENCE IF NOT EXISTS roomserver_state_snapshot_nid_seq;
+CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
+ -- Local numeric ID for the state.
+ state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_seq'),
+ -- Local numeric ID of the room this state is for.
+ -- Unused in normal operation, but useful for background work or ad-hoc debugging.
+ room_nid bigint NOT NULL,
+ -- List of state_block_nids, stored sorted by state_block_nid.
+ state_block_nids bigint[] NOT NULL
+);
+`
+
+const insertStateSQL = "" +
+ "INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)" +
+ " VALUES ($1, $2)" +
+ " RETURNING state_snapshot_nid"
+
+// Bulk state data NID lookup.
+// Sorting by state_snapshot_nid means we can use binary search over the result
+// to lookup the state data NIDs for a state snapshot NID.
+const bulkSelectStateBlockNIDsSQL = "" +
+ "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
+ " WHERE state_snapshot_nid = ANY($1) ORDER BY state_snapshot_nid ASC"
+
+type stateSnapshotStatements struct {
+ insertStateStmt *sql.Stmt
+ bulkSelectStateBlockNIDsStmt *sql.Stmt
+}
+
+func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(stateSnapshotSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertStateStmt, insertStateSQL},
+ {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
+ }.prepare(db)
+}
+
+func (s *stateSnapshotStatements) insertState(
+ ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
+) (stateNID types.StateSnapshotNID, err error) {
+ nids := make([]int64, len(stateBlockNIDs))
+ for i := range stateBlockNIDs {
+ nids[i] = int64(stateBlockNIDs[i])
+ }
+ err = s.insertStateStmt.QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
+ return
+}
+
+func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
+ ctx context.Context, stateNIDs []types.StateSnapshotNID,
+) ([]types.StateBlockNIDList, error) {
+ nids := make([]int64, len(stateNIDs))
+ for i := range stateNIDs {
+ nids[i] = int64(stateNIDs[i])
+ }
+ rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids))
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := make([]types.StateBlockNIDList, len(stateNIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ result := &results[i]
+ var stateBlockNIDs pq.Int64Array
+ if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
+ return nil, err
+ }
+ result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs))
+ for k := range stateBlockNIDs {
+ result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k])
+ }
+ }
+ if i != len(stateNIDs) {
+ return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
+ }
+ return results, nil
+}
diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go
new file mode 100644
index 00000000..f6c2fccd
--- /dev/null
+++ b/roomserver/storage/storage.go
@@ -0,0 +1,705 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+ "database/sql"
+
+ // Import the postgres database driver.
+ _ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// A Database is used to store room events and stream offsets.
+type Database struct {
+ statements statements
+ db *sql.DB
+}
+
+// Open a postgres database.
+func Open(dataSourceName string) (*Database, error) {
+ var d Database
+ var err error
+ if d.db, err = sql.Open("postgres", dataSourceName); err != nil {
+ return nil, err
+ }
+ if err = d.statements.prepare(d.db); err != nil {
+ return nil, err
+ }
+ return &d, nil
+}
+
+// StoreEvent implements input.EventDatabase
+func (d *Database) StoreEvent(
+ ctx context.Context, event gomatrixserverlib.Event,
+ txnAndDeviceID *api.TransactionID, authEventNIDs []types.EventNID,
+) (types.RoomNID, types.StateAtEvent, error) {
+ var (
+ roomNID types.RoomNID
+ eventTypeNID types.EventTypeNID
+ eventStateKeyNID types.EventStateKeyNID
+ eventNID types.EventNID
+ stateNID types.StateSnapshotNID
+ err error
+ )
+
+ if txnAndDeviceID != nil {
+ if err = d.statements.insertTransaction(
+ ctx, txnAndDeviceID.TransactionID,
+ txnAndDeviceID.DeviceID, event.Sender(), event.EventID(),
+ ); err != nil {
+ return 0, types.StateAtEvent{}, err
+ }
+ }
+
+ if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID()); err != nil {
+ return 0, types.StateAtEvent{}, err
+ }
+
+ if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil {
+ return 0, types.StateAtEvent{}, err
+ }
+
+ eventStateKey := event.StateKey()
+ // Assigned a numeric ID for the state_key if there is one present.
+ // Otherwise set the numeric ID for the state_key to 0.
+ if eventStateKey != nil {
+ if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil {
+ return 0, types.StateAtEvent{}, err
+ }
+ }
+
+ if eventNID, stateNID, err = d.statements.insertEvent(
+ ctx,
+ roomNID,
+ eventTypeNID,
+ eventStateKeyNID,
+ event.EventID(),
+ event.EventReference().EventSHA256,
+ authEventNIDs,
+ event.Depth(),
+ ); err != nil {
+ if err == sql.ErrNoRows {
+ // We've already inserted the event so select the numeric event ID
+ eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID())
+ }
+ if err != nil {
+ return 0, types.StateAtEvent{}, err
+ }
+ }
+
+ if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil {
+ return 0, types.StateAtEvent{}, err
+ }
+
+ return roomNID, types.StateAtEvent{
+ BeforeStateSnapshotNID: stateNID,
+ StateEntry: types.StateEntry{
+ StateKeyTuple: types.StateKeyTuple{
+ EventTypeNID: eventTypeNID,
+ EventStateKeyNID: eventStateKeyNID,
+ },
+ EventNID: eventNID,
+ },
+ }, nil
+}
+
+func (d *Database) assignRoomNID(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) (types.RoomNID, error) {
+ // Check if we already have a numeric ID in the database.
+ roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID)
+ if err == sql.ErrNoRows {
+ // We don't have a numeric ID so insert one into the database.
+ roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID)
+ if err == sql.ErrNoRows {
+ // We raced with another insert so run the select again.
+ roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
+ }
+ }
+ return roomNID, err
+}
+
+func (d *Database) assignEventTypeNID(
+ ctx context.Context, eventType string,
+) (types.EventTypeNID, error) {
+ // Check if we already have a numeric ID in the database.
+ eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType)
+ if err == sql.ErrNoRows {
+ // We don't have a numeric ID so insert one into the database.
+ eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType)
+ if err == sql.ErrNoRows {
+ // We raced with another insert so run the select again.
+ eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType)
+ }
+ }
+ return eventTypeNID, err
+}
+
+func (d *Database) assignStateKeyNID(
+ ctx context.Context, txn *sql.Tx, eventStateKey string,
+) (types.EventStateKeyNID, error) {
+ // Check if we already have a numeric ID in the database.
+ eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
+ if err == sql.ErrNoRows {
+ // We don't have a numeric ID so insert one into the database.
+ eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey)
+ if err == sql.ErrNoRows {
+ // We raced with another insert so run the select again.
+ eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
+ }
+ }
+ return eventStateKeyNID, err
+}
+
+// StateEntriesForEventIDs implements input.EventDatabase
+func (d *Database) StateEntriesForEventIDs(
+ ctx context.Context, eventIDs []string,
+) ([]types.StateEntry, error) {
+ return d.statements.bulkSelectStateEventByID(ctx, eventIDs)
+}
+
+// EventTypeNIDs implements state.RoomStateDatabase
+func (d *Database) EventTypeNIDs(
+ ctx context.Context, eventTypes []string,
+) (map[string]types.EventTypeNID, error) {
+ return d.statements.bulkSelectEventTypeNID(ctx, eventTypes)
+}
+
+// EventStateKeyNIDs implements state.RoomStateDatabase
+func (d *Database) EventStateKeyNIDs(
+ ctx context.Context, eventStateKeys []string,
+) (map[string]types.EventStateKeyNID, error) {
+ return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys)
+}
+
+// EventStateKeys implements query.RoomserverQueryAPIDatabase
+func (d *Database) EventStateKeys(
+ ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
+) (map[types.EventStateKeyNID]string, error) {
+ return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs)
+}
+
+// EventNIDs implements query.RoomserverQueryAPIDatabase
+func (d *Database) EventNIDs(
+ ctx context.Context, eventIDs []string,
+) (map[string]types.EventNID, error) {
+ return d.statements.bulkSelectEventNID(ctx, eventIDs)
+}
+
+// Events implements input.EventDatabase
+func (d *Database) Events(
+ ctx context.Context, eventNIDs []types.EventNID,
+) ([]types.Event, error) {
+ eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs)
+ if err != nil {
+ return nil, err
+ }
+ results := make([]types.Event, len(eventJSONs))
+ for i, eventJSON := range eventJSONs {
+ result := &results[i]
+ result.EventNID = eventJSON.EventNID
+ // TODO: Use NewEventFromTrustedJSON for efficiency
+ result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return results, nil
+}
+
+// AddState implements input.EventDatabase
+func (d *Database) AddState(
+ ctx context.Context,
+ roomNID types.RoomNID,
+ stateBlockNIDs []types.StateBlockNID,
+ state []types.StateEntry,
+) (types.StateSnapshotNID, error) {
+ if len(state) > 0 {
+ stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx)
+ if err != nil {
+ return 0, err
+ }
+ if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil {
+ return 0, err
+ }
+ stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
+ }
+
+ return d.statements.insertState(ctx, roomNID, stateBlockNIDs)
+}
+
+// SetState implements input.EventDatabase
+func (d *Database) SetState(
+ ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
+) error {
+ return d.statements.updateEventState(ctx, eventNID, stateNID)
+}
+
+// StateAtEventIDs implements input.EventDatabase
+func (d *Database) StateAtEventIDs(
+ ctx context.Context, eventIDs []string,
+) ([]types.StateAtEvent, error) {
+ return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs)
+}
+
+// StateBlockNIDs implements state.RoomStateDatabase
+func (d *Database) StateBlockNIDs(
+ ctx context.Context, stateNIDs []types.StateSnapshotNID,
+) ([]types.StateBlockNIDList, error) {
+ return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs)
+}
+
+// StateEntries implements state.RoomStateDatabase
+func (d *Database) StateEntries(
+ ctx context.Context, stateBlockNIDs []types.StateBlockNID,
+) ([]types.StateEntryList, error) {
+ return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs)
+}
+
+// SnapshotNIDFromEventID implements state.RoomStateDatabase
+func (d *Database) SnapshotNIDFromEventID(
+ ctx context.Context, eventID string,
+) (types.StateSnapshotNID, error) {
+ _, stateNID, err := d.statements.selectEvent(ctx, eventID)
+ return stateNID, err
+}
+
+// EventIDs implements input.RoomEventDatabase
+func (d *Database) EventIDs(
+ ctx context.Context, eventNIDs []types.EventNID,
+) (map[types.EventNID]string, error) {
+ return d.statements.bulkSelectEventID(ctx, eventNIDs)
+}
+
+// GetLatestEventsForUpdate implements input.EventDatabase
+func (d *Database) GetLatestEventsForUpdate(
+ ctx context.Context, roomNID types.RoomNID,
+) (types.RoomRecentEventsUpdater, error) {
+ txn, err := d.db.Begin()
+ if err != nil {
+ return nil, err
+ }
+ eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
+ d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
+ if err != nil {
+ txn.Rollback() // nolint: errcheck
+ return nil, err
+ }
+ stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
+ if err != nil {
+ txn.Rollback() // nolint: errcheck
+ return nil, err
+ }
+ var lastEventIDSent string
+ if lastEventNIDSent != 0 {
+ lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent)
+ if err != nil {
+ txn.Rollback() // nolint: errcheck
+ return nil, err
+ }
+ }
+ return &roomRecentEventsUpdater{
+ transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
+ }, nil
+}
+
+// GetTransactionEventID implements input.EventDatabase
+func (d *Database) GetTransactionEventID(
+ ctx context.Context, transactionID string,
+ deviceID string, userID string,
+) (string, error) {
+ eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, deviceID, userID)
+ if err == sql.ErrNoRows {
+ return "", nil
+ }
+ return eventID, err
+}
+
+type roomRecentEventsUpdater struct {
+ transaction
+ d *Database
+ roomNID types.RoomNID
+ latestEvents []types.StateAtEventAndReference
+ lastEventIDSent string
+ currentStateSnapshotNID types.StateSnapshotNID
+}
+
+// LatestEvents implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
+ return u.latestEvents
+}
+
+// LastEventIDSent implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) LastEventIDSent() string {
+ return u.lastEventIDSent
+}
+
+// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
+ return u.currentStateSnapshotNID
+}
+
+// StorePreviousEvents implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
+ for _, ref := range previousEventReferences {
+ if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// IsReferenced implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
+ err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
+ if err == nil {
+ return true, nil
+ }
+ if err == sql.ErrNoRows {
+ return false, nil
+ }
+ return false, err
+}
+
+// SetLatestEvents implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) SetLatestEvents(
+ roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
+ currentStateSnapshotNID types.StateSnapshotNID,
+) error {
+ eventNIDs := make([]types.EventNID, len(latest))
+ for i := range latest {
+ eventNIDs[i] = latest[i].EventNID
+ }
+ return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
+}
+
+// HasEventBeenSent implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
+ return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID)
+}
+
+// MarkEventAsSent implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
+ return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID)
+}
+
+func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) {
+ return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID)
+}
+
+// RoomNID implements query.RoomserverQueryAPIDB
+func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) {
+ roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID)
+ if err == sql.ErrNoRows {
+ return 0, nil
+ }
+ return roomNID, err
+}
+
+// LatestEventIDs implements query.RoomserverQueryAPIDatabase
+func (d *Database) LatestEventIDs(
+ ctx context.Context, roomNID types.RoomNID,
+) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) {
+ eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID)
+ if err != nil {
+ return nil, 0, 0, err
+ }
+ references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs)
+ if err != nil {
+ return nil, 0, 0, err
+ }
+ depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs)
+ if err != nil {
+ return nil, 0, 0, err
+ }
+ return references, currentStateSnapshotNID, depth, nil
+}
+
+// GetInvitesForUser implements query.RoomserverQueryAPIDatabase
+func (d *Database) GetInvitesForUser(
+ ctx context.Context,
+ roomNID types.RoomNID,
+ targetUserNID types.EventStateKeyNID,
+) (senderUserIDs []types.EventStateKeyNID, err error) {
+ return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
+}
+
+// SetRoomAlias implements alias.RoomserverAliasAPIDB
+func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string) error {
+ return d.statements.insertRoomAlias(ctx, alias, roomID)
+}
+
+// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB
+func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
+ return d.statements.selectRoomIDFromAlias(ctx, alias)
+}
+
+// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB
+func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
+ return d.statements.selectAliasesFromRoomID(ctx, roomID)
+}
+
+// RemoveRoomAlias implements alias.RoomserverAliasAPIDB
+func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
+ return d.statements.deleteRoomAlias(ctx, alias)
+}
+
+// StateEntriesForTuples implements state.RoomStateDatabase
+func (d *Database) StateEntriesForTuples(
+ ctx context.Context,
+ stateBlockNIDs []types.StateBlockNID,
+ stateKeyTuples []types.StateKeyTuple,
+) ([]types.StateEntryList, error) {
+ return d.statements.bulkSelectFilteredStateBlockEntries(
+ ctx, stateBlockNIDs, stateKeyTuples,
+ )
+}
+
+// MembershipUpdater implements input.RoomEventDatabase
+func (d *Database) MembershipUpdater(
+ ctx context.Context, roomID, targetUserID string,
+) (types.MembershipUpdater, error) {
+ txn, err := d.db.Begin()
+ if err != nil {
+ return nil, err
+ }
+ succeeded := false
+ defer func() {
+ if !succeeded {
+ txn.Rollback() // nolint: errcheck
+ }
+ }()
+
+ roomNID, err := d.assignRoomNID(ctx, txn, roomID)
+ if err != nil {
+ return nil, err
+ }
+
+ targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID)
+ if err != nil {
+ return nil, err
+ }
+
+ updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
+ if err != nil {
+ return nil, err
+ }
+
+ succeeded = true
+ return updater, nil
+}
+
+type membershipUpdater struct {
+ transaction
+ d *Database
+ roomNID types.RoomNID
+ targetUserNID types.EventStateKeyNID
+ membership membershipState
+}
+
+func (d *Database) membershipUpdaterTxn(
+ ctx context.Context,
+ txn *sql.Tx,
+ roomNID types.RoomNID,
+ targetUserNID types.EventStateKeyNID,
+) (types.MembershipUpdater, error) {
+
+ if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
+ return nil, err
+ }
+
+ membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
+ if err != nil {
+ return nil, err
+ }
+
+ return &membershipUpdater{
+ transaction{ctx, txn}, d, roomNID, targetUserNID, membership,
+ }, nil
+}
+
+// IsInvite implements types.MembershipUpdater
+func (u *membershipUpdater) IsInvite() bool {
+ return u.membership == membershipStateInvite
+}
+
+// IsJoin implements types.MembershipUpdater
+func (u *membershipUpdater) IsJoin() bool {
+ return u.membership == membershipStateJoin
+}
+
+// IsLeave implements types.MembershipUpdater
+func (u *membershipUpdater) IsLeave() bool {
+ return u.membership == membershipStateLeaveOrBan
+}
+
+// SetToInvite implements types.MembershipUpdater
+func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
+ if err != nil {
+ return false, err
+ }
+ inserted, err := u.d.statements.insertInviteEvent(
+ u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
+ )
+ if err != nil {
+ return false, err
+ }
+ if u.membership != membershipStateInvite {
+ if err = u.d.statements.updateMembership(
+ u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
+ ); err != nil {
+ return false, err
+ }
+ }
+ return inserted, nil
+}
+
+// SetToJoin implements types.MembershipUpdater
+func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
+ var inviteEventIDs []string
+
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
+ if err != nil {
+ return nil, err
+ }
+
+ // If this is a join event update, there is no invite to update
+ if !isUpdate {
+ inviteEventIDs, err = u.d.statements.updateInviteRetired(
+ u.ctx, u.txn, u.roomNID, u.targetUserNID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Look up the NID of the new join event
+ nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
+ if err != nil {
+ return nil, err
+ }
+
+ if u.membership != membershipStateJoin || isUpdate {
+ if err = u.d.statements.updateMembership(
+ u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
+ membershipStateJoin, nIDs[eventID],
+ ); err != nil {
+ return nil, err
+ }
+ }
+
+ return inviteEventIDs, nil
+}
+
+// SetToLeave implements types.MembershipUpdater
+func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
+ if err != nil {
+ return nil, err
+ }
+ inviteEventIDs, err := u.d.statements.updateInviteRetired(
+ u.ctx, u.txn, u.roomNID, u.targetUserNID,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Look up the NID of the new leave event
+ nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
+ if err != nil {
+ return nil, err
+ }
+
+ if u.membership != membershipStateLeaveOrBan {
+ if err = u.d.statements.updateMembership(
+ u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
+ membershipStateLeaveOrBan, nIDs[eventID],
+ ); err != nil {
+ return nil, err
+ }
+ }
+ return inviteEventIDs, nil
+}
+
+// GetMembership implements query.RoomserverQueryAPIDB
+func (d *Database) GetMembership(
+ ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
+) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
+ requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID)
+ if err != nil {
+ return
+ }
+
+ senderMembershipEventNID, senderMembership, err :=
+ d.statements.selectMembershipFromRoomAndTarget(
+ ctx, roomNID, requestSenderUserNID,
+ )
+ if err == sql.ErrNoRows {
+ // The user has never been a member of that room
+ return 0, false, nil
+ } else if err != nil {
+ return
+ }
+
+ return senderMembershipEventNID, senderMembership == membershipStateJoin, nil
+}
+
+// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
+func (d *Database) GetMembershipEventNIDsForRoom(
+ ctx context.Context, roomNID types.RoomNID, joinOnly bool,
+) ([]types.EventNID, error) {
+ if joinOnly {
+ return d.statements.selectMembershipsFromRoomAndMembership(
+ ctx, roomNID, membershipStateJoin,
+ )
+ }
+
+ return d.statements.selectMembershipsFromRoom(ctx, roomNID)
+}
+
+// EventsFromIDs implements query.RoomserverQueryAPIEventDB
+func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
+ nidMap, err := d.EventNIDs(ctx, eventIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ var nids []types.EventNID
+ for _, nid := range nidMap {
+ nids = append(nids, nid)
+ }
+
+ return d.Events(ctx, nids)
+}
+
+type transaction struct {
+ ctx context.Context
+ txn *sql.Tx
+}
+
+// Commit implements types.Transaction
+func (t *transaction) Commit() error {
+ return t.txn.Commit()
+}
+
+// Rollback implements types.Transaction
+func (t *transaction) Rollback() error {
+ return t.txn.Rollback()
+}
diff --git a/roomserver/storage/transactions_table.go b/roomserver/storage/transactions_table.go
new file mode 100644
index 00000000..e9c904cc
--- /dev/null
+++ b/roomserver/storage/transactions_table.go
@@ -0,0 +1,86 @@
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package storage
+
+import (
+ "context"
+ "database/sql"
+)
+
+const transactionsSchema = `
+-- The transactions table holds transaction IDs with sender's info and event ID it belongs to.
+-- This table is used by roomserver to prevent reprocessing of events.
+CREATE TABLE IF NOT EXISTS roomserver_transactions (
+ -- The transaction ID of the event.
+ transaction_id TEXT NOT NULL,
+ -- The device ID of the originating transaction.
+ device_id TEXT NOT NULL,
+ -- User ID of the sender who authored the event
+ user_id TEXT NOT NULL,
+ -- Event ID corresponding to the transaction
+ -- Required to return event ID to client on a duplicate request.
+ event_id TEXT NOT NULL,
+ -- A transaction ID is unique for a user and device
+ -- This automatically creates an index.
+ PRIMARY KEY (transaction_id, device_id, user_id)
+);
+`
+const insertTransactionSQL = "" +
+ "INSERT INTO roomserver_transactions (transaction_id, device_id, user_id, event_id)" +
+ " VALUES ($1, $2, $3, $4)"
+
+const selectTransactionEventIDSQL = "" +
+ "SELECT event_id FROM roomserver_transactions" +
+ " WHERE transaction_id = $1 AND device_id = $2 AND user_id = $3"
+
+type transactionStatements struct {
+ insertTransactionStmt *sql.Stmt
+ selectTransactionEventIDStmt *sql.Stmt
+}
+
+func (s *transactionStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(transactionsSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertTransactionStmt, insertTransactionSQL},
+ {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL},
+ }.prepare(db)
+}
+
+func (s *transactionStatements) insertTransaction(
+ ctx context.Context,
+ transactionID string,
+ deviceID string,
+ userID string,
+ eventID string,
+) (err error) {
+ _, err = s.insertTransactionStmt.ExecContext(
+ ctx, transactionID, deviceID, userID, eventID,
+ )
+ return
+}
+
+func (s *transactionStatements) selectTransactionEventID(
+ ctx context.Context,
+ transactionID string,
+ deviceID string,
+ userID string,
+) (eventID string, err error) {
+ err = s.selectTransactionEventIDStmt.QueryRowContext(
+ ctx, transactionID, deviceID, userID,
+ ).Scan(&eventID)
+ return
+}
diff --git a/roomserver/types/types.go b/roomserver/types/types.go
new file mode 100644
index 00000000..d5fe3276
--- /dev/null
+++ b/roomserver/types/types.go
@@ -0,0 +1,203 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package types provides the types that are used internally within the roomserver.
+package types
+
+import (
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// EventTypeNID is a numeric ID for an event type.
+type EventTypeNID int64
+
+// EventStateKeyNID is a numeric ID for an event state_key.
+type EventStateKeyNID int64
+
+// EventNID is a numeric ID for an event.
+type EventNID int64
+
+// RoomNID is a numeric ID for a room.
+type RoomNID int64
+
+// StateSnapshotNID is a numeric ID for the state at an event.
+type StateSnapshotNID int64
+
+// StateBlockNID is a numeric ID for a block of state data.
+// These blocks of state data are combined to form the actual state.
+type StateBlockNID int64
+
+// A StateKeyTuple is a pair of a numeric event type and a numeric state key.
+// It is used to lookup state entries.
+type StateKeyTuple struct {
+ // The numeric ID for the event type.
+ EventTypeNID EventTypeNID
+ // The numeric ID for the state key.
+ EventStateKeyNID EventStateKeyNID
+}
+
+// LessThan returns true if this state key is less than the other state key.
+// The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries.
+func (a StateKeyTuple) LessThan(b StateKeyTuple) bool {
+ if a.EventTypeNID != b.EventTypeNID {
+ return a.EventTypeNID < b.EventTypeNID
+ }
+ return a.EventStateKeyNID < b.EventStateKeyNID
+}
+
+// A StateEntry is an entry in the room state of a matrix room.
+type StateEntry struct {
+ StateKeyTuple
+ // The numeric ID for the event.
+ EventNID EventNID
+}
+
+// LessThan returns true if this state entry is less than the other state entry.
+// The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries.
+func (a StateEntry) LessThan(b StateEntry) bool {
+ if a.StateKeyTuple != b.StateKeyTuple {
+ return a.StateKeyTuple.LessThan(b.StateKeyTuple)
+ }
+ return a.EventNID < b.EventNID
+}
+
+// StateAtEvent is the state before and after a matrix event.
+type StateAtEvent struct {
+ // The state before the event.
+ BeforeStateSnapshotNID StateSnapshotNID
+ // The state entry for the event itself, allows us to calculate the state after the event.
+ StateEntry
+}
+
+// IsStateEvent returns whether the event the state is at is a state event.
+func (s StateAtEvent) IsStateEvent() bool {
+ return s.EventStateKeyNID != 0
+}
+
+// StateAtEventAndReference is StateAtEvent and gomatrixserverlib.EventReference glued together.
+// It is used when looking up the latest events in a room in the database.
+// The gomatrixserverlib.EventReference is used to check whether a new event references the event.
+// The StateAtEvent is used to construct the current state of the room from the latest events.
+type StateAtEventAndReference struct {
+ StateAtEvent
+ gomatrixserverlib.EventReference
+}
+
+// An Event is a gomatrixserverlib.Event with the numeric event ID attached.
+// It is when performing bulk event lookup in the database.
+type Event struct {
+ EventNID EventNID
+ gomatrixserverlib.Event
+}
+
+const (
+ // MRoomCreateNID is the numeric ID for the "m.room.create" event type.
+ MRoomCreateNID = 1
+ // MRoomPowerLevelsNID is the numeric ID for the "m.room.power_levels" event type.
+ MRoomPowerLevelsNID = 2
+ // MRoomJoinRulesNID is the numeric ID for the "m.room.join_rules" event type.
+ MRoomJoinRulesNID = 3
+ // MRoomThirdPartyInviteNID is the numeric ID for the "m.room.third_party_invite" event type.
+ MRoomThirdPartyInviteNID = 4
+ // MRoomMemberNID is the numeric ID for the "m.room.member" event type.
+ MRoomMemberNID = 5
+ // MRoomRedactionNID is the numeric ID for the "m.room.redaction" event type.
+ MRoomRedactionNID = 6
+ // MRoomHistoryVisibilityNID is the numeric ID for the "m.room.history_visibility" event type.
+ MRoomHistoryVisibilityNID = 7
+)
+
+const (
+ // EmptyStateKeyNID is the numeric ID for the empty state key.
+ EmptyStateKeyNID = 1
+)
+
+// StateBlockNIDList is used to return the result of bulk StateBlockNID lookups from the database.
+type StateBlockNIDList struct {
+ StateSnapshotNID StateSnapshotNID
+ StateBlockNIDs []StateBlockNID
+}
+
+// StateEntryList is used to return the result of bulk state entry lookups from the database.
+type StateEntryList struct {
+ StateBlockNID StateBlockNID
+ StateEntries []StateEntry
+}
+
+// A RoomRecentEventsUpdater is used to update the recent events in a room.
+// (On postgresql this wraps a database transaction that holds a "FOR UPDATE"
+// lock on the row in the rooms table holding the latest events for the room.)
+type RoomRecentEventsUpdater interface {
+ // The latest event IDs and state in the room.
+ LatestEvents() []StateAtEventAndReference
+ // The event ID of the latest event written to the output log in the room.
+ LastEventIDSent() string
+ // The current state of the room.
+ CurrentStateSnapshotNID() StateSnapshotNID
+ // Store the previous events referenced by an event.
+ // This adds the event NID to an entry in the database for each of the previous events.
+ // If there isn't an entry for one of previous events then an entry is created.
+ // If the entry already lists the event NID as a referrer then the entry unmodified.
+ // (i.e. the operation is idempotent)
+ StorePreviousEvents(eventNID EventNID, previousEventReferences []gomatrixserverlib.EventReference) error
+ // Check whether the eventReference is already referenced by another matrix event.
+ IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error)
+ // Set the list of latest events for the room.
+ // This replaces the current list stored in the database with the given list
+ SetLatestEvents(
+ roomNID RoomNID, latest []StateAtEventAndReference, lastEventNIDSent EventNID,
+ currentStateSnapshotNID StateSnapshotNID,
+ ) error
+ // Check if the event has already be written to the output logs.
+ HasEventBeenSent(eventNID EventNID) (bool, error)
+ // Mark the event as having been sent to the output logs.
+ MarkEventAsSent(eventNID EventNID) error
+ // Build a membership updater for the target user in this room.
+ // It will share the same transaction as this updater.
+ MembershipUpdater(targetUserNID EventStateKeyNID) (MembershipUpdater, error)
+ // Implements Transaction so it can be committed or rolledback
+ common.Transaction
+}
+
+// A MembershipUpdater is used to update the membership of a user in a room.
+// (On postgresql this wraps a database transaction that holds a "FOR UPDATE"
+// lock on the row in the membership table for this user in the room)
+// The caller should call one of SetToInvite, SetToJoin or SetToLeave once to
+// make the update, or none of them if no update is required.
+type MembershipUpdater interface {
+ // True if the target user is invited to the room before updating.
+ IsInvite() bool
+ // True if the target user is joined to the room before updating.
+ IsJoin() bool
+ // True if the target user is not invited or joined to the room before updating.
+ IsLeave() bool
+ // Set the state to invite.
+ // Returns whether this invite needs to be sent
+ SetToInvite(event gomatrixserverlib.Event) (needsSending bool, err error)
+ // Set the state to join or updates the event ID in the database.
+ // Returns a list of invite event IDs that this state change retired.
+ SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error)
+ // Set the state to leave.
+ // Returns a list of invite event IDs that this state change retired.
+ SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error)
+ // Implements Transaction so it can be committed or rolledback.
+ common.Transaction
+}
+
+// A MissingEventError is an error that happened because the roomserver was
+// missing requested events from its database.
+type MissingEventError string
+
+func (e MissingEventError) Error() string { return string(e) }