diff options
author | Till <2353100+S7evinK@users.noreply.github.com> | 2022-04-11 09:05:23 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-11 09:05:23 +0200 |
commit | 69f2ff7c82abe0731a05febde88098f4cd34ab8d (patch) | |
tree | c7e0e2e65550f8b2d3b50e385acc6ac4cdbc90d2 /syncapi | |
parent | b4b2fbc36b1eb1b46640feadbe7e1729c864a898 (diff) |
Correctly use provided filters (#2339)
* Apply filters correctly
* Fix issues; Use prepareWithFilters
* Update gmsl & tests
* go.mod..
* PR comments
Diffstat (limited to 'syncapi')
-rw-r--r-- | syncapi/routing/context.go | 4 | ||||
-rw-r--r-- | syncapi/storage/postgres/current_room_state_table.go | 5 | ||||
-rw-r--r-- | syncapi/storage/postgres/filtering.go | 32 | ||||
-rw-r--r-- | syncapi/storage/postgres/output_room_events_table.go | 26 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/account_data_table.go | 41 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/filtering.go | 50 | ||||
-rw-r--r-- | syncapi/streams/stream_pdu.go | 6 |
7 files changed, 101 insertions, 63 deletions
diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 2412bc2a..aaa0c61b 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -60,7 +60,9 @@ func Context( Headers: nil, } } - filter.Rooms = append(filter.Rooms, roomID) + if filter.Rooms != nil { + *filter.Rooms = append(*filter.Rooms, roomID) + } ctx := req.Context() membershipRes := roomserver.QueryMembershipForUserResponse{} diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 69e6e30e..fe68788d 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -233,9 +233,10 @@ func (s *currentRoomStateStatements) SelectCurrentState( excludeEventIDs []string, ) ([]*gomatrixserverlib.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) + senders, notSenders := getSendersStateFilterFilter(stateFilter) rows, err := stmt.QueryContext(ctx, roomID, - pq.StringArray(stateFilter.Senders), - pq.StringArray(stateFilter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), stateFilter.ContainsURL, diff --git a/syncapi/storage/postgres/filtering.go b/syncapi/storage/postgres/filtering.go index dcc42136..a2ca4215 100644 --- a/syncapi/storage/postgres/filtering.go +++ b/syncapi/storage/postgres/filtering.go @@ -16,21 +16,45 @@ package postgres import ( "strings" + + "github.com/matrix-org/gomatrixserverlib" ) // filterConvertWildcardToSQL converts wildcards as defined in // https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter // to SQL wildcards that can be used with LIKE() -func filterConvertTypeWildcardToSQL(values []string) []string { +func filterConvertTypeWildcardToSQL(values *[]string) []string { if values == nil { // Return nil instead of []string{} so IS NULL can work correctly when // the return value is passed into SQL queries return nil } - ret := make([]string, len(values)) - for i := range values { - ret[i] = strings.Replace(values[i], "*", "%", -1) + v := *values + ret := make([]string, len(v)) + for i := range v { + ret[i] = strings.Replace(v[i], "*", "%", -1) } return ret } + +// TODO: Replace when Dendrite uses Go 1.18 +func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (senders []string, notSenders []string) { + if filter.Senders != nil { + senders = *filter.Senders + } + if filter.NotSenders != nil { + notSenders = *filter.NotSenders + } + return senders, notSenders +} + +func getSendersStateFilterFilter(filter *gomatrixserverlib.StateFilter) (senders []string, notSenders []string) { + if filter.Senders != nil { + senders = *filter.Senders + } + if filter.NotSenders != nil { + notSenders = *filter.NotSenders + } + return senders, notSenders +} diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index a30e220b..269cd449 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -204,11 +204,11 @@ func (s *outputRoomEventsStatements) SelectStateInRange( stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) - + senders, notSenders := getSendersStateFilterFilter(stateFilter) rows, err := stmt.QueryContext( ctx, r.Low(), r.High(), pq.StringArray(roomIDs), - pq.StringArray(stateFilter.Senders), - pq.StringArray(stateFilter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), stateFilter.ContainsURL, @@ -353,10 +353,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( } else { stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) } + senders, notSenders := getSendersRoomEventFilter(eventFilter) rows, err := stmt.QueryContext( ctx, roomID, r.Low(), r.High(), - pq.StringArray(eventFilter.Senders), - pq.StringArray(eventFilter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)), eventFilter.Limit+1, @@ -398,11 +399,12 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, ) ([]types.StreamEvent, error) { + senders, notSenders := getSendersRoomEventFilter(eventFilter) stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt) rows, err := stmt.QueryContext( ctx, roomID, r.Low(), r.High(), - pq.StringArray(eventFilter.Senders), - pq.StringArray(eventFilter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)), eventFilter.Limit, @@ -480,10 +482,11 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn func (s *outputRoomEventsStatements) SelectContextBeforeEvent( ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, ) (evts []*gomatrixserverlib.HeaderedEvent, err error) { + senders, notSenders := getSendersRoomEventFilter(filter) rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext( ctx, roomID, id, filter.Limit, - pq.StringArray(filter.Senders), - pq.StringArray(filter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)), ) @@ -512,10 +515,11 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent( func (s *outputRoomEventsStatements) SelectContextAfterEvent( ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter, ) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) { + senders, notSenders := getSendersRoomEventFilter(filter) rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext( ctx, roomID, id, filter.Limit, - pq.StringArray(filter.Senders), - pq.StringArray(filter.NotSenders), + pq.StringArray(senders), + pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)), ) diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 5b2287e6..b0aeb70f 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -41,10 +41,10 @@ const insertAccountDataSQL = "" + " ON CONFLICT (user_id, room_id, type) DO UPDATE" + " SET id = $5" +// further parameters are added by prepareWithFilters const selectAccountDataInRangeSQL = "" + "SELECT room_id, type FROM syncapi_account_data_type" + - " WHERE user_id = $1 AND id > $2 AND id <= $3" + - " ORDER BY id ASC" + " WHERE user_id = $1 AND id > $2 AND id <= $3" const selectMaxAccountDataIDSQL = "" + "SELECT MAX(id) FROM syncapi_account_data_type" @@ -94,18 +94,25 @@ func (s *accountDataStatements) SelectAccountDataInRange( ctx context.Context, userID string, r types.Range, - accountDataFilterPart *gomatrixserverlib.EventFilter, + filter *gomatrixserverlib.EventFilter, ) (data map[string][]string, err error) { data = make(map[string][]string) - - rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High()) + stmt, params, err := prepareWithFilters( + s.db, nil, selectAccountDataInRangeSQL, + []interface{}{ + userID, r.Low(), r.High(), + }, + filter.Senders, filter.NotSenders, + filter.Types, filter.NotTypes, + []string{}, filter.Limit, FilterOrderAsc, + ) + + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return } defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") - var entries int - for rows.Next() { var dataType string var roomID string @@ -114,31 +121,11 @@ func (s *accountDataStatements) SelectAccountDataInRange( return } - // check if we should add this by looking at the filter. - // It would be nice if we could do this in SQL-land, but the mix of variadic - // and positional parameters makes the query annoyingly hard to do, it's easier - // and clearer to do it in Go-land. If there are no filters for [not]types then - // this gets skipped. - for _, includeType := range accountDataFilterPart.Types { - if includeType != dataType { // TODO: wildcard support - continue - } - } - for _, excludeType := range accountDataFilterPart.NotTypes { - if excludeType == dataType { // TODO: wildcard support - continue - } - } - if len(data[roomID]) > 0 { data[roomID] = append(data[roomID], dataType) } else { data[roomID] = []string{dataType} } - entries++ - if entries >= accountDataFilterPart.Limit { - break - } } return data, nil diff --git a/syncapi/storage/sqlite3/filtering.go b/syncapi/storage/sqlite3/filtering.go index 11f3e647..54b12ddf 100644 --- a/syncapi/storage/sqlite3/filtering.go +++ b/syncapi/storage/sqlite3/filtering.go @@ -25,32 +25,48 @@ const ( // parts. func prepareWithFilters( db *sql.DB, txn *sql.Tx, query string, params []interface{}, - senders, notsenders, types, nottypes []string, excludeEventIDs []string, + senders, notsenders, types, nottypes *[]string, excludeEventIDs []string, limit int, order FilterOrder, ) (*sql.Stmt, []interface{}, error) { offset := len(params) - if count := len(senders); count > 0 { - query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range senders { - params, offset = append(params, v), offset+1 + if senders != nil { + if count := len(*senders); count > 0 { + query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range *senders { + params, offset = append(params, v), offset+1 + } + } else { + query += ` AND sender = ""` } } - if count := len(notsenders); count > 0 { - query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range notsenders { - params, offset = append(params, v), offset+1 + if notsenders != nil { + if count := len(*notsenders); count > 0 { + query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range *notsenders { + params, offset = append(params, v), offset+1 + } + } else { + query += ` AND sender NOT = ""` } } - if count := len(types); count > 0 { - query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range types { - params, offset = append(params, v), offset+1 + if types != nil { + if count := len(*types); count > 0 { + query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range *types { + params, offset = append(params, v), offset+1 + } + } else { + query += ` AND type = ""` } } - if count := len(nottypes); count > 0 { - query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range nottypes { - params, offset = append(params, v), offset+1 + if nottypes != nil { + if count := len(*nottypes); count > 0 { + query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset) + for _, v := range *nottypes { + params, offset = append(params, v), offset+1 + } + } else { + query += ` AND type NOT = ""` } } if count := len(excludeEventIDs); count > 0 { diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index ab200e00..bcaf6ca3 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -423,8 +423,12 @@ func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *ty return err } req.IgnoredUsers = *ignores + userList := make([]string, 0, len(ignores.List)) for userID := range ignores.List { - eventFilter.NotSenders = append(eventFilter.NotSenders, userID) + userList = append(userList, userID) + } + if len(userList) > 0 { + eventFilter.NotSenders = &userList } return nil } |