module Network.PushNotify.Apns.Send
    ( sendAPNS
    , startAPNS
    , closeAPNS
    , withAPNS
    , feedBackAPNS
    ) where
import Network.PushNotify.Apns.Types
import Network.PushNotify.Apns.Constants
import Control.Concurrent
import Control.Concurrent.Async
import Control.Concurrent.STM.TChan
import Control.Monad.STM
import Control.Retry
import Data.Certificate.X509            (X509)
import Data.Convertible                 (convert)
import Data.Default
import Data.Int
import Data.IORef
import Data.Serialize
import Data.Text.Encoding               (encodeUtf8,decodeUtf8)
import Data.Time.Clock
import Data.Time.Clock.POSIX
import qualified Data.Aeson.Encode      as AE
import qualified Data.ByteString        as B
import qualified Data.ByteString.Lazy   as LB
import qualified Data.ByteString.Base16 as B16
import qualified Data.HashSet           as HS
import qualified Data.HashMap.Strict    as HM
import qualified Control.Exception      as CE
import qualified Crypto.Random.AESCtr   as RNG
import Network
import Network.Socket.Internal          (PortNumber(PortNum))
import Network.TLS
import Network.TLS.Extra                (ciphersuite_all)
import System.Timeout
connParams :: X509 -> PrivateKey -> Params
connParams cert privateKey = defaultParamsClient{
                     pConnectVersion    = TLS11
                   , pAllowedVersions   = [TLS10,TLS11,TLS12]
                   , pCiphers           = ciphersuite_all
                   , pCertificates      = [(cert , Just privateKey)]
                   , onCertificatesRecv = const $ return CertificateUsageAccept
                   , roleParams         = Client $ ClientParams{
                            clientWantSessionResume    = Nothing
                          , clientUseMaxFragmentLength = Nothing
                          , clientUseServerName        = Nothing
                          , onCertificateRequest       = \x -> return [(cert , Just privateKey)]
                          }
                   }
connectAPNS :: APNSConfig -> IO Context
connectAPNS config = do
        handle  <- case environment config of
                    Development -> connectTo cDEVELOPMENT_URL
                                           $ PortNumber $ fromInteger cDEVELOPMENT_PORT
                    Production  -> connectTo cPRODUCTION_URL
                                           $ PortNumber $ fromInteger cPRODUCTION_PORT
                    Local       -> connectTo cLOCAL_URL
                                           $ PortNumber $ fromInteger cLOCAL_PORT
        rng     <- RNG.makeSystem
        ctx     <- contextNewOnHandle handle (connParams (apnsCertificate config) (apnsPrivateKey config)) rng
        handshake ctx
        return ctx
startAPNS :: APNSConfig -> IO APNSManager
startAPNS config = do
        c       <- newTChanIO
        ref     <- newIORef $ Just ()
        tID     <- forkIO $ CE.catch (apnsWorker config c) (\(e :: CE.SomeException) ->
                                                              atomicModifyIORef ref (\_ -> (Nothing,())))
        return $ APNSManager ref c tID $ timeoutLimit config
closeAPNS :: APNSManager -> IO ()
closeAPNS m = do
                atomicModifyIORef (mState m) (\_ -> (Nothing,()))
                killThread $ mWorkerID m
sendAPNS :: APNSManager -> APNSmessage -> IO APNSresult
sendAPNS m msg = do
    s <- readIORef $ mState m
    case s of
      Nothing -> fail "APNS Service closed."
      Just () -> do
        let requestChan = mApnsChannel m
        var1 <- newEmptyMVar
        atomically $ writeTChan requestChan (var1,msg)
        Just (errorChan,startNum) <- takeMVar var1 
        
        
        
        v    <- race  (readChan errorChan) (takeMVar var1 >> (threadDelay $ mTimeoutLimit m))
        let (success,fail)    = case v of
                Left s  -> if s >= startNum 
                        then (\(a,b) -> (HS.fromList a,HS.fromList b)) $
                                        splitAt (s+1startNum)   $ HS.toList $ deviceTokens msg 
                        else (HS.empty,deviceTokens msg) 
                Right _ -> (deviceTokens msg,HS.empty) 
        return $ APNSresult success fail
