server.go (2339B)
1 package main 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "net" 8 "os" 9 "os/user" 10 "strings" 11 ) 12 13 const TableName = "membership" 14 15 type Server struct { 16 HostName string 17 Groups []string 18 GroupsFile string 19 } 20 21 func ServeUnixSocket(listenner net.Listener, srv *Server) error { 22 for { 23 // Accept an incoming connection. 24 conn, err := listenner.Accept() 25 if errors.Is(err, net.ErrClosed) { 26 break 27 } else if err != nil { 28 fmt.Fprintln(os.Stderr, err.Error()) 29 30 continue 31 } 32 33 go func(conn net.Conn) { 34 defer conn.Close() 35 36 err = srv.HandleConn(conn, conn) 37 if err != nil { 38 fmt.Fprintln(os.Stderr, err.Error()) 39 } 40 }(conn) 41 42 } 43 44 return nil 45 } 46 47 func (s *Server) readGroupList() ([]string, error) { 48 f, err := os.Open(s.GroupsFile) 49 if err != nil { 50 return nil, err 51 } 52 53 buf, err := io.ReadAll(f) 54 if err != nil { 55 return nil, err 56 } 57 58 groups := strings.Split(string(buf), "\n") 59 60 return groups, nil 61 } 62 63 func (s *Server) HandleConn(w io.Writer, r io.Reader) error { 64 for { 65 var req Request 66 67 _, err := readNetString(r, &req) 68 if errors.Is(err, io.EOF) { 69 break 70 } 71 72 if req.Name != TableName { 73 if _, err = writeNetString(w, &ReplyPerm{Reason: "unexpected name"}); err != nil { 74 return err 75 } 76 77 return nil 78 } 79 80 strs := strings.SplitN(req.Key, "@", 2) 81 if len(strs) != 2 { 82 return fmt.Errorf("unexpected email address format") 83 } 84 userName := strs[0] 85 hostName := strs[1] 86 87 if hostName != s.HostName { 88 if _, err := writeNetString(w, &ReplyNotFound{}); err != nil { 89 return err 90 } 91 92 return nil 93 } 94 95 u, err := user.Lookup(userName) 96 if _, ok := err.(user.UnknownUserError); ok { 97 _, err := writeNetString(w, &ReplyNotFound{}) 98 return err 99 } else if err != nil { 100 return err 101 } 102 103 groups, err := lookupGroupNames(u) 104 if err != nil { 105 return err 106 } 107 108 isListed := make(map[string]bool) 109 110 for _, v := range s.Groups { 111 isListed[v] = true 112 } 113 114 groupsFile, err := s.readGroupList() 115 if err != nil { 116 return fmt.Errorf("reading white list: %w", err) 117 } 118 119 for _, v := range groupsFile { 120 isListed[v] = true 121 } 122 123 found := false 124 125 for _, name := range groups { 126 found = isListed[name] 127 if found { 128 _, err = writeNetString(w, &ReplyOK{Data: "OK"}) 129 return err 130 } 131 } 132 133 _, err = writeNetString(w, &ReplyNotFound{}) 134 if err != nil { 135 return err 136 } 137 } 138 139 return nil 140 }