module Resolve.DNS.TCP where

import Resolve.Types
import Resolve.DNS.Types
import qualified Resolve.DNS.Channel as C

import Data.Typeable
import Data.ByteString.Builder
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL


import Network.Socket hiding (recv, send, socket, Closed)
import Network.Socket.ByteString

import Data.Bits

import Control.Monad
import Control.Monad.Trans.Except
import Control.Monad.Trans.Class
import Control.Concurrent
import Control.Exception
import Control.Concurrent.STM.TMVar
import Control.Concurrent.STM.TVar
import Control.Monad.STM

import System.Log.Logger

nameM = "Resolve.DNS.TCP"

data Config = Config { socket :: Socket
                     }

data Closed = Closed
  deriving (Show, Typeable)

instance Exception Closed where
    toException = dnsExceptionToException
    fromException = dnsExceptionFromException

new :: Config -> IO (Resolver Message Message)
new c = do
  qi <- newEmptyTMVarIO
  qo <- newEmptyTMVarIO
  si <- newTVarIO True
  so <- newTVarIO True
  
  bracketOnError
    (do
        chan <- C.new C.Config { C.send =  \a -> do
                                   v <- atomically $ do
                                     v <- readTVar so
                                     when v $ putTMVar qo a
                                     return v
                                   when (not v) $ throw Closed
                               , C.recv = do
                                   v <- atomically $ do
                                     v <- readTVar si
                                     if v then do 
                                       x <- takeTMVar qi
                                       return $ Just x
                                       else tryTakeTMVar qi
                                   case v of
                                     Nothing -> throw Closed
                                     Just x -> return x
                               , C.nick = "TCP<" ++ (show $ socket c) ++ ">"
                               }

        to <- forkIOWithUnmask $ \unmask -> unmask $ finally
          (forever $ do
              let nameF = nameM ++ ".send"
              let sendAll bs = if BS.null bs then
                                 return ()
                               else do
                    n <- send (socket c) bs
                    sendAll (BS.drop n bs)
              bs <- atomically $ takeTMVar qo
              sendAll $ BSL.toStrict $ toLazyByteString $ word16BE $ fromIntegral $ BS.length bs
              sendAll $ bs)
          (do
              debugM nameM "send died"
              atomically $ writeTVar so False)

    
        ti <- forkIOWithUnmask $ \unmask -> unmask $ finally
          (forever $ runExceptT $ do -- EitherT String IO ()
              let nameF = nameM ++ ".recv"
              let recvAll' n = if n == 0 then return mempty
                    else do  -- IO ()
                    bs <- recv (socket c) n

                    when (BS.null bs) $ do
                      throwTo to ThreadKilled
                      throw ThreadKilled
                    mappend (byteString bs) <$> (recvAll' $ n - (BS.length bs))
                  recvAll n = do
                    BSL.toStrict <$> toLazyByteString <$> recvAll' n
              n <- lift $ recvAll 2
              let n' = ((fromIntegral $ BS.index n 0) `shift` 8) .|. (fromIntegral $ BS.index n 1)
              d <- lift $ recvAll $ n'
              lift $ atomically $ putTMVar qi $ d)
          (do
              debugM nameM "recv died"
              atomically $ writeTVar si False)
        
        return (resolve chan, do
                   delete chan
                   killThread ti
                   killThread to
               )
    )
    (\(_, d) -> d)
    (\(r, d) -> return $ Resolver { resolve = r
                                  , delete = d
                                  }
    )