aboutsummaryrefslogtreecommitdiff
path: root/src/Ssb/Peer/RPC/Room.hs
blob: 1f53cf0dfb015d382f678fe4bc53e6ca31bddc23 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
-- | This module implements Scuttlebutt's Remote Procedure Call for
-- Rooms.
--
-- For more information kindly refer [WHERE]

-- TODO: Documentation for SSB-Room

module Ssb.Peer.RPC.Room where


import           Protolude               hiding ( Identity )
import qualified Data.Aeson                    as Aeson
import           Control.Concurrent.STM
import           Data.Default
import qualified Data.Map.Strict               as Map
import           Data.Time.Clock                ( UTCTime
                                                , getCurrentTime
                                                )
import qualified Data.Text                     as Text

import  Ssb.Aux
import qualified Ssb.Identity                  as Ssb
import qualified Ssb.Pub                       as Ssb
import           Ssb.Network
import qualified Ssb.Feed                      as Feed
import qualified Ssb.Peer.RPC                  as RPC

seed :: Text
seed = "SSB+Room+PSK3TLYC2T86EHQCUHBUHASCASE18JBV24="

data Invite = Invite
    { host :: Host
    , port :: Port
    , key  :: Ssb.PublicKey
    } deriving (Eq, Show)

formatInvite :: Invite -> Text
formatInvite arg =
  "net:"
    <> host arg
    <> ":"
    <> port arg
    <> "~shs"
    <> ":"
    <> formatPublicKey' arg
    <> ":"
    <> seed
 where
  formatPublicKey' arg =
    Text.dropEnd 8 $ Text.drop 1 $ Ssb.formatPublicKey $ key arg

-- TODO: Seriously, figure out how to use duplicate field names
data Room = Room {
      endpoints     :: TMVar (Map Ssb.PublicKey (RPC.ConnState, [Tunnel]))
    , notifyChange  :: TChan Bool
    , roomName      :: Text
    , roomDesc      :: Text
    , tunnels       :: TMVar (Map Tunnel [ThreadId])
    }

newRoom :: Text -> Text -> IO Room
newRoom name desc = do
  endpoints <- newTMVarIO Map.empty
  notifier  <- newBroadcastTChanIO
  tunnels   <- newTMVarIO Map.empty
  return $ Room endpoints notifier name desc tunnels

getEndpoints :: Room -> IO [Ssb.PublicKey]
getEndpoints h = do
  let mVar = endpoints h
  Map.keys <$> atomically (readTMVar mVar)

lookUpPeer :: Room -> Ssb.PublicKey -> IO (Maybe RPC.ConnState)
lookUpPeer h peer = do
  endpoints' <- atomically $ readTMVar (endpoints h)
  return $ fst <$> Map.lookup peer endpoints'

errConnLimitReached = "peer limit reached"

registerPeer :: Room -> RPC.Stream -> IO (Either Text ())
registerPeer h stream = do
  let mVar = endpoints h
  atomically $ do
    endpoints' <- takeTMVar mVar
    if Map.size endpoints' == (maxBound :: Int)
      then return $ Left errConnLimitReached
      else do
        putTMVar mVar
          $ Map.insert (RPC.peer stream) (RPC.conn stream, []) endpoints'
        writeTChan (notifyChange h) True
        return $ Right ()

unregisterPeer :: Room -> Ssb.PublicKey -> IO ()
unregisterPeer h peer = do
  let mVar = endpoints h
  atomically $ do
    endpoints' <- takeTMVar mVar
    putTMVar mVar $ Map.delete peer endpoints'
    writeTChan (notifyChange h) True

data IsRoomResponse = IsRoomResponse {
      name          :: Text
    , description   :: Text
    } deriving (Generic, Show)

instance Aeson.FromJSON IsRoomResponse

instance Aeson.ToJSON IsRoomResponse

newIsRoomResponse :: Room -> IsRoomResponse
newIsRoomResponse l = IsRoomResponse (roomName l) (roomDesc l)

announceRequest :: RPC.Request [Text]
announceRequest =
  RPC.Request { name = ["tunnel", "announce"], typ = RPC.Async, args = [] }

announce :: RPC.ConnState -> IO (Either Text ())
announce conn = RPC.requestAsync conn announceRequest

