diff options
author | Haskell Guy <haskell.guy@localhost> | 2020-05-26 13:07:50 +0200 |
---|---|---|
committer | Haskell Guy <haskell.guy@localhost> | 2020-05-26 13:37:29 +0200 |
commit | 41cde99ec6189dbecca6803a5aa4f6f18142e8ba (patch) | |
tree | 7a0ceab0d516b8c3b7b49313100ae50c97e875c3 /src/Ssb/Peer/RPC.hs | |
download | ssb-haskell-41cde99ec6189dbecca6803a5aa4f6f18142e8ba.tar.xz |
initial commit
Diffstat (limited to 'src/Ssb/Peer/RPC.hs')
-rw-r--r-- | src/Ssb/Peer/RPC.hs | 859 |
1 files changed, 859 insertions, 0 deletions
diff --git a/src/Ssb/Peer/RPC.hs b/src/Ssb/Peer/RPC.hs new file mode 100644 index 0000000..e5a4c9a --- /dev/null +++ b/src/Ssb/Peer/RPC.hs @@ -0,0 +1,859 @@ +-- For more information kindly refer the to protocol guide +-- https://ssbc.github.io/scuttlebutt-protocol-guide + +module Ssb.Peer.RPC where + +import Protolude + +import Control.Concurrent.STM +import Control.Monad.Fail +import Data.Aeson ( FromJSON + , ToJSON + ) +import Data.Aeson as Aeson +import qualified Data.ByteString as BS +import qualified Data.ByteString.Lazy as BS + ( toStrict ) +import qualified Data.Map.Strict as Map +import Data.Default +import Data.Either.Combinators ( mapLeft + , mapRight + ) +import Data.Text as Text +import Data.Serialize as Serialize + +import Ssb.Aux +import qualified Ssb.Identity as Ssb +import qualified Ssb.Peer.BoxStream as BoxStream + +data BodyType = Binary | UTF8String | JSON | UnknownBodyType + deriving (Bounded,Eq,Show) + +instance Convertible Word8 BodyType where + convert v = case v .&. 3 of + 0 -> Binary + 1 -> UTF8String + 2 -> JSON + v -> UnknownBodyType + +instance Convertible BodyType Word8 where + convert Binary = 0 + convert UTF8String = 1 + convert JSON = 2 + convert UnknownBodyType = 3 + +data Flags = Flags + { unused1 :: Bool + , unused2 :: Bool + , unused3 :: Bool + , unused4 :: Bool + , isStream :: Bool + , isEndOrError :: Bool + , bodyType :: BodyType + } deriving (Eq,Show) + +instance Convertible Flags Word8 where + convert arg = + set (unused1 arg) 7 + $ set (unused2 arg) 6 + $ set (unused3 arg) 5 + $ set (unused4 arg) 4 + $ set (isStream arg) 3 + $ set (isEndOrError arg) 2 + $ convert (bodyType arg) + where + set True pos arg = setBit arg pos + set False _ arg = arg + +instance Convertible Word8 Flags where + convert w = Flags { unused1 = testBit w 7 + , unused2 = testBit w 6 + , unused3 = testBit w 5 + , unused4 = testBit w 4 + , isStream = testBit w 3 + , isEndOrError = testBit w 2 + , bodyType = convert w + } + +instance Default Flags where + def = convert (zeroBits :: Word8) + +instance Serialize.Serialize Flags where + get = convert <$> getWord8 + put = putWord8 . convert + +-- | ProcedureType defines the type of remote call. +data ProcedureType = Async | Source | Duplex + deriving (Eq,Generic,Ord,Show) + +instance FromJSON ProcedureType where + parseJSON = withText "ProcedureType" $ \v -> case v of + "async" -> return Async + "source" -> return Source + "duplex" -> return Duplex + otherwise -> fail $ "unknown value '" <> toS v <> "'" + +instance ToJSON ProcedureType where + toJSON Async = "async" + toJSON Source = "source" + toJSON Duplex = "duplex" + +-- | HeaderLength is the length of the RPC header in bytes +headerLength :: Int +headerLength = 9 + +-- | bodySizeLength is the length of the bodySize parameter in the RPC header. +bodySizeLength :: Int +bodySizeLength = 4 + +-- | requestNumberLength is the length of the requestNumberLength parameter in +-- the RPC header. +requestNumberLength :: Int +requestNumberLength = 4 + +-- | Header is the first part of a RPC message used for stream control and communication. +data Header = Header + { flags :: Flags + , bodyLength :: Word32 + , requestNumber :: Int32 + } deriving (Eq,Generic,Show) + +instance Serialize.Serialize Header + +-- | GoodByeHeader is the RPC message header signalling the end of the RPC +-- stream. +goodByeHeader :: Header +goodByeHeader = Header (convert (zeroBits :: Word8)) 0 0 + +newHeader :: Flags -> Int32 -> ByteString -> Either Text Header +newHeader flags reqNum body = do + let bodyLength = fromIntegral $ BS.length body + return Header { flags = flags + , bodyLength = bodyLength + , requestNumber = reqNum + } + +-- | MessagePayload describes and contains the contents of an RPC message. +data MessagePayload = BinaryPayload ByteString | TextPayload ByteString | JSONPayload ByteString + deriving (Generic,Show) + +-- | Message is a single message in the RPC stream +data Message = Message + { header :: Header + , body :: MessagePayload + } deriving (Generic,Show) + +instance Serialize.Serialize Message where + get = do + header <- Serialize.get + buf <- Serialize.getByteString (fromIntegral $ bodyLength header) + let typ = bodyType . flags $ header + payload <- case typ of + Binary -> return $ BinaryPayload buf + UTF8String -> return $ TextPayload buf + JSON -> return $ JSONPayload buf + UnknownBodyType -> fail "unknown body type" + return $ Message header payload + + put msg = do + Serialize.put (header msg) + case body msg of + BinaryPayload buf -> Serialize.putByteString buf + JSONPayload buf -> Serialize.putByteString buf + TextPayload buf -> Serialize.putByteString buf + +-- | newJSONMessage is a convenience function to create a RPC message with a +-- JSON payload. +newJSONMessage :: ToJSON a => Int32 -> a -> Either Text Message +newJSONMessage reqNum payload = do + let buf = BS.toStrict $ Aeson.encode payload + len <- maybeWord8 $ BS.length buf + header' <- newHeader (def { bodyType = JSON }) reqNum buf + return $ Message { header = header', body = JSONPayload buf } + +-- | newJSONMessage is a convenience function to decode a RPC JSON message. +decodeJSONMessage :: FromJSON a => Message -> Either Text a +decodeJSONMessage msg = case body msg of + JSONPayload buf -> decodeJSON buf + _ -> error "unexpected body type" + +isRequest :: Int32 -> Message -> Bool +isRequest nextReqNum msg = do + let reqNum = requestNumber . header $ msg + nextReqNum == reqNum + +-- | Request is the message representing a RPC call. +data Request a = Request + { name :: [Text] + , typ :: ProcedureType + , args :: a + } deriving (Show) + +instance (FromJSON a) => FromJSON (Request a) where + parseJSON = withObject "Request" + $ \v -> Request <$> v .: "name" <*> v .:? "type" .!= Async <*> v .: "args" + +instance (ToJSON a) => ToJSON (Request a) where + toJSON arg = + object ["name" .= name arg, "type" .= typ arg, "args" .= args arg] + +data Direction = Incoming | Outgoing + deriving (Eq,Generic,Show) + +data StreamStatus = + -- | Open the stream is active. + Open + -- | The (Async) stream is awaiting a response. + | AwaitingResponse + -- | The stream is closed and waiting for its peer to do the same. + | AwaitingCloseRecv + -- | The stream is closed. + | Closed + deriving (Eq,Show) + +data Stream = Stream { + streamID :: Int32 + , streamType :: ProcedureType + , conn :: ConnState + , direction :: Direction + , status :: StreamStatus + , peer :: Ssb.PublicKey + } + +formatStream :: Stream -> Text +formatStream stream = + (show $ streamID stream :: Text) + <> "\t" + <> show (streamType stream) + <> "\t" + <> show (direction stream) + <> "\t" + <> show (status stream) + +foldStream + :: (a -> MessagePayload -> IO (Either Text a)) + -> a + -> Stream + -> IO (Either Text a) +foldStream fn acc stream = do + msg <- readStream stream + case msg of + (Just msg') -> do + acc' <- fn acc msg' + either (return . error) (\v -> foldStream fn v stream) acc' + Nothing -> return . return $ acc + +foldJSONStream + :: (FromJSON b) + => (a -> b -> IO (Either Text a)) + -> a + -> Stream + -> IO (Either Text a) +foldJSONStream fn acc stream = do + msg <- readStreamJSON stream + case msg of + (Right (Just msg')) -> do + acc' <- fn acc msg' + either (return . error) (\v -> foldJSONStream fn v stream) acc' + (Right Nothing) -> return . return $ acc + (Left err ) -> return $ error err + +data ConnState = ConnState { + connPeer :: Ssb.PublicKey + , boxConn :: BoxStream.Conn + , streamsIn :: TMVar (Map Int32 (Stream, TChan (Maybe Message))) + , streamsOut :: TMVar (Map Int32 (Stream, TChan (Maybe Message))) + , nextIncomingReqNum :: TMVar Int32 + , nextOutgoingReqNum :: TMVar Int32 + , lock :: TMVar Bool + } + +-- | Endpoint identifies which Remote Procedure Call to make. +data Endpoint = Endpoint [Text] ProcedureType + deriving (Eq, Ord) + +formatEndpoint :: Endpoint -> Text +formatEndpoint (Endpoint paths typ) = + Text.intercalate "." paths <> ":" <> show typ + +-- | HandlerFunc's are used to serve Remove Procedure Calls. +type HandlerFunc = Aeson.Value -> Stream -> IO (Either Text ()) + +-- | notFoundHandlerFunc is tells the peer the Endpoint does not exist. +notFoundHandlerFunc :: Endpoint -> HandlerFunc +notFoundHandlerFunc endpoint _ _ = return + (Left $ "endpoint not found '" <> endpoint' <> "'") + where endpoint' = formatEndpoint endpoint + +-- | Handler can serve Remote Procedure Requests. +class Handler h where + endpoints + :: h + -> [Endpoint] + + -- | Serve calls an incoming Remote Procedure Call. + serve + :: h + -> Endpoint + -> HandlerFunc + + -- TODO: Maybe add 'conn' to notifyConnect + + -- | notifyConnect tells the handler when a peer has connected. + notifyConnect + :: h + -> Ssb.PublicKey + -> IO (Either Text ()) + + -- | notifyDisconnect tells the handler when a peer has disconnected. + notifyDisconnect + :: h + -> Ssb.PublicKey + -> IO (Either Text ()) + +-- | Router serves Remote Procedure Calls with a set of handlers. +data Router = Router { + endpointHandlers :: Map Endpoint HandlerFunc + , connectCallbacks :: [Ssb.PublicKey -> IO (Either Text ())] + , disconnectCallbacks :: [Ssb.PublicKey -> IO (Either Text ())] + } deriving (Generic) + +instance Default Router + +instance Handler Router where + endpoints demuxer = Map.keys (endpointHandlers demuxer) + + serve demuxer endpoint = case Map.lookup endpoint endpointHandlers' of + Nothing -> (notFoundHandlerFunc endpoint) + Just handler -> handler + where endpointHandlers' = endpointHandlers demuxer + + notifyConnect demuxer id = do + errs <- fmap lefts . sequence $ (\f -> f id) <$> connectCallbacks demuxer + if Protolude.null errs + then return . return $ () + else return $ Left $ Text.intercalate ", " errs + + notifyDisconnect demuxer id = do + errs <- fmap lefts . sequence $ (\f -> f id) <$> connectCallbacks demuxer + if Protolude.null errs + then return . return $ () + else return $ Left $ Text.intercalate ", " errs + + +-- | with adds a handler to the Router. +with :: Handler h => Router -> h -> Router +with demuxer handler = Router + { endpointHandlers = Map.union (endpointHandlers demuxer) $ Map.fromList + ((\e -> (e, serve handler e)) <$> endpoints handler) + , connectCallbacks = connectCallbacks demuxer <> [notifyConnect handler] + , disconnectCallbacks = disconnectCallbacks demuxer + <> [notifyDisconnect handler] + } + +-- | withM is a convenience function for using 'with' in monads. +withM :: (MonadIO m, Handler h) => m Router -> m h -> m Router +withM demuxer handler = with <$> demuxer <*> handler + +logMsg :: ConnState -> Text -> IO () +logMsg conn msg = do + _ <- atomically $ takeTMVar (lock conn) + print msg + atomically $ putTMVar (lock conn) True + +logDebug :: ConnState -> Text -> IO () +logDebug _ _ = return () +-- logDebug = logMsg + +-- | spawnConnection handles safely forking and closing RPC connections. +spawnConnection + :: Stream + -> IO (Either Text ()) + -> IO () +spawnConnection stream action = do + forkFinally + (do + res <- action + either (closeStreamWithError stream) (\_ -> closeStream stream) res + ) + (\_ -> void (closeStream stream)) + return () + +-- | connect creates a Remote Procdure Call stream connection over the give Box +-- Stream. +connect + :: Handler h + => BoxStream.Conn + -> h + -> Ssb.PublicKey + -> (ConnState -> IO ()) + -> IO (Either Text ()) +connect boxConn handler peer client = do + streamsIn <- newTMVarIO Map.empty + streamsOut <- newTMVarIO Map.empty + nextIncomingReqNum <- newTMVarIO 1 + nextOutgoingReqNum <- newTMVarIO 1 + lock <- newTMVarIO True + let conn = ConnState { connPeer = peer + , boxConn = boxConn + , streamsIn = streamsIn + , streamsOut = streamsOut + , nextIncomingReqNum = nextIncomingReqNum + , nextOutgoingReqNum = nextOutgoingReqNum + , lock = lock + } + + forkIO $ + -- make RPC calls on the peer + client conn + print "entering service loop" + ret <- serviceLoop conn + print "out of service loop" + _ <- notifyDisconnect handler (connPeer conn) + print "disconnecting" + disconnect conn + print "disconnected" + return ret + where + serviceLoop conn = do + msg <- readMessage conn + case msg of + Left err -> return $ error ("connection error: " <> err) + Right msg -> if header msg == goodByeHeader + then return . return $ () + else do + ok <- handleMessage handler conn msg + if ok then serviceLoop conn else return . return $ () + +-- TODO: Handle stream closing within demux loop +-- The handleMessage loop is responsible for the lifetime of a stream +-- connection. It detects and creates requests, it should therefor handle +-- close requests. +-- +-- The stream type matters when closing a stream. Async do not require a +-- specific close message. +-- TODO: Handle closing of Async streams +-- TODO: Properly handle Error message + +-- TODO: Refactor or reallocate helper functions +-- These functions were added in a rush to get request handling working. A +-- better place could probably be found for them to improve code clarity. + +-- | isEndOfStream checks whether the message is the last of the stream. +isEndOfStream = isEndOrError . flags . header + +-- | checks if the message is a response to an Async request. +isAsyncResponse = not . isStream . flags . header + +-- | streamTable gets which lookup table to use for connections for the given +-- stream. There are seperate ones for incoming and outgoing RPC streams. +streamTable stream = case direction stream of + Incoming -> streamsIn (conn stream) + Outgoing -> streamsOut (conn stream) + +streamStatus :: Stream -> STM StreamStatus +streamStatus stream = do + table <- readTMVar $ streamTable stream + return + $ fromMaybe Closed (status . fst <$> Map.lookup (streamID stream) table) + +-- | manageStreamConn manages stream connection changes when writing messages +-- to a stream. +-- TODO: beware of entering a deadlock in manageStreamConn +-- TODO: request stuff should be moved here +manageStreamConn :: Stream -> Message -> IO (Either Text ()) +manageStreamConn stream msg = do + table <- atomically $ takeTMVar $ streamTable stream + res <- writeMessage (conn stream) msg + let table' = if + | isLeft res -> closeStream table stream + | (not . isStream . flags . header $ msg) -> Map.adjust + (updateStreamStatus AwaitingResponse) + (streamID stream) + table + | +-- TODO: confirm status of stream before closing it +-- This happy path is not production ready + (isEndOrError . flags . header $ msg) -> Map.adjust + (updateStreamStatus AwaitingCloseRecv) + (streamID stream) + table + | otherwise -> table + when (isLeft res) $ do + let (_, c) = fromMaybe undefined $ Map.lookup (streamID stream) table + -- TODO: Find better way of closing channels on write error + atomically $ forM_ [1 .. 30] (\_ -> writeTChan c Nothing) + atomically $ putTMVar (streamTable stream) table' + return res + where + updateStreamStatus status (s, c) = (s { status = status }, c) + closeStream table tream = case direction stream of + Incoming -> Map.adjust (updateStreamStatus Closed) (streamID stream) table + Outgoing -> Map.delete (streamID stream) table + + +-- TODO: Log important information for incoming messages +handleMessage :: Handler h => h -> ConnState -> Message -> IO Bool +handleMessage handler conn msg = do + nextReqNum <- atomically $ readTMVar (nextIncomingReqNum conn) + if + | (header msg) == goodByeHeader -> return False + | (isRequest nextReqNum msg) -> serveRequest handler conn msg + | (isEndOfStream msg) -> internalCloseStream conn msg + | (isAsyncResponse msg) -> do + demux conn msg + internalCloseStream conn msg + | otherwise -> demux conn msg + where + internalCloseStream conn msg = do + let flags = def { isStream = True, isEndOrError = True } + let requestNumber' = (requestNumber . header $ msg) + let mVarTable = + if requestNumber' > 0 then streamsIn conn else streamsOut conn + table <- atomically $ takeTMVar mVarTable + let streamID' = abs requestNumber' + let (stream, chan) = fromMaybe undefined $ Map.lookup streamID' table + newStatus <- case status stream of + Open -> do + let msg = fromRight undefined $ newCloseNotification streamID' + _ <- writeMessage conn msg + return AwaitingCloseRecv + AwaitingCloseRecv -> do + atomically $ writeTChan chan Nothing + return Closed + AwaitingResponse -> do + atomically $ writeTChan chan Nothing + return Closed + Closed -> do + logMsg conn + $ "received end of stream for already closed stream: " + <> (show streamID' :: Text) + return Closed + atomically $ putTMVar mVarTable $ case direction stream of + Incoming -> Map.adjust (updateStreamStatus newStatus) streamID' table + -- TODO: properly handle close of outgoing Duplex streams + Outgoing -> Map.delete streamID' table + return True + where updateStreamStatus status (s, c) = (s { status = status }, c) + + serveRequest handler conn msg = do + let req = decodeJSONMessage msg + case req of + Left err -> do + logMsg conn $ "could not decode request: " <> err + return False + Right req -> do + let stream = Stream { streamID = requestNumber . header $ msg + , streamType = typ req + , conn = conn + , direction = Incoming + , status = Open + , peer = connPeer conn + } + chan <- newTChanIO + err <- atomically $ do + reqNum <- takeTMVar (nextIncomingReqNum conn) + table <- takeTMVar (streamsIn conn) + if Map.size table == (maxBound :: Int) + then return $ Left errTooManyRequests + else do + let (reqNum', table') = if streamID stream == reqNum + then (reqNum + 1, (Map.insert reqNum (stream, chan) table)) + else (reqNum, table) + putTMVar (nextIncomingReqNum conn) reqNum' + putTMVar (streamsIn conn) table' + return . return $ () + case err of + Left msg -> do + print msg + return False + Right _ -> do + -- Serving a request call, how do we close it nicely? + let endpoint = Endpoint (name req) (typ req) + spawnConnection stream $ (serve handler) endpoint (args req) stream + return True + where errTooManyRequests = "connection limit reached, dropping request" + + demux conn msg = do + let reqNum = requestNumber . header $ msg + let table = if reqNum > 0 then streamsIn conn else streamsOut conn + let streamID' = abs reqNum + table' <- atomically $ readTMVar table + case Map.lookup streamID' table' of + Nothing -> do + logDebug conn + $ "message dropped due to missing stream: " + <> (show msg :: Text) + return () + Just (_, chan) -> atomically $ writeTChan chan (Just msg) + return True + +-- | disconnect is a hacked version, it rudely disconnects all connections in +-- order to unlock streams stuck on their reading their channels. +-- +-- Look into closing streams elegantly. +-- TODO: Ensure disconnect prevents new incoming and outgoing connections +disconnect :: ConnState -> IO () +disconnect conn = do + table <- atomically $ takeTMVar (streamsIn conn) + forM_ table + $ \(_, c) -> atomically $ forM_ [1 .. 30] (\_ -> writeTChan c Nothing) + let table' = Map.map (\(s, c) -> (s { status = Closed }, c)) table + atomically $ putTMVar (streamsIn conn) table' + + table <- atomically $ takeTMVar (streamsOut conn) + forM_ table + $ \(_, c) -> atomically $ forM_ [1 .. 30] (\_ -> writeTChan c Nothing) + let table' = Map.map (\(s, c) -> (s { status = Closed }, c)) table + atomically $ putTMVar (streamsOut conn) table' + + _ <- writeMessage conn $ Message goodByeHeader (BinaryPayload "") + return () + +-- | readBoxStream reads the given number of bytes from the Box Stream. +readBoxStream :: BoxStream.Conn -> Int -> IO (Either Text ByteString) +readBoxStream conn bytes = do + mbuf <- BoxStream.readStream conn bytes + let buf = join $ (maybeToRight errUnexpectedClose) <$> mbuf + return buf + where errUnexpectedClose = "rpc.readConn: unexpected end of box stream" + + +-- | readMessage reads a single message from the RPC connection. +readMessage :: ConnState -> IO (Either Text Message) +readMessage conn = do + headerBuf <- readBoxStream (boxConn conn) headerLength + header <- (return $ headerBuf >>= decode) + case header of + Left err -> return $ Left $ "RPC.readMessage header: " <> err + Right header -> do + bodyBuf <- readBoxStream (boxConn conn) (fromIntegral $ bodyLength header) + let ret = liftA2 (<>) headerBuf bodyBuf >>= decode + logDebug conn + $ "readMessage (" + <> (Ssb.formatPublicKey $ connPeer conn) + <> ")" + <> (show ret) + return $ ret + where + decode :: Serialize a => ByteString -> Either Text a + decode = mapLeft toS . Serialize.decode + +writeMessage :: ConnState -> Message -> IO (Either Text ()) +writeMessage conn msg = do + logDebug conn + $ "writeMessage (" + <> (Ssb.formatPublicKey $ connPeer conn) + <> ")" + <> (show msg) + BoxStream.sendStream (boxConn conn) $ Serialize.encode msg + +-- | request makes a Remote Procedure Call on the peer. +request + :: ToJSON a + => ConnState + -> Request a + -> (Stream -> IO (Either Text b)) + -> IO (Either Text b) +request conn req session = do + reqNum <- atomically $ takeTMVar (nextOutgoingReqNum conn) + let msg = newJSONMessage reqNum req + case msg of + Left err -> do + atomically $ putTMVar (nextOutgoingReqNum conn) reqNum + return $ Left err + Right msg -> do + let nextReqNum = reqNum + 1 + atomically $ putTMVar (nextOutgoingReqNum conn) nextReqNum + + -- TODO: Fix late night update + let flags' = flags . header $ msg + let + msg' = msg + { header = (header msg) { flags = flags' + { isStream = typ req /= Async + } + } + } + + let stream = Stream { streamID = reqNum + , streamType = (typ req) + , conn = conn + , direction = Outgoing + , status = Open + , peer = connPeer conn + } + atomically $ do + streams <- takeTMVar (streamsOut conn) + chan <- newTChan + putTMVar (streamsOut conn) $ Map.insert reqNum (stream, chan) streams + + _ <- writeMessage' stream msg' + result <- session stream + if streamType stream == Async + then return . return $ () + else either (closeStreamWithError stream) + (\x -> closeStream stream) + result + return result + where writeMessage' = manageStreamConn + +-- | requestAsync is the Async version of request. +requestAsync + :: (ToJSON a, FromJSON b) => ConnState -> Request a -> IO (Either Text b) +requestAsync conn req = do + resp <- request conn req readStreamJSON + return $ resp >>= withErr "not response received" + +-- TODO: Use CloseNotification to signal close of stream. +-- How is the JSON encoding done? +data CloseNotification = CloseNotification () + +newCloseNotification :: Int32 -> Either Text Message +newCloseNotification streamID = do + let payload = JSONPayload "true" + let requestNumer = -1 * streamID + header <- newHeader + (def { isEndOrError = True, isStream = True, bodyType = JSON }) + requestNumer + "true" + return $ Message header payload + +-- | closeStream politely shuts down an RPC stream. +-- TODO: closeStream needs to be more adaptive. It only accounts for +-- real streams. +closeStream :: Stream -> IO (Either Text ()) +closeStream stream = case (status stream, streamType stream) of + (_ , Async) -> return . return $ () + (Open, _ ) -> either (return . Left) (manageStreamConn stream) + $ newCloseNotification (streamID stream) + (Closed, _) -> return . return $ () + (_ , _) -> return $ Left "could not close stream" + +-- ErrorNotification is the format used to communicate a stream error between peers. +data ErrorNotification = ErrorNotification { + message :: Text + , stack :: Maybe Text + } + +instance FromJSON ErrorNotification where + parseJSON = withObject "Error" + $ \v -> ErrorNotification <$> v .: "message" <*> v .: "stack" + +instance ToJSON ErrorNotification where + toJSON arg = object + [ "name" .= ("error" :: Text) + , "message" .= message arg + , "stack" .= stack arg + ] + +-- TODO: deduplicate closeStreamWithError and closeStream code +-- TODO: Verify close operation on error +-- Does the remote end send an 'true' message to confirm? Or is the +-- connection simply dropped. +closeStreamWithError' :: Stream -> Text -> IO (Either Text ()) +closeStreamWithError' stream err = do + if streamType stream == Duplex || direction stream == Incoming + then do + let msg = ErrorNotification err Nothing + let flags = def { isStream = True, isEndOrError = True } + err <- writeStream stream + flags + (JSONPayload $ BS.toStrict $ Aeson.encode msg) + either + (\err -> + logMsg (conn stream) $ "could not notify remote of error: " <> err + ) + return + err + else logMsg (conn stream) $ "closing stream with error: " <> err + atomically $ do + let table = streamTable stream + value <- takeTMVar table + putTMVar table $ Map.delete (streamID stream) value + return . return $ () + where + streamTable stream = case direction stream of + Incoming -> streamsIn $ conn stream + Outgoing -> streamsOut $ conn stream + +-- TODO: Merge newCloseErrorNotification with newCloseNotification It has +-- proven difficult to use the ToJSON typeclass as an argument. +newCloseErrorNotification :: Int32 -> Text -> Either Text Message +newCloseErrorNotification streamID msg = do + let payload = JSONPayload $ BS.toStrict $ Aeson.encode msg + let requestNumer = -1 * streamID + header <- newHeader + (def { isEndOrError = True, isStream = True, bodyType = JSON }) + requestNumer + "true" + return $ Message header payload + +closeStreamWithError :: Stream -> Text -> IO (Either Text ()) +closeStreamWithError stream err = case streamType stream of + Async -> return . return $ () + otherwise -> + either (return . Left) (manageStreamConn stream) + $ newCloseErrorNotification (streamID stream) err + +-- TODO: Properly translate RPC errors +-- TODO: Close stream within lock to avoid race conditions + +readStream :: Stream -> IO (Maybe MessagePayload) +readStream stream = do + let table = streamTable stream + table' <- atomically $ readTMVar table + let result = Map.lookup (streamID stream) table' + msg <- case result of + Nothing -> return Nothing + Just (stream', chan') -> if status stream' == Closed + then atomically $ join <$> tryReadTChan chan' + else atomically $ readTChan chan' + table'' <- atomically $ takeTMVar table + let table''' = case msg of + Nothing -> Map.delete (streamID stream) table' + Just _ -> table'' + atomically $ putTMVar table table''' + return (body <$> msg) + +-- | readStreamJSON reads a single JSON message from the RPC stream. +readStreamJSON :: FromJSON a => Stream -> IO (Either Text (Maybe a)) +readStreamJSON stream = do + resp <- readStream stream + case resp of + (Just (JSONPayload buf)) -> return (Just <$> decodeJSON buf) + (Just otherwise ) -> return (errPayload resp) + Nothing -> return . return $ Nothing + where + errPayload payload = + error ("expected JSONPayload but got " <> (show payload :: Text)) + errEOS = error "end of stream" + +writeStream :: Stream -> Flags -> MessagePayload -> IO (Either Text ()) +writeStream stream flags payload = do + status <- atomically $ streamStatus stream + if status /= Open + then return $ Left "stream closed" + else do + let msgID = case direction stream of + Incoming -> -1 * streamID stream + Outgoing -> streamID stream + let flags' = flags { isStream = not $ (streamType stream) == Async } + let header' = case payload of + (BinaryPayload p) -> + newHeader (flags' { bodyType = Binary }) msgID p + (TextPayload p) -> + newHeader (flags' { bodyType = UTF8String }) msgID p + (JSONPayload p) -> newHeader (flags' { bodyType = JSON }) msgID p + case header' of + Left err -> return $ Left err + Right header' -> do + let msg = Message header' payload + writeMessage' stream msg + where writeMessage' = manageStreamConn + +-- | writeStreamJSON writes a single JSON message to the RPC stream. +writeStreamJSON :: ToJSON a => Stream -> a -> IO (Either Text ()) +writeStreamJSON stream msg = do + let buf = Aeson.encode msg + writeStream stream def (JSONPayload $ BS.toStrict buf) |