{-# LANGUAGE RecordWildCards #-}
module Database.Bolt.Connection.Connection where

import           Control.Applicative    (pure, (<$>))
import           Control.Exception      (throwIO)
import           Control.Monad          (forM_, when)
import           Control.Monad.Trans    (MonadIO (..))
import           Data.ByteString        (ByteString, null)
import           Data.Default           (Default (..))
import           GHC.Stack              (HasCallStack, withFrozenCallStack)
import           Network.Socket         (PortNumber)
import           Network.Connection     (ConnectionParams (..), connectTo, connectionClose,
                                        connectionGetExact, connectionPut, connectionSetSecure,
                                        initConnectionContext)
import           Prelude                hiding (null)
import           System.Timeout         (timeout)

import           Database.Bolt.Connection.Type (BoltError (..), ConnectionWithTimeout (..))

connect
  :: MonadIO m
  => HasCallStack
  => Bool
     -- ^ Use secure connection
  -> String
     -- ^ Hostname
  -> PortNumber
  -> Int
     -- ^ Connection and read timeout in seconds
  -> m ConnectionWithTimeout
connect :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Bool -> String -> PortNumber -> Int -> m ConnectionWithTimeout
connect Bool
secure String
host PortNumber
port Int
timeSec = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
                                      let timeUsec :: Int
timeUsec = Int
1000000 forall a. Num a => a -> a -> a
* Int
timeSec
                                      ConnectionContext
ctx  <- IO ConnectionContext
initConnectionContext
                                      Connection
conn <- forall a. HasCallStack => Int -> IO a -> IO a
timeoutThrow Int
timeUsec forall a b. (a -> b) -> a -> b
$
                                              ConnectionContext -> ConnectionParams -> IO Connection
connectTo ConnectionContext
ctx ConnectionParams { connectionHostname :: String
connectionHostname  = String
host
                                                                             , connectionPort :: PortNumber
connectionPort      = PortNumber
port
                                                                             , connectionUseSecure :: Maybe TLSSettings
connectionUseSecure = forall a. Maybe a
Nothing
                                                                             , connectionUseSocks :: Maybe ProxySettings
connectionUseSocks  = forall a. Maybe a
Nothing
                                                                             }
                                      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
secure forall a b. (a -> b) -> a -> b
$ ConnectionContext -> Connection -> TLSSettings -> IO ()
connectionSetSecure ConnectionContext
ctx Connection
conn forall a. Default a => a
def
                                      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Connection -> Int -> ConnectionWithTimeout
ConnectionWithTimeout Connection
conn Int
timeUsec

close :: MonadIO m => HasCallStack => ConnectionWithTimeout -> m ()
close :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ConnectionWithTimeout -> m ()
close ConnectionWithTimeout{Int
Connection
cwtTimeoutUsec :: ConnectionWithTimeout -> Int
cwtConnection :: ConnectionWithTimeout -> Connection
cwtTimeoutUsec :: Int
cwtConnection :: Connection
..} = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. HasCallStack => Int -> IO a -> IO a
timeoutThrow Int
cwtTimeoutUsec forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
connectionClose Connection
cwtConnection

recv :: MonadIO m => HasCallStack => ConnectionWithTimeout -> Int -> m (Maybe ByteString)
recv :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ConnectionWithTimeout -> Int -> m (Maybe ByteString)
recv ConnectionWithTimeout{Int
Connection
cwtTimeoutUsec :: Int
cwtConnection :: Connection
cwtTimeoutUsec :: ConnectionWithTimeout -> Int
cwtConnection :: ConnectionWithTimeout -> Connection
..} = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. (a -> Bool) -> a -> Maybe a
filterMaybe (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Bool
null) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => Int -> IO a -> IO a
timeoutThrow Int
cwtTimeoutUsec forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Int -> IO ByteString
connectionGetExact Connection
cwtConnection
  where
    filterMaybe :: (a -> Bool) -> a -> Maybe a
    filterMaybe :: forall a. (a -> Bool) -> a -> Maybe a
filterMaybe a -> Bool
p a
x | a -> Bool
p a
x       = forall a. a -> Maybe a
Just a
x
                    | Bool
otherwise = forall a. Maybe a
Nothing

send :: MonadIO m => HasCallStack => ConnectionWithTimeout -> ByteString -> m ()
send :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ConnectionWithTimeout -> ByteString -> m ()
send ConnectionWithTimeout{Int
Connection
cwtTimeoutUsec :: Int
cwtConnection :: Connection
cwtTimeoutUsec :: ConnectionWithTimeout -> Int
cwtConnection :: ConnectionWithTimeout -> Connection
..} = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => Int -> IO a -> IO a
timeoutThrow Int
cwtTimeoutUsec forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> ByteString -> IO ()
connectionPut Connection
cwtConnection

sendMany :: MonadIO m => HasCallStack => ConnectionWithTimeout -> [ByteString] -> m ()
sendMany :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ConnectionWithTimeout -> [ByteString] -> m ()
sendMany conn :: ConnectionWithTimeout
conn@ConnectionWithTimeout{Int
Connection
cwtTimeoutUsec :: Int
cwtConnection :: Connection
cwtTimeoutUsec :: ConnectionWithTimeout -> Int
cwtConnection :: ConnectionWithTimeout -> Connection
..} [ByteString]
chunks = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [ByteString]
chunks forall a b. (a -> b) -> a -> b
$ forall a. HasCallStack => Int -> IO a -> IO a
timeoutThrow Int
cwtTimeoutUsec forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ConnectionWithTimeout -> ByteString -> m ()
send ConnectionWithTimeout
conn

timeoutThrow :: HasCallStack => Int -> IO a -> IO a
timeoutThrow :: forall a. HasCallStack => Int -> IO a -> IO a
timeoutThrow Int
timeUsec IO a
action = forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack forall a b. (a -> b) -> a -> b
$ do
  Maybe a
res <- forall a. Int -> IO a -> IO (Maybe a)
timeout Int
timeUsec IO a
action
  case Maybe a
res of
    Just a
a  -> forall (m :: * -> *) a. Monad m => a -> m a
return a
a
    Maybe a
Nothing -> forall e a. Exception e => e -> IO a
throwIO HasCallStack => BoltError
TimeOut