{-# LANGUAGE CPP, BangPatterns, ScopedTypeVariables, ForeignFunctionInterface #-}
module Numeric.SpecFunctions.Internal
    ( erf
    , erfc
    , invErf
    , invErfc
    , log2
    ) where
import Data.Bits       ((.&.), (.|.), shiftR)
import Data.Word       (Word64)
import qualified Data.Vector.Unboxed as U
import Numeric.MathFunctions.Constants
erf :: Double -> Double
{-# INLINE erf #-}
erf = c_erf
erfc :: Double -> Double
{-# INLINE erfc #-}
erfc = c_erfc
foreign import ccall "erf"  c_erf  :: Double -> Double
foreign import ccall "erfc" c_erfc :: Double -> Double
invErf :: Double 
       -> Double
invErf p = invErfc (1 - p)
invErfc :: Double 
        -> Double
invErfc p
  | p == 2        = m_neg_inf
  | p == 0        = m_pos_inf
  | p >0 && p < 2 = if p <= 1 then r else -r
  | otherwise     = modErr $ "invErfc: p must be in [0,2] got " ++ show p
  where
    pp = if p <= 1 then p else 2 - p
    t  = sqrt $ -2 * log( 0.5 * pp)
    
    x0 = -0.70711 * ((2.30753 + t * 0.27061) / (1 + t * (0.99229 + t * 0.04481)) - t)
    r  = loop 0 x0
    
    loop :: Int -> Double -> Double
    loop !j !x
      | j >= 2    = x
      | otherwise = let err = erfc x - pp
                        x'  = x + err / (1.12837916709551257 * exp(-x * x) - x * err) 
                    in loop (j+1) x'
log2 :: Int -> Int
log2 v0
    | v0 <= 0   = modErr $ "log2: nonpositive input, got " ++ show v0
    | otherwise = go 5 0 v0
  where
    go !i !r !v | i == -1        = r
                | v .&. b i /= 0 = let si = U.unsafeIndex sv i
                                   in go (i-1) (r .|. si) (v `shiftR` si)
                | otherwise      = go (i-1) r v
    b = U.unsafeIndex bv
    !bv = U.fromList [ 0x02, 0x0c, 0xf0, 0xff00
                     , fromIntegral (0xffff0000 :: Word64)
                     , fromIntegral (0xffffffff00000000 :: Word64)]
    !sv = U.fromList [1,2,4,8,16,32]
modErr :: String -> a
modErr msg = error $ "Numeric.SpecFunctions." ++ msg