data ConnectRequest = ConnectRequest {
      target   :: Ssb.PublicKey
    , portal   :: Ssb.PublicKey
    } deriving (Generic, Show)

instance Aeson.FromJSON ConnectRequest

instance Aeson.ToJSON ConnectRequest

leaveRequest :: RPC.Request [Text]
leaveRequest =
  RPC.Request { name = ["tunnel", "leave"], typ = RPC.Async, args = [] }

leave :: RPC.ConnState -> IO (Either Text ())
leave conn = RPC.requestAsync conn leaveRequest

pingRequest :: RPC.Request [Text]
pingRequest =
  RPC.Request { name = ["tunnel", "ping"], typ = RPC.Async, args = [] }

ping :: RPC.ConnState -> IO (Either Text UTCTime)
ping conn = RPC.requestAsync conn pingRequest

-- | fork creates a new thread, incrementing the counter in the given mVar.
fork mVar action = do
  fork' <- newEmptyTMVarIO
  atomically $ do
    forks <- takeTMVar mVar
    putTMVar mVar $ fork' : forks
  forkFinally
    action
    (\_ -> do
      atomically $ putTMVar fork' ()
    )

-- | wait returns when all forks have completed.
waitForkGroup mVar threads = do
  cs <- atomically $ takeTMVar mVar
  case cs of
    []     -> return ()
    m : ms -> do
      atomically $ putTMVar mVar ms
      atomically $ takeTMVar m
      forM_ threads killThread
      waitForkGroup mVar []

forwardMessages :: RPC.Stream -> RPC.Stream -> IO (Either Text ())
forwardMessages s1 s2 = do
  mMsg <- RPC.readStream s1
  case mMsg of
    Nothing  -> return $ Right ()
    Just msg -> do
      res <- RPC.writeStream s2 def msg
      case res of
        Left  err -> return $ Left err
        Right _   -> forwardMessages s1 s2

type Tunnel = (Ssb.PublicKey, Ssb.PublicKey)

newTunnel arg1 arg2 =
  -- the tunnel's entries need to be ordered to make pairs unique
  if arg1 < arg2 then (arg1, arg2) else (arg2, arg1)

createTunnel :: Room -> (RPC.Stream, RPC.Stream) -> IO (Either Text ())
createTunnel room (stream1, stream2) = do
  let peer1  = RPC.connPeer $ RPC.conn stream1
  let peer2  = RPC.connPeer $ RPC.conn stream2
  let tunnel = newTunnel peer1 peer2

  exists <- atomically $ do
    tunnels' <- readTMVar (tunnels room)
    return $ Map.member tunnel tunnels'
  if exists
    then return $ Left "only one tunnel allowed"
    else do
      waiter  <- newTMVarIO []
      thread1 <- fork waiter $ forwardMessages stream1 stream2
      thread2 <- fork waiter $ forwardMessages stream2 stream1
      let threads = [thread1, thread2]

      atomically $ do
        endpoints' <- takeTMVar (endpoints room)
        let endpoints'' = Map.adjust (addTunnel tunnel) (fst tunnel) endpoints'
        let endpoints''' =
              Map.adjust (addTunnel tunnel) (snd tunnel) endpoints''
        putTMVar (endpoints room) endpoints'''

        tunnels' <- takeTMVar (tunnels room)
        let tunnels'' = Map.insert tunnel threads tunnels'
        putTMVar (tunnels room) tunnels''

      waitForkGroup waiter threads

      atomically $ do
        endpoints' <- takeTMVar (endpoints room)
        let endpoints'' =
              Map.adjust (removeTunnel tunnel) (fst tunnel) endpoints'
        let endpoints''' =
              Map.adjust (removeTunnel tunnel) (snd tunnel) endpoints''
        putTMVar (endpoints room) endpoints'''

        tunnels' <- takeTMVar (tunnels room)
        putTMVar (tunnels room) $ Map.delete tunnel tunnels'

      return . return $ ()
 where
  addTunnel arg (conn, tunnels) = (conn, tunnels ++ [arg])
  removeTunnel arg (conn, tunnels) = (conn, filter (arg /=) tunnels)

