aboutsummaryrefslogtreecommitdiff
path: root/src/Ssb/Peer/RPC.hs
blob: 3286a33368cb84aab32649735eb6e782fa69debd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
-- For more information kindly refer the to protocol guide
-- https://ssbc.github.io/scuttlebutt-protocol-guide

module Ssb.Peer.RPC where

import           Protolude

import           Control.Concurrent.STM
import           Control.Monad.Fail
import           Data.Aeson                     ( FromJSON
                                                , ToJSON
                                                )
import           Data.Aeson                    as Aeson
import qualified Data.ByteString               as BS
import qualified Data.ByteString.Lazy          as BS
                                                ( toStrict )
import qualified Data.Map.Strict               as Map
import           Data.Default
import           Data.Either.Combinators        ( mapLeft
                                                , mapRight
                                                )
import           Data.Text                     as Text
import           Data.Serialize                as Serialize

import           Ssb.Aux
import qualified Ssb.Identity                  as Ssb
import qualified Ssb.Peer.BoxStream            as BoxStream

data BodyType = Binary | UTF8String | JSON | UnknownBodyType
    deriving (Bounded,Eq,Show)

instance Convertible Word8 BodyType where
  convert v = case v .&. 3 of
    0 -> Binary
    1 -> UTF8String
    2 -> JSON
    v -> UnknownBodyType

instance Convertible BodyType Word8 where
  convert Binary          = 0
  convert UTF8String      = 1
  convert JSON            = 2
  convert UnknownBodyType = 3

data Flags = Flags
    { unused1      :: Bool
    , unused2      :: Bool
    , unused3      :: Bool
    , unused4      :: Bool
    , isStream     :: Bool
    , isEndOrError :: Bool
    , bodyType     :: BodyType
    } deriving (Eq,Show)

instance Convertible Flags Word8 where
  convert arg =
    set (unused1 arg) 7
      $ set (unused2 arg)      6
      $ set (unused3 arg)      5
      $ set (unused4 arg)      4
      $ set (isStream arg)     3
      $ set (isEndOrError arg) 2
      $ convert (bodyType arg)
   where
    set True  pos arg = setBit arg pos
    set False _   arg = arg

instance Convertible Word8 Flags where
  convert w = Flags { unused1      = testBit w 7
                    , unused2      = testBit w 6
                    , unused3      = testBit w 5
                    , unused4      = testBit w 4
                    , isStream     = testBit w 3
                    , isEndOrError = testBit w 2
                    , bodyType     = convert w
                    }

instance Default Flags where
  def = convert (zeroBits :: Word8)

instance Serialize.Serialize Flags where
  get = convert <$> getWord8
  put = putWord8 . convert

-- | ProcedureType defines the type of remote call.
data ProcedureType =  Async | Source | Duplex
    deriving (Eq,Generic,Ord,Show)

instance FromJSON ProcedureType where
  parseJSON = withText "ProcedureType" $ \v -> case v of
    "async"   -> return Async
    "source"  -> return Source
    "duplex"  -> return Duplex
    otherwise -> fail $ "unknown value '" <> toS v <> "'"

instance ToJSON ProcedureType where
  toJSON Async  = "async"
  toJSON Source = "source"
  toJSON Duplex = "duplex"

-- | HeaderLength is the length of the RPC header in bytes
headerLength :: Int
headerLength = 9

-- | bodySizeLength is the length of the bodySize parameter in the RPC header.
bodySizeLength :: Int
bodySizeLength = 4

-- | requestNumberLength is the length of the requestNumberLength parameter in
-- the RPC header.
requestNumberLength :: Int
requestNumberLength = 4

-- | Header is the first part of a RPC message used for stream control and communication.
data Header = Header
    { flags         :: Flags
    , bodyLength    :: Word32
    , requestNumber :: Int32
    } deriving (Eq,Generic,Show)

instance Serialize.Serialize Header

