diff options
Diffstat (limited to 'syncapi/routing')
-rw-r--r-- | syncapi/routing/context.go | 7 | ||||
-rw-r--r-- | syncapi/routing/routing.go | 28 | ||||
-rw-r--r-- | syncapi/routing/search.go | 344 |
3 files changed, 374 insertions, 5 deletions
diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 13c4e9d8..1ebdfe60 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -37,11 +37,11 @@ import ( type ContextRespsonse struct { End string `json:"end"` - Event gomatrixserverlib.ClientEvent `json:"event"` + Event *gomatrixserverlib.ClientEvent `json:"event,omitempty"` EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after,omitempty"` EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before,omitempty"` Start string `json:"start"` - State []gomatrixserverlib.ClientEvent `json:"state"` + State []gomatrixserverlib.ClientEvent `json:"state,omitempty"` } func Context( @@ -162,8 +162,9 @@ func Context( eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll) newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache) + ev := gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll) response := ContextRespsonse{ - Event: gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll), + Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, State: gomatrixserverlib.HeaderedToClientEvents(newState, gomatrixserverlib.FormatAll), diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 6bc495d8..8f84a134 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -18,15 +18,18 @@ import ( "net/http" "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) // Setup configures the given mux with sync-server listeners @@ -40,6 +43,7 @@ func Setup( rsAPI api.SyncRoomserverAPI, cfg *config.SyncAPI, lazyLoadCache caching.LazyLoadCache, + fts *fulltext.Search, ) { v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() @@ -95,4 +99,24 @@ func Setup( ) }), ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/search", + httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if !cfg.Fulltext.Enabled { + return util.JSONResponse{ + Code: http.StatusNotImplemented, + JSON: jsonerror.Unknown("Search has been disabled by the server administrator."), + } + } + var nextBatch *string + if err := req.ParseForm(); err != nil { + return jsonerror.InternalServerError() + } + if req.Form.Has("next_batch") { + nb := req.FormValue("next_batch") + nextBatch = &nb + } + return Search(req, device, syncDB, fts, nextBatch) + }), + ).Methods(http.MethodPost, http.MethodOptions) } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go new file mode 100644 index 00000000..341efeb1 --- /dev/null +++ b/syncapi/routing/search.go @@ -0,0 +1,344 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "net/http" + "sort" + "strconv" + "strings" + "time" + + "github.com/blevesearch/bleve/v2/search" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/fulltext" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/userapi/api" +) + +// nolint:gocyclo +func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts *fulltext.Search, from *string) util.JSONResponse { + start := time.Now() + var ( + searchReq SearchRequest + err error + ctx = req.Context() + ) + resErr := httputil.UnmarshalJSONRequest(req, &searchReq) + if resErr != nil { + logrus.Error("failed to unmarshal search request") + return *resErr + } + + nextBatch := 0 + if from != nil && *from != "" { + nextBatch, err = strconv.Atoi(*from) + if err != nil { + return jsonerror.InternalServerError() + } + } + + if searchReq.SearchCategories.RoomEvents.Filter.Limit == 0 { + searchReq.SearchCategories.RoomEvents.Filter.Limit = 5 + } + + // only search rooms the user is actually joined to + joinedRooms, err := syncDB.RoomIDsWithMembership(ctx, device.UserID, "join") + if err != nil { + return jsonerror.InternalServerError() + } + if len(joinedRooms) == 0 { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("User not joined to any rooms."), + } + } + joinedRoomsMap := make(map[string]struct{}, len(joinedRooms)) + for _, roomID := range joinedRooms { + joinedRoomsMap[roomID] = struct{}{} + } + rooms := []string{} + if searchReq.SearchCategories.RoomEvents.Filter.Rooms != nil { + for _, roomID := range *searchReq.SearchCategories.RoomEvents.Filter.Rooms { + if _, ok := joinedRoomsMap[roomID]; ok { + rooms = append(rooms, roomID) + } + } + } else { + rooms = joinedRooms + } + + if len(rooms) == 0 { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Unknown("User not allowed to search in this room(s)."), + } + } + + orderByTime := searchReq.SearchCategories.RoomEvents.OrderBy == "recent" + + result, err := fts.Search( + searchReq.SearchCategories.RoomEvents.SearchTerm, + rooms, + searchReq.SearchCategories.RoomEvents.Keys, + searchReq.SearchCategories.RoomEvents.Filter.Limit, + nextBatch, + orderByTime, + ) + if err != nil { + logrus.WithError(err).Error("failed to search fulltext") + return jsonerror.InternalServerError() + } + logrus.Debugf("Search took %s", result.Took) + + // From was specified but empty, return no results, only the count + if from != nil && *from == "" { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: SearchResponse{ + SearchCategories: SearchCategories{ + RoomEvents: RoomEvents{ + Count: int(result.Total), + NextBatch: nil, + }, + }, + }, + } + } + + results := []Result{} + + wantEvents := make([]string, 0, len(result.Hits)) + eventScore := make(map[string]*search.DocumentMatch) + + for _, hit := range result.Hits { + wantEvents = append(wantEvents, hit.ID) + eventScore[hit.ID] = hit + } + + // Filter on m.room.message, as otherwise we also get events like m.reaction + // which "breaks" displaying results in Element Web. + types := []string{"m.room.message"} + roomFilter := &gomatrixserverlib.RoomEventFilter{ + Rooms: &rooms, + Types: &types, + } + + evs, err := syncDB.Events(ctx, wantEvents) + if err != nil { + logrus.WithError(err).Error("failed to get events from database") + return jsonerror.InternalServerError() + } + + groups := make(map[string]RoomResult) + knownUsersProfiles := make(map[string]ProfileInfo) + + // Sort the events by depth, as the returned values aren't ordered + if orderByTime { + sort.Slice(evs, func(i, j int) bool { + return evs[i].Depth() > evs[j].Depth() + }) + } + + stateForRooms := make(map[string][]gomatrixserverlib.ClientEvent) + for _, event := range evs { + eventsBefore, eventsAfter, err := contextEvents(ctx, syncDB, event, roomFilter, searchReq) + if err != nil { + logrus.WithError(err).Error("failed to get context events") + return jsonerror.InternalServerError() + } + startToken, endToken, err := getStartEnd(ctx, syncDB, eventsBefore, eventsAfter) + if err != nil { + logrus.WithError(err).Error("failed to get start/end") + return jsonerror.InternalServerError() + } + + profileInfos := make(map[string]ProfileInfo) + for _, ev := range append(eventsBefore, eventsAfter...) { + profile, ok := knownUsersProfiles[event.Sender()] + if !ok { + stateEvent, err := syncDB.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender()) + if err != nil { + logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile") + continue + } + if stateEvent == nil { + continue + } + profile = ProfileInfo{ + AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str, + DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str, + } + knownUsersProfiles[event.Sender()] = profile + } + profileInfos[ev.Sender()] = profile + } + + results = append(results, Result{ + Context: SearchContextResponse{ + Start: startToken.String(), + End: endToken.String(), + EventsAfter: gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatSync), + EventsBefore: gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatSync), + ProfileInfo: profileInfos, + }, + Rank: eventScore[event.EventID()].Score, + Result: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll), + }) + roomGroup := groups[event.RoomID()] + roomGroup.Results = append(roomGroup.Results, event.EventID()) + groups[event.RoomID()] = roomGroup + if _, ok := stateForRooms[event.RoomID()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok { + stateFilter := gomatrixserverlib.DefaultStateFilter() + state, err := syncDB.CurrentState(ctx, event.RoomID(), &stateFilter, nil) + if err != nil { + logrus.WithError(err).Error("unable to get current state") + return jsonerror.InternalServerError() + } + stateForRooms[event.RoomID()] = gomatrixserverlib.HeaderedToClientEvents(state, gomatrixserverlib.FormatSync) + } + } + + var nextBatchResult *string = nil + if int(result.Total) > nextBatch+len(results) { + nb := strconv.Itoa(len(results) + nextBatch) + nextBatchResult = &nb + } else if int(result.Total) == nextBatch+len(results) { + // Sytest expects a next_batch even if we don't actually have any more results + nb := "" + nextBatchResult = &nb + } + + res := SearchResponse{ + SearchCategories: SearchCategories{ + RoomEvents: RoomEvents{ + Count: int(result.Total), + Groups: Groups{RoomID: groups}, + Results: results, + NextBatch: nextBatchResult, + Highlights: strings.Split(searchReq.SearchCategories.RoomEvents.SearchTerm, " "), + State: stateForRooms, + }, + }, + } + + logrus.Debugf("Full search request took %v", time.Since(start)) + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: res, + } +} + +// contextEvents returns the events around a given eventID +func contextEvents( + ctx context.Context, + syncDB storage.Database, + event *gomatrixserverlib.HeaderedEvent, + roomFilter *gomatrixserverlib.RoomEventFilter, + searchReq SearchRequest, +) ([]*gomatrixserverlib.HeaderedEvent, []*gomatrixserverlib.HeaderedEvent, error) { + id, _, err := syncDB.SelectContextEvent(ctx, event.RoomID(), event.EventID()) + if err != nil { + logrus.WithError(err).Error("failed to query context event") + return nil, nil, err + } + roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.BeforeLimit + eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter) + if err != nil { + logrus.WithError(err).Error("failed to query before context event") + return nil, nil, err + } + roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.AfterLimit + _, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter) + if err != nil { + logrus.WithError(err).Error("failed to query after context event") + return nil, nil, err + } + return eventsBefore, eventsAfter, err +} + +type SearchRequest struct { + SearchCategories struct { + RoomEvents struct { + EventContext struct { + AfterLimit int `json:"after_limit,omitempty"` + BeforeLimit int `json:"before_limit,omitempty"` + IncludeProfile bool `json:"include_profile,omitempty"` + } `json:"event_context"` + Filter gomatrixserverlib.StateFilter `json:"filter"` + Groupings struct { + GroupBy []struct { + Key string `json:"key"` + } `json:"group_by"` + } `json:"groupings"` + IncludeState bool `json:"include_state"` + Keys []string `json:"keys"` + OrderBy string `json:"order_by"` + SearchTerm string `json:"search_term"` + } `json:"room_events"` + } `json:"search_categories"` +} + +type SearchResponse struct { + SearchCategories SearchCategories `json:"search_categories"` +} +type RoomResult struct { + NextBatch *string `json:"next_batch,omitempty"` + Order int `json:"order"` + Results []string `json:"results"` +} + +type Groups struct { + RoomID map[string]RoomResult `json:"room_id"` +} + +type Result struct { + Context SearchContextResponse `json:"context"` + Rank float64 `json:"rank"` + Result gomatrixserverlib.ClientEvent `json:"result"` +} + +type SearchContextResponse struct { + End string `json:"end"` + EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after"` + EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before"` + Start string `json:"start"` + ProfileInfo map[string]ProfileInfo `json:"profile_info"` +} + +type ProfileInfo struct { + AvatarURL string `json:"avatar_url"` + DisplayName string `json:"display_name"` +} + +type RoomEvents struct { + Count int `json:"count"` + Groups Groups `json:"groups"` + Highlights []string `json:"highlights"` + NextBatch *string `json:"next_batch,omitempty"` + Results []Result `json:"results"` + State map[string][]gomatrixserverlib.ClientEvent `json:"state,omitempty"` +} +type SearchCategories struct { + RoomEvents RoomEvents `json:"room_events"` +} |