module GSL.Random.Gen.Internal (
RNG(..),
RNGType,
newRNG,
setSeed,
getSample,
getUniform,
getUniformPos,
getUniformInt,
getName,
getMax,
getMin,
getSize,
getState,
setState,
copyRNG,
cloneRNG,
mt19937,
) where
import Control.Monad ( liftM )
import Data.Word ( Word8, Word64 )
import Foreign.C.Types ( CULong, CSize, CDouble )
import Foreign.C.String ( CString, peekCAString )
import Foreign.ForeignPtr ( ForeignPtr, newForeignPtr, withForeignPtr )
import Foreign.Marshal.Array ( peekArray, pokeArray )
import Foreign.Ptr ( Ptr, FunPtr )
newtype RNG = MkRNG (ForeignPtr ())
newtype RNGType = MkRNGType (Ptr ())
newRNG :: RNGType -> IO RNG
newRNG t = do
ptr <- gsl_rng_alloc t
fptr <- newForeignPtr p_gsl_rng_free ptr
return $! MkRNG fptr
foreign import ccall unsafe "gsl/gsl_rng.h"
gsl_rng_alloc :: RNGType -> IO (Ptr ())
foreign import ccall unsafe "gsl/gsl_rng.h &gsl_rng_free"
p_gsl_rng_free :: FunPtr (Ptr () -> IO ())
setSeed :: RNG -> Word64 -> IO ()
setSeed (MkRNG fptr) seed =
let seed' = (fromInteger . toInteger) seed
in withForeignPtr fptr $ flip gsl_rng_set seed'
foreign import ccall unsafe "gsl/gsl_rng.h"
gsl_rng_set :: Ptr () -> CULong -> IO ()
getSample :: RNG -> IO Word64
getSample (MkRNG fptr) =
withForeignPtr fptr $ \ptr ->
gsl_rng_get ptr >>= return . fromInteger . toInteger
foreign import ccall unsafe "gsl/gsl_rng.h"
gsl_rng_get :: Ptr () -> IO CULong
getUniform :: RNG -> IO Double
getUniform (MkRNG fptr) =
withForeignPtr fptr $ \ptr ->
gsl_rng_uniform ptr >>= return . realToFrac
foreign import ccall unsafe "gsl/gsl_rng.h"
gsl_rng_uniform :: Ptr () -> IO CDouble
getUniformPos :: RNG -> IO Double
getUniformPos (MkRNG fptr) =
withForeignPtr fptr $ \ptr ->
gsl_rng_uniform_pos ptr >>= return . realToFrac
foreign import ccall unsafe "gsl/gsl_rng.h"
gsl_rng_uniform_pos :: Ptr () -> IO CDouble
getUniformInt :: RNG -> Int -> IO Int
getUniformInt (MkRNG fptr) n
| n <= 0 =
ioError $ userError $
"rngUnifInt: expected \"n\" to be greater than 0" ++
" but got `" ++ show n ++ "' instead."
| otherwise =
let n' = (fromInteger . toInteger) n
in withForeignPtr fptr $ \ptr ->
gsl_rng_uniform_int ptr n' >>= return . fromInteger . toInteger
foreign import ccall unsafe "gsl/gsl_rng.h"
gsl_rng_uniform_int :: Ptr () -> CULong -> IO CULong
getName :: RNG -> IO String
getName (MkRNG fptr) =
withForeignPtr fptr $ \ptr ->
peekCAString (gsl_rng_name ptr)
foreign import ccall unsafe "gsl/gsl_rng.h"
gsl_rng_name :: Ptr () -> CString
getMax :: RNG -> IO Word64
getMax (MkRNG fptr) =
withForeignPtr fptr $ \ptr ->
(return . fromInteger . toInteger) (gsl_rng_max ptr)
foreign import ccall unsafe "gsl/gsl_rng.h"
gsl_rng_max :: Ptr () -> CULong
getMin :: RNG -> IO Word64
getMin (MkRNG fptr) =
withForeignPtr fptr $ \ptr ->
(return . fromInteger . toInteger) (gsl_rng_min ptr)
foreign import ccall unsafe "gsl/gsl_rng.h"
gsl_rng_min :: Ptr () -> CULong
getSize :: RNG -> IO Word64
getSize (MkRNG fptr) =
withForeignPtr fptr $ \ptr ->
(return . fromInteger . toInteger) (gsl_rng_size ptr)
foreign import ccall unsafe "gsl/gsl_rng.h"
gsl_rng_size :: Ptr () -> CSize
getState :: RNG -> IO [Word8]
getState rng@(MkRNG fptr) = do
n <- liftM (fromInteger . toInteger) (getSize rng)
withForeignPtr fptr $ \ptr ->
peekArray n (gsl_rng_state ptr)
foreign import ccall unsafe "gsl/gsl_randist.h"
gsl_rng_state :: Ptr () -> Ptr Word8
setState :: RNG -> [Word8] -> IO ()
setState (MkRNG fptr) state = do
withForeignPtr fptr $ \ptr ->
pokeArray (gsl_rng_state ptr) state
copyRNG :: RNG -> RNG -> IO ()
copyRNG (MkRNG fdst) (MkRNG fsrc) =
withForeignPtr fdst $ \dst ->
withForeignPtr fsrc $ \src ->
gsl_rng_memcpy dst src
foreign import ccall unsafe "gsl/gsl_rng.h"
gsl_rng_memcpy :: Ptr () -> Ptr () -> IO ()
cloneRNG :: RNG -> IO RNG
cloneRNG (MkRNG fptr) =
withForeignPtr fptr $ \ptr -> do
ptr' <- gsl_rng_clone ptr
fptr' <- newForeignPtr p_gsl_rng_free ptr'
ptr' `seq` return $! MkRNG fptr'
foreign import ccall unsafe "gsl/gsl_rng.h"
gsl_rng_clone :: Ptr () -> IO (Ptr ())
mt19937 :: RNGType
mt19937 = MkRNGType gsl_rng_mt19937
foreign import ccall unsafe "gsl/gsl_rng.h &" gsl_rng_mt19937 :: Ptr ()