{-# LINE 1 "src/System/Socket/Family/Netlink.hsc" #-}
{-|
Module      : System.Socket.Family.Netlink
Description : Extends System.Socket with the netlink socket family.
Copyright   : (c) Formaltech Inc. 2017
License     : BSD3
Maintainer  : protob3n@gmail.com
Stability   : experimental
Portability : Linux
-}
{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE TypeFamilies #-}
module System.Socket.Family.Netlink
    ( Netlink
    , SocketAddress()
    , NetlinkGroup(..)
    , netlinkAddress
    , netlinkAddressPid
    , netlinkKernel
    ) where

import Data.Bits ((.|.), (.&.), shiftL)
import Data.Functor ((<$>))
import Data.Serialize (Serialize(..), encode, decode)
import Data.Serialize (putWord16host, putWord32host, getWord16host, getWord32host)
import Foreign.Ptr (castPtr)
import Foreign.Storable (Storable(..))
import GHC.Word (Word32)
import System.Posix (getProcessID)
import System.Random (randomRIO)
import qualified Data.ByteString.Char8 as S

import System.Socket




-- | Netlink socket family.
data Netlink
instance Family Netlink where
    -- | Netlink address corresponding to @struct sockaddr_nl@ from
    -- @linux/netlink.h@.
    data SocketAddress Netlink = SocketAddressNetlink
        { netlinkPid    :: Word32 -- ^ Netlink source address.
        , netlinkGroups :: Word32 -- ^ Group subscription mask.
        } deriving (Read, Show, Eq)
    familyNumber _ = 16
{-# LINE 50 "src/System/Socket/Family/Netlink.hsc" #-}
instance Serialize (SocketAddress Netlink) where
    put nl = do
        putWord16host $ 16
{-# LINE 53 "src/System/Socket/Family/Netlink.hsc" #-}
        putWord16host $ 0
        putWord32host $ netlinkPid nl
        putWord32host $ netlinkGroups nl
        putWord32host $ 0
    get = do
        _nl_family <- getWord16host
        _nl_pad16  <- getWord16host
        nl_pid     <- getWord32host
        nl_groups  <- getWord32host
        _nl_pad32  <- getWord32host
        return $ SocketAddressNetlink nl_pid nl_groups
instance Storable (SocketAddress Netlink) where
    sizeOf    _ = 12
{-# LINE 66 "src/System/Socket/Family/Netlink.hsc" #-}
    alignment _ = 4
    peek ptr    = do
        bs <- S.pack <$> mapM (peekByteOff ptr) [0..15]
        case decode bs of
            Left e   -> fail e
            Right nl -> return nl
    poke ptr nl =
        let pokePtr = pokeByteOff $ castPtr ptr
         in mapM_ (uncurry pokePtr) $ [0..15] `zip` S.unpack (encode nl)

-- | Class of netlink groups. This is extensible because groups vary by netlink
-- subsystem.
class NetlinkGroup g where
    netlinkGroupNumber :: g -> Word32

-- | Construct a group mask from a list of groups.
netlinkGroupMask :: NetlinkGroup g => [g] -> Word32
netlinkGroupMask = foldr (.|.) 0 . fmap netlinkGroupNumber

-- | Construct a netlink socket from a collection of groups.
netlinkAddress :: NetlinkGroup g => [g] -> IO (SocketAddress Netlink)
netlinkAddress gs = do
    pid <- fromIntegral <$> getProcessID
    rid <- randomRIO (linuxPidMax, maxBound)
    let id' = (pid .&. linuxPidMask) .|. linuxPidShift rid
    return $ SocketAddressNetlink id' (netlinkGroupMask gs)
    where
    linuxPidMax   = 0x00400000 -- Max pid for 64-bit Linux is 2^22 - 1
    linuxPidMask  = 0x003fffff
    linuxPidShift = (`shiftL` 22)

-- | Like 'netlinkAddress', but with a configurable source address.
netlinkAddressPid :: NetlinkGroup g => Word32 -> [g] -> SocketAddress Netlink
netlinkAddressPid pid = SocketAddressNetlink pid . netlinkGroupMask

-- | The kernel's address.
netlinkKernel :: SocketAddress Netlink
netlinkKernel = SocketAddressNetlink 0 0