{-# LANGUAGE CPP #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module System.Systemd.Internal where

import           Control.Exception         (bracket)
import           Control.Monad
import           Control.Monad.IO.Class    (liftIO)
import           Control.Monad.Trans.Maybe
import qualified Data.ByteString.Char8     as BC
import           Data.ByteString.Unsafe    (unsafeUseAsCStringLen)
import           Data.List
import           Foreign.C.Types           (CInt (..))
import           Foreign.Marshal           (free, mallocBytes)
import           Foreign.Ptr
import           Network.Socket
import           Network.Socket.Address    hiding (recvFrom, sendTo)
import           Network.Socket.ByteString
import           System.Posix.Env
import           System.Posix.Types        (Fd (..))

envVariableName :: String
envVariableName = "NOTIFY_SOCKET"

foreign import ccall unsafe "sd_notify_with_fd"
  c_sd_notify_with_fd :: CInt -> Ptr a -> CInt -> Ptr b -> CInt -> CInt -> IO CInt

-- | Unset all environnement variable related to Systemd.
--
-- Calls to functions like 'System.Systemd.Daemon.notify' and
-- 'System.Systemd.Daemon.getActivatedSockets' will return
-- 'Nothing' after that.
unsetEnvironnement :: IO ()
unsetEnvironnement = mapM_ unsetEnv [envVariableName, "LISTEN_PID", "LISTEN_FDS", "LISTEN_FDNAMES"]

sendBufWithFdTo :: Socket -> BC.ByteString -> SockAddr -> Fd -> IO Int
sendBufWithFdTo sock state addr fdToSend =
  unsafeUseAsCStringLen state $ \(ptr, nbytes) ->
    bracket addrPointer free $ \p_addr -> do
      fd <- socketToFd sock
      fromIntegral <$> c_sd_notify_with_fd (fromIntegral fd) ptr (fromIntegral nbytes)
                                           p_addr (fromIntegral addrSize) (fromIntegral fdToSend)
  where addrSize = sizeOfSocketAddress addr
        addrPointer = mallocBytes addrSize >>= (\ptr -> pokeSocketAddress ptr addr >> pure ptr)

notifyWithFD_ :: Bool -> String -> Maybe Fd -> IO (Maybe ())
notifyWithFD_ unset_env state fd = do
        res <- runMaybeT notifyImpl
        when unset_env unsetEnvironnement
        return res

    where
        isValidPath path =   (length path >= 2)
                          && ( "@" `isPrefixOf` path
                             || "/" `isPrefixOf` path)
        notifyImpl = do
            guard $ state /= ""

            socketPath <- MaybeT (getEnv envVariableName)
            guard $ isValidPath socketPath
            let socketPath' = if head socketPath == '@' -- For abstract socket
                              then '\0' : tail socketPath
                              else socketPath

            socketFd <- liftIO $ socket AF_UNIX Datagram 0
            nbBytes  <- liftIO $ case fd of
                  Nothing     -> sendTo socketFd (BC.pack state) (SockAddrUnix socketPath')
                  Just sock'  -> sendBufWithFdTo socketFd (BC.pack state)
                                                (SockAddrUnix socketPath') sock'

            liftIO $ close socketFd
            guard $ nbBytes >= length state


            return ()

socketToFd_ :: Socket -> IO Fd
#if ! MIN_VERSION_network(3,1,0)
socketToFd_ = fmap Fd . fdSocket
#else
socketToFd_ = fmap Fd . unsafeFdSocket
#endif

fdToSocket :: Fd -> IO Socket
fdToSocket = mkSocket . fromIntegral