{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TypeApplications #-} module Control.Distributed.Fork.Lambda.Internal.Invoke ( withInvoke ) where -------------------------------------------------------------------------------- import Control.Concurrent.Async import Control.Concurrent.MVar import Control.Exception.Safe import Control.Lens import Control.Monad import qualified Data.Aeson as A import Data.Aeson.Lens import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL import Data.ByteString.Base64 as B64 import qualified Data.HashMap.Strict as HM import qualified Data.Map.Strict as M import Data.Maybe import Data.Monoid import qualified Data.Text as T import qualified Data.Text.Encoding as T import Network.AWS import Network.AWS.Lambda import Network.AWS.SQS import Text.Read -------------------------------------------------------------------------------- import Control.Concurrent.Throttled import Control.Distributed.Fork.Backend import Control.Distributed.Fork.Lambda.Internal.Stack (StackInfo (..)) -------------------------------------------------------------------------------- {- Since we're going to get our answers asynchronously, we maintain a state with callbacks for individual invocations. Every individual invocation have an incrementing id, so we can distinguish the responses. -} data LambdaState = LambdaState { lsInvocations :: M.Map Int (IO BS.ByteString -> IO ()) , lsNextId :: Int } data LambdaEnv = LambdaEnv { leState :: MVar LambdaState , leStack :: StackInfo , leEnv :: Env } newLambdaEnv :: Env -> StackInfo -> IO LambdaEnv newLambdaEnv env st = LambdaEnv <$> newMVar (LambdaState M.empty 0) <*> return st <*> return env {- When invoking a function, we insert a new id to the state and then call Lambda. -} execute :: LambdaEnv -> Throttle -> BS.ByteString -> BackendM BS.ByteString execute LambdaEnv{..} throttle input = do -- Modify environment mvar <- liftIO $ newEmptyMVar @(IO BS.ByteString) id' <- liftIO $ modifyMVar leState $ \LambdaState{..} -> return ( LambdaState { lsNextId = lsNextId + 1 , lsInvocations = M.insert lsNextId (void . tryPutMVar mvar) lsInvocations } , lsNextId ) -- invoke the lambda function irs <- liftIO $ throttled throttle . runResourceT . runAWS leEnv $ send $ invoke (siFunc leStack) (BL.toStrict . A.encode $ A.object [ (T.pack "d", A.toJSON . T.decodeUtf8 $ B64.encode input) , (T.pack "i", A.toJSON id') ]) & iInvocationType ?~ Event submitted unless (irs ^. irsStatusCode `div` 100 == 2) $ throwIO . InvokeException $ "Invoke failed. Status code: " <> T.pack (show $ irs ^. irsStatusCode) -- wait fo the answer liftIO . join $ readMVar mvar {- And then we listen from answerQueue for the responses -} answerThread :: LambdaEnv -> IO () answerThread LambdaEnv {..} = runResourceT . runAWS leEnv . forever $ do msgs <- sqsReceiveSome $ siAnswerQueue leStack forM_ msgs $ \msg -> do id' <- liftIO $ decodeId msg liftIO . modifyMVar_ leState $ \s -> case M.updateLookupWithKey (\_ _ -> Nothing) id' (lsInvocations s) of (Nothing, _) -> return s (Just x, s') -> s { lsInvocations = s' } <$ x (decodeResponse msg) where decodeId :: Message -> IO Int decodeId msg = case HM.lookup "Id" (msg ^. mMessageAttributes) of Nothing -> throwIO . InvokeException $ "Error decoding answer: can not find Id: " <> T.pack (show msg) Just av -> case readMaybe . T.unpack <$> av ^. mavStringValue of Nothing -> throwIO . InvokeException $ "Error decoding answer: empty Id." Just Nothing -> throwIO . InvokeException $ "Error decoding answer: can not decode Id." Just (Just x) -> return x decodeResponse :: Message -> IO BS.ByteString decodeResponse msg = case B64.decode . T.encodeUtf8 <$> msg ^. mBody of Nothing -> throwIO . InvokeException $ "Error decoding answer: no body." Just (Left err) -> throwIO . InvokeException $ "Error decoding answer: " <> T.pack err Just (Right x) -> return x {- And then we listen from answerQueue for the responses -} deadLetterThread :: LambdaEnv -> IO () deadLetterThread LambdaEnv {..} = runResourceT . runAWS leEnv . forever $ do msgs <- sqsReceiveSome $ siDeadLetterQueue leStack forM_ msgs $ \msg -> do id' <- liftIO $ decodeId msg liftIO . modifyMVar_ leState $ \s -> case M.updateLookupWithKey (\_ _ -> Nothing) id' (lsInvocations s) of (Nothing, _) -> return s (Just x, s') -> s { lsInvocations = s' } <$ x failure where failure :: IO a failure = throwIO . InvokeException $ "Lambda function failed." decodeId :: Message -> IO Int decodeId msg = case msg ^? mBody . _Just . key "i" . _Number of Nothing -> throwIO . InvokeException $ "Can not find Id: " <> T.pack (show msg) Just x -> return $ truncate x {- A helper function to read from SQS queues. -} sqsReceiveSome :: T.Text -> AWS [Message] sqsReceiveSome queue = do rmrs <- send $ receiveMessage queue & rmVisibilityTimeout ?~ 10 & rmWaitTimeSeconds ?~ 10 & rmMaxNumberOfMessages ?~ 10 & rmMessageAttributeNames .~ ["Id"] unless (rmrs ^. rmrsResponseStatus == 200) $ liftIO . throwIO . InvokeException $ "Error receiving messages: " <> T.pack (show $ rmrs ^. rmrsResponseStatus) let msgs = rmrs ^. rmrsMessages unless (null msgs) $ do dmbrs <- send $ deleteMessageBatch queue & dmbEntries .~ [ deleteMessageBatchRequestEntry (T.pack $ show i) (fromJust $ msg ^. mReceiptHandle) | (i, msg) <- zip [(0::Integer)..] msgs ] unless (dmbrs ^. dmbrsResponseStatus == 200) $ liftIO . throwIO . InvokeException $ "Error deleting received messages: " <> T.pack (show $ rmrs ^. rmrsResponseStatus) return msgs -------------------------------------------------------------------------------- withInvoke :: Env -> StackInfo -> ((BS.ByteString -> BackendM BS.ByteString) -> IO a) -> IO a withInvoke env stack f = do le <- newLambdaEnv env stack throttle <- newThrottle 128 let answerT = async . forever $ catchAny (answerThread le) $ \ex -> print ex deadLetterT = async . forever $ catchAny (deadLetterThread le) $ \ex -> print ex threads <- (++) <$> replicateM 4 answerT <*> replicateM 2 deadLetterT f (execute le throttle) `finally` mapM_ cancel threads -------------------------------------------------------------------------------- newtype InvokeException = InvokeException T.Text deriving Show instance Exception InvokeException