{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE MagicHash #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.Float128.Float128
-- Copyright   :  (C) 2020 Claude Heiland-Allen
-- License     :  BSD3
-- Maintainer  :  Claude Heiland-Allen <claude@mathr.co.uk>
-- Stability   :  experimental
-- Portability :  non-portable (needs C _Float128)
--
-- This module contains a Float128 type that can be used if you need the
-- extra precision or range from the IEEE _Float128 quadruple precision type.
-- It has 15bit signed exponent (compared to 11bit signed exponent for Double)
-- and 113bit mantissa (compared to 53bit mantissa for Double).
--
-- Performance is likely to be poor, as the instances are implemented using
-- FFI with Ptr Float128, copying to and from memory around each operation.
-- If you need to bind to functions taking/returning _Float128 you need to
-- write wrapper functions expecting pointers to _Float128 instead, as GHC
-- does not expose a CFloat128 FFI type.
module Numeric.Float128
  (
  -- * _Float128 data type
    Float128(..)
  -- * RealFrac alternatives
  , truncate'
  , round'
  , ceiling'
  , floor'
  -- * Conversions
  , fromDouble
  , toDouble
  , fromInt
  , toInt
  ) where

import Data.Bits (bit, testBit, (.&.), (.|.), shiftL, shiftR)
import Data.Ratio ((%), numerator, denominator)
import Data.Word (Word64)
import Foreign (Ptr, castPtr, with, alloca)
import Foreign.C.Types (CIntMax(..), CInt(..), CDouble(..))
import Foreign.Storable (Storable(..))
import Numeric (showFloat, readFloat, readSigned)
import System.IO.Unsafe (unsafePerformIO)
import GHC.Exts (Int(..))
import GHC.Integer.Logarithms (integerLog2#)

-- | The _Float128 type.
data Float128 = F128 !Word64 !Word64 -- most significant, least significant

bigEndian :: Bool
bigEndian = False -- FIXME

instance Storable Float128 where
  sizeOf _ = 2 * sizeOf (0 :: Word64)
  alignment _ = 2 * alignment (0 :: Word64)
  peek p = do
    let q :: Ptr Word64
        q = castPtr p
    a <- peekElemOff q 0
    b <- peekElemOff q 1
    return $ if bigEndian then F128 a b else F128 b a
  poke p (F128 msw lsw) = do
    let q :: Ptr Word64
        q = castPtr p
        (a, b) = if bigEndian then (msw, lsw) else (lsw, msw)
    pokeElemOff q 0 a
    pokeElemOff q 1 b

instance Eq Float128 where
  (==) = cmp f128_eq
  (/=) = cmp f128_ne

instance Ord Float128 where
  (<=) = cmp f128_le
  (< ) = cmp f128_lt
  (>=) = cmp f128_ge
  (> ) = cmp f128_gt
  min  = f2 f128_min
  max  = f2 f128_max

instance Num Float128 where
  fromInteger z = encodeFloat z 0
  negate = f1 f128_neg
  (+) = f2 f128_add
  (-) = f2 f128_sub
  (*) = f2 f128_mul
  abs = f1 f128_abs
  signum = f1 f128_sgn

instance Real Float128 where
  toRational l = case decodeFloat l of
    (m, e)
      | e >= 0 -> m `shiftL` e % 1
      | otherwise -> m % bit (negate e)

instance Fractional Float128 where
  fromRational q = -- FIXME accuracy?
    let a = fromInteger (numerator q) / fromInteger (denominator q)
        r = q - toRational a
        b = fromInteger (numerator r) / fromInteger (denominator r)
    in  a + b
  (/) = f2 f128_div
  recip = f1 f128_recip

instance RealFrac Float128 where
  properFraction l
    | l >= 0 = let n' = floor' l
                   f = l - n'
               in  (fromInteger . toInteger' $ n', f)

    | l <  0 = let n' = ceiling' l
                   f = l - n'
               in  (fromInteger . toInteger' $ n', f)
    | otherwise = (0, l) -- NaN
  truncate = fromInteger . toInteger' . truncate'
  round    = fromInteger . toInteger' . round'
  ceiling  = fromInteger . toInteger' . ceiling'
  floor    = fromInteger . toInteger' . floor'

toInteger' :: Float128 -> Integer
toInteger' l = case decodeFloat l of
  (m, e)
    | e >= 0 -> m `shiftL` e
    | otherwise -> m `shiftR` negate e

-- | Alternate versions of RealFrac methods that
--   keep the value as a Float128.
truncate', round', ceiling', floor' :: Float128 -> Float128
truncate' = f1 f128_trunc
round'    = f1 f128_round
ceiling'  = f1 f128_ceil
floor'    = f1 f128_floor

instance Floating Float128 where
  pi = unsafePerformIO $ do
    alloca $ \lp -> do
      f128_pi lp
      peek lp
  exp = f1 f128_exp
  log = f1 f128_log
  sqrt = f1 f128_sqrt
  (**) = f2 f128_pow
  -- logBase
  sin = f1 f128_sin
  cos = f1 f128_cos
  tan = f1 f128_tan
  sinh = f1 f128_sinh
  cosh = f1 f128_cosh
  tanh = f1 f128_tanh
  asin = f1 f128_asin
  acos = f1 f128_acos
  atan = f1 f128_atan
  asinh = f1 f128_asinh
  acosh = f1 f128_acosh
  atanh = f1 f128_atanh

instance RealFloat Float128 where
  floatRadix _ = 2
  floatDigits _ = 113
  floatRange _ = (-16381,16384) -- FIXME verify?

  decodeFloat l@(F128 msw lsw)
    | isNaN l = (0, 0)
    | isInfinite l = (0, 0)
    | l == 0 = (0, 0)
    | isDenormalized l = case decodeFloat (scaleFloat 128 l) of
        (m, e) -> (m, e - 128)
    | otherwise =
        ( (if s then negate else id) (shiftL (0x1000000000000 .|. toInteger msw .&. 0xFFFFffffFFFF) 64 .|. toInteger lsw)
        , fromIntegral e0 - 16383 - 112 -- FIXME verify
        )
    where
      s = shiftR msw 48 `testBit` 15
      e0 = shiftR msw 48 .&. (bit 15 - 1)

  encodeFloat m e
    | m == 0 = F128 0 0
    | m <  0 = negate (encodeFloat (negate m) e)
    | b >= bit 15 - 1 = F128 (shiftL (bit 15 - 1) 48) 0 -- infinity
    | b <= 0 = scaleFloat (b - 128) (encodeFloat m (e - b + 128)) -- denormal
    | otherwise = F128 msw lsw -- normal
    where
      bigEndian = False -- FIXME
      l = I# (integerLog2# m)
      t = l - 112 -- FIXME verify
      m' | t >= 0    = m `shiftR`        t
         | otherwise = m `shiftL` negate t
      -- FIXME: verify that m' `testBit` 112 == True
      lsw = fromInteger (m' .&. 0xFFFFffffFFFFffff)
      msw = fromInteger (shiftR m' 64 .&. 0xFFFFffffFFFF) .|. shiftL (fromIntegral b) 48
      b = e + t + 16383 + 112 -- FIXME verify

  exponent l@(F128 msw lsw)
    | isNaN l = 0
    | isInfinite l = 0
    | l == 0 = 0
    | isDenormalized l = snd (decodeFloat l) + 113
    | otherwise = fromIntegral e0 - 16383 - 112 + 113
    where
      e0 = shiftR msw 48 .&. (bit 15 - 1)

  significand l = unsafePerformIO $ do
    with l $ \lp -> alloca $ \ep -> do
      f128_frexp lp lp ep
      peek lp

  scaleFloat e l = unsafePerformIO $ do
    with l $ \lp -> do
      f128_ldexp lp lp (fromIntegral e)
      peek lp

  isNaN = tst f128_isnan
  isInfinite = tst f128_isinf
  isDenormalized = tst f128_isdenorm
  isNegativeZero = tst f128_isnegzero
  isIEEE _= True

  atan2 = f2 f128_atan2

instance Read Float128 where
  readsPrec _ = readSigned readFloat

instance Show Float128 where
  showsPrec p x = showParen (p >= 7 && take 1 s == "-") (s ++) -- FIXME: precedence issues?
    where s = showFloat x ""

fromInt :: Int -> Float128
fromInt i = unsafePerformIO $ do
  alloca $ \lp -> do
    f128_set_i lp (fromIntegral i)
    peek lp

toInt :: Float128 -> Int
toInt l = unsafePerformIO $ with l f128_get_i

fromDouble :: Double -> Float128
fromDouble i = unsafePerformIO $ do
  alloca $ \lp -> do
    f128_set_d lp i
    peek lp

toDouble :: Float128 -> Double
toDouble l = unsafePerformIO $ with l f128_get_d

f2 :: F2 -> Float128 -> Float128 -> Float128
f2 f a b = unsafePerformIO $ do
  with a $ \ap -> with b $ \bp -> alloca $ \rp -> do
    f rp ap bp
    peek rp

type F2 = Ptr Float128 -> Ptr Float128 -> Ptr Float128 -> IO ()

foreign import ccall unsafe "f128_add"   f128_add   :: F2
foreign import ccall unsafe "f128_sub"   f128_sub   :: F2
foreign import ccall unsafe "f128_mul"   f128_mul   :: F2
foreign import ccall unsafe "f128_div"   f128_div   :: F2
foreign import ccall unsafe "f128_pow"   f128_pow   :: F2
foreign import ccall unsafe "f128_min"   f128_min   :: F2
foreign import ccall unsafe "f128_max"   f128_max   :: F2
foreign import ccall unsafe "f128_atan2" f128_atan2 :: F2

f1 :: F1 -> Float128 -> Float128
f1 f a = unsafePerformIO $ do
  with a $ \ap -> alloca $ \rp -> do
    f rp ap
    peek rp

type F1 = Ptr Float128 -> Ptr Float128 -> IO ()

foreign import ccall unsafe "f128_abs"   f128_abs   :: F1
foreign import ccall unsafe "f128_sgn"   f128_sgn   :: F1
foreign import ccall unsafe "f128_neg"   f128_neg   :: F1
foreign import ccall unsafe "f128_sqrt"  f128_sqrt  :: F1
foreign import ccall unsafe "f128_recip" f128_recip :: F1
foreign import ccall unsafe "f128_exp"   f128_exp   :: F1
foreign import ccall unsafe "f128_log"   f128_log   :: F1
foreign import ccall unsafe "f128_sin"   f128_sin   :: F1
foreign import ccall unsafe "f128_cos"   f128_cos   :: F1
foreign import ccall unsafe "f128_tan"   f128_tan   :: F1
foreign import ccall unsafe "f128_sinh"  f128_sinh  :: F1
foreign import ccall unsafe "f128_cosh"  f128_cosh  :: F1
foreign import ccall unsafe "f128_tanh"  f128_tanh  :: F1
foreign import ccall unsafe "f128_asin"  f128_asin  :: F1
foreign import ccall unsafe "f128_acos"  f128_acos  :: F1
foreign import ccall unsafe "f128_atan"  f128_atan  :: F1
foreign import ccall unsafe "f128_asinh" f128_asinh :: F1
foreign import ccall unsafe "f128_acosh" f128_acosh :: F1
foreign import ccall unsafe "f128_atanh" f128_atanh :: F1
foreign import ccall unsafe "f128_floor" f128_floor :: F1
foreign import ccall unsafe "f128_ceil"  f128_ceil  :: F1
foreign import ccall unsafe "f128_round" f128_round :: F1
foreign import ccall unsafe "f128_trunc" f128_trunc :: F1

type CMP = Ptr Float128 -> Ptr Float128 -> IO CInt

cmp :: CMP -> Float128 -> Float128 -> Bool
cmp f a b = unsafePerformIO $ do
  with a $ \ap -> with b $ \bp -> do
    r <- f ap bp
    return (r /= 0)

foreign import ccall unsafe "f128_eq" f128_eq :: CMP
foreign import ccall unsafe "f128_ne" f128_ne :: CMP
foreign import ccall unsafe "f128_lt" f128_lt :: CMP
foreign import ccall unsafe "f128_le" f128_le :: CMP
foreign import ccall unsafe "f128_gt" f128_gt :: CMP
foreign import ccall unsafe "f128_ge" f128_ge :: CMP

type TST = Ptr Float128 -> IO CInt

tst :: TST -> Float128 -> Bool
tst f a = unsafePerformIO $ do
  with a $ \ap -> do
    r <- f ap
    return (r /= 0)

foreign import ccall unsafe "f128_isnan" f128_isnan :: TST
foreign import ccall unsafe "f128_isinf" f128_isinf :: TST
foreign import ccall unsafe "f128_isdenorm" f128_isdenorm :: TST
foreign import ccall unsafe "f128_isnegzero" f128_isnegzero :: TST

foreign import ccall unsafe "f128_get_d" f128_get_d :: Ptr Float128 -> IO Double
foreign import ccall unsafe "f128_get_i" f128_get_i :: Ptr Float128 -> IO Int

foreign import ccall unsafe "f128_set_d" f128_set_d :: Ptr Float128 -> Double -> IO ()
foreign import ccall unsafe "f128_set_i" f128_set_i :: Ptr Float128 -> Int -> IO ()

foreign import ccall unsafe "f128_ldexp" f128_ldexp :: Ptr Float128 -> Ptr Float128 -> CInt -> IO ()
foreign import ccall unsafe "f128_frexp" f128_frexp :: Ptr Float128 -> Ptr Float128 -> Ptr CInt -> IO ()

foreign import ccall unsafe "f128_pi"    f128_pi    :: Ptr Float128 -> IO ()