connect :: Room -> RPC.Stream -> ConnectRequest -> IO (Either Text ())
connect room stream req = do
  let tunnel = newTunnel (RPC.connPeer $ RPC.conn $ stream) (target req)
  tunnels' <- atomically $ readTMVar (tunnels room)
  mPeer    <- lookUpPeer room (target req)
  let peerConn = do
        bool (return ()) errSelfNotAllowed   (fst tunnel == snd tunnel)
        bool (return ()) errOnlyUniqueTunnel (Map.member tunnel tunnels')
        maybeToRight errPeerNotAvailable mPeer
  case peerConn of
    Left err -> do
      putStrLn $ "room error:  " <> err
      return $ Left err
    Right conn -> RPC.request conn (rpcRequest req)
      $ \peerStream -> createTunnel room (stream, peerStream)
 where
  errOnlyUniqueTunnel = Left $ "only unique tunnels are allowed"
  errPeerNotAvailable = "peer is not connected" :: Text
  errSelfNotAllowed   = Left $ "connecting to self not allowed"
  rpcRequest arg =
    RPC.Request { name = ["tunnel", "connect"], typ = RPC.Duplex, args = [arg] }

leave' :: Room -> Ssb.PublicKey -> IO (Either Text ())
leave' room peer = do
  endpoints' <- atomically $ readTMVar (endpoints room)
  tunnels'   <- atomically $ readTMVar (tunnels room)

  let etunnels = fromMaybe mempty $ snd <$> Map.lookup peer endpoints'
  forM_ etunnels $ \tunnel -> do
    let threads = fromMaybe mempty $ Map.lookup tunnel tunnels'
    forM_ threads killThread
  unregisterPeer room peer
  return . return $ ()

instance RPC.Handler Room where
  endpoints _ =
    [ RPC.Endpoint ["tunnel", "announce"] RPC.Async
    , RPC.Endpoint ["tunnel", "connect"] RPC.Duplex
    , RPC.Endpoint ["tunnel", "endpoints"] RPC.Async
    , RPC.Endpoint ["tunnel", "leave"] RPC.Async
    , RPC.Endpoint ["tunnel", "isRoom"] RPC.Async
    , RPC.Endpoint ["tunnel", "ping"] RPC.Async
    ]

    -- The announce and leave endpoints are defined by the server's JS, but
    -- never called by the client.  The official (NodeJS) project uses
    -- disconnect notifications from the SSB-server.
    --
  serve h (RPC.Endpoint ["tunnel", "announce"] _) args stream =
    registerPeer h stream

  -- should decode request
  serve h (RPC.Endpoint ["tunnel", "connect"] RPC.Duplex) args stream = do
    let args' =
          decodeJSON (toS $ Aeson.encode args) :: Either
              Text
              [ConnectRequest]
    case args' of
      Left  err       -> return $ Left err
      Right [connReq] -> connect h stream connReq
      otherwise       -> return $ Left "bad target argument"

  serve h (RPC.Endpoint ["tunnel", "endpoints"] _) _ stream = do
    err <- registerPeer h stream
    case err of
      Left  msg -> return $ Left msg
      Right _   -> do
        change <- atomically $ dupTChan $ notifyChange h
        while $ do
          endpoints' <- getEndpoints h
          let resp = filter (/= RPC.peer stream) endpoints'
          res <- RPC.writeStreamJSON stream resp
          if isLeft res then return False else atomically $ readTChan change
        return $ Right ()
   where
    while f = do
      continue <- f
      if continue then (while f) else return False

  serve room (RPC.Endpoint ["tunnel", "leave"] _) _ stream =
    leave' room (RPC.peer stream)

  serve room (RPC.Endpoint ["tunnel", "isRoom"] _) _ stream =
    RPC.writeStreamJSON stream (newIsRoomResponse room)

  serve room (RPC.Endpoint ["tunnel", "ping"] _) _ stream = do
    resp <- getCurrentTime
    RPC.writeStreamJSON stream resp

  -- HACK: return OK when endpoint not known to avoid disconnecting clients
  serve room endpoint@otherwise arg stream = return . return $ ()
  --serve room endpoint@otherwise arg stream = (RPC.notFoundHandlerFunc endpoint) arg stream

  notifyConnect _ _ = return . return $ ()

  notifyDisconnect room peer = do
    _ <- leave' room peer
    return . return $ ()