diff options
author | ruben <code@rbn.im> | 2019-05-21 22:56:55 +0200 |
---|---|---|
committer | Brendan Abolivier <babolivier@matrix.org> | 2019-05-21 21:56:55 +0100 |
commit | 74827428bd3e11faab65f12204449c1b9469b0ae (patch) | |
tree | 0decafa542436a0667ed2d3e3cfd4df0f03de1e5 /roomserver | |
parent | 4d588f7008afe5600219ac0930c2eee2de5c447b (diff) |
use go module for dependencies (#594)
Diffstat (limited to 'roomserver')
35 files changed, 7786 insertions, 0 deletions
diff --git a/roomserver/README.md b/roomserver/README.md new file mode 100644 index 00000000..5a275760 --- /dev/null +++ b/roomserver/README.md @@ -0,0 +1,59 @@ +# RoomServer + + +## RoomServer Internals + +### Numeric IDs + +To save space matrix string identifiers are mapped to local numeric IDs. +The numeric IDs are more efficient to manipulate and use less space to store. +The numeric IDs are never exposed in the API the room server exposes. +The numeric IDs are converted to string IDs before they leave the room server. +The numeric ID for a string ID is never 0 to avoid being confused with go's +default zero value. +Zero is used to indicate that there was no corresponding string ID. +Well-known event types and event state keys are preassigned numeric IDs. + +### State Snapshot Storage + +The room server stores the state of the matrix room at each event. +For efficiency the state is stored as blocks of 3-tuples of numeric IDs for the +event type, event state key and event ID. For further efficiency the state +snapshots are stored as the combination of up to 64 these blocks. This allows +blocks of the room state to be reused in multiple snapshots. + +The resulting database tables look something like this: + + +-------------------------------------------------------------------+ + | Events | + +---------+-------------------+------------------+------------------+ + | EventNID| EventTypeNID | EventStateKeyNID | StateSnapshotNID | + +---------+-------------------+------------------+------------------+ + | 1 | m.room.create 1 | "" 1 | <nil> 0 | + | 2 | m.room.member 2 | "@user:foo" 2 | <nil> 0 | + | 3 | m.room.member 2 | "@user:bar" 3 | {1,2} 1 | + | 4 | m.room.message 3 | <nil> 0 | {1,2,3} 2 | + | 5 | m.room.member 2 | "@user:foo" 2 | {1,2,3} 2 | + | 6 | m.room.message 3 | <nil> 0 | {1,3,6} 3 | + +---------+-------------------+------------------+------------------+ + + +----------------------------------------+ + | State Snapshots | + +-----------------------+----------------+ + | EventStateSnapshotNID | StateBlockNIDs | + +-----------------------+----------------| + | 1 | {1} | + | 2 | {1,2} | + | 3 | {1,2,3} | + +-----------------------+----------------+ + + +-----------------------------------------------------------------+ + | State Blocks | + +---------------+-------------------+------------------+----------+ + | StateBlockNID | EventTypeNID | EventStateKeyNID | EventNID | + +---------------+-------------------+------------------+----------+ + | 1 | m.room.create 1 | "" 1 | 1 | + | 1 | m.room.member 2 | "@user:foo" 2 | 2 | + | 2 | m.room.member 2 | "@user:bar" 3 | 3 | + | 3 | m.room.member 2 | "@user:foo" 2 | 6 | + +---------------+-------------------+------------------+----------+ diff --git a/roomserver/alias/alias.go b/roomserver/alias/alias.go new file mode 100644 index 00000000..27279aad --- /dev/null +++ b/roomserver/alias/alias.go @@ -0,0 +1,285 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package alias + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/common/config" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// RoomserverAliasAPIDatabase has the storage APIs needed to implement the alias API. +type RoomserverAliasAPIDatabase interface { + // Save a given room alias with the room ID it refers to. + // Returns an error if there was a problem talking to the database. + SetRoomAlias(ctx context.Context, alias string, roomID string) error + // Look up the room ID a given alias refers to. + // Returns an error if there was a problem talking to the database. + GetRoomIDForAlias(ctx context.Context, alias string) (string, error) + // Look up all aliases referring to a given room ID. + // Returns an error if there was a problem talking to the database. + GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) + // Remove a given room alias. + // Returns an error if there was a problem talking to the database. + RemoveRoomAlias(ctx context.Context, alias string) error +} + +// RoomserverAliasAPI is an implementation of alias.RoomserverAliasAPI +type RoomserverAliasAPI struct { + DB RoomserverAliasAPIDatabase + Cfg *config.Dendrite + InputAPI roomserverAPI.RoomserverInputAPI + QueryAPI roomserverAPI.RoomserverQueryAPI + AppserviceAPI appserviceAPI.AppServiceQueryAPI +} + +// SetRoomAlias implements alias.RoomserverAliasAPI +func (r *RoomserverAliasAPI) SetRoomAlias( + ctx context.Context, + request *roomserverAPI.SetRoomAliasRequest, + response *roomserverAPI.SetRoomAliasResponse, +) error { + // Check if the alias isn't already referring to a room + roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) + if err != nil { + return err + } + if len(roomID) > 0 { + // If the alias already exists, stop the process + response.AliasExists = true + return nil + } + response.AliasExists = false + + // Save the new alias + if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID); err != nil { + return err + } + + // Send a m.room.aliases event with the updated list of aliases for this room + // At this point we've already committed the alias to the database so we + // shouldn't cancel this request. + // TODO: Ensure that we send unsent events when if server restarts. + return r.sendUpdatedAliasesEvent(context.TODO(), request.UserID, request.RoomID) +} + +// GetRoomIDForAlias implements alias.RoomserverAliasAPI +func (r *RoomserverAliasAPI) GetRoomIDForAlias( + ctx context.Context, + request *roomserverAPI.GetRoomIDForAliasRequest, + response *roomserverAPI.GetRoomIDForAliasResponse, +) error { + // Look up the room ID in the database + roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) + if err != nil { + return err + } + + // No rooms found locally, try our application services by making a call to + // the appservice component + aliasReq := appserviceAPI.RoomAliasExistsRequest{Alias: request.Alias} + var aliasResp appserviceAPI.RoomAliasExistsResponse + if err = r.AppserviceAPI.RoomAliasExists(ctx, &aliasReq, &aliasResp); err != nil { + return err + } + + response.RoomID = roomID + return nil +} + +// GetAliasesForRoomID implements alias.RoomserverAliasAPI +func (r *RoomserverAliasAPI) GetAliasesForRoomID( + ctx context.Context, + request *roomserverAPI.GetAliasesForRoomIDRequest, + response *roomserverAPI.GetAliasesForRoomIDResponse, +) error { + // Look up the aliases in the database for the given RoomID + aliases, err := r.DB.GetAliasesForRoomID(ctx, request.RoomID) + if err != nil { + return err + } + + response.Aliases = aliases + return nil +} + +// RemoveRoomAlias implements alias.RoomserverAliasAPI +func (r *RoomserverAliasAPI) RemoveRoomAlias( + ctx context.Context, + request *roomserverAPI.RemoveRoomAliasRequest, + response *roomserverAPI.RemoveRoomAliasResponse, +) error { + // Look up the room ID in the database + roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) + if err != nil { + return err + } + + // Remove the dalias from the database + if err := r.DB.RemoveRoomAlias(ctx, request.Alias); err != nil { + return err + } + + // Send an updated m.room.aliases event + // At this point we've already committed the alias to the database so we + // shouldn't cancel this request. + // TODO: Ensure that we send unsent events when if server restarts. + return r.sendUpdatedAliasesEvent(context.TODO(), request.UserID, roomID) +} + +type roomAliasesContent struct { + Aliases []string `json:"aliases"` +} + +// Build the updated m.room.aliases event to send to the room after addition or +// removal of an alias +func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent( + ctx context.Context, userID string, roomID string, +) error { + serverName := string(r.Cfg.Matrix.ServerName) + + builder := gomatrixserverlib.EventBuilder{ + Sender: userID, + RoomID: roomID, + Type: "m.room.aliases", + StateKey: &serverName, + } + + // Retrieve the updated list of aliases, marhal it and set it as the + // event's content + aliases, err := r.DB.GetAliasesForRoomID(ctx, roomID) + if err != nil { + return err + } + content := roomAliasesContent{Aliases: aliases} + rawContent, err := json.Marshal(content) + if err != nil { + return err + } + err = builder.SetContent(json.RawMessage(rawContent)) + if err != nil { + return err + } + + // Get needed state events and depth + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(&builder) + if err != nil { + return err + } + req := roomserverAPI.QueryLatestEventsAndStateRequest{ + RoomID: roomID, + StateToFetch: eventsNeeded.Tuples(), + } + var res roomserverAPI.QueryLatestEventsAndStateResponse + if err = r.QueryAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil { + return err + } + builder.Depth = res.Depth + builder.PrevEvents = res.LatestEvents + + // Add auth events + authEvents := gomatrixserverlib.NewAuthEvents(nil) + for i := range res.StateEvents { + err = authEvents.AddEvent(&res.StateEvents[i]) + if err != nil { + return err + } + } + refs, err := eventsNeeded.AuthEventReferences(&authEvents) + if err != nil { + return err + } + builder.AuthEvents = refs + + // Build the event + eventID := fmt.Sprintf("$%s:%s", util.RandomString(16), r.Cfg.Matrix.ServerName) + now := time.Now() + event, err := builder.Build( + eventID, now, r.Cfg.Matrix.ServerName, r.Cfg.Matrix.KeyID, r.Cfg.Matrix.PrivateKey, + ) + if err != nil { + return err + } + + // Create the request + ire := roomserverAPI.InputRoomEvent{ + Kind: roomserverAPI.KindNew, + Event: event, + AuthEventIDs: event.AuthEventIDs(), + SendAsServer: serverName, + } + inputReq := roomserverAPI.InputRoomEventsRequest{ + InputRoomEvents: []roomserverAPI.InputRoomEvent{ire}, + } + var inputRes roomserverAPI.InputRoomEventsResponse + + // Send the request + return r.InputAPI.InputRoomEvents(ctx, &inputReq, &inputRes) +} + +// SetupHTTP adds the RoomserverAliasAPI handlers to the http.ServeMux. +func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux) { + servMux.Handle( + roomserverAPI.RoomserverSetRoomAliasPath, + common.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse { + var request roomserverAPI.SetRoomAliasRequest + var response roomserverAPI.SetRoomAliasResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.SetRoomAlias(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + servMux.Handle( + roomserverAPI.RoomserverGetRoomIDForAliasPath, + common.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse { + var request roomserverAPI.GetRoomIDForAliasRequest + var response roomserverAPI.GetRoomIDForAliasResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.GetRoomIDForAlias(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + servMux.Handle( + roomserverAPI.RoomserverRemoveRoomAliasPath, + common.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse { + var request roomserverAPI.RemoveRoomAliasRequest + var response roomserverAPI.RemoveRoomAliasResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.RemoveRoomAlias(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go new file mode 100644 index 00000000..57671071 --- /dev/null +++ b/roomserver/api/alias.go @@ -0,0 +1,183 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "net/http" + + commonHTTP "github.com/matrix-org/dendrite/common/http" + opentracing "github.com/opentracing/opentracing-go" +) + +// SetRoomAliasRequest is a request to SetRoomAlias +type SetRoomAliasRequest struct { + // ID of the user setting the alias + UserID string `json:"user_id"` + // New alias for the room + Alias string `json:"alias"` + // The room ID the alias is referring to + RoomID string `json:"room_id"` +} + +// SetRoomAliasResponse is a response to SetRoomAlias +type SetRoomAliasResponse struct { + // Does the alias already refer to a room? + AliasExists bool `json:"alias_exists"` +} + +// GetRoomIDForAliasRequest is a request to GetRoomIDForAlias +type GetRoomIDForAliasRequest struct { + // Alias we want to lookup + Alias string `json:"alias"` +} + +// GetRoomIDForAliasResponse is a response to GetRoomIDForAlias +type GetRoomIDForAliasResponse struct { + // The room ID the alias refers to + RoomID string `json:"room_id"` +} + +// GetAliasesForRoomIDRequest is a request to GetAliasesForRoomID +type GetAliasesForRoomIDRequest struct { + // The room ID we want to find aliases for + RoomID string `json:"room_id"` +} + +// GetAliasesForRoomIDResponse is a response to GetAliasesForRoomID +type GetAliasesForRoomIDResponse struct { + // The aliases the alias refers to + Aliases []string `json:"aliases"` +} + +// RemoveRoomAliasRequest is a request to RemoveRoomAlias +type RemoveRoomAliasRequest struct { + // ID of the user removing the alias + UserID string `json:"user_id"` + // The room alias to remove + Alias string `json:"alias"` +} + +// RemoveRoomAliasResponse is a response to RemoveRoomAlias +type RemoveRoomAliasResponse struct{} + +// RoomserverAliasAPI is used to save, lookup or remove a room alias +type RoomserverAliasAPI interface { + // Set a room alias + SetRoomAlias( + ctx context.Context, + req *SetRoomAliasRequest, + response *SetRoomAliasResponse, + ) error + + // Get the room ID for an alias + GetRoomIDForAlias( + ctx context.Context, + req *GetRoomIDForAliasRequest, + response *GetRoomIDForAliasResponse, + ) error + + // Get all known aliases for a room ID + GetAliasesForRoomID( + ctx context.Context, + req *GetAliasesForRoomIDRequest, + response *GetAliasesForRoomIDResponse, + ) error + + // Remove a room alias + RemoveRoomAlias( + ctx context.Context, + req *RemoveRoomAliasRequest, + response *RemoveRoomAliasResponse, + ) error +} + +// RoomserverSetRoomAliasPath is the HTTP path for the SetRoomAlias API. +const RoomserverSetRoomAliasPath = "/api/roomserver/setRoomAlias" + +// RoomserverGetRoomIDForAliasPath is the HTTP path for the GetRoomIDForAlias API. +const RoomserverGetRoomIDForAliasPath = "/api/roomserver/GetRoomIDForAlias" + +// RoomserverGetAliasesForRoomIDPath is the HTTP path for the GetAliasesForRoomID API. +const RoomserverGetAliasesForRoomIDPath = "/api/roomserver/GetAliasesForRoomID" + +// RoomserverRemoveRoomAliasPath is the HTTP path for the RemoveRoomAlias API. +const RoomserverRemoveRoomAliasPath = "/api/roomserver/removeRoomAlias" + +// NewRoomserverAliasAPIHTTP creates a RoomserverAliasAPI implemented by talking to a HTTP POST API. +// If httpClient is nil then it uses the http.DefaultClient +func NewRoomserverAliasAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverAliasAPI { + if httpClient == nil { + httpClient = http.DefaultClient + } + return &httpRoomserverAliasAPI{roomserverURL, httpClient} +} + +type httpRoomserverAliasAPI struct { + roomserverURL string + httpClient *http.Client +} + +// SetRoomAlias implements RoomserverAliasAPI +func (h *httpRoomserverAliasAPI) SetRoomAlias( + ctx context.Context, + request *SetRoomAliasRequest, + response *SetRoomAliasResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "SetRoomAlias") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverSetRoomAliasPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// GetRoomIDForAlias implements RoomserverAliasAPI +func (h *httpRoomserverAliasAPI) GetRoomIDForAlias( + ctx context.Context, + request *GetRoomIDForAliasRequest, + response *GetRoomIDForAliasResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "GetRoomIDForAlias") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverGetRoomIDForAliasPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// GetAliasesForRoomID implements RoomserverAliasAPI +func (h *httpRoomserverAliasAPI) GetAliasesForRoomID( + ctx context.Context, + request *GetAliasesForRoomIDRequest, + response *GetAliasesForRoomIDResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "GetAliasesForRoomID") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverGetAliasesForRoomIDPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// RemoveRoomAlias implements RoomserverAliasAPI +func (h *httpRoomserverAliasAPI) RemoveRoomAlias( + ctx context.Context, + request *RemoveRoomAliasRequest, + response *RemoveRoomAliasResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "RemoveRoomAlias") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverRemoveRoomAliasPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/roomserver/api/input.go b/roomserver/api/input.go new file mode 100644 index 00000000..2c2e27c6 --- /dev/null +++ b/roomserver/api/input.go @@ -0,0 +1,139 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package api provides the types that are used to communicate with the roomserver. +package api + +import ( + "context" + "net/http" + + commonHTTP "github.com/matrix-org/dendrite/common/http" + "github.com/matrix-org/gomatrixserverlib" + opentracing "github.com/opentracing/opentracing-go" +) + +const ( + // KindOutlier event fall outside the contiguous event graph. + // We do not have the state for these events. + // These events are state events used to authenticate other events. + // They can become part of the contiguous event graph via backfill. + KindOutlier = 1 + // KindNew event extend the contiguous graph going forwards. + // They usually don't need state, but may include state if the + // there was a new event that references an event that we don't + // have a copy of. + KindNew = 2 + // KindBackfill event extend the contiguous graph going backwards. + // They always have state. + KindBackfill = 3 +) + +// DoNotSendToOtherServers tells us not to send the event to other matrix +// servers. +const DoNotSendToOtherServers = "" + +// InputRoomEvent is a matrix room event to add to the room server database. +// TODO: Implement UnmarshalJSON/MarshalJSON in a way that does something sensible with the event JSON. +type InputRoomEvent struct { + // Whether this event is new, backfilled or an outlier. + // This controls how the event is processed. + Kind int `json:"kind"` + // The event JSON for the event to add. + Event gomatrixserverlib.Event `json:"event"` + // List of state event IDs that authenticate this event. + // These are likely derived from the "auth_events" JSON key of the event. + // But can be different because the "auth_events" key can be incomplete or wrong. + // For example many matrix events forget to reference the m.room.create event even though it is needed for auth. + // (since synapse allows this to happen we have to allow it as well.) + AuthEventIDs []string `json:"auth_event_ids"` + // Whether the state is supplied as a list of event IDs or whether it + // should be derived from the state at the previous events. + HasState bool `json:"has_state"` + // Optional list of state event IDs forming the state before this event. + // These state events must have already been persisted. + // These are only used if HasState is true. + // The list can be empty, for example when storing the first event in a room. + StateEventIDs []string `json:"state_event_ids"` + // The server name to use to push this event to other servers. + // Or empty if this event shouldn't be pushed to other servers. + SendAsServer string `json:"send_as_server"` + // The transaction ID of the send request if sent by a local user and one + // was specified + TransactionID *TransactionID `json:"transaction_id"` +} + +// TransactionID contains the transaction ID sent by a client when sending an +// event, along with the ID of that device. +type TransactionID struct { + DeviceID string `json:"device_id"` + TransactionID string `json:"id"` +} + +// InputInviteEvent is a matrix invite event received over federation without +// the usual context a matrix room event would have. We usually do not have +// access to the events needed to check the event auth rules for the invite. +type InputInviteEvent struct { + Event gomatrixserverlib.Event `json:"event"` +} + +// InputRoomEventsRequest is a request to InputRoomEvents +type InputRoomEventsRequest struct { + InputRoomEvents []InputRoomEvent `json:"input_room_events"` + InputInviteEvents []InputInviteEvent `json:"input_invite_events"` +} + +// InputRoomEventsResponse is a response to InputRoomEvents +type InputRoomEventsResponse struct { + EventID string `json:"event_id"` +} + +// RoomserverInputAPI is used to write events to the room server. +type RoomserverInputAPI interface { + InputRoomEvents( + ctx context.Context, + request *InputRoomEventsRequest, + response *InputRoomEventsResponse, + ) error +} + +// RoomserverInputRoomEventsPath is the HTTP path for the InputRoomEvents API. +const RoomserverInputRoomEventsPath = "/api/roomserver/inputRoomEvents" + +// NewRoomserverInputAPIHTTP creates a RoomserverInputAPI implemented by talking to a HTTP POST API. +// If httpClient is nil then it uses the http.DefaultClient +func NewRoomserverInputAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverInputAPI { + if httpClient == nil { + httpClient = http.DefaultClient + } + return &httpRoomserverInputAPI{roomserverURL, httpClient} +} + +type httpRoomserverInputAPI struct { + roomserverURL string + httpClient *http.Client +} + +// InputRoomEvents implements RoomserverInputAPI +func (h *httpRoomserverInputAPI) InputRoomEvents( + ctx context.Context, + request *InputRoomEventsRequest, + response *InputRoomEventsResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverInputRoomEventsPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/roomserver/api/output.go b/roomserver/api/output.go new file mode 100644 index 00000000..c09d5a1e --- /dev/null +++ b/roomserver/api/output.go @@ -0,0 +1,138 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "github.com/matrix-org/gomatrixserverlib" +) + +// An OutputType is a type of roomserver output. +type OutputType string + +const ( + // OutputTypeNewRoomEvent indicates that the event is an OutputNewRoomEvent + OutputTypeNewRoomEvent OutputType = "new_room_event" + // OutputTypeNewInviteEvent indicates that the event is an OutputNewInviteEvent + OutputTypeNewInviteEvent OutputType = "new_invite_event" + // OutputTypeRetireInviteEvent indicates that the event is an OutputRetireInviteEvent + OutputTypeRetireInviteEvent OutputType = "retire_invite_event" +) + +// An OutputEvent is an entry in the roomserver output kafka log. +// Consumers should check the type field when consuming this event. +type OutputEvent struct { + // What sort of event this is. + Type OutputType `json:"type"` + // The content of event with type OutputTypeNewRoomEvent + NewRoomEvent *OutputNewRoomEvent `json:"new_room_event,omitempty"` + // The content of event with type OutputTypeNewInviteEvent + NewInviteEvent *OutputNewInviteEvent `json:"new_invite_event,omitempty"` + // The content of event with type OutputTypeRetireInviteEvent + RetireInviteEvent *OutputRetireInviteEvent `json:"retire_invite_event,omitempty"` +} + +// An OutputNewRoomEvent is written when the roomserver receives a new event. +// It contains the full matrix room event and enough information for a +// consumer to construct the current state of the room and the state before the +// event. +// +// When we talk about state in a matrix room we are talking about the state +// after a list of events. The current state is the state after the latest +// event IDs in the room. The state before an event is the state after its +// prev_events. +type OutputNewRoomEvent struct { + // The Event. + Event gomatrixserverlib.Event `json:"event"` + // The latest events in the room after this event. + // This can be used to set the prev events for new events in the room. + // This also can be used to get the full current state after this event. + LatestEventIDs []string `json:"latest_event_ids"` + // The state event IDs that were added to the state of the room by this event. + // Together with RemovesStateEventIDs this allows the receiver to keep an up to date + // view of the current state of the room. + AddsStateEventIDs []string `json:"adds_state_event_ids"` + // The state event IDs that were removed from the state of the room by this event. + RemovesStateEventIDs []string `json:"removes_state_event_ids"` + // The ID of the event that was output before this event. + // Or the empty string if this is the first event output for this room. + // This is used by consumers to check if they can safely update their + // current state using the delta supplied in AddsStateEventIDs and + // RemovesStateEventIDs. + // + // If the LastSentEventID doesn't match what they were expecting it to be + // they can use the LatestEventIDs to request the full current state. + LastSentEventID string `json:"last_sent_event_id"` + // The state event IDs that are part of the state at the event, but not + // part of the current state. Together with the StateBeforeRemovesEventIDs + // this can be used to construct the state before the event from the + // current state. The StateBeforeAddsEventIDs and StateBeforeRemovesEventIDs + // delta is applied after the AddsStateEventIDs and RemovesStateEventIDs. + // + // Consumers need to know the state at each event in order to determine + // which users and servers are allowed to see the event. This information + // is needed to apply the history visibility rules and to tell which + // servers we need to push events to over federation. + // + // The state is given as a delta against the current state because they are + // usually either the same state, or differ by just a couple of events. + StateBeforeAddsEventIDs []string `json:"state_before_adds_event_ids"` + // The state event IDs that are part of the current state, but not part + // of the state at the event. + StateBeforeRemovesEventIDs []string `json:"state_before_removes_event_ids"` + // The server name to use to push this event to other servers. + // Or empty if this event shouldn't be pushed to other servers. + // + // This is used by the federation sender component. We need to tell it what + // event it needs to send because it can't tell on its own. Normally if an + // event was created on this server then we are responsible for sending it. + // However there are a couple of exceptions. The first is that when the + // server joins a remote room through another matrix server, it is the job + // of the other matrix server to send the event over federation. The second + // is the reverse of the first, that is when a remote server joins a room + // that we are in over federation using our server it is our responsibility + // to send the join event to other matrix servers. + // + // We encode the server name that the event should be sent using here to + // future proof the API for virtual hosting. + SendAsServer string `json:"send_as_server"` + // The transaction ID of the send request if sent by a local user and one + // was specified + TransactionID *TransactionID `json:"transaction_id"` +} + +// An OutputNewInviteEvent is written whenever an invite becomes active. +// Invite events can be received outside of an existing room so have to be +// tracked separately from the room events themselves. +type OutputNewInviteEvent struct { + // The "m.room.member" invite event. + Event gomatrixserverlib.Event `json:"event"` +} + +// An OutputRetireInviteEvent is written whenever an existing invite is no longer +// active. An invite stops being active if the user joins the room or if the +// invite is rejected by the user. +type OutputRetireInviteEvent struct { + // The ID of the "m.room.member" invite event. + EventID string + // The target user ID of the "m.room.member" invite event that was retired. + TargetUserID string + // Optional event ID of the event that replaced the invite. + // This can be empty if the invite was rejected locally and we were unable + // to reach the server that originally sent the invite. + RetiredByEventID string + // The "membership" of the user after retiring the invite. One of "join" + // "leave" or "ban". + Membership string +} diff --git a/roomserver/api/query.go b/roomserver/api/query.go new file mode 100644 index 00000000..a544f8aa --- /dev/null +++ b/roomserver/api/query.go @@ -0,0 +1,480 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "net/http" + + commonHTTP "github.com/matrix-org/dendrite/common/http" + "github.com/matrix-org/gomatrixserverlib" + opentracing "github.com/opentracing/opentracing-go" +) + +// QueryLatestEventsAndStateRequest is a request to QueryLatestEventsAndState +type QueryLatestEventsAndStateRequest struct { + // The room ID to query the latest events for. + RoomID string `json:"room_id"` + // The state key tuples to fetch from the room current state. + // If this list is empty or nil then no state events are returned. + StateToFetch []gomatrixserverlib.StateKeyTuple `json:"state_to_fetch"` +} + +// QueryLatestEventsAndStateResponse is a response to QueryLatestEventsAndState +// This is used when sending events to set the prev_events, auth_events and depth. +// It is also used to tell whether the event is allowed by the event auth rules. +type QueryLatestEventsAndStateResponse struct { + // Copy of the request for debugging. + QueryLatestEventsAndStateRequest + // Does the room exist? + // If the room doesn't exist this will be false and LatestEvents will be empty. + RoomExists bool `json:"room_exists"` + // The latest events in the room. + // These are used to set the prev_events when sending an event. + LatestEvents []gomatrixserverlib.EventReference `json:"latest_events"` + // The state events requested. + // This list will be in an arbitrary order. + // These are used to set the auth_events when sending an event. + // These are used to check whether the event is allowed. + StateEvents []gomatrixserverlib.Event `json:"state_events"` + // The depth of the latest events. + // This is one greater than the maximum depth of the latest events. + // This is used to set the depth when sending an event. + Depth int64 `json:"depth"` +} + +// QueryStateAfterEventsRequest is a request to QueryStateAfterEvents +type QueryStateAfterEventsRequest struct { + // The room ID to query the state in. + RoomID string `json:"room_id"` + // The list of previous events to return the events after. + PrevEventIDs []string `json:"prev_event_ids"` + // The state key tuples to fetch from the state + StateToFetch []gomatrixserverlib.StateKeyTuple `json:"state_to_fetch"` +} + +// QueryStateAfterEventsResponse is a response to QueryStateAfterEvents +type QueryStateAfterEventsResponse struct { + // Copy of the request for debugging. + QueryStateAfterEventsRequest + // Does the room exist on this roomserver? + // If the room doesn't exist this will be false and StateEvents will be empty. + RoomExists bool `json:"room_exists"` + // Do all the previous events exist on this roomserver? + // If some of previous events do not exist this will be false and StateEvents will be empty. + PrevEventsExist bool `json:"prev_events_exist"` + // The state events requested. + // This list will be in an arbitrary order. + StateEvents []gomatrixserverlib.Event `json:"state_events"` +} + +// QueryEventsByIDRequest is a request to QueryEventsByID +type QueryEventsByIDRequest struct { + // The event IDs to look up. + EventIDs []string `json:"event_ids"` +} + +// QueryEventsByIDResponse is a response to QueryEventsByID +type QueryEventsByIDResponse struct { + // Copy of the request for debugging. + QueryEventsByIDRequest + // A list of events with the requested IDs. + // If the roomserver does not have a copy of a requested event + // then it will omit that event from the list. + // If the roomserver thinks it has a copy of the event, but + // fails to read it from the database then it will fail + // the entire request. + // This list will be in an arbitrary order. + Events []gomatrixserverlib.Event `json:"events"` +} + +// QueryMembershipForUserRequest is a request to QueryMembership +type QueryMembershipForUserRequest struct { + // ID of the room to fetch membership from + RoomID string `json:"room_id"` + // ID of the user for whom membership is requested + UserID string `json:"user_id"` +} + +// QueryMembershipForUserResponse is a response to QueryMembership +type QueryMembershipForUserResponse struct { + // The EventID of the latest "m.room.member" event for the sender, + // if HasBeenInRoom is true. + EventID string `json:"event_id"` + // True if the user has been in room before and has either stayed in it or left it. + HasBeenInRoom bool `json:"has_been_in_room"` + // True if the user is in room. + IsInRoom bool `json:"is_in_room"` +} + +// QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom +type QueryMembershipsForRoomRequest struct { + // If true, only returns the membership events of "join" membership + JoinedOnly bool `json:"joined_only"` + // ID of the room to fetch memberships from + RoomID string `json:"room_id"` + // ID of the user sending the request + Sender string `json:"sender"` +} + +// QueryMembershipsForRoomResponse is a response to QueryMembershipsForRoom +type QueryMembershipsForRoomResponse struct { + // The "m.room.member" events (of "join" membership) in the client format + JoinEvents []gomatrixserverlib.ClientEvent `json:"join_events"` + // True if the user has been in room before and has either stayed in it or + // left it. + HasBeenInRoom bool `json:"has_been_in_room"` +} + +// QueryInvitesForUserRequest is a request to QueryInvitesForUser +type QueryInvitesForUserRequest struct { + // The room ID to look up invites in. + RoomID string `json:"room_id"` + // The User ID to look up invites for. + TargetUserID string `json:"target_user_id"` +} + +// QueryInvitesForUserResponse is a response to QueryInvitesForUser +// This is used when accepting an invite or rejecting a invite to tell which +// remote matrix servers to contact. +type QueryInvitesForUserResponse struct { + // A list of matrix user IDs for each sender of an active invite targeting + // the requested user ID. + InviteSenderUserIDs []string `json:"invite_sender_user_ids"` +} + +// QueryServerAllowedToSeeEventRequest is a request to QueryServerAllowedToSeeEvent +type QueryServerAllowedToSeeEventRequest struct { + // The event ID to look up invites in. + EventID string `json:"event_id"` + // The server interested in the event + ServerName gomatrixserverlib.ServerName `json:"server_name"` +} + +// QueryServerAllowedToSeeEventResponse is a response to QueryServerAllowedToSeeEvent +type QueryServerAllowedToSeeEventResponse struct { + // Wether the server in question is allowed to see the event + AllowedToSeeEvent bool `json:"can_see_event"` +} + +// QueryMissingEventsRequest is a request to QueryMissingEvents +type QueryMissingEventsRequest struct { + // Events which are known previous to the gap in the timeline. + EarliestEvents []string `json:"earliest_events"` + // Latest known events. + LatestEvents []string `json:"latest_events"` + // Limit the number of events this query returns. + Limit int `json:"limit"` + // The server interested in the event + ServerName gomatrixserverlib.ServerName `json:"server_name"` +} + +// QueryMissingEventsResponse is a response to QueryMissingEvents +type QueryMissingEventsResponse struct { + // Missing events, arbritrary order. + Events []gomatrixserverlib.Event `json:"events"` +} + +// QueryStateAndAuthChainRequest is a request to QueryStateAndAuthChain +type QueryStateAndAuthChainRequest struct { + // The room ID to query the state in. + RoomID string `json:"room_id"` + // The list of prev events for the event. Used to calculate the state at + // the event + PrevEventIDs []string `json:"prev_event_ids"` + // The list of auth events for the event. Used to calculate the auth chain + AuthEventIDs []string `json:"auth_event_ids"` +} + +// QueryStateAndAuthChainResponse is a response to QueryStateAndAuthChain +type QueryStateAndAuthChainResponse struct { + // Copy of the request for debugging. + QueryStateAndAuthChainRequest + // Does the room exist on this roomserver? + // If the room doesn't exist this will be false and StateEvents will be empty. + RoomExists bool `json:"room_exists"` + // Do all the previous events exist on this roomserver? + // If some of previous events do not exist this will be false and StateEvents will be empty. + PrevEventsExist bool `json:"prev_events_exist"` + // The state and auth chain events that were requested. + // The lists will be in an arbitrary order. + StateEvents []gomatrixserverlib.Event `json:"state_events"` + AuthChainEvents []gomatrixserverlib.Event `json:"auth_chain_events"` +} + +// QueryBackfillRequest is a request to QueryBackfill. +type QueryBackfillRequest struct { + // Events to start paginating from. + EarliestEventsIDs []string `json:"earliest_event_ids"` + // The maximum number of events to retrieve. + Limit int `json:"limit"` + // The server interested in the events. + ServerName gomatrixserverlib.ServerName `json:"server_name"` +} + +// QueryBackfillResponse is a response to QueryBackfill. +type QueryBackfillResponse struct { + // Missing events, arbritrary order. + Events []gomatrixserverlib.Event `json:"events"` +} + +// RoomserverQueryAPI is used to query information from the room server. +type RoomserverQueryAPI interface { + // Query the latest events and state for a room from the room server. + QueryLatestEventsAndState( + ctx context.Context, + request *QueryLatestEventsAndStateRequest, + response *QueryLatestEventsAndStateResponse, + ) error + + // Query the state after a list of events in a room from the room server. + QueryStateAfterEvents( + ctx context.Context, + request *QueryStateAfterEventsRequest, + response *QueryStateAfterEventsResponse, + ) error + + // Query a list of events by event ID. + QueryEventsByID( + ctx context.Context, + request *QueryEventsByIDRequest, + response *QueryEventsByIDResponse, + ) error + + // Query the membership event for an user for a room. + QueryMembershipForUser( + ctx context.Context, + request *QueryMembershipForUserRequest, + response *QueryMembershipForUserResponse, + ) error + + // Query a list of membership events for a room + QueryMembershipsForRoom( + ctx context.Context, + request *QueryMembershipsForRoomRequest, + response *QueryMembershipsForRoomResponse, + ) error + + // Query a list of invite event senders for a user in a room. + QueryInvitesForUser( + ctx context.Context, + request *QueryInvitesForUserRequest, + response *QueryInvitesForUserResponse, + ) error + + // Query whether a server is allowed to see an event + QueryServerAllowedToSeeEvent( + ctx context.Context, + request *QueryServerAllowedToSeeEventRequest, + response *QueryServerAllowedToSeeEventResponse, + ) error + + // Query missing events for a room from roomserver + QueryMissingEvents( + ctx context.Context, + request *QueryMissingEventsRequest, + response *QueryMissingEventsResponse, + ) error + + // Query to get state and auth chain for a (potentially hypothetical) event. + // Takes lists of PrevEventIDs and AuthEventsIDs and uses them to calculate + // the state and auth chain to return. + QueryStateAndAuthChain( + ctx context.Context, + request *QueryStateAndAuthChainRequest, + response *QueryStateAndAuthChainResponse, + ) error + + // Query a given amount (or less) of events prior to a given set of events. + QueryBackfill( + ctx context.Context, + request *QueryBackfillRequest, + response *QueryBackfillResponse, + ) error +} + +// RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API. +const RoomserverQueryLatestEventsAndStatePath = "/api/roomserver/queryLatestEventsAndState" + +// RoomserverQueryStateAfterEventsPath is the HTTP path for the QueryStateAfterEvents API. +const RoomserverQueryStateAfterEventsPath = "/api/roomserver/queryStateAfterEvents" + +// RoomserverQueryEventsByIDPath is the HTTP path for the QueryEventsByID API. +const RoomserverQueryEventsByIDPath = "/api/roomserver/queryEventsByID" + +// RoomserverQueryMembershipForUserPath is the HTTP path for the QueryMembershipForUser API. +const RoomserverQueryMembershipForUserPath = "/api/roomserver/queryMembershipForUser" + +// RoomserverQueryMembershipsForRoomPath is the HTTP path for the QueryMembershipsForRoom API +const RoomserverQueryMembershipsForRoomPath = "/api/roomserver/queryMembershipsForRoom" + +// RoomserverQueryInvitesForUserPath is the HTTP path for the QueryInvitesForUser API +const RoomserverQueryInvitesForUserPath = "/api/roomserver/queryInvitesForUser" + +// RoomserverQueryServerAllowedToSeeEventPath is the HTTP path for the QueryServerAllowedToSeeEvent API +const RoomserverQueryServerAllowedToSeeEventPath = "/api/roomserver/queryServerAllowedToSeeEvent" + +// RoomserverQueryMissingEventsPath is the HTTP path for the QueryMissingEvents API +const RoomserverQueryMissingEventsPath = "/api/roomserver/queryMissingEvents" + +// RoomserverQueryStateAndAuthChainPath is the HTTP path for the QueryStateAndAuthChain API +const RoomserverQueryStateAndAuthChainPath = "/api/roomserver/queryStateAndAuthChain" + +// RoomserverQueryBackfillPath is the HTTP path for the QueryMissingEvents API +const RoomserverQueryBackfillPath = "/api/roomserver/QueryBackfill" + +// NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API. +// If httpClient is nil then it uses the http.DefaultClient +func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverQueryAPI { + if httpClient == nil { + httpClient = http.DefaultClient + } + return &httpRoomserverQueryAPI{roomserverURL, httpClient} +} + +type httpRoomserverQueryAPI struct { + roomserverURL string + httpClient *http.Client +} + +// QueryLatestEventsAndState implements RoomserverQueryAPI +func (h *httpRoomserverQueryAPI) QueryLatestEventsAndState( + ctx context.Context, + request *QueryLatestEventsAndStateRequest, + response *QueryLatestEventsAndStateResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLatestEventsAndState") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryStateAfterEvents implements RoomserverQueryAPI +func (h *httpRoomserverQueryAPI) QueryStateAfterEvents( + ctx context.Context, + request *QueryStateAfterEventsRequest, + response *QueryStateAfterEventsResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAfterEvents") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryStateAfterEventsPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryEventsByID implements RoomserverQueryAPI +func (h *httpRoomserverQueryAPI) QueryEventsByID( + ctx context.Context, + request *QueryEventsByIDRequest, + response *QueryEventsByIDResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryEventsByID") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryEventsByIDPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryMembershipForUser implements RoomserverQueryAPI +func (h *httpRoomserverQueryAPI) QueryMembershipForUser( + ctx context.Context, + request *QueryMembershipForUserRequest, + response *QueryMembershipForUserResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipForUser") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryMembershipForUserPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryMembershipsForRoom implements RoomserverQueryAPI +func (h *httpRoomserverQueryAPI) QueryMembershipsForRoom( + ctx context.Context, + request *QueryMembershipsForRoomRequest, + response *QueryMembershipsForRoomResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipsForRoom") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryMembershipsForRoomPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryInvitesForUser implements RoomserverQueryAPI +func (h *httpRoomserverQueryAPI) QueryInvitesForUser( + ctx context.Context, + request *QueryInvitesForUserRequest, + response *QueryInvitesForUserResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryInvitesForUser") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryInvitesForUserPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI +func (h *httpRoomserverQueryAPI) QueryServerAllowedToSeeEvent( + ctx context.Context, + request *QueryServerAllowedToSeeEventRequest, + response *QueryServerAllowedToSeeEventResponse, +) (err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerAllowedToSeeEvent") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryServerAllowedToSeeEventPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryMissingEvents implements RoomServerQueryAPI +func (h *httpRoomserverQueryAPI) QueryMissingEvents( + ctx context.Context, + request *QueryMissingEventsRequest, + response *QueryMissingEventsResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingEvents") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryMissingEventsPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryStateAndAuthChain implements RoomserverQueryAPI +func (h *httpRoomserverQueryAPI) QueryStateAndAuthChain( + ctx context.Context, + request *QueryStateAndAuthChainRequest, + response *QueryStateAndAuthChainResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAndAuthChain") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryStateAndAuthChainPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryBackfill implements RoomServerQueryAPI +func (h *httpRoomserverQueryAPI) QueryBackfill( + ctx context.Context, + request *QueryBackfillRequest, + response *QueryBackfillResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBackfill") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryMissingEventsPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go new file mode 100644 index 00000000..2dce6f6d --- /dev/null +++ b/roomserver/auth/auth.go @@ -0,0 +1,47 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import "github.com/matrix-org/gomatrixserverlib" + +// IsServerAllowed returns true if there exists a event in authEvents +// which allows server to view this event. That is true when a client on the server +// can view the event. Otherwise returns false. +func IsServerAllowed( + serverName gomatrixserverlib.ServerName, + authEvents []gomatrixserverlib.Event, +) bool { + for _, ev := range authEvents { + membership, err := ev.Membership() + if err != nil || membership != "join" { + continue + } + + stateKey := ev.StateKey() + if stateKey == nil { + continue + } + + _, domain, err := gomatrixserverlib.SplitID('@', *stateKey) + if err != nil { + continue + } + + if domain == serverName { + return true + } + } + + // TODO: Check if history visibility is shared and if the server is currently in the room + return false +} diff --git a/roomserver/input/authevents.go b/roomserver/input/authevents.go new file mode 100644 index 00000000..74be2ed3 --- /dev/null +++ b/roomserver/input/authevents.go @@ -0,0 +1,243 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package input + +import ( + "context" + "sort" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// checkAuthEvents checks that the event passes authentication checks +// Returns the numeric IDs for the auth events. +func checkAuthEvents( + ctx context.Context, + db RoomEventDatabase, + event gomatrixserverlib.Event, + authEventIDs []string, +) ([]types.EventNID, error) { + // Grab the numeric IDs for the supplied auth state events from the database. + authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs) + if err != nil { + return nil, err + } + // TODO: check for duplicate state keys here. + + // Work out which of the state events we actually need. + stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event}) + + // Load the actual auth events from the database. + authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) + if err != nil { + return nil, err + } + + // Check if the event is allowed. + if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + return nil, err + } + + // Return the numeric IDs for the auth events. + result := make([]types.EventNID, len(authStateEntries)) + for i := range authStateEntries { + result[i] = authStateEntries[i].EventNID + } + return result, nil +} + +type authEvents struct { + stateKeyNIDMap map[string]types.EventStateKeyNID + state stateEntryMap + events eventMap +} + +// Create implements gomatrixserverlib.AuthEventProvider +func (ae *authEvents) Create() (*gomatrixserverlib.Event, error) { + return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil +} + +// PowerLevels implements gomatrixserverlib.AuthEventProvider +func (ae *authEvents) PowerLevels() (*gomatrixserverlib.Event, error) { + return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil +} + +// JoinRules implements gomatrixserverlib.AuthEventProvider +func (ae *authEvents) JoinRules() (*gomatrixserverlib.Event, error) { + return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil +} + +// Memmber implements gomatrixserverlib.AuthEventProvider +func (ae *authEvents) Member(stateKey string) (*gomatrixserverlib.Event, error) { + return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil +} + +// ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider +func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Event, error) { + return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil +} + +func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event { + eventNID, ok := ae.state.lookup(types.StateKeyTuple{ + EventTypeNID: typeNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) + if !ok { + return nil + } + event, ok := ae.events.lookup(eventNID) + if !ok { + return nil + } + return &event.Event +} + +func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event { + stateKeyNID, ok := ae.stateKeyNIDMap[stateKey] + if !ok { + return nil + } + eventNID, ok := ae.state.lookup(types.StateKeyTuple{ + EventTypeNID: typeNID, + EventStateKeyNID: stateKeyNID, + }) + if !ok { + return nil + } + event, ok := ae.events.lookup(eventNID) + if !ok { + return nil + } + return &event.Event +} + +// loadAuthEvents loads the events needed for authentication from the supplied room state. +func loadAuthEvents( + ctx context.Context, + db RoomEventDatabase, + needed gomatrixserverlib.StateNeeded, + state []types.StateEntry, +) (result authEvents, err error) { + // Look up the numeric IDs for the state keys needed for auth. + var neededStateKeys []string + neededStateKeys = append(neededStateKeys, needed.Member...) + neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) + if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(ctx, neededStateKeys); err != nil { + return + } + + // Load the events we need. + result.state = state + var eventNIDs []types.EventNID + keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed) + for _, keyTuple := range keyTuplesNeeded { + eventNID, ok := result.state.lookup(keyTuple) + if ok { + eventNIDs = append(eventNIDs, eventNID) + } + } + if result.events, err = db.Events(ctx, eventNIDs); err != nil { + return + } + return +} + +// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events. +func stateKeyTuplesNeeded( + stateKeyNIDMap map[string]types.EventStateKeyNID, + stateNeeded gomatrixserverlib.StateNeeded, +) []types.StateKeyTuple { + var keyTuples []types.StateKeyTuple + if stateNeeded.Create { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomCreateNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) + } + if stateNeeded.PowerLevels { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomPowerLevelsNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) + } + if stateNeeded.JoinRules { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomJoinRulesNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) + } + for _, member := range stateNeeded.Member { + stateKeyNID, ok := stateKeyNIDMap[member] + if ok { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomMemberNID, + EventStateKeyNID: stateKeyNID, + }) + } + } + for _, token := range stateNeeded.ThirdPartyInvite { + stateKeyNID, ok := stateKeyNIDMap[token] + if ok { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomThirdPartyInviteNID, + EventStateKeyNID: stateKeyNID, + }) + } + } + return keyTuples +} + +// Map from event type, state key tuple to numeric event ID. +// Implemented using binary search on a sorted array. +type stateEntryMap []types.StateEntry + +// lookup an entry in the event map. +func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size and are controlled by us. + list := []types.StateEntry(m) + i := sort.Search(len(list), func(i int) bool { + return !list[i].StateKeyTuple.LessThan(stateKey) + }) + if i < len(list) && list[i].StateKeyTuple == stateKey { + ok = true + eventNID = list[i].EventNID + } + return +} + +// Map from numeric event ID to event. +// Implemented using binary search on a sorted array. +type eventMap []types.Event + +// lookup an entry in the event map. +func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size are controlled by us. + list := []types.Event(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].EventNID >= eventNID + }) + if i < len(list) && list[i].EventNID == eventNID { + ok = true + event = &list[i] + } + return +} diff --git a/roomserver/input/authevents_test.go b/roomserver/input/authevents_test.go new file mode 100644 index 00000000..0621a084 --- /dev/null +++ b/roomserver/input/authevents_test.go @@ -0,0 +1,136 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package input + +import ( + "testing" + + "github.com/matrix-org/dendrite/roomserver/types" +) + +func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) { + var list []types.StateEntry + for i := int64(0); i < entries; i++ { + list = append(list, types.StateEntry{ + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: types.EventTypeNID(i), + EventStateKeyNID: types.EventStateKeyNID(i), + }, + EventNID: types.EventNID(i), + }) + } + + for i := 0; i < b.N; i++ { + entryMap := stateEntryMap(list) + for j := int64(0); j < lookups; j++ { + entryMap.lookup(types.StateKeyTuple{ + EventTypeNID: types.EventTypeNID(j), + EventStateKeyNID: types.EventStateKeyNID(j), + }) + } + } +} + +func BenchmarkStateEntryMap100Lookup10(b *testing.B) { + benchmarkStateEntryMapLookup(100, 10, b) +} + +func BenchmarkStateEntryMap1000Lookup100(b *testing.B) { + benchmarkStateEntryMapLookup(1000, 100, b) +} + +func BenchmarkStateEntryMap100Lookup100(b *testing.B) { + benchmarkStateEntryMapLookup(100, 100, b) +} + +func BenchmarkStateEntryMap1000Lookup10000(b *testing.B) { + benchmarkStateEntryMapLookup(1000, 10000, b) +} + +func TestStateEntryMap(t *testing.T) { + entryMap := stateEntryMap([]types.StateEntry{ + {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1}, + {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 3}, EventNID: 2}, + {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 2, EventStateKeyNID: 1}, EventNID: 3}, + }) + + testCases := []struct { + inputTypeNID types.EventTypeNID + inputStateKey types.EventStateKeyNID + wantOK bool + wantEventNID types.EventNID + }{ + // Check that tuples that in the array are in the map. + {1, 1, true, 1}, + {1, 3, true, 2}, + {2, 1, true, 3}, + // Check that tuples that aren't in the array aren't in the map. + {0, 0, false, 0}, + {1, 2, false, 0}, + {3, 1, false, 0}, + } + + for _, testCase := range testCases { + keyTuple := types.StateKeyTuple{EventTypeNID: testCase.inputTypeNID, EventStateKeyNID: testCase.inputStateKey} + gotEventNID, gotOK := entryMap.lookup(keyTuple) + if testCase.wantOK != gotOK { + t.Fatalf("stateEntryMap lookup(%v): want ok to be %v, got %v", keyTuple, testCase.wantOK, gotOK) + } + if testCase.wantEventNID != gotEventNID { + t.Fatalf("stateEntryMap lookup(%v): want eventNID to be %v, got %v", keyTuple, testCase.wantEventNID, gotEventNID) + } + } +} + +func TestEventMap(t *testing.T) { + events := eventMap([]types.Event{ + {EventNID: 1}, + {EventNID: 2}, + {EventNID: 3}, + {EventNID: 5}, + {EventNID: 8}, + }) + + testCases := []struct { + inputEventNID types.EventNID + wantOK bool + wantEvent *types.Event + }{ + // Check that the IDs that are in the array are in the map. + {1, true, &events[0]}, + {2, true, &events[1]}, + {3, true, &events[2]}, + {5, true, &events[3]}, + {8, true, &events[4]}, + // Check that tuples that aren't in the array aren't in the map. + {0, false, nil}, + {4, false, nil}, + {6, false, nil}, + {7, false, nil}, + {9, false, nil}, + } + + for _, testCase := range testCases { + gotEvent, gotOK := events.lookup(testCase.inputEventNID) + if testCase.wantOK != gotOK { + t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK) + } + + if testCase.wantEvent != gotEvent { + t.Fatalf("eventMap lookup(%v): want event to be %v, got %v", testCase.inputEventNID, testCase.wantEvent, gotEvent) + } + } + +} diff --git a/roomserver/input/events.go b/roomserver/input/events.go new file mode 100644 index 00000000..feb15b3e --- /dev/null +++ b/roomserver/input/events.go @@ -0,0 +1,235 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package input + +import ( + "context" + "fmt" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// A RoomEventDatabase has the storage APIs needed to store a room event. +type RoomEventDatabase interface { + state.RoomStateDatabase + // Stores a matrix room event in the database + StoreEvent( + ctx context.Context, + event gomatrixserverlib.Event, + txnAndDeviceID *api.TransactionID, + authEventNIDs []types.EventNID, + ) (types.RoomNID, types.StateAtEvent, error) + // Look up the state entries for a list of string event IDs + // Returns an error if the there is an error talking to the database + // Returns a types.MissingEventError if the event IDs aren't in the database. + StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, + ) ([]types.StateEntry, error) + // Set the state at an event. + SetState( + ctx context.Context, + eventNID types.EventNID, + stateNID types.StateSnapshotNID, + ) error + // Look up the latest events in a room in preparation for an update. + // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. + // Returns the latest events in the room and the last eventID sent to the log along with an updater. + // If this returns an error then no further action is required. + GetLatestEventsForUpdate( + ctx context.Context, roomNID types.RoomNID, + ) (updater types.RoomRecentEventsUpdater, err error) + // Look up the string event IDs for a list of numeric event IDs + EventIDs( + ctx context.Context, eventNIDs []types.EventNID, + ) (map[types.EventNID]string, error) + // Build a membership updater for the target user in a room. + MembershipUpdater( + ctx context.Context, roomID, targerUserID string, + ) (types.MembershipUpdater, error) + // Look up event ID by transaction's info. + // This is used to determine if the room event is processed/processing already. + // Returns an empty string if no such event exists. + GetTransactionEventID( + ctx context.Context, transactionID string, + deviceID string, userID string, + ) (string, error) +} + +// OutputRoomEventWriter has the APIs needed to write an event to the output logs. +type OutputRoomEventWriter interface { + // Write a list of events for a room + WriteOutputEvents(roomID string, updates []api.OutputEvent) error +} + +// processRoomEvent can only be called once at a time +// +// TODO(#375): This should be rewritten to allow concurrent calls. The +// difficulty is in ensuring that we correctly annotate events with the correct +// state deltas when sending to kafka streams +func processRoomEvent( + ctx context.Context, + db RoomEventDatabase, + ow OutputRoomEventWriter, + input api.InputRoomEvent, +) (eventID string, err error) { + // Parse and validate the event JSON + event := input.Event + + // Check that the event passes authentication checks and work out the numeric IDs for the auth events. + authEventNIDs, err := checkAuthEvents(ctx, db, event, input.AuthEventIDs) + if err != nil { + return + } + + if input.TransactionID != nil { + tdID := input.TransactionID + eventID, err = db.GetTransactionEventID( + ctx, tdID.TransactionID, tdID.DeviceID, input.Event.Sender(), + ) + // On error OR event with the transaction already processed/processesing + if err != nil || eventID != "" { + return + } + } + + // Store the event + roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) + if err != nil { + return + } + + if input.Kind == api.KindOutlier { + // For outliers we can stop after we've stored the event itself as it + // doesn't have any associated state to store and we don't need to + // notify anyone about it. + return event.EventID(), nil + } + + if stateAtEvent.BeforeStateSnapshotNID == 0 { + // We haven't calculated a state for this event yet. + // Lets calculate one. + err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event) + if err != nil { + return + } + } + + if input.Kind == api.KindBackfill { + // Backfill is not implemented. + panic("Not implemented") + } + + // Update the extremities of the event graph for the room + return event.EventID(), updateLatestEvents( + ctx, db, ow, roomNID, stateAtEvent, event, input.SendAsServer, input.TransactionID, + ) +} + +func calculateAndSetState( + ctx context.Context, + db RoomEventDatabase, + input api.InputRoomEvent, + roomNID types.RoomNID, + stateAtEvent *types.StateAtEvent, + event gomatrixserverlib.Event, +) error { + var err error + if input.HasState { + // We've been told what the state at the event is so we don't need to calculate it. + // Check that those state events are in the database and store the state. + var entries []types.StateEntry + if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + return err + } + + if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil { + return err + } + } else { + // We haven't been told what the state at the event is so we need to calculate it from the prev_events + if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil { + return err + } + } + return db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) +} + +func processInviteEvent( + ctx context.Context, + db RoomEventDatabase, + ow OutputRoomEventWriter, + input api.InputInviteEvent, +) (err error) { + if input.Event.StateKey() == nil { + return fmt.Errorf("invite must be a state event") + } + + roomID := input.Event.RoomID() + targetUserID := *input.Event.StateKey() + + updater, err := db.MembershipUpdater(ctx, roomID, targetUserID) + if err != nil { + return err + } + succeeded := false + defer common.EndTransaction(updater, &succeeded) + + if updater.IsJoin() { + // If the user is joined to the room then that takes precedence over this + // invite event. It makes little sense to move a user that is already + // joined to the room into the invite state. + // This could plausibly happen if an invite request raced with a join + // request for a user. For example if a user was invited to a public + // room and they joined the room at the same time as the invite was sent. + // The other way this could plausibly happen is if an invite raced with + // a kick. For example if a user was kicked from a room in error and in + // response someone else in the room re-invited them then it is possible + // for the invite request to race with the leave event so that the + // target receives invite before it learns that it has been kicked. + // There are a few ways this could be plausibly handled in the roomserver. + // 1) Store the invite, but mark it as retired. That will result in the + // permanent rejection of that invite event. So even if the target + // user leaves the room and the invite is retransmitted it will be + // ignored. However a new invite with a new event ID would still be + // accepted. + // 2) Silently discard the invite event. This means that if the event + // was retransmitted at a later date after the target user had left + // the room we would accept the invite. However since we hadn't told + // the sending server that the invite had been discarded it would + // have no reason to attempt to retry. + // 3) Signal the sending server that the user is already joined to the + // room. + // For now we will implement option 2. Since in the abesence of a retry + // mechanism it will be equivalent to option 1, and we don't have a + // signalling mechanism to implement option 3. + return nil + } + + outputUpdates, err := updateToInviteMembership(updater, &input.Event, nil) + if err != nil { + return err + } + + if err = ow.WriteOutputEvents(roomID, outputUpdates); err != nil { + return err + } + + succeeded = true + return nil +} diff --git a/roomserver/input/input.go b/roomserver/input/input.go new file mode 100644 index 00000000..bd029d8d --- /dev/null +++ b/roomserver/input/input.go @@ -0,0 +1,95 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package input contains the code processes new room events +package input + +import ( + "context" + "encoding/json" + "net/http" + "sync" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/util" + sarama "gopkg.in/Shopify/sarama.v1" +) + +// RoomserverInputAPI implements api.RoomserverInputAPI +type RoomserverInputAPI struct { + DB RoomEventDatabase + Producer sarama.SyncProducer + // The kafkaesque topic to output new room events to. + // This is the name used in kafka to identify the stream to write events to. + OutputRoomEventTopic string + // Protects calls to processRoomEvent + mutex sync.Mutex +} + +// WriteOutputEvents implements OutputRoomEventWriter +func (r *RoomserverInputAPI) WriteOutputEvents(roomID string, updates []api.OutputEvent) error { + messages := make([]*sarama.ProducerMessage, len(updates)) + for i := range updates { + value, err := json.Marshal(updates[i]) + if err != nil { + return err + } + messages[i] = &sarama.ProducerMessage{ + Topic: r.OutputRoomEventTopic, + Key: sarama.StringEncoder(roomID), + Value: sarama.ByteEncoder(value), + } + } + return r.Producer.SendMessages(messages) +} + +// InputRoomEvents implements api.RoomserverInputAPI +func (r *RoomserverInputAPI) InputRoomEvents( + ctx context.Context, + request *api.InputRoomEventsRequest, + response *api.InputRoomEventsResponse, +) (err error) { + // We lock as processRoomEvent can only be called once at a time + r.mutex.Lock() + defer r.mutex.Unlock() + for i := range request.InputRoomEvents { + if response.EventID, err = processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil { + return err + } + } + for i := range request.InputInviteEvents { + if err = processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil { + return err + } + } + return nil +} + +// SetupHTTP adds the RoomserverInputAPI handlers to the http.ServeMux. +func (r *RoomserverInputAPI) SetupHTTP(servMux *http.ServeMux) { + servMux.Handle(api.RoomserverInputRoomEventsPath, + common.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse { + var request api.InputRoomEventsRequest + var response api.InputRoomEventsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.InputRoomEvents(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/roomserver/input/latest_events.go b/roomserver/input/latest_events.go new file mode 100644 index 00000000..c2f06393 --- /dev/null +++ b/roomserver/input/latest_events.go @@ -0,0 +1,293 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package input + +import ( + "bytes" + "context" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// updateLatestEvents updates the list of latest events for this room in the database and writes the +// event to the output log. +// The latest events are the events that aren't referenced by another event in the database: +// +// Time goes down the page. 1 is the m.room.create event (root). +// +// 1 After storing 1 the latest events are {1} +// | After storing 2 the latest events are {2} +// 2 After storing 3 the latest events are {3} +// / \ After storing 4 the latest events are {3,4} +// 3 4 After storing 5 the latest events are {5,4} +// | | After storing 6 the latest events are {5,6} +// 5 6 <--- latest After storing 7 the latest events are {6,7} +// | +// 7 <----- latest +// +// Can only be called once at a time +func updateLatestEvents( + ctx context.Context, + db RoomEventDatabase, + ow OutputRoomEventWriter, + roomNID types.RoomNID, + stateAtEvent types.StateAtEvent, + event gomatrixserverlib.Event, + sendAsServer string, + transactionID *api.TransactionID, +) (err error) { + updater, err := db.GetLatestEventsForUpdate(ctx, roomNID) + if err != nil { + return + } + succeeded := false + defer common.EndTransaction(updater, &succeeded) + + u := latestEventsUpdater{ + ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID, + stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer, + transactionID: transactionID, + } + if err = u.doUpdateLatestEvents(); err != nil { + return err + } + + succeeded = true + return +} + +// latestEventsUpdater tracks the state used to update the latest events in the +// room. It mostly just ferries state between the various function calls. +// The state could be passed using function arguments, but it becomes impractical +// when there are so many variables to pass around. +type latestEventsUpdater struct { + ctx context.Context + db RoomEventDatabase + updater types.RoomRecentEventsUpdater + ow OutputRoomEventWriter + roomNID types.RoomNID + stateAtEvent types.StateAtEvent + event gomatrixserverlib.Event + transactionID *api.TransactionID + // Which server to send this event as. + sendAsServer string + // The eventID of the event that was processed before this one. + lastEventIDSent string + // The latest events in the room after processing this event. + latest []types.StateAtEventAndReference + // The state entries removed from and added to the current state of the + // room as a result of processing this event. They are sorted lists. + removed []types.StateEntry + added []types.StateEntry + // The state entries that are removed and added to recover the state before + // the event being processed. They are sorted lists. + stateBeforeEventRemoves []types.StateEntry + stateBeforeEventAdds []types.StateEntry + // The snapshots of current state before and after processing this event + oldStateNID types.StateSnapshotNID + newStateNID types.StateSnapshotNID +} + +func (u *latestEventsUpdater) doUpdateLatestEvents() error { + prevEvents := u.event.PrevEvents() + oldLatest := u.updater.LatestEvents() + u.lastEventIDSent = u.updater.LastEventIDSent() + u.oldStateNID = u.updater.CurrentStateSnapshotNID() + + hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID) + if err != nil { + return err + } else if hasBeenSent { + // Already sent this event so we can stop processing + return nil + } + + if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil { + return err + } + + eventReference := u.event.EventReference() + // Check if this event is already referenced by another event in the room. + alreadyReferenced, err := u.updater.IsReferenced(eventReference) + if err != nil { + return err + } + + u.latest = calculateLatest(oldLatest, alreadyReferenced, prevEvents, types.StateAtEventAndReference{ + EventReference: eventReference, + StateAtEvent: u.stateAtEvent, + }) + + if err = u.latestState(); err != nil { + return err + } + + updates, err := updateMemberships(u.ctx, u.db, u.updater, u.removed, u.added) + if err != nil { + return err + } + + update, err := u.makeOutputNewRoomEvent() + if err != nil { + return err + } + updates = append(updates, *update) + + // Send the event to the output logs. + // We do this inside the database transaction to ensure that we only mark an event as sent if we sent it. + // (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but + // the write to the output log succeeds) + // TODO: This assumes that writing the event to the output log is synchronous. It should be possible to + // send the event asynchronously but we would need to ensure that 1) the events are written to the log in + // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the + // necessary bookkeeping we'll keep the event sending synchronous for now. + if err = u.ow.WriteOutputEvents(u.event.RoomID(), updates); err != nil { + return err + } + + if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil { + return err + } + + return u.updater.MarkEventAsSent(u.stateAtEvent.EventNID) +} + +func (u *latestEventsUpdater) latestState() error { + var err error + + latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) + for i := range u.latest { + latestStateAtEvents[i] = u.latest[i].StateAtEvent + } + u.newStateNID, err = state.CalculateAndStoreStateAfterEvents( + u.ctx, u.db, u.roomNID, latestStateAtEvents, + ) + if err != nil { + return err + } + + u.removed, u.added, err = state.DifferenceBetweeenStateSnapshots( + u.ctx, u.db, u.oldStateNID, u.newStateNID, + ) + if err != nil { + return err + } + + u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = state.DifferenceBetweeenStateSnapshots( + u.ctx, u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, + ) + return err +} + +func calculateLatest( + oldLatest []types.StateAtEventAndReference, + alreadyReferenced bool, + prevEvents []gomatrixserverlib.EventReference, + newEvent types.StateAtEventAndReference, +) []types.StateAtEventAndReference { + var alreadyInLatest bool + var newLatest []types.StateAtEventAndReference + for _, l := range oldLatest { + keep := true + for _, prevEvent := range prevEvents { + if l.EventID == prevEvent.EventID && bytes.Equal(l.EventSHA256, prevEvent.EventSHA256) { + // This event can be removed from the latest events cause we've found an event that references it. + // (If an event is referenced by another event then it can't be one of the latest events in the room + // because we have an event that comes after it) + keep = false + break + } + } + if l.EventNID == newEvent.EventNID { + alreadyInLatest = true + } + if keep { + // Keep the event in the latest events. + newLatest = append(newLatest, l) + } + } + + if !alreadyReferenced && !alreadyInLatest { + // This event is not referenced by any of the events in the room + // and the event is not already in the latest events. + // Add it to the latest events + newLatest = append(newLatest, newEvent) + } + + return newLatest +} + +func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) { + + latestEventIDs := make([]string, len(u.latest)) + for i := range u.latest { + latestEventIDs[i] = u.latest[i].EventID + } + + ore := api.OutputNewRoomEvent{ + Event: u.event, + LastSentEventID: u.lastEventIDSent, + LatestEventIDs: latestEventIDs, + TransactionID: u.transactionID, + } + + var stateEventNIDs []types.EventNID + for _, entry := range u.added { + stateEventNIDs = append(stateEventNIDs, entry.EventNID) + } + for _, entry := range u.removed { + stateEventNIDs = append(stateEventNIDs, entry.EventNID) + } + for _, entry := range u.stateBeforeEventRemoves { + stateEventNIDs = append(stateEventNIDs, entry.EventNID) + } + for _, entry := range u.stateBeforeEventAdds { + stateEventNIDs = append(stateEventNIDs, entry.EventNID) + } + stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] + eventIDMap, err := u.db.EventIDs(u.ctx, stateEventNIDs) + if err != nil { + return nil, err + } + for _, entry := range u.added { + ore.AddsStateEventIDs = append(ore.AddsStateEventIDs, eventIDMap[entry.EventNID]) + } + for _, entry := range u.removed { + ore.RemovesStateEventIDs = append(ore.RemovesStateEventIDs, eventIDMap[entry.EventNID]) + } + for _, entry := range u.stateBeforeEventRemoves { + ore.StateBeforeRemovesEventIDs = append(ore.StateBeforeRemovesEventIDs, eventIDMap[entry.EventNID]) + } + for _, entry := range u.stateBeforeEventAdds { + ore.StateBeforeAddsEventIDs = append(ore.StateBeforeAddsEventIDs, eventIDMap[entry.EventNID]) + } + ore.SendAsServer = u.sendAsServer + + return &api.OutputEvent{ + Type: api.OutputTypeNewRoomEvent, + NewRoomEvent: &ore, + }, nil +} + +type eventNIDSorter []types.EventNID + +func (s eventNIDSorter) Len() int { return len(s) } +func (s eventNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s eventNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/input/membership.go b/roomserver/input/membership.go new file mode 100644 index 00000000..0c3fbb80 --- /dev/null +++ b/roomserver/input/membership.go @@ -0,0 +1,310 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package input + +import ( + "context" + "fmt" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// Membership values +// TODO: Factor these out somewhere sensible? +const join = "join" +const leave = "leave" +const invite = "invite" +const ban = "ban" + +// updateMembership updates the current membership and the invites for each +// user affected by a change in the current state of the room. +// Returns a list of output events to write to the kafka log to inform the +// consumers about the invites added or retired by the change in current state. +func updateMemberships( + ctx context.Context, + db RoomEventDatabase, + updater types.RoomRecentEventsUpdater, + removed, added []types.StateEntry, +) ([]api.OutputEvent, error) { + changes := membershipChanges(removed, added) + var eventNIDs []types.EventNID + for _, change := range changes { + if change.addedEventNID != 0 { + eventNIDs = append(eventNIDs, change.addedEventNID) + } + if change.removedEventNID != 0 { + eventNIDs = append(eventNIDs, change.removedEventNID) + } + } + + // Load the event JSON so we can look up the "membership" key. + // TODO: Maybe add a membership key to the events table so we can load that + // key without having to load the entire event JSON? + events, err := db.Events(ctx, eventNIDs) + if err != nil { + return nil, err + } + + var updates []api.OutputEvent + + for _, change := range changes { + var ae *gomatrixserverlib.Event + var re *gomatrixserverlib.Event + targetUserNID := change.EventStateKeyNID + if change.removedEventNID != 0 { + ev, _ := eventMap(events).lookup(change.removedEventNID) + if ev != nil { + re = &ev.Event + } + } + if change.addedEventNID != 0 { + ev, _ := eventMap(events).lookup(change.addedEventNID) + if ev != nil { + ae = &ev.Event + } + } + if updates, err = updateMembership(updater, targetUserNID, re, ae, updates); err != nil { + return nil, err + } + } + return updates, nil +} + +func updateMembership( + updater types.RoomRecentEventsUpdater, targetUserNID types.EventStateKeyNID, + remove, add *gomatrixserverlib.Event, + updates []api.OutputEvent, +) ([]api.OutputEvent, error) { + var err error + // Default the membership to Leave if no event was added or removed. + oldMembership := leave + newMembership := leave + + if remove != nil { + oldMembership, err = remove.Membership() + if err != nil { + return nil, err + } + } + if add != nil { + newMembership, err = add.Membership() + if err != nil { + return nil, err + } + } + if oldMembership == newMembership && newMembership != join { + // If the membership is the same then nothing changed and we can return + // immediately, unless it's a Join update (e.g. profile update). + return updates, nil + } + + mu, err := updater.MembershipUpdater(targetUserNID) + if err != nil { + return nil, err + } + + switch newMembership { + case invite: + return updateToInviteMembership(mu, add, updates) + case join: + return updateToJoinMembership(mu, add, updates) + case leave, ban: + return updateToLeaveMembership(mu, add, newMembership, updates) + default: + panic(fmt.Errorf( + "input: membership %q is not one of the allowed values", newMembership, + )) + } +} + +func updateToInviteMembership( + mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, +) ([]api.OutputEvent, error) { + // We may have already sent the invite to the user, either because we are + // reprocessing this event, or because the we received this invite from a + // remote server via the federation invite API. In those cases we don't need + // to send the event. + needsSending, err := mu.SetToInvite(*add) + if err != nil { + return nil, err + } + if needsSending { + // We notify the consumers using a special event even though we will + // notify them about the change in current state as part of the normal + // room event stream. This ensures that the consumers only have to + // consider a single stream of events when determining whether a user + // is invited, rather than having to combine multiple streams themselves. + onie := api.OutputNewInviteEvent{ + Event: *add, + } + updates = append(updates, api.OutputEvent{ + Type: api.OutputTypeNewInviteEvent, + NewInviteEvent: &onie, + }) + } + return updates, nil +} + +func updateToJoinMembership( + mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, +) ([]api.OutputEvent, error) { + // If the user is already marked as being joined, we call SetToJoin to update + // the event ID then we can return immediately. Retired is ignored as there + // is no invite event to retire. + if mu.IsJoin() { + _, err := mu.SetToJoin(add.Sender(), add.EventID(), true) + if err != nil { + return nil, err + } + return updates, nil + } + // When we mark a user as being joined we will invalidate any invites that + // are active for that user. We notify the consumers that the invites have + // been retired using a special event, even though they could infer this + // by studying the state changes in the room event stream. + retired, err := mu.SetToJoin(add.Sender(), add.EventID(), false) + if err != nil { + return nil, err + } + for _, eventID := range retired { + orie := api.OutputRetireInviteEvent{ + EventID: eventID, + Membership: join, + RetiredByEventID: add.EventID(), + TargetUserID: *add.StateKey(), + } + updates = append(updates, api.OutputEvent{ + Type: api.OutputTypeRetireInviteEvent, + RetireInviteEvent: &orie, + }) + } + return updates, nil +} + +func updateToLeaveMembership( + mu types.MembershipUpdater, add *gomatrixserverlib.Event, + newMembership string, updates []api.OutputEvent, +) ([]api.OutputEvent, error) { + // If the user is already neither joined, nor invited to the room then we + // can return immediately. + if mu.IsLeave() { + return updates, nil + } + // When we mark a user as having left we will invalidate any invites that + // are active for that user. We notify the consumers that the invites have + // been retired using a special event, even though they could infer this + // by studying the state changes in the room event stream. + retired, err := mu.SetToLeave(add.Sender(), add.EventID()) + if err != nil { + return nil, err + } + for _, eventID := range retired { + orie := api.OutputRetireInviteEvent{ + EventID: eventID, + Membership: newMembership, + RetiredByEventID: add.EventID(), + TargetUserID: *add.StateKey(), + } + updates = append(updates, api.OutputEvent{ + Type: api.OutputTypeRetireInviteEvent, + RetireInviteEvent: &orie, + }) + } + return updates, nil +} + +// membershipChanges pairs up the membership state changes from a sorted list +// of state removed and a sorted list of state added. +func membershipChanges(removed, added []types.StateEntry) []stateChange { + changes := pairUpChanges(removed, added) + var result []stateChange + for _, c := range changes { + if c.EventTypeNID == types.MRoomMemberNID { + result = append(result, c) + } + } + return result +} + +type stateChange struct { + types.StateKeyTuple + removedEventNID types.EventNID + addedEventNID types.EventNID +} + +// pairUpChanges pairs up the state events added and removed for each type, +// state key tuple. Assumes that removed and added are sorted. +func pairUpChanges(removed, added []types.StateEntry) []stateChange { + var ai int + var ri int + var result []stateChange + for { + switch { + case ai == len(added): + // We've reached the end of the added entries. + // The rest of the removed list are events that were removed without + // an event with the same state key being added. + for _, s := range removed[ri:] { + result = append(result, stateChange{ + StateKeyTuple: s.StateKeyTuple, + removedEventNID: s.EventNID, + }) + } + return result + case ri == len(removed): + // We've reached the end of the removed entries. + // The rest of the added list are events that were added without + // an event with the same state key being removed. + for _, s := range added[ai:] { + result = append(result, stateChange{ + StateKeyTuple: s.StateKeyTuple, + addedEventNID: s.EventNID, + }) + } + return result + case added[ai].StateKeyTuple == removed[ri].StateKeyTuple: + // The tuple is in both lists so an event with that key is being + // removed and another event with the same key is being added. + result = append(result, stateChange{ + StateKeyTuple: added[ai].StateKeyTuple, + removedEventNID: removed[ri].EventNID, + addedEventNID: added[ai].EventNID, + }) + ai++ + ri++ + case added[ai].StateKeyTuple.LessThan(removed[ri].StateKeyTuple): + // The lists are sorted so the added entry being less than the + // removed entry means that the added event was added without an + // event with the same key being removed. + result = append(result, stateChange{ + StateKeyTuple: added[ai].StateKeyTuple, + addedEventNID: added[ai].EventNID, + }) + ai++ + default: + // Reaching the default case implies that the removed entry is less + // than the added entry. Since the lists are sorted this means that + // the removed event was removed without an event with the same + // key being added. + result = append(result, stateChange{ + StateKeyTuple: removed[ai].StateKeyTuple, + removedEventNID: removed[ri].EventNID, + }) + ri++ + } + } +} diff --git a/roomserver/query/query.go b/roomserver/query/query.go new file mode 100644 index 00000000..b97d50b1 --- /dev/null +++ b/roomserver/query/query.go @@ -0,0 +1,802 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package query + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/auth" + "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// RoomserverQueryAPIEventDB has a convenience API to fetch events directly by +// EventIDs. +type RoomserverQueryAPIEventDB interface { + // Look up the Events for a list of event IDs. Does not error if event was + // not found. + // Returns an error if the retrieval went wrong. + EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) +} + +// RoomserverQueryAPIDatabase has the storage APIs needed to implement the query API. +type RoomserverQueryAPIDatabase interface { + state.RoomStateDatabase + RoomserverQueryAPIEventDB + // Look up the numeric ID for the room. + // Returns 0 if the room doesn't exists. + // Returns an error if there was a problem talking to the database. + RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) + // Look up event references for the latest events in the room and the current state snapshot. + // Returns the latest events, the current state and the maximum depth of the latest events plus 1. + // Returns an error if there was a problem talking to the database. + LatestEventIDs( + ctx context.Context, roomNID types.RoomNID, + ) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) + // Look up the numeric IDs for a list of events. + // Returns an error if there was a problem talking to the database. + EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + // Lookup the event IDs for a batch of event numeric IDs. + // Returns an error if the retrieval went wrong. + EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + // Lookup the membership of a given user in a given room. + // Returns the numeric ID of the latest membership event sent from this user + // in this room, along a boolean set to true if the user is still in this room, + // false if not. + // Returns an error if there was a problem talking to the database. + GetMembership( + ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, + ) (membershipEventNID types.EventNID, stillInRoom bool, err error) + // Lookup the membership event numeric IDs for all user that are or have + // been members of a given room. Only lookup events of "join" membership if + // joinOnly is set to true. + // Returns an error if there was a problem talking to the database. + GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, + ) ([]types.EventNID, error) + // Look up the active invites targeting a user in a room and return the + // numeric state key IDs for the user IDs who sent them. + // Returns an error if there was a problem talking to the database. + GetInvitesForUser( + ctx context.Context, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, + ) (senderUserNIDs []types.EventStateKeyNID, err error) + // Look up the string event state keys for a list of numeric event state keys + // Returns an error if there was a problem talking to the database. + EventStateKeys( + context.Context, []types.EventStateKeyNID, + ) (map[types.EventStateKeyNID]string, error) +} + +// RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI +type RoomserverQueryAPI struct { + DB RoomserverQueryAPIDatabase +} + +// QueryLatestEventsAndState implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryLatestEventsAndState( + ctx context.Context, + request *api.QueryLatestEventsAndStateRequest, + response *api.QueryLatestEventsAndStateResponse, +) error { + response.QueryLatestEventsAndStateRequest = *request + roomNID, err := r.DB.RoomNID(ctx, request.RoomID) + if err != nil { + return err + } + if roomNID == 0 { + return nil + } + response.RoomExists = true + var currentStateSnapshotNID types.StateSnapshotNID + response.LatestEvents, currentStateSnapshotNID, response.Depth, err = + r.DB.LatestEventIDs(ctx, roomNID) + if err != nil { + return err + } + + // Look up the currrent state for the requested tuples. + stateEntries, err := state.LoadStateAtSnapshotForStringTuples( + ctx, r.DB, currentStateSnapshotNID, request.StateToFetch, + ) + if err != nil { + return err + } + + stateEvents, err := r.loadStateEvents(ctx, stateEntries) + if err != nil { + return err + } + + response.StateEvents = stateEvents + return nil +} + +// QueryStateAfterEvents implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryStateAfterEvents( + ctx context.Context, + request *api.QueryStateAfterEventsRequest, + response *api.QueryStateAfterEventsResponse, +) error { + response.QueryStateAfterEventsRequest = *request + roomNID, err := r.DB.RoomNID(ctx, request.RoomID) + if err != nil { + return err + } + if roomNID == 0 { + return nil + } + response.RoomExists = true + + prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) + if err != nil { + switch err.(type) { + case types.MissingEventError: + return nil + default: + return err + } + } + response.PrevEventsExist = true + + // Look up the currrent state for the requested tuples. + stateEntries, err := state.LoadStateAfterEventsForStringTuples( + ctx, r.DB, prevStates, request.StateToFetch, + ) + if err != nil { + return err + } + + stateEvents, err := r.loadStateEvents(ctx, stateEntries) + if err != nil { + return err + } + + response.StateEvents = stateEvents + return nil +} + +// QueryEventsByID implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryEventsByID( + ctx context.Context, + request *api.QueryEventsByIDRequest, + response *api.QueryEventsByIDResponse, +) error { + response.QueryEventsByIDRequest = *request + + eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs) + if err != nil { + return err + } + + var eventNIDs []types.EventNID + for _, nid := range eventNIDMap { + eventNIDs = append(eventNIDs, nid) + } + + events, err := r.loadEvents(ctx, eventNIDs) + if err != nil { + return err + } + + response.Events = events + return nil +} + +func (r *RoomserverQueryAPI) loadStateEvents( + ctx context.Context, stateEntries []types.StateEntry, +) ([]gomatrixserverlib.Event, error) { + eventNIDs := make([]types.EventNID, len(stateEntries)) + for i := range stateEntries { + eventNIDs[i] = stateEntries[i].EventNID + } + return r.loadEvents(ctx, eventNIDs) +} + +func (r *RoomserverQueryAPI) loadEvents( + ctx context.Context, eventNIDs []types.EventNID, +) ([]gomatrixserverlib.Event, error) { + stateEvents, err := r.DB.Events(ctx, eventNIDs) + if err != nil { + return nil, err + } + + result := make([]gomatrixserverlib.Event, len(stateEvents)) + for i := range stateEvents { + result[i] = stateEvents[i].Event + } + return result, nil +} + +// QueryMembershipForUser implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryMembershipForUser( + ctx context.Context, + request *api.QueryMembershipForUserRequest, + response *api.QueryMembershipForUserResponse, +) error { + roomNID, err := r.DB.RoomNID(ctx, request.RoomID) + if err != nil { + return err + } + + membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, roomNID, request.UserID) + if err != nil { + return err + } + + if membershipEventNID == 0 { + response.HasBeenInRoom = false + return nil + } + + response.IsInRoom = stillInRoom + eventIDMap, err := r.DB.EventIDs(ctx, []types.EventNID{membershipEventNID}) + if err != nil { + return err + } + + response.EventID = eventIDMap[membershipEventNID] + return nil +} + +// QueryMembershipsForRoom implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryMembershipsForRoom( + ctx context.Context, + request *api.QueryMembershipsForRoomRequest, + response *api.QueryMembershipsForRoomResponse, +) error { + roomNID, err := r.DB.RoomNID(ctx, request.RoomID) + if err != nil { + return err + } + + membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, roomNID, request.Sender) + if err != nil { + return err + } + + if membershipEventNID == 0 { + response.HasBeenInRoom = false + response.JoinEvents = nil + return nil + } + + response.HasBeenInRoom = true + response.JoinEvents = []gomatrixserverlib.ClientEvent{} + + var events []types.Event + if stillInRoom { + var eventNIDs []types.EventNID + eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly) + if err != nil { + return err + } + + events, err = r.DB.Events(ctx, eventNIDs) + } else { + events, err = r.getMembershipsBeforeEventNID(ctx, membershipEventNID, request.JoinedOnly) + } + + if err != nil { + return err + } + + for _, event := range events { + clientEvent := gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll) + response.JoinEvents = append(response.JoinEvents, clientEvent) + } + + return nil +} + +// getMembershipsBeforeEventNID takes the numeric ID of an event and fetches the state +// of the event's room as it was when this event was fired, then filters the state events to +// only keep the "m.room.member" events with a "join" membership. These events are returned. +// Returns an error if there was an issue fetching the events. +func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID( + ctx context.Context, eventNID types.EventNID, joinedOnly bool, +) ([]types.Event, error) { + events := []types.Event{} + // Lookup the event NID + eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID}) + if err != nil { + return nil, err + } + eventIDs := []string{eIDs[eventNID]} + + prevState, err := r.DB.StateAtEventIDs(ctx, eventIDs) + if err != nil { + return nil, err + } + + // Fetch the state as it was when this event was fired + stateEntries, err := state.LoadCombinedStateAfterEvents(ctx, r.DB, prevState) + if err != nil { + return nil, err + } + + var eventNIDs []types.EventNID + for _, entry := range stateEntries { + // Filter the events to retrieve to only keep the membership events + if entry.EventTypeNID == types.MRoomMemberNID { + eventNIDs = append(eventNIDs, entry.EventNID) + } + } + + // Get all of the events in this state + stateEvents, err := r.DB.Events(ctx, eventNIDs) + if err != nil { + return nil, err + } + + if !joinedOnly { + return stateEvents, nil + } + + // Filter the events to only keep the "join" membership events + for _, event := range stateEvents { + membership, err := event.Membership() + if err != nil { + return nil, err + } + + if membership == "join" { + events = append(events, event) + } + } + + return events, nil +} + +// QueryInvitesForUser implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryInvitesForUser( + ctx context.Context, + request *api.QueryInvitesForUserRequest, + response *api.QueryInvitesForUserResponse, +) error { + roomNID, err := r.DB.RoomNID(ctx, request.RoomID) + if err != nil { + return err + } + + targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.TargetUserID}) + if err != nil { + return err + } + targetUserNID := targetUserNIDs[request.TargetUserID] + + senderUserNIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID) + if err != nil { + return err + } + + senderUserIDs, err := r.DB.EventStateKeys(ctx, senderUserNIDs) + if err != nil { + return err + } + + for _, senderUserID := range senderUserIDs { + response.InviteSenderUserIDs = append(response.InviteSenderUserIDs, senderUserID) + } + + return nil +} + +// QueryServerAllowedToSeeEvent implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent( + ctx context.Context, + request *api.QueryServerAllowedToSeeEventRequest, + response *api.QueryServerAllowedToSeeEventResponse, +) (err error) { + response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent( + ctx, request.EventID, request.ServerName, + ) + return +} + +func (r *RoomserverQueryAPI) checkServerAllowedToSeeEvent( + ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stateEntries, err := state.LoadStateAtEvent(ctx, r.DB, eventID) + if err != nil { + return false, err + } + + // TODO: We probably want to make it so that we don't have to pull + // out all the state if possible. + stateAtEvent, err := r.loadStateEvents(ctx, stateEntries) + if err != nil { + return false, err + } + + return auth.IsServerAllowed(serverName, stateAtEvent), nil +} + +// QueryMissingEvents implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryMissingEvents( + ctx context.Context, + request *api.QueryMissingEventsRequest, + response *api.QueryMissingEventsResponse, +) error { + var front []string + eventsToFilter := make(map[string]bool, len(request.LatestEvents)) + visited := make(map[string]bool, request.Limit) // request.Limit acts as a hint to size. + for _, id := range request.EarliestEvents { + visited[id] = true + } + + for _, id := range request.LatestEvents { + if !visited[id] { + front = append(front, id) + eventsToFilter[id] = true + } + } + + resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) + if err != nil { + return err + } + + loadedEvents, err := r.loadEvents(ctx, resultNIDs) + if err != nil { + return err + } + + response.Events = make([]gomatrixserverlib.Event, 0, len(loadedEvents)-len(eventsToFilter)) + for _, event := range loadedEvents { + if !eventsToFilter[event.EventID()] { + response.Events = append(response.Events, event) + } + } + + return err +} + +// QueryBackfill implements api.RoomServerQueryAPI +func (r *RoomserverQueryAPI) QueryBackfill( + ctx context.Context, + request *api.QueryBackfillRequest, + response *api.QueryBackfillResponse, +) error { + var err error + var front []string + + // The limit defines the maximum number of events to retrieve, so it also + // defines the highest number of elements in the map below. + visited := make(map[string]bool, request.Limit) + + // The provided event IDs have already been seen by the request's emitter, + // and will be retrieved anyway, so there's no need to care about them if + // they appear in our exploration of the event tree. + for _, id := range request.EarliestEventsIDs { + visited[id] = true + } + + front = request.EarliestEventsIDs + + // Scan the event tree for events to send back. + resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) + if err != nil { + return err + } + + // Retrieve events from the list that was filled previously. + response.Events, err = r.loadEvents(ctx, resultNIDs) + return err +} + +func (r *RoomserverQueryAPI) scanEventTree( + ctx context.Context, front []string, visited map[string]bool, limit int, + serverName gomatrixserverlib.ServerName, +) (resultNIDs []types.EventNID, err error) { + var allowed bool + var events []types.Event + var next []string + var pre string + + resultNIDs = make([]types.EventNID, 0, limit) + + // Loop through the event IDs to retrieve the requested events and go + // through the whole tree (up to the provided limit) using the events' + // "prev_event" key. +BFSLoop: + for len(front) > 0 { + // Prevent unnecessary allocations: reset the slice only when not empty. + if len(next) > 0 { + next = make([]string, 0) + } + // Retrieve the events to process from the database. + events, err = r.DB.EventsFromIDs(ctx, front) + if err != nil { + return + } + + for _, ev := range events { + // Break out of the loop if the provided limit is reached. + if len(resultNIDs) == limit { + break BFSLoop + } + // Update the list of events to retrieve. + resultNIDs = append(resultNIDs, ev.EventNID) + // Loop through the event's parents. + for _, pre = range ev.PrevEventIDs() { + // Only add an event to the list of next events to process if it + // hasn't been seen before. + if !visited[pre] { + visited[pre] = true + allowed, err = r.checkServerAllowedToSeeEvent(ctx, pre, serverName) + if err != nil { + return + } + + // If the event hasn't been seen before and the HS + // requesting to retrieve it is allowed to do so, add it to + // the list of events to retrieve. + if allowed { + next = append(next, pre) + } + } + } + } + // Repeat the same process with the parent events we just processed. + front = next + } + + return +} + +// QueryStateAndAuthChain implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryStateAndAuthChain( + ctx context.Context, + request *api.QueryStateAndAuthChainRequest, + response *api.QueryStateAndAuthChainResponse, +) error { + response.QueryStateAndAuthChainRequest = *request + roomNID, err := r.DB.RoomNID(ctx, request.RoomID) + if err != nil { + return err + } + if roomNID == 0 { + return nil + } + response.RoomExists = true + + prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) + if err != nil { + switch err.(type) { + case types.MissingEventError: + return nil + default: + return err + } + } + response.PrevEventsExist = true + + // Look up the currrent state for the requested tuples. + stateEntries, err := state.LoadCombinedStateAfterEvents( + ctx, r.DB, prevStates, + ) + if err != nil { + return err + } + + stateEvents, err := r.loadStateEvents(ctx, stateEntries) + if err != nil { + return err + } + + response.StateEvents = stateEvents + response.AuthChainEvents, err = getAuthChain(ctx, r.DB, request.AuthEventIDs) + return err +} + +// getAuthChain fetches the auth chain for the given auth events. +// An auth chain is the list of all events that are referenced in the +// auth_events section, and all their auth_events, recursively. +// The returned set of events contain the given events. +// Will *not* error if we don't have all auth events. +func getAuthChain( + ctx context.Context, dB RoomserverQueryAPIEventDB, authEventIDs []string, +) ([]gomatrixserverlib.Event, error) { + var authEvents []gomatrixserverlib.Event + + // List of event ids to fetch. These will be added to the result and + // their auth events will be fetched (if they haven't been previously) + eventsToFetch := authEventIDs + + // Set of events we've already fetched. + fetchedEventMap := make(map[string]bool) + + // Check if there's anything left to do + for len(eventsToFetch) > 0 { + // Convert eventIDs to events. First need to fetch NIDs + events, err := dB.EventsFromIDs(ctx, eventsToFetch) + if err != nil { + return nil, err + } + + // Work out a) which events we should add to the returned list of + // events and b) which of the auth events we haven't seen yet and + // add them to the list of events to fetch. + eventsToFetch = eventsToFetch[:0] + for _, event := range events { + fetchedEventMap[event.EventID()] = true + authEvents = append(authEvents, event.Event) + + // Now we need to fetch any auth events that we haven't + // previously seen. + for _, authEventID := range event.AuthEventIDs() { + if !fetchedEventMap[authEventID] { + fetchedEventMap[authEventID] = true + eventsToFetch = append(eventsToFetch, authEventID) + } + } + } + } + + return authEvents, nil +} + +// SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux. +// nolint: gocyclo +func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) { + servMux.Handle( + api.RoomserverQueryLatestEventsAndStatePath, + common.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse { + var request api.QueryLatestEventsAndStateRequest + var response api.QueryLatestEventsAndStateResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryLatestEventsAndState(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + servMux.Handle( + api.RoomserverQueryStateAfterEventsPath, + common.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse { + var request api.QueryStateAfterEventsRequest + var response api.QueryStateAfterEventsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryStateAfterEvents(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + servMux.Handle( + api.RoomserverQueryEventsByIDPath, + common.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse { + var request api.QueryEventsByIDRequest + var response api.QueryEventsByIDResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryEventsByID(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + servMux.Handle( + api.RoomserverQueryMembershipForUserPath, + common.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse { + var request api.QueryMembershipForUserRequest + var response api.QueryMembershipForUserResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryMembershipForUser(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + servMux.Handle( + api.RoomserverQueryMembershipsForRoomPath, + common.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse { + var request api.QueryMembershipsForRoomRequest + var response api.QueryMembershipsForRoomResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryMembershipsForRoom(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + servMux.Handle( + api.RoomserverQueryInvitesForUserPath, + common.MakeInternalAPI("queryInvitesForUser", func(req *http.Request) util.JSONResponse { + var request api.QueryInvitesForUserRequest + var response api.QueryInvitesForUserResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryInvitesForUser(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + servMux.Handle( + api.RoomserverQueryServerAllowedToSeeEventPath, + common.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse { + var request api.QueryServerAllowedToSeeEventRequest + var response api.QueryServerAllowedToSeeEventResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryServerAllowedToSeeEvent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + servMux.Handle( + api.RoomserverQueryMissingEventsPath, + common.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse { + var request api.QueryMissingEventsRequest + var response api.QueryMissingEventsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryMissingEvents(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + servMux.Handle( + api.RoomserverQueryStateAndAuthChainPath, + common.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse { + var request api.QueryStateAndAuthChainRequest + var response api.QueryStateAndAuthChainResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryStateAndAuthChain(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + servMux.Handle( + api.RoomserverQueryBackfillPath, + common.MakeInternalAPI("QueryBackfill", func(req *http.Request) util.JSONResponse { + var request api.QueryBackfillRequest + var response api.QueryBackfillResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryBackfill(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/roomserver/query/query_test.go b/roomserver/query/query_test.go new file mode 100644 index 00000000..76c2e158 --- /dev/null +++ b/roomserver/query/query_test.go @@ -0,0 +1,155 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package query + +import ( + "context" + "encoding/json" + "testing" + + "github.com/matrix-org/dendrite/common/test" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// used to implement RoomserverQueryAPIEventDB to test getAuthChain +type getEventDB struct { + eventMap map[string]gomatrixserverlib.Event +} + +func createEventDB() *getEventDB { + return &getEventDB{ + eventMap: make(map[string]gomatrixserverlib.Event), + } +} + +// Adds a fake event to the storage with given auth events. +func (db *getEventDB) addFakeEvent(eventID string, authIDs []string) error { + authEvents := []gomatrixserverlib.EventReference{} + for _, authID := range authIDs { + authEvents = append(authEvents, gomatrixserverlib.EventReference{ + EventID: authID, + }) + } + + builder := map[string]interface{}{ + "event_id": eventID, + "auth_events": authEvents, + } + + eventJSON, err := json.Marshal(&builder) + if err != nil { + return err + } + + event, err := gomatrixserverlib.NewEventFromTrustedJSON(eventJSON, false) + if err != nil { + return err + } + + db.eventMap[eventID] = event + + return nil +} + +// Adds multiple events at once, each entry in the map is an eventID and set of +// auth events that are converted to an event and added. +func (db *getEventDB) addFakeEvents(graph map[string][]string) error { + for eventID, authIDs := range graph { + err := db.addFakeEvent(eventID, authIDs) + if err != nil { + return err + } + } + + return nil +} + +// EventsFromIDs implements RoomserverQueryAPIEventDB +func (db *getEventDB) EventsFromIDs(ctx context.Context, eventIDs []string) (res []types.Event, err error) { + for _, evID := range eventIDs { + res = append(res, types.Event{ + EventNID: 0, + Event: db.eventMap[evID], + }) + } + + return +} + +func TestGetAuthChainSingle(t *testing.T) { + db := createEventDB() + + err := db.addFakeEvents(map[string][]string{ + "a": {}, + "b": {"a"}, + "c": {"a", "b"}, + "d": {"b", "c"}, + "e": {"a", "d"}, + }) + + if err != nil { + t.Fatalf("Failed to add events to db: %v", err) + } + + result, err := getAuthChain(context.TODO(), db, []string{"e"}) + if err != nil { + t.Fatalf("getAuthChain failed: %v", err) + } + + var returnedIDs []string + for _, event := range result { + returnedIDs = append(returnedIDs, event.EventID()) + } + + expectedIDs := []string{"a", "b", "c", "d", "e"} + + if !test.UnsortedStringSliceEqual(expectedIDs, returnedIDs) { + t.Fatalf("returnedIDs got '%v', expected '%v'", returnedIDs, expectedIDs) + } +} + +func TestGetAuthChainMultiple(t *testing.T) { + db := createEventDB() + + err := db.addFakeEvents(map[string][]string{ + "a": {}, + "b": {"a"}, + "c": {"a", "b"}, + "d": {"b", "c"}, + "e": {"a", "d"}, + "f": {"a", "b", "c"}, + }) + + if err != nil { + t.Fatalf("Failed to add events to db: %v", err) + } + + result, err := getAuthChain(context.TODO(), db, []string{"e", "f"}) + if err != nil { + t.Fatalf("getAuthChain failed: %v", err) + } + + var returnedIDs []string + for _, event := range result { + returnedIDs = append(returnedIDs, event.EventID()) + } + + expectedIDs := []string{"a", "b", "c", "d", "e", "f"} + + if !test.UnsortedStringSliceEqual(expectedIDs, returnedIDs) { + t.Fatalf("returnedIDs got '%v', expected '%v'", returnedIDs, expectedIDs) + } +} diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go new file mode 100644 index 00000000..2ffbf67d --- /dev/null +++ b/roomserver/roomserver.go @@ -0,0 +1,68 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package roomserver + +import ( + "net/http" + + "github.com/matrix-org/dendrite/roomserver/api" + + asQuery "github.com/matrix-org/dendrite/appservice/query" + "github.com/matrix-org/dendrite/common/basecomponent" + "github.com/matrix-org/dendrite/roomserver/alias" + "github.com/matrix-org/dendrite/roomserver/input" + "github.com/matrix-org/dendrite/roomserver/query" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/sirupsen/logrus" +) + +// SetupRoomServerComponent sets up and registers HTTP handlers for the +// RoomServer component. Returns instances of the various roomserver APIs, +// allowing other components running in the same process to hit the query the +// APIs directly instead of having to use HTTP. +func SetupRoomServerComponent( + base *basecomponent.BaseDendrite, +) (api.RoomserverAliasAPI, api.RoomserverInputAPI, api.RoomserverQueryAPI) { + roomserverDB, err := storage.Open(string(base.Cfg.Database.RoomServer)) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to room server db") + } + + inputAPI := input.RoomserverInputAPI{ + DB: roomserverDB, + Producer: base.KafkaProducer, + OutputRoomEventTopic: string(base.Cfg.Kafka.Topics.OutputRoomEvent), + } + + inputAPI.SetupHTTP(http.DefaultServeMux) + + queryAPI := query.RoomserverQueryAPI{DB: roomserverDB} + + queryAPI.SetupHTTP(http.DefaultServeMux) + + asAPI := asQuery.AppServiceQueryAPI{Cfg: base.Cfg} + + aliasAPI := alias.RoomserverAliasAPI{ + DB: roomserverDB, + Cfg: base.Cfg, + InputAPI: &inputAPI, + QueryAPI: &queryAPI, + AppserviceAPI: &asAPI, + } + + aliasAPI.SetupHTTP(http.DefaultServeMux) + + return &aliasAPI, &inputAPI, &queryAPI +} diff --git a/roomserver/state/state.go b/roomserver/state/state.go new file mode 100644 index 00000000..2a0b7f57 --- /dev/null +++ b/roomserver/state/state.go @@ -0,0 +1,966 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package state provides functions for reading state from the database. +// The functions for writing state to the database are the input package. +package state + +import ( + "context" + "fmt" + "sort" + "time" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" +) + +// A RoomStateDatabase has the storage APIs needed to load state from the database +type RoomStateDatabase interface { + // Store the room state at an event in the database + AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, + ) (types.StateSnapshotNID, error) + // Look up the state of a room at each event for a list of string event IDs. + // Returns an error if there is an error talking to the database + // Returns a types.MissingEventError if the room state for the event IDs aren't in the database + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + // Look up the numeric IDs for a list of string event types. + // Returns a map from string event type to numeric ID for the event type. + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + // Look up the numeric IDs for a list of string event state keys. + // Returns a map from string state key to numeric ID for the state key. + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + // Look up the numeric state data IDs for each numeric state snapshot ID + // The returned slice is sorted by numeric state snapshot ID. + StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + // Look up the state data for each numeric state data ID + // The returned slice is sorted by numeric state data ID. + StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + // Look up the state data for the state key tuples for each numeric state block ID + // This is used to fetch a subset of the room state at a snapshot. + // If a block doesn't contain any of the requested tuples then it can be discarded from the result. + // The returned slice is sorted by numeric state block ID. + StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, + ) ([]types.StateEntryList, error) + // Look up the Events for a list of numeric event IDs. + // Returns a sorted list of events. + Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) + // Look up snapshot NID for an event ID string + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) +} + +// LoadStateAtSnapshot loads the full state of a room at a particular snapshot. +// This is typically the state before an event or the current state of a room. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func LoadStateAtSnapshot( + ctx context.Context, db RoomStateDatabase, stateNID types.StateSnapshotNID, +) ([]types.StateEntry, error) { + stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) + if err != nil { + return nil, err + } + // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. + stateBlockNIDList := stateBlockNIDLists[0] + + stateEntryLists, err := db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs) + if err != nil { + return nil, err + } + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) + } + fullState = append(fullState, entries...) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + return fullState, nil +} + +// LoadStateAtEvent loads the full state of a room at a particular event. +func LoadStateAtEvent( + ctx context.Context, db RoomStateDatabase, eventID string, +) ([]types.StateEntry, error) { + snapshotNID, err := db.SnapshotNIDFromEventID(ctx, eventID) + if err != nil { + return nil, err + } + + stateEntries, err := LoadStateAtSnapshot(ctx, db, snapshotNID) + if err != nil { + return nil, err + } + + return stateEntries, nil +} + +// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events +// and combines those snapshots together into a single list. +func LoadCombinedStateAfterEvents( + ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent, +) ([]types.StateEntry, error) { + stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) + for i, state := range prevStates { + stateNIDs[i] = state.BeforeStateSnapshotNID + } + // Fetch the state snapshots for the state before the each prev event from the database. + // Deduplicate the IDs before passing them to the database. + // There could be duplicates because the events could be state events where + // the snapshot of the room state before them was the same. + stateBlockNIDLists, err := db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs)) + if err != nil { + return nil, err + } + + var stateBlockNIDs []types.StateBlockNID + for _, list := range stateBlockNIDLists { + stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...) + } + // Fetch the state entries that will be combined to create the snapshots. + // Deduplicate the IDs before passing them to the database. + // There could be duplicates because a block of state entries could be reused by + // multiple snapshots. + stateEntryLists, err := db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs)) + if err != nil { + return nil, err + } + stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine the entries from all the snapshots of state after each prev event into a single list. + var combined []types.StateEntry + for _, prevState := range prevStates { + // Grab the list of state data NIDs for this snapshot. + stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID)) + } + + // Combine all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) + } + fullState = append(fullState, entries...) + } + if prevState.IsStateEvent() { + // If the prev event was a state event then add an entry for the event itself + // so that we get the state after the event rather than the state before. + fullState = append(fullState, prevState.StateEntry) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + // Add the full state for this StateSnapshotNID. + combined = append(combined, fullState...) + } + return combined, nil +} + +// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots. +func DifferenceBetweeenStateSnapshots( + ctx context.Context, db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID, +) (removed, added []types.StateEntry, err error) { + if oldStateNID == newStateNID { + // If the snapshot NIDs are the same then nothing has changed + return nil, nil, nil + } + + var oldEntries []types.StateEntry + var newEntries []types.StateEntry + if oldStateNID != 0 { + oldEntries, err = LoadStateAtSnapshot(ctx, db, oldStateNID) + if err != nil { + return nil, nil, err + } + } + if newStateNID != 0 { + newEntries, err = LoadStateAtSnapshot(ctx, db, newStateNID) + if err != nil { + return nil, nil, err + } + } + + var oldI int + var newI int + for { + switch { + case oldI == len(oldEntries): + // We've reached the end of the old entries. + // The rest of the new list must have been newly added. + added = append(added, newEntries[newI:]...) + return + case newI == len(newEntries): + // We've reached the end of the new entries. + // The rest of the old list must be have been removed. + removed = append(removed, oldEntries[oldI:]...) + return + case oldEntries[oldI] == newEntries[newI]: + // The entry is in both lists so skip over it. + oldI++ + newI++ + case oldEntries[oldI].LessThan(newEntries[newI]): + // The lists are sorted so the old entry being less than the new entry means that it only appears in the old list. + removed = append(removed, oldEntries[oldI]) + oldI++ + default: + // Reaching the default case implies that the new entry is less than the old entry. + // Since the lists are sorted this means that it only appears in the new list. + added = append(added, newEntries[newI]) + newI++ + } + } +} + +// LoadStateAtSnapshotForStringTuples loads the state for a list of event type and state key pairs at a snapshot. +// This is used when we only want to load a subset of the room state at a snapshot. +// If there is no entry for a given event type and state key pair then it will be discarded. +// This is typically the state before an event or the current state of a room. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func LoadStateAtSnapshotForStringTuples( + ctx context.Context, + db RoomStateDatabase, + stateNID types.StateSnapshotNID, + stateKeyTuples []gomatrixserverlib.StateKeyTuple, +) ([]types.StateEntry, error) { + numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples) + if err != nil { + return nil, err + } + return loadStateAtSnapshotForNumericTuples(ctx, db, stateNID, numericTuples) +} + +// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs +// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded. +// Returns an error if there was a problem talking to the database. +func stringTuplesToNumericTuples( + ctx context.Context, + db RoomStateDatabase, + stringTuples []gomatrixserverlib.StateKeyTuple, +) ([]types.StateKeyTuple, error) { + eventTypes := make([]string, len(stringTuples)) + stateKeys := make([]string, len(stringTuples)) + for i := range stringTuples { + eventTypes[i] = stringTuples[i].EventType + stateKeys[i] = stringTuples[i].StateKey + } + eventTypes = util.UniqueStrings(eventTypes) + eventTypeMap, err := db.EventTypeNIDs(ctx, eventTypes) + if err != nil { + return nil, err + } + stateKeys = util.UniqueStrings(stateKeys) + stateKeyMap, err := db.EventStateKeyNIDs(ctx, stateKeys) + if err != nil { + return nil, err + } + + var result []types.StateKeyTuple + for _, stringTuple := range stringTuples { + var numericTuple types.StateKeyTuple + var ok1, ok2 bool + numericTuple.EventTypeNID, ok1 = eventTypeMap[stringTuple.EventType] + numericTuple.EventStateKeyNID, ok2 = stateKeyMap[stringTuple.StateKey] + // Discard the tuple if there wasn't a numeric ID for either the event type or the state key. + if ok1 && ok2 { + result = append(result, numericTuple) + } + } + + return result, nil +} + +// loadStateAtSnapshotForNumericTuples loads the state for a list of event type and state key pairs at a snapshot. +// This is used when we only want to load a subset of the room state at a snapshot. +// If there is no entry for a given event type and state key pair then it will be discarded. +// This is typically the state before an event or the current state of a room. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func loadStateAtSnapshotForNumericTuples( + ctx context.Context, + db RoomStateDatabase, + stateNID types.StateSnapshotNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntry, error) { + stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) + if err != nil { + return nil, err + } + // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. + stateBlockNIDList := stateBlockNIDLists[0] + + stateEntryLists, err := db.StateEntriesForTuples( + ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples, + ) + if err != nil { + return nil, err + } + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // If the block is missing from the map it means that none of its entries matched a requested tuple. + // This can happen if the block doesn't contain an update for one of the requested tuples. + // If none of the requested tuples are in the block then it can be safely skipped. + continue + } + fullState = append(fullState, entries...) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + return fullState, nil +} + +// LoadStateAfterEventsForStringTuples loads the state for a list of event type +// and state key pairs after list of events. +// This is used when we only want to load a subset of the room state after a list of events. +// If there is no entry for a given event type and state key pair then it will be discarded. +// This is typically the state before an event. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func LoadStateAfterEventsForStringTuples( + ctx context.Context, + db RoomStateDatabase, + prevStates []types.StateAtEvent, + stateKeyTuples []gomatrixserverlib.StateKeyTuple, +) ([]types.StateEntry, error) { + numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples) + if err != nil { + return nil, err + } + return loadStateAfterEventsForNumericTuples(ctx, db, prevStates, numericTuples) +} + +func loadStateAfterEventsForNumericTuples( + ctx context.Context, + db RoomStateDatabase, + prevStates []types.StateAtEvent, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntry, error) { + if len(prevStates) == 1 { + // Fast path for a single event. + prevState := prevStates[0] + result, err := loadStateAtSnapshotForNumericTuples( + ctx, db, prevState.BeforeStateSnapshotNID, stateKeyTuples, + ) + if err != nil { + return nil, err + } + if prevState.IsStateEvent() { + // The result is current the state before the requested event. + // We want the state after the requested event. + // If the requested event was a state event then we need to + // update that key in the result. + // If the requested event wasn't a state event then the state after + // it is the same as the state before it. + for i := range result { + if result[i].StateKeyTuple == prevState.StateKeyTuple { + result[i] = prevState.StateEntry + } + } + } + return result, nil + } + + // Slow path for more that one event. + // Load the entire state so that we can do conflict resolution if we need to. + // TODO: The are some optimistations we could do here: + // 1) We only need to do conflict resolution if there is a conflict in the + // requested tuples so we might try loading just those tuples and then + // checking for conflicts. + // 2) When there is a conflict we still only need to load the state + // needed to do conflict resolution which would save us having to load + // the full state. + + // TODO: Add metrics for this as it could take a long time for big rooms + // with large conflicts. + fullState, _, _, err := calculateStateAfterManyEvents(ctx, db, prevStates) + if err != nil { + return nil, err + } + + // Sort the full state so we can use it as a map. + sort.Sort(stateEntrySorter(fullState)) + + // Filter the full state down to the required tuples. + var result []types.StateEntry + for _, tuple := range stateKeyTuples { + eventNID, ok := stateEntryMap(fullState).lookup(tuple) + if ok { + result = append(result, types.StateEntry{ + StateKeyTuple: tuple, + EventNID: eventNID, + }) + } + } + sort.Sort(stateEntrySorter(result)) + return result, nil +} + +var calculateStateDurations = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "dendrite", + Subsystem: "roomserver", + Name: "calculate_state_duration_microseconds", + Help: "How long it takes to calculate the state after a list of events", + }, + // Takes two labels: + // algorithm: + // The algorithm used to calculate the state or the step it failed on if it failed. + // Labels starting with "_" are used to indicate when the algorithm fails halfway. + // outcome: + // Whether the state was successfully calculated. + // + // The possible values for algorithm are: + // empty_state -> The list of events was empty so the state is empty. + // no_change -> The state hasn't changed. + // single_delta -> There was a single event added to the state in a way that can be encoded as a single delta + // full_state_no_conflicts -> We created a new copy of the full room state, but didn't enounter any conflicts + // while doing so. + // full_state_with_conflicts -> We created a new copy of the full room state and had to resolve conflicts to do so. + // _load_state_block_nids -> Failed loading the state block nids for a single previous state. + // _load_combined_state -> Failed to load the combined state. + // _resolve_conflicts -> Failed to resolve conflicts. + []string{"algorithm", "outcome"}, +) + +var calculateStatePrevEventLength = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "dendrite", + Subsystem: "roomserver", + Name: "calculate_state_prev_event_length", + Help: "The length of the list of events to calculate the state after", + }, + []string{"algorithm", "outcome"}, +) + +var calculateStateFullStateLength = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "dendrite", + Subsystem: "roomserver", + Name: "calculate_state_full_state_length", + Help: "The length of the full room state.", + }, + []string{"algorithm", "outcome"}, +) + +var calculateStateConflictLength = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "dendrite", + Subsystem: "roomserver", + Name: "calculate_state_conflict_state_length", + Help: "The length of the conflicted room state.", + }, + []string{"algorithm", "outcome"}, +) + +type calculateStateMetrics struct { + algorithm string + startTime time.Time + prevEventLength int + fullStateLength int + conflictLength int +} + +func (c *calculateStateMetrics) stop(stateNID types.StateSnapshotNID, err error) (types.StateSnapshotNID, error) { + var outcome string + if err == nil { + outcome = "success" + } else { + outcome = "failure" + } + endTime := time.Now() + calculateStateDurations.WithLabelValues(c.algorithm, outcome).Observe( + float64(endTime.Sub(c.startTime).Nanoseconds()) / 1000., + ) + calculateStatePrevEventLength.WithLabelValues(c.algorithm, outcome).Observe( + float64(c.prevEventLength), + ) + calculateStateFullStateLength.WithLabelValues(c.algorithm, outcome).Observe( + float64(c.fullStateLength), + ) + calculateStateConflictLength.WithLabelValues(c.algorithm, outcome).Observe( + float64(c.conflictLength), + ) + return stateNID, err +} + +func init() { + prometheus.MustRegister( + calculateStateDurations, calculateStatePrevEventLength, + calculateStateFullStateLength, calculateStateConflictLength, + ) +} + +// CalculateAndStoreStateBeforeEvent calculates a snapshot of the state of a room before an event. +// Stores the snapshot of the state in the database. +// Returns a numeric ID for the snapshot of the state before the event. +func CalculateAndStoreStateBeforeEvent( + ctx context.Context, + db RoomStateDatabase, + event gomatrixserverlib.Event, + roomNID types.RoomNID, +) (types.StateSnapshotNID, error) { + // Load the state at the prev events. + prevEventRefs := event.PrevEvents() + prevEventIDs := make([]string, len(prevEventRefs)) + for i := range prevEventRefs { + prevEventIDs[i] = prevEventRefs[i].EventID + } + + prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs) + if err != nil { + return 0, err + } + + // The state before this event will be the state after the events that came before it. + return CalculateAndStoreStateAfterEvents(ctx, db, roomNID, prevStates) +} + +// CalculateAndStoreStateAfterEvents finds the room state after the given events. +// Stores the resulting state in the database and returns a numeric ID for that snapshot. +func CalculateAndStoreStateAfterEvents( + ctx context.Context, + db RoomStateDatabase, + roomNID types.RoomNID, + prevStates []types.StateAtEvent, +) (types.StateSnapshotNID, error) { + metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} + + if len(prevStates) == 0 { + // 2) There weren't any prev_events for this event so the state is + // empty. + metrics.algorithm = "empty_state" + return metrics.stop(db.AddState(ctx, roomNID, nil, nil)) + } + + if len(prevStates) == 1 { + prevState := prevStates[0] + if prevState.EventStateKeyNID == 0 { + // 3) None of the previous events were state events and they all + // have the same state, so this event has exactly the same state + // as the previous events. + // This should be the common case. + metrics.algorithm = "no_change" + return metrics.stop(prevState.BeforeStateSnapshotNID, nil) + } + // The previous event was a state event so we need to store a copy + // of the previous state updated with that event. + stateBlockNIDLists, err := db.StateBlockNIDs( + ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID}, + ) + if err != nil { + metrics.algorithm = "_load_state_blocks" + return metrics.stop(0, err) + } + stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs + if len(stateBlockNIDs) < maxStateBlockNIDs { + // 4) The number of state data blocks is small enough that we can just + // add the state event as a block of size one to the end of the blocks. + metrics.algorithm = "single_delta" + return metrics.stop(db.AddState( + ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, + )) + } + // If there are too many deltas then we need to calculate the full state + // So fall through to calculateAndStoreStateAfterManyEvents + } + + return calculateAndStoreStateAfterManyEvents(ctx, db, roomNID, prevStates, metrics) +} + +// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state. +// Increasing this number means that we can encode more of the state changes as simple deltas which means that +// we need fewer entries in the state data table. However making this number bigger will increase the size of +// the rows in the state table itself and will require more index lookups when retrieving a snapshot. +// TODO: Tune this to get the right balance between size and lookup performance. +const maxStateBlockNIDs = 64 + +// calculateAndStoreStateAfterManyEvents finds the room state after the given events. +// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event. +// Stores the resulting state and returns a numeric ID for the snapshot. +func calculateAndStoreStateAfterManyEvents( + ctx context.Context, + db RoomStateDatabase, + roomNID types.RoomNID, + prevStates []types.StateAtEvent, + metrics calculateStateMetrics, +) (types.StateSnapshotNID, error) { + + state, algorithm, conflictLength, err := + calculateStateAfterManyEvents(ctx, db, prevStates) + metrics.algorithm = algorithm + if err != nil { + return metrics.stop(0, err) + } + + // TODO: Check if we can encode the new state as a delta against the + // previous state. + metrics.conflictLength = conflictLength + metrics.fullStateLength = len(state) + return metrics.stop(db.AddState(ctx, roomNID, nil, state)) +} + +func calculateStateAfterManyEvents( + ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent, +) (state []types.StateEntry, algorithm string, conflictLength int, err error) { + var combined []types.StateEntry + // Conflict resolution. + // First stage: load the state after each of the prev events. + combined, err = LoadCombinedStateAfterEvents(ctx, db, prevStates) + if err != nil { + algorithm = "_load_combined_state" + return + } + + // Collect all the entries with the same type and key together. + // We don't care about the order here because the conflict resolution + // algorithm doesn't depend on the order of the prev events. + // Remove duplicate entires. + combined = combined[:util.SortAndUnique(stateEntrySorter(combined))] + + // Find the conflicts + conflicts := findDuplicateStateKeys(combined) + + if len(conflicts) > 0 { + conflictLength = len(conflicts) + + // 5) There are conflicting state events, for each conflict workout + // what the appropriate state event is. + + // Work out which entries aren't conflicted. + var notConflicted []types.StateEntry + for _, entry := range combined { + if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok { + notConflicted = append(notConflicted, entry) + } + } + + var resolved []types.StateEntry + resolved, err = resolveConflicts(ctx, db, notConflicted, conflicts) + if err != nil { + algorithm = "_resolve_conflicts" + return + } + algorithm = "full_state_with_conflicts" + state = resolved + } else { + algorithm = "full_state_no_conflicts" + // 6) There weren't any conflicts + state = combined + } + return +} + +// resolveConflicts resolves a list of conflicted state entries. It takes two lists. +// The first is a list of all state entries that are not conflicted. +// The second is a list of all state entries that are conflicted +// A state entry is conflicted when there is more than one numeric event ID for the same state key tuple. +// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts. +// The returned list is sorted by state key tuple. +// Returns an error if there was a problem talking to the database. +func resolveConflicts( + ctx context.Context, + db RoomStateDatabase, + notConflicted, conflicted []types.StateEntry, +) ([]types.StateEntry, error) { + + // Load the conflicted events + conflictedEvents, eventIDMap, err := loadStateEvents(ctx, db, conflicted) + if err != nil { + return nil, err + } + + // Work out which auth events we need to load. + needed := gomatrixserverlib.StateNeededForAuth(conflictedEvents) + + // Find the numeric IDs for the necessary state keys. + var neededStateKeys []string + neededStateKeys = append(neededStateKeys, needed.Member...) + neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) + stateKeyNIDMap, err := db.EventStateKeyNIDs(ctx, neededStateKeys) + if err != nil { + return nil, err + } + + // Load the necessary auth events. + tuplesNeeded := stateKeyTuplesNeeded(stateKeyNIDMap, needed) + var authEntries []types.StateEntry + for _, tuple := range tuplesNeeded { + if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { + authEntries = append(authEntries, types.StateEntry{ + StateKeyTuple: tuple, + EventNID: eventNID, + }) + } + } + authEvents, _, err := loadStateEvents(ctx, db, authEntries) + if err != nil { + return nil, err + } + + // Resolve the conflicts. + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) + + // Map from the full events back to numeric state entries. + for _, resolvedEvent := range resolvedEvents { + entry, ok := eventIDMap[resolvedEvent.EventID()] + if !ok { + panic(fmt.Errorf("Missing state entry for event ID %q", resolvedEvent.EventID())) + } + notConflicted = append(notConflicted, entry) + } + + // Sort the result so it can be searched. + sort.Sort(stateEntrySorter(notConflicted)) + return notConflicted, nil +} + +// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events. +func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { + var keyTuples []types.StateKeyTuple + if stateNeeded.Create { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomCreateNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) + } + if stateNeeded.PowerLevels { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomPowerLevelsNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) + } + if stateNeeded.JoinRules { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomJoinRulesNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) + } + for _, member := range stateNeeded.Member { + stateKeyNID, ok := stateKeyNIDMap[member] + if ok { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomMemberNID, + EventStateKeyNID: stateKeyNID, + }) + } + } + for _, token := range stateNeeded.ThirdPartyInvite { + stateKeyNID, ok := stateKeyNIDMap[token] + if ok { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomThirdPartyInviteNID, + EventStateKeyNID: stateKeyNID, + }) + } + } + return keyTuples +} + +// loadStateEvents loads the matrix events for a list of state entries. +// Returns a list of state events in no particular order and a map from string event ID back to state entry. +// The map can be used to recover which numeric state entry a given event is for. +// Returns an error if there was a problem talking to the database. +func loadStateEvents( + ctx context.Context, db RoomStateDatabase, entries []types.StateEntry, +) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) { + eventNIDs := make([]types.EventNID, len(entries)) + for i := range entries { + eventNIDs[i] = entries[i].EventNID + } + events, err := db.Events(ctx, eventNIDs) + if err != nil { + return nil, nil, err + } + eventIDMap := map[string]types.StateEntry{} + result := make([]gomatrixserverlib.Event, len(entries)) + for i := range entries { + event, ok := eventMap(events).lookup(entries[i].EventNID) + if !ok { + panic(fmt.Errorf("Corrupt DB: Missing event numeric ID %d", entries[i].EventNID)) + } + result[i] = event.Event + eventIDMap[event.Event.EventID()] = entries[i] + } + return result, eventIDMap, nil +} + +// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. +// Returns a sorted list of those state entries. +func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { + var result []types.StateEntry + // j is the starting index of a block of entries with the same state key tuple. + j := 0 + for i := 1; i < len(a); i++ { + // Check if the state key tuple matches the start of the block + if a[j].StateKeyTuple != a[i].StateKeyTuple { + // If the state key tuple is different then we've reached the end of a block of duplicates. + // Check if the size of the block is bigger than one. + // If the size is one then there was only a single entry with that state key tuple so we don't add it to the result + if j+1 != i { + // Add the block to the result. + result = append(result, a[j:i]...) + } + // Start a new block for the next state key tuple. + j = i + } + } + // Check if the last block with the same state key tuple had more than one event in it. + if j+1 != len(a) { + result = append(result, a[j:]...) + } + return result +} + +type stateEntrySorter []types.StateEntry + +func (s stateEntrySorter) Len() int { return len(s) } +func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } +func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type stateBlockNIDListMap []types.StateBlockNIDList + +func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) { + list := []types.StateBlockNIDList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateSnapshotNID >= stateNID + }) + if i < len(list) && list[i].StateSnapshotNID == stateNID { + ok = true + stateBlockNIDs = list[i].StateBlockNIDs + } + return +} + +type stateEntryListMap []types.StateEntryList + +func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { + list := []types.StateEntryList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateBlockNID >= stateBlockNID + }) + if i < len(list) && list[i].StateBlockNID == stateBlockNID { + ok = true + stateEntries = list[i].StateEntries + } + return +} + +type stateEntryByStateKeySorter []types.StateEntry + +func (s stateEntryByStateKeySorter) Len() int { return len(s) } +func (s stateEntryByStateKeySorter) Less(i, j int) bool { + return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) +} +func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type stateNIDSorter []types.StateSnapshotNID + +func (s stateNIDSorter) Len() int { return len(s) } +func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID { + return nids[:util.SortAndUnique(stateNIDSorter(nids))] +} + +type stateBlockNIDSorter []types.StateBlockNID + +func (s stateBlockNIDSorter) Len() int { return len(s) } +func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID { + return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))] +} + +// Map from event type, state key tuple to numeric event ID. +// Implemented using binary search on a sorted array. +type stateEntryMap []types.StateEntry + +// lookup an entry in the event map. +func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size and are controlled by us. + list := []types.StateEntry(m) + i := sort.Search(len(list), func(i int) bool { + return !list[i].StateKeyTuple.LessThan(stateKey) + }) + if i < len(list) && list[i].StateKeyTuple == stateKey { + ok = true + eventNID = list[i].EventNID + } + return +} + +// Map from numeric event ID to event. +// Implemented using binary search on a sorted array. +type eventMap []types.Event + +// lookup an entry in the event map. +func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size are controlled by us. + list := []types.Event(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].EventNID >= eventNID + }) + if i < len(list) && list[i].EventNID == eventNID { + ok = true + event = &list[i] + } + return +} diff --git a/roomserver/state/state_test.go b/roomserver/state/state_test.go new file mode 100644 index 00000000..67af1867 --- /dev/null +++ b/roomserver/state/state_test.go @@ -0,0 +1,56 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "testing" + + "github.com/matrix-org/dendrite/roomserver/types" +) + +func TestFindDuplicateStateKeys(t *testing.T) { + testCases := []struct { + Input []types.StateEntry + Want []types.StateEntry + }{{ + Input: []types.StateEntry{ + {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1}, + {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 2}, + {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 2, EventStateKeyNID: 2}, EventNID: 3}, + }, + Want: []types.StateEntry{ + {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1}, + {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 2}, + }, + }, { + Input: []types.StateEntry{ + {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1}, + {StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 2}, EventNID: 2}, + }, + Want: nil, + }} + + for _, test := range testCases { + got := findDuplicateStateKeys(test.Input) + if len(got) != len(test.Want) { + t.Fatalf("Wanted %v, got %v", test.Want, got) + } + for i := range got { + if got[i] != test.Want[i] { + t.Fatalf("Wanted %v, got %v", test.Want, got) + } + } + } +} diff --git a/roomserver/storage/event_json_table.go b/roomserver/storage/event_json_table.go new file mode 100644 index 00000000..b81667d9 --- /dev/null +++ b/roomserver/storage/event_json_table.go @@ -0,0 +1,105 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/roomserver/types" +) + +const eventJSONSchema = ` +-- Stores the JSON for each event. This kept separate from the main events +-- table to keep the rows in the main events table small. +CREATE TABLE IF NOT EXISTS roomserver_event_json ( + -- Local numeric ID for the event. + event_nid BIGINT NOT NULL PRIMARY KEY, + -- The JSON for the event. + -- Stored as TEXT because this should be valid UTF-8. + -- Not stored as a JSONB because we always just pull the entire event + -- so there is no point in postgres parsing it. + -- Not stored as JSON because we already validate the JSON in the server + -- so there is no point in postgres validating it. + -- TODO: Should we be compressing the events with Snappy or DEFLATE? + event_json TEXT NOT NULL +); +` + +const insertEventJSONSQL = "" + + "INSERT INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +// Bulk event JSON lookup by numeric event ID. +// Sort by the numeric event ID. +// This means that we can use binary search to lookup by numeric event ID. +const bulkSelectEventJSONSQL = "" + + "SELECT event_nid, event_json FROM roomserver_event_json" + + " WHERE event_nid = ANY($1)" + + " ORDER BY event_nid ASC" + +type eventJSONStatements struct { + insertEventJSONStmt *sql.Stmt + bulkSelectEventJSONStmt *sql.Stmt +} + +func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(eventJSONSchema) + if err != nil { + return + } + return statementList{ + {&s.insertEventJSONStmt, insertEventJSONSQL}, + {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, + }.prepare(db) +} + +func (s *eventJSONStatements) insertEventJSON( + ctx context.Context, eventNID types.EventNID, eventJSON []byte, +) error { + _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) + return err +} + +type eventJSONPair struct { + EventNID types.EventNID + EventJSON []byte +} + +func (s *eventJSONStatements) bulkSelectEventJSON( + ctx context.Context, eventNIDs []types.EventNID, +) ([]eventJSONPair, error) { + rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + // We know that we will only get as many results as event NIDs + // because of the unique constraint on event NIDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than NIDs so we adjust the length of the slice before returning it. + results := make([]eventJSONPair, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + var eventNID int64 + if err := rows.Scan(&eventNID, &result.EventJSON); err != nil { + return nil, err + } + result.EventNID = types.EventNID(eventNID) + } + return results[:i], nil +} diff --git a/roomserver/storage/event_state_keys_table.go b/roomserver/storage/event_state_keys_table.go new file mode 100644 index 00000000..1ef93370 --- /dev/null +++ b/roomserver/storage/event_state_keys_table.go @@ -0,0 +1,153 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const eventStateKeysSchema = ` +-- Numeric versions of the event "state_key"s. State keys tend to be reused so +-- assigning each string a numeric ID should reduce the amount of data that +-- needs to be stored and fetched from the database. +-- It also means that many operations can work with int64 arrays rather than +-- string arrays which may help reduce GC pressure. +-- Well known state keys are pre-assigned numeric IDs: +-- 1 -> "" (the empty string) +-- Other state keys are automatically assigned numeric IDs starting from 2**16. +-- This leaves room to add more pre-assigned numeric IDs and clearly separates +-- the automatically assigned IDs from the pre-assigned IDs. +CREATE SEQUENCE IF NOT EXISTS roomserver_event_state_key_nid_seq START 65536; +CREATE TABLE IF NOT EXISTS roomserver_event_state_keys ( + -- Local numeric ID for the state key. + event_state_key_nid BIGINT PRIMARY KEY DEFAULT nextval('roomserver_event_state_key_nid_seq'), + event_state_key TEXT NOT NULL CONSTRAINT roomserver_event_state_key_unique UNIQUE +); +INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key) VALUES + (1, '') ON CONFLICT DO NOTHING; +` + +// Same as insertEventTypeNIDSQL +const insertEventStateKeyNIDSQL = "" + + "INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1)" + + " ON CONFLICT ON CONSTRAINT roomserver_event_state_key_unique" + + " DO NOTHING RETURNING (event_state_key_nid)" + +const selectEventStateKeyNIDSQL = "" + + "SELECT event_state_key_nid FROM roomserver_event_state_keys" + + " WHERE event_state_key = $1" + +// Bulk lookup from string state key to numeric ID for that state key. +// Takes an array of strings as the query parameter. +const bulkSelectEventStateKeyNIDSQL = "" + + "SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys" + + " WHERE event_state_key = ANY($1)" + +// Bulk lookup from numeric ID to string state key for that state key. +// Takes an array of strings as the query parameter. +const bulkSelectEventStateKeySQL = "" + + "SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys" + + " WHERE event_state_key_nid = ANY($1)" + +type eventStateKeyStatements struct { + insertEventStateKeyNIDStmt *sql.Stmt + selectEventStateKeyNIDStmt *sql.Stmt + bulkSelectEventStateKeyNIDStmt *sql.Stmt + bulkSelectEventStateKeyStmt *sql.Stmt +} + +func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(eventStateKeysSchema) + if err != nil { + return + } + return statementList{ + {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, + {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, + {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, + {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, + }.prepare(db) +} + +func (s *eventStateKeyStatements) insertEventStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + var eventStateKeyNID int64 + stmt := common.TxStmt(txn, s.insertEventStateKeyNIDStmt) + err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) + return types.EventStateKeyNID(eventStateKeyNID), err +} + +func (s *eventStateKeyStatements) selectEventStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + var eventStateKeyNID int64 + stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt) + err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) + return types.EventStateKeyNID(eventStateKeyNID), err +} + +func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext( + ctx, pq.StringArray(eventStateKeys), + ) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + result := make(map[string]types.EventStateKeyNID, len(eventStateKeys)) + for rows.Next() { + var stateKey string + var stateKeyNID int64 + if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { + return nil, err + } + result[stateKey] = types.EventStateKeyNID(stateKeyNID) + } + return result, nil +} + +func (s *eventStateKeyStatements) bulkSelectEventStateKey( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, +) (map[types.EventStateKeyNID]string, error) { + nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) + for i := range eventStateKeyNIDs { + nIDs[i] = int64(eventStateKeyNIDs[i]) + } + rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) + for rows.Next() { + var stateKey string + var stateKeyNID int64 + if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { + return nil, err + } + result[types.EventStateKeyNID(stateKeyNID)] = stateKey + } + return result, nil +} diff --git a/roomserver/storage/event_types_table.go b/roomserver/storage/event_types_table.go new file mode 100644 index 00000000..7b8d53a5 --- /dev/null +++ b/roomserver/storage/event_types_table.go @@ -0,0 +1,146 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const eventTypesSchema = ` +-- Numeric versions of the event "type"s. Event types tend to be taken from a +-- small common pool. Assigning each a numeric ID should reduce the amount of +-- data that needs to be stored and fetched from the database. +-- It also means that many operations can work with int64 arrays rather than +-- string arrays which may help reduce GC pressure. +-- Well known event types are pre-assigned numeric IDs: +-- 1 -> m.room.create +-- 2 -> m.room.power_levels +-- 3 -> m.room.join_rules +-- 4 -> m.room.third_party_invite +-- 5 -> m.room.member +-- 6 -> m.room.redaction +-- 7 -> m.room.history_visibility +-- Picking well-known numeric IDs for the events types that require special +-- attention during state conflict resolution means that we write that code +-- using numeric constants. +-- It also means that the numeric IDs for common event types should be +-- consistent between different instances which might make ad-hoc debugging +-- easier. +-- Other event types are automatically assigned numeric IDs starting from 2**16. +-- This leaves room to add more pre-assigned numeric IDs and clearly separates +-- the automatically assigned IDs from the pre-assigned IDs. +CREATE SEQUENCE IF NOT EXISTS roomserver_event_type_nid_seq START 65536; +CREATE TABLE IF NOT EXISTS roomserver_event_types ( + -- Local numeric ID for the event type. + event_type_nid BIGINT PRIMARY KEY DEFAULT nextval('roomserver_event_type_nid_seq'), + -- The string event_type. + event_type TEXT NOT NULL CONSTRAINT roomserver_event_type_unique UNIQUE +); +INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES + (1, 'm.room.create'), + (2, 'm.room.power_levels'), + (3, 'm.room.join_rules'), + (4, 'm.room.third_party_invite'), + (5, 'm.room.member'), + (6, 'm.room.redaction'), + (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING; +` + +// Assign a new numeric event type ID. +// The usual case is that the event type is not in the database. +// In that case the ID will be assigned using the next value from the sequence. +// We use `RETURNING` to tell postgres to return the assigned ID. +// But it's possible that the type was added in a query that raced with us. +// This will result in a conflict on the event_type_unique constraint, in this +// case we do nothing. Postgresql won't return a row in that case so we rely on +// the caller catching the sql.ErrNoRows error and running a select to get the row. +// We could get postgresql to return the row on a conflict by updating the row +// but it doesn't seem like a good idea to modify the rows just to make postgresql +// return it. Modifying the rows will cause postgres to assign a new tuple for the +// row even though the data doesn't change resulting in unncesssary modifications +// to the indexes. +const insertEventTypeNIDSQL = "" + + "INSERT INTO roomserver_event_types (event_type) VALUES ($1)" + + " ON CONFLICT ON CONSTRAINT roomserver_event_type_unique" + + " DO NOTHING RETURNING (event_type_nid)" + +const selectEventTypeNIDSQL = "" + + "SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1" + +// Bulk lookup from string event type to numeric ID for that event type. +// Takes an array of strings as the query parameter. +const bulkSelectEventTypeNIDSQL = "" + + "SELECT event_type, event_type_nid FROM roomserver_event_types" + + " WHERE event_type = ANY($1)" + +type eventTypeStatements struct { + insertEventTypeNIDStmt *sql.Stmt + selectEventTypeNIDStmt *sql.Stmt + bulkSelectEventTypeNIDStmt *sql.Stmt +} + +func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(eventTypesSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, + {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, + {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, + }.prepare(db) +} + +func (s *eventTypeStatements) insertEventTypeNID( + ctx context.Context, eventType string, +) (types.EventTypeNID, error) { + var eventTypeNID int64 + err := s.insertEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + return types.EventTypeNID(eventTypeNID), err +} + +func (s *eventTypeStatements) selectEventTypeNID( + ctx context.Context, eventType string, +) (types.EventTypeNID, error) { + var eventTypeNID int64 + err := s.selectEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + return types.EventTypeNID(eventTypeNID), err +} + +func (s *eventTypeStatements) bulkSelectEventTypeNID( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + result := make(map[string]types.EventTypeNID, len(eventTypes)) + for rows.Next() { + var eventType string + var eventTypeNID int64 + if err := rows.Scan(&eventType, &eventTypeNID); err != nil { + return nil, err + } + result[eventType] = types.EventTypeNID(eventTypeNID) + } + return result, nil +} diff --git a/roomserver/storage/events_table.go b/roomserver/storage/events_table.go new file mode 100644 index 00000000..5bad939f --- /dev/null +++ b/roomserver/storage/events_table.go @@ -0,0 +1,410 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const eventsSchema = ` +-- The events table holds metadata for each event, the actual JSON is stored +-- separately to keep the size of the rows small. +CREATE SEQUENCE IF NOT EXISTS roomserver_event_nid_seq; +CREATE TABLE IF NOT EXISTS roomserver_events ( + -- Local numeric ID for the event. + event_nid BIGINT PRIMARY KEY DEFAULT nextval('roomserver_event_nid_seq'), + -- Local numeric ID for the room the event is in. + -- This is never 0. + room_nid BIGINT NOT NULL, + -- Local numeric ID for the type of the event. + -- This is never 0. + event_type_nid BIGINT NOT NULL, + -- Local numeric ID for the state_key of the event + -- This is 0 if the event is not a state event. + event_state_key_nid BIGINT NOT NULL, + -- Whether the event has been written to the output log. + sent_to_output BOOLEAN NOT NULL DEFAULT FALSE, + -- Local numeric ID for the state at the event. + -- This is 0 if we don't know the state at the event. + -- If the state is not 0 then this event is part of the contiguous + -- part of the event graph + -- Since many different events can have the same state we store the + -- state into a separate state table and refer to it by numeric ID. + state_snapshot_nid BIGINT NOT NULL DEFAULT 0, + -- Depth of the event in the event graph. + depth BIGINT NOT NULL, + -- The textual event id. + -- Used to lookup the numeric ID when processing requests. + -- Needed for state resolution. + -- An event may only appear in this table once. + event_id TEXT NOT NULL CONSTRAINT roomserver_event_id_unique UNIQUE, + -- The sha256 reference hash for the event. + -- Needed for setting reference hashes when sending new events. + reference_sha256 BYTEA NOT NULL, + -- A list of numeric IDs for events that can authenticate this event. + auth_event_nids BIGINT[] NOT NULL +); +` + +const insertEventSQL = "" + + "INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7)" + + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique" + + " DO NOTHING" + + " RETURNING event_nid, state_snapshot_nid" + +const selectEventSQL = "" + + "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" + +// Bulk lookup of events by string ID. +// Sort by the numeric IDs for event type and state key. +// This means we can use binary search to lookup entries by type and state key. +const bulkSelectStateEventByIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_id = ANY($1)" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + +const bulkSelectStateAtEventByIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM roomserver_events" + + " WHERE event_id = ANY($1)" + +const updateEventStateSQL = "" + + "UPDATE roomserver_events SET state_snapshot_nid = $2 WHERE event_nid = $1" + +const selectEventSentToOutputSQL = "" + + "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" + +const updateEventSentToOutputSQL = "" + + "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1" + +const selectEventIDSQL = "" + + "SELECT event_id FROM roomserver_events WHERE event_nid = $1" + +const bulkSelectStateAtEventAndReferenceSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + + " FROM roomserver_events WHERE event_nid = ANY($1)" + +const bulkSelectEventReferenceSQL = "" + + "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid = ANY($1)" + +const bulkSelectEventIDSQL = "" + + "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid = ANY($1)" + +const bulkSelectEventNIDSQL = "" + + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)" + +const selectMaxEventDepthSQL = "" + + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" + +type eventStatements struct { + insertEventStmt *sql.Stmt + selectEventStmt *sql.Stmt + bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateAtEventByIDStmt *sql.Stmt + updateEventStateStmt *sql.Stmt + selectEventSentToOutputStmt *sql.Stmt + updateEventSentToOutputStmt *sql.Stmt + selectEventIDStmt *sql.Stmt + bulkSelectStateAtEventAndReferenceStmt *sql.Stmt + bulkSelectEventReferenceStmt *sql.Stmt + bulkSelectEventIDStmt *sql.Stmt + bulkSelectEventNIDStmt *sql.Stmt + selectMaxEventDepthStmt *sql.Stmt +} + +func (s *eventStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(eventsSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertEventStmt, insertEventSQL}, + {&s.selectEventStmt, selectEventSQL}, + {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, + {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, + {&s.updateEventStateStmt, updateEventStateSQL}, + {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, + {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, + {&s.selectEventIDStmt, selectEventIDSQL}, + {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, + {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, + {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, + {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, + }.prepare(db) +} + +func (s *eventStatements) insertEvent( + ctx context.Context, + roomNID types.RoomNID, + eventTypeNID types.EventTypeNID, + eventStateKeyNID types.EventStateKeyNID, + eventID string, + referenceSHA256 []byte, + authEventNIDs []types.EventNID, + depth int64, +) (types.EventNID, types.StateSnapshotNID, error) { + var eventNID int64 + var stateNID int64 + err := s.insertEventStmt.QueryRowContext( + ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), + eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, + ).Scan(&eventNID, &stateNID) + return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err +} + +func (s *eventStatements) selectEvent( + ctx context.Context, eventID string, +) (types.EventNID, types.StateSnapshotNID, error) { + var eventNID int64 + var stateNID int64 + err := s.selectEventStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID) + return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err +} + +// bulkSelectStateEventByID lookups a list of state events by event ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError +func (s *eventStatements) bulkSelectStateEventByID( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + // We know that we will only get as many results as event IDs + // because of the unique constraint on event IDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than IDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + ); err != nil { + return nil, err + } + } + if i != len(eventIDs) { + // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. + // We don't know which ones were missing because we don't return the string IDs in the query. + // However it should be possible debug this by replaying queries or entries from the input kafka logs. + // If this turns out to be impossible and we do need the debug information here, it would be better + // to do it as a separate query rather than slowing down/complicating the common case. + return nil, types.MissingEventError( + fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)), + ) + } + return results, err +} + +// bulkSelectStateAtEventByID lookups the state at a list of events by event ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError. +// If we do not have the state for any of the requested events it returns a types.MissingEventError. +func (s *eventStatements) bulkSelectStateAtEventByID( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]types.StateAtEvent, len(eventIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + &result.BeforeStateSnapshotNID, + ); err != nil { + return nil, err + } + if result.BeforeStateSnapshotNID == 0 { + return nil, types.MissingEventError( + fmt.Sprintf("storage: missing state for event NID %d", result.EventNID), + ) + } + } + if i != len(eventIDs) { + return nil, types.MissingEventError( + fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)), + ) + } + return results, err +} + +func (s *eventStatements) updateEventState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + _, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) + return err +} + +func (s *eventStatements) selectEventSentToOutput( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (sentToOutput bool, err error) { + stmt := common.TxStmt(txn, s.selectEventSentToOutputStmt) + err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) + return +} + +func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { + stmt := common.TxStmt(txn, s.updateEventSentToOutputStmt) + _, err := stmt.ExecContext(ctx, int64(eventNID)) + return err +} + +func (s *eventStatements) selectEventID( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (eventID string, err error) { + stmt := common.TxStmt(txn, s.selectEventIDStmt) + err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID) + return +} + +func (s *eventStatements) bulkSelectStateAtEventAndReference( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, +) ([]types.StateAtEventAndReference, error) { + stmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]types.StateAtEventAndReference, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + var ( + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + stateSnapshotNID int64 + eventID string + eventSHA256 []byte + ) + if err = rows.Scan( + &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, + ); err != nil { + return nil, err + } + result := &results[i] + result.EventTypeNID = types.EventTypeNID(eventTypeNID) + result.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + result.EventNID = types.EventNID(eventNID) + result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID) + result.EventID = eventID + result.EventSHA256 = eventSHA256 + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + +func (s *eventStatements) bulkSelectEventReference( + ctx context.Context, eventNIDs []types.EventNID, +) ([]gomatrixserverlib.EventReference, error) { + rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]gomatrixserverlib.EventReference, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil { + return nil, err + } + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + +// bulkSelectEventID returns a map from numeric event ID to string event ID. +func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { + rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make(map[types.EventNID]string, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + var eventNID int64 + var eventID string + if err = rows.Scan(&eventNID, &eventID); err != nil { + return nil, err + } + results[types.EventNID(eventNID)] = eventID + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + +// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { + rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make(map[string]types.EventNID, len(eventIDs)) + for rows.Next() { + var eventID string + var eventNID int64 + if err = rows.Scan(&eventID, &eventNID); err != nil { + return nil, err + } + results[eventID] = types.EventNID(eventNID) + } + return results, nil +} + +func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) { + var result int64 + stmt := s.selectMaxEventDepthStmt + err := stmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result) + if err != nil { + return 0, err + } + return result, nil +} + +func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { + nids := make([]int64, len(eventNIDs)) + for i := range eventNIDs { + nids[i] = int64(eventNIDs[i]) + } + return nids +} diff --git a/roomserver/storage/invite_table.go b/roomserver/storage/invite_table.go new file mode 100644 index 00000000..4f9cdfb4 --- /dev/null +++ b/roomserver/storage/invite_table.go @@ -0,0 +1,154 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const inviteSchema = ` +CREATE TABLE IF NOT EXISTS roomserver_invites ( + -- The string ID of the invite event itself. + -- We can't use a numeric event ID here because we don't always have + -- enough information to store an invite in the event table. + -- In particular we don't always have a chain of auth_events for invites + -- received over federation. + invite_event_id TEXT PRIMARY KEY, + -- The numeric ID of the room the invite m.room.member event is in. + room_nid BIGINT NOT NULL, + -- The numeric ID for the state key of the invite m.room.member event. + -- This tells us who the invite is for. + -- This is used to query the active invites for a user. + target_nid BIGINT NOT NULL, + -- The numeric ID for the sender of the invite m.room.member event. + -- This tells us who sent the invite. + -- This is used to work out which matrix server we should talk to when + -- we try to join the room. + sender_nid BIGINT NOT NULL DEFAULT 0, + -- This is used to track whether the invite is still active. + -- This is set implicitly when processing new join and leave events and + -- explicitly when rejecting events over federation. + retired BOOLEAN NOT NULL DEFAULT FALSE, + -- The invite event JSON. + invite_event_json TEXT NOT NULL +); + +CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid) + WHERE NOT retired; +` +const insertInviteEventSQL = "" + + "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," + + " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" + + " ON CONFLICT DO NOTHING" + +const selectInviteActiveForUserInRoomSQL = "" + + "SELECT sender_nid FROM roomserver_invites" + + " WHERE target_nid = $1 AND room_nid = $2" + + " AND NOT retired" + +// Retire every active invite for a user in a room. +// Ideally we'd know which invite events were retired by a given update so we +// wouldn't need to remove every active invite. +// However the matrix protocol doesn't give us a way to reliably identify the +// invites that were retired, so we are forced to retire all of them. +const updateInviteRetiredSQL = "" + + "UPDATE roomserver_invites SET retired = TRUE" + + " WHERE room_nid = $1 AND target_nid = $2 AND NOT retired" + + " RETURNING invite_event_id" + +type inviteStatements struct { + insertInviteEventStmt *sql.Stmt + selectInviteActiveForUserInRoomStmt *sql.Stmt + updateInviteRetiredStmt *sql.Stmt +} + +func (s *inviteStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(inviteSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, + {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, + }.prepare(db) +} + +func (s *inviteStatements) insertInviteEvent( + ctx context.Context, + txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, + targetUserNID, senderUserNID types.EventStateKeyNID, + inviteEventJSON []byte, +) (bool, error) { + result, err := common.TxStmt(txn, s.insertInviteEventStmt).ExecContext( + ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, + ) + if err != nil { + return false, err + } + count, err := result.RowsAffected() + if err != nil { + return false, err + } + return count != 0, nil +} + +func (s *inviteStatements) updateInviteRetired( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) (eventIDs []string, err error) { + stmt := common.TxStmt(txn, s.updateInviteRetiredStmt) + rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) + if err != nil { + return nil, err + } + defer (func() { err = rows.Close() })() + for rows.Next() { + var inviteEventID string + if err := rows.Scan(&inviteEventID); err != nil { + return nil, err + } + eventIDs = append(eventIDs, inviteEventID) + } + return +} + +// selectInviteActiveForUserInRoom returns a list of sender state key NIDs +func (s *inviteStatements) selectInviteActiveForUserInRoom( + ctx context.Context, + targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, +) ([]types.EventStateKeyNID, error) { + rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( + ctx, targetUserNID, roomNID, + ) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + var result []types.EventStateKeyNID + for rows.Next() { + var senderUserNID int64 + if err := rows.Scan(&senderUserNID); err != nil { + return nil, err + } + result = append(result, types.EventStateKeyNID(senderUserNID)) + } + return result, nil +} diff --git a/roomserver/storage/membership_table.go b/roomserver/storage/membership_table.go new file mode 100644 index 00000000..88a9ed72 --- /dev/null +++ b/roomserver/storage/membership_table.go @@ -0,0 +1,193 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +type membershipState int64 + +const ( + membershipStateLeaveOrBan membershipState = 1 + membershipStateInvite membershipState = 2 + membershipStateJoin membershipState = 3 +) + +const membershipSchema = ` +-- The membership table is used to coordinate updates between the invite table +-- and the room state tables. +-- This table is updated in one of 3 ways: +-- 1) The membership of a user changes within the current state of the room. +-- 2) An invite is received outside of a room over federation. +-- 3) An invite is rejected outside of a room over federation. +CREATE TABLE IF NOT EXISTS roomserver_membership ( + room_nid BIGINT NOT NULL, + -- Numeric state key ID for the user ID this state is for. + target_nid BIGINT NOT NULL, + -- Numeric state key ID for the user ID who changed the state. + -- This may be 0 since it is not always possible to identify the user that + -- changed the state. + sender_nid BIGINT NOT NULL DEFAULT 0, + -- The state the user is in within this room. + -- Default value is "membershipStateLeaveOrBan" + membership_nid BIGINT NOT NULL DEFAULT 1, + -- The numeric ID of the membership event. + -- It refers to the join membership event if the membership_nid is join (3), + -- and to the leave/ban membership event if the membership_nid is leave or + -- ban (1). + -- If the membership_nid is invite (2) and the user has been in the room + -- before, it will refer to the previous leave/ban membership event, and will + -- be equals to 0 (its default) if the user never joined the room before. + -- This NID is updated if the join event gets updated (e.g. profile update), + -- or if the user leaves/joins the room. + event_nid BIGINT NOT NULL DEFAULT 0, + UNIQUE (room_nid, target_nid) +); +` + +// Insert a row in to membership table so that it can be locked by the +// SELECT FOR UPDATE +const insertMembershipSQL = "" + + "INSERT INTO roomserver_membership (room_nid, target_nid)" + + " VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectMembershipFromRoomAndTargetSQL = "" + + "SELECT membership_nid, event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND target_nid = $2" + +const selectMembershipsFromRoomAndMembershipSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND membership_nid = $2" + +const selectMembershipsFromRoomSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1" + +const selectMembershipForUpdateSQL = "" + + "SELECT membership_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE" + +const updateMembershipSQL = "" + + "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" + + " WHERE room_nid = $1 AND target_nid = $2" + +type membershipStatements struct { + insertMembershipStmt *sql.Stmt + selectMembershipForUpdateStmt *sql.Stmt + selectMembershipFromRoomAndTargetStmt *sql.Stmt + selectMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectMembershipsFromRoomStmt *sql.Stmt + updateMembershipStmt *sql.Stmt +} + +func (s *membershipStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(membershipSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertMembershipStmt, insertMembershipSQL}, + {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, + {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, + {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, + {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, + {&s.updateMembershipStmt, updateMembershipSQL}, + }.prepare(db) +} + +func (s *membershipStatements) insertMembership( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) error { + stmt := common.TxStmt(txn, s.insertMembershipStmt) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID) + return err +} + +func (s *membershipStatements) selectMembershipForUpdate( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) (membership membershipState, err error) { + err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( + ctx, roomNID, targetUserNID, + ).Scan(&membership) + return +} + +func (s *membershipStatements) selectMembershipFromRoomAndTarget( + ctx context.Context, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) (eventNID types.EventNID, membership membershipState, err error) { + err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( + ctx, roomNID, targetUserNID, + ).Scan(&membership, &eventNID) + return +} + +func (s *membershipStatements) selectMembershipsFromRoom( + ctx context.Context, roomNID types.RoomNID, +) (eventNIDs []types.EventNID, err error) { + rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID) + if err != nil { + return + } + + for rows.Next() { + var eNID types.EventNID + if err = rows.Scan(&eNID); err != nil { + return + } + eventNIDs = append(eventNIDs, eNID) + } + return +} +func (s *membershipStatements) selectMembershipsFromRoomAndMembership( + ctx context.Context, + roomNID types.RoomNID, membership membershipState, +) (eventNIDs []types.EventNID, err error) { + stmt := s.selectMembershipsFromRoomAndMembershipStmt + rows, err := stmt.QueryContext(ctx, roomNID, membership) + if err != nil { + return + } + + for rows.Next() { + var eNID types.EventNID + if err = rows.Scan(&eNID); err != nil { + return + } + eventNIDs = append(eventNIDs, eNID) + } + return +} + +func (s *membershipStatements) updateMembership( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + senderUserNID types.EventStateKeyNID, membership membershipState, + eventNID types.EventNID, +) error { + _, err := common.TxStmt(txn, s.updateMembershipStmt).ExecContext( + ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, + ) + return err +} diff --git a/roomserver/storage/prepare.go b/roomserver/storage/prepare.go new file mode 100644 index 00000000..61c49a3c --- /dev/null +++ b/roomserver/storage/prepare.go @@ -0,0 +1,36 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "database/sql" +) + +// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. +type statementList []struct { + statement **sql.Stmt + sql string +} + +// prepare the SQL for each statement in the list and assign the result to the prepared statement. +// nolint: safesql +func (s statementList) prepare(db *sql.DB) (err error) { + for _, statement := range s { + if *statement.statement, err = db.Prepare(statement.sql); err != nil { + return + } + } + return +} diff --git a/roomserver/storage/previous_events_table.go b/roomserver/storage/previous_events_table.go new file mode 100644 index 00000000..81d581a9 --- /dev/null +++ b/roomserver/storage/previous_events_table.go @@ -0,0 +1,99 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const previousEventSchema = ` +-- The previous events table stores the event_ids referenced by the events +-- stored in the events table. +-- This is used to tell if a new event is already referenced by an event in +-- the database. +CREATE TABLE IF NOT EXISTS roomserver_previous_events ( + -- The string event ID taken from the prev_events key of an event. + previous_event_id TEXT NOT NULL, + -- The SHA256 reference hash taken from the prev_events key of an event. + previous_reference_sha256 BYTEA NOT NULL, + -- A list of numeric event IDs of events that reference this prev_event. + event_nids BIGINT[] NOT NULL, + CONSTRAINT roomserver_previous_event_id_unique UNIQUE (previous_event_id, previous_reference_sha256) +); +` + +// Insert an entry into the previous_events table. +// If there is already an entry indicating that an event references that previous event then +// add the event NID to the list to indicate that this event references that previous event as well. +// This should only be modified while holding a "FOR UPDATE" lock on the row in the rooms table for this room. +// The lock is necessary to avoid data races when checking whether an event is already referenced by another event. +const insertPreviousEventSQL = "" + + "INSERT INTO roomserver_previous_events" + + " (previous_event_id, previous_reference_sha256, event_nids)" + + " VALUES ($1, $2, array_append('{}'::bigint[], $3))" + + " ON CONFLICT ON CONSTRAINT roomserver_previous_event_id_unique" + + " DO UPDATE SET event_nids = array_append(roomserver_previous_events.event_nids, $3)" + + " WHERE $3 != ALL(roomserver_previous_events.event_nids)" + +// Check if the event is referenced by another event in the table. +// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. +const selectPreviousEventExistsSQL = "" + + "SELECT 1 FROM roomserver_previous_events" + + " WHERE previous_event_id = $1 AND previous_reference_sha256 = $2" + +type previousEventStatements struct { + insertPreviousEventStmt *sql.Stmt + selectPreviousEventExistsStmt *sql.Stmt +} + +func (s *previousEventStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(previousEventSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertPreviousEventStmt, insertPreviousEventSQL}, + {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, + }.prepare(db) +} + +func (s *previousEventStatements) insertPreviousEvent( + ctx context.Context, + txn *sql.Tx, + previousEventID string, + previousEventReferenceSHA256 []byte, + eventNID types.EventNID, +) error { + stmt := common.TxStmt(txn, s.insertPreviousEventStmt) + _, err := stmt.ExecContext( + ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), + ) + return err +} + +// Check if the event reference exists +// Returns sql.ErrNoRows if the event reference doesn't exist. +func (s *previousEventStatements) selectPreviousEventExists( + ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, +) error { + var ok int64 + stmt := common.TxStmt(txn, s.selectPreviousEventExistsStmt) + return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) +} diff --git a/roomserver/storage/room_aliases_table.go b/roomserver/storage/room_aliases_table.go new file mode 100644 index 00000000..f640c37f --- /dev/null +++ b/roomserver/storage/room_aliases_table.go @@ -0,0 +1,109 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" +) + +const roomAliasesSchema = ` +-- Stores room aliases and room IDs they refer to +CREATE TABLE IF NOT EXISTS roomserver_room_aliases ( + -- Alias of the room + alias TEXT NOT NULL PRIMARY KEY, + -- Room ID the alias refers to + room_id TEXT NOT NULL +); + +CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id); +` + +const insertRoomAliasSQL = "" + + "INSERT INTO roomserver_room_aliases (alias, room_id) VALUES ($1, $2)" + +const selectRoomIDFromAliasSQL = "" + + "SELECT room_id FROM roomserver_room_aliases WHERE alias = $1" + +const selectAliasesFromRoomIDSQL = "" + + "SELECT alias FROM roomserver_room_aliases WHERE room_id = $1" + +const deleteRoomAliasSQL = "" + + "DELETE FROM roomserver_room_aliases WHERE alias = $1" + +type roomAliasesStatements struct { + insertRoomAliasStmt *sql.Stmt + selectRoomIDFromAliasStmt *sql.Stmt + selectAliasesFromRoomIDStmt *sql.Stmt + deleteRoomAliasStmt *sql.Stmt +} + +func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(roomAliasesSchema) + if err != nil { + return + } + return statementList{ + {&s.insertRoomAliasStmt, insertRoomAliasSQL}, + {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, + {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, + {&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, + }.prepare(db) +} + +func (s *roomAliasesStatements) insertRoomAlias( + ctx context.Context, alias string, roomID string, +) (err error) { + _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID) + return +} + +func (s *roomAliasesStatements) selectRoomIDFromAlias( + ctx context.Context, alias string, +) (roomID string, err error) { + err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) + if err == sql.ErrNoRows { + return "", nil + } + return +} + +func (s *roomAliasesStatements) selectAliasesFromRoomID( + ctx context.Context, roomID string, +) (aliases []string, err error) { + aliases = []string{} + rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) + if err != nil { + return + } + + for rows.Next() { + var alias string + if err = rows.Scan(&alias); err != nil { + return + } + + aliases = append(aliases, alias) + } + + return +} + +func (s *roomAliasesStatements) deleteRoomAlias( + ctx context.Context, alias string, +) (err error) { + _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) + return +} diff --git a/roomserver/storage/rooms_table.go b/roomserver/storage/rooms_table.go new file mode 100644 index 00000000..64193ffe --- /dev/null +++ b/roomserver/storage/rooms_table.go @@ -0,0 +1,155 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const roomsSchema = ` +CREATE SEQUENCE IF NOT EXISTS roomserver_room_nid_seq; +CREATE TABLE IF NOT EXISTS roomserver_rooms ( + -- Local numeric ID for the room. + room_nid BIGINT PRIMARY KEY DEFAULT nextval('roomserver_room_nid_seq'), + -- Textual ID for the room. + room_id TEXT NOT NULL CONSTRAINT roomserver_room_id_unique UNIQUE, + -- The most recent events in the room that aren't referenced by another event. + -- This list may empty if the server hasn't joined the room yet. + -- (The server will be in that state while it stores the events for the initial state of the room) + latest_event_nids BIGINT[] NOT NULL DEFAULT '{}'::BIGINT[], + -- The last event written to the output log for this room. + last_event_sent_nid BIGINT NOT NULL DEFAULT 0, + -- The state of the room after the current set of latest events. + -- This will be 0 if there are no latest events in the room. + state_snapshot_nid BIGINT NOT NULL DEFAULT 0 +); +` + +// Same as insertEventTypeNIDSQL +const insertRoomNIDSQL = "" + + "INSERT INTO roomserver_rooms (room_id) VALUES ($1)" + + " ON CONFLICT ON CONSTRAINT roomserver_room_id_unique" + + " DO NOTHING RETURNING (room_nid)" + +const selectRoomNIDSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" + +const selectLatestEventNIDsSQL = "" + + "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" + +const selectLatestEventNIDsForUpdateSQL = "" + + "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1 FOR UPDATE" + +const updateLatestEventNIDsSQL = "" + + "UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1" + +type roomStatements struct { + insertRoomNIDStmt *sql.Stmt + selectRoomNIDStmt *sql.Stmt + selectLatestEventNIDsStmt *sql.Stmt + selectLatestEventNIDsForUpdateStmt *sql.Stmt + updateLatestEventNIDsStmt *sql.Stmt +} + +func (s *roomStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(roomsSchema) + if err != nil { + return + } + return statementList{ + {&s.insertRoomNIDStmt, insertRoomNIDSQL}, + {&s.selectRoomNIDStmt, selectRoomNIDSQL}, + {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, + {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, + {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, + }.prepare(db) +} + +func (s *roomStatements) insertRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := common.TxStmt(txn, s.insertRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + +func (s *roomStatements) selectRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := common.TxStmt(txn, s.selectRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + +func (s *roomStatements) selectLatestEventNIDs( + ctx context.Context, roomNID types.RoomNID, +) ([]types.EventNID, types.StateSnapshotNID, error) { + var nids pq.Int64Array + var stateSnapshotNID int64 + stmt := s.selectLatestEventNIDsStmt + err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID) + if err != nil { + return nil, 0, err + } + eventNIDs := make([]types.EventNID, len(nids)) + for i := range nids { + eventNIDs[i] = types.EventNID(nids[i]) + } + return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil +} + +func (s *roomStatements) selectLatestEventsNIDsForUpdate( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { + var nids pq.Int64Array + var lastEventSentNID int64 + var stateSnapshotNID int64 + stmt := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) + err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) + if err != nil { + return nil, 0, 0, err + } + eventNIDs := make([]types.EventNID, len(nids)) + for i := range nids { + eventNIDs[i] = types.EventNID(nids[i]) + } + return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil +} + +func (s *roomStatements) updateLatestEventNIDs( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + eventNIDs []types.EventNID, + lastEventSentNID types.EventNID, + stateSnapshotNID types.StateSnapshotNID, +) error { + stmt := common.TxStmt(txn, s.updateLatestEventNIDsStmt) + _, err := stmt.ExecContext( + ctx, + roomNID, + eventNIDsAsArray(eventNIDs), + int64(lastEventSentNID), + int64(stateSnapshotNID), + ) + return err +} diff --git a/roomserver/storage/sql.go b/roomserver/storage/sql.go new file mode 100644 index 00000000..05efa8dd --- /dev/null +++ b/roomserver/storage/sql.go @@ -0,0 +1,59 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "database/sql" +) + +type statements struct { + eventTypeStatements + eventStateKeyStatements + roomStatements + eventStatements + eventJSONStatements + stateSnapshotStatements + stateBlockStatements + previousEventStatements + roomAliasesStatements + inviteStatements + membershipStatements + transactionStatements +} + +func (s *statements) prepare(db *sql.DB) error { + var err error + + for _, prepare := range []func(db *sql.DB) error{ + s.eventTypeStatements.prepare, + s.eventStateKeyStatements.prepare, + s.roomStatements.prepare, + s.eventStatements.prepare, + s.eventJSONStatements.prepare, + s.stateSnapshotStatements.prepare, + s.stateBlockStatements.prepare, + s.previousEventStatements.prepare, + s.roomAliasesStatements.prepare, + s.inviteStatements.prepare, + s.membershipStatements.prepare, + s.transactionStatements.prepare, + } { + if err = prepare(db); err != nil { + return err + } + } + + return nil +} diff --git a/roomserver/storage/state_block_table.go b/roomserver/storage/state_block_table.go new file mode 100644 index 00000000..b2e8ef8a --- /dev/null +++ b/roomserver/storage/state_block_table.go @@ -0,0 +1,280 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" + "fmt" + "sort" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" +) + +const stateDataSchema = ` +-- The state data map. +-- Designed to give enough information to run the state resolution algorithm +-- without hitting the database in the common case. +-- TODO: Is it worth replacing the unique btree index with a covering index so +-- that postgres could lookup the state using an index-only scan? +-- The type and state_key are included in the index to make it easier to +-- lookup a specific (type, state_key) pair for an event. It also makes it easy +-- to read the state for a given state_block_nid ordered by (type, state_key) +-- which in turn makes it easier to merge state data blocks. +CREATE SEQUENCE IF NOT EXISTS roomserver_state_block_nid_seq; +CREATE TABLE IF NOT EXISTS roomserver_state_block ( + -- Local numeric ID for this state data. + state_block_nid bigint NOT NULL, + event_type_nid bigint NOT NULL, + event_state_key_nid bigint NOT NULL, + event_nid bigint NOT NULL, + UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) +); +` + +const insertStateDataSQL = "" + + "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + + " VALUES ($1, $2, $3, $4)" + +const selectNextStateBlockNIDSQL = "" + + "SELECT nextval('roomserver_state_block_nid_seq')" + +// Bulk state lookup by numeric state block ID. +// Sort by the state_block_nid, event_type_nid, event_state_key_nid +// This means that all the entries for a given state_block_nid will appear +// together in the list and those entries will sorted by event_type_nid +// and event_state_key_nid. This property makes it easier to merge two +// state data blocks together. +const bulkSelectStateBlockEntriesSQL = "" + + "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + + " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" + + " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + +// Bulk state lookup by numeric state block ID. +// Filters the rows in each block to the requested types and state keys. +// We would like to restrict to particular type state key pairs but we are +// restricted by the query language to pull the cross product of a list +// of types and a list state_keys. So we have to filter the result in the +// application to restrict it to the list of event types and state keys we +// actually wanted. +const bulkSelectFilteredStateBlockEntriesSQL = "" + + "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + + " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" + + " AND event_type_nid = ANY($2) AND event_state_key_nid = ANY($3)" + + " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + +type stateBlockStatements struct { + insertStateDataStmt *sql.Stmt + selectNextStateBlockNIDStmt *sql.Stmt + bulkSelectStateBlockEntriesStmt *sql.Stmt + bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt +} + +func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(stateDataSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertStateDataStmt, insertStateDataSQL}, + {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, + {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, + {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, + }.prepare(db) +} + +func (s *stateBlockStatements) bulkInsertStateData( + ctx context.Context, + stateBlockNID types.StateBlockNID, + entries []types.StateEntry, +) error { + for _, entry := range entries { + _, err := s.insertStateDataStmt.ExecContext( + ctx, + int64(stateBlockNID), + int64(entry.EventTypeNID), + int64(entry.EventStateKeyNID), + int64(entry.EventNID), + ) + if err != nil { + return err + } + } + return nil +} + +func (s *stateBlockStatements) selectNextStateBlockNID( + ctx context.Context, +) (types.StateBlockNID, error) { + var stateBlockNID int64 + err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID) + return types.StateBlockNID(stateBlockNID), err +} + +func (s *stateBlockStatements) bulkSelectStateBlockEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + nids := make([]int64, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + nids[i] = int64(stateBlockNIDs[i]) + } + rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + results := make([]types.StateEntryList, len(stateBlockNIDs)) + // current is a pointer to the StateEntryList to append the state entries to. + var current *types.StateEntryList + i := 0 + for rows.Next() { + var ( + stateBlockNID int64 + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + entry types.StateEntry + ) + if err := rows.Scan( + &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, + ); err != nil { + return nil, err + } + entry.EventTypeNID = types.EventTypeNID(eventTypeNID) + entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + entry.EventNID = types.EventNID(eventNID) + if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { + // The state entry row is for a different state data block to the current one. + // So we start appending to the next entry in the list. + current = &results[i] + current.StateBlockNID = types.StateBlockNID(stateBlockNID) + i++ + } + current.StateEntries = append(current.StateEntries, entry) + } + if i != len(stateBlockNIDs) { + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs)) + } + return results, nil +} + +func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + tuples := stateKeyTupleSorter(stateKeyTuples) + // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. + sort.Sort(tuples) + + eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext( + ctx, + stateBlockNIDsAsArray(stateBlockNIDs), + eventTypeNIDArray, + eventStateKeyNIDArray, + ) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + var results []types.StateEntryList + var current types.StateEntryList + for rows.Next() { + var ( + stateBlockNID int64 + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + entry types.StateEntry + ) + if err := rows.Scan( + &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, + ); err != nil { + return nil, err + } + entry.EventTypeNID = types.EventTypeNID(eventTypeNID) + entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + entry.EventNID = types.EventNID(eventNID) + + // We can use binary search here because we sorted the tuples earlier + if !tuples.contains(entry.StateKeyTuple) { + // The select will return the cross product of types and state keys. + // So we need to check if type of the entry is in the list. + continue + } + + if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { + // The state entry row is for a different state data block to the current one. + // So we append the current entry to the results and start adding to a new one. + // The first time through the loop current will be empty. + if current.StateEntries != nil { + results = append(results, current) + } + current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} + } + current.StateEntries = append(current.StateEntries, entry) + } + // Add the last entry to the list if it is not empty. + if current.StateEntries != nil { + results = append(results, current) + } + return results, nil +} + +func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array { + nids := make([]int64, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + nids[i] = int64(stateBlockNIDs[i]) + } + return pq.Int64Array(nids) +} + +type stateKeyTupleSorter []types.StateKeyTuple + +func (s stateKeyTupleSorter) Len() int { return len(s) } +func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } +func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Check whether a tuple is in the list. Assumes that the list is sorted. +func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool { + i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) }) + return i < len(s) && s[i] == value +} + +// List the unique eventTypeNIDs and eventStateKeyNIDs. +// Assumes that the list is sorted. +func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) { + eventTypeNIDs = make(pq.Int64Array, len(s)) + eventStateKeyNIDs = make(pq.Int64Array, len(s)) + for i := range s { + eventTypeNIDs[i] = int64(s[i].EventTypeNID) + eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID) + } + eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))] + eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))] + return +} + +type int64Sorter []int64 + +func (s int64Sorter) Len() int { return len(s) } +func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } +func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/state_block_table_test.go b/roomserver/storage/state_block_table_test.go new file mode 100644 index 00000000..f891b5bc --- /dev/null +++ b/roomserver/storage/state_block_table_test.go @@ -0,0 +1,85 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "sort" + "testing" + + "github.com/matrix-org/dendrite/roomserver/types" +) + +func TestStateKeyTupleSorter(t *testing.T) { + input := stateKeyTupleSorter{ + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 1}, + } + want := []types.StateKeyTuple{ + {EventTypeNID: 1, EventStateKeyNID: 1}, + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, + } + doNotWant := []types.StateKeyTuple{ + {EventTypeNID: 0, EventStateKeyNID: 0}, + {EventTypeNID: 1, EventStateKeyNID: 3}, + {EventTypeNID: 2, EventStateKeyNID: 1}, + {EventTypeNID: 3, EventStateKeyNID: 1}, + } + wantTypeNIDs := []int64{1, 2} + wantStateKeyNIDs := []int64{1, 2, 4} + + // Sort the input and check it's in the right order. + sort.Sort(input) + gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays() + + for i := range want { + if input[i] != want[i] { + t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i]) + } + + if !input.contains(want[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i]) + } + } + + for i := range doNotWant { + if input.contains(doNotWant[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i]) + } + } + + if len(wantTypeNIDs) != len(gotTypeNIDs) { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + + for i := range wantTypeNIDs { + if wantTypeNIDs[i] != gotTypeNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } + + if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) { + t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs) + } + + for i := range wantStateKeyNIDs { + if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } +} diff --git a/roomserver/storage/state_snapshot_table.go b/roomserver/storage/state_snapshot_table.go new file mode 100644 index 00000000..aa14daad --- /dev/null +++ b/roomserver/storage/state_snapshot_table.go @@ -0,0 +1,118 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const stateSnapshotSchema = ` +-- The state of a room before an event. +-- Stored as a list of state_block entries stored in a separate table. +-- The actual state is constructed by combining all the state_block entries +-- referenced by state_block_nids together. If the same state key tuple appears +-- multiple times then the entry from the later state_block clobbers the earlier +-- entries. +-- This encoding format allows us to implement a delta encoding which is useful +-- because room state tends to accumulate small changes over time. Although if +-- the list of deltas becomes too long it becomes more efficient to encode +-- the full state under single state_block_nid. +CREATE SEQUENCE IF NOT EXISTS roomserver_state_snapshot_nid_seq; +CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( + -- Local numeric ID for the state. + state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_seq'), + -- Local numeric ID of the room this state is for. + -- Unused in normal operation, but useful for background work or ad-hoc debugging. + room_nid bigint NOT NULL, + -- List of state_block_nids, stored sorted by state_block_nid. + state_block_nids bigint[] NOT NULL +); +` + +const insertStateSQL = "" + + "INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)" + + " VALUES ($1, $2)" + + " RETURNING state_snapshot_nid" + +// Bulk state data NID lookup. +// Sorting by state_snapshot_nid means we can use binary search over the result +// to lookup the state data NIDs for a state snapshot NID. +const bulkSelectStateBlockNIDsSQL = "" + + "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + + " WHERE state_snapshot_nid = ANY($1) ORDER BY state_snapshot_nid ASC" + +type stateSnapshotStatements struct { + insertStateStmt *sql.Stmt + bulkSelectStateBlockNIDsStmt *sql.Stmt +} + +func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(stateSnapshotSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertStateStmt, insertStateSQL}, + {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, + }.prepare(db) +} + +func (s *stateSnapshotStatements) insertState( + ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, +) (stateNID types.StateSnapshotNID, err error) { + nids := make([]int64, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + nids[i] = int64(stateBlockNIDs[i]) + } + err = s.insertStateStmt.QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) + return +} + +func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + nids := make([]int64, len(stateNIDs)) + for i := range stateNIDs { + nids[i] = int64(stateNIDs[i]) + } + rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids)) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]types.StateBlockNIDList, len(stateNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + var stateBlockNIDs pq.Int64Array + if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil { + return nil, err + } + result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs)) + for k := range stateBlockNIDs { + result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k]) + } + } + if i != len(stateNIDs) { + return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) + } + return results, nil +} diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go new file mode 100644 index 00000000..f6c2fccd --- /dev/null +++ b/roomserver/storage/storage.go @@ -0,0 +1,705 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" + + // Import the postgres database driver. + _ "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// A Database is used to store room events and stream offsets. +type Database struct { + statements statements + db *sql.DB +} + +// Open a postgres database. +func Open(dataSourceName string) (*Database, error) { + var d Database + var err error + if d.db, err = sql.Open("postgres", dataSourceName); err != nil { + return nil, err + } + if err = d.statements.prepare(d.db); err != nil { + return nil, err + } + return &d, nil +} + +// StoreEvent implements input.EventDatabase +func (d *Database) StoreEvent( + ctx context.Context, event gomatrixserverlib.Event, + txnAndDeviceID *api.TransactionID, authEventNIDs []types.EventNID, +) (types.RoomNID, types.StateAtEvent, error) { + var ( + roomNID types.RoomNID + eventTypeNID types.EventTypeNID + eventStateKeyNID types.EventStateKeyNID + eventNID types.EventNID + stateNID types.StateSnapshotNID + err error + ) + + if txnAndDeviceID != nil { + if err = d.statements.insertTransaction( + ctx, txnAndDeviceID.TransactionID, + txnAndDeviceID.DeviceID, event.Sender(), event.EventID(), + ); err != nil { + return 0, types.StateAtEvent{}, err + } + } + + if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID()); err != nil { + return 0, types.StateAtEvent{}, err + } + + if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil { + return 0, types.StateAtEvent{}, err + } + + eventStateKey := event.StateKey() + // Assigned a numeric ID for the state_key if there is one present. + // Otherwise set the numeric ID for the state_key to 0. + if eventStateKey != nil { + if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil { + return 0, types.StateAtEvent{}, err + } + } + + if eventNID, stateNID, err = d.statements.insertEvent( + ctx, + roomNID, + eventTypeNID, + eventStateKeyNID, + event.EventID(), + event.EventReference().EventSHA256, + authEventNIDs, + event.Depth(), + ); err != nil { + if err == sql.ErrNoRows { + // We've already inserted the event so select the numeric event ID + eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID()) + } + if err != nil { + return 0, types.StateAtEvent{}, err + } + } + + if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { + return 0, types.StateAtEvent{}, err + } + + return roomNID, types.StateAtEvent{ + BeforeStateSnapshotNID: stateNID, + StateEntry: types.StateEntry{ + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: eventTypeNID, + EventStateKeyNID: eventStateKeyNID, + }, + EventNID: eventNID, + }, + }, nil +} + +func (d *Database) assignRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + // Check if we already have a numeric ID in the database. + roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) + } + } + return roomNID, err +} + +func (d *Database) assignEventTypeNID( + ctx context.Context, eventType string, +) (types.EventTypeNID, error) { + // Check if we already have a numeric ID in the database. + eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType) + } + } + return eventTypeNID, err +} + +func (d *Database) assignStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + // Check if we already have a numeric ID in the database. + eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + } + } + return eventStateKeyNID, err +} + +// StateEntriesForEventIDs implements input.EventDatabase +func (d *Database) StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + return d.statements.bulkSelectStateEventByID(ctx, eventIDs) +} + +// EventTypeNIDs implements state.RoomStateDatabase +func (d *Database) EventTypeNIDs( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return d.statements.bulkSelectEventTypeNID(ctx, eventTypes) +} + +// EventStateKeyNIDs implements state.RoomStateDatabase +func (d *Database) EventStateKeyNIDs( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys) +} + +// EventStateKeys implements query.RoomserverQueryAPIDatabase +func (d *Database) EventStateKeys( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, +) (map[types.EventStateKeyNID]string, error) { + return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs) +} + +// EventNIDs implements query.RoomserverQueryAPIDatabase +func (d *Database) EventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return d.statements.bulkSelectEventNID(ctx, eventIDs) +} + +// Events implements input.EventDatabase +func (d *Database) Events( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.Event, error) { + eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs) + if err != nil { + return nil, err + } + results := make([]types.Event, len(eventJSONs)) + for i, eventJSON := range eventJSONs { + result := &results[i] + result.EventNID = eventJSON.EventNID + // TODO: Use NewEventFromTrustedJSON for efficiency + result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON) + if err != nil { + return nil, err + } + } + return results, nil +} + +// AddState implements input.EventDatabase +func (d *Database) AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, +) (types.StateSnapshotNID, error) { + if len(state) > 0 { + stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx) + if err != nil { + return 0, err + } + if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil { + return 0, err + } + stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) + } + + return d.statements.insertState(ctx, roomNID, stateBlockNIDs) +} + +// SetState implements input.EventDatabase +func (d *Database) SetState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + return d.statements.updateEventState(ctx, eventNID, stateNID) +} + +// StateAtEventIDs implements input.EventDatabase +func (d *Database) StateAtEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs) +} + +// StateBlockNIDs implements state.RoomStateDatabase +func (d *Database) StateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs) +} + +// StateEntries implements state.RoomStateDatabase +func (d *Database) StateEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs) +} + +// SnapshotNIDFromEventID implements state.RoomStateDatabase +func (d *Database) SnapshotNIDFromEventID( + ctx context.Context, eventID string, +) (types.StateSnapshotNID, error) { + _, stateNID, err := d.statements.selectEvent(ctx, eventID) + return stateNID, err +} + +// EventIDs implements input.RoomEventDatabase +func (d *Database) EventIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID]string, error) { + return d.statements.bulkSelectEventID(ctx, eventNIDs) +} + +// GetLatestEventsForUpdate implements input.EventDatabase +func (d *Database) GetLatestEventsForUpdate( + ctx context.Context, roomNID types.RoomNID, +) (types.RoomRecentEventsUpdater, error) { + txn, err := d.db.Begin() + if err != nil { + return nil, err + } + eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := + d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + var lastEventIDSent string + if lastEventNIDSent != 0 { + lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + } + return &roomRecentEventsUpdater{ + transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + }, nil +} + +// GetTransactionEventID implements input.EventDatabase +func (d *Database) GetTransactionEventID( + ctx context.Context, transactionID string, + deviceID string, userID string, +) (string, error) { + eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, deviceID, userID) + if err == sql.ErrNoRows { + return "", nil + } + return eventID, err +} + +type roomRecentEventsUpdater struct { + transaction + d *Database + roomNID types.RoomNID + latestEvents []types.StateAtEventAndReference + lastEventIDSent string + currentStateSnapshotNID types.StateSnapshotNID +} + +// LatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { + return u.latestEvents +} + +// LastEventIDSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LastEventIDSent() string { + return u.lastEventIDSent +} + +// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { + return u.currentStateSnapshotNID +} + +// StorePreviousEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { + for _, ref := range previousEventReferences { + if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return err + } + } + return nil +} + +// IsReferenced implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { + err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) + if err == nil { + return true, nil + } + if err == sql.ErrNoRows { + return false, nil + } + return false, err +} + +// SetLatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) SetLatestEvents( + roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, + currentStateSnapshotNID types.StateSnapshotNID, +) error { + eventNIDs := make([]types.EventNID, len(latest)) + for i := range latest { + eventNIDs[i] = latest[i].EventNID + } + return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) +} + +// HasEventBeenSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { + return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID) +} + +// MarkEventAsSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { + return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) +} + +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID) +} + +// RoomNID implements query.RoomserverQueryAPIDB +func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { + roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID) + if err == sql.ErrNoRows { + return 0, nil + } + return roomNID, err +} + +// LatestEventIDs implements query.RoomserverQueryAPIDatabase +func (d *Database) LatestEventIDs( + ctx context.Context, roomNID types.RoomNID, +) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) { + eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID) + if err != nil { + return nil, 0, 0, err + } + references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs) + if err != nil { + return nil, 0, 0, err + } + depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs) + if err != nil { + return nil, 0, 0, err + } + return references, currentStateSnapshotNID, depth, nil +} + +// GetInvitesForUser implements query.RoomserverQueryAPIDatabase +func (d *Database) GetInvitesForUser( + ctx context.Context, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (senderUserIDs []types.EventStateKeyNID, err error) { + return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) +} + +// SetRoomAlias implements alias.RoomserverAliasAPIDB +func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string) error { + return d.statements.insertRoomAlias(ctx, alias, roomID) +} + +// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB +func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { + return d.statements.selectRoomIDFromAlias(ctx, alias) +} + +// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB +func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { + return d.statements.selectAliasesFromRoomID(ctx, roomID) +} + +// RemoveRoomAlias implements alias.RoomserverAliasAPIDB +func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { + return d.statements.deleteRoomAlias(ctx, alias) +} + +// StateEntriesForTuples implements state.RoomStateDatabase +func (d *Database) StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return d.statements.bulkSelectFilteredStateBlockEntries( + ctx, stateBlockNIDs, stateKeyTuples, + ) +} + +// MembershipUpdater implements input.RoomEventDatabase +func (d *Database) MembershipUpdater( + ctx context.Context, roomID, targetUserID string, +) (types.MembershipUpdater, error) { + txn, err := d.db.Begin() + if err != nil { + return nil, err + } + succeeded := false + defer func() { + if !succeeded { + txn.Rollback() // nolint: errcheck + } + }() + + roomNID, err := d.assignRoomNID(ctx, txn, roomID) + if err != nil { + return nil, err + } + + targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) + if err != nil { + return nil, err + } + + updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) + if err != nil { + return nil, err + } + + succeeded = true + return updater, nil +} + +type membershipUpdater struct { + transaction + d *Database + roomNID types.RoomNID + targetUserNID types.EventStateKeyNID + membership membershipState +} + +func (d *Database) membershipUpdaterTxn( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (types.MembershipUpdater, error) { + + if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { + return nil, err + } + + membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) + if err != nil { + return nil, err + } + + return &membershipUpdater{ + transaction{ctx, txn}, d, roomNID, targetUserNID, membership, + }, nil +} + +// IsInvite implements types.MembershipUpdater +func (u *membershipUpdater) IsInvite() bool { + return u.membership == membershipStateInvite +} + +// IsJoin implements types.MembershipUpdater +func (u *membershipUpdater) IsJoin() bool { + return u.membership == membershipStateJoin +} + +// IsLeave implements types.MembershipUpdater +func (u *membershipUpdater) IsLeave() bool { + return u.membership == membershipStateLeaveOrBan +} + +// SetToInvite implements types.MembershipUpdater +func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + if err != nil { + return false, err + } + inserted, err := u.d.statements.insertInviteEvent( + u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), + ) + if err != nil { + return false, err + } + if u.membership != membershipStateInvite { + if err = u.d.statements.updateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, + ); err != nil { + return false, err + } + } + return inserted, nil +} + +// SetToJoin implements types.MembershipUpdater +func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { + var inviteEventIDs []string + + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) + if err != nil { + return nil, err + } + + // If this is a join event update, there is no invite to update + if !isUpdate { + inviteEventIDs, err = u.d.statements.updateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return nil, err + } + } + + // Look up the NID of the new join event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return nil, err + } + + if u.membership != membershipStateJoin || isUpdate { + if err = u.d.statements.updateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateJoin, nIDs[eventID], + ); err != nil { + return nil, err + } + } + + return inviteEventIDs, nil +} + +// SetToLeave implements types.MembershipUpdater +func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) + if err != nil { + return nil, err + } + inviteEventIDs, err := u.d.statements.updateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return nil, err + } + + // Look up the NID of the new leave event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return nil, err + } + + if u.membership != membershipStateLeaveOrBan { + if err = u.d.statements.updateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateLeaveOrBan, nIDs[eventID], + ); err != nil { + return nil, err + } + } + return inviteEventIDs, nil +} + +// GetMembership implements query.RoomserverQueryAPIDB +func (d *Database) GetMembership( + ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, +) (membershipEventNID types.EventNID, stillInRoom bool, err error) { + requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) + if err != nil { + return + } + + senderMembershipEventNID, senderMembership, err := + d.statements.selectMembershipFromRoomAndTarget( + ctx, roomNID, requestSenderUserNID, + ) + if err == sql.ErrNoRows { + // The user has never been a member of that room + return 0, false, nil + } else if err != nil { + return + } + + return senderMembershipEventNID, senderMembership == membershipStateJoin, nil +} + +// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB +func (d *Database) GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, +) ([]types.EventNID, error) { + if joinOnly { + return d.statements.selectMembershipsFromRoomAndMembership( + ctx, roomNID, membershipStateJoin, + ) + } + + return d.statements.selectMembershipsFromRoom(ctx, roomNID) +} + +// EventsFromIDs implements query.RoomserverQueryAPIEventDB +func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + nidMap, err := d.EventNIDs(ctx, eventIDs) + if err != nil { + return nil, err + } + + var nids []types.EventNID + for _, nid := range nidMap { + nids = append(nids, nid) + } + + return d.Events(ctx, nids) +} + +type transaction struct { + ctx context.Context + txn *sql.Tx +} + +// Commit implements types.Transaction +func (t *transaction) Commit() error { + return t.txn.Commit() +} + +// Rollback implements types.Transaction +func (t *transaction) Rollback() error { + return t.txn.Rollback() +} diff --git a/roomserver/storage/transactions_table.go b/roomserver/storage/transactions_table.go new file mode 100644 index 00000000..e9c904cc --- /dev/null +++ b/roomserver/storage/transactions_table.go @@ -0,0 +1,86 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" +) + +const transactionsSchema = ` +-- The transactions table holds transaction IDs with sender's info and event ID it belongs to. +-- This table is used by roomserver to prevent reprocessing of events. +CREATE TABLE IF NOT EXISTS roomserver_transactions ( + -- The transaction ID of the event. + transaction_id TEXT NOT NULL, + -- The device ID of the originating transaction. + device_id TEXT NOT NULL, + -- User ID of the sender who authored the event + user_id TEXT NOT NULL, + -- Event ID corresponding to the transaction + -- Required to return event ID to client on a duplicate request. + event_id TEXT NOT NULL, + -- A transaction ID is unique for a user and device + -- This automatically creates an index. + PRIMARY KEY (transaction_id, device_id, user_id) +); +` +const insertTransactionSQL = "" + + "INSERT INTO roomserver_transactions (transaction_id, device_id, user_id, event_id)" + + " VALUES ($1, $2, $3, $4)" + +const selectTransactionEventIDSQL = "" + + "SELECT event_id FROM roomserver_transactions" + + " WHERE transaction_id = $1 AND device_id = $2 AND user_id = $3" + +type transactionStatements struct { + insertTransactionStmt *sql.Stmt + selectTransactionEventIDStmt *sql.Stmt +} + +func (s *transactionStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(transactionsSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertTransactionStmt, insertTransactionSQL}, + {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, + }.prepare(db) +} + +func (s *transactionStatements) insertTransaction( + ctx context.Context, + transactionID string, + deviceID string, + userID string, + eventID string, +) (err error) { + _, err = s.insertTransactionStmt.ExecContext( + ctx, transactionID, deviceID, userID, eventID, + ) + return +} + +func (s *transactionStatements) selectTransactionEventID( + ctx context.Context, + transactionID string, + deviceID string, + userID string, +) (eventID string, err error) { + err = s.selectTransactionEventIDStmt.QueryRowContext( + ctx, transactionID, deviceID, userID, + ).Scan(&eventID) + return +} diff --git a/roomserver/types/types.go b/roomserver/types/types.go new file mode 100644 index 00000000..d5fe3276 --- /dev/null +++ b/roomserver/types/types.go @@ -0,0 +1,203 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package types provides the types that are used internally within the roomserver. +package types + +import ( + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/gomatrixserverlib" +) + +// EventTypeNID is a numeric ID for an event type. +type EventTypeNID int64 + +// EventStateKeyNID is a numeric ID for an event state_key. +type EventStateKeyNID int64 + +// EventNID is a numeric ID for an event. +type EventNID int64 + +// RoomNID is a numeric ID for a room. +type RoomNID int64 + +// StateSnapshotNID is a numeric ID for the state at an event. +type StateSnapshotNID int64 + +// StateBlockNID is a numeric ID for a block of state data. +// These blocks of state data are combined to form the actual state. +type StateBlockNID int64 + +// A StateKeyTuple is a pair of a numeric event type and a numeric state key. +// It is used to lookup state entries. +type StateKeyTuple struct { + // The numeric ID for the event type. + EventTypeNID EventTypeNID + // The numeric ID for the state key. + EventStateKeyNID EventStateKeyNID +} + +// LessThan returns true if this state key is less than the other state key. +// The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries. +func (a StateKeyTuple) LessThan(b StateKeyTuple) bool { + if a.EventTypeNID != b.EventTypeNID { + return a.EventTypeNID < b.EventTypeNID + } + return a.EventStateKeyNID < b.EventStateKeyNID +} + +// A StateEntry is an entry in the room state of a matrix room. +type StateEntry struct { + StateKeyTuple + // The numeric ID for the event. + EventNID EventNID +} + +// LessThan returns true if this state entry is less than the other state entry. +// The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries. +func (a StateEntry) LessThan(b StateEntry) bool { + if a.StateKeyTuple != b.StateKeyTuple { + return a.StateKeyTuple.LessThan(b.StateKeyTuple) + } + return a.EventNID < b.EventNID +} + +// StateAtEvent is the state before and after a matrix event. +type StateAtEvent struct { + // The state before the event. + BeforeStateSnapshotNID StateSnapshotNID + // The state entry for the event itself, allows us to calculate the state after the event. + StateEntry +} + +// IsStateEvent returns whether the event the state is at is a state event. +func (s StateAtEvent) IsStateEvent() bool { + return s.EventStateKeyNID != 0 +} + +// StateAtEventAndReference is StateAtEvent and gomatrixserverlib.EventReference glued together. +// It is used when looking up the latest events in a room in the database. +// The gomatrixserverlib.EventReference is used to check whether a new event references the event. +// The StateAtEvent is used to construct the current state of the room from the latest events. +type StateAtEventAndReference struct { + StateAtEvent + gomatrixserverlib.EventReference +} + +// An Event is a gomatrixserverlib.Event with the numeric event ID attached. +// It is when performing bulk event lookup in the database. +type Event struct { + EventNID EventNID + gomatrixserverlib.Event +} + +const ( + // MRoomCreateNID is the numeric ID for the "m.room.create" event type. + MRoomCreateNID = 1 + // MRoomPowerLevelsNID is the numeric ID for the "m.room.power_levels" event type. + MRoomPowerLevelsNID = 2 + // MRoomJoinRulesNID is the numeric ID for the "m.room.join_rules" event type. + MRoomJoinRulesNID = 3 + // MRoomThirdPartyInviteNID is the numeric ID for the "m.room.third_party_invite" event type. + MRoomThirdPartyInviteNID = 4 + // MRoomMemberNID is the numeric ID for the "m.room.member" event type. + MRoomMemberNID = 5 + // MRoomRedactionNID is the numeric ID for the "m.room.redaction" event type. + MRoomRedactionNID = 6 + // MRoomHistoryVisibilityNID is the numeric ID for the "m.room.history_visibility" event type. + MRoomHistoryVisibilityNID = 7 +) + +const ( + // EmptyStateKeyNID is the numeric ID for the empty state key. + EmptyStateKeyNID = 1 +) + +// StateBlockNIDList is used to return the result of bulk StateBlockNID lookups from the database. +type StateBlockNIDList struct { + StateSnapshotNID StateSnapshotNID + StateBlockNIDs []StateBlockNID +} + +// StateEntryList is used to return the result of bulk state entry lookups from the database. +type StateEntryList struct { + StateBlockNID StateBlockNID + StateEntries []StateEntry +} + +// A RoomRecentEventsUpdater is used to update the recent events in a room. +// (On postgresql this wraps a database transaction that holds a "FOR UPDATE" +// lock on the row in the rooms table holding the latest events for the room.) +type RoomRecentEventsUpdater interface { + // The latest event IDs and state in the room. + LatestEvents() []StateAtEventAndReference + // The event ID of the latest event written to the output log in the room. + LastEventIDSent() string + // The current state of the room. + CurrentStateSnapshotNID() StateSnapshotNID + // Store the previous events referenced by an event. + // This adds the event NID to an entry in the database for each of the previous events. + // If there isn't an entry for one of previous events then an entry is created. + // If the entry already lists the event NID as a referrer then the entry unmodified. + // (i.e. the operation is idempotent) + StorePreviousEvents(eventNID EventNID, previousEventReferences []gomatrixserverlib.EventReference) error + // Check whether the eventReference is already referenced by another matrix event. + IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) + // Set the list of latest events for the room. + // This replaces the current list stored in the database with the given list + SetLatestEvents( + roomNID RoomNID, latest []StateAtEventAndReference, lastEventNIDSent EventNID, + currentStateSnapshotNID StateSnapshotNID, + ) error + // Check if the event has already be written to the output logs. + HasEventBeenSent(eventNID EventNID) (bool, error) + // Mark the event as having been sent to the output logs. + MarkEventAsSent(eventNID EventNID) error + // Build a membership updater for the target user in this room. + // It will share the same transaction as this updater. + MembershipUpdater(targetUserNID EventStateKeyNID) (MembershipUpdater, error) + // Implements Transaction so it can be committed or rolledback + common.Transaction +} + +// A MembershipUpdater is used to update the membership of a user in a room. +// (On postgresql this wraps a database transaction that holds a "FOR UPDATE" +// lock on the row in the membership table for this user in the room) +// The caller should call one of SetToInvite, SetToJoin or SetToLeave once to +// make the update, or none of them if no update is required. +type MembershipUpdater interface { + // True if the target user is invited to the room before updating. + IsInvite() bool + // True if the target user is joined to the room before updating. + IsJoin() bool + // True if the target user is not invited or joined to the room before updating. + IsLeave() bool + // Set the state to invite. + // Returns whether this invite needs to be sent + SetToInvite(event gomatrixserverlib.Event) (needsSending bool, err error) + // Set the state to join or updates the event ID in the database. + // Returns a list of invite event IDs that this state change retired. + SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error) + // Set the state to leave. + // Returns a list of invite event IDs that this state change retired. + SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) + // Implements Transaction so it can be committed or rolledback. + common.Transaction +} + +// A MissingEventError is an error that happened because the roomserver was +// missing requested events from its database. +type MissingEventError string + +func (e MissingEventError) Error() string { return string(e) } |