{-# LANGUAGE BangPatterns, DeriveDataTypeable #-}
module Numeric.QD.QuadDouble
  ( QuadDouble(QuadDouble)
  , toDouble
  , fromDouble
  , toDoubleDouble
  , fromDoubleDouble
  , sqr
  ) where

import Foreign (Ptr, alloca, castPtr, Storable(..), unsafePerformIO, with)
import Foreign.C (CDouble, CInt)
import Data.Bits (shiftL, shiftR)
import Data.Ratio ((%))
import Data.Typeable (Typeable(..))
import Numeric (showFloat, readSigned, readFloat)

import Numeric.QD.DoubleDouble (DoubleDouble(DoubleDouble))
import Numeric.QD.QuadDouble.Raw
  ( c_qd_add
  , c_qd_sub
  , c_qd_mul
  , c_qd_div
  , c_qd_pi
  , c_qd_exp
  , c_qd_sqrt
  , c_qd_log
  , c_qd_sin
  , c_qd_cos
  , c_qd_tan
  , c_qd_asin
  , c_qd_acos
  , c_qd_atan
  , c_qd_sinh
  , c_qd_cosh
  , c_qd_tanh
  , c_qd_asinh
  , c_qd_acosh
  , c_qd_atanh
  , c_qd_comp
  , c_qd_neg
  , c_qd_abs
  , c_qd_aint
  , c_qd_nint
  , c_qd_ceil
  , c_qd_floor
  , c_qd_atan2
  , c_qd_sqr
  )

data QuadDouble = QuadDouble {-# UNPACK #-} !CDouble {-# UNPACK #-} !CDouble {-# UNPACK #-} !CDouble {-# UNPACK #-} !CDouble deriving Typeable

toDouble :: QuadDouble -> Double
toDouble !(QuadDouble a _ _ _) = realToFrac a

fromDouble :: Double -> QuadDouble
fromDouble !a = QuadDouble (realToFrac a) 0 0 0

toDoubleDouble :: QuadDouble -> DoubleDouble
toDoubleDouble !(QuadDouble a b _ _) = DoubleDouble a b

fromDoubleDouble :: DoubleDouble -> QuadDouble
fromDoubleDouble !(DoubleDouble a b) = QuadDouble a b 0 0

instance Eq QuadDouble where
  (!a) == (!b) = a `compare` b == EQ
  (!a) /= (!b) = a `compare` b /= EQ

instance Ord QuadDouble where
  (!a) `compare` (!b) = unsafePerformIO $ with a $ \p -> with b $ \q -> alloca $ \r -> do
                          c_qd_comp (castPtr p) (castPtr q) (castPtr r)
                          !i <- peek r
                          return $ i `compare` (0 :: CInt)

instance Num QuadDouble where
  (+) = lift_qd_qd_qd c_qd_add
  (*) = lift_qd_qd_qd c_qd_mul
  (-) = lift_qd_qd_qd c_qd_sub
  negate = lift_qd_qd c_qd_neg
  abs = lift_qd_qd c_qd_abs
  signum !a = case a `compare` 0 of { LT -> -1 ; EQ -> 0 ; GT -> 1 }
  fromInteger !i = fromRational (i % 1)

sqr :: QuadDouble -> QuadDouble
sqr = lift_qd_qd c_qd_sqr

instance Fractional QuadDouble where
  (/) = lift_qd_qd_qd c_qd_div
  recip !b = 1 / b
  fromRational !k = let a = fromRational k
                        ka = k  - toRational a
                        b = fromRational ka
                        kb = ka - toRational b
                        c = fromRational kb
                        kc = kb - toRational c
                        d = fromRational kc
                    in  QuadDouble a b c d

instance Real QuadDouble where
  toRational (QuadDouble a b c d) = toRational a + toRational b + toRational c + toRational d

instance RealFrac QuadDouble where
  properFraction k = let (n, r) = properFraction (toRational k)
                     in  (n, fromRational r)
  truncate = truncate . toRational . lift_qd_qd c_qd_aint
  round = round . toRational . lift_qd_qd c_qd_nint
  ceiling = ceiling . toRational . lift_qd_qd c_qd_ceil
  floor = floor . toRational . lift_qd_qd c_qd_floor

instance Floating QuadDouble where
  pi = unsafePerformIO $ alloca $ \r -> c_qd_pi (castPtr r) >> peek r
  exp = lift_qd_qd c_qd_exp
  sqrt = lift_qd_qd c_qd_sqrt
  log = lift_qd_qd c_qd_log
  sin = lift_qd_qd c_qd_sin
  cos = lift_qd_qd c_qd_cos
  tan = lift_qd_qd c_qd_tan
  asin = lift_qd_qd c_qd_asin
  acos = lift_qd_qd c_qd_acos
  atan = lift_qd_qd c_qd_atan
  sinh = lift_qd_qd c_qd_sinh
  cosh = lift_qd_qd c_qd_cosh
  tanh = lift_qd_qd c_qd_tanh
  asinh = lift_qd_qd c_qd_asinh
  acosh = lift_qd_qd c_qd_acosh
  atanh = lift_qd_qd c_qd_atanh

instance RealFloat QuadDouble where
  floatRadix _ = 2
  floatDigits _ = 4 * floatDigits (error "Numeric.QD.QuadDouble.floatDigits" :: CDouble)
  floatRange _ = floatRange (error "Numeric.QD.QuadDouble.floatRange" :: CDouble)
  decodeFloat !x = case toRational x of
    0 -> (0, 0)
    r ->  let k = floor $ fromIntegral ff - logBase (fromIntegral $ floatRadix x) (abs x)
              i = round $ (fromIntegral $ floatRadix x) ^ k * r
              fixup m e =
                if abs m < mMin
                  then fixup (m `shiftL` 1) (e - 1)
                  else if abs m >= mMax
                         then fixup (m `shiftR` 1) (e + 1)
                         else (m, e)
              mMin = 1 `shiftL` (ff - 1)
              mMax = 1 `shiftL` ff
              ff = floatDigits x
              g = -k
          in  fixup i g
  encodeFloat m e = scaleFloat e (fromInteger m)
  -- exponent _ = -- use default implementation
  -- significand _ = -- use default implementation
  scaleFloat !n !(QuadDouble a b c d) = QuadDouble (scaleFloat n a) (scaleFloat n b) (scaleFloat n c) (scaleFloat n d)
  isNaN !(QuadDouble a b c d) = isNaN a || isNaN b || isNaN c || isNaN d
  isInfinite !(QuadDouble a b c d) = isInfinite a || isInfinite b || isInfinite c || isInfinite d
  isDenormalized !(QuadDouble a b c d) = isDenormalized a || isDenormalized b || isDenormalized c || isDenormalized d
  isNegativeZero !(QuadDouble a b c d) = isNegativeZero a || (a == 0 && (isNegativeZero b || (b == 0 && (isNegativeZero c || (c == 0 && isNegativeZero d)))))
  isIEEE _ = False -- FIXME what does this imply?
  atan2 = lift_qd_qd_qd c_qd_atan2

-- instance Enum QuadDouble -- FIXME

instance Show QuadDouble where
  show = flip showFloat ""

instance Read QuadDouble where
  readsPrec _ = readSigned readFloat

instance Storable QuadDouble where
  sizeOf _ = 4 * sizeOf (error "Numeric.QD.QuadDouble.sizeOf" :: CDouble)
  alignment _ = alignment (error "Numeric.QD.QuadDouble.alignment" :: CDouble)
  peek !p = do
    let !q = castPtr p
    a <- peekElemOff q 0
    b <- peekElemOff q 1
    c <- peekElemOff q 2
    d <- peekElemOff q 3
    return $ QuadDouble a b c d
  poke !p !(QuadDouble a b c d) = do
    let !q = castPtr p
    pokeElemOff q 0 a
    pokeElemOff q 1 b
    pokeElemOff q 2 c
    pokeElemOff q 3 d

lift_qd_qd :: (Ptr CDouble -> Ptr CDouble -> IO ()) -> QuadDouble -> QuadDouble
lift_qd_qd f !a = unsafePerformIO $ with a $ \p -> alloca $ \r -> f (castPtr p) (castPtr r) >> peek r

lift_qd_qd_qd :: (Ptr CDouble -> Ptr CDouble -> Ptr CDouble -> IO ()) -> QuadDouble -> QuadDouble -> QuadDouble
lift_qd_qd_qd f !a !b = unsafePerformIO $ with a $ \p -> with b $ \q -> alloca $ \r -> f (castPtr p) (castPtr q) (castPtr r) >> peek r