module Network.PushNotify.Apns.Send
( sendAPNS
, startAPNS
, closeAPNS
, withAPNS
, feedBackAPNS
) where
import Network.PushNotify.Apns.Types
import Network.PushNotify.Apns.Constants
import Control.Concurrent
import Control.Concurrent.Async
import Control.Concurrent.STM.TChan
import Control.Monad.STM
import Control.Retry
import Data.Certificate.X509 (X509)
import Data.Convertible (convert)
import Data.Default
import Data.Int
import Data.IORef
import Data.Serialize
import Data.Text.Encoding (encodeUtf8,decodeUtf8)
import Data.Time.Clock
import Data.Time.Clock.POSIX
import qualified Data.Aeson.Encode as AE
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LB
import qualified Data.ByteString.Base16 as B16
import qualified Data.HashSet as HS
import qualified Data.HashMap.Strict as HM
import qualified Control.Exception as CE
import qualified Crypto.Random.AESCtr as RNG
import Network
import Network.Socket.Internal (PortNumber(PortNum))
import Network.TLS
import Network.TLS.Extra (ciphersuite_all)
import System.Timeout
connParams :: X509 -> PrivateKey -> Params
connParams cert privateKey = defaultParamsClient{
pConnectVersion = TLS11
, pAllowedVersions = [TLS10,TLS11,TLS12]
, pCiphers = ciphersuite_all
, pCertificates = [(cert , Just privateKey)]
, onCertificatesRecv = const $ return CertificateUsageAccept
, roleParams = Client $ ClientParams{
clientWantSessionResume = Nothing
, clientUseMaxFragmentLength = Nothing
, clientUseServerName = Nothing
, onCertificateRequest = \x -> return [(cert , Just privateKey)]
}
}
connectAPNS :: APNSConfig -> IO Context
connectAPNS config = do
handle <- case environment config of
Development -> connectTo cDEVELOPMENT_URL
$ PortNumber $ fromInteger cDEVELOPMENT_PORT
Production -> connectTo cPRODUCTION_URL
$ PortNumber $ fromInteger cPRODUCTION_PORT
Local -> connectTo cLOCAL_URL
$ PortNumber $ fromInteger cLOCAL_PORT
rng <- RNG.makeSystem
ctx <- contextNewOnHandle handle (connParams (apnsCertificate config) (apnsPrivateKey config)) rng
handshake ctx
return ctx
startAPNS :: APNSConfig -> IO APNSManager
startAPNS config = do
c <- newTChanIO
ref <- newIORef $ Just ()
tID <- forkIO $ CE.catch (apnsWorker config c) (\(e :: CE.SomeException) ->
atomicModifyIORef ref (\_ -> (Nothing,())))
return $ APNSManager ref c tID $ timeoutLimit config
closeAPNS :: APNSManager -> IO ()
closeAPNS m = do
atomicModifyIORef (mState m) (\_ -> (Nothing,()))
killThread $ mWorkerID m
sendAPNS :: APNSManager -> APNSmessage -> IO APNSresult
sendAPNS m msg = do
s <- readIORef $ mState m
case s of
Nothing -> fail "APNS Service closed."
Just () -> do
let requestChan = mApnsChannel m
var1 <- newEmptyMVar
atomically $ writeTChan requestChan (var1,msg)
Just (errorChan,startNum) <- takeMVar var1
v <- race (readChan errorChan) (takeMVar var1 >> (threadDelay $ mTimeoutLimit m))
let (success,fail) = case v of
Left s -> if s >= startNum
then (\(a,b) -> (HS.fromList a,HS.fromList b)) $
splitAt (s+1startNum) $ HS.toList $ deviceTokens msg
else (HS.empty,deviceTokens msg)
Right _ -> (deviceTokens msg,HS.empty)
return $ APNSresult success fail
apnsWorker :: APNSConfig -> TChan (MVar (Maybe (Chan Int,Int)) , APNSmessage) -> IO ()
apnsWorker config requestChan = do
ctx <- recoverAll (apnsRetrySettings config) $ connectAPNS config
errorChan <- newChan
lock <- newMVar ()
s <- async (catch $ sender 1 lock requestChan errorChan ctx)
r <- async (catch $ receiver ctx)
res <- waitEither s r
case res of
Left _ -> do
cancel r
writeChan errorChan 0
Right v -> do
takeMVar lock
cancel s
writeChan errorChan v
CE.catch (contextClose ctx) (\(e :: CE.SomeException) -> return ())
apnsWorker config requestChan
where
catch :: IO Int -> IO Int
catch m = CE.catch m (\(e :: CE.SomeException) -> return 0)
sender :: Int32
-> MVar ()
-> TChan (MVar (Maybe (Chan Int,Int)) , APNSmessage)
-> Chan Int
-> Context
-> IO Int
sender n lock requestChan errorChan c = do
atomically $ peekTChan requestChan
takeMVar lock
(var,msg) <- atomically $ readTChan requestChan
let list = HS.toList $ deviceTokens msg
len = convert $ HS.size $ deviceTokens msg
num = if (n + len :: Int32) < 0 then 1 else n
echan <- dupChan errorChan
putMVar var $ Just (echan,convert num)
putMVar lock ()
ctime <- getPOSIXTime
loop var c num (createPut msg ctime) list
sender (num+len) lock requestChan errorChan c
receiver :: Context -> IO Int
receiver c = do
dat <- recvData c
case runGet (getWord16be >> getWord32be) dat of
Right ident -> return (convert ident)
Left _ -> return 0
loop :: MVar (Maybe (Chan Int,Int))
-> Context
-> Int32
-> (DeviceToken -> Int32 -> Put)
-> [DeviceToken]
-> IO Bool
loop var _ _ _ [] = tryPutMVar var Nothing
loop var ctx num cput (x:xs) = do
sendData ctx $ LB.fromChunks [ (runPut $ cput x num) ]
loop var ctx (num+1) cput xs
createPut :: APNSmessage -> NominalDiffTime -> DeviceToken -> Int32 -> Put
createPut msg ctime dst identifier = do
let
btoken = fst $ B16.decode $ encodeUtf8 dst
bpayload = AE.encode msg
expiryTime = case expiry msg of
Nothing -> round (ctime + posixDayLength)
Just t -> round (utcTimeToPOSIXSeconds t)
if (LB.length bpayload > 256)
then fail "Too long payload"
else do
putWord8 1
putWord32be $ convert identifier
putWord32be expiryTime
putWord16be $ convert $ B.length btoken
putByteString btoken
putWord16be $ convert $ LB.length bpayload
putLazyByteString bpayload
withAPNS :: APNSConfig -> (APNSManager -> IO a) -> IO a
withAPNS confg fun = CE.bracket (startAPNS confg) closeAPNS fun
connectFeedBackAPNS :: APNSConfig -> IO Context
connectFeedBackAPNS config = do
handle <- case environment config of
Development -> connectTo cDEVELOPMENT_FEEDBACK_URL
$ PortNumber $ fromInteger cDEVELOPMENT_FEEDBACK_PORT
Production -> connectTo cPRODUCTION_FEEDBACK_URL
$ PortNumber $ fromInteger cPRODUCTION_FEEDBACK_PORT
Local -> connectTo cLOCAL_FEEDBACK_URL
$ PortNumber $ fromInteger cLOCAL_FEEDBACK_PORT
rng <- RNG.makeSystem
ctx <- contextNewOnHandle handle (connParams (apnsCertificate config) (apnsPrivateKey config)) rng
handshake ctx
return ctx
feedBackAPNS :: APNSConfig -> IO APNSFeedBackresult
feedBackAPNS config = do
ctx <- connectFeedBackAPNS config
var <- newEmptyMVar
tID <- forkIO $ loopReceive var ctx
res <- waitAndCheck var HM.empty
killThread tID
bye ctx
contextClose ctx
return res
where
getData :: Get (DeviceToken,UTCTime)
getData = do
time <- getWord32be
length <- getWord16be
dtoken <- getBytes $ convert length
return ( decodeUtf8 $ B16.encode dtoken
, posixSecondsToUTCTime $ fromInteger $ convert time )
loopReceive :: MVar (DeviceToken,UTCTime) -> Context -> IO ()
loopReceive var ctx = do
dat <- recvData ctx
case runGet getData dat of
Right tuple -> do
putMVar var tuple
loopReceive var ctx
Left _ -> return ()
waitAndCheck :: MVar (DeviceToken,UTCTime) -> HM.HashMap DeviceToken UTCTime -> IO APNSFeedBackresult
waitAndCheck var hmap = do
v <- timeout (timeoutLimit config) $ takeMVar var
case v of
Nothing -> return $ APNSFeedBackresult hmap
Just (d,t) -> waitAndCheck var (HM.insert d t hmap)