{-# LANGUAGE RecordWildCards #-}
module Database.Bolt.Connection.Connection where

import           Control.Applicative    (pure, (<$>))
import           Control.Exception      (throwIO)
import           Control.Monad          (forM_, when)
import           Control.Monad.Trans    (MonadIO (..))
import           Data.ByteString        (ByteString, null)
import           Data.Default           (Default (..))
import           GHC.Stack              (HasCallStack, withFrozenCallStack)
import           Network.Socket         (PortNumber)
import           Network.Connection     (ConnectionParams (..), connectTo, connectionClose,
                                        connectionGetExact, connectionPut, connectionSetSecure,
                                        initConnectionContext)
import           Prelude                hiding (null)
import           System.Timeout         (timeout)

import           Database.Bolt.Connection.Type (BoltError (..), ConnectionWithTimeout (..))

connect
  :: MonadIO m
  => HasCallStack
  => Bool
     -- ^ Use secure connection
  -> String
     -- ^ Hostname
  -> PortNumber
  -> Int
     -- ^ Connection and read timeout in seconds
  -> m ConnectionWithTimeout
connect :: Bool -> String -> PortNumber -> Int -> m ConnectionWithTimeout
connect Bool
secure String
host PortNumber
port Int
timeSec = IO ConnectionWithTimeout -> m ConnectionWithTimeout
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ConnectionWithTimeout -> m ConnectionWithTimeout)
-> IO ConnectionWithTimeout -> m ConnectionWithTimeout
forall a b. (a -> b) -> a -> b
$ do
                                      let timeUsec :: Int
timeUsec = Int
1000000 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
timeSec
                                      ConnectionContext
ctx  <- IO ConnectionContext
initConnectionContext
                                      Connection
conn <- Int -> IO Connection -> IO Connection
forall a. HasCallStack => Int -> IO a -> IO a
timeoutThrow Int
timeUsec (IO Connection -> IO Connection) -> IO Connection -> IO Connection
forall a b. (a -> b) -> a -> b
$
                                              ConnectionContext -> ConnectionParams -> IO Connection
connectTo ConnectionContext
ctx ConnectionParams :: String
-> PortNumber
-> Maybe TLSSettings
-> Maybe ProxySettings
-> ConnectionParams
ConnectionParams { connectionHostname :: String
connectionHostname  = String
host
                                                                             , connectionPort :: PortNumber
connectionPort      = PortNumber
port
                                                                             , connectionUseSecure :: Maybe TLSSettings
connectionUseSecure = Maybe TLSSettings
forall a. Maybe a
Nothing
                                                                             , connectionUseSocks :: Maybe ProxySettings
connectionUseSocks  = Maybe ProxySettings
forall a. Maybe a
Nothing
                                                                             }
                                      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
secure (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ConnectionContext -> Connection -> TLSSettings -> IO ()
connectionSetSecure ConnectionContext
ctx Connection
conn TLSSettings
forall a. Default a => a
def
                                      ConnectionWithTimeout -> IO ConnectionWithTimeout
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ConnectionWithTimeout -> IO ConnectionWithTimeout)
-> ConnectionWithTimeout -> IO ConnectionWithTimeout
forall a b. (a -> b) -> a -> b
$ Connection -> Int -> ConnectionWithTimeout
ConnectionWithTimeout Connection
conn Int
timeUsec

