{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module Network.PushNotify.APN
( newSession
, newMessage
, newMessageWithCustomPayload
, hexEncodedToken
, rawToken
, sendMessage
, sendSilentMessage
, sendRawMessage
, alertMessage
, bodyMessage
, emptyMessage
, setAlertMessage
, setMessageBody
, setBadge
, setCategory
, setSound
, clearAlertMessage
, clearBadge
, clearCategory
, clearSound
, addSupplementalField
, closeSession
, isOpen
, ApnSession
, JsonAps
, JsonApsAlert
, JsonApsMessage
, ApnMessageResult(..)
, ApnFatalError(..)
, ApnTemporaryError(..)
, ApnToken
) where
import Control.Concurrent
import Control.Concurrent.QSem
import Control.Exception.Lifted (Exception, try, bracket_, throw, throwIO)
import Control.Monad
import Control.Monad.Except
import Data.Aeson
import Data.Aeson.Types
import Data.ByteString (ByteString)
import Data.Char (toLower)
import Data.Default (def)
import Data.Either
import Data.Int
import Data.IORef
import Data.Map.Strict (Map)
import Data.Maybe
import Data.Pool
import Data.Semigroup ((<>))
import Data.Text (Text)
import Data.Time.Clock
import Data.Time.Clock.POSIX
import Data.Typeable (Typeable)
import Data.X509
import Data.X509.CertificateStore
import GHC.Generics
import Network.HTTP2 (ErrorCodeId,
toErrorCodeId)
import Network.HTTP2.Client
import Network.HTTP2.Client.FrameConnection
import Network.HTTP2.Client.Helpers
import Network.TLS hiding (sendData)
import Network.TLS.Extra.Cipher
import System.IO.Error
import System.Mem.Weak
import System.Random
import qualified Data.ByteString as S
import qualified Data.ByteString.Base16 as B16
import qualified Data.ByteString.Lazy as L
import qualified Data.List as DL
import qualified Data.Map.Strict as M
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Network.HPACK as HTTP2
import qualified Network.HTTP2 as HTTP2
data ApnSession = ApnSession
{ apnSessionPool :: !(Pool ApnConnection)
, apnSessionOpen :: !(IORef Bool)
}
data ApnConnectionInfo = ApnConnectionInfo
{ aciCertPath :: !FilePath
, aciCertKey :: !FilePath
, aciCaPath :: !FilePath
, aciHostname :: !Text
, aciMaxConcurrentStreams :: !Int
, aciTopic :: !ByteString }
data ApnConnection = ApnConnection
{ apnConnectionConnection :: !Http2Client
, apnConnectionInfo :: !ApnConnectionInfo
, apnConnectionWorkerPool :: !QSem
, apnConnectionFlowControlWorker :: !ThreadId
, apnConnectionOpen :: !(IORef Bool)}
newtype ApnToken = ApnToken { unApnToken :: ByteString }
class SpecifyError a where
isAnError :: IOError -> a
rawToken
:: ByteString
-> ApnToken
rawToken = ApnToken . B16.encode
hexEncodedToken
:: Text
-> ApnToken
hexEncodedToken = ApnToken . B16.encode . fst . B16.decode . TE.encodeUtf8
data ApnException = ApnExceptionHTTP ErrorCodeId
| ApnExceptionJSON String
| ApnExceptionMissingHeader HTTP2.HeaderName
| ApnExceptionUnexpectedResponse
deriving (Show, Typeable)
instance Exception ApnException
data ApnMessageResult = ApnMessageResultOk
| ApnMessageResultBackoff
| ApnMessageResultFatalError ApnFatalError
| ApnMessageResultTemporaryError ApnTemporaryError
| ApnMessageResultIOError IOError
| ApnMessageResultClientError ClientError
deriving (Eq, Show)
data JsonApsAlert = JsonApsAlert
{ jaaTitle :: !(Maybe Text)
, jaaBody :: !Text
} deriving (Generic, Show)
instance ToJSON JsonApsAlert where
toJSON = genericToJSON defaultOptions
{ fieldLabelModifier = drop 3 . map toLower
, omitNothingFields = True
}
data JsonApsMessage
= JsonApsMessage
{ jamAlert :: !(Maybe JsonApsAlert)
, jamBadge :: !(Maybe Int)
, jamSound :: !(Maybe Text)
, jamCategory :: !(Maybe Text)
} deriving (Generic, Show)
emptyMessage :: JsonApsMessage
emptyMessage = JsonApsMessage Nothing Nothing Nothing Nothing
setSound
:: Text
-> JsonApsMessage
-> JsonApsMessage
setSound s a = a { jamSound = Just s }
clearSound
:: JsonApsMessage
-> JsonApsMessage
clearSound a = a { jamSound = Nothing }
setCategory
:: Text
-> JsonApsMessage
-> JsonApsMessage
setCategory c a = a { jamCategory = Just c }
clearCategory
:: JsonApsMessage
-> JsonApsMessage
clearCategory a = a { jamCategory = Nothing }
setBadge
:: Int
-> JsonApsMessage
-> JsonApsMessage
setBadge i a = a { jamBadge = Just i }
clearBadge
:: JsonApsMessage
-> JsonApsMessage
clearBadge a = a { jamBadge = Nothing }
alertMessage
:: Text
-> Text
-> JsonApsMessage
alertMessage title text = setAlertMessage title text emptyMessage
bodyMessage
:: Text
-> JsonApsMessage
bodyMessage text = setMessageBody text emptyMessage
setAlertMessage
:: Text
-> Text
-> JsonApsMessage
-> JsonApsMessage
setAlertMessage title text a = a { jamAlert = Just jam }
where
jam = JsonApsAlert (Just title) text
setMessageBody
:: Text
-> JsonApsMessage
-> JsonApsMessage
setMessageBody text a = a { jamAlert = Just newJaa }
where
newJaa = case jamAlert a of
Nothing -> JsonApsAlert Nothing text
Just jaa -> jaa { jaaBody = text }
clearAlertMessage
:: JsonApsMessage
-> JsonApsMessage
clearAlertMessage a = a { jamAlert = Nothing }
instance ToJSON JsonApsMessage where
toJSON = genericToJSON defaultOptions
{ fieldLabelModifier = drop 3 . map toLower }
data JsonAps
= JsonAps
{ jaAps :: !JsonApsMessage
, jaAppSpecificContent :: !(Maybe Text)
, jaSupplementalFields :: !(Map Text Value)
} deriving (Generic, Show)
instance ToJSON JsonAps where
toJSON JsonAps{..} = object (staticFields <> dynamicFields)
where
dynamicFields = M.toList jaSupplementalFields
staticFields = [ "aps" .= jaAps
, "appspecificcontent" .= jaAppSpecificContent
]
newMessage
:: JsonApsMessage
-> JsonAps
newMessage aps = JsonAps aps Nothing M.empty
newMessageWithCustomPayload
:: JsonApsMessage
-> Text
-> JsonAps
newMessageWithCustomPayload message payload =
JsonAps message (Just payload) M.empty
addSupplementalField :: ToJSON record =>
Text
-> record
-> JsonAps
-> JsonAps
addSupplementalField "aps" _ _ = error "The 'aps' field may not be overwritten by user code"
addSupplementalField fieldName fieldValue oldAPN = oldAPN { jaSupplementalFields = newSupplemental }
where
oldSupplemental = jaSupplementalFields oldAPN
newSupplemental = M.insert fieldName (toJSON fieldValue) oldSupplemental
newSession
:: FilePath
-> FilePath
-> FilePath
-> Bool
-> Int
-> Int
-> ByteString
-> IO ApnSession
newSession certKey certPath caPath dev maxparallel maxConnectionCount topic = do
let hostname = if dev
then "api.development.push.apple.com"
else "api.push.apple.com"
connInfo = ApnConnectionInfo certPath certKey caPath hostname maxparallel topic
certsOk <- checkCertificates connInfo
unless certsOk $ error "Unable to load certificates and/or the private key"
isOpen <- newIORef True
let connectionUnusedTimeout :: NominalDiffTime
connectionUnusedTimeout = 600
pool <-
createPool
(newConnection connInfo) closeApnConnection 1 connectionUnusedTimeout maxConnectionCount
let session =
ApnSession
{ apnSessionPool = pool
, apnSessionOpen = isOpen
}
addFinalizer session $
closeSession session
return session
closeSession :: ApnSession -> IO ()
closeSession s = do
isOpen <- atomicModifyIORef' (apnSessionOpen s) (False,)
unless isOpen $ error "Session is already closed"
destroyAllResources (apnSessionPool s)
isOpen :: ApnSession -> IO Bool
isOpen = readIORef . apnSessionOpen
withConnection :: ApnSession -> (ApnConnection -> ClientIO a) -> ClientIO a
withConnection s action = do
lift $ ensureOpen s
ExceptT . try $
withResource (apnSessionPool s) $ \conn -> do
res <- runClientIO (action conn)
case res of
Left clientError ->
throw clientError
Right res -> return res
checkCertificates :: ApnConnectionInfo -> IO Bool
checkCertificates aci = do
castore <- readCertificateStore $ aciCaPath aci
credential <- credentialLoadX509 (aciCertPath aci) (aciCertKey aci)
return $ isJust castore && isRight credential
newConnection :: ApnConnectionInfo -> IO ApnConnection
newConnection aci = do
Just castore <- readCertificateStore $ aciCaPath aci
Right credential <- credentialLoadX509 (aciCertPath aci) (aciCertKey aci)
let credentials = Credentials [credential]
shared = def { sharedCredentials = credentials
, sharedCAStore=castore }
maxConcurrentStreams = aciMaxConcurrentStreams aci
clip = ClientParams
{ clientUseMaxFragmentLength=Nothing
, clientServerIdentification=(T.unpack hostname, undefined)
, clientUseServerNameIndication=True
, clientWantSessionResume=Nothing
, clientShared=shared
, clientHooks=def
{ onCertificateRequest=const . return . Just $ credential }
, clientDebug=DebugParams { debugSeed=Nothing, debugPrintSeed=const $ return () }
, clientSupported=def
{ supportedVersions=[ TLS12 ]
, supportedCiphers=ciphersuite_strong }
}
conf = [ (HTTP2.SettingsMaxFrameSize, 16384)
, (HTTP2.SettingsMaxConcurrentStreams, maxConcurrentStreams)
, (HTTP2.SettingsMaxHeaderBlockSize, 4096)
, (HTTP2.SettingsInitialWindowSize, 65536)
, (HTTP2.SettingsEnablePush, 1)
]
hostname = aciHostname aci
isOpen <- newIORef True
let handleGoAway rsgaf = do
lift $ writeIORef isOpen False
return ()
client <-
fmap (either throw id) . runClientIO $ do
httpFrameConnection <- newHttp2FrameConnection (T.unpack hostname) 443 (Just clip)
client <-
newHttp2Client httpFrameConnection 4096 4096 conf handleGoAway ignoreFallbackHandler
linkAsyncs client
return client
flowWorker <- forkIO $ forever $ do
updated <- runClientIO $ _updateWindow $ _incomingFlowControl client
threadDelay 1000000
workersem <- newQSem maxConcurrentStreams
return $ ApnConnection client aci workersem flowWorker isOpen
closeApnConnection :: ApnConnection -> IO ()
closeApnConnection connection =
void $ runClientIO $ do
lift $ writeIORef (apnConnectionOpen connection) False
let flowWorker = apnConnectionFlowControlWorker connection
lift $ killThread flowWorker
_gtfo (apnConnectionConnection connection) HTTP2.NoError ""
_close (apnConnectionConnection connection)
sendRawMessage
:: ApnSession
-> ApnToken
-> ByteString
-> IO ApnMessageResult
sendRawMessage s token payload = catchErrors $
withConnection s $ \c ->
sendApnRaw c token payload
sendMessage
:: ApnSession
-> ApnToken
-> JsonAps
-> IO ApnMessageResult
sendMessage s token payload = catchErrors $
withConnection s $ \c ->
sendApnRaw c token message
where message = L.toStrict $ encode payload
sendSilentMessage
:: ApnSession
-> ApnToken
-> IO ApnMessageResult
sendSilentMessage s token = catchErrors $
withConnection s $ \c ->
sendApnRaw c token message
where message = "{\"aps\":{\"content-available\":1}}"
ensureOpen :: ApnSession -> IO ()
ensureOpen s = do
open <- isOpen s
unless open $ error "Session is closed"
sendApnRaw
:: ApnConnection
-> ApnToken
-> ByteString
-> ClientIO ApnMessageResult
sendApnRaw connection token message = bracket_
(lift $ waitQSem (apnConnectionWorkerPool connection))
(lift $ signalQSem (apnConnectionWorkerPool connection)) $ do
let requestHeaders = [ ( ":method", "POST" )
, ( ":scheme", "https" )
, ( ":authority", TE.encodeUtf8 hostname )
, ( ":path", "/3/device/" `S.append` token1 )
, ( "apns-topic", topic ) ]
aci = apnConnectionInfo connection
hostname = aciHostname aci
topic = aciTopic aci
client = apnConnectionConnection connection
token1 = unApnToken token
res <- _startStream client $ \stream ->
let init = headers stream requestHeaders id
handler isfc osfc = do
upload message (HTTP2.setEndHeader . HTTP2.setEndStream) client (_outgoingFlowControl client) stream osfc
let pph hStreamId hStream hHeaders hIfc hOfc =
lift $ print hHeaders
response <- waitStream stream isfc pph
let (errOrHeaders, frameResponses, _) = response
case errOrHeaders of
Left err -> throwIO (ApnExceptionHTTP $ toErrorCodeId err)
Right hdrs1 -> do
let status = getHeaderEx ":status" hdrs1
[Right body] = frameResponses
return $ case status of
"200" -> ApnMessageResultOk
"400" -> decodeReason ApnMessageResultFatalError body
"403" -> decodeReason ApnMessageResultFatalError body
"405" -> decodeReason ApnMessageResultFatalError body
"410" -> decodeReason ApnMessageResultFatalError body
"413" -> decodeReason ApnMessageResultFatalError body
"429" -> decodeReason ApnMessageResultTemporaryError body
"500" -> decodeReason ApnMessageResultTemporaryError body
"503" -> decodeReason ApnMessageResultTemporaryError body
in StreamDefinition init handler
case res of
Left _ -> return ApnMessageResultBackoff
Right res1 -> return res1
where
decodeReason :: FromJSON response => (response -> ApnMessageResult) -> ByteString -> ApnMessageResult
decodeReason ctor = either (throw . ApnExceptionJSON) id . decodeBody . L.fromStrict
where
decodeBody body =
eitherDecode body
>>= parseEither (\obj -> ctor <$> obj .: "reason")
getHeaderEx :: HTTP2.HeaderName -> [HTTP2.Header] -> HTTP2.HeaderValue
getHeaderEx name headers = fromMaybe (throw $ ApnExceptionMissingHeader name) (DL.lookup name headers)
catchErrors :: ClientIO ApnMessageResult -> IO ApnMessageResult
catchErrors = catchIOErrors . catchClientErrors
where
catchIOErrors :: IO ApnMessageResult -> IO ApnMessageResult
catchIOErrors = flip catchIOError (return . ApnMessageResultIOError)
catchClientErrors :: ClientIO ApnMessageResult -> IO ApnMessageResult
catchClientErrors act =
either ApnMessageResultClientError id <$> runClientIO act
data ApnFatalError = ApnFatalErrorBadCollapseId
| ApnFatalErrorBadDeviceToken
| ApnFatalErrorBadExpirationDate
| ApnFatalErrorBadMessageId
| ApnFatalErrorBadPriority
| ApnFatalErrorBadTopic
| ApnFatalErrorDeviceTokenNotForTopic
| ApnFatalErrorDuplicateHeaders
| ApnFatalErrorIdleTimeout
| ApnFatalErrorMissingDeviceToken
| ApnFatalErrorMissingTopic
| ApnFatalErrorPayloadEmpty
| ApnFatalErrorTopicDisallowed
| ApnFatalErrorBadCertificate
| ApnFatalErrorBadCertificateEnvironment
| ApnFatalErrorExpiredProviderToken
| ApnFatalErrorForbidden
| ApnFatalErrorInvalidProviderToken
| ApnFatalErrorMissingProviderToken
| ApnFatalErrorBadPath
| ApnFatalErrorMethodNotAllowed
| ApnFatalErrorUnregistered
| ApnFatalErrorPayloadTooLarge
| ApnFatalErrorOther Text
deriving (Eq, Show, Generic)
instance FromJSON ApnFatalError where
parseJSON json =
let result = parse genericParser json
in
case result of
Success success -> return success
Error err -> case json of
String other -> return $ ApnFatalErrorOther other
_ -> fail err
where
genericParser = genericParseJSON defaultOptions {
constructorTagModifier = drop 13,
sumEncoding = UntaggedValue
}
data ApnTemporaryError = ApnTemporaryErrorTooManyProviderTokenUpdates
| ApnTemporaryErrorTooManyRequests
| ApnTemporaryErrorInternalServerError
| ApnTemporaryErrorServiceUnavailable
| ApnTemporaryErrorShutdown
deriving (Enum, Eq, Show, Generic)
instance FromJSON ApnTemporaryError where
parseJSON = genericParseJSON defaultOptions { constructorTagModifier = drop 17 }