module System.Linux.Netlink.C
( makeSocket
, makeSocketGeneric
, closeSocket
, sendmsg
, recvmsg
, joinMulticastGroup
)
where
import Control.Monad (when)
import Data.ByteString (ByteString)
import Data.ByteString.Internal (createAndTrim, toForeignPtr)
import Data.Word (Word32)
import Foreign.C.Error (throwErrnoIf, throwErrnoIfMinus1, throwErrnoIfMinus1_)
import Foreign.C.Types
import Foreign.ForeignPtr (touchForeignPtr)
import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr)
import Foreign.Marshal.Array (withArrayLen)
import Foreign.Marshal.Utils (with)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import Foreign.Storable (Storable(..))
import System.Linux.Netlink.Constants (eAF_NETLINK)
foreign import ccall unsafe "socket" c_socket :: CInt -> CInt -> CInt -> IO CInt
foreign import ccall unsafe "bind" c_bind :: CInt -> Ptr SockAddrNetlink -> Int -> IO CInt
foreign import ccall unsafe "close" c_close :: CInt -> IO CInt
foreign import ccall unsafe "setsockopt" c_setsockopt :: CInt -> CInt -> CInt -> Ptr a -> CInt -> IO CInt
foreign import ccall unsafe "memset" c_memset :: Ptr a -> CInt -> CInt -> IO ()
foreign import ccall "sendmsg" c_sendmsg :: CInt -> Ptr MsgHdr -> CInt -> IO CInt
foreign import ccall "recvmsg" c_recvmsg :: CInt -> Ptr MsgHdr -> CInt -> IO CInt
data SockAddrNetlink = SockAddrNetlink Word32
instance Storable SockAddrNetlink where
sizeOf _ = (12)
alignment _ = 4
peek p = do
family <- (\hsc_ptr -> peekByteOff hsc_ptr 0) p
when ((family :: CShort) /= eAF_NETLINK) $ fail "Bad address family"
SockAddrNetlink . fromIntegral <$> ((\hsc_ptr -> peekByteOff hsc_ptr 4) p :: IO CUInt)
poke p (SockAddrNetlink pid) = do
zero p
(\hsc_ptr -> pokeByteOff hsc_ptr 0) p (eAF_NETLINK :: CShort)
(\hsc_ptr -> pokeByteOff hsc_ptr 4) p (fromIntegral pid :: CUInt)
data IoVec = IoVec (Ptr (), Int)
instance Storable IoVec where
sizeOf _ = (16)
alignment _ = 4
peek p = do
addr <- (\hsc_ptr -> peekByteOff hsc_ptr 0) p
len <- (\hsc_ptr -> peekByteOff hsc_ptr 8) p :: IO CSize
return $ IoVec (addr, fromIntegral len)
poke p (IoVec (addr, len)) = do
zero p
(\hsc_ptr -> pokeByteOff hsc_ptr 0) p addr
(\hsc_ptr -> pokeByteOff hsc_ptr 8) p (fromIntegral len :: CSize)
data MsgHdr = MsgHdr (Ptr (), Int)
instance Storable MsgHdr where
sizeOf _ = (56)
alignment _ = 4
peek p = do
iov <- (\hsc_ptr -> peekByteOff hsc_ptr 16) p
iovlen <- (\hsc_ptr -> peekByteOff hsc_ptr 24) p :: IO CSize
return $ MsgHdr (iov, fromIntegral iovlen)
poke p (MsgHdr (iov, iovlen)) = do
zero p
(\hsc_ptr -> pokeByteOff hsc_ptr 16) p iov
(\hsc_ptr -> pokeByteOff hsc_ptr 24) p (fromIntegral iovlen :: CSize)
makeSocket :: IO CInt
makeSocket = makeSocketGeneric 0
makeSocketGeneric
:: Int
-> IO CInt
makeSocketGeneric prot = do
fd <- throwErrnoIfMinus1 "makeSocket.socket" $
c_socket eAF_NETLINK 3 (fromIntegral prot)
with (SockAddrNetlink 0) $ \addr ->
throwErrnoIfMinus1_ "makeSocket.bind" $
c_bind fd (castPtr addr) (12)
return fd
closeSocket :: CInt -> IO ()
closeSocket fd = throwErrnoIfMinus1_ "closeSocket" $ c_close fd
sendmsg :: CInt -> [ByteString] -> IO ()
sendmsg fd bs =
useManyAsPtrLen bs $ \ptrs ->
withArrayLen (map IoVec ptrs) $ \iovlen iov ->
with (MsgHdr (castPtr iov, iovlen)) $ \msg ->
throwErrnoIfMinus1_ "sendmsg" $c_sendmsg fd (castPtr msg) (0 :: CInt)
recvmsg :: CInt -> Int -> IO ByteString
recvmsg fd len =
createAndTrim len $ \ptr ->
with (IoVec (castPtr ptr, len)) $ \vec ->
with (MsgHdr (castPtr vec, 1)) $ \msg ->
fmap fromIntegral . throwErrnoIf (<= 0) "recvmsg" $
c_recvmsg fd (castPtr msg) (0 :: CInt)
useManyAsPtrLen :: [ByteString] -> ([(Ptr (), Int)] -> IO a) -> IO a
useManyAsPtrLen bs act =
let makePtrLen (fptr, off, len) =
let ptr = plusPtr (unsafeForeignPtrToPtr fptr) off
in (ptr, len)
touchByteStringPtr (fptr, _, _) = touchForeignPtr fptr
foreigns = map toForeignPtr bs
in act (map makePtrLen foreigns) <* mapM_ touchByteStringPtr foreigns
sizeOfPtr :: (Storable a, Integral b) => Ptr a -> b
sizeOfPtr = fromIntegral . sizeOf . (undefined :: Ptr a -> a)
zero :: Storable a => Ptr a -> IO ()
zero p = void $ c_memset (castPtr p) 0 (sizeOfPtr p)
void :: Monad m => m a -> m ()
void act = act >> return ()
joinMulticastGroup :: CInt -> Word32 -> IO ()
joinMulticastGroup fd fid = do
_ <- throwErrnoIfMinus1 "joinMulticast" $ with fid (\ptr ->
c_setsockopt fd sol_netlink 1 (castPtr ptr) size)
return ()
where size = fromIntegral $sizeOf (undefined :: CInt)
sol_netlink = 270 :: CInt