close :: MonadIO m => HasCallStack => ConnectionWithTimeout -> m ()
close :: ConnectionWithTimeout -> m ()
close ConnectionWithTimeout{Int
Connection
cwtTimeoutUsec :: ConnectionWithTimeout -> Int
cwtConnection :: ConnectionWithTimeout -> Connection
cwtTimeoutUsec :: Int
cwtConnection :: Connection
..} = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> IO () -> IO ()
forall a. HasCallStack => Int -> IO a -> IO a
timeoutThrow Int
cwtTimeoutUsec (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
connectionClose Connection
cwtConnection

recv :: MonadIO m => HasCallStack => ConnectionWithTimeout -> Int -> m (Maybe ByteString)
recv :: ConnectionWithTimeout -> Int -> m (Maybe ByteString)
recv ConnectionWithTimeout{Int
Connection
cwtTimeoutUsec :: Int
cwtConnection :: Connection
cwtTimeoutUsec :: ConnectionWithTimeout -> Int
cwtConnection :: ConnectionWithTimeout -> Connection
..} = IO (Maybe ByteString) -> m (Maybe ByteString)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe ByteString) -> m (Maybe ByteString))
-> (Int -> IO (Maybe ByteString)) -> Int -> m (Maybe ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ByteString -> Bool) -> ByteString -> Maybe ByteString
forall a. (a -> Bool) -> a -> Maybe a
filterMaybe (Bool -> Bool
not (Bool -> Bool) -> (ByteString -> Bool) -> ByteString -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Bool
null) (ByteString -> Maybe ByteString)
-> IO ByteString -> IO (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) (IO ByteString -> IO (Maybe ByteString))
-> (Int -> IO ByteString) -> Int -> IO (Maybe ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IO ByteString -> IO ByteString
forall a. HasCallStack => Int -> IO a -> IO a
timeoutThrow Int
cwtTimeoutUsec (IO ByteString -> IO ByteString)
-> (Int -> IO ByteString) -> Int -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> Int -> IO ByteString
connectionGetExact Connection
cwtConnection
  where
    filterMaybe :: (a -> Bool) -> a -> Maybe a
    filterMaybe :: (a -> Bool) -> a -> Maybe a
filterMaybe a -> Bool
p a
x | a -> Bool
p a
x       = a -> Maybe a
forall a. a -> Maybe a
Just a
x
                    | Bool
otherwise = Maybe a
forall a. Maybe a
Nothing

send :: MonadIO m => HasCallStack => ConnectionWithTimeout -> ByteString -> m ()
send :: ConnectionWithTimeout -> ByteString -> m ()
send ConnectionWithTimeout{Int
Connection
cwtTimeoutUsec :: Int
cwtConnection :: Connection
cwtTimeoutUsec :: ConnectionWithTimeout -> Int
cwtConnection :: ConnectionWithTimeout -> Connection
..} = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (ByteString -> IO ()) -> ByteString -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IO () -> IO ()
forall a. HasCallStack => Int -> IO a -> IO a
timeoutThrow Int
cwtTimeoutUsec (IO () -> IO ()) -> (ByteString -> IO ()) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> ByteString -> IO ()
connectionPut Connection
cwtConnection

sendMany :: MonadIO m => HasCallStack => ConnectionWithTimeout -> [ByteString] -> m ()
sendMany :: ConnectionWithTimeout -> [ByteString] -> m ()
sendMany conn :: ConnectionWithTimeout
conn@ConnectionWithTimeout{Int
Connection
cwtTimeoutUsec :: Int
cwtConnection :: Connection
cwtTimeoutUsec :: ConnectionWithTimeout -> Int
cwtConnection :: ConnectionWithTimeout -> Connection
..} [ByteString]
chunks = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ [ByteString] -> (ByteString -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [ByteString]
chunks ((ByteString -> IO ()) -> IO ()) -> (ByteString -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> IO () -> IO ()
forall a. HasCallStack => Int -> IO a -> IO a
timeoutThrow Int
cwtTimeoutUsec (IO () -> IO ()) -> (ByteString -> IO ()) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnectionWithTimeout -> ByteString -> IO ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ConnectionWithTimeout -> ByteString -> m ()
send ConnectionWithTimeout
conn

timeoutThrow :: HasCallStack => Int -> IO a -> IO a
timeoutThrow :: Int -> IO a -> IO a
timeoutThrow Int
timeUsec IO a
action = (HasCallStack => IO a) -> IO a
forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack ((HasCallStack => IO a) -> IO a) -> (HasCallStack => IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ do
  Maybe a
res <- Int -> IO a -> IO (Maybe a)
forall a. Int -> IO a -> IO (Maybe a)
timeout Int
timeUsec IO a
action
  case Maybe a
res of
    Just a
a  -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
    Maybe a
Nothing -> BoltError -> IO a
forall e a. Exception e => e -> IO a
throwIO BoltError
HasCallStack => BoltError
TimeOut