diff options
Diffstat (limited to 'internal/sqlutil/trace.go')
-rw-r--r-- | internal/sqlutil/trace.go | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/internal/sqlutil/trace.go b/internal/sqlutil/trace.go index f6644d59..fbd983be 100644 --- a/internal/sqlutil/trace.go +++ b/internal/sqlutil/trace.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "github.com/matrix-org/dendrite/internal/config" "github.com/ngrok/sqlmw" "github.com/sirupsen/logrus" ) @@ -77,7 +78,22 @@ func (in *traceInterceptor) RowsNext(c context.Context, rows driver.Rows, dest [ // Open opens a database specified by its database driver name and a driver-specific data source name, // usually consisting of at least a database name and connection information. Includes tracing driver // if DENDRITE_TRACE_SQL=1 -func Open(driverName, dsn string, dbProperties DbProperties) (*sql.DB, error) { +func Open(dbProperties *config.DatabaseOptions) (*sql.DB, error) { + var err error + var driverName, dsn string + switch { + case dbProperties.ConnectionString.IsSQLite(): + driverName = SQLiteDriverName() + dsn, err = ParseFileURI(dbProperties.ConnectionString) + if err != nil { + return nil, fmt.Errorf("ParseFileURI: %w", err) + } + case dbProperties.ConnectionString.IsPostgres(): + driverName = "postgres" + dsn = string(dbProperties.ConnectionString) + default: + return nil, fmt.Errorf("invalid database connection string %q", dbProperties.ConnectionString) + } if tracingEnabled { // install the wrapped driver driverName += "-trace" @@ -86,11 +102,11 @@ func Open(driverName, dsn string, dbProperties DbProperties) (*sql.DB, error) { if err != nil { return nil, err } - if driverName != SQLiteDriverName() && dbProperties != nil { + if driverName != SQLiteDriverName() { logrus.WithFields(logrus.Fields{ - "MaxOpenConns": dbProperties.MaxOpenConns(), - "MaxIdleConns": dbProperties.MaxIdleConns(), - "ConnMaxLifetime": dbProperties.ConnMaxLifetime(), + "MaxOpenConns": dbProperties.MaxOpenConns, + "MaxIdleConns": dbProperties.MaxIdleConns, + "ConnMaxLifetime": dbProperties.ConnMaxLifetime, "dataSourceName": regexp.MustCompile(`://[^@]*@`).ReplaceAllLiteralString(dsn, "://"), }).Debug("Setting DB connection limits") db.SetMaxOpenConns(dbProperties.MaxOpenConns()) |