diff options
Diffstat (limited to 'src/Ssb/Peer')
-rw-r--r-- | src/Ssb/Peer/BoxStream.hs | 367 | ||||
-rw-r--r-- | src/Ssb/Peer/RPC.hs | 859 | ||||
-rw-r--r-- | src/Ssb/Peer/RPC/Gossip.hs | 139 | ||||
-rw-r--r-- | src/Ssb/Peer/RPC/Room.hs | 330 | ||||
-rw-r--r-- | src/Ssb/Peer/RPC/WhoAmI.hs | 49 | ||||
-rw-r--r-- | src/Ssb/Peer/SecretHandshake.hs | 695 | ||||
-rw-r--r-- | src/Ssb/Peer/TCP.hs | 100 |
7 files changed, 2539 insertions, 0 deletions
diff --git a/src/Ssb/Peer/BoxStream.hs b/src/Ssb/Peer/BoxStream.hs new file mode 100644 index 0000000..d804e56 --- /dev/null +++ b/src/Ssb/Peer/BoxStream.hs @@ -0,0 +1,367 @@ +-- | This module implements Scuttlebutt's Box Stream. +-- +-- For more information kindly refer the to protocol guide +-- https://ssbc.github.io/scuttlebutt-protocol-guide + +module Ssb.Peer.BoxStream where + +import Protolude hiding ( Identity ) + +import qualified Crypto.Hash.SHA256 as SHA256 +import qualified Data.ByteString as BS +import qualified Data.ByteString.Base64 as Base64 +import Data.Either.Combinators ( mapLeft + , mapRight + ) +import Control.Concurrent.STM +import qualified Data.Serialize as Serialize +import qualified Data.Serialize.Put as Serialize +import qualified Network.Simple.TCP as TCP + +import Ssb.Aux +import qualified Ssb.Identity as SSB +import qualified Ssb.Peer.SecretHandshake as SH + +import qualified Crypto.Saltine.Class as Nacl +import qualified Crypto.Saltine.Core.Auth as NaclAuth +import qualified Crypto.Saltine.Core.SecretBox as Nacl +import qualified Crypto.Saltine.Core.ScalarMult + as Nacl + +import Pipes +import qualified Pipes.Prelude as P + +-- | HeaderLength is the length of a Box Stream header in bytes +headerLength :: Int +headerLength = 34 + +-- | MaxBodyLength is the maximum Box Stream body length in bytes +maxBodyLength :: Int +maxBodyLength = 4096 + +data Header = Header + { bodyLength :: Word16 + , authTag :: AuthTag + } deriving (Eq,Generic,Show) + +instance Serialize.Serialize Header + +-- | GoodBye is the message signalling the end of the Box Stream +goodByeHeader :: Header +goodByeHeader = Header + { bodyLength = 0 + , authTag = AuthTag $ BS.pack $ replicate authTagLength 0 + } + +newHeader :: ByteString -> ByteString -> Either Text Header +newHeader authTag body = do + let bodyLength = BS.length body + if bodyLength >= maxLength + then Left "body size too big" + else return $ Header { bodyLength = fromIntegral bodyLength + , authTag = AuthTag authTag + } + where maxLength = fromIntegral (maxBound :: Word16) + +data Message = Message + { header :: Header + , body :: ByteString + } + +encryptMessage :: Nacl.Key -> Nacl.Nonce -> ByteString -> Either Text ByteString +encryptMessage key nonce buf = do + let (authTag, ebody) = Nacl.secretboxDetached key bodyNonce buf + header <- newHeader authTag ebody + let eheader = Nacl.secretbox key headerNonce $ Serialize.encode header + return $ eheader <> ebody + where + headerNonce = nonce + bodyNonce = increment nonce + +-- | A breakdown of the message alignment would be nice +-- | Problem is the body size is variable + +-- | TODO: deduplicate +-- | TODO: safe take, safe tail + +-- TODO: handle goodbye + +-- TODO: Properly describe the function w/ Nonce update +-- The decryption / update functions should return the nonce after +-- evaluation. This increases the difficulty for describing in the language. +-- Current work around is to model the behaviour outside the function. + +decryptHeader :: Nacl.Key -> Nacl.Nonce -> ByteString -> Either Text Header +decryptHeader key nonce buf = do + let eheader = BS.take headerLength buf + headerBuf <- withErr (errHeader eheader) + $ Nacl.secretboxOpen key nonce eheader + decodeByteString headerBuf :: Either Text Header + where errHeader h = "could not decrypt header: " <> (show h :: Text) + +decryptMessage :: Nacl.Key -> Nacl.Nonce -> ByteString -> Either Text ByteString +decryptMessage key nonce buf = do + header <- decryptHeader key nonce buf + + let rest = BS.drop headerLength buf + ebody <- withErr "message body is smaller than messages body length" + $ takeMay (fromIntegral $ bodyLength header) rest + withErr (errBody ebody) $ Nacl.secretboxOpenDetached + key + (increment nonce) + (extractAuthTag $ authTag header) + ebody + where errBody b = "could not decrypt body: " <> (show b :: Text) + +goodBye :: ConnState -> ByteString +goodBye state = Nacl.secretbox (key state) (nonce state) $ encodeByteString goodByeHeader + +-- | clientToServerKey is the key for client to server stream encryption. +clientToServerKey :: SH.SharedSecrets -> Either Text Nacl.Key +clientToServerKey sharedSecrets = do + secretB <- withErr errMissingB $ SH.secretB sharedSecrets + let layer1 = SHA256.hash $ SHA256.hash + ( SH.network sharedSecrets + <> Nacl.encode (SH.secretab sharedSecrets) + <> Nacl.encode (SH.secretaB sharedSecrets) + <> Nacl.encode (SH.secretAb sharedSecrets) + ) + let layer2 = SHA256.hash $ layer1 <> SSB.extractPublicKey secretB + maybeToRight errKey $ Nacl.decode layer2 + where + errMissingB = "missing shared secret B" + errKey = "badly formatted sodium secret box key" + +-- | serverToClientKey is the key for server to client stream encryption. +serverToClientKey :: SH.SharedSecrets -> Either Text Nacl.Key +serverToClientKey sharedSecrets = do + secretA <- withErr errMissingA $ SH.secretA sharedSecrets + let layer1 = SHA256.hash + (SHA256.hash + ( SH.network sharedSecrets + <> Nacl.encode (SH.secretab sharedSecrets) + <> Nacl.encode (SH.secretaB sharedSecrets) + <> Nacl.encode (SH.secretAb sharedSecrets) + ) + ) + let layer2 = SHA256.hash (layer1 <> SSB.extractPublicKey secretA) + maybeToRight errKey $ Nacl.decode layer2 + where + errMissingA = "missing shared secret A" + errKey = "badly formatted sodium secret box key" + +clientToServerNonce :: SH.SharedSecrets -> Either Text Nacl.Nonce +clientToServerNonce sharedSecrets = do + secretb <- withErr errMissing $ SH.secretb sharedSecrets + key <- withErr errBadNet $ Nacl.decode $ SH.network sharedSecrets + let auth = NaclAuth.auth key (SSB.extractPublicKey secretb) + let noncebuf = BS.take 24 $ Nacl.encode auth + withErr errMsg $ Nacl.decode noncebuf + where + errBadNet = "badly formatted network id" + errMissing = "missing shared secret a" + errMsg = "badly formatted sodium nonce" + +serverToClientNonce :: SH.SharedSecrets -> Either Text Nacl.Nonce +serverToClientNonce sharedSecrets = do + a <- withErr errMissing $ SH.secreta sharedSecrets + key <- withErr errBadNet $ Nacl.decode $ SH.network sharedSecrets + let auth = NaclAuth.auth key (SSB.extractPublicKey a) + let noncebuf = BS.take 24 $ Nacl.encode auth + withErr errMsg $ Nacl.decode noncebuf + where + errMissing = "missing shared secret a" + errBadNet = "badly formatted network id" + errHMAC = "badly formatted HMAC" + errMsg = "badly formatted sodium nonce" + +-- The documentation's Client / Server key and nonce terminology is replaced +-- with local/remote fields to simplify implmenentation. +data ConnState = ConnState { + key :: Nacl.Key + , nonce :: Nacl.Nonce + , remoteKey :: Nacl.Key + , remoteNonce :: Nacl.Nonce + , buffer :: ByteString + , socket :: TCP.Socket + } + +newtype Conn = Conn ((TMVar ConnState),(TMVar ConnState)) + +inc :: (Word64, Word64, Word64) -> (Word64, Word64, Word64) +inc (w1, w2, w3) | w3 /= maxBound = (w1, w2, w3 + 1) + | w2 /= maxBound = (w1, w2 + 1, 0) + | w1 /= maxBound = (w1 + 1, 0, 0) + | otherwise = (0, 0, 0) + + +-- TODO : finish me +increment :: Nacl.Nonce -> Nacl.Nonce +increment arg = + fromMaybe undefined + $ Nacl.decode + $ Serialize.encode w1' + <> Serialize.encode w2' + <> Serialize.encode w3' + where + noncebuf = Nacl.encode arg + (b1, e3) = BS.splitAt 16 noncebuf + (e1, e2) = BS.splitAt 8 b1 + w3 = fromRight undefined $ Serialize.decode e3 :: Word64 + w2 = fromRight undefined $ Serialize.decode e2 :: Word64 + w1 = fromRight undefined $ Serialize.decode e1 :: Word64 + (w1', w2', w3') = inc (w1, w2, w3) + + +-- | TODO: update me for handling multiple encryptions + +newConnState :: TCP.Socket -> SH.SharedSecrets -> Either Text ConnState +newConnState socket sharedSecrets = + ConnState + <$> clientToServerKey sharedSecrets + <*> clientToServerNonce sharedSecrets + <*> serverToClientKey sharedSecrets + <*> serverToClientNonce sharedSecrets + <*> Right "" + <*> Right socket + +newServerConnState :: TCP.Socket -> SH.SharedSecrets -> Either Text ConnState +newServerConnState socket sharedSecrets = + ConnState + <$> serverToClientKey sharedSecrets + <*> serverToClientNonce sharedSecrets + <*> clientToServerKey sharedSecrets + <*> clientToServerNonce sharedSecrets + <*> Right "" + <*> Right socket + +-- TODO: Fix underlying network functions +-- Send never seems to fail. +send :: ConnState -> ByteString -> IO (Either Text ()) +send state buf = runExceptT (TCP.send (socket state) buf) + +read :: ConnState -> Int -> IO (Maybe ByteString) +read state 0 = return Nothing +read state bytes = do + buf <- TCP.recv (socket state) bytes + case buf of + Nothing -> return Nothing + Just buf -> + if BS.length buf == bytes + then return $ Just buf + else fmap (buf <>) <$> read state (bytes - BS.length buf) + +-- TODO: Keep connection terminology tied to peer and local. +-- Using network terminology such as 'server' can be confusing in other when +-- functions are used in other contexts. +connectClient :: TCP.Socket -> SH.SharedSecrets -> IO (Either Text Conn) +connectClient socket sharedSecrets = do + let state = newConnState socket sharedSecrets + case state of + Left err -> return $ Left err + Right state -> do + rstate <- newTMVarIO state + wstate <- newTMVarIO state + return . return $ Conn (rstate, wstate) + +connectServer :: TCP.Socket -> SH.SharedSecrets -> IO (Either Text Conn) +connectServer socket sharedSecrets = do + let state = newServerConnState socket sharedSecrets + case state of + Left err -> return $ Left err + Right state -> do + rstate <- newTMVarIO state + wstate <- newTMVarIO state + return . return $ Conn (rstate, wstate) + + +disconnect :: ConnState -> IO (Either Text ()) +disconnect connState = send connState $ goodBye connState + +-- TODO: Find out how to avoid these stair cases + +readStream' + :: ConnState -> Int -> IO (ConnState, Either Text (Maybe ByteString)) +readStream' connState bytes = if BS.length (buffer connState) >= bytes + then do + let (buf, rem) = BS.splitAt bytes (buffer connState) + let connState' = connState { buffer = rem } + return (connState', Right (Just buf)) + else do + buf <- withErr errNoHeader <$> read connState headerLength + let header' = buf >>= decryptHeader key' nonce' + if header' == Right goodByeHeader + -- TODO: Error if not enough bytes available + then return (connState, return Nothing) + else do + let bodyLength' = bodyLength <$> (buf >>= decryptHeader key' nonce') + case bodyLength' of + Left err -> return $ (connState, Left err) + Right bodyLength' -> do + ePayload <- (withErr errNoBody) + <$> read connState (fromIntegral bodyLength') + case ePayload of + Left err -> return $ (connState, Left err) + Right payload -> + case + decryptMessage key' + nonce' + (fromRight undefined buf <> payload) + of + Left err -> return (connState, Left err) + Right payload -> readStream' + (updateNonce . appendBuffer payload $ connState) + bytes + where + errNoHeader = "could not read header" + errNoBody = "could not read body" + key' = remoteKey connState + nonce' = remoteNonce connState + updateNonce connState = + connState { remoteNonce = (increment . increment) (remoteNonce connState) } + appendBuffer buf connState = connState { buffer = buffer connState <> buf } + +readStream :: Conn -> Int -> IO (Either Text (Maybe ByteString)) +readStream (Conn (mVar, _)) bytes = do + state <- atomically $ takeTMVar mVar + (state', ret) <- readStream' state bytes + atomically $ putTMVar mVar state' + return ret + +sendStream' :: ConnState -> ByteString -> IO (ConnState, Either Text ()) +sendStream' connState msg = do + let eMsg = encryptMessage key' nonce' msg + case eMsg of + Left err -> return (connState, Left err) + Right buf -> do + ret <- send connState buf + return (updateNonce connState, ret) + where + key' = key connState + nonce' = nonce connState + updateNonce connState = + connState { nonce = (increment . increment) (nonce connState) } + +sendStream :: Conn -> ByteString -> IO (Either Text ()) +sendStream (Conn (_, mVar)) buf = do + state <- atomically $ takeTMVar mVar + (state', ret) <- sendStream' state buf + atomically $ putTMVar mVar state' + return ret + +-- | authTagLength is the AuthTag size in bytes +authTagLength = 16 + +newtype AuthTag = AuthTag ByteString + deriving (Eq,Generic,Show) + +instance Serialize.Serialize AuthTag where + get = AuthTag <$> Serialize.getByteString authTagLength + put (AuthTag buf) = Serialize.putByteString buf + +extractAuthTag :: AuthTag -> ByteString +extractAuthTag (AuthTag buf) = buf + +-- | Aux functions +takeMay :: Int -> ByteString -> Maybe ByteString +takeMay l arg = if BS.length arg <= l then return $ BS.take l arg else Nothing diff --git a/src/Ssb/Peer/RPC.hs b/src/Ssb/Peer/RPC.hs new file mode 100644 index 0000000..e5a4c9a --- /dev/null +++ b/src/Ssb/Peer/RPC.hs @@ -0,0 +1,859 @@ +-- For more information kindly refer the to protocol guide +-- https://ssbc.github.io/scuttlebutt-protocol-guide + +module Ssb.Peer.RPC where + +import Protolude + +import Control.Concurrent.STM +import Control.Monad.Fail +import Data.Aeson ( FromJSON + , ToJSON + ) +import Data.Aeson as Aeson +import qualified Data.ByteString as BS +import qualified Data.ByteString.Lazy as BS + ( toStrict ) +import qualified Data.Map.Strict as Map +import Data.Default +import Data.Either.Combinators ( mapLeft + , mapRight + ) +import Data.Text as Text +import Data.Serialize as Serialize + +import Ssb.Aux +import qualified Ssb.Identity as Ssb +import qualified Ssb.Peer.BoxStream as BoxStream + +data BodyType = Binary | UTF8String | JSON | UnknownBodyType + deriving (Bounded,Eq,Show) + +instance Convertible Word8 BodyType where + convert v = case v .&. 3 of + 0 -> Binary + 1 -> UTF8String + 2 -> JSON + v -> UnknownBodyType + +instance Convertible BodyType Word8 where + convert Binary = 0 + convert UTF8String = 1 + convert JSON = 2 + convert UnknownBodyType = 3 + +data Flags = Flags + { unused1 :: Bool + , unused2 :: Bool + , unused3 :: Bool + , unused4 :: Bool + , isStream :: Bool + , isEndOrError :: Bool + , bodyType :: BodyType + } deriving (Eq,Show) + +instance Convertible Flags Word8 where + convert arg = + set (unused1 arg) 7 + $ set (unused2 arg) 6 + $ set (unused3 arg) 5 + $ set (unused4 arg) 4 + $ set (isStream arg) 3 + $ set (isEndOrError arg) 2 + $ convert (bodyType arg) + where + set True pos arg = setBit arg pos + set False _ arg = arg + +instance Convertible Word8 Flags where + convert w = Flags { unused1 = testBit w 7 + , unused2 = testBit w 6 + , unused3 = testBit w 5 + , unused4 = testBit w 4 + , isStream = testBit w 3 + , isEndOrError = testBit w 2 + , bodyType = convert w + } + +instance Default Flags where + def = convert (zeroBits :: Word8) + +instance Serialize.Serialize Flags where + get = convert <$> getWord8 + put = putWord8 . convert + +-- | ProcedureType defines the type of remote call. +data ProcedureType = Async | Source | Duplex + deriving (Eq,Generic,Ord,Show) + +instance FromJSON ProcedureType where + parseJSON = withText "ProcedureType" $ \v -> case v of + "async" -> return Async + "source" -> return Source + "duplex" -> return Duplex + otherwise -> fail $ "unknown value '" <> toS v <> "'" + +instance ToJSON ProcedureType where + toJSON Async = "async" + toJSON Source = "source" + toJSON Duplex = "duplex" + +-- | HeaderLength is the length of the RPC header in bytes +headerLength :: Int +headerLength = 9 + +-- | bodySizeLength is the length of the bodySize parameter in the RPC header. +bodySizeLength :: Int +bodySizeLength = 4 + +-- | requestNumberLength is the length of the requestNumberLength parameter in +-- the RPC header. +requestNumberLength :: Int +requestNumberLength = 4 + +-- | Header is the first part of a RPC message used for stream control and communication. +data Header = Header + { flags :: Flags + , bodyLength :: Word32 + , requestNumber :: Int32 + } deriving (Eq,Generic,Show) + +instance Serialize.Serialize Header + +-- | GoodByeHeader is the RPC message header signalling the end of the RPC +-- stream. +goodByeHeader :: Header +goodByeHeader = Header (convert (zeroBits :: Word8)) 0 0 + +newHeader :: Flags -> Int32 -> ByteString -> Either Text Header +newHeader flags reqNum body = do + let bodyLength = fromIntegral $ BS.length body + return Header { flags = flags + , bodyLength = bodyLength + , requestNumber = reqNum + } + +-- | MessagePayload describes and contains the contents of an RPC message. +data MessagePayload = BinaryPayload ByteString | TextPayload ByteString | JSONPayload ByteString + deriving (Generic,Show) + +-- | Message is a single message in the RPC stream +data Message = Message + { header :: Header + , body :: MessagePayload + } deriving (Generic,Show) + +instance Serialize.Serialize Message where + get = do + header <- Serialize.get + buf <- Serialize.getByteString (fromIntegral $ bodyLength header) + let typ = bodyType . flags $ header + payload <- case typ of + Binary -> return $ BinaryPayload buf + UTF8String -> return $ TextPayload buf + JSON -> return $ JSONPayload buf + UnknownBodyType -> fail "unknown body type" + return $ Message header payload + + put msg = do + Serialize.put (header msg) + case body msg of + BinaryPayload buf -> Serialize.putByteString buf + JSONPayload buf -> Serialize.putByteString buf + TextPayload buf -> Serialize.putByteString buf + +-- | newJSONMessage is a convenience function to create a RPC message with a +-- JSON payload. +newJSONMessage :: ToJSON a => Int32 -> a -> Either Text Message +newJSONMessage reqNum payload = do + let buf = BS.toStrict $ Aeson.encode payload + len <- maybeWord8 $ BS.length buf + header' <- newHeader (def { bodyType = JSON }) reqNum buf + return $ Message { header = header', body = JSONPayload buf } + +-- | newJSONMessage is a convenience function to decode a RPC JSON message. +decodeJSONMessage :: FromJSON a => Message -> Either Text a +decodeJSONMessage msg = case body msg of + JSONPayload buf -> decodeJSON buf + _ -> error "unexpected body type" + +isRequest :: Int32 -> Message -> Bool +isRequest nextReqNum msg = do + let reqNum = requestNumber . header $ msg + nextReqNum == reqNum + +-- | Request is the message representing a RPC call. +data Request a = Request + { name :: [Text] + , typ :: ProcedureType + , args :: a + } deriving (Show) + +instance (FromJSON a) => FromJSON (Request a) where + parseJSON = withObject "Request" + $ \v -> Request <$> v .: "name" <*> v .:? "type" .!= Async <*> v .: "args" + +instance (ToJSON a) => ToJSON (Request a) where + toJSON arg = + object ["name" .= name arg, "type" .= typ arg, "args" .= args arg] + +data Direction = Incoming | Outgoing + deriving (Eq,Generic,Show) + +data StreamStatus = + -- | Open the stream is active. + Open + -- | The (Async) stream is awaiting a response. + | AwaitingResponse + -- | The stream is closed and waiting for its peer to do the same. + | AwaitingCloseRecv + -- | The stream is closed. + | Closed + deriving (Eq,Show) + +data Stream = Stream { + streamID :: Int32 + , streamType :: ProcedureType + , conn :: ConnState + , direction :: Direction + , status :: StreamStatus + , peer :: Ssb.PublicKey + } + +formatStream :: Stream -> Text +formatStream stream = + (show $ streamID stream :: Text) + <> "\t" + <> show (streamType stream) + <> "\t" + <> show (direction stream) + <> "\t" + <> show (status stream) + +foldStream + :: (a -> MessagePayload -> IO (Either Text a)) + -> a + -> Stream + -> IO (Either Text a) +foldStream fn acc stream = do + msg <- readStream stream + case msg of + (Just msg') -> do + acc' <- fn acc msg' + either (return . error) (\v -> foldStream fn v stream) acc' + Nothing -> return . return $ acc + +foldJSONStream + :: (FromJSON b) + => (a -> b -> IO (Either Text a)) + -> a + -> Stream + -> IO (Either Text a) +foldJSONStream fn acc stream = do + msg <- readStreamJSON stream + case msg of + (Right (Just msg')) -> do + acc' <- fn acc msg' + either (return . error) (\v -> foldJSONStream fn v stream) acc' + (Right Nothing) -> return . return $ acc + (Left err ) -> return $ error err + +data ConnState = ConnState { + connPeer :: Ssb.PublicKey + , boxConn :: BoxStream.Conn + , streamsIn :: TMVar (Map Int32 (Stream, TChan (Maybe Message))) + , streamsOut :: TMVar (Map Int32 (Stream, TChan (Maybe Message))) + , nextIncomingReqNum :: TMVar Int32 + , nextOutgoingReqNum :: TMVar Int32 + , lock :: TMVar Bool + } + +-- | Endpoint identifies which Remote Procedure Call to make. +data Endpoint = Endpoint [Text] ProcedureType + deriving (Eq, Ord) + +formatEndpoint :: Endpoint -> Text +formatEndpoint (Endpoint paths typ) = + Text.intercalate "." paths <> ":" <> show typ + +-- | HandlerFunc's are used to serve Remove Procedure Calls. +type HandlerFunc = Aeson.Value -> Stream -> IO (Either Text ()) + +-- | notFoundHandlerFunc is tells the peer the Endpoint does not exist. +notFoundHandlerFunc :: Endpoint -> HandlerFunc +notFoundHandlerFunc endpoint _ _ = return + (Left $ "endpoint not found '" <> endpoint' <> "'") + where endpoint' = formatEndpoint endpoint + +-- | Handler can serve Remote Procedure Requests. +class Handler h where + endpoints + :: h + -> [Endpoint] + + -- | Serve calls an incoming Remote Procedure Call. + serve + :: h + -> Endpoint + -> HandlerFunc + + -- TODO: Maybe add 'conn' to notifyConnect + + -- | notifyConnect tells the handler when a peer has connected. + notifyConnect + :: h + -> Ssb.PublicKey + -> IO (Either Text ()) + + -- | notifyDisconnect tells the handler when a peer has disconnected. + notifyDisconnect + :: h + -> Ssb.PublicKey + -> IO (Either Text ()) + +-- | Router serves Remote Procedure Calls with a set of handlers. +data Router = Router { + endpointHandlers :: Map Endpoint HandlerFunc + , connectCallbacks :: [Ssb.PublicKey -> IO (Either Text ())] + , disconnectCallbacks :: [Ssb.PublicKey -> IO (Either Text ())] + } deriving (Generic) + +instance Default Router + +instance Handler Router where + endpoints demuxer = Map.keys (endpointHandlers demuxer) + + serve demuxer endpoint = case Map.lookup endpoint endpointHandlers' of + Nothing -> (notFoundHandlerFunc endpoint) + Just handler -> handler + where endpointHandlers' = endpointHandlers demuxer + + notifyConnect demuxer id = do + errs <- fmap lefts . sequence $ (\f -> f id) <$> connectCallbacks demuxer + if Protolude.null errs + then return . return $ () + else return $ Left $ Text.intercalate ", " errs + + notifyDisconnect demuxer id = do + errs <- fmap lefts . sequence $ (\f -> f id) <$> connectCallbacks demuxer + if Protolude.null errs + then return . return $ () + else return $ Left $ Text.intercalate ", " errs + + +-- | with adds a handler to the Router. +with :: Handler h => Router -> h -> Router +with demuxer handler = Router + { endpointHandlers = Map.union (endpointHandlers demuxer) $ Map.fromList + ((\e -> (e, serve handler e)) <$> endpoints handler) + , connectCallbacks = connectCallbacks demuxer <> [notifyConnect handler] + , disconnectCallbacks = disconnectCallbacks demuxer + <> [notifyDisconnect handler] + } + +-- | withM is a convenience function for using 'with' in monads. +withM :: (MonadIO m, Handler h) => m Router -> m h -> m Router +withM demuxer handler = with <$> demuxer <*> handler + +logMsg :: ConnState -> Text -> IO () +logMsg conn msg = do + _ <- atomically $ takeTMVar (lock conn) + print msg + atomically $ putTMVar (lock conn) True + +logDebug :: ConnState -> Text -> IO () +logDebug _ _ = return () +-- logDebug = logMsg + +-- | spawnConnection handles safely forking and closing RPC connections. +spawnConnection + :: Stream + -> IO (Either Text ()) + -> IO () +spawnConnection stream action = do + forkFinally + (do + res <- action + either (closeStreamWithError stream) (\_ -> closeStream stream) res + ) + (\_ -> void (closeStream stream)) + return () + +-- | connect creates a Remote Procdure Call stream connection over the give Box +-- Stream. +connect + :: Handler h + => BoxStream.Conn + -> h + -> Ssb.PublicKey + -> (ConnState -> IO ()) + -> IO (Either Text ()) +connect boxConn handler peer client = do + streamsIn <- newTMVarIO Map.empty + streamsOut <- newTMVarIO Map.empty + nextIncomingReqNum <- newTMVarIO 1 + nextOutgoingReqNum <- newTMVarIO 1 + lock <- newTMVarIO True + let conn = ConnState { connPeer = peer + , boxConn = boxConn + , streamsIn = streamsIn + , streamsOut = streamsOut + , nextIncomingReqNum = nextIncomingReqNum + , nextOutgoingReqNum = nextOutgoingReqNum + , lock = lock + } + + forkIO $ + -- make RPC calls on the peer + client conn + print "entering service loop" + ret <- serviceLoop conn + print "out of service loop" + _ <- notifyDisconnect handler (connPeer conn) + print "disconnecting" + disconnect conn + print "disconnected" + return ret + where + serviceLoop conn = do + msg <- readMessage conn + case msg of + Left err -> return $ error ("connection error: " <> err) + Right msg -> if header msg == goodByeHeader + then return . return $ () + else do + ok <- handleMessage handler conn msg + if ok then serviceLoop conn else return . return $ () + +-- TODO: Handle stream closing within demux loop +-- The handleMessage loop is responsible for the lifetime of a stream +-- connection. It detects and creates requests, it should therefor handle +-- close requests. +-- +-- The stream type matters when closing a stream. Async do not require a +-- specific close message. +-- TODO: Handle closing of Async streams +-- TODO: Properly handle Error message + +-- TODO: Refactor or reallocate helper functions +-- These functions were added in a rush to get request handling working. A +-- better place could probably be found for them to improve code clarity. + +-- | isEndOfStream checks whether the message is the last of the stream. +isEndOfStream = isEndOrError . flags . header + +-- | checks if the message is a response to an Async request. +isAsyncResponse = not . isStream . flags . header + +-- | streamTable gets which lookup table to use for connections for the given +-- stream. There are seperate ones for incoming and outgoing RPC streams. +streamTable stream = case direction stream of + Incoming -> streamsIn (conn stream) + Outgoing -> streamsOut (conn stream) + +streamStatus :: Stream -> STM StreamStatus +streamStatus stream = do + table <- readTMVar $ streamTable stream + return + $ fromMaybe Closed (status . fst <$> Map.lookup (streamID stream) table) + +-- | manageStreamConn manages stream connection changes when writing messages +-- to a stream. +-- TODO: beware of entering a deadlock in manageStreamConn +-- TODO: request stuff should be moved here +manageStreamConn :: Stream -> Message -> IO (Either Text ()) +manageStreamConn stream msg = do + table <- atomically $ takeTMVar $ streamTable stream + res <- writeMessage (conn stream) msg + let table' = if + | isLeft res -> closeStream table stream + | (not . isStream . flags . header $ msg) -> Map.adjust + (updateStreamStatus AwaitingResponse) + (streamID stream) + table + | +-- TODO: confirm status of stream before closing it +-- This happy path is not production ready + (isEndOrError . flags . header $ msg) -> Map.adjust + (updateStreamStatus AwaitingCloseRecv) + (streamID stream) + table + | otherwise -> table + when (isLeft res) $ do + let (_, c) = fromMaybe undefined $ Map.lookup (streamID stream) table + -- TODO: Find better way of closing channels on write error + atomically $ forM_ [1 .. 30] (\_ -> writeTChan c Nothing) + atomically $ putTMVar (streamTable stream) table' + return res + where + updateStreamStatus status (s, c) = (s { status = status }, c) + closeStream table tream = case direction stream of + Incoming -> Map.adjust (updateStreamStatus Closed) (streamID stream) table + Outgoing -> Map.delete (streamID stream) table + + +-- TODO: Log important information for incoming messages +handleMessage :: Handler h => h -> ConnState -> Message -> IO Bool +handleMessage handler conn msg = do + nextReqNum <- atomically $ readTMVar (nextIncomingReqNum conn) + if + | (header msg) == goodByeHeader -> return False + | (isRequest nextReqNum msg) -> serveRequest handler conn msg + | (isEndOfStream msg) -> internalCloseStream conn msg + | (isAsyncResponse msg) -> do + demux conn msg + internalCloseStream conn msg + | otherwise -> demux conn msg + where + internalCloseStream conn msg = do + let flags = def { isStream = True, isEndOrError = True } + let requestNumber' = (requestNumber . header $ msg) + let mVarTable = + if requestNumber' > 0 then streamsIn conn else streamsOut conn + table <- atomically $ takeTMVar mVarTable + let streamID' = abs requestNumber' + let (stream, chan) = fromMaybe undefined $ Map.lookup streamID' table + newStatus <- case status stream of + Open -> do + let msg = fromRight undefined $ newCloseNotification streamID' + _ <- writeMessage conn msg + return AwaitingCloseRecv + AwaitingCloseRecv -> do + atomically $ writeTChan chan Nothing + return Closed + AwaitingResponse -> do + atomically $ writeTChan chan Nothing + return Closed + Closed -> do + logMsg conn + $ "received end of stream for already closed stream: " + <> (show streamID' :: Text) + return Closed + atomically $ putTMVar mVarTable $ case direction stream of + Incoming -> Map.adjust (updateStreamStatus newStatus) streamID' table + -- TODO: properly handle close of outgoing Duplex streams + Outgoing -> Map.delete streamID' table + return True + where updateStreamStatus status (s, c) = (s { status = status }, c) + + serveRequest handler conn msg = do + let req = decodeJSONMessage msg + case req of + Left err -> do + logMsg conn $ "could not decode request: " <> err + return False + Right req -> do + let stream = Stream { streamID = requestNumber . header $ msg + , streamType = typ req + , conn = conn + , direction = Incoming + , status = Open + , peer = connPeer conn + } + chan <- newTChanIO + err <- atomically $ do + reqNum <- takeTMVar (nextIncomingReqNum conn) + table <- takeTMVar (streamsIn conn) + if Map.size table == (maxBound :: Int) + then return $ Left errTooManyRequests + else do + let (reqNum', table') = if streamID stream == reqNum + then (reqNum + 1, (Map.insert reqNum (stream, chan) table)) + else (reqNum, table) + putTMVar (nextIncomingReqNum conn) reqNum' + putTMVar (streamsIn conn) table' + return . return $ () + case err of + Left msg -> do + print msg + return False + Right _ -> do + -- Serving a request call, how do we close it nicely? + let endpoint = Endpoint (name req) (typ req) + spawnConnection stream $ (serve handler) endpoint (args req) stream + return True + where errTooManyRequests = "connection limit reached, dropping request" + + demux conn msg = do + let reqNum = requestNumber . header $ msg + let table = if reqNum > 0 then streamsIn conn else streamsOut conn + let streamID' = abs reqNum + table' <- atomically $ readTMVar table + case Map.lookup streamID' table' of + Nothing -> do + logDebug conn + $ "message dropped due to missing stream: " + <> (show msg :: Text) + return () + Just (_, chan) -> atomically $ writeTChan chan (Just msg) + return True + +-- | disconnect is a hacked version, it rudely disconnects all connections in +-- order to unlock streams stuck on their reading their channels. +-- +-- Look into closing streams elegantly. +-- TODO: Ensure disconnect prevents new incoming and outgoing connections +disconnect :: ConnState -> IO () +disconnect conn = do + table <- atomically $ takeTMVar (streamsIn conn) + forM_ table + $ \(_, c) -> atomically $ forM_ [1 .. 30] (\_ -> writeTChan c Nothing) + let table' = Map.map (\(s, c) -> (s { status = Closed }, c)) table + atomically $ putTMVar (streamsIn conn) table' + + table <- atomically $ takeTMVar (streamsOut conn) + forM_ table + $ \(_, c) -> atomically $ forM_ [1 .. 30] (\_ -> writeTChan c Nothing) + let table' = Map.map (\(s, c) -> (s { status = Closed }, c)) table + atomically $ putTMVar (streamsOut conn) table' + + _ <- writeMessage conn $ Message goodByeHeader (BinaryPayload "") + return () + +-- | readBoxStream reads the given number of bytes from the Box Stream. +readBoxStream :: BoxStream.Conn -> Int -> IO (Either Text ByteString) +readBoxStream conn bytes = do + mbuf <- BoxStream.readStream conn bytes + let buf = join $ (maybeToRight errUnexpectedClose) <$> mbuf + return buf + where errUnexpectedClose = "rpc.readConn: unexpected end of box stream" + + +-- | readMessage reads a single message from the RPC connection. +readMessage :: ConnState -> IO (Either Text Message) +readMessage conn = do + headerBuf <- readBoxStream (boxConn conn) headerLength + header <- (return $ headerBuf >>= decode) + case header of + Left err -> return $ Left $ "RPC.readMessage header: " <> err + Right header -> do + bodyBuf <- readBoxStream (boxConn conn) (fromIntegral $ bodyLength header) + let ret = liftA2 (<>) headerBuf bodyBuf >>= decode + logDebug conn + $ "readMessage (" + <> (Ssb.formatPublicKey $ connPeer conn) + <> ")" + <> (show ret) + return $ ret + where + decode :: Serialize a => ByteString -> Either Text a + decode = mapLeft toS . Serialize.decode + +writeMessage :: ConnState -> Message -> IO (Either Text ()) +writeMessage conn msg = do + logDebug conn + $ "writeMessage (" + <> (Ssb.formatPublicKey $ connPeer conn) + <> ")" + <> (show msg) + BoxStream.sendStream (boxConn conn) $ Serialize.encode msg + +-- | request makes a Remote Procedure Call on the peer. +request + :: ToJSON a + => ConnState + -> Request a + -> (Stream -> IO (Either Text b)) + -> IO (Either Text b) +request conn req session = do + reqNum <- atomically $ takeTMVar (nextOutgoingReqNum conn) + let msg = newJSONMessage reqNum req + case msg of + Left err -> do + atomically $ putTMVar (nextOutgoingReqNum conn) reqNum + return $ Left err + Right msg -> do + let nextReqNum = reqNum + 1 + atomically $ putTMVar (nextOutgoingReqNum conn) nextReqNum + + -- TODO: Fix late night update + let flags' = flags . header $ msg + let + msg' = msg + { header = (header msg) { flags = flags' + { isStream = typ req /= Async + } + } + } + + let stream = Stream { streamID = reqNum + , streamType = (typ req) + , conn = conn + , direction = Outgoing + , status = Open + , peer = connPeer conn + } + atomically $ do + streams <- takeTMVar (streamsOut conn) + chan <- newTChan + putTMVar (streamsOut conn) $ Map.insert reqNum (stream, chan) streams + + _ <- writeMessage' stream msg' + result <- session stream + if streamType stream == Async + then return . return $ () + else either (closeStreamWithError stream) + (\x -> closeStream stream) + result + return result + where writeMessage' = manageStreamConn + +-- | requestAsync is the Async version of request. +requestAsync + :: (ToJSON a, FromJSON b) => ConnState -> Request a -> IO (Either Text b) +requestAsync conn req = do + resp <- request conn req readStreamJSON + return $ resp >>= withErr "not response received" + +-- TODO: Use CloseNotification to signal close of stream. +-- How is the JSON encoding done? +data CloseNotification = CloseNotification () + +newCloseNotification :: Int32 -> Either Text Message +newCloseNotification streamID = do + let payload = JSONPayload "true" + let requestNumer = -1 * streamID + header <- newHeader + (def { isEndOrError = True, isStream = True, bodyType = JSON }) + requestNumer + "true" + return $ Message header payload + +-- | closeStream politely shuts down an RPC stream. +-- TODO: closeStream needs to be more adaptive. It only accounts for +-- real streams. +closeStream :: Stream -> IO (Either Text ()) +closeStream stream = case (status stream, streamType stream) of + (_ , Async) -> return . return $ () + (Open, _ ) -> either (return . Left) (manageStreamConn stream) + $ newCloseNotification (streamID stream) + (Closed, _) -> return . return $ () + (_ , _) -> return $ Left "could not close stream" + +-- ErrorNotification is the format used to communicate a stream error between peers. +data ErrorNotification = ErrorNotification { + message :: Text + , stack :: Maybe Text + } + +instance FromJSON ErrorNotification where + parseJSON = withObject "Error" + $ \v -> ErrorNotification <$> v .: "message" <*> v .: "stack" + +instance ToJSON ErrorNotification where + toJSON arg = object + [ "name" .= ("error" :: Text) + , "message" .= message arg + , "stack" .= stack arg + ] + +-- TODO: deduplicate closeStreamWithError and closeStream code +-- TODO: Verify close operation on error +-- Does the remote end send an 'true' message to confirm? Or is the +-- connection simply dropped. +closeStreamWithError' :: Stream -> Text -> IO (Either Text ()) +closeStreamWithError' stream err = do + if streamType stream == Duplex || direction stream == Incoming + then do + let msg = ErrorNotification err Nothing + let flags = def { isStream = True, isEndOrError = True } + err <- writeStream stream + flags + (JSONPayload $ BS.toStrict $ Aeson.encode msg) + either + (\err -> + logMsg (conn stream) $ "could not notify remote of error: " <> err + ) + return + err + else logMsg (conn stream) $ "closing stream with error: " <> err + atomically $ do + let table = streamTable stream + value <- takeTMVar table + putTMVar table $ Map.delete (streamID stream) value + return . return $ () + where + streamTable stream = case direction stream of + Incoming -> streamsIn $ conn stream + Outgoing -> streamsOut $ conn stream + +-- TODO: Merge newCloseErrorNotification with newCloseNotification It has +-- proven difficult to use the ToJSON typeclass as an argument. +newCloseErrorNotification :: Int32 -> Text -> Either Text Message +newCloseErrorNotification streamID msg = do + let payload = JSONPayload $ BS.toStrict $ Aeson.encode msg + let requestNumer = -1 * streamID + header <- newHeader + (def { isEndOrError = True, isStream = True, bodyType = JSON }) + requestNumer + "true" + return $ Message header payload + +closeStreamWithError :: Stream -> Text -> IO (Either Text ()) +closeStreamWithError stream err = case streamType stream of + Async -> return . return $ () + otherwise -> + either (return . Left) (manageStreamConn stream) + $ newCloseErrorNotification (streamID stream) err + +-- TODO: Properly translate RPC errors +-- TODO: Close stream within lock to avoid race conditions + +readStream :: Stream -> IO (Maybe MessagePayload) +readStream stream = do + let table = streamTable stream + table' <- atomically $ readTMVar table + let result = Map.lookup (streamID stream) table' + msg <- case result of + Nothing -> return Nothing + Just (stream', chan') -> if status stream' == Closed + then atomically $ join <$> tryReadTChan chan' + else atomically $ readTChan chan' + table'' <- atomically $ takeTMVar table + let table''' = case msg of + Nothing -> Map.delete (streamID stream) table' + Just _ -> table'' + atomically $ putTMVar table table''' + return (body <$> msg) + +-- | readStreamJSON reads a single JSON message from the RPC stream. +readStreamJSON :: FromJSON a => Stream -> IO (Either Text (Maybe a)) +readStreamJSON stream = do + resp <- readStream stream + case resp of + (Just (JSONPayload buf)) -> return (Just <$> decodeJSON buf) + (Just otherwise ) -> return (errPayload resp) + Nothing -> return . return $ Nothing + where + errPayload payload = + error ("expected JSONPayload but got " <> (show payload :: Text)) + errEOS = error "end of stream" + +writeStream :: Stream -> Flags -> MessagePayload -> IO (Either Text ()) +writeStream stream flags payload = do + status <- atomically $ streamStatus stream + if status /= Open + then return $ Left "stream closed" + else do + let msgID = case direction stream of + Incoming -> -1 * streamID stream + Outgoing -> streamID stream + let flags' = flags { isStream = not $ (streamType stream) == Async } + let header' = case payload of + (BinaryPayload p) -> + newHeader (flags' { bodyType = Binary }) msgID p + (TextPayload p) -> + newHeader (flags' { bodyType = UTF8String }) msgID p + (JSONPayload p) -> newHeader (flags' { bodyType = JSON }) msgID p + case header' of + Left err -> return $ Left err + Right header' -> do + let msg = Message header' payload + writeMessage' stream msg + where writeMessage' = manageStreamConn + +-- | writeStreamJSON writes a single JSON message to the RPC stream. +writeStreamJSON :: ToJSON a => Stream -> a -> IO (Either Text ()) +writeStreamJSON stream msg = do + let buf = Aeson.encode msg + writeStream stream def (JSONPayload $ BS.toStrict buf) diff --git a/src/Ssb/Peer/RPC/Gossip.hs b/src/Ssb/Peer/RPC/Gossip.hs new file mode 100644 index 0000000..b97e4a4 --- /dev/null +++ b/src/Ssb/Peer/RPC/Gossip.hs @@ -0,0 +1,139 @@ +-- | This module implements Scuttlebutt's Remote Procedure Call for +-- CreateHistoryStream. +-- +-- For more information kindly refer the to protocol guide +-- https://ssbc.github.io/scuttlebutt-protocol-guide + +module Ssb.Peer.RPC.Gossip where + +import Protolude hiding ( Identity ) + +import Control.Concurrent.STM +import qualified Data.Aeson as Aeson +import Data.Aeson ( FromJSON + , ToJSON + ) +import Data.Default +import qualified Data.Map.Strict as Map + +import Ssb.Aux +import qualified Ssb.Feed as Feed +import qualified Ssb.Peer.RPC as RPC + + +-- | TODO: Comment +-- | TODO: Naming +-- | TODO: Keyed Message support +-- | TODO: Proper Request default values setting +data Request = Request + { id :: Feed.FeedID + , sequence :: Maybe Int + , limit :: Maybe Int + , live :: Maybe Bool + , old :: Maybe Bool + , keys :: Bool + } deriving (Generic,Show) + +newRequest :: Feed.FeedID -> Request +newRequest id = Request { id = id + , sequence = Nothing + , limit = Nothing + , live = Just False + , old = Just False + , keys = True + } + + +instance FromJSON Request + +instance ToJSON Request where + toJSON = Aeson.genericToJSON (Aeson.defaultOptions {Aeson.omitNothingFields = True}) + +-- TODO: reduce friction for introducing RPC requests + +createHistoryStreamRequest :: Request -> RPC.Request [Request] +createHistoryStreamRequest req = RPC.Request + { RPC.name = ["createHistoryStream"] + , RPC.typ = RPC.Source + , RPC.args = [req] + } + +createHistoryStream + :: FromJSON b + => RPC.ConnState + -> Request + -> a + -> (a -> Feed.VerifiableMessage b -> IO (Either Text a)) + -> IO (Either Text a) +createHistoryStream conn req init cmd = RPC.request + conn + (createHistoryStreamRequest req) + (RPC.foldStream cmd' init) + where + cmd' a payload = case payload of + RPC.JSONPayload buf -> do + msg <- Feed.decodeJSONVerifiableMessage buf + either (return . error) (cmd a) msg + v@otherwise -> return $ error "expected JSON but got something else" + +data KeyedMessage a = KeyedMessage + { key :: Feed.MessageID + , timestamp :: Int + , value :: Feed.Message a + } deriving (Generic,Show) + +instance (FromJSON a) => FromJSON (KeyedMessage a) + +instance (ToJSON a) => ToJSON (KeyedMessage a) + +--createKeyedHistoryStream +-- :: RPC.ConnState +-- -> Request +-- -> (KeyedMessage -> IO (Either Text a)) +-- -> IO (Either Text a) +--createHistoryStream conn req = +-- RPC.request conn (createHistoryStreamRequest req) +-- + +newtype Gossiper a = Gossiper (TMVar (Feed.Feeds a)) + +newGossiper :: ToJSON a => IO (Gossiper a) +newGossiper = do + mVar <- newTMVarIO Feed.emptyFeeds + return $ Gossiper mVar + +addFeed :: ToJSON a => Gossiper a -> Feed.Feed a -> IO () +addFeed (Gossiper (mFeeds)) feed = do + atomically $ do + feeds <- takeTMVar mFeeds + putTMVar mFeeds (Feed.insert feed feeds) + +writeFeed :: ToJSON a => RPC.Stream -> Feed.Feed a -> IO (Either Text ()) +writeFeed stream (Feed.Feed _ msgs) = do + return <$> forM_ + msgs + (\msg -> do + let msg' = Feed.encodeJSONVerifiableMessage msg + err <- RPC.writeStream stream def (RPC.JSONPayload msg') + either (\err -> print err) (\_ -> return ()) err + ) + +instance ToJSON a => RPC.Handler (Gossiper a) where + endpoints h = [RPC.Endpoint ["createHistoryStream"] RPC.Source] + + serve (Gossiper mFeeds) (RPC.Endpoint ["createHistoryStream"] RPC.Source) arg stream + = do + feeds <- atomically $ readTMVar mFeeds + let req = decodeJSON (encodeJSON arg) :: Either Text [Request] + case req of + Left err -> do + return . return $ () + Right [] -> return . return $ () + Right [arg] -> do + case Feed.lookup (id arg) feeds of + Just feed -> writeFeed stream feed + Nothing -> return . return $ () + + notifyConnect _ _ = return . return $ () + + notifyDisconnect _ _ = return . return $ () diff --git a/src/Ssb/Peer/RPC/Room.hs b/src/Ssb/Peer/RPC/Room.hs new file mode 100644 index 0000000..47bc956 --- /dev/null +++ b/src/Ssb/Peer/RPC/Room.hs @@ -0,0 +1,330 @@ +-- | This module implements Scuttlebutt's Remote Procedure Call for +-- Rooms. +-- +-- For more information kindly refer [WHERE] + +-- TODO: Documentation for SSB-Room + +module Ssb.Peer.RPC.Room where + + +import Protolude hiding ( Identity ) +import qualified Data.Aeson as Aeson +import Control.Concurrent.STM +import Data.Default +import qualified Data.Map.Strict as Map +import Data.Time.Clock ( UTCTime + , getCurrentTime + ) +import qualified Data.Text as Text + +import Ssb.Aux +import qualified Ssb.Identity as Ssb +import qualified Ssb.Pub as Ssb +import Ssb.Network +import qualified Ssb.Feed as Feed +import qualified Ssb.Peer.RPC as RPC + +seed :: Text +seed = "SSB+Room+PSK3TLYC2T86EHQCUHBUHASCASE18JBV24=" + +data Invite = Invite + { host :: Host + , port :: Port + , key :: Ssb.PublicKey + } deriving (Eq, Show) + +formatInvite :: Invite -> Text +formatInvite arg = + "net:" + <> host arg + <> ":" + <> port arg + <> "~shs" + <> ":" + <> formatPublicKey' arg + <> ":" + <> seed + where + formatPublicKey' arg = + Text.dropEnd 8 $ Text.drop 1 $ Ssb.formatPublicKey $ key arg + +-- TODO: Seriously, figure out how to use duplicate field names +data Room = Room { + endpoints :: TMVar (Map Ssb.PublicKey (RPC.ConnState, [Tunnel])) + , notifyChange :: TChan Bool + , roomName :: Text + , roomDesc :: Text + , tunnels :: TMVar (Map Tunnel [ThreadId]) + } + +newRoom :: Text -> Text -> IO Room +newRoom name desc = do + endpoints <- newTMVarIO Map.empty + notifier <- newBroadcastTChanIO + tunnels <- newTMVarIO Map.empty + return $ Room endpoints notifier name desc tunnels + +getEndpoints :: Room -> IO [Ssb.PublicKey] +getEndpoints h = do + let mVar = endpoints h + Map.keys <$> atomically (readTMVar mVar) + +lookUpPeer :: Room -> Ssb.PublicKey -> IO (Maybe RPC.ConnState) +lookUpPeer h peer = do + endpoints' <- atomically $ readTMVar (endpoints h) + return $ fst <$> Map.lookup peer endpoints' + +errConnLimitReached = "peer limit reached" + +registerPeer :: Room -> RPC.Stream -> IO (Either Text ()) +registerPeer h stream = do + let mVar = endpoints h + atomically $ do + endpoints' <- takeTMVar mVar + if Map.size endpoints' == (maxBound :: Int) + then return $ Left errConnLimitReached + else do + putTMVar mVar + $ Map.insert (RPC.peer stream) (RPC.conn stream, []) endpoints' + writeTChan (notifyChange h) True + return $ Right () + +unregisterPeer :: Room -> Ssb.PublicKey -> IO () +unregisterPeer h peer = do + let mVar = endpoints h + atomically $ do + endpoints' <- takeTMVar mVar + putTMVar mVar $ Map.delete peer endpoints' + writeTChan (notifyChange h) True + +data IsRoomResponse = IsRoomResponse { + name :: Text + , description :: Text + } deriving (Generic, Show) + +instance Aeson.FromJSON IsRoomResponse + +instance Aeson.ToJSON IsRoomResponse + +newIsRoomResponse :: Room -> IsRoomResponse +newIsRoomResponse l = IsRoomResponse (roomName l) (roomDesc l) + +announceRequest :: RPC.Request [Text] +announceRequest = + RPC.Request { name = ["tunnel", "announce"], typ = RPC.Async, args = [] } + +announce :: RPC.ConnState -> IO (Either Text ()) +announce conn = RPC.requestAsync conn announceRequest + +data ConnectRequest = ConnectRequest { + target :: Ssb.PublicKey + , portal :: Ssb.PublicKey + } deriving (Generic, Show) + +instance Aeson.FromJSON ConnectRequest + +instance Aeson.ToJSON ConnectRequest + +leaveRequest :: RPC.Request [Text] +leaveRequest = + RPC.Request { name = ["tunnel", "leave"], typ = RPC.Async, args = [] } + +leave :: RPC.ConnState -> IO (Either Text ()) +leave conn = RPC.requestAsync conn leaveRequest + +pingRequest :: RPC.Request [Text] +pingRequest = + RPC.Request { name = ["tunnel", "ping"], typ = RPC.Async, args = [] } + +ping :: RPC.ConnState -> IO (Either Text UTCTime) +ping conn = RPC.requestAsync conn pingRequest + +-- | fork creates a new thread, incrementing the counter in the given mVar. +fork mVar action = do + fork' <- newEmptyTMVarIO + atomically $ do + forks <- takeTMVar mVar + putTMVar mVar $ fork' : forks + forkFinally + action + (\_ -> do + print "exiting fork" + atomically $ putTMVar fork' () + ) + +-- | wait returns when all forks have completed. +waitForkGroup mVar threads = do + cs <- atomically $ takeTMVar mVar + case cs of + [] -> return () + m : ms -> do + atomically $ putTMVar mVar ms + atomically $ takeTMVar m + forM_ threads killThread + waitForkGroup mVar [] + +forwardMessages :: RPC.Stream -> RPC.Stream -> IO (Either Text ()) +forwardMessages s1 s2 = do + mMsg <- RPC.readStream s1 + case mMsg of + Nothing -> return $ Right () + Just msg -> do + res <- RPC.writeStream s2 def msg + case res of + Left err -> return $ Left err + Right _ -> forwardMessages s1 s2 + +type Tunnel = (Ssb.PublicKey, Ssb.PublicKey) + +newTunnel arg1 arg2 = + -- the tunnel's entries need to be ordered to make pairs unique + if arg1 < arg2 then (arg1, arg2) else (arg2, arg1) + +createTunnel :: Room -> (RPC.Stream, RPC.Stream) -> IO (Either Text ()) +createTunnel room (stream1, stream2) = do + let peer1 = RPC.connPeer $ RPC.conn stream1 + let peer2 = RPC.connPeer $ RPC.conn stream2 + --print + -- $ "creating tunnel for " + -- <> Ssb.formatPublicKey peer1 + -- <> " <-> " + -- <> Ssb.formatPublicKey peer2 + let tunnel = newTunnel peer1 peer2 + + exists <- atomically $ do + tunnels' <- readTMVar (tunnels room) + return $ Map.member tunnel tunnels' + if exists + then return $ Left "only one tunnel allowed" + else do + waiter <- newTMVarIO [] + thread1 <- fork waiter $ forwardMessages stream1 stream2 + thread2 <- fork waiter $ forwardMessages stream2 stream1 + let threads = [thread1, thread2] + + atomically $ do + endpoints' <- takeTMVar (endpoints room) + let endpoints'' = Map.adjust (addTunnel tunnel) (fst tunnel) endpoints' + let endpoints''' = + Map.adjust (addTunnel tunnel) (snd tunnel) endpoints'' + putTMVar (endpoints room) endpoints''' + + tunnels' <- takeTMVar (tunnels room) + let tunnels'' = Map.insert tunnel threads tunnels' + putTMVar (tunnels room) tunnels'' + + waitForkGroup waiter threads + + atomically $ do + endpoints' <- takeTMVar (endpoints room) + let endpoints'' = + Map.adjust (removeTunnel tunnel) (fst tunnel) endpoints' + let endpoints''' = + Map.adjust (removeTunnel tunnel) (snd tunnel) endpoints'' + putTMVar (endpoints room) endpoints''' + + tunnels' <- takeTMVar (tunnels room) + putTMVar (tunnels room) $ Map.delete tunnel tunnels' + + return . return $ () + where + addTunnel arg (conn, tunnels) = (conn, tunnels ++ [arg]) + removeTunnel arg (conn, tunnels) = (conn, filter (arg /=) tunnels) + +connect :: Room -> RPC.Stream -> ConnectRequest -> IO (Either Text ()) +connect room stream req = do + let tunnel = newTunnel (RPC.connPeer $ RPC.conn $ stream) (target req) + tunnels' <- atomically $ readTMVar (tunnels room) + mPeer <- lookUpPeer room (target req) + let peerConn = do + bool (return ()) errSelfNotAllowed (fst tunnel == snd tunnel) + bool (return ()) errOnlyUniqueTunnel (Map.member tunnel tunnels') + maybeToRight errPeerNotAvailable mPeer + case peerConn of + Left err -> do + print $ "errorr! " <> err + return $ Left err + Right conn -> RPC.request conn (rpcRequest req) + $ \peerStream -> createTunnel room (stream, peerStream) + where + errOnlyUniqueTunnel = Left $ "only unique tunnels are allowed" + errPeerNotAvailable = "peer is not connected" :: Text + errSelfNotAllowed = Left $ "connecting to self not allowed" + rpcRequest arg = + RPC.Request { name = ["tunnel", "connect"], typ = RPC.Duplex, args = [arg] } + +leave' :: Room -> Ssb.PublicKey -> IO (Either Text ()) +leave' room peer = do + endpoints' <- atomically $ readTMVar (endpoints room) + tunnels' <- atomically $ readTMVar (tunnels room) + + let etunnels = fromMaybe mempty $ snd <$> Map.lookup peer endpoints' + forM_ etunnels $ \tunnel -> do + let threads = fromMaybe mempty $ Map.lookup tunnel tunnels' + forM_ threads killThread + unregisterPeer room peer + return . return $ () + +instance RPC.Handler Room where + endpoints _ = + [ RPC.Endpoint ["tunnel", "announce"] RPC.Async + , RPC.Endpoint ["tunnel", "connect"] RPC.Duplex + , RPC.Endpoint ["tunnel", "endpoints"] RPC.Async + , RPC.Endpoint ["tunnel", "leave"] RPC.Async + , RPC.Endpoint ["tunnel", "isRoom"] RPC.Async + , RPC.Endpoint ["tunnel", "ping"] RPC.Async + ] + + -- The announce and leave endpoints are defined by the server's JS, but + -- never called by the client. The official (NodeJS) project uses + -- disconnect notifications from the SSB-server. + -- + serve h (RPC.Endpoint ["tunnel", "announce"] _) args stream = + registerPeer h stream + + -- should decode request + serve h (RPC.Endpoint ["tunnel", "connect"] RPC.Duplex) args stream = do + let args' = + decodeJSON (toS $ Aeson.encode args) :: Either + Text + [ConnectRequest] + case args' of + Left err -> return $ Left err + Right [connReq] -> connect h stream connReq + otherwise -> return $ Left "bad target argument" + + serve h (RPC.Endpoint ["tunnel", "endpoints"] _) _ stream = do + err <- registerPeer h stream + case err of + Left msg -> return $ Left msg + Right _ -> do + change <- atomically $ dupTChan $ notifyChange h + while $ do + endpoints' <- getEndpoints h + let resp = filter (/= RPC.peer stream) endpoints' + res <- RPC.writeStreamJSON stream resp + if isLeft res then return False else atomically $ readTChan change + return $ Right () + where + while f = do + continue <- f + if continue then (while f) else return False + + serve room (RPC.Endpoint ["tunnel", "leave"] _) _ stream = + leave' room (RPC.peer stream) + + serve room (RPC.Endpoint ["tunnel", "isRoom"] _) _ stream = + RPC.writeStreamJSON stream (newIsRoomResponse room) + + serve room (RPC.Endpoint ["tunnel", "ping"] _) _ stream = do + resp <- getCurrentTime + RPC.writeStreamJSON stream resp + + serve room endpoint@otherwise arg stream = (RPC.notFoundHandlerFunc endpoint) arg stream + + notifyConnect _ _ = return . return $ () + + notifyDisconnect room peer = do + _ <- leave' room peer + return . return $ () diff --git a/src/Ssb/Peer/RPC/WhoAmI.hs b/src/Ssb/Peer/RPC/WhoAmI.hs new file mode 100644 index 0000000..be979fc --- /dev/null +++ b/src/Ssb/Peer/RPC/WhoAmI.hs @@ -0,0 +1,49 @@ +-- | This module implements Scuttlebutt's Remote Procedure Call for +-- Ping. +-- +-- For more information kindly refer [WHERE] + +-- TODO: Update above documentation + +module Ssb.Peer.RPC.WhoAmI where + +import Protolude hiding ( Identity ) +import Data.Aeson as Aeson (FromJSON,ToJSON) + +import qualified Ssb.Identity as Ssb +import qualified Ssb.Feed as Feed +import qualified Ssb.Peer.RPC as RPC + +whoAmIRequest :: RPC.Request [Text] +whoAmIRequest = RPC.Request { + name = ["whoami"] + , typ = RPC.Async + , args = [] + } + +newtype WhoAmIResponse = WhoAmIResponse + { id :: Feed.FeedID + } deriving (Eq,Generic,Show) + +instance FromJSON WhoAmIResponse + +instance ToJSON WhoAmIResponse + +whoAmI + :: RPC.ConnState + -> IO (Either Text WhoAmIResponse) +whoAmI conn = RPC.requestAsync conn whoAmIRequest + +newtype Handler = Handler () + +newHandler :: Handler +newHandler = Handler () + +instance RPC.Handler Handler where + endpoints h = [RPC.Endpoint ["whoami"] RPC.Async] + + serve (Handler ssbID) (RPC.Endpoint ["whoami"] RPC.Async) _ stream = + RPC.writeStreamJSON stream (WhoAmIResponse $ Feed.FeedID (RPC.peer stream)) + + notifyConnect _ _ = return . return $ () + notifyDisconnect _ _ = return . return $ () diff --git a/src/Ssb/Peer/SecretHandshake.hs b/src/Ssb/Peer/SecretHandshake.hs new file mode 100644 index 0000000..4245542 --- /dev/null +++ b/src/Ssb/Peer/SecretHandshake.hs @@ -0,0 +1,695 @@ +-- | This module implements Scuttlebutt's Secret Handshake. +-- +-- For more information kindly refer the to protocol guide +-- https://ssbc.github.io/scuttlebutt-protocol-guide + +-- | TODO: Take care of possible import loop +-- | TODO: Optimize handling of PublicKey (extractPublicKey) + +module Ssb.Peer.SecretHandshake where + +import Protolude hiding ( Identity ) +import qualified Data.ByteString as BS +import Data.Default +import qualified Crypto.Hash.SHA256 as SHA256 +import qualified Crypto.Saltine.Core.ScalarMult + as ScalarMult +import qualified Crypto.Saltine.Class as Nacl +import qualified Crypto.Saltine.Core.Auth as Auth +import qualified Crypto.Saltine.Core.Box as Box +import qualified Crypto.Saltine.Core.SecretBox as SecretBox +import qualified Crypto.Saltine.Core.Sign as Sign + +import Ssb.Network +import Ssb.Identity +import qualified Sodium + +-- | ChallengeLength is the length of a challenge message in bytes +challengeLength :: Int +challengeLength = 64 + +-- | ClientAuthLength is the length of a clientAuth message in bytes +clientAuthLength :: Int +clientAuthLength = 16 + 32 + 64 + +-- | ServerAcceptLength is the length of a serverAccept message in bytes +serverAcceptLength :: Int +serverAcceptLength = 16 + 64 + +-- | MACLength is the length of a MAC in bytes +macLength :: Int +macLength = 16 + +-- | NetworkIdentifier defines which of the possible networks is being used. +-- Most traffic is on MainNet, and others may be used for testing purposes. +type NetworkIdentifier = ByteString + +type SharedSecret = ScalarMult.GroupElement + +-- | SharedSecrets are the result of Scuttlebutt's handshake +-- TODO: make shared secrets readable and showable +data SharedSecrets = SharedSecrets + { network :: NetworkIdentifier + , secreta :: Maybe PublicKey + , secretA :: Maybe PublicKey + , secretb :: Maybe PublicKey + , secretB :: Maybe PublicKey + , secretab :: SharedSecret + , secretaB :: SharedSecret + , secretAb :: SharedSecret + , serverHMAC :: Auth.Authenticator + } + +-- | ConnStatus defines the progress of the handshake. +data ConnStatus = + StartingHandshake + | AwaitingClientHello + | AwaitingServerHello + | AwaitingClientAuthentication + | AwaitingServerAccept + | HandshakeComplete + deriving Show + +-- | Message sent between Scuttlebutt peers. +-- TODO: Add encoding and processing of remaining messages +data Message = + ClientHello Auth.Authenticator -- | Client's HMAC + PublicKey -- | Client's Ephemeral Public Key + NetworkIdentifier + | ServerHello Auth.Authenticator -- | Server's HMAC + PublicKey -- | Server's Ephemeral Public Key + NetworkIdentifier + -- TODO: Can this be renamed? + | ClientAuthMessage ByteString -- | Detached Signature A + PublicKey -- | Client long term Public Key + | ServerAccept ByteString -- | Detached Signature B + +-- | ConnState holds important details during the connection process. +-- +-- TODO: define a getter method for fields. Is it possible to get the field +-- name for the error message? +data ConnState = ConnState + { connState :: ConnStatus + , networkID :: NetworkIdentifier + , clientPrivateKey :: Maybe PrivateKey + , clientPublicKey :: Maybe PublicKey + , clientEphemeralPrivKey :: Maybe PrivateKey + , clientEphemeralPubKey :: Maybe PublicKey + , clientHMAC :: Maybe Auth.Authenticator + , serverPrivateKey :: Maybe PrivateKey + , serverPublicKey :: Maybe PublicKey + , serverEphemeralPrivKey :: Maybe PrivateKey + , serverEphemeralPubKey :: Maybe PublicKey + , serverHMAC :: Maybe Auth.Authenticator + , sharedSecretab :: Maybe SharedSecret + , sharedSecretaB :: Maybe SharedSecret + , sharedSecretAb :: Maybe SharedSecret + , detachedSignatureA :: Maybe ByteString + , detachedSignatureB :: Maybe ByteString + } + +-- | TODO: confirm use of default +instance Default ConnState where + def = ConnState { connState = StartingHandshake + , networkID = "" + , clientPrivateKey = def + , clientPublicKey = def + , clientEphemeralPrivKey = def + , clientEphemeralPubKey = def + , clientHMAC = def + , serverPrivateKey = def + , serverPublicKey = def + , serverEphemeralPrivKey = def + , serverEphemeralPubKey = def + , serverHMAC = def + , sharedSecretab = def + , sharedSecretaB = def + , sharedSecretAb = def + , detachedSignatureA = def + , detachedSignatureB = def + } + +must :: Text -> Maybe a -> Either Text a +must field = maybeToEither ("missing " <> field) + +-- | Create the state for initiating a Handshake given the Scuttlebutt User's key pair. +newClientConnState + :: NetworkIdentifier + -> Identity + -> PublicKey + -> IO ConnState +newClientConnState network clientID serverPubKey = do + let clientPrivKey = Ssb.Identity.privateKey clientID + let clientPubKey = Ssb.Identity.publicKey clientID + + (ephPrivKey, ephPubKey) <- Box.newKeypair + return $ def { connState = StartingHandshake + , networkID = network + , clientPrivateKey = clientPrivKey + , clientPublicKey = Just clientPubKey + , clientEphemeralPrivKey = Just $ PrivateKey (Nacl.encode ephPrivKey) + , clientEphemeralPubKey = Just $ PublicKey (Nacl.encode ephPubKey) + , serverPublicKey = Just serverPubKey + } +-- | Create the state for initiating a Handshake given the Scuttlebutt User's key pair. +newServerConnState + :: NetworkIdentifier + -> Identity + -> IO ConnState +newServerConnState network serverID = do + let serverPrivKey = Ssb.Identity.privateKey serverID + let serverPubKey = Ssb.Identity.publicKey serverID + + (ephPrivKey, ephPubKey) <- Box.newKeypair + return $ def { connState = AwaitingClientHello + , networkID = network + , serverEphemeralPrivKey = Just $ PrivateKey (Nacl.encode ephPrivKey) + , serverEphemeralPubKey = Just $ PublicKey (Nacl.encode ephPubKey) + , serverPrivateKey = serverPrivKey + , serverPublicKey = Just serverPubKey + } + +-- | Create shared secrets given the Handshake's final connection state. +newSharedSecrets :: ConnState -> Either Text SharedSecrets +newSharedSecrets state = do + ssab <- must "secret key ab" $ sharedSecretab state + ssaB <- must "secret key aB" $ sharedSecretaB state + ssAb <- must "secret key Ab" $ sharedSecretAb state + serverHMAC' <- must "secret HMAC" $ serverHMAC (state :: ConnState) + return $ SharedSecrets { network = networkID state + , secreta = clientEphemeralPubKey state + , secretA = clientPublicKey state + , secretb = serverEphemeralPubKey state + , secretB = serverPublicKey state + , secretab = ssab + , secretaB = ssaB + , secretAb = ssAb + , serverHMAC = serverHMAC' + } + +newClientAuthMessage :: ConnState -> Either Text Message +newClientAuthMessage state = do + let network = networkID state + cliLTPrivKey <- must "client Private Key" $ clientPrivateKey state + cliLTPubKey <- must "client Public Key" $ clientPublicKey state + srvLTPubKey <- must "server Public Key" $ serverPublicKey state + + cliEphPrivKey <- must "client Private Key" $ clientEphemeralPrivKey state + srvEphPubKey <- must "server Ephemeral Public Key" + $ serverEphemeralPubKey state + + ssab <- must "shared secret ab" $ sharedSecretab state + + detachedSignatureA <- newDetachedSignatureA network + srvLTPubKey + ssab + cliLTPrivKey + + return $ ClientAuthMessage detachedSignatureA + cliLTPubKey + +newClientHello :: ConnState -> Either Text Message +newClientHello state = do + cliEphPubKey <- maybeToEither noKeyMsg $ clientEphemeralPubKey state + key <- maybeToEither badNetMsg $ Nacl.decode (networkID state) + let auth = Auth.auth key (extractPublicKey cliEphPubKey) + return $ ClientHello auth cliEphPubKey (networkID state) + where + badNetMsg = "badly formatted Network Identifier" + noKeyMsg = "clientEphemeralKey required" + +decodeClientHello :: ConnState -> ByteString -> Either Text Message +decodeClientHello state buf = do + let network = networkID state + let (hmacbuf, cliEphPubKey) = BS.splitAt 32 buf + + key <- maybeToEither badNetMsg $ Nacl.decode network + auth <- maybeToEither badHMACMsg $ Nacl.decode hmacbuf + let msg = cliEphPubKey + + if Auth.verify key auth msg + then Right $ ClientHello auth (PublicKey cliEphPubKey) network + else Left badVerificationMsg + where + badNetMsg = "badly formatted Network Identifier" + badHMACMsg = "badly formatted server HMAC" + badPubKeyMsg = "badly formatted server Public Key" + badVerificationMsg = "verification failed" + +-- TODO: check if its possible to change the function depending on the return type. + +newServerHello :: ConnState -> Either Text Message +newServerHello state = do + srvEphPubKey <- maybeToEither noKeyMsg $ serverEphemeralPubKey state + key <- maybeToEither badNetMsg $ Nacl.decode (networkID state) + let auth = Auth.auth key (extractPublicKey srvEphPubKey) + return $ ServerHello auth srvEphPubKey (networkID state) + where + badNetMsg = "badly formatted Network Identifier" + noKeyMsg = "clientEphemeralKey required" + +decodeServerHello :: ConnState -> ByteString -> Either Text Message +decodeServerHello state buf = do + let network = networkID state + let (hmacbuf, srvEphPubKey) = BS.splitAt 32 buf + + key <- maybeToEither badNetMsg $ Nacl.decode network + auth <- maybeToEither badHMACMsg $ Nacl.decode hmacbuf + let msg = srvEphPubKey + + if Auth.verify key auth msg + then Right $ ServerHello auth (PublicKey srvEphPubKey) network + else Left badVerificationMsg + where + badNetMsg = "badly formatted Network Identifier" + badHMACMsg = "badly formatted server HMAC" + badPubKeyMsg = "badly formatted server Public Key" + badVerificationMsg = "verification failed" + +decodeClientAuthMessage :: ConnState -> ByteString -> Either Text Message +decodeClientAuthMessage state buf = do + let network = networkID state + serverPublicKey <- must "serverPublicKey" $ serverPublicKey state + sharedSecretab <- must "sharedSecretab" $ sharedSecretab state + sharedSecretaB <- must "sharedSecretaB" $ sharedSecretaB state + + key <- + naclDecode "key" + $ SHA256.hash + $ network + <> Nacl.encode sharedSecretab + <> Nacl.encode sharedSecretaB + let nonce = Nacl.zero + msg3 <- maybeToEither "could not open secret box" + $ SecretBox.secretboxOpen key nonce buf + + -- TODO: Make the client auth message length a constant + msg3 <- if (BS.length msg3 == 96) + then (return msg3) + else (Left badMessageLength) + let detachedSignatureA = BS.take 64 msg3 + clientLongTermPubKey <- naclDecode "client Long Term Public Key" + $ BS.drop 64 msg3 + + let msg = + (network :: ByteString) + <> (extractPublicKey serverPublicKey) + <> SHA256.hash (Nacl.encode sharedSecretab) + if Sign.signVerifyDetached + clientLongTermPubKey + detachedSignatureA + msg + then Right state {connState = HandshakeComplete} + else Left "client verification failed" + + return $ ClientAuthMessage + detachedSignatureA + (PublicKey $ Nacl.encode clientLongTermPubKey) + where + badMessageLength = "unexpected length of Client Authentication Message" + naclDecode msg = maybeToEither msg . Nacl.decode + +newServerAccept :: ConnState -> Either Text Message +newServerAccept state = do + detachedSignatureB' <- maybeToEither noSigB (detachedSignatureB state) + return $ ServerAccept detachedSignatureB' + where + noSigB = "detachedSignatureB required" + +decodeServerAccept :: ConnState -> ByteString -> Either Text Message +decodeServerAccept state buf = do + let network = networkID state + sharedSecretab <- must "sharedSecretab" $ sharedSecretab state + sharedSecretaB <- must "sharedSecretaB" $ sharedSecretaB state + sharedSecretAb <- must "sharedSecretAb" $ sharedSecretAb state + + key <- + naclDecode "key" + $ SHA256.hash + $ network + <> Nacl.encode sharedSecretab + <> Nacl.encode sharedSecretaB + <> Nacl.encode sharedSecretAb + let nonce = Nacl.zero + + detachedSignatureB <- secretBoxOpen key nonce buf + return $ ServerAccept detachedSignatureB + where + naclDecode msg = + maybeToEither ("could not decode " <> msg :: Text) . Nacl.decode + secretBoxOpen key nonce msg = + maybeToEither "could not open secret box" + $ SecretBox.secretboxOpen key nonce msg + +-- | generate a signature used in the Client Authentication +newDetachedSignatureA + :: NetworkIdentifier + -> Ssb.Identity.PublicKey + -> SharedSecret + -> PrivateKey + -> Either Text ByteString +newDetachedSignatureA network serverLongTermPubKey sharedSecretab clientLongTermPrivKey + = do + clientLongTermPrivKey' <- maybeToEither badCliKeyMsg + $ Nacl.decode $ extractPrivateKey clientLongTermPrivKey + let secretChecksum = SHA256.hash $ Nacl.encode sharedSecretab + let msg = + (network :: ByteString) + <> extractPublicKey serverLongTermPubKey + <> (secretChecksum :: ByteString) + return $ Sign.signDetached clientLongTermPrivKey' msg + where + badSrvKeyMsg = "badly encoded long term server public key" + badCliKeyMsg = "badly encoded long term client private key" + +calcSharedSecretab :: PrivateKey -> PublicKey -> Either Text SharedSecret +calcSharedSecretab cliEphPrivKey srvEphPubKey = do + cliEphPrivKey' <- maybeToEither "badly formatted client ephemeral private key" + $ Nacl.decode $ extractPrivateKey cliEphPrivKey + srvEphPubKey' <- maybeToEither "badly formatted server ephemeral public key" + $ Nacl.decode $ extractPublicKey srvEphPubKey + return $ ScalarMult.mult cliEphPrivKey' srvEphPubKey' + +-- | generate a signature used in the Server acknowledgement +newDetachedSignatureB + :: NetworkIdentifier + -> ByteString + -> PublicKey + -> SharedSecret + -> PrivateKey + -> Either Text ByteString +newDetachedSignatureB network detachedSignatureA clientPublicKey sharedSecretab serverPrivateKey = do + key <- naclDecode badPrivkey $ extractPrivateKey serverPrivateKey + let msg = + (network :: ByteString) + <> detachedSignatureA + <> (extractPublicKey clientPublicKey) + <> SHA256.hash (Nacl.encode sharedSecretab) + return $ Sign.signDetached key msg + where + badPrivkey = "badly formatted private key" + naclDecode msg = maybeToEither msg . Nacl.decode + + +-- | Server Longterm PK should be converted to curve25519 +-- Does not look like a problem given the Golang code +-- TODO: Implement type conversion here +clientCalcSharedSecretaB :: PrivateKey -> PublicKey -> Either Text SharedSecret +clientCalcSharedSecretaB clientEphemeralSK serverLongtermPK = do + cliEphPrivKey' <- maybeToEither "badly formatted client ephemeral private key" + $ Nacl.decode $ extractPrivateKey clientEphemeralSK + srvLTPubKey' <- maybeToEither "badly formatted server long term public key" + $ Nacl.decode $ extractPublicKey serverLongtermPK + curvePublicKey <- + maybeToEither "badly formatted curve25519" + $ Nacl.decode . Nacl.encode $ Sodium.publicKeyToCurve25519 srvLTPubKey' + return $ ScalarMult.mult cliEphPrivKey' curvePublicKey + +serverCalcSharedSecretaB :: PrivateKey -> PublicKey -> Either Text SharedSecret +serverCalcSharedSecretaB serverLongtermSK clientEphemeralPK = do + srvLTPrivKey' <- maybeToEither "badly formatted server long term private key" + $ Nacl.decode $ extractPrivateKey serverLongtermSK + cliEphPubKey' <- maybeToEither "badly formatted client ephemeral public key" + $ Nacl.decode $ extractPublicKey clientEphemeralPK + curvePrivKey <- + maybeToEither "badly formatted curve25519" + $ Nacl.decode . Nacl.encode $ Sodium.secretKeyToCurve25519 srvLTPrivKey' + return $ ScalarMult.mult curvePrivKey cliEphPubKey' + +calcSharedSecretAb :: PrivateKey -> PublicKey -> Either Text SharedSecret +calcSharedSecretAb clientLongTermPrivKey serverEphemeralPubKey = do + cliLTPrivKey' <- naclDecode "badly formatted client long term private key" + $ extractPrivateKey clientLongTermPrivKey + curveSecretKey <- + naclDecode "badly formatted curve25519" + . Nacl.encode + $ Sodium.secretKeyToCurve25519 cliLTPrivKey' + srvEphPubKey' <- naclDecode "badly formatted server ephemeral public key" + $ extractPublicKey serverEphemeralPubKey + return $ ScalarMult.mult curveSecretKey srvEphPubKey' + where naclDecode msg = maybeToEither msg . Nacl.decode + +serverCalcSharedSecretAb + :: PrivateKey + -> PublicKey + -> Either Text SharedSecret +serverCalcSharedSecretAb serverEphemeralPrivKey clientLongTermPubKey = do + srvEphPrivKey' <- naclDecode "here bad formatted server long term private key" + $ extractPrivateKey serverEphemeralPrivKey + cliLTPubKey' <- naclDecode "badly formatted client long term public key" + $ extractPublicKey clientLongTermPubKey + curvePublicKey <- + naclDecode "badly formatted curve25519" + . Nacl.encode + $ Sodium.publicKeyToCurve25519 cliLTPubKey' + return $ ScalarMult.mult srvEphPrivKey' curvePublicKey + where + naclDecode msg = maybeToEither msg . Nacl.decode + + +-- | encode and serialize the message in preparation to send to peer. +encode :: ConnState -> Message -> Either Text ByteString +encode state msg = case msg of + ClientHello auth pubKey network -> do + cliEphPubKey <- maybeToEither noKeyMsg $ clientEphemeralPubKey state + return $ Nacl.encode auth <> extractPublicKey cliEphPubKey + where + noKeyMsg = "clientEphemeralKey required" + ServerHello auth pubKey network -> do + return $ Nacl.encode auth <> extractPublicKey pubKey + where + noKeyMsg = "clientEphemeralKey required" + ClientAuthMessage dSigA cliLTPubKey -> do + let network = networkID state + ssab <- must "shared secret ab" $ sharedSecretab state + ssaB <- must "shared secret aB" $ sharedSecretaB state + + key <- + maybeToEither badKeyMsg + $ Nacl.decode + $ SHA256.hash + $ network + <> Nacl.encode ssab + <> Nacl.encode ssaB + let nonce = Nacl.zero + let msg = dSigA <> extractPublicKey cliLTPubKey + return $ SecretBox.secretbox key nonce msg + where badKeyMsg = "clientEphemeralKey required" + ServerAccept detachedSignatureB -> do + let network = networkID state + ssab <- must "shared secret ab" $ sharedSecretab state + ssaB <- must "shared secret aB" $ sharedSecretaB state + ssAb <- must "shared secret Ab" $ sharedSecretAb state + + key <- + maybeToEither badKeyMsg + $ Nacl.decode + $ SHA256.hash + $ ((network :: ByteString) + <> Nacl.encode ssab + <> Nacl.encode ssaB + <> Nacl.encode ssAb) + let nonce = Nacl.zero + let msg = detachedSignatureB + return $ SecretBox.secretbox key nonce msg + where badKeyMsg = "clientEphemeralKey required" + +-- | update the connection state and return any reponse message for the peer. +-- TODO: Process secretAb +process :: ConnState -> Message -> IO (Either Text (ConnState, Maybe Message)) +process state (ClientHello hmac cliEphPubKey network) = do + stateUpdate <- return $ do + srvLTPrivKey <- must "server Private Key" + $ serverPrivateKey state + srvEphPrivKey <- must "server ephemeral Private Key" + $ serverEphemeralPrivKey state + + ssab <- calcSharedSecretab srvEphPrivKey cliEphPubKey + -- TODO: srvLTPubKey should be curved Process sk_to_curve25519 + ssaB <- serverCalcSharedSecretaB srvLTPrivKey cliEphPubKey + + return $ state { connState = AwaitingClientAuthentication + , clientEphemeralPubKey = Just cliEphPubKey + , serverHMAC = Just hmac + , sharedSecretab = Just ssab + , sharedSecretaB = Just ssaB + } + return $ stateUpdate >>= \state' -> case newServerHello state' of + Right msg' -> return $ (state', Just msg') + Left err -> Left err + +process state (ServerHello hmac ephPubKey network) = do + stateUpdate <- return $ do + cliLTPrivKey <- must "client Private Key" $ clientPrivateKey state + cliEphPrivKey <- must "ephemeral client Private Key" + $ clientEphemeralPrivKey state + srvLTPubKey <- must "server Public Key" $ serverPublicKey state + + ssab <- calcSharedSecretab cliEphPrivKey ephPubKey + ssaB <- clientCalcSharedSecretaB cliEphPrivKey srvLTPubKey + ssAb <- calcSharedSecretAb cliLTPrivKey ephPubKey + + return $ state { connState = AwaitingServerAccept + , serverHMAC = Just hmac + , serverEphemeralPubKey = Just ephPubKey + , sharedSecretab = Just ssab + , sharedSecretaB = Just ssaB + , sharedSecretAb = Just ssAb + } + return $ stateUpdate >>= \state' -> case newClientAuthMessage state' of + Right msg' -> return $ (state', Just msg') + Left err -> Left err + +process state (ClientAuthMessage detachedSignatureA clientLongTermPubKey) = do + stateUpdate <- return $ do + let network = networkID state + srvPrivKey <- must "server private key" + $ serverPrivateKey state + srvEphPrivKey <- must "server Long Term ephemeral private key" + $ serverEphemeralPrivKey state + sharedSecretAb <- serverCalcSharedSecretAb srvEphPrivKey clientLongTermPubKey + sharedSecretab <- must "shared secret ab" $ sharedSecretab state + + detachedSignatureB <- newDetachedSignatureB + network + detachedSignatureA + clientLongTermPubKey + sharedSecretab + srvPrivKey + + return $ state { connState = HandshakeComplete + , detachedSignatureA = Just detachedSignatureA + , detachedSignatureB = Just detachedSignatureB + , clientPublicKey = Just clientLongTermPubKey + , sharedSecretAb = Just sharedSecretAb + } + return $ stateUpdate >>= \state' -> case newServerAccept state' of + Right msg' -> return $ (state', Just msg') + Left err -> Left err + +process state (ServerAccept dSigB) = do + stateUpdate <- return $ do + let network = networkID state + cliLTPrivKey <- must "client private key" $ clientPrivateKey state + cliLTPubKey <- must "client public key" $ clientPublicKey state + srvLTPubKey <- must "server public key" $ serverPublicKey state + ssab <- must "shared secret ab" $ sharedSecretab state + + detachedSignatureA <- newDetachedSignatureA network + srvLTPubKey + ssab + cliLTPrivKey + + keyBuf <- must "server Public Key" $ serverPublicKey state + key <- maybeToEither "badly formatted public key" $ Nacl.decode $ extractPublicKey keyBuf + let msg = network <> detachedSignatureA <> extractPublicKey cliLTPubKey <> SHA256.hash (Nacl.encode ssab) + + if Sign.signVerifyDetached key dSigB msg + then Right state {connState = HandshakeComplete} + else Left "server verification failed" + + return $ case stateUpdate of + Right state' -> return (state', Nothing) + Left err -> Left err + +-- TODO: Investigate a better way to separate network from handshake logic +type ReadFn = Int -> IO (Maybe ByteString) +type SendFn = ByteString -> IO () + +-- | readMsg decodes the next expected message from the byte stream. +readMsg :: ConnState -> ReadFn -> IO (Either Text Message) +readMsg state read = case connState state of + AwaitingClientHello -> do + mbuf <- read' challengeLength + return $ do + buf <- mbuf + decodeClientHello state buf + AwaitingServerHello -> do + mbuf <- read' challengeLength + return $ do + buf <- mbuf + decodeServerHello state buf + AwaitingClientAuthentication -> do + mbuf <- read' 112 + return $ do + buf <- mbuf + decodeClientAuthMessage state buf + AwaitingServerAccept -> do + mbuf <- read' serverAcceptLength + return $ do + buf <- mbuf + decodeServerAccept state buf + _ -> return $ Left "unknown state" + where read' len = maybeToEither "connection broken" <$> read len + +-- TODO: use Either instead +sendMsg :: SendFn -> ConnState -> Message -> IO () +sendMsg send state msg = do + case encode state msg of + Left err -> die err + Right buf -> send buf + +-- | startHandshake initializes the connection with the Scuttlebutt peer +-- returning the new shared secrets upon completion. +startHandshake + :: SendFn + -> ReadFn + -> NetworkIdentifier + -> Identity + -> PublicKey + -> IO (Either Text SharedSecrets) +startHandshake send recv network clientID srvPubKey = do + state <- newClientConnState network clientID srvPubKey + let clientHello = fromRight undefined (newClientHello state) + let state' = state { connState = AwaitingServerHello } + finalState <- loop state' recv (Just clientHello) + return $ finalState >>= newSharedSecrets + where + loop :: ConnState -> ReadFn -> Maybe Message -> IO (Either Text ConnState) + loop state _ Nothing = return . return $ state + loop state recv (Just msg) = do + sendMsg send state msg + case connState state of + HandshakeComplete -> return . return $ state + _ -> do + resp <- readMsg state recv + case resp of + Left err -> return $ Left err + Right msg -> do + res <- process state msg + case res of + Left err -> return $ Left ("handshake error while connecting to peer: " <> err) + Right (state', msg') -> + loop state' recv msg' + +welcomeHandshake + :: SendFn + -> ReadFn + -> NetworkIdentifier + -> Identity + -> IO (Either Text SharedSecrets) +welcomeHandshake send recv network serverID = do + state <- newServerConnState network serverID + finalState <- loop state + return $ finalState >>= newSharedSecrets + where + loop :: ConnState -> IO (Either Text ConnState) + loop state = do + msg <- readMsg state recv + case msg of + Left err -> return $ Left err + Right msg -> do + res <- process state msg + case res of + Left err -> return $ Left $ "handshake failed: " <> err + Right (state', msg') -> + case msg' of + Nothing -> return $ Right state' + Just msg'' -> do + + sendMsg send state' msg'' + case connState state' of + HandshakeComplete -> return $ Right state' + _ -> loop state' + + diff --git a/src/Ssb/Peer/TCP.hs b/src/Ssb/Peer/TCP.hs new file mode 100644 index 0000000..27e3f07 --- /dev/null +++ b/src/Ssb/Peer/TCP.hs @@ -0,0 +1,100 @@ +-- | This module implements basic TCP connectivity for Scuttlebutt. + +module Ssb.Peer.TCP where + +import Protolude hiding ( Identity ) +import Data.Maybe ( fromJust ) +import qualified Network.Simple.TCP as TCP + +import Ssb.Aux +import qualified Ssb.Identity as Ssb +import Ssb.Network +import Ssb.Peer +import qualified Ssb.Peer.BoxStream as BoxStream +import qualified Ssb.Peer.SecretHandshake as SH +import qualified Ssb.Peer.RPC as RPC + +connectBoxStream + :: Host + -> Port + -> NetworkIdentifier + -> Ssb.Identity + -> Ssb.Identity + -> (BoxStream.Conn -> IO ()) + -> IO (Either Text ()) +connectBoxStream host port networkID id peer cmd = + TCP.connect (toS host) (toS port) $ \(socket, addr) -> do + res <- SH.startHandshake (TCP.send socket) + (TCP.recv socket) + networkID + id + (Ssb.publicKey peer) + case res of + Left err -> return $ Left ("client handshake error: " <> err) + Right sharedSecrets -> do + conn <- BoxStream.connectServer socket sharedSecrets + case conn of + Left err -> return (Left err) + Right conn -> return <$> cmd conn + +serveBoxStream + :: Host + -> Port + -> NetworkIdentifier + -> Ssb.Identity + -> (BoxStream.Conn -> Ssb.Identity -> IO ()) + -> IO () +serveBoxStream host port networkID id cmd = + TCP.serve (TCP.Host $ toS host) (toS port) $ \(socket, remoteAddr) -> do + res <- SH.welcomeHandshake (TCP.send socket) + (TCP.recv socket) + networkID + id + case res of + Left err -> print $ "client handshake error: " <> err + Right sharedSecrets -> do + let peerID = (fromMaybe undefined $ SH.secretA sharedSecrets) + conn <- BoxStream.connectServer socket sharedSecrets + case conn of + Left err -> print $ "client error: " <> err + Right conn -> cmd conn (Ssb.Identity Nothing peerID) + +connectRPC + :: RPC.Handler a + => a + -> Host + -> Port + -> NetworkIdentifier + -> Ssb.Identity + -> Ssb.Identity + -> (RPC.ConnState -> IO ()) + -> IO (Either Text ()) +connectRPC handler host port networkID id peer cmd = + TCP.connect (toS host) (toS port) $ \(socket, addr) -> do + res <- SH.startHandshake (TCP.send socket) + (TCP.recv socket) + networkID + id + (Ssb.publicKey peer) + case res of + Left err -> return $ error ("client handshake error: " <> err) + Right sharedSecrets -> do + conn <- BoxStream.connectClient socket sharedSecrets + case conn of + Left err -> return $ Left err + Right conn -> RPC.connect conn handler (Ssb.publicKey peer) cmd + +serveRPC + :: RPC.Handler a + => a + -> Host + -> Port + -> NetworkIdentifier + -> Ssb.Identity + -> IO () +serveRPC handler host port networkID id = + serveBoxStream host port networkID id $ \conn peer -> do + res <- RPC.connect conn handler (Ssb.publicKey peer) (\_ -> return ()) + case res of + Left err -> print $ "RPC error serving client: " <> err + Right _ -> return () |