aboutsummaryrefslogtreecommitdiff
path: root/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'server.go')
-rw-r--r--server.go140
1 files changed, 140 insertions, 0 deletions
diff --git a/server.go b/server.go
new file mode 100644
index 0000000..e34e3f2
--- /dev/null
+++ b/server.go
@@ -0,0 +1,140 @@
+package main
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "os/user"
+ "strings"
+)
+
+const TableName = "membership"
+
+type Server struct {
+ HostName string
+ Groups []string
+ GroupsFile string
+}
+
+func ServeUnixSocket(listenner net.Listener, srv *Server) error {
+ for {
+ // Accept an incoming connection.
+ conn, err := listenner.Accept()
+ if errors.Is(err, net.ErrClosed) {
+ break
+ } else if err != nil {
+ fmt.Fprintln(os.Stderr, err.Error())
+
+ continue
+ }
+
+ go func(conn net.Conn) {
+ defer conn.Close()
+
+ err = srv.HandleConn(conn, conn)
+ if err != nil {
+ fmt.Fprintln(os.Stderr, err.Error())
+ }
+ }(conn)
+
+ }
+
+ return nil
+}
+
+func (s *Server) readGroupList() ([]string, error) {
+ f, err := os.Open(s.GroupsFile)
+ if err != nil {
+ return nil, err
+ }
+
+ buf, err := io.ReadAll(f)
+ if err != nil {
+ return nil, err
+ }
+
+ groups := strings.Split(string(buf), "\n")
+
+ return groups, nil
+}
+
+func (s *Server) HandleConn(w io.Writer, r io.Reader) error {
+ for {
+ var req Request
+
+ _, err := readNetString(r, &req)
+ if errors.Is(err, io.EOF) {
+ break
+ }
+
+ if req.Name != TableName {
+ if _, err = writeNetString(w, &ReplyPerm{Reason: "unexpected name"}); err != nil {
+ return err
+ }
+
+ return nil
+ }
+
+ strs := strings.SplitN(req.Key, "@", 2)
+ if len(strs) != 2 {
+ return fmt.Errorf("unexpected email address format")
+ }
+ userName := strs[0]
+ hostName := strs[1]
+
+ if hostName != s.HostName {
+ if _, err := writeNetString(w, &ReplyNotFound{}); err != nil {
+ return err
+ }
+
+ return nil
+ }
+
+ u, err := user.Lookup(userName)
+ if _, ok := err.(user.UnknownUserError); ok {
+ _, err := writeNetString(w, &ReplyNotFound{})
+ return err
+ } else if err != nil {
+ return err
+ }
+
+ groups, err := lookupGroupNames(u)
+ if err != nil {
+ return err
+ }
+
+ isListed := make(map[string]bool)
+
+ for _, v := range s.Groups {
+ isListed[v] = true
+ }
+
+ groupsFile, err := s.readGroupList()
+ if err != nil {
+ return fmt.Errorf("reading white list: %w", err)
+ }
+
+ for _, v := range groupsFile {
+ isListed[v] = true
+ }
+
+ found := false
+
+ for _, name := range groups {
+ found = isListed[name]
+ if found {
+ _, err = writeNetString(w, &ReplyOK{Data: "OK"})
+ return err
+ }
+ }
+
+ _, err = writeNetString(w, &ReplyNotFound{})
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}