aboutsummaryrefslogtreecommitdiff
path: root/setup
diff options
context:
space:
mode:
Diffstat (limited to 'setup')
-rw-r--r--setup/base/base.go38
-rw-r--r--setup/base/base_test.go49
-rw-r--r--setup/config/config.go15
-rw-r--r--setup/config/config_address.go45
-rw-r--r--setup/config/config_address_test.go25
5 files changed, 146 insertions, 26 deletions
diff --git a/setup/base/base.go b/setup/base/base.go
index aabdd793..dfe48ff3 100644
--- a/setup/base/base.go
+++ b/setup/base/base.go
@@ -20,9 +20,11 @@ import (
"database/sql"
"embed"
"encoding/json"
+ "errors"
"fmt"
"html/template"
"io"
+ "io/fs"
"net"
"net/http"
_ "net/http/pprof"
@@ -85,8 +87,6 @@ type BaseDendrite struct {
startupLock sync.Mutex
}
-const NoListener = ""
-
const HTTPServerTimeout = time.Minute * 5
type BaseDendriteOptions int
@@ -345,18 +345,17 @@ func (b *BaseDendrite) ConfigureAdminEndpoints() {
// SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs
// and adds a prometheus handler under /_dendrite/metrics.
func (b *BaseDendrite) SetupAndServeHTTP(
- externalHTTPAddr config.HTTPAddress,
+ externalHTTPAddr config.ServerAddress,
certFile, keyFile *string,
) {
// Manually unlocked right before actually serving requests,
// as we don't return from this method (defer doesn't work).
b.startupLock.Lock()
- externalAddr, _ := externalHTTPAddr.Address()
externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()
externalServ := &http.Server{
- Addr: string(externalAddr),
+ Addr: externalHTTPAddr.Address,
WriteTimeout: HTTPServerTimeout,
Handler: externalRouter,
BaseContext: func(_ net.Listener) context.Context {
@@ -419,7 +418,7 @@ func (b *BaseDendrite) SetupAndServeHTTP(
b.startupLock.Unlock()
- if externalAddr != NoListener {
+ if externalHTTPAddr.Enabled() {
go func() {
var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once
logrus.Infof("Starting external listener on %s", externalServ.Addr)
@@ -437,9 +436,30 @@ func (b *BaseDendrite) SetupAndServeHTTP(
}
}
} else {
- if err := externalServ.ListenAndServe(); err != nil {
- if err != http.ErrServerClosed {
- logrus.WithError(err).Fatal("failed to serve HTTP")
+ if externalHTTPAddr.IsUnixSocket() {
+ err := os.Remove(externalHTTPAddr.Address)
+ if err != nil && !errors.Is(err, fs.ErrNotExist) {
+ logrus.WithError(err).Fatal("failed to remove existing unix socket")
+ }
+ listener, err := net.Listen(externalHTTPAddr.Network(), externalHTTPAddr.Address)
+ if err != nil {
+ logrus.WithError(err).Fatal("failed to serve unix socket")
+ }
+ err = os.Chmod(externalHTTPAddr.Address, externalHTTPAddr.UnixSocketPermission)
+ if err != nil {
+ logrus.WithError(err).Fatal("failed to set unix socket permissions")
+ }
+ if err := externalServ.Serve(listener); err != nil {
+ if err != http.ErrServerClosed {
+ logrus.WithError(err).Fatal("failed to serve unix socket")
+ }
+ }
+
+ } else {
+ if err := externalServ.ListenAndServe(); err != nil {
+ if err != http.ErrServerClosed {
+ logrus.WithError(err).Fatal("failed to serve HTTP")
+ }
}
}
}
diff --git a/setup/base/base_test.go b/setup/base/base_test.go
index d906294c..658dc5b0 100644
--- a/setup/base/base_test.go
+++ b/setup/base/base_test.go
@@ -2,10 +2,13 @@ package base_test
import (
"bytes"
+ "context"
"embed"
"html/template"
+ "net"
"net/http"
"net/http/httptest"
+ "path"
"testing"
"time"
@@ -18,7 +21,7 @@ import (
//go:embed static/*.gotmpl
var staticContent embed.FS
-func TestLandingPage(t *testing.T) {
+func TestLandingPage_Tcp(t *testing.T) {
// generate the expected result
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
expectedRes := &bytes.Buffer{}
@@ -35,7 +38,9 @@ func TestLandingPage(t *testing.T) {
s.Close()
// start base with the listener and wait for it to be started
- go b.SetupAndServeHTTP(config.HTTPAddress(s.URL), nil, nil)
+ address, err := config.HTTPAddress(s.URL)
+ assert.NoError(t, err)
+ go b.SetupAndServeHTTP(address, nil, nil)
time.Sleep(time.Millisecond * 10)
// When hitting /, we should be redirected to /_matrix/static, which should contain the landing page
@@ -55,3 +60,43 @@ func TestLandingPage(t *testing.T) {
// Using .String() for user friendly output
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
}
+
+func TestLandingPage_UnixSocket(t *testing.T) {
+ // generate the expected result
+ tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
+ expectedRes := &bytes.Buffer{}
+ err := tmpl.ExecuteTemplate(expectedRes, "index.gotmpl", map[string]string{
+ "Version": internal.VersionString(),
+ })
+ assert.NoError(t, err)
+
+ b, _, _ := testrig.Base(nil)
+ defer b.Close()
+
+ tempDir := t.TempDir()
+ socket := path.Join(tempDir, "socket")
+ // start base with the listener and wait for it to be started
+ address := config.UnixSocketAddress(socket, 0755)
+ assert.NoError(t, err)
+ go b.SetupAndServeHTTP(address, nil, nil)
+ time.Sleep(time.Millisecond * 100)
+
+ client := &http.Client{
+ Transport: &http.Transport{
+ DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
+ return net.Dial("unix", socket)
+ },
+ },
+ }
+ resp, err := client.Get("http://unix/")
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ // read the response
+ buf := &bytes.Buffer{}
+ _, err = buf.ReadFrom(resp.Body)
+ assert.NoError(t, err)
+
+ // Using .String() for user friendly output
+ assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
+}
diff --git a/setup/config/config.go b/setup/config/config.go
index 84876616..1a25f71e 100644
--- a/setup/config/config.go
+++ b/setup/config/config.go
@@ -19,7 +19,6 @@ import (
"encoding/pem"
"fmt"
"io"
- "net/url"
"os"
"path/filepath"
"regexp"
@@ -131,20 +130,6 @@ func (d DataSource) IsPostgres() bool {
// A Topic in kafka.
type Topic string
-// An Address to listen on.
-type Address string
-
-// An HTTPAddress to listen on, starting with either http:// or https://.
-type HTTPAddress string
-
-func (h HTTPAddress) Address() (Address, error) {
- url, err := url.Parse(string(h))
- if err != nil {
- return "", err
- }
- return Address(url.Host), nil
-}
-
// FileSizeBytes is a file size in bytes
type FileSizeBytes int64
diff --git a/setup/config/config_address.go b/setup/config/config_address.go
new file mode 100644
index 00000000..0e4f0296
--- /dev/null
+++ b/setup/config/config_address.go
@@ -0,0 +1,45 @@
+package config
+
+import (
+ "io/fs"
+ "net/url"
+)
+
+const (
+ NetworkTCP = "tcp"
+ NetworkUnix = "unix"
+)
+
+type ServerAddress struct {
+ Address string
+ Scheme string
+ UnixSocketPermission fs.FileMode
+}
+
+func (s ServerAddress) Enabled() bool {
+ return s.Address != ""
+}
+
+func (s ServerAddress) IsUnixSocket() bool {
+ return s.Scheme == NetworkUnix
+}
+
+func (s ServerAddress) Network() string {
+ if s.Scheme == NetworkUnix {
+ return NetworkUnix
+ } else {
+ return NetworkTCP
+ }
+}
+
+func UnixSocketAddress(path string, perm fs.FileMode) ServerAddress {
+ return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: perm}
+}
+
+func HTTPAddress(urlAddress string) (ServerAddress, error) {
+ parsedUrl, err := url.Parse(urlAddress)
+ if err != nil {
+ return ServerAddress{}, err
+ }
+ return ServerAddress{parsedUrl.Host, parsedUrl.Scheme, 0}, nil
+}
diff --git a/setup/config/config_address_test.go b/setup/config/config_address_test.go
new file mode 100644
index 00000000..1be484fd
--- /dev/null
+++ b/setup/config/config_address_test.go
@@ -0,0 +1,25 @@
+package config
+
+import (
+ "io/fs"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestHttpAddress_ParseGood(t *testing.T) {
+ address, err := HTTPAddress("http://localhost:123")
+ assert.NoError(t, err)
+ assert.Equal(t, "localhost:123", address.Address)
+ assert.Equal(t, "tcp", address.Network())
+}
+
+func TestHttpAddress_ParseBad(t *testing.T) {
+ _, err := HTTPAddress(":")
+ assert.Error(t, err)
+}
+
+func TestUnixSocketAddress_Network(t *testing.T) {
+ address := UnixSocketAddress("/tmp", fs.FileMode(0755))
+ assert.Equal(t, "unix", address.Network())
+}