{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE ViewPatterns        #-}
--------------------------------------------------------------------------------
-- |
-- Module      : ArrayFire.Util
-- Copyright   : David Johnson (c) 2019-2020
-- License     : BSD 3
-- Maintainer  : David Johnson <djohnson.m@gmail.com>
-- Stability   : Experimental
-- Portability : GHC
--
-- Various utilities for working with the ArrayFire C library
--
-- @
-- import qualified ArrayFire as A
-- import           Control.Monad
--
-- main :: IO ()
-- main = do
--   let arr = A.constant [1,1,1,1] 10
--   idx <- A.saveArray "key" arr "file.array" False
--   foundIndex <- A.readArrayKeyCheck "file.array" "key"
--   when (idx == foundIndex) $ do
--     array <- A.readArrayKey "file.array" "key"
--     print array
-- @
-- @
-- ArrayFire Array
-- [ 1 1 1 1 ]
--         10
-- @
--------------------------------------------------------------------------------
module ArrayFire.Util where

import Control.Exception

import Data.Proxy
import Foreign.C.String
import Foreign.ForeignPtr
import Foreign.Marshal         hiding (void)
import Foreign.Storable
import System.IO.Unsafe

import ArrayFire.Internal.Types
import ArrayFire.Internal.Util

import ArrayFire.Exception
import ArrayFire.FFI

-- | Retrieve version for ArrayFire API
--
-- @
-- >>> 'print' '=<<' 'getVersion'
-- @
-- @
-- (3.6.4)
-- @
getVersion :: IO (Int,Int,Int)
getVersion :: IO (Int, Int, Int)
getVersion =
  (Ptr CInt -> IO (Int, Int, Int)) -> IO (Int, Int, Int)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CInt -> IO (Int, Int, Int)) -> IO (Int, Int, Int))
-> (Ptr CInt -> IO (Int, Int, Int)) -> IO (Int, Int, Int)
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
x ->
    (Ptr CInt -> IO (Int, Int, Int)) -> IO (Int, Int, Int)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CInt -> IO (Int, Int, Int)) -> IO (Int, Int, Int))
-> (Ptr CInt -> IO (Int, Int, Int)) -> IO (Int, Int, Int)
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
y ->
      (Ptr CInt -> IO (Int, Int, Int)) -> IO (Int, Int, Int)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CInt -> IO (Int, Int, Int)) -> IO (Int, Int, Int))
-> (Ptr CInt -> IO (Int, Int, Int)) -> IO (Int, Int, Int)
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
z -> do
        AFErr -> IO ()
throwAFError (AFErr -> IO ()) -> IO AFErr -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr CInt -> Ptr CInt -> Ptr CInt -> IO AFErr
af_get_version Ptr CInt
x Ptr CInt
y Ptr CInt
z
        (,,) (Int -> Int -> Int -> (Int, Int, Int))
-> IO Int -> IO (Int -> Int -> (Int, Int, Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CInt -> IO CInt
forall a. Storable a => Ptr a -> IO a
peek Ptr CInt
x)
             IO (Int -> Int -> (Int, Int, Int))
-> IO Int -> IO (Int -> (Int, Int, Int))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CInt -> IO CInt
forall a. Storable a => Ptr a -> IO a
peek Ptr CInt
y)
             IO (Int -> (Int, Int, Int)) -> IO Int -> IO (Int, Int, Int)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CInt -> IO CInt
forall a. Storable a => Ptr a -> IO a
peek Ptr CInt
z)

-- | Prints array to stdout
--
-- @
-- >>> 'printArray' (constant \@'Double' [1] 1)
-- @
-- @
-- ArrayFire Array
--   [ 1 1 1 1 ]
--       1.0
-- @
printArray
  :: Array a
  -- ^ Input 'Array'
  -> IO ()
printArray :: forall a. Array a -> IO ()
printArray (Array ForeignPtr ()
fptr) =
  IO () -> IO ()
