{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE ForeignFunctionInterface #-}

-- |
-- Filelike provides a typeclass for Unix style "everything is a file" IO,
-- and implementations for abstracting standard IO, files and network connections.
-- This module also provides TLS wraping over other filelike types.
module System.IO.Uniform.Targets (TlsSettings(..), UniformIO(..), SocketIO, FileIO, TlsStream, BoundPort, SomeIO(..), connectTo, connectToHost, bindPort, accept, openFile, getPeer, closePort) where

import Foreign
import Foreign.C.Types
import Foreign.C.String
import Foreign.C.Error
import qualified Data.IP as IP
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.List as L
import Control.Exception
import qualified Network.Socket as Soc
import System.IO.Error

import Data.Default.Class

-- | Settings for starttls functions.
data TlsSettings = TlsSettings {tlsPrivateKeyFile :: String, tlsCertificateChainFile :: String} deriving (Read, Show)

instance Default TlsSettings where
  def = TlsSettings "" ""

-- |
-- Typeclass for IO objects that behave like a Unix file (independent of seeking support).
class UniformIO a where
  -- | fRead fd n
  --  Reads a block of at most n bytes of data from the filelike object fd.
  --  Reading will block if there's no data available, but will return immediately
  --  if any amount of data is availble.
  fRead  :: a -> Int -> IO ByteString
  -- | fPut fd text
  --  Writes all the bytes of text into the filelike object. Takes care of retrying if needed.
  fPut   :: a -> ByteString -> IO ()
  -- | fClose fd
  --  Closes the filelike object, releasing any allocated resource. Resources may leak if not called
  --  for every oppened fd.
  fClose :: a -> IO ()
  -- | startTLS fd
  --  Starts a TLS connection over the filelike object.
  startTls :: TlsSettings -> a -> IO TlsStream
  -- | isSecure fd
  --  Indicates whether the data written or read from fd is secure at transport.
  isSecure :: a -> Bool
  
-- | A type that wraps any type in the UniformIO class.
data SomeIO = forall a. (UniformIO a) => SomeIO a

instance UniformIO SomeIO where
  fRead (SomeIO s) n = fRead s n
  fPut (SomeIO s) t  = fPut s t
  fClose (SomeIO s) = fClose s
  startTls set (SomeIO s) = startTls set s
  isSecure (SomeIO s) = isSecure s

data Nethandler
-- | A bound IP port from where to accept SocketIO connections.
newtype BoundPort = BoundPort {lis :: (Ptr Nethandler)}
data SockDs
newtype SocketIO = SocketIO {sock :: (Ptr SockDs)}
data FileDs
newtype FileIO = FileIO {fd :: (Ptr FileDs)}
data TlsDs
newtype TlsStream = TlsStream {tls :: (Ptr TlsDs)}

-- | UniformIO IP connections.
instance UniformIO SocketIO where
  fRead s n = allocaArray n (
    \b -> do
      count <- c_recvSock (sock s) b (fromIntegral n)
      if count < 0
        then throwErrno "could not read"
        else BS.packCStringLen (b, fromIntegral count)
    )
  fPut s t = BS.useAsCStringLen t (
    \(str, n) -> do
      count <- c_sendSock (sock s) str $ fromIntegral n
      if count < 0
        then throwErrno "could not write"
        else return ()
    )
  fClose s = c_closeSock (sock s)
  startTls st s = withCString (tlsCertificateChainFile st) (
    \cert -> withCString (tlsPrivateKeyFile st) (
      \key -> do
        r <- c_startSockTls (sock s) cert key
        if r == nullPtr
          then throwErrno "could not start TLS"
          else return . TlsStream $ r
      )
    )
  isSecure _ = False
  
-- | UniformIO type for file IO.
instance UniformIO FileIO where
  fRead s n = allocaArray n (
    \b -> do
      count <- c_recvFile (fd s) b $ fromIntegral n
      if count < 0
        then throwErrno "could not read"
        else  BS.packCStringLen (b, fromIntegral count)
    )
  fPut s t = BS.useAsCStringLen t (
    \(str, n) -> do
      count <- c_sendFile (fd s) str $ fromIntegral n
      if count < 0
        then throwErrno "could not write"
        else return ()
    )
  fClose s = c_closeFile (fd s)
  -- Not implemented yet.
  startTls _ _ = return . TlsStream $ nullPtr
  isSecure _ = False
  
-- | UniformIO wrapper that applies TLS to communication on filelike objects.
instance UniformIO TlsStream where
  fRead s n = allocaArray n (
    \b -> do
      count <- c_recvTls (tls s) b $ fromIntegral n
      if count < 0
        then throwErrno "could not read"
        else BS.packCStringLen (b, fromIntegral count)
    )
  fPut s t = BS.useAsCStringLen t (
    \(str, n) -> do
      count <- c_sendTls (tls s) str $ fromIntegral n
      if count < 0
        then throwErrno "could not write"
        else return ()
    )
  fClose s = c_closeTls (tls s)
  startTls _ s = return s
  isSecure _ = True

-- | connectToHost hostName port
--  Connects to the given host and port.
connectToHost :: String -> Int -> IO SocketIO
connectToHost host port = do
  ip <- getAddr
  connectTo ip port
  where
    getAddr :: IO IP.IP
    getAddr = do
      add <- Soc.getAddrInfo Nothing (Just host) Nothing
      case add of
        [] -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing
        (a:_) -> case Soc.addrAddress a of
          Soc.SockAddrInet _ a'  -> return . IP.IPv4 . IP.fromHostAddress $ a'
          Soc.SockAddrInet6 _ _ a' _ -> return . IP.IPv6 . IP.fromHostAddress6 $ a'
          _ -> throwIO $ mkIOError doesNotExistErrorType "host not found" Nothing Nothing


-- | ConnecctTo ipAddress port
--  Connects to the given port of the host at the given IP address.
connectTo :: IP.IP -> Int -> IO SocketIO
connectTo host port = do
  r <- case host of
    IP.IPv4 host' -> fmap SocketIO $ c_connect4 (fromIntegral . IP.toHostAddress $ host') (fromIntegral port)
    IP.IPv6 host' -> fmap SocketIO $ withArray (ipToArray host') (
      \add -> c_connect6 add (fromIntegral port)
      )
  if sock r == nullPtr
    then throwErrno "could not connect to host"
    else return r
  where
    ipToArray :: IP.IPv6 -> [CUChar]
    ipToArray ip = let
      (w0, w1, w2, w3) = IP.toHostAddress6 ip
      in L.concat [wtoc w0, wtoc w1, wtoc w2, wtoc w3]
    wtoc :: Word32 -> [CUChar]
    wtoc w = let
      c0 = fromIntegral $ mod w 256
      w1 = div w 256
      c1 = fromIntegral $ mod w1 256
      w2 = div w1 256
      c2 = fromIntegral $ mod w2 256
      c3 = fromIntegral $ div w2 256
      in [c3, c2, c1, c0]
  
