module Database.Redis.ProtocolPipelining (
  Connection,
  connect, disconnect, request, send, recv, flush,
  ConnectionLostException(..),
  HostName, PortID(..)
) where
import           Prelude
import           Control.Exception
import           Control.Monad
import qualified Scanner
import qualified Data.ByteString as S
import           Data.IORef
import           Data.Typeable
import           Network
import qualified Network.BSD as BSD
import qualified Network.Socket as NS
import           System.IO
import           System.IO.Error
import           System.IO.Unsafe
import           Database.Redis.Protocol
data Connection = Conn
  { connHandle     :: Handle        
  , connReplies    :: IORef [Reply] 
  , connPending    :: IORef [Reply]
    
    
  , connPendingCnt :: IORef Int
    
    
    
  }
data ConnectionLostException = ConnectionLost
  deriving (Show, Typeable)
instance Exception ConnectionLostException
connect :: HostName -> PortID -> IO Connection
connect hostName portID =
  bracketOnError (hConnect portID) hClose $ \connHandle -> do
    hSetBinaryMode connHandle True
    connReplies <- newIORef []
    connPending <- newIORef []
    connPendingCnt <- newIORef 0
    let conn = Conn{..}
    rs <- connGetReplies conn
    writeIORef connReplies rs
    writeIORef connPending rs
    return conn
  where hConnect (PortNumber port) =
          bracketOnError mkSocket NS.close $ \sock -> do
            NS.setSocketOption sock NS.KeepAlive 1
            host <- BSD.getHostByName hostName
            NS.connect sock $ NS.SockAddrInet port (BSD.hostAddress host)
            NS.socketToHandle sock ReadWriteMode
        hConnect _ = connectTo hostName portID
        mkSocket   = NS.socket NS.AF_INET NS.Stream 0
disconnect :: Connection -> IO ()
disconnect Conn{..} = do
  open <- hIsOpen connHandle
  when open (hClose connHandle)
send :: Connection -> S.ByteString -> IO ()
send Conn{..} s = do
  ioErrorToConnLost (S.hPut connHandle s)
  
  n <- atomicModifyIORef' connPendingCnt $ \n -> let n' = n+1 in (n', n')
  
  
  
  when (n >= 1000) $ do
    
    r:_ <- readIORef connPending
    r `seq` return ()
recv :: Connection -> IO Reply
recv Conn{..} = do
  (r:rs) <- readIORef connReplies
  writeIORef connReplies rs
  return r
flush :: Connection -> IO ()
flush Conn{..} = hFlush connHandle
request :: Connection -> S.ByteString -> IO Reply
request conn req = send conn req >> recv conn
connGetReplies :: Connection -> IO [Reply]
connGetReplies Conn{..} = go S.empty (SingleLine "previous of first")
  where
    go rest previous = do
      
      ~(r, rest') <- unsafeInterleaveIO $ do
        
        previous `seq` return ()
        scanResult <- Scanner.scanWith readMore reply rest
        case scanResult of
          Scanner.Fail{}       -> errConnClosed
          Scanner.More{}    -> error "Hedis: parseWith returned Partial"
          Scanner.Done rest' r -> do
            
            
            atomicModifyIORef' connPending $ \(_:rs) -> (rs, ())
            
            
            atomicModifyIORef' connPendingCnt $ \n -> (max 0 (n1), ())
            return (r, rest')
      rs <- unsafeInterleaveIO (go rest' r)
      return (r:rs)
    readMore = ioErrorToConnLost $ do
      hFlush connHandle 
      S.hGetSome connHandle 4096
ioErrorToConnLost :: IO a -> IO a
ioErrorToConnLost a = a `catchIOError` const errConnClosed
errConnClosed :: IO a
errConnClosed = throwIO ConnectionLost