{-# LANGUAGE
        MultiParamTypeClasses
  #-}

module Data.Random.Source.GSL
    ( wrapRNG
    , newWrappedRNG
    , WrappedRNG
    ) where

import GSL.Random.Gen
import Data.Random.Source

import Data.Word
import Data.List
import Data.Maybe
import Data.Bits

-- |'RNG's may not produce random words in the full range of Word64.  Therefore,
-- 'getRandomWordFrom' requires that the range of the generator be checked, and
-- if it does not produce the full range, multiple samples will be taken to produce
-- a suitable 'Word64'.  The triage part will be done every time if sampling from
-- a raw 'RNG'.  Using 'WrappedRNG' allows that triage to be done ahead of time,
-- speeding up sampling integral values at the expense of an unwrap
-- when sampling 'Double's.
data WrappedRNG
    = RNG32 { unwrapRNG :: RNG }
    | RNG64 { unwrapRNG :: RNG }
    | RNGOther { unwrapRNG :: RNG, getWord64 :: IO Word64 }

newWrappedRNG rng = newRNG rng >>= wrapRNG

wrapRNG rng = do
    mn <- getMin rng
    mx <- getMax rng
    return $ case (mn, mx) of
        (0,18446744073709551615)    -> RNG64 rng
        (_,x) -> if mn == 0 && x >= 4294967295
            then RNG32 rng 
            else RNGOther rng getWord64
            
            where 
                getAdjustedSample
                    | mn == 0   = getSample rng
                    | otherwise = do
                        x <- getSample rng
                        return (x - mn)
                
                mxAdjusted = mx - mn
                
                bitsPerSample = log2 mxAdjusted
                
                -- Need to mask a sample if bitsPerSample is not precise?
                -- Seems like no, because any determinism will be
                -- smashed by the next sample
                
                getWord64 = getBits 0 0
                
                getBits offset accum
                    | offset >= 64 = return accum
                    | otherwise = do
                        x <- getAdjustedSample
                        getBits (offset + bitsPerSample) $! (accum `xor` (x `shiftL` offset))

log2 :: Word64 -> Int
log2 x = fromJust (findIndex (>x) powersOf2)

powersOf2 :: [Word64]
powersOf2 = map bit [0..63]

instance RandomSource IO WrappedRNG where
    getRandomWordFrom (RNG32 rng) = do
        x <- getSample rng
        y <- getSample rng
        return (x .|. (y `shiftL` 32))
    getRandomWordFrom (RNG64 rng) = getSample rng
    getRandomWordFrom other = getWord64 other
        
    getRandomDoubleFrom = getUniform . unwrapRNG
    getRandomByteFrom rng = fmap fromIntegral (getUniformInt (unwrapRNG rng) 255)

instance RandomSource IO RNG where
    getRandomByteFrom rng = fmap fromIntegral (getUniformInt rng 255)
    getRandomWordFrom rng = wrapRNG rng >>= getRandomWordFrom
    getRandomDoubleFrom = getUniform