{-# LANGUAGE ConstrainedClassMethods #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -Wno-orphans #-} module Numeric.Rounded.Hardware.Internal.Class ( module Numeric.Rounded.Hardware.Internal.Class , module Numeric.Rounded.Hardware.Internal.Rounding ) where import Data.Coerce import Data.Proxy import Data.Ratio import Data.Tagged import qualified Data.Vector.Generic as VG import Numeric.Rounded.Hardware.Internal.Rounding import Prelude hiding (fromInteger, fromRational, recip, sqrt, (*), (+), (-), (/)) import qualified Prelude -- | Rounding-controlled version of 'Num'. class Ord a => RoundedRing a where roundedAdd :: RoundingMode -> a -> a -> a roundedSub :: RoundingMode -> a -> a -> a roundedMul :: RoundingMode -> a -> a -> a roundedFusedMultiplyAdd :: RoundingMode -> a -> a -> a -> a roundedFromInteger :: RoundingMode -> Integer -> a -- roundedToFloat :: RoundingMode -> a -> Float -- roundedToDouble :: RoundingMode -> a -> Double -- | -- prop> \x_lo x_hi y_lo y_hi -> intervalAdd (Rounded x_lo) (Rounded x_hi) (Rounded y_lo) (Rounded y_hi) == (Rounded (roundedAdd TowardNegInf x_lo y_lo), Rounded (roundedAdd TowardInf x_hi y_hi)) intervalAdd :: Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> (Rounded 'TowardNegInf a, Rounded 'TowardInf a) intervalAdd x_lo x_hi y_lo y_hi = (x_lo + y_lo, x_hi + y_hi) where (+) :: forall r. Rounding r => Rounded r a -> Rounded r a -> Rounded r a Rounded x + Rounded y = Rounded (roundedAdd (rounding (Proxy :: Proxy r)) x y) -- | -- prop> \x_lo x_hi y_lo y_hi -> intervalSub (Rounded x_lo) (Rounded x_hi) (Rounded y_lo) (Rounded y_hi) == (Rounded (roundedSub TowardNegInf x_lo y_hi), Rounded (roundedSub TowardInf x_hi y_lo)) intervalSub :: Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> (Rounded 'TowardNegInf a, Rounded 'TowardInf a) intervalSub x_lo x_hi y_lo y_hi = (x_lo - coerce y_hi, x_hi - coerce y_lo) where (-) :: forall r. Rounding r => Rounded r a -> Rounded r a -> Rounded r a Rounded x - Rounded y = Rounded (roundedSub (rounding (Proxy :: Proxy r)) x y) intervalMul :: Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> (Rounded 'TowardNegInf a, Rounded 'TowardInf a) intervalMul x_lo x_hi y_lo y_hi = ( minimum [ x_lo * y_lo , x_lo * coerce y_hi , coerce x_hi * y_lo , coerce x_hi * coerce y_hi ] , maximum [ coerce x_lo * coerce y_lo , coerce x_lo * y_hi , x_hi * coerce y_lo , x_hi * y_hi ] ) where (*) :: forall r. Rounding r => Rounded r a -> Rounded r a -> Rounded r a Rounded x * Rounded y = Rounded (roundedMul (rounding (Proxy :: Proxy r)) x y) intervalMulAdd :: Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> (Rounded 'TowardNegInf a, Rounded 'TowardInf a) intervalMulAdd x_lo x_hi y_lo y_hi z_lo z_hi = case intervalMul x_lo x_hi y_lo y_hi of (xy_lo, xy_hi) -> intervalAdd xy_lo xy_hi z_lo z_hi intervalFromInteger :: Integer -> (Rounded 'TowardNegInf a, Rounded 'TowardInf a) intervalFromInteger x = (fromInteger x, fromInteger x) where fromInteger :: forall r. Rounding r => Integer -> Rounded r a fromInteger y = Rounded (roundedFromInteger (rounding (Proxy :: Proxy r)) y) {-# INLINE intervalAdd #-} {-# INLINE intervalSub #-} {-# INLINE intervalMul #-} {-# INLINE intervalFromInteger #-} backendNameT :: Tagged a String -- | Returns the name of backend as a string. -- -- Example: -- -- @ -- >>> :m + Data.Proxy -- >>> 'backendName' (Proxy :: Proxy Double) -- "FastFFI+SSE2" -- @ backendName :: RoundedRing a => proxy a -> String backendName = Data.Tagged.proxy backendNameT {-# INLINE backendName #-} -- | Rounding-controlled version of 'Fractional'. class RoundedRing a => RoundedFractional a where roundedDiv :: RoundingMode -> a -> a -> a roundedRecip :: RoundingMode -> a -> a default roundedRecip :: Num a => RoundingMode -> a -> a roundedRecip r = roundedDiv r 1 roundedFromRational :: RoundingMode -> Rational -> a roundedFromRealFloat :: RealFloat b => RoundingMode -> b -> a default roundedFromRealFloat :: (Fractional a, RealFloat b) => RoundingMode -> b -> a roundedFromRealFloat r x | isNaN x = 0 Prelude./ 0 | isInfinite x = if x > 0 then 1 Prelude./ 0 else -1 Prelude./ 0 | isNegativeZero x = -0 | otherwise = roundedFromRational r (toRational x) intervalDiv :: Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> (Rounded 'TowardNegInf a, Rounded 'TowardInf a) intervalDiv x_lo x_hi y_lo y_hi = ( minimum [ x_lo / y_lo , x_lo / coerce y_hi , coerce x_hi / y_lo , coerce x_hi / coerce y_hi ] , maximum [ coerce x_lo / coerce y_lo , coerce x_lo / y_hi , x_hi / coerce y_lo , x_hi / y_hi ] ) where (/) :: forall r. Rounding r => Rounded r a -> Rounded r a -> Rounded r a Rounded x / Rounded y = Rounded (roundedDiv (rounding (Proxy :: Proxy r)) x y) intervalDivAdd :: Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> (Rounded 'TowardNegInf a, Rounded 'TowardInf a) intervalDivAdd x_lo x_hi y_lo y_hi z_lo z_hi = case intervalDiv x_lo x_hi y_lo y_hi of (xy_lo, xy_hi) -> intervalAdd xy_lo xy_hi z_lo z_hi intervalRecip :: Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> (Rounded 'TowardNegInf a, Rounded 'TowardInf a) intervalRecip x_lo x_hi = (recip (coerce x_hi), recip (coerce x_lo)) where recip :: forall r. Rounding r => Rounded r a -> Rounded r a recip (Rounded x) = Rounded (roundedRecip (rounding (Proxy :: Proxy r)) x) intervalFromRational :: Rational -> (Rounded 'TowardNegInf a, Rounded 'TowardInf a) intervalFromRational x = (fromRational x, fromRational x) where fromRational :: forall r. Rounding r => Rational -> Rounded r a fromRational y = Rounded (roundedFromRational (rounding (Proxy :: Proxy r)) y) {-# INLINE intervalDiv #-} {-# INLINE intervalRecip #-} {-# INLINE intervalFromRational #-} -- | Rounding-controlled version of 'sqrt'. class RoundedRing a => RoundedSqrt a where roundedSqrt :: RoundingMode -> a -> a intervalSqrt :: Rounded 'TowardNegInf a -> Rounded 'TowardInf a -> (Rounded 'TowardNegInf a, Rounded 'TowardInf a) intervalSqrt x y = (sqrt x, sqrt y) where sqrt :: forall r. Rounding r => Rounded r a -> Rounded r a sqrt (Rounded z) = Rounded (roundedSqrt (rounding (Proxy :: Proxy r)) z) {-# INLINE intervalSqrt #-} -- | Lifted version of 'RoundedRing' class RoundedRing a => RoundedRing_Vector vector a where -- | Equivalent to @\\r -> foldl ('roundedAdd' r) 0@ roundedSum :: RoundingMode -> vector a -> a -- | Equivalent to @zipWith . 'roundedAdd'@ zipWith_roundedAdd :: RoundingMode -> vector a -> vector a -> vector a -- | Equivalent to @zipWith . 'roundedSub'@ zipWith_roundedSub :: RoundingMode -> vector a -> vector a -> vector a -- | Equivalent to @zipWith . 'roundedMul'@ zipWith_roundedMul :: RoundingMode -> vector a -> vector a -> vector a -- | Equivalent to @zipWith3 . 'roundedFusedMultiplyAdd'@ zipWith3_roundedFusedMultiplyAdd :: RoundingMode -> vector a -> vector a -> vector a -> vector a default roundedSum :: (VG.Vector vector a, Num a) => RoundingMode -> vector a -> a roundedSum mode = VG.foldl' (roundedAdd mode) 0 default zipWith_roundedAdd :: (VG.Vector vector a) => RoundingMode -> vector a -> vector a -> vector a zipWith_roundedAdd mode = VG.zipWith (roundedAdd mode) default zipWith_roundedSub :: (VG.Vector vector a) => RoundingMode -> vector a -> vector a -> vector a zipWith_roundedSub mode = VG.zipWith (roundedSub mode) default zipWith_roundedMul :: (VG.Vector vector a) => RoundingMode -> vector a -> vector a -> vector a zipWith_roundedMul mode = VG.zipWith (roundedMul mode) default zipWith3_roundedFusedMultiplyAdd :: (VG.Vector vector a) => RoundingMode -> vector a -> vector a -> vector a -> vector a zipWith3_roundedFusedMultiplyAdd mode = VG.zipWith3 (roundedFusedMultiplyAdd mode) -- | Lifted version of 'RoundedFractional' class (RoundedFractional a, RoundedRing_Vector vector a) => RoundedFractional_Vector vector a where -- | Equivalent to @zipWith . 'roundedDiv'@ zipWith_roundedDiv :: RoundingMode -> vector a -> vector a -> vector a -- map_roundedRecip :: RoundingMode -> vector a -> vector a default zipWith_roundedDiv :: (VG.Vector vector a) => RoundingMode -> vector a -> vector a -> vector a zipWith_roundedDiv mode = VG.zipWith (roundedDiv mode) -- | Lifted version of 'RoundedSqrt' class (RoundedSqrt a, RoundedRing_Vector vector a) => RoundedSqrt_Vector vector a where -- | Equivalent to @map . 'roundedSqrt'@ map_roundedSqrt :: RoundingMode -> vector a -> vector a default map_roundedSqrt :: (VG.Vector vector a) => RoundingMode -> vector a -> vector a map_roundedSqrt mode = VG.map (roundedSqrt mode) instance (Rounding r, Num a, RoundedRing a) => Num (Rounded r a) where Rounded x + Rounded y = Rounded (roundedAdd (rounding (Proxy :: Proxy r)) x y) Rounded x - Rounded y = Rounded (roundedSub (rounding (Proxy :: Proxy r)) x y) Rounded x * Rounded y = Rounded (roundedMul (rounding (Proxy :: Proxy r)) x y) negate = coerce (negate :: a -> a) abs = coerce (abs :: a -> a) signum = coerce (signum :: a -> a) fromInteger x = Rounded (roundedFromInteger (rounding (Proxy :: Proxy r)) x) {-# INLINE (+) #-} {-# INLINE (-) #-} {-# INLINE (*) #-} {-# INLINE negate #-} {-# INLINE abs #-} {-# INLINE signum #-} {-# INLINE fromInteger #-} instance (Rounding r, Num a, RoundedFractional a) => Fractional (Rounded r a) where Rounded x / Rounded y = Rounded (roundedDiv (rounding (Proxy :: Proxy r)) x y) recip (Rounded x) = Rounded (roundedRecip (rounding (Proxy :: Proxy r)) x) fromRational x = Rounded (roundedFromRational (rounding (Proxy :: Proxy r)) x) {-# INLINE (/) #-} {-# INLINE recip #-} {-# INLINE fromRational #-} deriving newtype instance (Rounding r, Real a, RoundedFractional a) => Real (Rounded r a) deriving newtype instance (Rounding r, RealFrac a, RoundedFractional a) => RealFrac (Rounded r a) -- no instance for Floating/RealFloat currently... -- These instances are provided in Numeric.Rounded.Hardware.Backend.Default: -- instance RoundedRing Float -- instance RoundedFractional Float -- instance RoundedSqrt Float -- instance RoundedRing Double -- instance RoundedFractional Double -- instance RoundedSqrt Double instance RoundedRing Integer where roundedAdd _ = (Prelude.+) roundedSub _ = (Prelude.-) roundedMul _ = (Prelude.*) roundedFusedMultiplyAdd _ x y z = x Prelude.* y Prelude.+ z roundedFromInteger _ = id backendNameT = Tagged "Integer" instance RoundedFractional Integer where roundedDiv r x y = roundedFromRational r (x % y) roundedFromRational ToNearest = round roundedFromRational TowardNegInf = floor roundedFromRational TowardInf = ceiling roundedFromRational TowardZero = truncate roundedFromRealFloat r x | isNaN x = error "NaN" | isInfinite x = error "Infinity" | otherwise = roundedFromRational r (toRational x) -- TODO: instance RoundedSqrt Integer instance Integral a => RoundedRing (Ratio a) where roundedAdd _ = (Prelude.+) roundedSub _ = (Prelude.-) roundedMul _ = (Prelude.*) roundedFusedMultiplyAdd _ x y z = x Prelude.* y Prelude.+ z roundedFromInteger _ = Prelude.fromInteger backendNameT = Tagged "Rational" instance Integral a => RoundedFractional (Ratio a) where roundedDiv _ = (Prelude./) roundedRecip _ = Prelude.recip roundedFromRational _ = Prelude.fromRational roundedFromRealFloat _ x | isNaN x = error "NaN" | isInfinite x = error "Infinity" | otherwise = Prelude.fromRational (toRational x) -- There is no RoundedSqrt (Ratio a)