diff options
author | Devon Hudson <devonhudson@librem.one> | 2023-05-18 13:41:47 -0600 |
---|---|---|
committer | Devon Hudson <devonhudson@librem.one> | 2023-05-18 13:41:47 -0600 |
commit | 027a9b8ce0a7e2d577e2c41f9de7a6fe42ace655 (patch) | |
tree | 3bf20546b3db3052d3df0cdfbf0ac1cfb493c575 | |
parent | 345f025ee3654d120b9e668e943a4f2d428c12c7 (diff) |
Fix bug with nil interface return & add test
-rw-r--r-- | roomserver/internal/query/query.go | 6 | ||||
-rw-r--r-- | roomserver/internal/query/query_test.go | 33 |
2 files changed, 38 insertions, 1 deletions
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index e4dac45e..35cafd0e 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -868,7 +868,11 @@ func (r *Queryer) QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types } func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) { - return r.DB.GetStateEvent(ctx, roomID.String(), string(eventType), "") + res, err := r.DB.GetStateEvent(ctx, roomID.String(), string(eventType), "") + if res == nil { + return nil, err + } + return res, err } func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, userID spec.UserID) (bool, error) { diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index 2ebf7f33..b6715cb0 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -18,10 +18,16 @@ import ( "context" "encoding/json" "testing" + "time" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // used to implement RoomserverInternalAPIEventDB to test getAuthChain @@ -155,3 +161,30 @@ func TestGetAuthChainMultiple(t *testing.T) { t.Fatalf("returnedIDs got '%v', expected '%v'", returnedIDs, expectedIDs) } } + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + conStr, close := test.PrepareDBConnectionString(t, dbType) + caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) + db, err := storage.Open(context.Background(), cm, &config.DatabaseOptions{ConnectionString: config.DataSource(conStr)}, caches) + if err != nil { + t.Fatalf("failed to create Database: %v", err) + } + return db, close +} + +func TestCurrentEventIsNil(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + querier := Queryer{ + DB: db, + } + + roomID, _ := spec.NewRoomID("!room:server") + event, _ := querier.CurrentStateEvent(context.Background(), *roomID, spec.MRoomMember, "@user:server") + if event != nil { + t.Fatal("Event should equal nil, most likely this is failing because the interface type is not nil, but the value is.") + } + }) +} |