{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE KindSignatures #-}
module ArrayFire.Random
( createRandomEngine
, retainRandomEngine
, setRandomEngine
, getRandomEngineType
, randomEngineSetSeed
, getDefaultRandomEngine
, setDefaultRandomEngineType
, randomEngineGetSeed
, setSeed
, getSeed
, randn
, randu
, randomUniform
, randomNormal
) where
import Control.Exception
import Data.Proxy
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Marshal hiding (void)
import Foreign.Ptr
import Foreign.Storable
import ArrayFire.Exception
import ArrayFire.Internal.Types
import ArrayFire.Internal.Defines
import ArrayFire.Internal.Random
import ArrayFire.FFI
createRandomEngine
:: Int
-> RandomEngineType
-> IO RandomEngine
createRandomEngine :: Int -> RandomEngineType -> IO RandomEngine
createRandomEngine (Int -> UIntL
forall a b. (Integral a, Num b) => a -> b
fromIntegral -> UIntL
n) RandomEngineType
typ =
IO RandomEngine -> IO RandomEngine
forall a. IO a -> IO a
mask_ (IO RandomEngine -> IO RandomEngine)
-> IO RandomEngine -> IO RandomEngine
forall a b. (a -> b) -> a -> b
$ do
AFRandomEngine
ptr <-
(Ptr AFRandomEngine -> IO AFRandomEngine) -> IO AFRandomEngine
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr AFRandomEngine -> IO AFRandomEngine) -> IO AFRandomEngine)
-> (Ptr AFRandomEngine -> IO AFRandomEngine) -> IO AFRandomEngine
forall a b. (a -> b) -> a -> b
$ \Ptr AFRandomEngine
ptrInput -> do
AFErr -> IO ()
throwAFError (AFErr -> IO ()) -> IO AFErr -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr AFRandomEngine -> AFRandomEngineType -> UIntL -> IO AFErr
af_create_random_engine Ptr AFRandomEngine
ptrInput (RandomEngineType -> AFRandomEngineType
fromRandomEngine RandomEngineType
typ) UIntL
n
Ptr AFRandomEngine -> IO AFRandomEngine
forall a. Storable a => Ptr a -> IO a
peek Ptr AFRandomEngine
ptrInput
ForeignPtr ()
fptr <- FinalizerPtr () -> AFRandomEngine -> IO (ForeignPtr ())
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr ()
af_release_random_engine_finalizer AFRandomEngine
ptr
RandomEngine -> IO RandomEngine
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ForeignPtr () -> RandomEngine
RandomEngine ForeignPtr ()
fptr)
retainRandomEngine
:: RandomEngine
-> IO RandomEngine
retainRandomEngine :: RandomEngine -> IO RandomEngine
retainRandomEngine =
(RandomEngine
-> (Ptr AFRandomEngine -> AFRandomEngine -> IO AFErr)
-> IO RandomEngine
`op1re` Ptr AFRandomEngine -> AFRandomEngine -> IO AFErr
af_retain_random_engine)
foreign import ccall unsafe "af_random_engine_set_type_"
af_random_engine_set_type_ :: AFRandomEngine -> AFRandomEngineType -> IO AFErr
setRandomEngine
:: RandomEngine
-> RandomEngineType
-> IO ()
setRandomEngine :: RandomEngine -> RandomEngineType -> IO ()
setRandomEngine RandomEngine
r RandomEngineType
t =
RandomEngine
r RandomEngine -> (AFRandomEngine -> IO AFErr) -> IO ()
`inPlaceEng` (AFRandomEngine -> AFRandomEngineType -> IO AFErr
`af_random_engine_set_type_` (RandomEngineType -> AFRandomEngineType
fromRandomEngine RandomEngineType
t))
getRandomEngineType
:: RandomEngine
-> IO RandomEngineType
getRandomEngineType :: RandomEngine -> IO RandomEngineType
getRandomEngineType RandomEngine
r =
AFRandomEngineType -> RandomEngineType
toRandomEngine (AFRandomEngineType -> RandomEngineType)
-> IO AFRandomEngineType -> IO RandomEngineType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
RandomEngine
r RandomEngine
-> (Ptr AFRandomEngineType -> AFRandomEngine -> IO AFErr)
-> IO AFRandomEngineType
forall a.
Storable a =>
RandomEngine -> (Ptr a -> AFRandomEngine -> IO AFErr) -> IO a
`infoFromRandomEngine` Ptr AFRandomEngineType -> AFRandomEngine -> IO AFErr
af_random_engine_get_type
foreign import ccall unsafe "af_random_engine_set_seed_"
af_random_engine_set_seed_ :: AFRandomEngine -> IntL -> IO AFErr
randomEngineSetSeed
:: RandomEngine
-> Int
-> IO ()
randomEngineSetSeed :: RandomEngine -> Int -> IO ()
randomEngineSetSeed RandomEngine
r Int
t =
RandomEngine
r RandomEngine -> (AFRandomEngine -> IO AFErr) -> IO ()
`inPlaceEng` (AFRandomEngine -> IntL -> IO AFErr
`af_random_engine_set_seed_` (Int -> IntL
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
t))
getDefaultRandomEngine
:: IO RandomEngine
getDefaultRandomEngine :: IO RandomEngine
getDefaultRandomEngine =
IO RandomEngine -> IO RandomEngine
forall a. IO a -> IO a
mask_ (IO RandomEngine -> IO RandomEngine)
-> IO RandomEngine -> IO RandomEngine
forall a b. (a -> b) -> a -> b
$ do
AFRandomEngine
ptr <-
(Ptr AFRandomEngine -> IO AFRandomEngine) -> IO AFRandomEngine
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr AFRandomEngine -> IO AFRandomEngine) -> IO AFRandomEngine)
-> (Ptr AFRandomEngine -> IO AFRandomEngine) -> IO AFRandomEngine
forall a b. (a -> b) -> a -> b
$ \Ptr AFRandomEngine
ptrInput -> do
AFErr -> IO ()
throwAFError (AFErr -> IO ()) -> IO AFErr -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr AFRandomEngine -> IO AFErr
af_get_default_random_engine Ptr AFRandomEngine
ptrInput
Ptr AFRandomEngine -> IO AFRandomEngine
forall a. Storable a => Ptr a -> IO a
peek Ptr AFRandomEngine
ptrInput
ForeignPtr ()
fptr <- FinalizerPtr () -> AFRandomEngine -> IO (ForeignPtr ())
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr ()
af_release_random_engine_finalizer AFRandomEngine
ptr
RandomEngine -> IO RandomEngine
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ForeignPtr () -> RandomEngine
RandomEngine ForeignPtr ()
fptr)
setDefaultRandomEngineType
:: RandomEngineType
-> IO ()
setDefaultRandomEngineType :: RandomEngineType -> IO ()
setDefaultRandomEngineType RandomEngineType
n =
IO AFErr -> IO ()
afCall (AFRandomEngineType -> IO AFErr
af_set_default_random_engine_type (RandomEngineType -> AFRandomEngineType
fromRandomEngine RandomEngineType
n))
randomEngineGetSeed
:: RandomEngine
-> IO Int
randomEngineGetSeed :: RandomEngine -> IO Int
randomEngineGetSeed RandomEngine
r =
UIntL -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (UIntL -> Int) -> IO UIntL -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
RandomEngine
r RandomEngine
-> (Ptr UIntL -> AFRandomEngine -> IO AFErr) -> IO UIntL
forall a.
Storable a =>
RandomEngine -> (Ptr a -> AFRandomEngine -> IO AFErr) -> IO a
`infoFromRandomEngine` Ptr UIntL -> AFRandomEngine -> IO AFErr
af_random_engine_get_seed
setSeed :: Int -> IO ()
setSeed :: Int -> IO ()
setSeed = IO AFErr -> IO ()
afCall (IO AFErr -> IO ()) -> (Int -> IO AFErr) -> Int -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UIntL -> IO AFErr
af_set_seed (UIntL -> IO AFErr) -> (Int -> UIntL) -> Int -> IO AFErr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> UIntL
forall a b. (Integral a, Num b) => a -> b
fromIntegral
getSeed :: IO Int
getSeed :: IO Int
getSeed = UIntL -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (UIntL -> Int) -> IO UIntL -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ptr UIntL -> IO AFErr) -> IO UIntL
forall a. Storable a => (Ptr a -> IO AFErr) -> IO a
afCall1 Ptr UIntL -> IO AFErr
af_get_seed
randEng
:: forall a . AFType a
=> [Int]
-> (Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> AFRandomEngine -> IO AFErr)
-> RandomEngine
-> IO (Array a)
randEng :: forall a.
AFType a =>
[Int]
-> (Ptr AFRandomEngine
-> CUInt -> Ptr DimT -> AFDtype -> AFRandomEngine -> IO AFErr)
-> RandomEngine
-> IO (Array a)
randEng [Int]
dims Ptr AFRandomEngine
-> CUInt -> Ptr DimT -> AFDtype -> AFRandomEngine -> IO AFErr
f (RandomEngine ForeignPtr ()
fptr) = IO (Array a) -> IO (Array a)
forall a. IO a -> IO a
mask_ (IO (Array a) -> IO (Array a)) -> IO (Array a) -> IO (Array a)
forall a b. (a -> b) -> a -> b
$
ForeignPtr () -> (AFRandomEngine -> IO (Array a)) -> IO (Array a)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr ()
fptr ((AFRandomEngine -> IO (Array a)) -> IO (Array a))
-> (AFRandomEngine -> IO (Array a)) -> IO (Array a)
forall a b. (a -> b) -> a -> b
$ \AFRandomEngine
rptr -> do
AFRandomEngine
ptr <- (Ptr AFRandomEngine -> IO AFRandomEngine) -> IO AFRandomEngine
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr AFRandomEngine -> IO AFRandomEngine) -> IO AFRandomEngine)
-> (Ptr AFRandomEngine -> IO AFRandomEngine) -> IO AFRandomEngine
forall a b. (a -> b) -> a -> b
$ \Ptr AFRandomEngine
ptrPtr -> do
[DimT] -> (Ptr DimT -> IO AFRandomEngine) -> IO AFRandomEngine
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray (Int -> DimT
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> DimT) -> [Int] -> [DimT]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int]
dims) ((Ptr DimT -> IO AFRandomEngine) -> IO AFRandomEngine)
-> (Ptr DimT -> IO AFRandomEngine) -> IO AFRandomEngine
forall a b. (a -> b) -> a -> b
$ \Ptr DimT
dimArray -> do
AFErr -> IO ()
throwAFError (AFErr -> IO ()) -> IO AFErr -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr AFRandomEngine
-> CUInt -> Ptr DimT -> AFDtype -> AFRandomEngine -> IO AFErr
f Ptr AFRandomEngine
ptrPtr CUInt
n Ptr DimT
dimArray AFDtype
typ AFRandomEngine
rptr
Ptr AFRandomEngine -> IO AFRandomEngine
forall a. Storable a => Ptr a -> IO a
peek Ptr AFRandomEngine
ptrPtr
ForeignPtr () -> Array a
forall a. ForeignPtr () -> Array a
Array (ForeignPtr () -> Array a) -> IO (ForeignPtr ()) -> IO (Array a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
FinalizerPtr () -> AFRandomEngine -> IO (ForeignPtr ())
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr
FinalizerPtr ()
af_release_array_finalizer
AFRandomEngine
ptr
where
n :: CUInt
n = Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
dims)
typ :: AFDtype
typ = Proxy a -> AFDtype
forall a. AFType a => Proxy a -> AFDtype
afType (forall {t}. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a)
rand
:: forall a . AFType a
=> [Int]
-> (Ptr AFArray -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr)
-> IO (Array a)
rand :: forall a.
AFType a =>
[Int]
-> (Ptr AFRandomEngine -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr)
-> IO (Array a)
rand [Int]
dims Ptr AFRandomEngine -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr
f = IO (Array a) -> IO (Array a)
forall a. IO a -> IO a
mask_ (IO (Array a) -> IO (Array a)) -> IO (Array a) -> IO (Array a)
forall a b. (a -> b) -> a -> b
$ do
AFRandomEngine
ptr <- (Ptr AFRandomEngine -> IO AFRandomEngine) -> IO AFRandomEngine
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr AFRandomEngine -> IO AFRandomEngine) -> IO AFRandomEngine)
-> (Ptr AFRandomEngine -> IO AFRandomEngine) -> IO AFRandomEngine
forall a b. (a -> b) -> a -> b
$ \Ptr AFRandomEngine
ptrPtr -> do
Ptr AFRandomEngine -> IO ()
zeroOutArray Ptr AFRandomEngine
ptrPtr
[DimT] -> (Ptr DimT -> IO AFRandomEngine) -> IO AFRandomEngine
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray (Int -> DimT
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> DimT) -> [Int] -> [DimT]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int]
dims) ((Ptr DimT -> IO AFRandomEngine) -> IO AFRandomEngine)
-> (Ptr DimT -> IO AFRandomEngine) -> IO AFRandomEngine
forall a b. (a -> b) -> a -> b
$ \Ptr DimT
dimArray -> do
AFErr -> IO ()
throwAFError (AFErr -> IO ()) -> IO AFErr -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr AFRandomEngine -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr
f Ptr AFRandomEngine
ptrPtr CUInt
n Ptr DimT
dimArray AFDtype
typ
Ptr AFRandomEngine -> IO AFRandomEngine
forall a. Storable a => Ptr a -> IO a
peek Ptr AFRandomEngine
ptrPtr
ForeignPtr () -> Array a
forall a. ForeignPtr () -> Array a
Array (ForeignPtr () -> Array a) -> IO (ForeignPtr ()) -> IO (Array a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
FinalizerPtr () -> AFRandomEngine -> IO (ForeignPtr ())
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr
FinalizerPtr ()
af_release_array_finalizer
AFRandomEngine
ptr
where
n :: CUInt
n = Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
dims)
typ :: AFDtype
typ = Proxy a -> AFDtype
forall a. AFType a => Proxy a -> AFDtype
afType (forall {t}. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a)
randn
:: forall a
. (AFType a, Fractional a)
=> [Int]
-> IO (Array a)
randn :: forall a. (AFType a, Fractional a) => [Int] -> IO (Array a)
randn [Int]
dims = forall a.
AFType a =>
[Int]
-> (Ptr AFRandomEngine -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr)
-> IO (Array a)
rand @a [Int]
dims Ptr AFRandomEngine -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr
af_randn
randu
:: forall a . AFType a
=> [Int]
-> IO (Array a)
randu :: forall a. AFType a => [Int] -> IO (Array a)
randu [Int]
dims = forall a.
AFType a =>
[Int]
-> (Ptr AFRandomEngine -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr)
-> IO (Array a)
rand @a [Int]
dims Ptr AFRandomEngine -> CUInt -> Ptr DimT -> AFDtype -> IO AFErr
af_randu
randomUniform
:: forall a . AFType a
=> [Int]
-> RandomEngine
-> IO (Array a)
randomUniform :: forall a. AFType a => [Int] -> RandomEngine -> IO (Array a)
randomUniform [Int]
dims RandomEngine
eng =
forall a.
AFType a =>
[Int]
-> (Ptr AFRandomEngine
-> CUInt -> Ptr DimT -> AFDtype -> AFRandomEngine -> IO AFErr)
-> RandomEngine
-> IO (Array a)
randEng @a [Int]
dims Ptr AFRandomEngine
-> CUInt -> Ptr DimT -> AFDtype -> AFRandomEngine -> IO AFErr
af_random_uniform RandomEngine
eng
randomNormal
:: forall a
. AFType a
=> [Int]
-> RandomEngine
-> IO (Array a)
randomNormal :: forall a. AFType a => [Int] -> RandomEngine -> IO (Array a)
randomNormal [Int]
dims RandomEngine
eng =
forall a.
AFType a =>
[Int]
-> (Ptr AFRandomEngine
-> CUInt -> Ptr DimT -> AFDtype -> AFRandomEngine -> IO AFErr)
-> RandomEngine
-> IO (Array a)
randEng @a [Int]
dims Ptr AFRandomEngine
-> CUInt -> Ptr DimT -> AFDtype -> AFRandomEngine -> IO AFErr
af_random_normal RandomEngine
eng