-- | bindPort port
--  Binds to the given IP port, becoming ready to accept connections on it.
--  Binding to port numbers under 1024 will fail unless performed by the superuser,
--  once bound, a process can reduce its privileges and still accept clients on that port.
bindPort :: Int -> IO BoundPort
bindPort port = do
  r <- fmap BoundPort $ c_getPort $ fromIntegral port
  if lis r == nullPtr
    then throwErrno "could not bind to port"
    else return r
  
-- | accept port
--  Accept clients on a port previously bound with bindPort.
accept :: BoundPort -> IO SocketIO
accept port = do
  r <- fmap SocketIO $ c_accept (lis port)
  if sock r == nullPtr
    then throwErrno "could not accept connection"
    else return r
  
-- | Open a file for bidirectional IO.
openFile :: String -> IO FileIO
openFile fileName = do
  r <- withCString fileName (
    \f -> fmap FileIO $ c_createFile f
    )
  if fd r == nullPtr
    then throwErrno "could not open file"
    else return r

-- | Gets the address of the peer socket of a internet connection.
getPeer :: SocketIO -> IO (IP.IP, Int)
getPeer s = allocaArray 16 (
  \p6 -> alloca (
    \p4 -> alloca (
      \iptype -> do
        p <- c_getPeer (sock s) p4 p6 iptype
        if p == -1
          then throwErrno "could not get peer address"
          else do
          iptp <- peek iptype
          if iptp == 1
            then do --IPv6
            add <- peekArray 16 p6
            return (IP.IPv6 . IP.toIPv6b $ map fromIntegral add, fromIntegral p)
            else do --IPv4
            add <- peek p4
            return (IP.IPv4 . IP.fromHostAddress . fromIntegral $ add, fromIntegral p)
      )
    )
  )
    
-- | Closes a BoundPort, and releases any resource used by it.
closePort :: BoundPort -> IO ()
closePort p = c_closePort (lis p)

foreign import ccall "getPort" c_getPort :: CInt -> IO (Ptr Nethandler)
foreign import ccall "createFromHandler" c_accept :: Ptr Nethandler -> IO (Ptr SockDs)
foreign import ccall "createFromFileName" c_createFile :: CString -> IO (Ptr FileDs)
foreign import ccall "createToIPv4Host" c_connect4 :: CUInt -> CInt -> IO (Ptr SockDs)
foreign import ccall "createToIPv6Host" c_connect6 :: Ptr CUChar -> CInt -> IO (Ptr SockDs)

foreign import ccall "startSockTls" c_startSockTls :: Ptr SockDs -> CString -> CString -> IO (Ptr TlsDs)
foreign import ccall "getPeer" c_getPeer :: Ptr SockDs -> Ptr CUInt -> Ptr CUChar -> Ptr CInt -> IO (CInt)

foreign import ccall "closeSockDs" c_closeSock :: Ptr SockDs -> IO ()
foreign import ccall "closeFileDs" c_closeFile :: Ptr FileDs -> IO ()
foreign import ccall "closeHandler" c_closePort :: Ptr Nethandler -> IO ()
foreign import ccall "closeTlsDs" c_closeTls :: Ptr TlsDs -> IO ()

foreign import ccall "fileDsSend" c_sendFile :: Ptr FileDs -> Ptr CChar -> CInt -> IO CInt
foreign import ccall "sockDsSend" c_sendSock :: Ptr SockDs -> Ptr CChar -> CInt -> IO CInt
foreign import ccall "tlsDsSend" c_sendTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt

foreign import ccall "fileDsRecv" c_recvFile :: Ptr FileDs -> Ptr CChar -> CInt -> IO CInt
foreign import ccall "sockDsRecv" c_recvSock :: Ptr SockDs -> Ptr CChar -> CInt -> IO CInt
foreign import ccall "tlsDsRecv" c_recvTls :: Ptr TlsDs -> Ptr CChar -> CInt -> IO CInt