-- | Internal functions
module Network.Transport.Internal
  ( -- * Encoders/decoders
    encodeInt32
  , decodeInt32
  , encodeInt16
  , decodeInt16
  , prependLength
    -- * Miscellaneous abstractions
  , mapIOException
  , tryIO
  , tryToEnum
  , timeoutMaybe
  , asyncWhenCancelled
  -- * Replicated functionality from "base"
  , void
  , forkIOWithUnmask
    -- * Debugging
  , tlog
  ) where

#if ! MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif

import Foreign.Storable (pokeByteOff, peekByteOff)
import Foreign.C (CInt(..), CShort(..))
import Foreign.ForeignPtr (withForeignPtr)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (length)
import qualified Data.ByteString.Internal as BSI
  ( unsafeCreate
  , toForeignPtr
  , inlinePerformIO
  )
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Exception
  ( IOException
  , SomeException
  , AsyncException
  , Exception
  , catch
  , try
  , throw
  , throwIO
  , mask_
  )
import Control.Concurrent (ThreadId, forkIO)
import Control.Concurrent.MVar (MVar, newEmptyMVar, takeMVar, putMVar)
import GHC.IO (unsafeUnmask)
import System.Timeout (timeout)
--import Control.Concurrent (myThreadId)

#ifdef mingw32_HOST_OS

foreign import stdcall unsafe "htonl" htonl :: CInt -> CInt
foreign import stdcall unsafe "ntohl" ntohl :: CInt -> CInt
foreign import stdcall unsafe "htons" htons :: CShort -> CShort
foreign import stdcall unsafe "ntohs" ntohs :: CShort -> CShort

#else

foreign import ccall unsafe "htonl" htonl :: CInt -> CInt
foreign import ccall unsafe "ntohl" ntohl :: CInt -> CInt
foreign import ccall unsafe "htons" htons :: CShort -> CShort
foreign import ccall unsafe "ntohs" ntohs :: CShort -> CShort

#endif

-- | Serialize 32-bit to network byte order
encodeInt32 :: Enum a => a -> ByteString
encodeInt32 i32 =
  BSI.unsafeCreate 4 $ \p ->
    pokeByteOff p 0 (htonl . fromIntegral . fromEnum $ i32)

-- | Deserialize 32-bit from network byte order
-- Throws an IO exception if this is not a valid integer.
decodeInt32 :: Num a => ByteString -> a
decodeInt32 bs
  | BS.length bs /= 4 = throw $ userError "decodeInt32: Invalid length"
  | otherwise         = BSI.inlinePerformIO $ do
      let (fp, offset, _) = BSI.toForeignPtr bs
      withForeignPtr fp $ \p -> do
        w32 <- peekByteOff p offset
        return (fromIntegral . ntohl $ w32)

-- | Serialize 16-bit to network byte order
encodeInt16 :: Enum a => a -> ByteString
encodeInt16 i16 =
  BSI.unsafeCreate 2 $ \p ->
    pokeByteOff p 0 (htons . fromIntegral . fromEnum $ i16)

-- | Deserialize 16-bit from network byte order
-- Throws an IO exception if this is not a valid integer
decodeInt16 :: Num a => ByteString -> a
decodeInt16 bs
  | BS.length bs /= 2 = throw $ userError "decodeInt16: Invalid length"
  | otherwise         = BSI.inlinePerformIO $ do
      let (fp, offset, _) = BSI.toForeignPtr bs
      withForeignPtr fp $ \p -> do
        w16 <- peekByteOff p offset
        return (fromIntegral . ntohs $ w16)

-- | Prepend a list of bytestrings with their total length
prependLength :: [ByteString] -> [ByteString]
prependLength bss = encodeInt32 (sum . map BS.length $ bss) : bss

-- | Translate exceptions that arise in IO computations
mapIOException :: Exception e => (IOException -> e) -> IO a -> IO a
mapIOException f p = catch p (throwIO . f)

-- | Like 'try', but lifted and specialized to IOExceptions
tryIO :: MonadIO m => IO a -> m (Either IOException a)
tryIO = liftIO . try

-- | Logging (for debugging)
tlog :: MonadIO m => String -> m ()
tlog _ = return ()
{-
tlog msg = liftIO $ do
  tid <- myThreadId
  putStrLn $ show tid ++ ": "  ++ msg
-}

-- | Not all versions of "base" export 'void'
void :: Monad m => m a -> m ()
void p = p >> return ()

-- | This was introduced in "base" some time after 7.0.4
forkIOWithUnmask :: ((forall a . IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask io = forkIO (io unsafeUnmask)

-- | Safe version of 'toEnum'
tryToEnum :: (Enum a, Bounded a) => Int -> Maybe a
tryToEnum = go minBound maxBound
  where
    go :: Enum b => b -> b -> Int -> Maybe b
    go lo hi n = if fromEnum lo <= n && n <= fromEnum hi then Just (toEnum n) else Nothing

-- | If the timeout value is not Nothing, wrap the given computation with a
-- timeout and it if times out throw the specified exception. Identity
-- otherwise.
timeoutMaybe :: Exception e => Maybe Int -> e -> IO a -> IO a
timeoutMaybe Nothing  _ f = f
timeoutMaybe (Just n) e f = do
  ma <- timeout n f
  case ma of
    Nothing -> throwIO e
    Just a  -> return a

-- | @asyncWhenCancelled g f@ runs f in a separate thread and waits for it
-- to complete. If f throws an exception we catch it and rethrow it in the
-- current thread. If the current thread is interrupted before f completes,
-- we run the specified clean up handler (if f throws an exception we assume
-- that no cleanup is necessary).
asyncWhenCancelled :: forall a. (a -> IO ()) -> IO a -> IO a
asyncWhenCancelled g f = mask_ $ do
    mvar <- newEmptyMVar
    forkIO $ try f >>= putMVar mvar
    -- takeMVar is interruptible (even inside a mask_)
    catch (takeMVar mvar) (exceptionHandler mvar) >>= either throwIO return
  where
    exceptionHandler :: MVar (Either SomeException a)
                     -> AsyncException
                     -> IO (Either SomeException a)
    exceptionHandler mvar ex = do
      forkIO $ takeMVar mvar >>= either (const $ return ()) g
      throwIO ex