forall a. IO a -> IO a
mask_ (IO () -> IO ())
-> ((Ptr () -> IO ()) -> IO ()) -> (Ptr () -> IO ()) -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr () -> (Ptr () -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr ()
fptr ((Ptr () -> IO ()) -> IO ()) -> (Ptr () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr ->
    AFErr -> IO ()
throwAFError (AFErr -> IO ()) -> IO AFErr -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr () -> IO AFErr
af_print_array Ptr ()
ptr

-- | Gets git revision of ArrayFire
--
-- @
-- >>> 'putStrLn' '=<<' 'getRevision'
-- @
-- @
-- 1b8030c5
-- @
getRevision :: IO String
getRevision :: IO String
getRevision = CString -> IO String
peekCString (CString -> IO String) -> IO CString -> IO String
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO CString
af_get_revision

-- | Prints 'Array' with error codes
--
-- @
-- >>> printArrayGen "test" (constant \@'Double' [1] 1) 2
-- @
-- @
-- ArrayFire Array
--   [ 1 1 1 1 ]
--       1.00
-- @
printArrayGen
  :: String
  -- ^  is the expression or name of the array
  -> Array a
  -- ^  is the input array
  -> Int
  -- ^ precision for the display
  -> IO ()
printArrayGen :: forall a. String -> Array a -> Int -> IO ()
printArrayGen String
s (Array ForeignPtr ()
fptr) (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
prec) = do
  IO () -> IO ()
forall a. IO a -> IO a
mask_ (IO () -> IO ())
-> ((Ptr () -> IO ()) -> IO ()) -> (Ptr () -> IO ()) -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr () -> (Ptr () -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr ()
fptr ((Ptr () -> IO ()) -> IO ()) -> (Ptr () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr ->
    String -> (CString -> IO ()) -> IO ()
forall a. String -> (CString -> IO a) -> IO a
withCString String
s ((CString -> IO ()) -> IO ()) -> (CString -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \CString
cstr ->
      AFErr -> IO ()
throwAFError (AFErr -> IO ()) -> IO AFErr -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CString -> Ptr () -> CInt -> IO AFErr
af_print_array_gen CString
cstr Ptr ()
ptr CInt
prec

-- | Saves 'Array' to disk
--
-- Save an array to a binary file.
-- The 'saveArray' and readArray functions are designed to provide store and read access to arrays using files written to disk.
-- <http://arrayfire.org/docs/group__stream__func__save.htm>
--
-- @
-- >>> saveArray "my array" (constant \@'Double' [1] 1) "array.file" 'True'
-- @
-- @
-- 0
-- @
saveArray
  :: String
  -- ^ An expression used as tag/key for the 'Array' during readArray
  -> Array a
  -- ^ Input 'Array'
  -> FilePath
  -- ^ Path that 'Array' will be saved
  -> Bool
  -- ^ Used to append to an existing file when 'True' and create or overwrite an existing file when 'False'
  -> IO Int
  -- ^ The index location of the 'Array' in the file
saveArray :: forall a. String -> Array a -> String -> Bool -> IO Int
saveArray String
key (Array ForeignPtr ()
fptr) String
filename (Int -> CBool
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CBool) -> (Bool -> Int) -> Bool -> CBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Int
forall a. Enum a => a -> Int
fromEnum -> CBool
append) = do
  IO Int -> IO Int
forall a. IO a -> IO a
mask_ (IO Int -> IO Int)
-> ((Ptr () -> IO Int) -> IO Int) -> (Ptr () -> IO Int) -> IO Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr () -> (Ptr () -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr ()
fptr ((Ptr () -> IO Int) -> IO Int) -> (Ptr () -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr ->
    (Ptr CInt -> IO Int) -> IO Int
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CInt -> IO Int) -> IO Int) -> (Ptr CInt -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
ptrIdx -> do
      String -> (CString -> IO Int) -> IO Int
forall a. String -> (CString -> IO a) -> IO a
withCString String
key ((CString -> IO Int) -> IO Int) -> (CString -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \CString
keyCstr ->
        String -> (CString -> IO Int) -> IO Int
forall a. String -> (CString -> IO a) -> IO a
withCString String
filename ((CString -> IO Int) -> IO Int) -> (CString -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \CString
filenameCstr -> do
          AFErr -> IO ()
throwAFError (AFErr -> IO ()) -> IO AFErr -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
            Ptr CInt -> CString -> Ptr () -> CString -> CBool -> IO AFErr
af_save_array Ptr CInt
ptrIdx CString
keyCstr
              Ptr ()
ptr CString
filenameCstr CBool
append
          CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CInt -> IO CInt
forall a. Storable a => Ptr a -> IO a
peek Ptr CInt
ptrIdx

-- | Reads Array by index
--
-- The 'saveArray' and readArray functions are designed to provide store and read access to arrays using files written to disk.
-- <http://arrayfire.org/docs/group__stream__func__save.htm>
--
-- @
-- >>> readArrayIndex "array.file" 0
-- @
-- @
-- ArrayFire Array
--   [ 1 1 1 1 ]
--          10.0000
-- @
readArrayIndex
  :: FilePath
  -- ^ Path to 'Array' location
  -> Int
  -- ^ Index into 'Array'
  -> IO (Array a)
readArrayIndex :: forall a. String -> Int -> IO (Array a)
readArrayIndex String
str (Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CUInt
idx') =
  String -> (CString -> IO (Array a)) -> IO (Array a)
forall a. String -> (CString -> IO a) -> IO a
withCString String
str ((CString -> IO (Array a)) -> IO (Array a))
-> (CString -> IO (Array a)) -> IO (Array a)
forall a b. (a -> b) -> a -> b
$ \CString
cstr ->
    (Ptr (Ptr ()) -> IO AFErr) -> IO (Array a)
forall a. (Ptr (Ptr ()) -> IO AFErr) -> IO (Array a)
createArray' (\Ptr (Ptr ())
p -> Ptr (Ptr ()) -> CString -> CUInt -> IO AFErr
af_read_array_index Ptr (Ptr ())
p CString
cstr CUInt
idx')

-- | Reads 'Array' by key
--
-- @
-- >>> readArrayKey "array.file" "my array"
-- @
-- @
-- ArrayFire 'Array'
--    [ 1 1 1 1 ]
--        10.0000
-- @
readArrayKey
  :: FilePath
  -- ^ Path to 'Array'
  -> String
  -- ^ Key of 'Array' on disk
  -> IO (Array a)
  -- ^ Returned 'Array'
readArrayKey :: forall a. String -> String -> IO (Array a)
readArrayKey String
fn String
key =
  String -> (CString -> IO (Array a)) -> IO (Array a)
forall a. String -> (CString -> IO a) -> IO a
withCString String
fn ((CString -> IO (Array a)) -> IO (Array a))
-> (CString -> IO (Array a)) -> IO (Array a)
forall a b. (a -> b) -> a -> b
$ \CString
fcstr ->
    String -> (CString -> IO (Array a)) -> IO (Array a)
forall a. String -> (CString -> IO a) -> IO a
withCString String
key ((CString -> IO (Array a)) -> IO (Array a))
-> (CString -> IO (Array a)) -> IO (Array a)
forall a b. (a -> b) -> a -> b
$ \CString
kcstr ->
      (Ptr (Ptr ()) -> IO AFErr) -> IO (Array a)
forall a. (Ptr (Ptr ()) -> IO AFErr) -> IO (Array a)
createArray' (\Ptr (Ptr ())
p -> Ptr (Ptr ()) -> CString -> CString -> IO AFErr
af_read_array_key Ptr (Ptr ())
p CString
fcstr CString
kcstr)

-- | Reads Array, checks if a key exists in the specified file
--
-- When reading by key, it may be a good idea to run this function first to check for the key and then call the readArray using the index.
-- <http://arrayfire.org/docs/group__stream__func__read.htm#ga31522b71beee2b1c06d49b5aa65a5c6f>
--
-- @
-- >>> readArrayCheck "array.file" "my array"
-- @
-- @
-- 0
-- @
readArrayKeyCheck
  :: FilePath
  -- ^ Path to file
  -> String
  -- ^ Key
  -> IO Int
  -- ^ is the tag/name of the array to be read. The key needs to have an exact match.
readArrayKeyCheck :: String -> String -> IO Int
readArrayKeyCheck String
a String
b =
  String -> (CString -> IO Int) -> IO Int
forall a. String -> (CString -> IO a) -> IO a
withCString String
a ((CString -> IO Int) -> IO Int) -> (CString -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \CString
acstr ->
    String -> (CString -> IO Int) -> IO Int
forall a. String -> (CString -> IO a) -> IO a
withCString String
b ((CString -> IO Int) -> IO Int) -> (CString -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \CString
bcstr ->
      CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        (Ptr CInt -> IO AFErr) -> IO CInt
forall a. Storable a => (Ptr a -> IO AFErr) -> IO a
afCall1 (\Ptr CInt
p -> Ptr CInt -> CString -> CString -> IO AFErr
af_read_array_key_check Ptr CInt
p CString
acstr CString
bcstr)

-- | Convert ArrayFire 'Array' to 'String', used for 'Show' instance.
--
-- @
-- >>> 'putStrLn' '$' 'arrayString' (constant \@'Double' 10 [1,1,1,1])
-- @
-- @
-- ArrayFire 'Array'
--    [ 1 1 1 1 ]
--        10.0000
-- @
arrayString
  :: Array a
  -- ^ Input 'Array'
  -> String
  -- ^ 'String' representation of 'Array'
arrayString :: forall a. Array a -> String
arrayString Array a
a = String -> Array a -> Int -> Bool -> String
forall a. String -> Array a -> Int -> Bool -> String
arrayToString String
"ArrayFire Array" Array a
a Int
4 Bool
True

-- | Convert ArrayFire Array to String
--
-- @
-- >>> print (constant \@'Double' 10 [1,1,1,1]) 4 'False'
-- @
-- @
-- ArrayFire 'Array'
--    [ 1 1 1 1 ]
--        10.0000
-- @
arrayToString
  :: String
  -- ^ Name of 'Array'
  -> Array a
  -- ^ 'Array' input
  -> Int
  -- ^ Precision of 'Array' values.
  -> Bool
  -- ^ If 'True', performs takes the transpose before rendering to 'String'
  -> String
  -- ^ 'Array' rendered to 'String'
arrayToString :: forall a. String -> Array a -> Int -> Bool -> String
arrayToString String
expr (Array ForeignPtr ()
fptr) (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral -> CInt
prec) (Int -> CBool
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CBool) -> (Bool -> Int) -> Bool -> CBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Int
forall a. Enum a => a -> Int
fromEnum -> CBool
trans) =
  IO String -> String
forall a. IO a -> a
unsafePerformIO (IO String -> String)
-> ((Ptr () -> IO String) -> IO String)
-> (Ptr () -> IO String)
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO String -> IO String
forall a. IO a -> IO a
mask_ (IO String -> IO String)
-> ((Ptr () -> IO String) -> IO String)
-> (Ptr () -> IO String)
-> IO String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr () -> (Ptr () -> IO String) -> IO String
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr ()
fptr ((Ptr () -> IO String) -> String)
-> (Ptr () -> IO String) -> String
forall a b. (a -> b) -> a -> b
$ \Ptr ()
aptr ->
    String -> (CString -> IO String) -> IO String
forall a. String -> (CString -> IO a) -> IO a
withCString String
expr ((CString -> IO String) -> IO String)
-> (CString -> IO String) -> IO String
forall a b. (a -> b) -> a -> b
$ \CString
expCstr ->
      (Ptr CString -> IO String) -> IO String
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CString -> IO String) -> IO String)
-> (Ptr CString -> IO String) -> IO String
forall a b. (a -> b) -> a -> b
$ \Ptr CString
ocstr -> do
        AFErr -> IO ()
throwAFError (AFErr -> IO ()) -> IO AFErr -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr CString -> CString -> Ptr () -> CInt -> CBool -> IO AFErr
af_array_to_string Ptr CString
ocstr CString
expCstr Ptr ()
aptr CInt
prec CBool
trans
        CString -> IO String
peekCString (CString -> IO String) -> IO CString -> IO String
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr CString -> IO CString
forall a. Storable a => Ptr a -> IO a
peek Ptr CString
ocstr

-- | Retrieve size of ArrayFire data type
--
-- @
-- >>> 'getSizeOf' ('Proxy' \@ 'Double')
-- @
-- @
-- 8
-- @
getSizeOf
  :: forall a . AFType a
  => Proxy a
  -- ^ Witness of Haskell type that mirrors ArrayFire type.
  -> Int
  -- ^ Size of ArrayFire type
getSizeOf :: forall a. AFType a => Proxy a -> Int
getSizeOf Proxy a
proxy =
  IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int)
-> ((Ptr CSize -> IO Int) -> IO Int)
-> (Ptr CSize -> IO Int)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO Int -> IO Int
forall a. IO a -> IO a
mask_ (IO Int -> IO Int)
-> ((Ptr CSize -> IO Int) -> IO Int)
-> (Ptr CSize -> IO Int)
-> IO Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ptr CSize -> IO Int) -> IO Int
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CSize -> IO Int) -> Int) -> (Ptr CSize -> IO Int) -> Int
forall a b. (a -> b) -> a -> b
$ \Ptr CSize
csize -> do
    AFErr -> IO ()
throwAFError (AFErr -> IO ()) -> IO AFErr -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr CSize -> AFDtype -> IO AFErr
af_get_size_of Ptr CSize
csize (Proxy a -> AFDtype
forall a. AFType a => Proxy a -> AFDtype
afType Proxy a
proxy)
    CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Int) -> IO CSize -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CSize -> IO CSize
forall a. Storable a => Ptr a -> IO a
peek Ptr CSize
csize