{-|
Module: Numeric.Rounded.Hardware.Backend.FastFFI

The types in this module implements interval addition and subtraction in assembly.

Currently, the only platform supported is x86_64.

One of the following technology will be used to control rounding mode:

    * SSE2 MXCSR
    * AVX512 EVEX encoding

You should not need to import this module directly.

This module may not be available depending on the platform or package flags.
-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GHCForeignImportPrim #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UnliftedFFITypes #-}
module Numeric.Rounded.Hardware.Backend.FastFFI
  ( CDouble(..)
  , fastIntervalAdd
  , fastIntervalSub
  , fastIntervalRecip
  , VUM.MVector(MV_CFloat, MV_CDouble)
  , VU.Vector(V_CFloat, V_CDouble)
  ) where
import           Control.DeepSeq (NFData (..))
import           Data.Coerce
import           Data.Proxy
import           Data.Tagged
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import qualified FFIWrapper.Double as D
import           Foreign.C.String (CString, peekCString)
import           Foreign.Storable (Storable)
import           GHC.Exts
import           GHC.Generics (Generic)
import           GHC.Int (Int64 (I64#))
import           GHC.Word (Word64 (W64#))
import qualified Numeric.Rounded.Hardware.Backend.C as C
import           Numeric.Rounded.Hardware.Internal.Class
import           System.IO.Unsafe (unsafePerformIO)
import           Unsafe.Coerce

#include "MachDeps.h"

--
-- Double
--

newtype CDouble = CDouble Double
  deriving (Eq,Ord,Show,Generic,Num,Storable)

instance NFData CDouble

instance RoundedRing CDouble where
  roundedAdd = coerce D.roundedAdd
  roundedSub = coerce D.roundedSub
  roundedMul = coerce D.roundedMul
  roundedFusedMultiplyAdd = coerce D.roundedFMA
  intervalAdd x x' y y' = coerce fastIntervalAdd x x' y y'
  intervalSub x x' y y' = coerce fastIntervalSub x x' y y'
  intervalMul x x' y y' = (coerce D.intervalMul_down x x' y y', coerce D.intervalMul_up x x' y y')
  intervalMulAdd x x' y y' z z' = (coerce D.intervalMulAdd_down x x' y y' z, coerce D.intervalMulAdd_up x x' y y' z')
  roundedFromInteger = coerce (roundedFromInteger :: RoundingMode -> Integer -> C.CDouble)
  intervalFromInteger = coerce (intervalFromInteger :: Integer -> (Rounded 'TowardNegInf C.CDouble, Rounded 'TowardInf C.CDouble))
  backendNameT = Tagged $ let base = backendName (Proxy :: Proxy C.CDouble)
                              intervals = intervalBackendName
                          in if base == intervals
                             then base ++ "+FastFFI"
                             else base ++ "+FastFFI(" ++ intervals ++ ")"
  {-# INLINE roundedAdd #-}
  {-# INLINE roundedSub #-}
  {-# INLINE roundedMul #-}
  {-# INLINE roundedFusedMultiplyAdd #-}
  {-# INLINE intervalAdd #-}
  {-# INLINE intervalSub #-}
  {-# INLINE intervalMul #-}
  {-# INLINE roundedFromInteger #-}
  {-# INLINE intervalFromInteger #-}

instance RoundedFractional CDouble where
  roundedDiv = coerce D.roundedDiv
  intervalDiv x x' y y' = (coerce D.intervalDiv_down x x' y y', coerce D.intervalDiv_up x x' y y')
  intervalDivAdd x x' y y' z z' = (coerce D.intervalDivAdd_down x x' y y' z, coerce D.intervalDivAdd_up x x' y y' z')
  intervalRecip x x' = coerce fastIntervalRecip x x'
  roundedFromRational = coerce (roundedFromRational :: RoundingMode -> Rational -> C.CDouble)
  roundedFromRealFloat r x = coerce (roundedFromRealFloat r x :: C.CDouble)
  intervalFromRational = coerce (intervalFromRational :: Rational -> (Rounded 'TowardNegInf C.CDouble, Rounded 'TowardInf C.CDouble))
  {-# INLINE roundedDiv #-}
  {-# INLINE intervalDiv #-}
  {-# INLINE intervalRecip #-}
  {-# INLINE roundedFromRational #-}
  {-# INLINE roundedFromRealFloat #-}
  {-# INLINE intervalFromRational #-}

instance RoundedSqrt CDouble where
  roundedSqrt = coerce D.roundedSqrt
  {-# INLINE roundedSqrt #-}

instance RoundedRing_Vector VS.Vector CDouble where
  roundedSum mode vec = coerce (roundedSum mode (unsafeCoerce vec :: VS.Vector C.CDouble))
  zipWith_roundedAdd mode vec vec' = unsafeCoerce (zipWith_roundedAdd mode (unsafeCoerce vec) (unsafeCoerce vec') :: VS.Vector C.CDouble)
  zipWith_roundedSub mode vec vec' = unsafeCoerce (zipWith_roundedSub mode (unsafeCoerce vec) (unsafeCoerce vec') :: VS.Vector C.CDouble)
  zipWith_roundedMul mode vec vec' = unsafeCoerce (zipWith_roundedMul mode (unsafeCoerce vec) (unsafeCoerce vec') :: VS.Vector C.CDouble)
  zipWith3_roundedFusedMultiplyAdd mode vec1 vec2 vec3 = unsafeCoerce (zipWith3_roundedFusedMultiplyAdd mode (unsafeCoerce vec1) (unsafeCoerce vec2) (unsafeCoerce vec3) :: VS.Vector C.CDouble)
  {-# INLINE roundedSum #-}
  {-# INLINE zipWith_roundedAdd #-}
  {-# INLINE zipWith_roundedSub #-}
  {-# INLINE zipWith_roundedMul #-}
  {-# INLINE zipWith3_roundedFusedMultiplyAdd #-}

instance RoundedFractional_Vector VS.Vector CDouble where
  zipWith_roundedDiv mode vec vec' = unsafeCoerce (zipWith_roundedDiv mode (unsafeCoerce vec) (unsafeCoerce vec') :: VS.Vector C.CDouble)
  {-# INLINE zipWith_roundedDiv #-}

instance RoundedSqrt_Vector VS.Vector CDouble where
  map_roundedSqrt mode vec = unsafeCoerce (map_roundedSqrt mode (unsafeCoerce vec) :: VS.Vector C.CDouble)
  {-# INLINE map_roundedSqrt #-}

deriving via C.CDouble instance RoundedRing_Vector VU.Vector CDouble
deriving via C.CDouble instance RoundedFractional_Vector VU.Vector CDouble
deriving via C.CDouble instance RoundedSqrt_Vector VU.Vector CDouble

--
-- FFI
--

foreign import prim "rounded_hw_interval_add"
  fastIntervalAdd# :: Double# -- lower 1, %xmm1
                   -> Double# -- upper 1, %xmm2
                   -> Double# -- lower 2, %xmm3
                   -> Double# -- upper 2, %xmm4
                   -> (# Double#  -- lower, %xmm1
                       , Double#  -- upper, %xmm2
                       #)

foreign import prim "rounded_hw_interval_sub"
  fastIntervalSub# :: Double# -- lower 1, %xmm1
                   -> Double# -- upper 1, %xmm2
                   -> Double# -- lower 2, %xmm3
                   -> Double# -- upper 2, %xmm4
                   -> (# Double#  -- lower, %xmm1
                       , Double#  -- upper, %xmm2
                       #)

foreign import prim "rounded_hw_interval_recip"
  fastIntervalRecip# :: Double# -- lower 1, %xmm1
                     -> Double# -- upper 1, %xmm2
                     -> (# Double#  -- lower, %xmm1
                         , Double#  -- upper, %xmm2
                         #)

foreign import prim "rounded_hw_interval_sqrt"
  fastIntervalSqrt# :: Double# -- lower 1, %xmm1
                    -> Double# -- upper 1, %xmm2
                    -> (# Double#  -- lower, %xmm1
                        , Double#  -- upper, %xmm2
                        #)

#if WORD_SIZE_IN_BITS >= 64
type INT64# = Int#
type WORD64# = Word#
#else
type INT64# = Int64#
type WORD64# = Word64#
#endif

foreign import prim "rounded_hw_interval_from_int64"
  fastIntervalFromInt64# :: INT64# -- value
                         -> (# Double# -- lower, %xmm1
                             , Double# -- upper, %xmm2
                             #)

{-
foreign import prim "rounded_hw_interval_from_word64"
  fastIntervalFromWord64# :: WORD64# -- value
                          -> (# Double# -- lower, %xmm1
                              , Double# -- upper, %xmm2
                              #)
-}

fastIntervalAdd :: Double -> Double -> Double -> Double -> (Double, Double)
fastIntervalAdd (D# l1) (D# h1) (D# l2) (D# h2) = case fastIntervalAdd# l1 h1 l2 h2 of
  (# l3, h3 #) -> (D# l3, D# h3)
{-# INLINE fastIntervalAdd #-}

fastIntervalSub :: Double -> Double -> Double -> Double -> (Double, Double)
fastIntervalSub (D# l1) (D# h1) (D# l2) (D# h2) = case fastIntervalSub# l1 h1 l2 h2 of
  (# l3, h3 #) -> (D# l3, D# h3)
{-# INLINE fastIntervalSub #-}

fastIntervalRecip :: Double -> Double -> (Double, Double)
fastIntervalRecip (D# l1) (D# h1) = case fastIntervalRecip# l1 h1 of
  (# l2, h2 #) -> (D# l2, D# h2)
{-# INLINE fastIntervalRecip #-}

fastIntervalSqrt :: Double -> Double -> (Double, Double)
fastIntervalSqrt (D# l1) (D# h1) = case fastIntervalSqrt# l1 h1 of
  (# l2, h2 #) -> (D# l2, D# h2)
{-# INLINE fastIntervalSqrt #-}

fastIntervalFromInt64 :: Int64 -> (Double, Double)
fastIntervalFromInt64 (I64# x) = case fastIntervalFromInt64# x of
  (# l, h #) -> (D# l, D# h)
{-# INLINE fastIntervalFromInt64 #-}

{-
fastIntervalFromWord64 :: Word64 -> (Double, Double)
fastIntervalFromWord64 (W64# x) = case fastIntervalFromWord64# x of
  (# l, h #) -> (D# l, D# h)
{-# INLINE fastIntervalFromWord64 #-}
-}

--
-- Backend name
--

foreign import ccall "&rounded_hw_interval_backend_name"
  c_interval_backend_name :: CString

intervalBackendName :: String
intervalBackendName = unsafePerformIO (peekCString c_interval_backend_name)

--
-- instance for Data.Vector.Unboxed.Unbox
--

newtype instance VUM.MVector s CDouble = MV_CDouble (VUM.MVector s Double)
newtype instance VU.Vector CDouble = V_CDouble (VU.Vector Double)

instance VGM.MVector VUM.MVector CDouble where
  basicLength (MV_CDouble mv) = VGM.basicLength mv
  basicUnsafeSlice i l (MV_CDouble mv) = MV_CDouble (VGM.basicUnsafeSlice i l mv)
  basicOverlaps (MV_CDouble mv) (MV_CDouble mv') = VGM.basicOverlaps mv mv'
  basicUnsafeNew l = MV_CDouble <$> VGM.basicUnsafeNew l
  basicInitialize (MV_CDouble mv) = VGM.basicInitialize mv
  basicUnsafeReplicate i x = MV_CDouble <$> VGM.basicUnsafeReplicate i (coerce x)
  basicUnsafeRead (MV_CDouble mv) i = coerce <$> VGM.basicUnsafeRead mv i
  basicUnsafeWrite (MV_CDouble mv) i x = VGM.basicUnsafeWrite mv i (coerce x)
  basicClear (MV_CDouble mv) = VGM.basicClear mv
  basicSet (MV_CDouble mv) x = VGM.basicSet mv (coerce x)
  basicUnsafeCopy (MV_CDouble mv) (MV_CDouble mv') = VGM.basicUnsafeCopy mv mv'
  basicUnsafeMove (MV_CDouble mv) (MV_CDouble mv') = VGM.basicUnsafeMove mv mv'
  basicUnsafeGrow (MV_CDouble mv) n = MV_CDouble <$> VGM.basicUnsafeGrow mv n

instance VG.Vector VU.Vector CDouble where
  basicUnsafeFreeze (MV_CDouble mv) = V_CDouble <$> VG.basicUnsafeFreeze mv
  basicUnsafeThaw (V_CDouble v) = MV_CDouble <$> VG.basicUnsafeThaw v
  basicLength (V_CDouble v) = VG.basicLength v
  basicUnsafeSlice i l (V_CDouble v) = V_CDouble (VG.basicUnsafeSlice i l v)
  basicUnsafeIndexM (V_CDouble v) i = coerce <$> VG.basicUnsafeIndexM v i
  basicUnsafeCopy (MV_CDouble mv) (V_CDouble v) = VG.basicUnsafeCopy mv v
  elemseq (V_CDouble v) x y = VG.elemseq v (coerce x) y

instance VU.Unbox CDouble