{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
module Network.PushNotify.APN
( newSession
, newMessage
, newMessageWithCustomPayload
, hexEncodedToken
, rawToken
, sendMessage
, sendSilentMessage
, sendRawMessage
, alertMessage
, bodyMessage
, emptyMessage
, setAlertMessage
, setMessageBody
, setBadge
, setCategory
, setSound
, clearAlertMessage
, clearBadge
, clearCategory
, clearSound
, addSupplementalField
, closeSession
, isOpen
, ApnSession
, JsonAps
, JsonApsAlert
, JsonApsMessage
, ApnMessageResult(..)
, ApnFatalError(..)
, ApnTemporaryError(..)
, ApnToken
) where
import Control.Concurrent
import Control.Concurrent.QSem
import Control.Exception
import Control.Monad
import Data.Aeson
import Data.Aeson.Types
import Data.ByteString (ByteString)
import Data.Char (toLower)
import Data.Default (def)
import Data.Either
import Data.Int
import Data.IORef
import Data.Map.Strict (Map)
import Data.Maybe
import Data.Semigroup ((<>))
import Data.Text (Text)
import Data.Time.Clock.POSIX
import Data.Typeable (Typeable)
import Data.X509
import Data.X509.CertificateStore
import GHC.Generics
import Network.HTTP2 (ErrorCodeId,
toErrorCodeId)
import Network.HTTP2.Client
import Network.HTTP2.Client.FrameConnection
import Network.HTTP2.Client.Helpers
import Network.TLS hiding (sendData)
import Network.TLS.Extra.Cipher
import System.IO.Error
import System.Mem.Weak
import System.Random
import qualified Data.ByteString as S
import qualified Data.ByteString.Base16 as B16
import qualified Data.ByteString.Lazy as L
import qualified Data.List as DL
import qualified Data.Map.Strict as M
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Network.HPACK as HTTP2
import qualified Network.HTTP2 as HTTP2
data ApnSession = ApnSession
{ apnSessionPool :: !(IORef [ApnConnection])
, apnSessionConnectionInfo :: !ApnConnectionInfo
, apnSessionConnectionManager :: !ThreadId
, apnSessionOpen :: !(IORef Bool)}
data ApnConnectionInfo = ApnConnectionInfo
{ aciCertPath :: !FilePath
, aciCertKey :: !FilePath
, aciCaPath :: !FilePath
, aciHostname :: !Text
, aciMaxConcurrentStreams :: !Int
, aciTopic :: !ByteString }
data ApnConnection = ApnConnection
{ apnConnectionConnection :: !Http2Client
, apnConnectionInfo :: !ApnConnectionInfo
, apnConnectionWorkerPool :: !QSem
, apnConnectionLastUsed :: !Int64
, apnConnectionFlowControlWorker :: !ThreadId
, apnConnectionOpen :: !(IORef Bool)}
newtype ApnToken = ApnToken { unApnToken :: ByteString }
class SpecifyError a where
isAnError :: IOError -> a
rawToken
:: ByteString
-> ApnToken
rawToken = ApnToken . B16.encode
hexEncodedToken
:: Text
-> ApnToken
hexEncodedToken = ApnToken . B16.encode . fst . B16.decode . TE.encodeUtf8
data ApnException = ApnExceptionHTTP ErrorCodeId
| ApnExceptionJSON String
| ApnExceptionMissingHeader HTTP2.HeaderName
| ApnExceptionUnexpectedResponse
deriving (Show, Typeable)
instance Exception ApnException
data ApnMessageResult = ApnMessageResultOk
| ApnMessageResultBackoff
| ApnMessageResultFatalError ApnFatalError
| ApnMessageResultTemporaryError ApnTemporaryError
| ApnMessageResultIOError IOError
deriving (Eq, Show)
instance SpecifyError ApnMessageResult where
isAnError = ApnMessageResultIOError
data JsonApsAlert = JsonApsAlert
{ jaaTitle :: !(Maybe Text)
, jaaBody :: !Text
} deriving (Generic, Show)
instance ToJSON JsonApsAlert where
toJSON = genericToJSON defaultOptions
{ fieldLabelModifier = drop 3 . map toLower
, omitNothingFields = True
}
data JsonApsMessage
= JsonApsMessage
{ jamAlert :: !(Maybe JsonApsAlert)
, jamBadge :: !(Maybe Int)
, jamSound :: !(Maybe Text)
, jamCategory :: !(Maybe Text)
} deriving (Generic, Show)
emptyMessage :: JsonApsMessage
emptyMessage = JsonApsMessage Nothing Nothing Nothing Nothing
setSound
:: Text
-> JsonApsMessage
-> JsonApsMessage
setSound s a = a { jamSound = Just s }
clearSound
:: JsonApsMessage
-> JsonApsMessage
clearSound a = a { jamSound = Nothing }
setCategory
:: Text
-> JsonApsMessage
-> JsonApsMessage
setCategory c a = a { jamCategory = Just c }
clearCategory
:: JsonApsMessage
-> JsonApsMessage
clearCategory a = a { jamCategory = Nothing }
setBadge
:: Int
-> JsonApsMessage
-> JsonApsMessage
setBadge i a = a { jamBadge = Just i }
clearBadge
:: JsonApsMessage
-> JsonApsMessage
clearBadge a = a { jamBadge = Nothing }
alertMessage
:: Text
-> Text
-> JsonApsMessage
alertMessage title text = setAlertMessage title text emptyMessage
bodyMessage
:: Text
-> JsonApsMessage
bodyMessage text = setMessageBody text emptyMessage
setAlertMessage
:: Text
-> Text
-> JsonApsMessage
-> JsonApsMessage
setAlertMessage title text a = a { jamAlert = Just jam }
where
jam = JsonApsAlert (Just title) text
setMessageBody
:: Text
-> JsonApsMessage
-> JsonApsMessage
setMessageBody text a = a { jamAlert = Just newJaa }
where
newJaa = case jamAlert a of
Nothing -> JsonApsAlert Nothing text
Just jaa -> jaa { jaaBody = text }
clearAlertMessage
:: JsonApsMessage
-> JsonApsMessage
clearAlertMessage a = a { jamAlert = Nothing }
instance ToJSON JsonApsMessage where
toJSON = genericToJSON defaultOptions
{ fieldLabelModifier = drop 3 . map toLower }
data JsonAps
= JsonAps
{ jaAps :: !JsonApsMessage
, jaAppSpecificContent :: !(Maybe Text)
, jaSupplementalFields :: !(Map Text Value)
} deriving (Generic, Show)
instance ToJSON JsonAps where
toJSON JsonAps{..} = object (staticFields <> dynamicFields)
where
dynamicFields = M.toList jaSupplementalFields
staticFields = [ "aps" .= jaAps
, "appspecificcontent" .= jaAppSpecificContent
]
newMessage
:: JsonApsMessage
-> JsonAps
newMessage aps = JsonAps aps Nothing M.empty
newMessageWithCustomPayload
:: JsonApsMessage
-> Text
-> JsonAps
newMessageWithCustomPayload message payload =
JsonAps message (Just payload) M.empty
addSupplementalField :: ToJSON record =>
Text
-> record
-> JsonAps
-> JsonAps
addSupplementalField "aps" _ _ = error "The 'aps' field may not be overwritten by user code"
addSupplementalField fieldName fieldValue oldAPN = oldAPN { jaSupplementalFields = newSupplemental }
where
oldSupplemental = jaSupplementalFields oldAPN
newSupplemental = M.insert fieldName (toJSON fieldValue) oldSupplemental
newSession
:: FilePath
-> FilePath
-> FilePath
-> Bool
-> Int
-> ByteString
-> IO ApnSession
newSession certKey certPath caPath dev maxparallel topic = do
let hostname = if dev
then "api.development.push.apple.com"
else "api.push.apple.com"
connInfo = ApnConnectionInfo certPath certKey caPath hostname maxparallel topic
certsOk <- checkCertificates connInfo
unless certsOk $ error "Unable to load certificates and/or the private key"
connections <- newIORef []
connectionManager <- forkIO $ manage 600 connections
isOpen <- newIORef True
let session = ApnSession connections connInfo connectionManager isOpen
addFinalizer session $
closeSession session
return session
closeSession :: ApnSession -> IO ()
closeSession s = do
isOpen <- atomicModifyIORef' (apnSessionOpen s) (False,)
unless isOpen $ error "Session is already closed"
killThread (apnSessionConnectionManager s)
let ioref = apnSessionPool s
openConnections <- atomicModifyIORef' ioref ([],)
mapM_ closeApnConnection openConnections
isOpen :: ApnSession -> IO Bool
isOpen = readIORef . apnSessionOpen
withConnection :: ApnSession -> (ApnConnection -> IO a) -> IO a
withConnection s action = do
ensureOpen s
let pool = apnSessionPool s
connections <- readIORef pool
let len = length connections
if len == 0
then do
conn <- newConnection s
res <- action conn
atomicModifyIORef' pool (\a -> (conn:a, ()))
return res
else do
num <- randomRIO (0, len - 1)
currtime <- round <$> getPOSIXTime :: IO Int64
let conn = connections !! num
conn1 = conn { apnConnectionLastUsed=currtime }
atomicModifyIORef' pool (\a -> (removeNth num a, ()))
isOpen <- readIORef (apnConnectionOpen conn)
if isOpen
then do
res <- action conn1
atomicModifyIORef' pool (\a -> (conn1:a, ()))
return res
else withConnection s action
checkCertificates :: ApnConnectionInfo -> IO Bool
checkCertificates aci = do
castore <- readCertificateStore $ aciCaPath aci
credential <- credentialLoadX509 (aciCertPath aci) (aciCertKey aci)
return $ isJust castore && isRight credential
replaceNth n newVal (x:xs)
| n == 0 = newVal:xs
| otherwise = x:replaceNth (n-1) newVal xs
removeNth n (x:xs)
| n == 0 = xs
| otherwise = x:removeNth (n-1) xs
manage :: Int64 -> IORef [ApnConnection] -> IO ()
manage timeout ioref = forever $ do
currtime <- round <$> getPOSIXTime :: IO Int64
let minTime = currtime - timeout
expiredOnes <- atomicModifyIORef' ioref
(foldl ( \(a,b) i -> if apnConnectionLastUsed i < minTime then (a, i:b ) else ( i:a ,b)) ([],[]))
mapM_ closeApnConnection expiredOnes
threadDelay 60000000
newConnection :: ApnSession -> IO ApnConnection
newConnection apnSession = do
let aci = apnSessionConnectionInfo apnSession
Just castore <- readCertificateStore $ aciCaPath aci
Right credential <- credentialLoadX509 (aciCertPath aci) (aciCertKey aci)
let credentials = Credentials [credential]
shared = def { sharedCredentials = credentials
, sharedCAStore=castore }
maxConcurrentStreams = aciMaxConcurrentStreams aci
clip = ClientParams
{ clientUseMaxFragmentLength=Nothing
, clientServerIdentification=(T.unpack hostname, undefined)
, clientUseServerNameIndication=True
, clientWantSessionResume=Nothing
, clientShared=shared
, clientHooks=def
{ onCertificateRequest=const . return . Just $ credential }
, clientDebug=DebugParams { debugSeed=Nothing, debugPrintSeed=const $ return () }
, clientSupported=def
{ supportedVersions=[ TLS12 ]
, supportedCiphers=ciphersuite_strong }
}
conf = [ (HTTP2.SettingsMaxFrameSize, 16384)
, (HTTP2.SettingsMaxConcurrentStreams, maxConcurrentStreams)
, (HTTP2.SettingsMaxHeaderBlockSize, 4096)
, (HTTP2.SettingsInitialWindowSize, 65536)
, (HTTP2.SettingsEnablePush, 1)
]
hostname = aciHostname aci
httpFrameConnection <- newHttp2FrameConnection (T.unpack hostname) 443 (Just clip)
isOpen <- newIORef True
let handleGoAway rsgaf = do
writeIORef isOpen False
return ()
client <- newHttp2Client httpFrameConnection 4096 4096 conf handleGoAway ignoreFallbackHandler
linkAsyncs client
flowWorker <- forkIO $ forever $ do
updated <- _updateWindow $ _incomingFlowControl client
threadDelay 1000000
workersem <- newQSem maxConcurrentStreams
currtime <- round <$> getPOSIXTime :: IO Int64
return $ ApnConnection client aci workersem currtime flowWorker isOpen
closeApnConnection :: ApnConnection -> IO ()
closeApnConnection connection = do
writeIORef (apnConnectionOpen connection) False
let flowWorker = apnConnectionFlowControlWorker connection
killThread flowWorker
_gtfo (apnConnectionConnection connection) HTTP2.NoError ""
_close (apnConnectionConnection connection)
sendRawMessage
:: ApnSession
-> ApnToken
-> ByteString
-> IO ApnMessageResult
sendRawMessage s token payload = catchIOErrors $
withConnection s $ \c ->
sendApnRaw c token payload
sendMessage
:: ApnSession
-> ApnToken
-> JsonAps
-> IO ApnMessageResult
sendMessage s token payload = catchIOErrors $
withConnection s $ \c ->
sendApnRaw c token message
where message = L.toStrict $ encode payload
sendSilentMessage
:: ApnSession
-> ApnToken
-> IO ApnMessageResult
sendSilentMessage s token = catchIOErrors $
withConnection s $ \c ->
sendApnRaw c token message
where message = "{\"aps\":{\"content-available\":1}}"
ensureOpen :: ApnSession -> IO ()
ensureOpen s = do
open <- isOpen s
unless open $ error "Session is closed"
sendApnRaw
:: ApnConnection
-> ApnToken
-> ByteString
-> IO ApnMessageResult
sendApnRaw connection token message = bracket_
(waitQSem (apnConnectionWorkerPool connection))
(signalQSem (apnConnectionWorkerPool connection)) $ do
let requestHeaders = [ ( ":method", "POST" )
, ( ":scheme", "https" )
, ( ":authority", TE.encodeUtf8 hostname )
, ( ":path", "/3/device/" `S.append` token1 )
, ( "apns-topic", topic ) ]
aci = apnConnectionInfo connection
hostname = aciHostname aci
topic = aciTopic aci
client = apnConnectionConnection connection
token1 = unApnToken token
res <- _startStream client $ \stream ->
let init = headers stream requestHeaders id
handler isfc osfc = do
upload message (HTTP2.setEndHeader . HTTP2.setEndStream) client (_outgoingFlowControl client) stream osfc
let pph hStreamId hStream hHeaders hIfc hOfc =
print hHeaders
response <- waitStream stream isfc pph
let (errOrHeaders, frameResponses, _) = response
case errOrHeaders of
Left err -> throwIO (ApnExceptionHTTP $ toErrorCodeId err)
Right hdrs1 -> do
let status = getHeaderEx ":status" hdrs1
[Right body] = frameResponses
return $ case status of
"200" -> ApnMessageResultOk
"400" -> decodeReason ApnMessageResultFatalError body
"403" -> decodeReason ApnMessageResultFatalError body
"405" -> decodeReason ApnMessageResultFatalError body
"410" -> decodeReason ApnMessageResultFatalError body
"413" -> decodeReason ApnMessageResultFatalError body
"429" -> decodeReason ApnMessageResultTemporaryError body
"500" -> decodeReason ApnMessageResultTemporaryError body
"503" -> decodeReason ApnMessageResultTemporaryError body
in StreamDefinition init handler
case res of
Left _ -> return ApnMessageResultBackoff
Right res1 -> return res1
where
decodeReason :: FromJSON response => (response -> ApnMessageResult) -> ByteString -> ApnMessageResult
decodeReason ctor = either (throw . ApnExceptionJSON) id . decodeBody . L.fromStrict
where
decodeBody body =
eitherDecode body
>>= parseEither (\obj -> ctor <$> obj .: "reason")
getHeaderEx :: HTTP2.HeaderName -> [HTTP2.Header] -> HTTP2.HeaderValue
getHeaderEx name headers = fromMaybe (throw $ ApnExceptionMissingHeader name) (DL.lookup name headers)
catchIOErrors :: SpecifyError a => IO a -> IO a
catchIOErrors = flip catchIOError (return . isAnError)
data ApnFatalError = ApnFatalErrorBadCollapseId
| ApnFatalErrorBadDeviceToken
| ApnFatalErrorBadExpirationDate
| ApnFatalErrorBadMessageId
| ApnFatalErrorBadPriority
| ApnFatalErrorBadTopic
| ApnFatalErrorDeviceTokenNotForTopic
| ApnFatalErrorDuplicateHeaders
| ApnFatalErrorIdleTimeout
| ApnFatalErrorMissingDeviceToken
| ApnFatalErrorMissingTopic
| ApnFatalErrorPayloadEmpty
| ApnFatalErrorTopicDisallowed
| ApnFatalErrorBadCertificate
| ApnFatalErrorBadCertificateEnvironment
| ApnFatalErrorExpiredProviderToken
| ApnFatalErrorForbidden
| ApnFatalErrorInvalidProviderToken
| ApnFatalErrorMissingProviderToken
| ApnFatalErrorBadPath
| ApnFatalErrorMethodNotAllowed
| ApnFatalErrorUnregistered
| ApnFatalErrorPayloadTooLarge
| ApnFatalErrorOther Text
deriving (Eq, Show, Generic)
instance FromJSON ApnFatalError where
parseJSON json =
let result = parse genericParser json
in
case result of
Success success -> return success
Error err -> case json of
String other -> return $ ApnFatalErrorOther other
_ -> fail err
where
genericParser = genericParseJSON defaultOptions {
constructorTagModifier = drop 13,
sumEncoding = UntaggedValue
}
data ApnTemporaryError = ApnTemporaryErrorTooManyProviderTokenUpdates
| ApnTemporaryErrorTooManyRequests
| ApnTemporaryErrorInternalServerError
| ApnTemporaryErrorServiceUnavailable
| ApnTemporaryErrorShutdown
deriving (Enum, Eq, Show, Generic)
instance FromJSON ApnTemporaryError where
parseJSON = genericParseJSON defaultOptions { constructorTagModifier = drop 17 }