module ML.DMLC.XGBoost.Rabit.FFI where import Foundation import Foundation.Array.Internal import Foundation.Class.Storable import Foundation.Collection import Foundation.Foreign import Foreign.Ptr (nullPtr) import Foreign.Marshal.Alloc (alloca) import qualified Foreign.Storable (peek) import ML.DMLC.XGBoost.Foreign foreign import ccall unsafe "RabitInit" c_rabitInit :: Int32 -> StringArray -> IO () foreign import ccall unsafe "RabitFinalize" c_rabitFinalize :: IO () foreign import ccall unsafe "RabitGetRank" c_rabitGetRank :: IO Int32 foreign import ccall unsafe "RabitGetWorldSize" c_rabitGetWorldSize :: IO Int32 foreign import ccall unsafe "RabitIsDistributed" c_rabitIsDistributed :: IO Int32 foreign import ccall unsafe "RabitTrackerPrint" c_rabitTrackerPrint :: StringPtr -> IO () foreign import ccall unsafe "RabitGetProcessorName" c_rabitGetProcessorName :: StringPtr -> Ptr CULong -> CULong -> IO () foreign import ccall unsafe "RabitBroadcast" c_rabitBroadcast :: Ptr a -> CULong -> Int32 -> IO () foreign import ccall unsafe "RabitAllreduce" c_rabitAllreduce :: Ptr a -> CSize -> Int32 -> Int32 -> Ptr () -> Ptr () -> IO () foreign import ccall unsafe "RabitLoadCheckPoint" c_rabitLoadCheckPoint :: Ptr StringPtr -> Ptr CULong -> Ptr StringPtr -> Ptr CULong -> IO Int32 foreign import ccall unsafe "RabitCheckPoint" c_rabitCheckPoint :: StringPtr -> CULong -> StringPtr -> CULong -> IO () foreign import ccall unsafe "RabitVersionNumber" c_rabitVersionNumber :: IO Int32 foreign import ccall unsafe "RabitLinkTag" c_rabitLinkTag :: IO Int32 rabitInit :: [String] -> IO () rabitInit args = do let (CountOf nlen) = length args argv = fromIntegral nlen withStringArray args $ \pargs -> c_rabitInit argv pargs rabitFinalize :: IO () rabitFinalize = c_rabitFinalize rabitGetRank :: IO Int32 rabitGetRank = c_rabitGetRank rabitGetWordSize :: IO Int32 rabitGetWordSize = c_rabitGetWorldSize rabitIsDistributed :: IO Bool rabitIsDistributed = int32ToBool <$> c_rabitIsDistributed rabitTrackerPrint :: String -> IO () rabitTrackerPrint msg = withString msg $ \pmsg -> c_rabitTrackerPrint pmsg rabitGetProcessorName :: IO String rabitGetProcessorName = do let nlimit = 64 buf <- mutNew (CountOf nlimit) alloca $ \plen -> withMutablePtr buf $ \pbuf -> do c_rabitGetProcessorName pbuf plen (fromIntegral nlimit) nlen <- Foreign.Storable.peek plen getString' (CountOf (fromIntegral nlen)) pbuf rabitBoradcast :: Ptr a -- ^ the pointer to send or recive buffer. -> Int32 -- ^ the size of data -> Int32 -- ^ the root of process -> IO () rabitBoradcast pdata nlen root = c_rabitBroadcast pdata (fromIntegral nlen) root -- | Ref: https://github.com/dmlc/rabit/blob/master/include/rabit/internal/engine.h#L172 data AllreduceOpType = KMax | KMin | KSum | KBitwiseOR deriving Eq instance Enum AllreduceOpType where toEnum 0 = KMax toEnum 1 = KMin toEnum 2 = KSum toEnum 3 = KBitwiseOR toEnum _ = error "No such AllreduceOpType" fromEnum KMax = 0 fromEnum KMin = 2 fromEnum KSum = 3 fromEnum KBitwiseOR = 4 -- | Ref: https://github.com/dmlc/rabit/blob/master/include/rabit/internal/engine.h#L179 data AllreduceDataType = KChar | KUChar | KInt | KUInt | KLong | KULong | KFloat | KDouble | KLongLong | KULongLong deriving Eq instance Enum AllreduceDataType where toEnum 0 = KChar toEnum 1 = KUChar toEnum 2 = KInt toEnum 3 = KUInt toEnum 4 = KLong toEnum 5 = KULong toEnum 6 = KFloat toEnum 7 = KDouble toEnum 8 = KLongLong toEnum 9 = KULongLong toEnum _ = error "No such AllreduceDataType" fromEnum KChar = 0 fromEnum KUChar = 1 fromEnum KInt = 2 fromEnum KUInt = 3 fromEnum KLong = 4 fromEnum KULong = 5 fromEnum KFloat = 6 fromEnum KDouble = 7 fromEnum KLongLong = 8 fromEnum KULongLong = 9 rabitAllreduce :: Ptr a -- ^ buffer for both sending and recving data -> Int32 -- ^ number of elements to be reduced -> AllreduceDataType -> AllreduceOpType -> IO () rabitAllreduce pdata count dtype optype = c_rabitAllreduce pdata (fromIntegral count) (fromIntegral . fromEnum $ dtype) (fromIntegral . fromEnum $ optype) nullPtr nullPtr rabitVersionNumber :: IO Int32 rabitVersionNumber = c_rabitVersionNumber rabitLinkTag :: IO Int32 rabitLinkTag = c_rabitLinkTag