{-# LANGUAGE BangPatterns      #-}
{-# LANGUAGE CApiFFI           #-}
{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}
module Network.HTTP.MicroClient
    ( 
      SockStream
    , ssFromSocket
    , ssConnect
    , ssToSocket
    , ssClose
    , ssId
    , ssRead
    , ssPeek
    , ssPeekBuf
    , ssRead'
    , ssReadN
    , ssUnRead
    , ssWrite
    , ssReadCnt
    , ssWriteCnt
      
    , HttpResponse(..)
    , HttpCode
    , Method(..)
    , ReqURI
    , HostPort
    , MsgHeader
    , TransferEncoding(..)
    , mkHttp11Req
    , mkHttp11GetReq
    , recvHttpResponse
      
    , recvHttpHeaders
    , httpHeaderGetInfos
    , recvChunk
      
    , getSockAddr
    , splitUrl
    , getPOSIXTimeSecs
    , getPOSIXTimeUSecs
    ) where
#if !MIN_VERSION_base(4,8,0)
import           Control.Applicative
#endif
import           Control.DeepSeq              (NFData (rnf), deepseq)
import           Control.Exception
import           Control.Monad
import           Data.ByteString              (ByteString)
import qualified Data.ByteString              as B
import           Data.ByteString.Lex.Integral (packDecimal, readDecimal,
                                               readHexadecimal)
import qualified Data.ByteString.Unsafe       as B
import           Data.IORef
import           Data.Maybe
#if !MIN_VERSION_base(4,11,0)
import           Data.Monoid
#endif
import           Data.Tuple
import           Data.Word
import           Network.BSD                  (getHostByName, getProtocolNumber,
                                               hostAddress)
import           Network.Socket               (Family (AF_INET), HostName,
                                               PortNumber, ProtocolNumber,
                                               SockAddr (SockAddrInet), Socket,
                                               SocketType (Stream), bind, close,
                                               connect, socket)
import           Network.Socket.ByteString
import           Network.URI
import           System.IO.Error
import           System.IO.Unsafe             (unsafePerformIO)
data SockStream = SockStream {-# NOUNPACK #-} !Socket
                             {-# UNPACK #-} !(IORef ByteString)
                             {-# UNPACK #-} !(IORef Word64) 
                             {-# UNPACK #-} !(IORef Word64) 
                             {-# UNPACK #-} !Int 
ssDebug :: String -> SockStream -> IO a -> IO a
#if 0
ssDebug msg (SockStream _ bufref cntref _) act = do
    cnt <- readIORef cntref
    buf <- readIORef bufref
    putStrLn $ "DEBUG: " ++ msg ++ " SockStream _ " ++ show buf ++ " " ++ show cnt
    act
#else
ssDebug _ _ act = act
#endif
ssIdCounter :: IORef Int
ssIdCounter = unsafePerformIO $ newIORef 1
{-# NOINLINE ssIdCounter #-}
getSsCounterId :: IO Int
getSsCounterId = atomicModifyIORef' ssIdCounter (\n -> (n+1,n))
ssFromSocket :: Socket -> IO SockStream
ssFromSocket s = SockStream s <$> newIORef B.empty <*> newIORef 0 <*> newIORef 0 <*> getSsCounterId
ssToSocket :: SockStream -> Socket
ssToSocket (SockStream s _ _ _ _) = s
ssId :: SockStream -> Int
ssId (SockStream _ _ _ _ i) = i
ssClose :: SockStream -> IO ()
ssClose = close . ssToSocket
ssConnect :: Maybe SockAddr -> SockAddr -> IO SockStream
ssConnect lsa rsa =
    bracketOnError (socket AF_INET Stream tcpProtoNum) close $ \sock -> do
        whenJust lsa (bind sock)
        connect sock rsa
        ssFromSocket sock
{-# INLINE ssConnect #-}
ssRead :: SockStream -> IO ByteString
ssRead ss@(SockStream s bufref rcntref _ _) = ssDebug "ssRead" ss $ do
    buf <- readIORef bufref
    if B.null buf
    then do
        buf' <- recv s 32752 
        modifyIORef' rcntref (+ (fromIntegral $ B.length buf'))
        return buf'
    else do
        writeIORef bufref B.empty
        return buf
ssRead' :: SockStream -> IO ByteString
ssRead' ss = do
    buf <- ssRead ss
    if B.null buf then ioError eofEx else return buf
  where
    eofEx = mkIOError eofErrorType "ssRead'" Nothing Nothing
ssPeek :: SockStream -> IO ByteString
ssPeek ss@(SockStream s bufref rcntref _ _) = ssDebug "ssPeek" ss $ do
    buf <- readIORef bufref
    if B.null buf
    then do
        buf' <- recv s 32752 
        modifyIORef' rcntref (+ (fromIntegral $ B.length buf'))
        writeIORef bufref buf'
        return buf'
    else do
        return buf
ssPeekBuf :: SockStream -> IO ByteString
ssPeekBuf ss@(SockStream _ bufref _ _ _) =
    ssDebug "ssPeekBuf" ss $ readIORef bufref
ssReadN :: SockStream -> Word64 -> IO ByteString
ssReadN ss@(SockStream s bufref rcntref _ _) l0 = ssDebug "ssRead'" ss $ do
    buf <- readIORef bufref
    let l    = fromIntegral l0
        need = l - B.length buf
    if need <= 0
    then
        atomicModifyIORef' bufref (swap . B.splitAt l)
    else do
        (res,buf') <- go need buf
        let rcntdelta = B.length buf' + B.length res - B.length buf
        writeIORef   bufref $! buf'
        modifyIORef' rcntref (+ fromIntegral rcntdelta)
        return res
  where
    go n bufa
      | n >  0     = do
          buf <- recv s 32752
          let l = B.length buf
          if l > n
          then 
              return (bufa <> B.unsafeTake n buf, B.unsafeDrop n buf)
          else do 
              unless (l>0) $ ioError eofEx
              go (n-l) (bufa <> buf)
      | n == 0     = return (bufa,B.empty)
      | otherwise  = return $ B.splitAt n bufa
    eofEx = mkIOError eofErrorType "ssReadN" Nothing Nothing
ssUnRead :: ByteString -> SockStream -> IO ()
ssUnRead buf0 ss@(SockStream _ bufref _ _ _) =
    ssDebug "ssUnRead" ss $ modifyIORef' bufref (buf0 <>)
ssReadCnt :: SockStream -> IO Word64
ssReadCnt ss@(SockStream _ bufref rcntref _ _) = ssDebug "ssReadCnt" ss $ do
    buf <- readIORef bufref
    rcnt <- readIORef rcntref
    return $! rcnt - fromIntegral (B.length buf)
ssWrite :: ByteString -> SockStream -> IO ()
ssWrite buf ss@(SockStream s _ _ wcntref _) = ssDebug "ssWrite" ss $
    when (buflen /= 0) $ do
        sendAll s buf
        modifyIORef' wcntref (+ buflen)
  where
    buflen = fromIntegral $ B.length buf
ssWriteCnt :: SockStream -> IO Word64
ssWriteCnt (SockStream _ _ _ wcntref _) = readIORef wcntref
tcpProtoNum :: ProtocolNumber
tcpProtoNum = unsafePerformIO $ getProtocolNumber "tcp"
{-# NOINLINE tcpProtoNum #-}
getSockAddr :: HostName -> PortNumber -> IO SockAddr
getSockAddr hostname port = do
    he <- getHostByName hostname
    return $! SockAddrInet port (hostAddress he)
type HttpCode = Int
data TransferEncoding = TeIdentity !Word64 
                      | TeChunked          
                      | TeInvalid
                      deriving (Show,Eq)
instance NFData TransferEncoding where rnf !_ = ()
httpHeaderGetInfos :: [ByteString] -> (HttpCode, Bool, TransferEncoding)
httpHeaderGetInfos hds0
    | ver /= "HTTP/1.1" = error "unsupported HTTP version"
    | otherwise         = (code, connClose, if chunkTx then TeChunked else clen)
  where
    hds = init hds0
    (ver:code':_) = B.split 0x20 (last hds0) 
    code | Just (n,_) <- readDecimal code' = n
         | otherwise = -1
    
    connClose = "Connection: close" `elem` hds
    chunkTx   = "Transfer-Encoding: chunked" `elem` hds
    clen | (h:_) <- filter ("Content-Length: " `B.isPrefixOf`) hds =
                 case readDecimal (B.unsafeDrop 16 h) of
                     Just (n,_) -> TeIdentity n
                     Nothing    -> TeInvalid
         | otherwise = TeInvalid 
    
    
recvHttpHeaders :: SockStream -> IO [ByteString]
recvHttpHeaders ss = do
    res <- ssRead' ss
    (buf,h0:hds) <- go $ httpParseHeader (res,[])
    unless (B.null h0) $ fail "recvHttpHeaders"
    ssUnRead buf ss
    return hds
  where
    go st@(res,hds) 
      | httpParseHeaderDone st = return st 
      | otherwise = do
          buf <- ssRead' ss
          go $ httpParseHeader (res <> buf,hds)
    
    httpParseHeader :: (ByteString,[ByteString]) -> (ByteString,[ByteString])
    httpParseHeader (s0,acc)
      | Just i <- B.elemIndex 10 s0 =
          let (line1,rest1) = (stripCR $ B.unsafeTake i s0, B.unsafeDrop (i+1) s0)
          in (if B.null line1 then id else httpParseHeader) (rest1,line1:acc)
      | otherwise = (s0, acc) 
    httpParseHeaderDone :: (ByteString,[ByteString]) -> Bool
    httpParseHeaderDone (_,l:_) | B.null l = True
    httpParseHeaderDone _       = False
data HttpResponse = HttpResponse
    { respCode       :: !HttpCode    
    , respKeepalive  :: !Bool        
    , respContentLen :: !Word64      
    , respHeader     :: [MsgHeader]  
    , respContent    :: [ByteString] 
    } deriving Show
instance NFData HttpResponse where
    rnf (HttpResponse _ _ _ h c) = h `deepseq` c `deepseq` ()
recvHttpResponse :: SockStream -> IO HttpResponse
recvHttpResponse ss = do
    hds <- recvHttpHeaders ss
    let (code, needClose, te) = httpHeaderGetInfos hds
    (clen',body) <- case te of
        TeIdentity n -> recvIdentityBody n
        TeChunked    -> recvChunkedBody
        TeInvalid    -> fail "invalid response w/ invalid transfer-encoding/content-length"
    return $! HttpResponse code (not needClose) clen' hds body
  where
    recvIdentityBody n = do
        res <- ssReadN ss n
        return (n, [res])
    recvChunkedBody = do
        res <- go'
        return (fI $ sum $ map B.length res, res)
      where
        go' = do
            bs <- recvChunk ss
            if B.null bs
            then return []
            else fmap (bs:) go'
recvChunk :: SockStream -> IO ByteString
recvChunk ss = go B.empty
  where
    go buf
      | Just j <- B.elemIndex 10 buf = do
          let (chunksize',rest) = (stripCR $ B.unsafeTake j buf, B.unsafeDrop (j+1) buf)
          ssUnRead rest ss
          chunksize <- maybe (fail "invalid chunk-size") return $ readHex chunksize'
          bs <- ssReadN ss chunksize
          dropCrLf
          return bs
      | otherwise = do 
          buf' <- ssRead' ss
          go (buf<>buf')
    dropCrLf = do
        tmp <- ssReadN ss 2
        unless (tmp == "\r\n") $ fail "dropCrLf: expected CRLF"
    
    readHex bs | Just (n,rest) <- readHexadecimal bs, B.null rest, n>=0 = Just n
               | otherwise = Nothing
{-# INLINE recvChunk #-}
splitUrl :: String -> Either String (String,PortNumber,String)
splitUrl url0 = do
    uri <- note "invalid URI" $ parseAbsoluteURI url0
    unless (uriScheme uri == "http:") $
        Left "URI must have 'http' scheme"
    urlauth <- note "missing host-part in URI" $ uriAuthority uri
    let hostname = uriRegName urlauth
    when (null hostname) $
        Left "empty hostname in URL"
    unless (null . uriUserInfo $ urlauth) $
        Left "user/pass in URL not supported"
    portnum <- case uriPort urlauth of
        ""      -> return 80
        ':':tmp -> return $! fromIntegral (read tmp :: Word)
        _       -> Left "invalid port-number"
    return (hostname,portnum,if null (uriPath uri) then "/" else uriPath uri)
data Method = GET | POST | HEAD | PUT | DELETE | TRACE | CONNECT | OPTIONS
            deriving (Show,Eq,Enum)
type ReqURI    = ByteString 
type HostPort  = ByteString 
type MsgHeader = ByteString 
mkHttp11Req :: Method
            -> ReqURI
            -> HostPort
            -> Bool        
            -> [MsgHeader] 
            -> (Maybe ByteString) 
            -> ByteString  
mkHttp11Req method urlpath hostport keepalive xhdrs mbody = mconcat request
  where
    request = methStr:urlpath:" HTTP/1.1\r\nHost: ":hostport:
              (if keepalive then "\r\n" else "\r\nConnection: close\r\n"):
              addCrLf' bodydat xhdrs
    bodydat | Just body <- mbody  = [ "Content-Length: ", (fromJust . packDecimal . B.length) body
                                    , "\r\n\r\n"
                                    , body
                                    ]
            | otherwise  = ["\r\n"]
    addCrLf' :: [ByteString] -> [ByteString] -> [ByteString]
    addCrLf' tl = go
      where
        go (x:xs) = x:"\r\n":go xs
        go []     = tl
    methStr = case method of
        GET     -> "GET "
        POST    -> "POST "
        HEAD    -> "HEAD "
        PUT     -> "PUT "
        DELETE  -> "DELETE "
        TRACE   -> "TRACE "
        CONNECT -> "CONNECT "
        OPTIONS -> "OPTIONS "
mkHttp11GetReq :: ReqURI -> HostPort -> Bool -> [MsgHeader] -> ByteString
mkHttp11GetReq urlpath hostport keepalive xhdrs = mkHttp11Req GET urlpath hostport keepalive xhdrs Nothing
whenJust :: Monad m => Maybe a -> (a -> m ()) -> m ()
whenJust mb f = maybe (return ()) f mb
{-# INLINE whenJust #-}
note :: a -> Maybe b -> Either a b
note a = maybe (Left a) Right
fI :: (Integral a, Num b) => a -> b
fI = fromIntegral
stripCR :: ByteString -> ByteString
stripCR s
  | B.null s                = s
#if MIN_VERSION_bytestring(0,10,2)
  | B.unsafeLast s == 0x0d  = B.unsafeInit s
#else
  | B.last s == 0x0d        = B.init s
#endif
  | otherwise               = s
foreign import capi unsafe "hs_uhttpc.h get_posix_time_secs" getPOSIXTimeSecs :: IO Double
foreign import capi unsafe "hs_uhttpc.h get_posix_time_usecs" getPOSIXTimeUSecs :: IO Word64