apnsWorker :: APNSConfig -> TChan (MVar (Maybe (Chan Int,Int)) , APNSmessage) -> IO ()
apnsWorker config requestChan = do
        ctx        <- recoverAll (apnsRetrySettings config) $ connectAPNS config 
        errorChan  <- newChan 
        lock       <- newMVar ()
        
        s          <- async (catch $ sender 1 lock requestChan errorChan ctx)
        r          <- async (catch $ receiver ctx)
        res        <- waitEither s r
        case res of
            Left  _ -> do
                            cancel r
                            writeChan errorChan 0
            Right v -> do
                            takeMVar lock
                            cancel s
                            writeChan errorChan v 
                                        
                                        
                                        
        CE.catch (contextClose ctx) (\(e :: CE.SomeException) -> return ())
        apnsWorker config requestChan 
        where
            catch :: IO Int -> IO Int
            catch m = CE.catch m (\(e :: CE.SomeException) -> return 0)
            sender  :: Int32
                    -> MVar ()
                    -> TChan (MVar (Maybe (Chan Int,Int)) , APNSmessage)
                    -> Chan Int
                    -> Context
                    -> IO Int
            sender n lock requestChan errorChan c = do 
                    atomically $ peekTChan requestChan
                    
                    takeMVar lock
                    (var,msg)   <- atomically $ readTChan requestChan
                    let list = HS.toList $ deviceTokens msg
                        len  = convert $ HS.size $ deviceTokens msg     
                        num  = if (n + len :: Int32) < 0 then 1 else n 
                    echan       <- dupChan errorChan
                    putMVar var $ Just (echan,convert num) 
                                                           
                    putMVar lock ()
                    ctime       <- getPOSIXTime
                    loop var c num (createPut msg ctime) list 
                    sender (num+len) lock requestChan errorChan c
            receiver :: Context -> IO Int
            receiver c = do
                    dat <- recvData c
                    case runGet (getWord16be >> getWord32be) dat of 
                        Right ident -> return (convert ident)
                        Left _      -> return 0
            loop :: MVar (Maybe (Chan Int,Int)) 
                -> Context 
                -> Int32 
                         
                -> (DeviceToken -> Int32 -> Put)
                -> [DeviceToken]
                -> IO Bool
            loop var _   _   _    []     = tryPutMVar var Nothing
            loop var ctx num cput (x:xs) = do
                    sendData ctx $ LB.fromChunks [ (runPut $ cput x num) ]
                    loop var ctx (num+1) cput xs
createPut :: APNSmessage -> NominalDiffTime -> DeviceToken -> Int32 -> Put
createPut msg ctime dst identifier = do
   let
       
       btoken     = fst $ B16.decode $ encodeUtf8 dst 
       bpayload   = AE.encode msg
       expiryTime = case expiry msg of
                      Nothing ->  round (ctime + posixDayLength) 
                      Just t  ->  round (utcTimeToPOSIXSeconds t)
   if (LB.length bpayload > 256)
      then fail "Too long payload"
      else do 
            putWord8 1
            putWord32be $ convert identifier
            putWord32be expiryTime
            putWord16be $ convert $ B.length btoken
            putByteString btoken
            putWord16be $ convert $ LB.length bpayload
            putLazyByteString bpayload
withAPNS :: APNSConfig -> (APNSManager -> IO a) -> IO a
withAPNS confg fun = CE.bracket (startAPNS confg) closeAPNS fun
connectFeedBackAPNS :: APNSConfig -> IO Context
connectFeedBackAPNS config = do
        handle  <- case environment config of
                    Development -> connectTo cDEVELOPMENT_FEEDBACK_URL
                                           $ PortNumber $ fromInteger cDEVELOPMENT_FEEDBACK_PORT
                    Production  -> connectTo cPRODUCTION_FEEDBACK_URL
                                           $ PortNumber $ fromInteger cPRODUCTION_FEEDBACK_PORT
                    Local       -> connectTo cLOCAL_FEEDBACK_URL
                                           $ PortNumber $ fromInteger cLOCAL_FEEDBACK_PORT
        rng     <- RNG.makeSystem
        ctx     <- contextNewOnHandle handle (connParams (apnsCertificate config) (apnsPrivateKey config)) rng
        handshake ctx
        return ctx
feedBackAPNS :: APNSConfig -> IO APNSFeedBackresult
feedBackAPNS config = do
        ctx     <- connectFeedBackAPNS config
        var     <- newEmptyMVar
        tID     <- forkIO $ loopReceive var ctx 
        res     <- waitAndCheck var HM.empty
        killThread tID
        bye ctx
        contextClose ctx
        return res
        where
            getData :: Get (DeviceToken,UTCTime)
            getData = do 
                        time    <- getWord32be
                        length  <- getWord16be
                        dtoken  <- getBytes $ convert length
                        return (    decodeUtf8 $ B16.encode dtoken
                               ,    posixSecondsToUTCTime $ fromInteger $ convert time )
            loopReceive :: MVar (DeviceToken,UTCTime) -> Context -> IO ()
            loopReceive var ctx = do
                        dat <- recvData ctx
                        case runGet getData dat of
                            Right tuple -> do
                                                putMVar var tuple
                                                loopReceive var ctx
                            Left _      -> return ()
            waitAndCheck :: MVar (DeviceToken,UTCTime) -> HM.HashMap DeviceToken UTCTime -> IO APNSFeedBackresult
            waitAndCheck var hmap = do
                        v <- timeout (timeoutLimit config) $ takeMVar var
                        case v of
                            Nothing -> return $ APNSFeedBackresult hmap
                            Just (d,t)  -> waitAndCheck var (HM.insert d t hmap)