{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeApplications #-}
module Control.Distributed.Fork.Lambda.Internal.Invoke
( withInvoke
) where
import Control.Concurrent.Async
import Control.Concurrent.MVar
import Control.Exception.Safe
import Control.Lens
import Control.Monad
import qualified Data.Aeson as A
import Data.Aeson.Lens
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import Data.ByteString.Base64 as B64
import qualified Data.HashMap.Strict as HM
import qualified Data.Map.Strict as M
import Data.Maybe
import Data.Monoid
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Network.AWS
import Network.AWS.Lambda
import Network.AWS.SQS
import Text.Read
import Control.Concurrent.Throttled
import Control.Distributed.Fork.Backend
import Control.Distributed.Fork.Lambda.Internal.Stack (StackInfo (..))
data LambdaState = LambdaState
{ lsInvocations :: M.Map Int (IO BS.ByteString -> IO ())
, lsNextId :: Int
}
data LambdaEnv = LambdaEnv
{ leState :: MVar LambdaState
, leStack :: StackInfo
, leEnv :: Env
}
newLambdaEnv :: Env -> StackInfo -> IO LambdaEnv
newLambdaEnv env st =
LambdaEnv
<$> newMVar (LambdaState M.empty 0)
<*> return st
<*> return env
execute :: LambdaEnv -> Throttle -> BS.ByteString -> BackendM BS.ByteString
execute LambdaEnv{..} throttle input = do
mvar <- liftIO $ newEmptyMVar @(IO BS.ByteString)
id' <- liftIO $ modifyMVar leState $ \LambdaState{..} -> return
( LambdaState { lsNextId = lsNextId + 1
, lsInvocations =
M.insert lsNextId (void . tryPutMVar mvar) lsInvocations
}
, lsNextId
)
irs <- liftIO $ throttled throttle . runResourceT . runAWS leEnv $
send $ invoke
(siFunc leStack)
(BL.toStrict . A.encode $ A.object
[ (T.pack "d", A.toJSON . T.decodeUtf8 $ B64.encode input)
, (T.pack "i", A.toJSON id')
])
& iInvocationType ?~ Event
submitted
unless (irs ^. irsStatusCode `div` 100 == 2) $
throwIO . InvokeException $
"Invoke failed. Status code: " <> T.pack (show $ irs ^. irsStatusCode)
liftIO . join $ readMVar mvar
answerThread :: LambdaEnv -> IO ()
answerThread LambdaEnv {..} = runResourceT . runAWS leEnv . forever $ do
msgs <- sqsReceiveSome $ siAnswerQueue leStack
forM_ msgs $ \msg -> do
id' <- liftIO $ decodeId msg
liftIO . modifyMVar_ leState $ \s ->
case M.updateLookupWithKey (\_ _ -> Nothing) id' (lsInvocations s) of
(Nothing, _) -> return s
(Just x, s') -> s { lsInvocations = s' } <$ x (decodeResponse msg)
where
decodeId :: Message -> IO Int
decodeId msg =
case HM.lookup "Id" (msg ^. mMessageAttributes) of
Nothing -> throwIO . InvokeException $
"Error decoding answer: can not find Id: " <> T.pack (show msg)
Just av -> case readMaybe . T.unpack <$> av ^. mavStringValue of
Nothing -> throwIO . InvokeException $
"Error decoding answer: empty Id."
Just Nothing -> throwIO . InvokeException $
"Error decoding answer: can not decode Id."
Just (Just x) -> return x
decodeResponse :: Message -> IO BS.ByteString
decodeResponse msg =
case B64.decode . T.encodeUtf8 <$> msg ^. mBody of
Nothing -> throwIO . InvokeException $
"Error decoding answer: no body."
Just (Left err) -> throwIO . InvokeException $
"Error decoding answer: " <> T.pack err
Just (Right x) -> return x
deadLetterThread :: LambdaEnv -> IO ()
deadLetterThread LambdaEnv {..} = runResourceT . runAWS leEnv . forever $ do
msgs <- sqsReceiveSome $ siDeadLetterQueue leStack
forM_ msgs $ \msg -> do
id' <- liftIO $ decodeId msg
liftIO . modifyMVar_ leState $ \s ->
case M.updateLookupWithKey (\_ _ -> Nothing) id' (lsInvocations s) of
(Nothing, _) -> return s
(Just x, s') -> s { lsInvocations = s' } <$ x failure
where
failure :: IO a
failure = throwIO . InvokeException $ "Lambda function failed."
decodeId :: Message -> IO Int
decodeId msg = case msg ^? mBody . _Just . key "i" . _Number of
Nothing -> throwIO . InvokeException $
"Can not find Id: " <> T.pack (show msg)
Just x -> return $ truncate x
sqsReceiveSome :: T.Text -> AWS [Message]
sqsReceiveSome queue = do
rmrs <- send $
receiveMessage queue
& rmVisibilityTimeout ?~ 10
& rmWaitTimeSeconds ?~ 10
& rmMaxNumberOfMessages ?~ 10
& rmMessageAttributeNames .~ ["Id"]
unless (rmrs ^. rmrsResponseStatus == 200) $
liftIO . throwIO . InvokeException $
"Error receiving messages: " <> T.pack (show $ rmrs ^. rmrsResponseStatus)
let msgs = rmrs ^. rmrsMessages
unless (null msgs) $ do
dmbrs <- send $ deleteMessageBatch queue
& dmbEntries .~
[ deleteMessageBatchRequestEntry
(T.pack $ show i)
(fromJust $ msg ^. mReceiptHandle)
| (i, msg) <- zip [(0::Integer)..] msgs
]
unless (dmbrs ^. dmbrsResponseStatus == 200) $
liftIO . throwIO . InvokeException $
"Error deleting received messages: " <> T.pack (show $ rmrs ^. rmrsResponseStatus)
return msgs
withInvoke :: Env -> StackInfo -> ((BS.ByteString -> BackendM BS.ByteString) -> IO a) -> IO a
withInvoke env stack f = do
le <- newLambdaEnv env stack
throttle <- newThrottle 128
let answerT = async . forever $ catchAny (answerThread le) $ \ex -> print ex
deadLetterT = async . forever $ catchAny (deadLetterThread le) $ \ex -> print ex
threads <- (++) <$> replicateM 4 answerT <*> replicateM 2 deadLetterT
f (execute le throttle) `finally` mapM_ cancel threads
newtype InvokeException
= InvokeException T.Text
deriving Show
instance Exception InvokeException