-- | Internal stuff.
module Numeric.ExpExtended.Internal
  ( Cache(..)
  , cacheDefault
  , minExponent
  , maxExponent
  ) where

import Data.Bits (bit, shiftL, shiftR)

-- | Cache of useful magic values.
data Cache a = Cache
  { cRadix       :: !Integer
  -- ^ base 'floatRadix'
  , cDigits      :: !Int
  -- ^ base 'floatDigits'
  , cRangeMin    :: !Int
  -- ^ base 'fst' . 'floatRange'
  , cRangeMax    :: !Int
  -- ^ base 'snd' . 'floatRange'
  , cSupExponent :: !Int
  -- ^ magic for overflow checks
  --
  --   > (finite :: a) && supExponent <= e ==> maxExponent < exponent finite + e
  , cInfExponent :: !Int
  -- ^ magic for underflow checks
  --
  --   > (finite :: a) && e <= infExponent ==> exponent finite + e < minExponent
  , cUpShift     :: Integer -> Int -> Integer
  -- ^ radix-aware 'shiftL'
  , cDownShift   :: Integer -> Int -> Integer
  -- ^ radix-aware 'shiftR'
  , cRadixPower  :: Int -> Integer
  -- ^ radix-aware 'bit'
  , cExpMin      :: !Int
  -- ^ smaller than this exponent and base 'exp' is 1
  , cExpMax      :: !Int
  -- ^ larger  than this exponent and base 'exp' overflows to inf or 0
  , cExpInf      :: !Int
  -- ^ smaller than this exponent and extended 'exp' is 1
  , cExpSup      :: !Int
  -- ^ larger  than this exponent and extended 'exp' overflows to inf or 0
  , cLogRadix    :: !a
  -- ^ base 'log' . 'fromInteger' . 'floatRadix'
  , cRadix'      :: !a
  -- ^ base 'fromInteger' . 'floatRadix'
  }

-- | Calculate the magic values at a type.
cacheDefault :: RealFloat a => Cache a
cacheDefault = self
  where
    x = unProxy self
    self = Cache
            { cRadix = floatRadix x
            , cDigits = floatDigits x
            , cRangeMin = fst (floatRange x)
            , cRangeMax = snd (floatRange x)
            , cSupExponent = supExponentDefault x
            , cInfExponent = infExponentDefault x
            , cUpShift = case floatRadix x of
                2 -> shiftL
                d -> \n e -> n * (d ^ e)
            , cDownShift = case floatRadix x of
                2 -> shiftR
                d -> \n e -> n `div` (d ^ e)
            , cRadixPower = case floatRadix x of
                2 -> bit
                d -> \e -> d ^ e
            , cExpMin = expMinDefault x
            , cExpMax = expMaxDefault x
            , cExpInf = expInfDefault x
            , cExpSup = expSupDefault x
            , cLogRadix = log (fromIntegral (floatRadix x))
            , cRadix' = fromIntegral (floatRadix x)
            }

unProxy :: proxy a -> a
unProxy _ = undefined

margin :: Int
-- for handling small changes in exponents in addition to large changes
-- TODO prove this works properly
margin = 2

-- | Maximum exponent.
--
--   As big as possible without requiring more expensive overflow checks.
maxExponent :: Int
maxExponent = div maxBound 2 - margin

-- | Minimum exponent.
--
--   As small as possible without requiring more expensive overflow checks.
minExponent :: Int
minExponent = div minBound 2 + margin

{-
b ^ maxExponent = exp (b^e)
maxExponent * log b = b ^ e
maxExponent * log b = exp (e * log b))
log (maxExponent * log b) / log b = e
-}
expSupDefault :: RealFloat a => a -> Int
expSupDefault m = floor $ logBase b (fromIntegral maxExponent * log b)
  where
    b :: Double
    b = fromIntegral (floatRadix m)

expInfDefault :: RealFloat a => a -> Int
expInfDefault m = negate (floatDigits m)

expMaxDefault :: RealFloat a => a -> Int
expMaxDefault m = floor $ logBase b (log (scaleFloat (snd (floatRange m)) (recip b)))
  where b = fromIntegral (floatRadix m) `asTypeOf` m

expMinDefault :: RealFloat a => a -> Int
expMinDefault m = negate (floatDigits m)


-- Exponents larger than this will always overflow.
-- Smaller exponents might still overflow, depending on the base exponent.
supExponentDefault :: RealFloat a => a -> Int
supExponentDefault x
  = maxExponent
  - fst (floatRange x)
  + floatDigits x

-- Exponents smaller than this will always underflow.
-- Larger exponents might still underflow, depending on the base exponent.
infExponentDefault :: RealFloat a => a -> Int
infExponentDefault x
  = minExponent
  - snd (floatRange x)
  - floatDigits x