aboutsummaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-07-21 08:34:01 +0200
committerGitHub <noreply@github.com>2023-07-21 08:34:01 +0200
commite216c2fbf0fd117ddb8b96b05d514b9987cbb0d2 (patch)
treedae1a7ff940d1fc0bbb88e20656b53e1a8b0d0a6 /internal
parent958282749391a13dc6f03c1dd13a9554fb5db3ae (diff)
Update ConnectionManager to still allow component defined connections (#3154)
Diffstat (limited to 'internal')
-rw-r--r--internal/sqlutil/connection_manager.go67
-rw-r--r--internal/sqlutil/connection_manager_test.go22
2 files changed, 60 insertions, 29 deletions
diff --git a/internal/sqlutil/connection_manager.go b/internal/sqlutil/connection_manager.go
index 4933cfaf..437da6c8 100644
--- a/internal/sqlutil/connection_manager.go
+++ b/internal/sqlutil/connection_manager.go
@@ -17,16 +17,21 @@ package sqlutil
import (
"database/sql"
"fmt"
+ "sync"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
)
type Connections struct {
- db *sql.DB
- writer Writer
- globalConfig config.DatabaseOptions
- processContext *process.ProcessContext
+ globalConfig config.DatabaseOptions
+ processContext *process.ProcessContext
+ existingConnections sync.Map
+}
+
+type con struct {
+ db *sql.DB
+ writer Writer
}
func NewConnectionManager(processCtx *process.ProcessContext, globalConfig config.DatabaseOptions) *Connections {
@@ -38,9 +43,13 @@ func NewConnectionManager(processCtx *process.ProcessContext, globalConfig confi
func (c *Connections) Connection(dbProperties *config.DatabaseOptions) (*sql.DB, Writer, error) {
var err error
+ // If no connectionString was provided, try the global one
if dbProperties.ConnectionString == "" {
- // if no connectionString was provided, try the global one
dbProperties = &c.globalConfig
+ // If we still don't have a connection string, that's a problem
+ if dbProperties.ConnectionString == "" {
+ return nil, nil, fmt.Errorf("no database connections configured")
+ }
}
writer := NewDummyWriter()
@@ -48,30 +57,30 @@ func (c *Connections) Connection(dbProperties *config.DatabaseOptions) (*sql.DB,
writer = NewExclusiveWriter()
}
- if dbProperties.ConnectionString != "" && c.db == nil {
- // Open a new database connection using the supplied config.
- c.db, err = Open(dbProperties, writer)
- if err != nil {
- return nil, nil, err
- }
- c.writer = writer
- go func() {
- if c.processContext == nil {
- return
- }
- // If we have a ProcessContext, start a component and wait for
- // Dendrite to shut down to cleanly close the database connection.
- c.processContext.ComponentStarted()
- <-c.processContext.WaitForShutdown()
- _ = c.db.Close()
- c.processContext.ComponentFinished()
- }()
- return c.db, c.writer, nil
+ existing, loaded := c.existingConnections.LoadOrStore(dbProperties.ConnectionString, &con{})
+ if loaded {
+ // We found an existing connection
+ ex := existing.(*con)
+ return ex.db, ex.writer, nil
}
- if c.db != nil && c.writer != nil {
- // Ignore the supplied config and return the global pool and
- // writer.
- return c.db, c.writer, nil
+
+ // Open a new database connection using the supplied config.
+ db, err := Open(dbProperties, writer)
+ if err != nil {
+ return nil, nil, err
}
- return nil, nil, fmt.Errorf("no database connections configured")
+ c.existingConnections.Store(dbProperties.ConnectionString, &con{db: db, writer: writer})
+ go func() {
+ if c.processContext == nil {
+ return
+ }
+ // If we have a ProcessContext, start a component and wait for
+ // Dendrite to shut down to cleanly close the database connection.
+ c.processContext.ComponentStarted()
+ <-c.processContext.WaitForShutdown()
+ _ = db.Close()
+ c.processContext.ComponentFinished()
+ }()
+ return db, writer, nil
+
}
diff --git a/internal/sqlutil/connection_manager_test.go b/internal/sqlutil/connection_manager_test.go
index 965d3b9b..5086684b 100644
--- a/internal/sqlutil/connection_manager_test.go
+++ b/internal/sqlutil/connection_manager_test.go
@@ -48,6 +48,22 @@ func TestConnectionManager(t *testing.T) {
if !reflect.DeepEqual(writer, writer2) {
t.Fatalf("expected database writer to be reused")
}
+
+ // This test does not work with Postgres, because we can't just simply append
+ // "x" or replace the database to use.
+ if dbType == test.DBTypePostgres {
+ return
+ }
+
+ // Test different connection string
+ dbProps = &config.DatabaseOptions{ConnectionString: config.DataSource(conStr + "x")}
+ db3, _, err := cm.Connection(dbProps)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if reflect.DeepEqual(db, db3) {
+ t.Fatalf("expected different database connection")
+ }
})
})
@@ -115,4 +131,10 @@ func TestConnectionManager(t *testing.T) {
if err == nil {
t.Fatal("expected an error but got none")
}
+
+ // empty connection string is not allowed
+ _, _, err = cm2.Connection(&config.DatabaseOptions{})
+ if err == nil {
+ t.Fatal("expected an error but got none")
+ }
}