diff options
Diffstat (limited to 'server.go')
-rw-r--r-- | server.go | 140 |
1 files changed, 140 insertions, 0 deletions
diff --git a/server.go b/server.go new file mode 100644 index 0000000..e34e3f2 --- /dev/null +++ b/server.go @@ -0,0 +1,140 @@ +package main + +import ( + "errors" + "fmt" + "io" + "net" + "os" + "os/user" + "strings" +) + +const TableName = "membership" + +type Server struct { + HostName string + Groups []string + GroupsFile string +} + +func ServeUnixSocket(listenner net.Listener, srv *Server) error { + for { + // Accept an incoming connection. + conn, err := listenner.Accept() + if errors.Is(err, net.ErrClosed) { + break + } else if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + + continue + } + + go func(conn net.Conn) { + defer conn.Close() + + err = srv.HandleConn(conn, conn) + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + } + }(conn) + + } + + return nil +} + +func (s *Server) readGroupList() ([]string, error) { + f, err := os.Open(s.GroupsFile) + if err != nil { + return nil, err + } + + buf, err := io.ReadAll(f) + if err != nil { + return nil, err + } + + groups := strings.Split(string(buf), "\n") + + return groups, nil +} + +func (s *Server) HandleConn(w io.Writer, r io.Reader) error { + for { + var req Request + + _, err := readNetString(r, &req) + if errors.Is(err, io.EOF) { + break + } + + if req.Name != TableName { + if _, err = writeNetString(w, &ReplyPerm{Reason: "unexpected name"}); err != nil { + return err + } + + return nil + } + + strs := strings.SplitN(req.Key, "@", 2) + if len(strs) != 2 { + return fmt.Errorf("unexpected email address format") + } + userName := strs[0] + hostName := strs[1] + + if hostName != s.HostName { + if _, err := writeNetString(w, &ReplyNotFound{}); err != nil { + return err + } + + return nil + } + + u, err := user.Lookup(userName) + if _, ok := err.(user.UnknownUserError); ok { + _, err := writeNetString(w, &ReplyNotFound{}) + return err + } else if err != nil { + return err + } + + groups, err := lookupGroupNames(u) + if err != nil { + return err + } + + isListed := make(map[string]bool) + + for _, v := range s.Groups { + isListed[v] = true + } + + groupsFile, err := s.readGroupList() + if err != nil { + return fmt.Errorf("reading white list: %w", err) + } + + for _, v := range groupsFile { + isListed[v] = true + } + + found := false + + for _, name := range groups { + found = isListed[name] + if found { + _, err = writeNetString(w, &ReplyOK{Data: "OK"}) + return err + } + } + + _, err = writeNetString(w, &ReplyNotFound{}) + if err != nil { + return err + } + } + + return nil +} |