module Data.Floating.Environment (
module Control.Applicative,
RoundingMode(..), FloatException(..), FEnvState, FEnv,
fenvEval, withRoundingMode, raiseExceptions, fenvTrace,
unsafeSaveEnvironment, unsafeRestoreEnvironment,
unsafeRaiseExceptions,
unsafeSetRoundingMode, getRoundingMode
) where
#include <config.h>
import Prelude hiding (Float, Double, Floating(..), RealFloat(..))
import Data.Floating.Classes
import Control.Exception
import Control.Applicative
import Control.Monad
import System.IO.Unsafe
import Debug.Trace
import Foreign.C
import Foreign
foreign import ccall unsafe "set_roundmode"
set_roundmode :: CInt -> IO CInt
foreign import ccall unsafe "get_roundmode"
get_roundmode :: IO CInt
foreign import ccall unsafe "fegetenv"
c_fegetenv :: Ptr FEnvState -> IO CInt
foreign import ccall unsafe "feholdexcept"
c_feholdexcept :: Ptr FEnvState -> IO CInt
foreign import ccall unsafe "fenv_restore"
fenv_restore :: Ptr FEnvState -> Ptr CUInt -> IO CInt
foreign import ccall unsafe "fenv_raise_excepts"
fenv_raise_excepts :: CUInt -> IO CInt
data RoundingMode = ToNearest | Upward | Downward | TowardZero
deriving (Show, Read, Enum, Bounded)
data FloatException = DivByZero | Inexact | Invalid | Overflow | Underflow
deriving (Show, Read, Enum, Bounded)
newtype FEnvState = FEnvState (ForeignPtr FEnvState)
instance Storable FEnvState where
sizeOf = const SIZEOF_FENV_T
alignment = const ALIGNOF_FENV_T
peek ptr = do
fp <- mallocForeignPtrBytes SIZEOF_FENV_T
withForeignPtr fp (\p -> copyBytes p ptr SIZEOF_FENV_T)
return (FEnvState fp)
poke ptr (FEnvState fp) = do
withForeignPtr fp (\p -> copyBytes ptr p SIZEOF_FENV_T)
data FEnv a = forall b . FEnv (b -> a) !b
instance Functor FEnv where
fmap f (FEnv g x) = FEnv (f . g) x
instance Applicative FEnv where
pure = FEnv id
(FEnv f x) <*> (FEnv g y) = FEnv (\(x',y') -> f x' . g $ y') (x, y)
instance Eq a => Eq (FEnv a) where
(==) = error "The Eq instance for FEnv is a lie."
instance Show a => Show (FEnv a) where
show = const "<<FEnv>>"
instance Num a => Num (FEnv a) where
(+) = liftA2 (+)
() = liftA2 ()
(*) = liftA2 (*)
negate = liftA negate
signum = liftA signum
abs = liftA abs
fromInteger = pure . fromInteger
instance Fractional a => Fractional (FEnv a) where
(/) = liftA2 (/)
recip = liftA recip
fromRational = pure . fromRational
instance Floating a => Floating (FEnv a) where
(**) = liftA2 (**)
sqrt = liftA sqrt
acos = liftA acos
asin = liftA asin
atan = liftA atan
cos = liftA cos
sin = liftA sin
tan = liftA tan
cosh = liftA cosh
sinh = liftA sinh
tanh = liftA tanh
exp = liftA exp
log = liftA log
acosh = liftA acosh
asinh = liftA asinh
atanh = liftA atanh
instance RealFloat a => RealFloat (FEnv a) where
fma = liftA3 fma
copysign = liftA2 copysign
nextafter = liftA2 nextafter
fmod = liftA2 fmod
frem = liftA2 frem
atan2 = liftA2 atan2
hypot = liftA2 hypot
cbrt = liftA cbrt
exp2 = liftA exp2
expm1 = liftA expm1
log10 = liftA log10
log1p = liftA log1p
log2 = liftA log2
logb = liftA logb
erf = liftA erf
erfc = liftA erfc
gamma = liftA gamma
lgamma = liftA lgamma
nearbyint = liftA nearbyint
rint = liftA rint
infinity = pure infinity
nan = pure nan
pi = pure pi
unsafeSaveEnvironment :: Bool -> IO FEnvState
unsafeSaveEnvironment reset = alloca $ \env -> do
rc <- saveEnv env
unless (rc == 0) $ fail "Error saving floating point environment."
peek env
where
saveEnv = if reset then c_feholdexcept else c_fegetenv
unsafeRestoreEnvironment :: FEnvState -> IO [FloatException]
unsafeRestoreEnvironment (FEnvState fp) = alloca $ \pe -> do
rc <- withForeignPtr fp (flip fenv_restore pe)
unless (rc == 0) $ fail "Error restoring floating point environment."
rawExcepts <- peek pe
return $! filter (testBit rawExcepts . fromEnum) [minBound..maxBound]
unsafeRaiseExceptions :: [FloatException] -> IO ()
unsafeRaiseExceptions ex = do
rc <- fenv_raise_excepts $ foldr (flip setBit . fromEnum) 0 ex
unless (rc == 0) $ fail "Error raising floating point exceptions."
unsafeSetRoundingMode :: RoundingMode -> IO ()
unsafeSetRoundingMode mode = do
rc <- set_roundmode (fromIntegral (fromEnum mode))
unless (rc == 0) $ fail "Error setting rounding mode"
getRoundingMode :: IO RoundingMode
getRoundingMode = do
rc <- get_roundmode
unless (rc >= 0) $ fail "Error getting rounding mode"
return . toEnum . fromIntegral $ rc
withRoundingMode :: RoundingMode -> FEnv a -> FEnv a
withRoundingMode mode (FEnv f x) = FEnv unsafePerformIO $ do
oldMode <- getRoundingMode
unsafeSetRoundingMode mode
rc <- evaluate $ f x
unsafeSetRoundingMode oldMode
return rc
raiseExceptions :: [FloatException] -> FEnv a -> FEnv a
raiseExceptions ex = liftA2 seq $
FEnv unsafePerformIO (unsafeRaiseExceptions ex)
fenvTrace :: String -> a -> FEnv a
fenvTrace s = fmap (trace s) . pure
fenvEval :: FEnv a -> IO (a, [FloatException])
fenvEval (FEnv f x) = do
env <- unsafeSaveEnvironment True
rc <- evaluate $ f x
ex <- unsafeRestoreEnvironment env
return (rc, ex)