aboutsummaryrefslogtreecommitdiff
path: root/internal/sqlutil/trace.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/sqlutil/trace.go')
-rw-r--r--internal/sqlutil/trace.go26
1 files changed, 21 insertions, 5 deletions
diff --git a/internal/sqlutil/trace.go b/internal/sqlutil/trace.go
index f6644d59..fbd983be 100644
--- a/internal/sqlutil/trace.go
+++ b/internal/sqlutil/trace.go
@@ -25,6 +25,7 @@ import (
"strings"
"time"
+ "github.com/matrix-org/dendrite/internal/config"
"github.com/ngrok/sqlmw"
"github.com/sirupsen/logrus"
)
@@ -77,7 +78,22 @@ func (in *traceInterceptor) RowsNext(c context.Context, rows driver.Rows, dest [
// Open opens a database specified by its database driver name and a driver-specific data source name,
// usually consisting of at least a database name and connection information. Includes tracing driver
// if DENDRITE_TRACE_SQL=1
-func Open(driverName, dsn string, dbProperties DbProperties) (*sql.DB, error) {
+func Open(dbProperties *config.DatabaseOptions) (*sql.DB, error) {
+ var err error
+ var driverName, dsn string
+ switch {
+ case dbProperties.ConnectionString.IsSQLite():
+ driverName = SQLiteDriverName()
+ dsn, err = ParseFileURI(dbProperties.ConnectionString)
+ if err != nil {
+ return nil, fmt.Errorf("ParseFileURI: %w", err)
+ }
+ case dbProperties.ConnectionString.IsPostgres():
+ driverName = "postgres"
+ dsn = string(dbProperties.ConnectionString)
+ default:
+ return nil, fmt.Errorf("invalid database connection string %q", dbProperties.ConnectionString)
+ }
if tracingEnabled {
// install the wrapped driver
driverName += "-trace"
@@ -86,11 +102,11 @@ func Open(driverName, dsn string, dbProperties DbProperties) (*sql.DB, error) {
if err != nil {
return nil, err
}
- if driverName != SQLiteDriverName() && dbProperties != nil {
+ if driverName != SQLiteDriverName() {
logrus.WithFields(logrus.Fields{
- "MaxOpenConns": dbProperties.MaxOpenConns(),
- "MaxIdleConns": dbProperties.MaxIdleConns(),
- "ConnMaxLifetime": dbProperties.ConnMaxLifetime(),
+ "MaxOpenConns": dbProperties.MaxOpenConns,
+ "MaxIdleConns": dbProperties.MaxIdleConns,
+ "ConnMaxLifetime": dbProperties.ConnMaxLifetime,
"dataSourceName": regexp.MustCompile(`://[^@]*@`).ReplaceAllLiteralString(dsn, "://"),
}).Debug("Setting DB connection limits")
db.SetMaxOpenConns(dbProperties.MaxOpenConns())