aboutsummaryrefslogtreecommitdiff
path: root/src/Ssb/Peer/RPC.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Ssb/Peer/RPC.hs')
-rw-r--r--src/Ssb/Peer/RPC.hs859
1 files changed, 859 insertions, 0 deletions
diff --git a/src/Ssb/Peer/RPC.hs b/src/Ssb/Peer/RPC.hs
new file mode 100644
index 0000000..e5a4c9a
--- /dev/null
+++ b/src/Ssb/Peer/RPC.hs
@@ -0,0 +1,859 @@
+-- For more information kindly refer the to protocol guide
+-- https://ssbc.github.io/scuttlebutt-protocol-guide
+
+module Ssb.Peer.RPC where
+
+import Protolude
+
+import Control.Concurrent.STM
+import Control.Monad.Fail
+import Data.Aeson ( FromJSON
+ , ToJSON
+ )
+import Data.Aeson as Aeson
+import qualified Data.ByteString as BS
+import qualified Data.ByteString.Lazy as BS
+ ( toStrict )
+import qualified Data.Map.Strict as Map
+import Data.Default
+import Data.Either.Combinators ( mapLeft
+ , mapRight
+ )
+import Data.Text as Text
+import Data.Serialize as Serialize
+
+import Ssb.Aux
+import qualified Ssb.Identity as Ssb
+import qualified Ssb.Peer.BoxStream as BoxStream
+
+data BodyType = Binary | UTF8String | JSON | UnknownBodyType
+ deriving (Bounded,Eq,Show)
+
+instance Convertible Word8 BodyType where
+ convert v = case v .&. 3 of
+ 0 -> Binary
+ 1 -> UTF8String
+ 2 -> JSON
+ v -> UnknownBodyType
+
+instance Convertible BodyType Word8 where
+ convert Binary = 0
+ convert UTF8String = 1
+ convert JSON = 2
+ convert UnknownBodyType = 3
+
+data Flags = Flags
+ { unused1 :: Bool
+ , unused2 :: Bool
+ , unused3 :: Bool
+ , unused4 :: Bool
+ , isStream :: Bool
+ , isEndOrError :: Bool
+ , bodyType :: BodyType
+ } deriving (Eq,Show)
+
+instance Convertible Flags Word8 where
+ convert arg =
+ set (unused1 arg) 7
+ $ set (unused2 arg) 6
+ $ set (unused3 arg) 5
+ $ set (unused4 arg) 4
+ $ set (isStream arg) 3
+ $ set (isEndOrError arg) 2
+ $ convert (bodyType arg)
+ where
+ set True pos arg = setBit arg pos
+ set False _ arg = arg
+
+instance Convertible Word8 Flags where
+ convert w = Flags { unused1 = testBit w 7
+ , unused2 = testBit w 6
+ , unused3 = testBit w 5
+ , unused4 = testBit w 4
+ , isStream = testBit w 3
+ , isEndOrError = testBit w 2
+ , bodyType = convert w
+ }
+
+instance Default Flags where
+ def = convert (zeroBits :: Word8)
+
+instance Serialize.Serialize Flags where
+ get = convert <$> getWord8
+ put = putWord8 . convert
+
+-- | ProcedureType defines the type of remote call.
+data ProcedureType = Async | Source | Duplex
+ deriving (Eq,Generic,Ord,Show)
+
+instance FromJSON ProcedureType where
+ parseJSON = withText "ProcedureType" $ \v -> case v of
+ "async" -> return Async
+ "source" -> return Source
+ "duplex" -> return Duplex
+ otherwise -> fail $ "unknown value '" <> toS v <> "'"
+
+instance ToJSON ProcedureType where
+ toJSON Async = "async"
+ toJSON Source = "source"
+ toJSON Duplex = "duplex"
+
+-- | HeaderLength is the length of the RPC header in bytes
+headerLength :: Int
+headerLength = 9
+
+-- | bodySizeLength is the length of the bodySize parameter in the RPC header.
+bodySizeLength :: Int
+bodySizeLength = 4
+
+-- | requestNumberLength is the length of the requestNumberLength parameter in
+-- the RPC header.
+requestNumberLength :: Int
+requestNumberLength = 4
+
+-- | Header is the first part of a RPC message used for stream control and communication.
+data Header = Header
+ { flags :: Flags
+ , bodyLength :: Word32
+ , requestNumber :: Int32
+ } deriving (Eq,Generic,Show)
+
+instance Serialize.Serialize Header
+
+-- | GoodByeHeader is the RPC message header signalling the end of the RPC
+-- stream.
+goodByeHeader :: Header
+goodByeHeader = Header (convert (zeroBits :: Word8)) 0 0
+
+newHeader :: Flags -> Int32 -> ByteString -> Either Text Header
+newHeader flags reqNum body = do
+ let bodyLength = fromIntegral $ BS.length body
+ return Header { flags = flags
+ , bodyLength = bodyLength
+ , requestNumber = reqNum
+ }
+
+-- | MessagePayload describes and contains the contents of an RPC message.
+data MessagePayload = BinaryPayload ByteString | TextPayload ByteString | JSONPayload ByteString
+ deriving (Generic,Show)
+
+-- | Message is a single message in the RPC stream
+data Message = Message
+ { header :: Header
+ , body :: MessagePayload
+ } deriving (Generic,Show)
+
+instance Serialize.Serialize Message where
+ get = do
+ header <- Serialize.get
+ buf <- Serialize.getByteString (fromIntegral $ bodyLength header)
+ let typ = bodyType . flags $ header
+ payload <- case typ of
+ Binary -> return $ BinaryPayload buf
+ UTF8String -> return $ TextPayload buf
+ JSON -> return $ JSONPayload buf
+ UnknownBodyType -> fail "unknown body type"
+ return $ Message header payload
+
+ put msg = do
+ Serialize.put (header msg)
+ case body msg of
+ BinaryPayload buf -> Serialize.putByteString buf
+ JSONPayload buf -> Serialize.putByteString buf
+ TextPayload buf -> Serialize.putByteString buf
+
+-- | newJSONMessage is a convenience function to create a RPC message with a
+-- JSON payload.
+newJSONMessage :: ToJSON a => Int32 -> a -> Either Text Message
+newJSONMessage reqNum payload = do
+ let buf = BS.toStrict $ Aeson.encode payload
+ len <- maybeWord8 $ BS.length buf
+ header' <- newHeader (def { bodyType = JSON }) reqNum buf
+ return $ Message { header = header', body = JSONPayload buf }
+
+-- | newJSONMessage is a convenience function to decode a RPC JSON message.
+decodeJSONMessage :: FromJSON a => Message -> Either Text a
+decodeJSONMessage msg = case body msg of
+ JSONPayload buf -> decodeJSON buf
+ _ -> error "unexpected body type"
+
+isRequest :: Int32 -> Message -> Bool
+isRequest nextReqNum msg = do
+ let reqNum = requestNumber . header $ msg
+ nextReqNum == reqNum
+
+-- | Request is the message representing a RPC call.
+data Request a = Request
+ { name :: [Text]
+ , typ :: ProcedureType
+ , args :: a
+ } deriving (Show)
+
+instance (FromJSON a) => FromJSON (Request a) where
+ parseJSON = withObject "Request"
+ $ \v -> Request <$> v .: "name" <*> v .:? "type" .!= Async <*> v .: "args"
+
+instance (ToJSON a) => ToJSON (Request a) where
+ toJSON arg =
+ object ["name" .= name arg, "type" .= typ arg, "args" .= args arg]
+
+data Direction = Incoming | Outgoing
+ deriving (Eq,Generic,Show)
+
+data StreamStatus =
+ -- | Open the stream is active.
+ Open
+ -- | The (Async) stream is awaiting a response.
+ | AwaitingResponse
+ -- | The stream is closed and waiting for its peer to do the same.
+ | AwaitingCloseRecv
+ -- | The stream is closed.
+ | Closed
+ deriving (Eq,Show)
+
+data Stream = Stream {
+ streamID :: Int32
+ , streamType :: ProcedureType
+ , conn :: ConnState
+ , direction :: Direction
+ , status :: StreamStatus
+ , peer :: Ssb.PublicKey
+ }
+
+formatStream :: Stream -> Text
+formatStream stream =
+ (show $ streamID stream :: Text)
+ <> "\t"
+ <> show (streamType stream)
+ <> "\t"
+ <> show (direction stream)
+ <> "\t"
+ <> show (status stream)
+
+foldStream
+ :: (a -> MessagePayload -> IO (Either Text a))
+ -> a
+ -> Stream
+ -> IO (Either Text a)
+foldStream fn acc stream = do
+ msg <- readStream stream
+ case msg of
+ (Just msg') -> do
+ acc' <- fn acc msg'
+ either (return . error) (\v -> foldStream fn v stream) acc'
+ Nothing -> return . return $ acc
+
+foldJSONStream
+ :: (FromJSON b)
+ => (a -> b -> IO (Either Text a))
+ -> a
+ -> Stream
+ -> IO (Either Text a)
+foldJSONStream fn acc stream = do
+ msg <- readStreamJSON stream
+ case msg of
+ (Right (Just msg')) -> do
+ acc' <- fn acc msg'
+ either (return . error) (\v -> foldJSONStream fn v stream) acc'
+ (Right Nothing) -> return . return $ acc
+ (Left err ) -> return $ error err
+
+data ConnState = ConnState {
+ connPeer :: Ssb.PublicKey
+ , boxConn :: BoxStream.Conn
+ , streamsIn :: TMVar (Map Int32 (Stream, TChan (Maybe Message)))
+ , streamsOut :: TMVar (Map Int32 (Stream, TChan (Maybe Message)))
+ , nextIncomingReqNum :: TMVar Int32
+ , nextOutgoingReqNum :: TMVar Int32
+ , lock :: TMVar Bool
+ }
+
+-- | Endpoint identifies which Remote Procedure Call to make.
+data Endpoint = Endpoint [Text] ProcedureType
+ deriving (Eq, Ord)
+
+formatEndpoint :: Endpoint -> Text
+formatEndpoint (Endpoint paths typ) =
+ Text.intercalate "." paths <> ":" <> show typ
+
+-- | HandlerFunc's are used to serve Remove Procedure Calls.
+type HandlerFunc = Aeson.Value -> Stream -> IO (Either Text ())
+
+-- | notFoundHandlerFunc is tells the peer the Endpoint does not exist.
+notFoundHandlerFunc :: Endpoint -> HandlerFunc
+notFoundHandlerFunc endpoint _ _ = return
+ (Left $ "endpoint not found '" <> endpoint' <> "'")
+ where endpoint' = formatEndpoint endpoint
+
+-- | Handler can serve Remote Procedure Requests.
+class Handler h where
+ endpoints
+ :: h
+ -> [Endpoint]
+
+ -- | Serve calls an incoming Remote Procedure Call.
+ serve
+ :: h
+ -> Endpoint
+ -> HandlerFunc
+
+ -- TODO: Maybe add 'conn' to notifyConnect
+
+ -- | notifyConnect tells the handler when a peer has connected.
+ notifyConnect
+ :: h
+ -> Ssb.PublicKey
+ -> IO (Either Text ())
+
+ -- | notifyDisconnect tells the handler when a peer has disconnected.
+ notifyDisconnect
+ :: h
+ -> Ssb.PublicKey
+ -> IO (Either Text ())
+
+-- | Router serves Remote Procedure Calls with a set of handlers.
+data Router = Router {
+ endpointHandlers :: Map Endpoint HandlerFunc
+ , connectCallbacks :: [Ssb.PublicKey -> IO (Either Text ())]
+ , disconnectCallbacks :: [Ssb.PublicKey -> IO (Either Text ())]
+ } deriving (Generic)
+
+instance Default Router
+
+instance Handler Router where
+ endpoints demuxer = Map.keys (endpointHandlers demuxer)
+
+ serve demuxer endpoint = case Map.lookup endpoint endpointHandlers' of
+ Nothing -> (notFoundHandlerFunc endpoint)
+ Just handler -> handler
+ where endpointHandlers' = endpointHandlers demuxer
+
+ notifyConnect demuxer id = do
+ errs <- fmap lefts . sequence $ (\f -> f id) <$> connectCallbacks demuxer
+ if Protolude.null errs
+ then return . return $ ()
+ else return $ Left $ Text.intercalate ", " errs
+
+ notifyDisconnect demuxer id = do
+ errs <- fmap lefts . sequence $ (\f -> f id) <$> connectCallbacks demuxer
+ if Protolude.null errs
+ then return . return $ ()
+ else return $ Left $ Text.intercalate ", " errs
+
+
+-- | with adds a handler to the Router.
+with :: Handler h => Router -> h -> Router
+with demuxer handler = Router
+ { endpointHandlers = Map.union (endpointHandlers demuxer) $ Map.fromList
+ ((\e -> (e, serve handler e)) <$> endpoints handler)
+ , connectCallbacks = connectCallbacks demuxer <> [notifyConnect handler]
+ , disconnectCallbacks = disconnectCallbacks demuxer
+ <> [notifyDisconnect handler]
+ }
+
+-- | withM is a convenience function for using 'with' in monads.
+withM :: (MonadIO m, Handler h) => m Router -> m h -> m Router
+withM demuxer handler = with <$> demuxer <*> handler
+
+logMsg :: ConnState -> Text -> IO ()
+logMsg conn msg = do
+ _ <- atomically $ takeTMVar (lock conn)
+ print msg
+ atomically $ putTMVar (lock conn) True
+
+logDebug :: ConnState -> Text -> IO ()
+logDebug _ _ = return ()
+-- logDebug = logMsg
+
+-- | spawnConnection handles safely forking and closing RPC connections.
+spawnConnection
+ :: Stream
+ -> IO (Either Text ())
+ -> IO ()
+spawnConnection stream action = do
+ forkFinally
+ (do
+ res <- action
+ either (closeStreamWithError stream) (\_ -> closeStream stream) res
+ )
+ (\_ -> void (closeStream stream))
+ return ()
+
+-- | connect creates a Remote Procdure Call stream connection over the give Box
+-- Stream.
+connect
+ :: Handler h
+ => BoxStream.Conn
+ -> h
+ -> Ssb.PublicKey
+ -> (ConnState -> IO ())
+ -> IO (Either Text ())
+connect boxConn handler peer client = do
+ streamsIn <- newTMVarIO Map.empty
+ streamsOut <- newTMVarIO Map.empty
+ nextIncomingReqNum <- newTMVarIO 1
+ nextOutgoingReqNum <- newTMVarIO 1
+ lock <- newTMVarIO True
+ let conn = ConnState { connPeer = peer
+ , boxConn = boxConn
+ , streamsIn = streamsIn
+ , streamsOut = streamsOut
+ , nextIncomingReqNum = nextIncomingReqNum
+ , nextOutgoingReqNum = nextOutgoingReqNum
+ , lock = lock
+ }
+
+ forkIO $
+ -- make RPC calls on the peer
+ client conn
+ print "entering service loop"
+ ret <- serviceLoop conn
+ print "out of service loop"
+ _ <- notifyDisconnect handler (connPeer conn)
+ print "disconnecting"
+ disconnect conn
+ print "disconnected"
+ return ret
+ where
+ serviceLoop conn = do
+ msg <- readMessage conn
+ case msg of
+ Left err -> return $ error ("connection error: " <> err)
+ Right msg -> if header msg == goodByeHeader
+ then return . return $ ()
+ else do
+ ok <- handleMessage handler conn msg
+ if ok then serviceLoop conn else return . return $ ()
+
+-- TODO: Handle stream closing within demux loop
+-- The handleMessage loop is responsible for the lifetime of a stream
+-- connection. It detects and creates requests, it should therefor handle
+-- close requests.
+--
+-- The stream type matters when closing a stream. Async do not require a
+-- specific close message.
+-- TODO: Handle closing of Async streams
+-- TODO: Properly handle Error message
+
+-- TODO: Refactor or reallocate helper functions
+-- These functions were added in a rush to get request handling working. A
+-- better place could probably be found for them to improve code clarity.
+
+-- | isEndOfStream checks whether the message is the last of the stream.
+isEndOfStream = isEndOrError . flags . header
+
+-- | checks if the message is a response to an Async request.
+isAsyncResponse = not . isStream . flags . header
+
+-- | streamTable gets which lookup table to use for connections for the given
+-- stream. There are seperate ones for incoming and outgoing RPC streams.
+streamTable stream = case direction stream of
+ Incoming -> streamsIn (conn stream)
+ Outgoing -> streamsOut (conn stream)
+
+streamStatus :: Stream -> STM StreamStatus
+streamStatus stream = do
+ table <- readTMVar $ streamTable stream
+ return
+ $ fromMaybe Closed (status . fst <$> Map.lookup (streamID stream) table)
+
+-- | manageStreamConn manages stream connection changes when writing messages
+-- to a stream.
+-- TODO: beware of entering a deadlock in manageStreamConn
+-- TODO: request stuff should be moved here
+manageStreamConn :: Stream -> Message -> IO (Either Text ())
+manageStreamConn stream msg = do
+ table <- atomically $ takeTMVar $ streamTable stream
+ res <- writeMessage (conn stream) msg
+ let table' = if
+ | isLeft res -> closeStream table stream
+ | (not . isStream . flags . header $ msg) -> Map.adjust
+ (updateStreamStatus AwaitingResponse)
+ (streamID stream)
+ table
+ |
+-- TODO: confirm status of stream before closing it
+-- This happy path is not production ready
+ (isEndOrError . flags . header $ msg) -> Map.adjust
+ (updateStreamStatus AwaitingCloseRecv)
+ (streamID stream)
+ table
+ | otherwise -> table
+ when (isLeft res) $ do
+ let (_, c) = fromMaybe undefined $ Map.lookup (streamID stream) table
+ -- TODO: Find better way of closing channels on write error
+ atomically $ forM_ [1 .. 30] (\_ -> writeTChan c Nothing)
+ atomically $ putTMVar (streamTable stream) table'
+ return res
+ where
+ updateStreamStatus status (s, c) = (s { status = status }, c)
+ closeStream table tream = case direction stream of
+ Incoming -> Map.adjust (updateStreamStatus Closed) (streamID stream) table
+ Outgoing -> Map.delete (streamID stream) table
+
+
+-- TODO: Log important information for incoming messages
+handleMessage :: Handler h => h -> ConnState -> Message -> IO Bool
+handleMessage handler conn msg = do
+ nextReqNum <- atomically $ readTMVar (nextIncomingReqNum conn)
+ if
+ | (header msg) == goodByeHeader -> return False
+ | (isRequest nextReqNum msg) -> serveRequest handler conn msg
+ | (isEndOfStream msg) -> internalCloseStream conn msg
+ | (isAsyncResponse msg) -> do
+ demux conn msg
+ internalCloseStream conn msg
+ | otherwise -> demux conn msg
+ where
+ internalCloseStream conn msg = do
+ let flags = def { isStream = True, isEndOrError = True }
+ let requestNumber' = (requestNumber . header $ msg)
+ let mVarTable =
+ if requestNumber' > 0 then streamsIn conn else streamsOut conn
+ table <- atomically $ takeTMVar mVarTable
+ let streamID' = abs requestNumber'
+ let (stream, chan) = fromMaybe undefined $ Map.lookup streamID' table
+ newStatus <- case status stream of
+ Open -> do
+ let msg = fromRight undefined $ newCloseNotification streamID'
+ _ <- writeMessage conn msg
+ return AwaitingCloseRecv
+ AwaitingCloseRecv -> do
+ atomically $ writeTChan chan Nothing
+ return Closed
+ AwaitingResponse -> do
+ atomically $ writeTChan chan Nothing
+ return Closed
+ Closed -> do
+ logMsg conn
+ $ "received end of stream for already closed stream: "
+ <> (show streamID' :: Text)
+ return Closed
+ atomically $ putTMVar mVarTable $ case direction stream of
+ Incoming -> Map.adjust (updateStreamStatus newStatus) streamID' table
+ -- TODO: properly handle close of outgoing Duplex streams
+ Outgoing -> Map.delete streamID' table
+ return True
+ where updateStreamStatus status (s, c) = (s { status = status }, c)
+
+ serveRequest handler conn msg = do
+ let req = decodeJSONMessage msg
+ case req of
+ Left err -> do
+ logMsg conn $ "could not decode request: " <> err
+ return False
+ Right req -> do
+ let stream = Stream { streamID = requestNumber . header $ msg
+ , streamType = typ req
+ , conn = conn
+ , direction = Incoming
+ , status = Open
+ , peer = connPeer conn
+ }
+ chan <- newTChanIO
+ err <- atomically $ do
+ reqNum <- takeTMVar (nextIncomingReqNum conn)
+ table <- takeTMVar (streamsIn conn)
+ if Map.size table == (maxBound :: Int)
+ then return $ Left errTooManyRequests
+ else do
+ let (reqNum', table') = if streamID stream == reqNum
+ then (reqNum + 1, (Map.insert reqNum (stream, chan) table))
+ else (reqNum, table)
+ putTMVar (nextIncomingReqNum conn) reqNum'
+ putTMVar (streamsIn conn) table'
+ return . return $ ()
+ case err of
+ Left msg -> do
+ print msg
+ return False
+ Right _ -> do
+ -- Serving a request call, how do we close it nicely?
+ let endpoint = Endpoint (name req) (typ req)
+ spawnConnection stream $ (serve handler) endpoint (args req) stream
+ return True
+ where errTooManyRequests = "connection limit reached, dropping request"
+
+ demux conn msg = do
+ let reqNum = requestNumber . header $ msg
+ let table = if reqNum > 0 then streamsIn conn else streamsOut conn
+ let streamID' = abs reqNum
+ table' <- atomically $ readTMVar table
+ case Map.lookup streamID' table' of
+ Nothing -> do
+ logDebug conn
+ $ "message dropped due to missing stream: "
+ <> (show msg :: Text)
+ return ()
+ Just (_, chan) -> atomically $ writeTChan chan (Just msg)
+ return True
+
+-- | disconnect is a hacked version, it rudely disconnects all connections in
+-- order to unlock streams stuck on their reading their channels.
+--
+-- Look into closing streams elegantly.
+-- TODO: Ensure disconnect prevents new incoming and outgoing connections
+disconnect :: ConnState -> IO ()
+disconnect conn = do
+ table <- atomically $ takeTMVar (streamsIn conn)
+ forM_ table
+ $ \(_, c) -> atomically $ forM_ [1 .. 30] (\_ -> writeTChan c Nothing)
+ let table' = Map.map (\(s, c) -> (s { status = Closed }, c)) table
+ atomically $ putTMVar (streamsIn conn) table'
+
+ table <- atomically $ takeTMVar (streamsOut conn)
+ forM_ table
+ $ \(_, c) -> atomically $ forM_ [1 .. 30] (\_ -> writeTChan c Nothing)
+ let table' = Map.map (\(s, c) -> (s { status = Closed }, c)) table
+ atomically $ putTMVar (streamsOut conn) table'
+
+ _ <- writeMessage conn $ Message goodByeHeader (BinaryPayload "")
+ return ()
+
+-- | readBoxStream reads the given number of bytes from the Box Stream.
+readBoxStream :: BoxStream.Conn -> Int -> IO (Either Text ByteString)
+readBoxStream conn bytes = do
+ mbuf <- BoxStream.readStream conn bytes
+ let buf = join $ (maybeToRight errUnexpectedClose) <$> mbuf
+ return buf
+ where errUnexpectedClose = "rpc.readConn: unexpected end of box stream"
+
+
+-- | readMessage reads a single message from the RPC connection.
+readMessage :: ConnState -> IO (Either Text Message)
+readMessage conn = do
+ headerBuf <- readBoxStream (boxConn conn) headerLength
+ header <- (return $ headerBuf >>= decode)
+ case header of
+ Left err -> return $ Left $ "RPC.readMessage header: " <> err
+ Right header -> do
+ bodyBuf <- readBoxStream (boxConn conn) (fromIntegral $ bodyLength header)
+ let ret = liftA2 (<>) headerBuf bodyBuf >>= decode
+ logDebug conn
+ $ "readMessage ("
+ <> (Ssb.formatPublicKey $ connPeer conn)
+ <> ")"
+ <> (show ret)
+ return $ ret
+ where
+ decode :: Serialize a => ByteString -> Either Text a
+ decode = mapLeft toS . Serialize.decode
+
+writeMessage :: ConnState -> Message -> IO (Either Text ())
+writeMessage conn msg = do
+ logDebug conn
+ $ "writeMessage ("
+ <> (Ssb.formatPublicKey $ connPeer conn)
+ <> ")"
+ <> (show msg)
+ BoxStream.sendStream (boxConn conn) $ Serialize.encode msg
+
+-- | request makes a Remote Procedure Call on the peer.
+request
+ :: ToJSON a
+ => ConnState
+ -> Request a
+ -> (Stream -> IO (Either Text b))
+ -> IO (Either Text b)
+request conn req session = do
+ reqNum <- atomically $ takeTMVar (nextOutgoingReqNum conn)
+ let msg = newJSONMessage reqNum req
+ case msg of
+ Left err -> do
+ atomically $ putTMVar (nextOutgoingReqNum conn) reqNum
+ return $ Left err
+ Right msg -> do
+ let nextReqNum = reqNum + 1
+ atomically $ putTMVar (nextOutgoingReqNum conn) nextReqNum
+
+ -- TODO: Fix late night update
+ let flags' = flags . header $ msg
+ let
+ msg' = msg
+ { header = (header msg) { flags = flags'
+ { isStream = typ req /= Async
+ }
+ }
+ }
+
+ let stream = Stream { streamID = reqNum
+ , streamType = (typ req)
+ , conn = conn
+ , direction = Outgoing
+ , status = Open
+ , peer = connPeer conn
+ }
+ atomically $ do
+ streams <- takeTMVar (streamsOut conn)
+ chan <- newTChan
+ putTMVar (streamsOut conn) $ Map.insert reqNum (stream, chan) streams
+
+ _ <- writeMessage' stream msg'
+ result <- session stream
+ if streamType stream == Async
+ then return . return $ ()
+ else either (closeStreamWithError stream)
+ (\x -> closeStream stream)
+ result
+ return result
+ where writeMessage' = manageStreamConn
+
+-- | requestAsync is the Async version of request.
+requestAsync
+ :: (ToJSON a, FromJSON b) => ConnState -> Request a -> IO (Either Text b)
+requestAsync conn req = do
+ resp <- request conn req readStreamJSON
+ return $ resp >>= withErr "not response received"
+
+-- TODO: Use CloseNotification to signal close of stream.
+-- How is the JSON encoding done?
+data CloseNotification = CloseNotification ()
+
+newCloseNotification :: Int32 -> Either Text Message
+newCloseNotification streamID = do
+ let payload = JSONPayload "true"
+ let requestNumer = -1 * streamID
+ header <- newHeader
+ (def { isEndOrError = True, isStream = True, bodyType = JSON })
+ requestNumer
+ "true"
+ return $ Message header payload
+
+-- | closeStream politely shuts down an RPC stream.
+-- TODO: closeStream needs to be more adaptive. It only accounts for
+-- real streams.
+closeStream :: Stream -> IO (Either Text ())
+closeStream stream = case (status stream, streamType stream) of
+ (_ , Async) -> return . return $ ()
+ (Open, _ ) -> either (return . Left) (manageStreamConn stream)
+ $ newCloseNotification (streamID stream)
+ (Closed, _) -> return . return $ ()
+ (_ , _) -> return $ Left "could not close stream"
+
+-- ErrorNotification is the format used to communicate a stream error between peers.
+data ErrorNotification = ErrorNotification {
+ message :: Text
+ , stack :: Maybe Text
+ }
+
+instance FromJSON ErrorNotification where
+ parseJSON = withObject "Error"
+ $ \v -> ErrorNotification <$> v .: "message" <*> v .: "stack"
+
+instance ToJSON ErrorNotification where
+ toJSON arg = object
+ [ "name" .= ("error" :: Text)
+ , "message" .= message arg
+ , "stack" .= stack arg
+ ]
+
+-- TODO: deduplicate closeStreamWithError and closeStream code
+-- TODO: Verify close operation on error
+-- Does the remote end send an 'true' message to confirm? Or is the
+-- connection simply dropped.
+closeStreamWithError' :: Stream -> Text -> IO (Either Text ())
+closeStreamWithError' stream err = do
+ if streamType stream == Duplex || direction stream == Incoming
+ then do
+ let msg = ErrorNotification err Nothing
+ let flags = def { isStream = True, isEndOrError = True }
+ err <- writeStream stream
+ flags
+ (JSONPayload $ BS.toStrict $ Aeson.encode msg)
+ either
+ (\err ->
+ logMsg (conn stream) $ "could not notify remote of error: " <> err
+ )
+ return
+ err
+ else logMsg (conn stream) $ "closing stream with error: " <> err
+ atomically $ do
+ let table = streamTable stream
+ value <- takeTMVar table
+ putTMVar table $ Map.delete (streamID stream) value
+ return . return $ ()
+ where
+ streamTable stream = case direction stream of
+ Incoming -> streamsIn $ conn stream
+ Outgoing -> streamsOut $ conn stream
+
+-- TODO: Merge newCloseErrorNotification with newCloseNotification It has
+-- proven difficult to use the ToJSON typeclass as an argument.
+newCloseErrorNotification :: Int32 -> Text -> Either Text Message
+newCloseErrorNotification streamID msg = do
+ let payload = JSONPayload $ BS.toStrict $ Aeson.encode msg
+ let requestNumer = -1 * streamID
+ header <- newHeader
+ (def { isEndOrError = True, isStream = True, bodyType = JSON })
+ requestNumer
+ "true"
+ return $ Message header payload
+
+closeStreamWithError :: Stream -> Text -> IO (Either Text ())
+closeStreamWithError stream err = case streamType stream of
+ Async -> return . return $ ()
+ otherwise ->
+ either (return . Left) (manageStreamConn stream)
+ $ newCloseErrorNotification (streamID stream) err
+
+-- TODO: Properly translate RPC errors
+-- TODO: Close stream within lock to avoid race conditions
+
+readStream :: Stream -> IO (Maybe MessagePayload)
+readStream stream = do
+ let table = streamTable stream
+ table' <- atomically $ readTMVar table
+ let result = Map.lookup (streamID stream) table'
+ msg <- case result of
+ Nothing -> return Nothing
+ Just (stream', chan') -> if status stream' == Closed
+ then atomically $ join <$> tryReadTChan chan'
+ else atomically $ readTChan chan'
+ table'' <- atomically $ takeTMVar table
+ let table''' = case msg of
+ Nothing -> Map.delete (streamID stream) table'
+ Just _ -> table''
+ atomically $ putTMVar table table'''
+ return (body <$> msg)
+
+-- | readStreamJSON reads a single JSON message from the RPC stream.
+readStreamJSON :: FromJSON a => Stream -> IO (Either Text (Maybe a))
+readStreamJSON stream = do
+ resp <- readStream stream
+ case resp of
+ (Just (JSONPayload buf)) -> return (Just <$> decodeJSON buf)
+ (Just otherwise ) -> return (errPayload resp)
+ Nothing -> return . return $ Nothing
+ where
+ errPayload payload =
+ error ("expected JSONPayload but got " <> (show payload :: Text))
+ errEOS = error "end of stream"
+
+writeStream :: Stream -> Flags -> MessagePayload -> IO (Either Text ())
+writeStream stream flags payload = do
+ status <- atomically $ streamStatus stream
+ if status /= Open
+ then return $ Left "stream closed"
+ else do
+ let msgID = case direction stream of
+ Incoming -> -1 * streamID stream
+ Outgoing -> streamID stream
+ let flags' = flags { isStream = not $ (streamType stream) == Async }
+ let header' = case payload of
+ (BinaryPayload p) ->
+ newHeader (flags' { bodyType = Binary }) msgID p
+ (TextPayload p) ->
+ newHeader (flags' { bodyType = UTF8String }) msgID p
+ (JSONPayload p) -> newHeader (flags' { bodyType = JSON }) msgID p
+ case header' of
+ Left err -> return $ Left err
+ Right header' -> do
+ let msg = Message header' payload
+ writeMessage' stream msg
+ where writeMessage' = manageStreamConn
+
+-- | writeStreamJSON writes a single JSON message to the RPC stream.
+writeStreamJSON :: ToJSON a => Stream -> a -> IO (Either Text ())
+writeStreamJSON stream msg = do
+ let buf = Aeson.encode msg
+ writeStream stream def (JSONPayload $ BS.toStrict buf)