{-# LANGUAGE BangPatterns, DeriveDataTypeable #-}
module Numeric.QD.DoubleDouble
  ( DoubleDouble(DoubleDouble)
  , toDouble
  , fromDouble
  , sqr
  ) where

import Foreign (Ptr, alloca, castPtr, Storable(..), unsafePerformIO, with)
import Foreign.C (CDouble, CInt)
import Data.Ratio ((%))
import Data.Bits (shiftL, shiftR)
import Data.Typeable (Typeable(..))
import Numeric (showFloat, readSigned, readFloat)

import Numeric.QD.DoubleDouble.Raw
  ( c_dd_add
  , c_dd_sub
  , c_dd_mul
  , c_dd_div
  , c_dd_pi
  , c_dd_exp
  , c_dd_sqrt
  , c_dd_log
  , c_dd_sin
  , c_dd_cos
  , c_dd_tan
  , c_dd_asin
  , c_dd_acos
  , c_dd_atan
  , c_dd_sinh
  , c_dd_cosh
  , c_dd_tanh
  , c_dd_asinh
  , c_dd_acosh
  , c_dd_atanh
  , c_dd_comp
  , c_dd_neg
  , c_dd_abs
  , c_dd_aint
  , c_dd_nint
  , c_dd_ceil
  , c_dd_floor
  , c_dd_atan2
  , c_dd_sqr
  )

data DoubleDouble = DoubleDouble {-# UNPACK #-} !CDouble {-# UNPACK #-} !CDouble deriving Typeable

toDouble :: DoubleDouble -> Double
toDouble !(DoubleDouble a _) = realToFrac a

fromDouble :: Double -> DoubleDouble
fromDouble !a = DoubleDouble (realToFrac a) 0

instance Eq DoubleDouble where
  (!a) == (!b) = a `compare` b == EQ
  (!a) /= (!b) = a `compare` b /= EQ

instance Ord DoubleDouble where
  (!a) `compare` (!b) = unsafePerformIO $ with a $ \p -> with b $ \q -> alloca $ \r -> do
                          c_dd_comp (castPtr p) (castPtr q) (castPtr r)
                          !i <- peek r
                          return $ i `compare` (0 :: CInt)

instance Show DoubleDouble where
  show = flip showFloat ""

instance Read DoubleDouble where
  readsPrec _ = readSigned readFloat

instance Num DoubleDouble where
  (+) = lift_dd_dd_dd c_dd_add
  (*) = lift_dd_dd_dd c_dd_mul
  (-) = lift_dd_dd_dd c_dd_sub
  negate = lift_dd_dd c_dd_neg
  abs = lift_dd_dd c_dd_abs
  signum !a = case a `compare` 0 of { LT -> -1 ; EQ -> 0 ; GT -> 1 }
  fromInteger !i = fromRational (i % 1)

sqr :: DoubleDouble -> DoubleDouble
sqr = lift_dd_dd c_dd_sqr

instance Fractional DoubleDouble where
  (/) = lift_dd_dd_dd c_dd_div
  recip !b = 1 / b
  fromRational !k = let a = fromRational k
                        k' = k - toRational a
                        b = fromRational k'
                    in  DoubleDouble a b

instance Real DoubleDouble where
  toRational !(DoubleDouble a b) = toRational a + toRational b

instance RealFrac DoubleDouble where
  properFraction k = let (n, r) = properFraction (toRational k)
                     in  (n, fromRational r)
  truncate = truncate . toRational . lift_dd_dd c_dd_aint
  round = round . toRational . lift_dd_dd c_dd_nint
  ceiling = ceiling . toRational . lift_dd_dd c_dd_ceil
  floor = floor . toRational . lift_dd_dd c_dd_floor

instance Floating DoubleDouble where
  pi = unsafePerformIO $ alloca $ \r -> c_dd_pi (castPtr r) >> peek r
  exp = lift_dd_dd c_dd_exp
  sqrt = lift_dd_dd c_dd_sqrt
  log = lift_dd_dd c_dd_log
  sin = lift_dd_dd c_dd_sin
  cos = lift_dd_dd c_dd_cos
  tan = lift_dd_dd c_dd_tan
  asin = lift_dd_dd c_dd_asin
  acos = lift_dd_dd c_dd_acos
  atan = lift_dd_dd c_dd_atan
  sinh = lift_dd_dd c_dd_sinh
  cosh = lift_dd_dd c_dd_cosh
  tanh = lift_dd_dd c_dd_tanh
  asinh = lift_dd_dd c_dd_asinh
  acosh = lift_dd_dd c_dd_acosh
  atanh = lift_dd_dd c_dd_atanh

instance RealFloat DoubleDouble where
  floatRadix _ = 2
  floatDigits _ = 2 * floatDigits (error "Numeric.QD.DoubleDouble.floatDigits" :: CDouble)
  floatRange _ = floatRange (error "Numeric.QD.DoubleDouble.floatRange" :: CDouble)
  decodeFloat !x = case toRational x of
    0 -> (0, 0)
    r ->  let k = floor $ fromIntegral ff - logBase (fromIntegral $ floatRadix x) (abs x)
              i = round $ (fromIntegral $ floatRadix x) ^^ k * r
              fixup m e =
                if abs m < mMin
                  then fixup (m `shiftL` 1) (e - 1)
                  else if abs m >= mMax
                         then fixup (m `shiftR` 1) (e + 1)
                         else (m, e)
              mMin = 1 `shiftL` (ff - 1)
              mMax = 1 `shiftL` ff
              ff = floatDigits x
              g = -k
          in  fixup i g
  encodeFloat m e = scaleFloat e (fromInteger m)
  -- exponent _ = -- use default implementation
  -- significand _ = -- use default implementation
  scaleFloat !n !(DoubleDouble a b) = DoubleDouble (scaleFloat n a) (scaleFloat n b)
  isNaN !(DoubleDouble a b) = isNaN a || isNaN b
  isInfinite !(DoubleDouble a b) = isInfinite a || isInfinite b
  isDenormalized !(DoubleDouble a b) = isDenormalized a || isDenormalized b
  isNegativeZero !(DoubleDouble a b) = isNegativeZero a || (a == 0 && isNegativeZero b)
  isIEEE _ = False -- FIXME what does this imply?
  atan2 = lift_dd_dd_dd c_dd_atan2

-- instance Enum DoubleDouble -- FIXME

instance Storable DoubleDouble where
  sizeOf _ = 2 * sizeOf (error "Numeric.QD.DoubleDouble.sizeOf" :: CDouble)
  alignment _ = alignment (error "Numeric.QD.DoubleDouble.alignment" :: CDouble)
  peek !p = do
    let !q = castPtr p
    a <- peekElemOff q 0
    b <- peekElemOff q 1
    return $ DoubleDouble a b
  poke !p !(DoubleDouble a b) = do
    let !q = castPtr p
    pokeElemOff q 0 a
    pokeElemOff q 1 b

lift_dd_dd :: (Ptr CDouble -> Ptr CDouble -> IO ()) -> DoubleDouble -> DoubleDouble
lift_dd_dd f !a = unsafePerformIO $ with a $ \p -> alloca $ \r -> f (castPtr p) (castPtr r) >> peek r

lift_dd_dd_dd :: (Ptr CDouble -> Ptr CDouble -> Ptr CDouble -> IO ()) -> DoubleDouble -> DoubleDouble -> DoubleDouble
lift_dd_dd_dd f !a !b = unsafePerformIO $ with a $ \p -> with b $ \q -> alloca $ \r -> f (castPtr p) (castPtr q) (castPtr r) >> peek r