-- | GoodByeHeader is the RPC message header signalling the end of the RPC
-- stream.
goodByeHeader :: Header
goodByeHeader = Header (convert (zeroBits :: Word8)) 0 0

newHeader :: Flags -> Int32 -> ByteString -> Either Text Header
newHeader flags reqNum body = do
  let bodyLength = fromIntegral $ BS.length body
  return Header { flags         = flags
                , bodyLength    = bodyLength
                , requestNumber = reqNum
                }

-- | MessagePayload describes and contains the contents of an RPC message.
data MessagePayload = BinaryPayload ByteString | TextPayload ByteString | JSONPayload ByteString
    deriving (Generic,Show)

-- | Message is a single message in the RPC stream
data Message = Message
    { header :: Header
    , body   :: MessagePayload
    } deriving (Generic,Show)

instance Serialize.Serialize Message where
  get = do
    header <- Serialize.get
    buf    <- Serialize.getByteString (fromIntegral $ bodyLength header)
    let typ = bodyType . flags $ header
    payload <- case typ of
      Binary          -> return $ BinaryPayload buf
      UTF8String      -> return $ TextPayload buf
      JSON            -> return $ JSONPayload buf
      UnknownBodyType -> fail "unknown body type"
    return $ Message header payload

  put msg = do
    Serialize.put (header msg)
    case body msg of
      BinaryPayload buf -> Serialize.putByteString buf
      JSONPayload   buf -> Serialize.putByteString buf
      TextPayload   buf -> Serialize.putByteString buf

