{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.System.Random.MWC (
(:~>),
uniform, uniformR,
randomArray, randomArrayWith,
module System.Random.MWC,
) where
import Prelude as P
import System.Random.MWC hiding ( uniform, uniformR )
import qualified System.Random.MWC as R
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Array.Data as A
import Data.Array.Accelerate.Array.Sugar as Sugar
type sh :~> e = sh -> GenIO -> IO e
{-# INLINE uniform #-}
uniform :: (Shape sh, Elt e, Variate e) => sh :~> e
uniform _ = R.uniform
{-# INLINE uniformR #-}
uniformR :: (Shape sh, Elt e, Variate e) => (e, e) -> sh :~> e
uniformR bounds _ = R.uniformR bounds
{-# INLINE randomArray #-}
randomArray :: (Shape sh, Elt e) => sh :~> e -> sh -> IO (Array sh e)
randomArray f sh
= do
gen <- createSystemRandom
randomArrayWith gen f sh
{-# INLINE randomArrayWith #-}
randomArrayWith
:: (Shape sh, Elt e)
=> GenIO
-> sh :~> e
-> sh
-> IO (Array sh e)
randomArrayWith gen f sh
= do
adata <- runRandomArray f sh gen
return $! Array (fromElt sh) adata
{-# INLINE runRandomArray #-}
runRandomArray
:: (Shape sh, Elt e)
=> sh :~> e
-> sh
-> GenIO
-> IO (MutableArrayData (EltRepr e))
runRandomArray f sh gen
= do
arr <- newArrayData $! Sugar.size sh
let !n = Sugar.size sh
write !i
| i P.>= n = return ()
| otherwise = do
unsafeWriteArrayData arr i . fromElt =<< f (Sugar.fromIndex sh i) gen
write (i+1)
write 0
return arr