--------------------------------------------------------------------------------
-- |
-- Module      : ArrayFire.Backend
-- Copyright   : David Johnson (c) 2019-2020
-- License     : BSD 3
-- Maintainer  : David Johnson <djohnson.m@gmail.com>
-- Stability   : Experimental
-- Portability : GHC
--
-- Set and get available ArrayFire 'Backend's.
--
-- @
-- module Main where
--
-- import ArrayFire
--
-- main :: IO ()
-- main = print =<< getAvailableBackends
-- @
--
-- @
-- [CPU,OpenCL]
-- @
--------------------------------------------------------------------------------
module ArrayFire.Backend where

import ArrayFire.FFI
import ArrayFire.Internal.Backend
import ArrayFire.Internal.Types

-- | Set specific 'Backend' to use
--
-- >>> setBackend OpenCL
-- ()
setBackend
  :: Backend
  -- ^ 'Backend' to use for 'Array' construction
  -> IO ()
setBackend :: Backend -> IO ()
setBackend = IO AFErr -> IO ()
afCall (IO AFErr -> IO ()) -> (Backend -> IO AFErr) -> Backend -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AFBackend -> IO AFErr
af_set_backend (AFBackend -> IO AFErr)
-> (Backend -> AFBackend) -> Backend -> IO AFErr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Backend -> AFBackend
toAFBackend

-- | Retrieve count of Backends available
--
-- >>> getBackendCount
-- 2
--
getBackendCount :: IO Int
getBackendCount :: IO Int
getBackendCount =
  CUInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CUInt -> Int) -> IO CUInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
    (Ptr CUInt -> IO AFErr) -> IO CUInt
forall a. Storable a => (Ptr a -> IO AFErr) -> IO a
afCall1 Ptr CUInt -> IO AFErr
af_get_backend_count

-- | Retrieve available 'Backend's
--
-- >>> mapM_ print =<< getAvailableBackends
-- CPU
-- OpenCL
getAvailableBackends :: IO [Backend]
getAvailableBackends :: IO [Backend]
getAvailableBackends =
  Int -> [Backend]
toBackends (Int -> [Backend]) -> (CInt -> Int) -> CInt -> [Backend]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> [Backend]) -> IO CInt -> IO [Backend]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
    (Ptr CInt -> IO AFErr) -> IO CInt
forall a. Storable a => (Ptr a -> IO AFErr) -> IO a
afCall1 Ptr CInt -> IO AFErr
af_get_available_backends

-- | Retrieve 'Backend' that specific 'Array' was created from
--
-- >>> getBackend (scalar @Double 2.0)
-- OpenCL
getBackend :: Array a -> Backend
getBackend :: forall a. Array a -> Backend
getBackend = AFBackend -> Backend
toBackend (AFBackend -> Backend)
-> (Array a -> AFBackend) -> Array a -> Backend
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Array a -> (Ptr AFBackend -> AFArray -> IO AFErr) -> AFBackend)
-> (Ptr AFBackend -> AFArray -> IO AFErr) -> Array a -> AFBackend
forall a b c. (a -> b -> c) -> b -> a -> c
flip Array a -> (Ptr AFBackend -> AFArray -> IO AFErr) -> AFBackend
forall a b.
Storable a =>
Array b -> (Ptr a -> AFArray -> IO AFErr) -> a
infoFromArray Ptr AFBackend -> AFArray -> IO AFErr
af_get_backend_id

-- | Retrieve active 'Backend'
--
-- >>> getActiveBackend
-- OpenCL
getActiveBackend :: IO Backend
getActiveBackend :: IO Backend
getActiveBackend = AFBackend -> Backend
toBackend (AFBackend -> Backend) -> IO AFBackend -> IO Backend
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ptr AFBackend -> IO AFErr) -> IO AFBackend
forall a. Storable a => (Ptr a -> IO AFErr) -> IO a
afCall1 Ptr AFBackend -> IO AFErr
af_get_active_backend

-- | Retrieve Device ID that 'Array' was created from
--
-- >>> getDeviceID (scalar \@Double 2.0)
-- 1
getDeviceID :: Array a -> Int
getDeviceID :: forall a. Array a -> Int
getDeviceID = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> (Array a -> CInt) -> Array a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Array a -> (Ptr CInt -> AFArray -> IO AFErr) -> CInt)
-> (Ptr CInt -> AFArray -> IO AFErr) -> Array a -> CInt
forall a b c. (a -> b -> c) -> b -> a -> c
flip Array a -> (Ptr CInt -> AFArray -> IO AFErr) -> CInt
forall a b.
Storable a =>
Array b -> (Ptr a -> AFArray -> IO AFErr) -> a
infoFromArray Ptr CInt -> AFArray -> IO AFErr
af_get_device_id