{-# OPTIONS_GHC -Wno-orphans #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UnboxedTuples #-}
module ML.DMLC.XGBoost.Foreign where
import Foundation
import Foundation.Array.Internal
import Foundation.Class.Storable
import Foundation.Collection
import Foundation.Primitive hiding (toBytes)
import Foundation.String
import Foreign.Ptr (nullPtr)
import GHC.Exts
instance PrimType (Ptr a) where
primSizeInBytes _ = size (Proxy :: Proxy (Ptr a))
{-# INLINE primSizeInBytes #-}
primShiftToBytes _ = let (CountOf k) = size (Proxy :: Proxy (Ptr a)) in if k == 4 then 3 else 5
{-# INLINE primShiftToBytes #-}
primBaUIndex ba (Offset (I# n)) = Ptr (indexAddrArray# ba n)
{-# INLINE primBaUIndex #-}
primMbaURead mba (Offset (I# n)) = primitive $ \s1 -> let !(# s2, r1 #) = readAddrArray# mba n s1
in (# s2, Ptr r1 #)
{-# INLINE primMbaURead #-}
primMbaUWrite mba (Offset (I# n)) (Ptr w) = primitive $ \s1 -> (# writeAddrArray# mba n w s1, () #)
{-# INLINE primMbaUWrite #-}
primAddrIndex addr (Offset (I# n)) = Ptr (indexAddrOffAddr# addr n)
{-# INLINE primAddrIndex #-}
primAddrRead addr (Offset (I# n)) = primitive $ \s1 -> let !(# s2, r1 #) = readAddrOffAddr# addr n s1
in (# s2, Ptr r1 #)
{-# INLINE primAddrRead #-}
primAddrWrite addr (Offset (I# n)) (Ptr w) = primitive $ \s1 -> (# writeAddrOffAddr# addr n w s1, () #)
{-# INLINE primAddrWrite #-}
type StringPtr = Ptr Word8
type StringArray = Ptr StringPtr
type FloatArray = Ptr Float
type UIntArray = Ptr Word32
type ByteArray = Ptr Word8
boolToInt32 :: Bool -> Int32
boolToInt32 b = if b then 1 else 0
int32ToBool :: Int32 -> Bool
int32ToBool i = if i == 0 then False else True
getString :: StringPtr -> IO String
getString ptr
| ptr == nullPtr = return ""
| otherwise = fromBytesUnsafe <$> peekArrayEndedBy 0 ptr
getString' :: CountOf Word8 -> StringPtr -> IO String
getString' nlen ptr
| ptr == nullPtr = return ""
| otherwise = fromBytesUnsafe <$> peekArray nlen ptr
withString :: String -> (StringPtr -> IO a) -> IO a
withString s = withPtr (toBytes UTF8 (s <> "\0"))
getStringArray :: StringArray -> IO [String]
getStringArray ptr
| ptr == nullPtr = return []
| otherwise = do ptrs <- peekArrayEndedBy nullPtr ptr :: IO (Array StringPtr)
mapM getString . toList $ ptrs
getStringArray' :: CountOf StringPtr -> StringArray -> IO [String]
getStringArray' nlen ptr
| ptr == nullPtr = return []
| otherwise = do ptrs <- peekArray nlen ptr :: IO (Array StringPtr)
mapM getString . toList $ ptrs
withStringArray :: [String] -> (StringArray -> IO a) -> IO a
withStringArray [] f = f nullPtr
withStringArray ss f = do
ptrs <- mapM (\s -> withString s return) ss
withPtr (fromList ptrs) f