{-# OPTIONS_GHC -Wno-orphans #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UnboxedTuples #-}

module ML.DMLC.XGBoost.Foreign where

import Foundation
import Foundation.Array.Internal
import Foundation.Class.Storable
import Foundation.Collection
import Foundation.Primitive hiding (toBytes)
import Foundation.String

import Foreign.Ptr (nullPtr)
import GHC.Exts

{-- Orphan instance to make `Ptr a` as foundation's PrimType --}

instance PrimType (Ptr a) where
    primSizeInBytes _ = size (Proxy :: Proxy (Ptr a))
    {-# INLINE primSizeInBytes #-}

    primShiftToBytes _ = let (CountOf k) = size (Proxy :: Proxy (Ptr a)) in if k == 4 then 3 else 5 -- TODO may be wrong
    {-# INLINE primShiftToBytes #-}

    primBaUIndex ba (Offset (I# n)) = Ptr (indexAddrArray# ba n)
    {-# INLINE primBaUIndex #-}

    primMbaURead mba (Offset (I# n)) = primitive $ \s1 -> let !(# s2, r1 #) = readAddrArray# mba n s1
                                                           in (# s2, Ptr r1 #)
    {-# INLINE primMbaURead #-}

    primMbaUWrite mba (Offset (I# n)) (Ptr w) = primitive $ \s1 -> (# writeAddrArray# mba n w s1, () #)
    {-# INLINE primMbaUWrite #-}

    primAddrIndex addr (Offset (I# n)) = Ptr (indexAddrOffAddr# addr n)
    {-# INLINE primAddrIndex #-}

    primAddrRead addr (Offset (I# n)) = primitive $ \s1 -> let !(# s2, r1 #) = readAddrOffAddr# addr n s1
                                                            in (# s2, Ptr r1 #)
    {-# INLINE primAddrRead #-}

    primAddrWrite addr (Offset (I# n)) (Ptr w) = primitive $ \s1 -> (# writeAddrOffAddr# addr n w s1, () #)
    {-# INLINE primAddrWrite #-}

{-- Types --------------------------------------------------------------------}

type StringPtr = Ptr Word8
type StringArray = Ptr StringPtr
type FloatArray = Ptr Float
type UIntArray = Ptr Word32
type ByteArray = Ptr Word8

{-- Utilities ----------------------------------------------------------------}

boolToInt32 :: Bool -> Int32
boolToInt32 b = if b then 1 else 0

int32ToBool :: Int32 -> Bool
int32ToBool i = if i == 0 then False else True

getString :: StringPtr -> IO String
getString ptr
    | ptr == nullPtr = return ""
    | otherwise = fromBytesUnsafe <$> peekArrayEndedBy 0 ptr

getString' :: CountOf Word8 -> StringPtr -> IO String
getString' nlen ptr
    | ptr == nullPtr = return ""
    | otherwise = fromBytesUnsafe <$> peekArray nlen ptr

withString :: String -> (StringPtr -> IO a) -> IO a
withString s = withPtr (toBytes UTF8 (s <> "\0"))

getStringArray :: StringArray -> IO [String]
getStringArray ptr
    | ptr == nullPtr = return []
    | otherwise = do ptrs <- peekArrayEndedBy nullPtr ptr :: IO (Array StringPtr)
                     mapM getString . toList $ ptrs

getStringArray' :: CountOf StringPtr -> StringArray -> IO [String]
getStringArray' nlen ptr
    | ptr == nullPtr = return []
    | otherwise = do ptrs <- peekArray nlen ptr :: IO (Array StringPtr)
                     mapM getString . toList $ ptrs

withStringArray :: [String] -> (StringArray -> IO a) -> IO a
withStringArray [] f = f nullPtr
withStringArray ss f = do
    ptrs <- mapM (\s -> withString s return) ss
    withPtr (fromList ptrs) f