-- | newJSONMessage is a convenience function to create a RPC message with a
-- JSON payload.
newJSONMessage :: ToJSON a => Int32 -> a -> Either Text Message
newJSONMessage reqNum payload = do
  let buf = BS.toStrict $ Aeson.encode payload
  len     <- maybeWord8 $ BS.length buf
  header' <- newHeader (def { bodyType = JSON }) reqNum buf
  return $ Message { header = header', body = JSONPayload buf }

-- | newJSONMessage is a convenience function to decode a RPC JSON message.
decodeJSONMessage :: FromJSON a => Message -> Either Text a
decodeJSONMessage msg = case body msg of
  JSONPayload buf -> decodeJSON buf
  _               -> error "unexpected body type"

isRequest :: Int32 -> Message -> Bool
isRequest nextReqNum msg = do
  let reqNum = requestNumber . header $ msg
  nextReqNum == reqNum

-- | Request is the message representing a RPC call.
data Request a = Request
    { name :: [Text]
    , typ  :: ProcedureType
    , args :: a
    } deriving (Show)

instance (FromJSON a) => FromJSON (Request a) where
  parseJSON = withObject "Request"
    $ \v -> Request <$> v .: "name" <*> v .:? "type" .!= Async <*> v .: "args"

instance (ToJSON a) => ToJSON (Request a) where
  toJSON arg =
    object ["name" .= name arg, "type" .= typ arg, "args" .= args arg]

data Direction = Incoming | Outgoing
    deriving (Eq,Generic,Show)

data StreamStatus =
      -- | Open the stream is active.
      Open
      -- | The (Async) stream is awaiting a response.
    | AwaitingResponse
      -- | The stream is closed and waiting for its peer to do the same.
    | AwaitingCloseRecv
      -- | The stream is closed.
    | Closed
    deriving (Eq,Show)

data Stream  = Stream {
      streamID    :: Int32
    , streamType  :: ProcedureType
    , conn        :: ConnState
    , direction   :: Direction
    , status      :: StreamStatus
    , peer        :: Ssb.PublicKey
    }

formatStream :: Stream -> Text
formatStream stream =
  (show $ streamID stream :: Text)
    <> "\t"
    <> show (streamType stream)
    <> "\t"
    <> show (direction stream)
    <> "\t"
    <> show (status stream)

foldStream
  :: (a -> MessagePayload -> IO (Either Text a))
  -> a
  -> Stream
  -> IO (Either Text a)
foldStream fn acc stream = do
  msg <- readStream stream
  case msg of
    (Just msg') -> do
      acc' <- fn acc msg'
      either (return . error) (\v -> foldStream fn v stream) acc'
    Nothing -> return . return $ acc

foldJSONStream
  :: (FromJSON b)
  => (a -> b -> IO (Either Text a))
  -> a
  -> Stream
  -> IO (Either Text a)
foldJSONStream fn acc stream = do
  msg <- readStreamJSON stream
  case msg of
    (Right (Just msg')) -> do
      acc' <- fn acc msg'
      either (return . error) (\v -> foldJSONStream fn v stream) acc'
    (Right Nothing) -> return . return $ acc
    (Left  err    ) -> return $ error err

data ConnState = ConnState {
    connPeer           :: Ssb.PublicKey
  , boxConn            :: BoxStream.Conn
  , streamsIn          :: TMVar (Map Int32 (Stream, TChan (Maybe Message)))
  , streamsOut         :: TMVar (Map Int32 (Stream, TChan (Maybe Message)))
  , nextIncomingReqNum :: TMVar Int32
  , nextOutgoingReqNum :: TMVar Int32
  , lock               :: TMVar Bool
  }

-- | Endpoint identifies which Remote Procedure Call to make.
data Endpoint = Endpoint [Text] ProcedureType
  deriving (Eq, Ord)

formatEndpoint :: Endpoint -> Text
formatEndpoint (Endpoint paths typ) =
  Text.intercalate "." paths <> ":" <> show typ

-- | HandlerFunc's are used to serve Remove Procedure Calls.
type HandlerFunc = Aeson.Value -> Stream -> IO (Either Text ())

-- | notFoundHandlerFunc is tells the peer the Endpoint does not exist.
notFoundHandlerFunc :: Endpoint -> HandlerFunc
notFoundHandlerFunc endpoint _ _ = return
  (Left $ "endpoint not found '" <> endpoint' <> "'")
  where endpoint' = formatEndpoint endpoint

-- | Handler can serve Remote Procedure Requests.
class Handler h  where
    endpoints
        :: h
        -> [Endpoint]

    -- | Serve calls an incoming Remote Procedure Call.
    serve
        :: h
        -> Endpoint
        -> HandlerFunc

    -- TODO: Maybe add 'conn' to notifyConnect

    -- | notifyConnect tells the handler when a peer has connected.
    notifyConnect
        :: h
        -> Ssb.PublicKey
        -> IO (Either Text ())

    -- | notifyDisconnect tells the handler when a peer has disconnected.
    notifyDisconnect
        :: h
        -> Ssb.PublicKey
        -> IO (Either Text ())

-- | Router serves Remote Procedure Calls with a set of handlers.
data Router = Router {
      endpointHandlers :: Map Endpoint HandlerFunc
    , connectCallbacks :: [Ssb.PublicKey -> IO (Either Text ())]
    , disconnectCallbacks :: [Ssb.PublicKey -> IO (Either Text ())]
    } deriving (Generic)

instance Default Router

instance Handler Router where
  endpoints demuxer = Map.keys (endpointHandlers demuxer)

  serve demuxer endpoint = case Map.lookup endpoint endpointHandlers' of
    Nothing      -> (notFoundHandlerFunc endpoint)
    Just handler -> handler
    where endpointHandlers' = endpointHandlers demuxer

  notifyConnect demuxer id = do
    errs <- fmap lefts . sequence $ (\f -> f id) <$> connectCallbacks demuxer
    if Protolude.null errs
      then return . return $ ()
      else return $ Left $ Text.intercalate ", " errs

  notifyDisconnect demuxer id = do
    errs <- fmap lefts . sequence $ (\f -> f id) <$> connectCallbacks demuxer
    if Protolude.null errs
      then return . return $ ()
      else return $ Left $ Text.intercalate ", " errs


-- | with adds a handler to the Router.
with :: Handler h => Router -> h -> Router
with demuxer handler = Router
  { endpointHandlers    = Map.union (endpointHandlers demuxer) $ Map.fromList
                            ((\e -> (e, serve handler e)) <$> endpoints handler)
  , connectCallbacks    = connectCallbacks demuxer <> [notifyConnect handler]
  , disconnectCallbacks = disconnectCallbacks demuxer
                            <> [notifyDisconnect handler]
  }

-- | withM is a convenience function for using 'with' in monads.
withM :: (MonadIO m, Handler h) => m Router -> m h -> m Router
withM demuxer handler = with <$> demuxer <*> handler

logMsg :: ConnState -> Text -> IO ()
logMsg conn msg = do
  _ <- atomically $ takeTMVar (lock conn)
  putStrLn msg
  atomically $ putTMVar (lock conn) True

logDebug :: ConnState -> Text -> IO ()
logDebug _ _ = return ()
-- logDebug = logMsg

-- | spawnConnection handles safely forking and closing RPC connections.
spawnConnection
  :: Stream
  -> IO (Either Text ())
  -> IO ()
spawnConnection stream action = do
  forkFinally
    (do
      res <- action
      either (closeStreamWithError stream) (\_ -> closeStream stream) res
    )
    (\_ -> void (closeStream stream))
  return ()

-- | connect creates a Remote Procdure Call stream connection over the give Box
-- Stream.
connect
  :: Handler h
  => BoxStream.Conn
  -> h
  -> Ssb.PublicKey
  -> (ConnState -> IO ())
  -> IO (Either Text ())
connect boxConn handler peer client = do
  streamsIn          <- newTMVarIO Map.empty
  streamsOut         <- newTMVarIO Map.empty
  nextIncomingReqNum <- newTMVarIO 1
  nextOutgoingReqNum <- newTMVarIO 1
  lock               <- newTMVarIO True
  let conn = ConnState { connPeer           = peer
                       , boxConn            = boxConn
                       , streamsIn          = streamsIn
                       , streamsOut         = streamsOut
                       , nextIncomingReqNum = nextIncomingReqNum
                       , nextOutgoingReqNum = nextOutgoingReqNum
                       , lock               = lock
                       }

  forkIO $
    -- make RPC calls on the peer
           client conn
  ret <- serviceLoop conn
  _   <- notifyDisconnect handler (connPeer conn)
  disconnect conn
  return ret
 where
  serviceLoop conn = do
    msg <- readMessage conn
    case msg of
      Left  err -> return $ error ("connection error: " <> err)
      Right msg -> if header msg == goodByeHeader
        then return . return $ ()
        else do
          ok <- handleMessage handler conn msg
          if ok then serviceLoop conn else return . return $ ()

-- TODO: Handle stream closing within demux loop
--    The handleMessage loop is responsible for the lifetime of a stream
--    connection.  It detects and creates requests, it should therefor handle
--    close requests.
--
--    The stream type matters when closing a stream. Async do not require a
--    specific close message.
-- TODO: Handle closing of Async streams
-- TODO: Properly handle Error message

-- TODO: Refactor or reallocate helper functions
--    These functions were added in a rush to get request handling working.  A
--    better place could probably be found for them to improve code clarity.

-- | isEndOfStream checks whether the message is the last of the stream.
isEndOfStream = isEndOrError . flags . header

-- | checks if the message is a response to an Async request.
isAsyncResponse = not . isStream . flags . header

-- | streamTable gets which lookup table to use for connections for the given
-- stream.  There are seperate ones for incoming and outgoing RPC streams.
streamTable stream = case direction stream of
  Incoming -> streamsIn (conn stream)
  Outgoing -> streamsOut (conn stream)

streamStatus :: Stream -> STM StreamStatus
streamStatus stream = do
  table <- readTMVar $ streamTable stream
  return
    $ fromMaybe Closed (status . fst <$> Map.lookup (streamID stream) table)

-- | manageStreamConn manages stream connection changes when writing messages
-- to a stream.
-- TODO: beware of entering a deadlock in manageStreamConn
-- TODO: request stuff should be moved here
manageStreamConn :: Stream -> Message -> IO (Either Text ())
manageStreamConn stream msg = do
  table <- atomically $ takeTMVar $ streamTable stream
  res   <- writeMessage (conn stream) msg
  let table' = if
        | isLeft res -> closeStream table stream
        | (not . isStream . flags . header $ msg) -> Map.adjust
          (updateStreamStatus AwaitingResponse)
          (streamID stream)
          table
        |
-- TODO: confirm status of stream before closing it
-- This happy path is not production ready
          (isEndOrError . flags . header $ msg) -> Map.adjust
          (updateStreamStatus AwaitingCloseRecv)
          (streamID stream)
          table
        | otherwise -> table
  when (isLeft res) $ do
    let (_, c) = fromMaybe undefined $ Map.lookup (streamID stream) table
    -- TODO: Find better way of closing channels on write error
    atomically $ forM_ [1 .. 30] (\_ -> writeTChan c Nothing)
  atomically $ putTMVar (streamTable stream) table'
  return res
 where
  updateStreamStatus status (s, c) = (s { status = status }, c)
  closeStream table tream = case direction stream of
    Incoming -> Map.adjust (updateStreamStatus Closed) (streamID stream) table
    Outgoing -> Map.delete (streamID stream) table


-- TODO: Log important information for incoming messages
handleMessage :: Handler h => h -> ConnState -> Message -> IO Bool
handleMessage handler conn msg = do
  nextReqNum <- atomically $ readTMVar (nextIncomingReqNum conn)
  if
    | (header msg) == goodByeHeader -> return False
    | (isRequest nextReqNum msg) -> serveRequest handler conn msg
    | (isEndOfStream msg) -> internalCloseStream conn msg
    | (isAsyncResponse msg) -> do
      demux conn msg
      internalCloseStream conn msg
    | otherwise -> demux conn msg
 where
  internalCloseStream conn msg = do
    let flags          = def { isStream = True, isEndOrError = True }
    let requestNumber' = (requestNumber . header $ msg)
    let mVarTable =
          if requestNumber' > 0 then streamsIn conn else streamsOut conn
    table <- atomically $ takeTMVar mVarTable
    let streamID'      = abs requestNumber'
    let (stream, chan) = fromMaybe undefined $ Map.lookup streamID' table
    newStatus <- case status stream of
      Open -> do
        let msg = fromRight undefined $ newCloseNotification streamID'
        _ <- writeMessage conn msg
        return AwaitingCloseRecv
      AwaitingCloseRecv -> do
        atomically $ writeTChan chan Nothing
        return Closed
      AwaitingResponse -> do
        atomically $ writeTChan chan Nothing
        return Closed
      Closed -> do
        logMsg conn
          $  "received end of stream for already closed stream: "
          <> (show streamID' :: Text)
        return Closed
    atomically $ putTMVar mVarTable $ case direction stream of
      Incoming -> Map.adjust (updateStreamStatus newStatus) streamID' table
      -- TODO: properly handle close of outgoing Duplex streams
      Outgoing -> Map.delete streamID' table
    return True
    where updateStreamStatus status (s, c) = (s { status = status }, c)

  serveRequest handler conn msg = do
    let req = decodeJSONMessage msg
    case req of
      Left err -> do
        logMsg conn $ "could not decode request: " <> err
        return False
      Right req -> do
        let stream = Stream { streamID   = requestNumber . header $ msg
                            , streamType = typ req
                            , conn       = conn
                            , direction  = Incoming
                            , status     = Open
                            , peer       = connPeer conn
                            }
        chan <- newTChanIO
        err  <- atomically $ do
          reqNum <- takeTMVar (nextIncomingReqNum conn)
          table  <- takeTMVar (streamsIn conn)
          if Map.size table == (maxBound :: Int)
            then return $ Left errTooManyRequests
            else do
              let (reqNum', table') = if streamID stream == reqNum
                    then (reqNum + 1, (Map.insert reqNum (stream, chan) table))
                    else (reqNum, table)
              putTMVar (nextIncomingReqNum conn) reqNum'
              putTMVar (streamsIn conn)          table'
              return . return $ ()
        case err of
          Left err -> do
            putStrLn err
            return False
          Right _ -> do
            -- Serving a request call, how do we close it nicely?
            let endpoint = Endpoint (name req) (typ req)
            spawnConnection stream $ (serve handler) endpoint (args req) stream
            return True
    where errTooManyRequests = "connection limit reached, dropping request" :: Text

  demux conn msg = do
    let reqNum    = requestNumber . header $ msg
    let table = if reqNum > 0 then streamsIn conn else streamsOut conn
    let streamID' = abs reqNum
    table' <- atomically $ readTMVar table
    case Map.lookup streamID' table' of
      Nothing -> do
        logDebug conn
          $  "message dropped due to missing stream: "
          <> (show msg :: Text)
        return ()
      Just (_, chan) -> atomically $ writeTChan chan (Just msg)
    return True

-- | disconnect is a hacked version, it rudely disconnects all connections in
-- order to unlock streams stuck on their reading their channels.
--
-- Look into closing streams elegantly.
-- TODO: Ensure disconnect prevents new incoming and outgoing connections
disconnect :: ConnState -> IO ()
disconnect conn = do
  table <- atomically $ takeTMVar (streamsIn conn)
  forM_ table
    $ \(_, c) -> atomically $ forM_ [1 .. 30] (\_ -> writeTChan c Nothing)
  let table' = Map.map (\(s, c) -> (s { status = Closed }, c)) table
  atomically $ putTMVar (streamsIn conn) table'

  table <- atomically $ takeTMVar (streamsOut conn)
  forM_ table
    $ \(_, c) -> atomically $ forM_ [1 .. 30] (\_ -> writeTChan c Nothing)
  let table' = Map.map (\(s, c) -> (s { status = Closed }, c)) table
  atomically $ putTMVar (streamsOut conn) table'

  _ <- writeMessage conn $ Message goodByeHeader (BinaryPayload "")
  return ()

-- | readBoxStream reads the given number of bytes from the Box Stream.
readBoxStream :: BoxStream.Conn -> Int -> IO (Either Text ByteString)
readBoxStream conn bytes = do
  mbuf <- BoxStream.readStream conn bytes
  let buf = join $ (maybeToRight errUnexpectedClose) <$> mbuf
  return buf
  where errUnexpectedClose = "rpc.readConn: unexpected end of box stream"


-- | readMessage reads a single message from the RPC connection.
readMessage :: ConnState -> IO (Either Text Message)
readMessage conn = do
  headerBuf <- readBoxStream (boxConn conn) headerLength
  header    <- (return $ headerBuf >>= decode)
  case header of
    Left  err    -> return $ Left $ "RPC.readMessage header: " <> err
    Right header -> do
      bodyBuf <- readBoxStream (boxConn conn) (fromIntegral $ bodyLength header)
      let ret = liftA2 (<>) headerBuf bodyBuf >>= decode
      logDebug conn
        $  "readMessage ("
        <> (Ssb.formatPublicKey $ connPeer conn)
        <> ")"
        <> (show ret)
      return $ ret
 where
  decode :: Serialize a => ByteString -> Either Text a
  decode = mapLeft toS . Serialize.decode

writeMessage :: ConnState -> Message -> IO (Either Text ())
writeMessage conn msg = do
  logDebug conn
    $  "writeMessage ("
    <> (Ssb.formatPublicKey $ connPeer conn)
    <> ")"
    <> (show msg)
  BoxStream.sendStream (boxConn conn) $ Serialize.encode msg

-- | request makes a Remote Procedure Call on the peer.
request
  :: ToJSON a
  => ConnState
  -> Request a
  -> (Stream -> IO (Either Text b))
  -> IO (Either Text b)
request conn req session = do
  reqNum <- atomically $ takeTMVar (nextOutgoingReqNum conn)
  let msg = newJSONMessage reqNum req
  case msg of
    Left err -> do
      atomically $ putTMVar (nextOutgoingReqNum conn) reqNum
      return $ Left err
    Right msg -> do
      let nextReqNum = reqNum + 1
      atomically $ putTMVar (nextOutgoingReqNum conn) nextReqNum

      -- TODO: Fix late night update
      let flags' = flags . header $ msg
      let
        msg' = msg
          { header = (header msg) { flags = flags'
                                    { isStream = typ req /= Async
                                    }
                                  }
          }

      let stream = Stream { streamID   = reqNum
                          , streamType = (typ req)
                          , conn       = conn
                          , direction  = Outgoing
                          , status     = Open
                          , peer       = connPeer conn
                          }
      atomically $ do
        streams <- takeTMVar (streamsOut conn)
        chan    <- newTChan
        putTMVar (streamsOut conn) $ Map.insert reqNum (stream, chan) streams

      _      <- writeMessage' stream msg'
      result <- session stream
      if streamType stream == Async
        then return . return $ ()
        else either (closeStreamWithError stream)
                    (\x -> closeStream stream)
                    result
      return result
  where writeMessage' = manageStreamConn

-- | requestAsync is the Async version of request.
requestAsync
  :: (ToJSON a, FromJSON b) => ConnState -> Request a -> IO (Either Text b)
requestAsync conn req = do
  resp <- request conn req readStreamJSON
  return $ resp >>= withErr "not response received"

-- TODO: Use CloseNotification to signal close of stream.
--   How is the JSON encoding done?
data CloseNotification = CloseNotification ()

newCloseNotification :: Int32 -> Either Text Message
newCloseNotification streamID = do
  let payload      = JSONPayload "true"
  let requestNumer = -1 * streamID
  header <- newHeader
    (def { isEndOrError = True, isStream = True, bodyType = JSON })
    requestNumer
    "true"
  return $ Message header payload

-- | closeStream politely shuts down an RPC stream.
-- TODO: closeStream needs to be more adaptive.  It only accounts for
-- real streams.
closeStream :: Stream -> IO (Either Text ())
closeStream stream = case (status stream, streamType stream) of
  (_   , Async) -> return . return $ ()
  (Open, _    ) -> either (return . Left) (manageStreamConn stream)
    $ newCloseNotification (streamID stream)
  (Closed, _) -> return . return $ ()
  (_     , _) -> return $ Left "could not close stream"

-- ErrorNotification is the format used to communicate a stream error between peers.
data ErrorNotification = ErrorNotification {
      message :: Text
    , stack   :: Maybe Text
    }

instance FromJSON ErrorNotification where
  parseJSON = withObject "Error"
    $ \v -> ErrorNotification <$> v .: "message" <*> v .: "stack"

instance ToJSON ErrorNotification where
  toJSON arg = object
    [ "name" .= ("error" :: Text)
    , "message" .= message arg
    , "stack" .= stack arg
    ]

-- TODO: deduplicate closeStreamWithError and closeStream code
-- TODO: Verify close operation on error
--    Does the remote end send an 'true' message to confirm?  Or is the
--    connection simply dropped.
closeStreamWithError' :: Stream -> Text -> IO (Either Text ())
closeStreamWithError' stream err = do
  if streamType stream == Duplex || direction stream == Incoming
    then do
      let msg   = ErrorNotification err Nothing
      let flags = def { isStream = True, isEndOrError = True }
      err <- writeStream stream
                         flags
                         (JSONPayload $ BS.toStrict $ Aeson.encode msg)
      either
        (\err ->
          logMsg (conn stream) $ "could not notify remote of error: " <> err
        )
        return
        err
    else logMsg (conn stream) $ "closing stream with error: " <> err
  atomically $ do
    let table = streamTable stream
    value <- takeTMVar table
    putTMVar table $ Map.delete (streamID stream) value
  return . return $ ()
 where
  streamTable stream = case direction stream of
    Incoming -> streamsIn $ conn stream
    Outgoing -> streamsOut $ conn stream

-- TODO: Merge newCloseErrorNotification with newCloseNotification It has
-- proven difficult to use the ToJSON typeclass as an argument.
newCloseErrorNotification :: Int32 -> Text -> Either Text Message
newCloseErrorNotification streamID msg = do
  let payload = JSONPayload $ BS.toStrict $ Aeson.encode msg
  let requestNumer = -1 * streamID
  header <- newHeader
    (def { isEndOrError = True, isStream = True, bodyType = JSON })
    requestNumer
    "true"
  return $ Message header payload

closeStreamWithError :: Stream -> Text -> IO (Either Text ())
closeStreamWithError stream err = case streamType stream of
  Async -> return . return $ ()
  otherwise ->
    either (return . Left) (manageStreamConn stream)
      $ newCloseErrorNotification (streamID stream) err

-- TODO: Properly translate RPC errors
-- TODO: Close stream within lock to avoid race conditions

readStream :: Stream -> IO (Maybe MessagePayload)
readStream stream = do
  let table = streamTable stream
  table' <- atomically $ readTMVar table
  let result = Map.lookup (streamID stream) table'
  msg <- case result of
    Nothing               -> return Nothing
    Just (stream', chan') -> if status stream' == Closed
      then atomically $ join <$> tryReadTChan chan'
      else atomically $ readTChan chan'
  table'' <- atomically $ takeTMVar table
  let table''' = case msg of
        Nothing -> Map.delete (streamID stream) table'
        Just _  -> table''
  atomically $ putTMVar table table'''
  return (body <$> msg)

-- | readStreamJSON reads a single JSON message from the RPC stream.
readStreamJSON :: FromJSON a => Stream -> IO (Either Text (Maybe a))
readStreamJSON stream = do
  resp <- readStream stream
  case resp of
    (Just (JSONPayload buf)) -> return (Just <$> decodeJSON buf)
    (Just otherwise        ) -> return (errPayload resp)
    Nothing                  -> return . return $ Nothing
 where
  errPayload payload =
    error ("expected JSONPayload but got " <> (show payload :: Text))
  errEOS = error "end of stream"

writeStream :: Stream -> Flags -> MessagePayload -> IO (Either Text ())
writeStream stream flags payload = do
  status <- atomically $ streamStatus stream
  if status /= Open
    then return $ Left "stream closed"
    else do
      let msgID = case direction stream of
            Incoming -> -1 * streamID stream
            Outgoing -> streamID stream
      let flags' = flags { isStream = not $ (streamType stream) == Async }
      let header' = case payload of
            (BinaryPayload p) ->
              newHeader (flags' { bodyType = Binary }) msgID p
            (TextPayload p) ->
              newHeader (flags' { bodyType = UTF8String }) msgID p
            (JSONPayload p) -> newHeader (flags' { bodyType = JSON }) msgID p
      case header' of
        Left  err     -> return $ Left err
        Right header' -> do
          let msg = Message header' payload
          writeMessage' stream msg
  where writeMessage' = manageStreamConn

-- | writeStreamJSON writes a single JSON message to the RPC stream.
writeStreamJSON :: ToJSON a => Stream -> a -> IO (Either Text ())
writeStreamJSON stream msg = do
  let buf = Aeson.encode msg
  writeStream stream def (JSONPayload $ BS.toStrict buf)