{-
 - Copyright (C) 2010 Nick Bowler.
 -
 - License BSD2:  2-clause BSD license.  See LICENSE for full terms.
 - This is free software: you are free to change and redistribute it.
 - There is NO WARRANTY, to the extent permitted by law.
 -}

-- | Access to the floating point environment.  Performing this access within
-- a Haskell program turns out to be extremely problematic, because floating
-- point operations are secretly impure.  For example, the innocent-looking
-- function
--
-- @  (+) :: Double -> Double -> Double@
--
-- potentially both depends on and modifies the global floating point
-- environment.
--
-- This module avoids the referential transparency problems that occur as a
-- result of accessing the floating point environment by restricting when
-- computations which access it are evaluated.  There is some minor discipline
-- required of the programmer: she must arrange her code so that pure floating
-- point expressions are not forced during a call to 'fenvEval'.
-- See @fenv-impure.hs@ in the @examples/@ directory of the altfloat
-- distribution for why this discipline is necessary.
--
-- FEnv instances the numeric classes, so it should be possible to use
-- natural syntax.  Note that the operations done on FEnv are stored so that
-- they can be performed later, thus one should be take care not to construct
-- huge thunks when using this interface.
--
-- This interface has not been tested in multi-threaded programs.  It might
-- work: more info is needed about GHC's threading support.
{-# LANGUAGE CPP, ForeignFunctionInterface, ExistentialQuantification #-}
{-# OPTIONS_GHC -I. #-}
module Data.Floating.Environment (
    module Control.Applicative,

    -- * Data types
    RoundingMode(..), FloatException(..), FEnvState, FEnv,

    -- * Controlled access to the floating point environment
    -- | These functions can still break referential transparency, because it
    -- is possible to arrange for a pure floating point expression to be forced
    -- during the execution of 'fenvEval'.  The easiest way to ensure that this
    -- does not happen is to only use such expressions as the argument to
    -- 'pure'; never as the argument to 'fmap'.
    fenvEval, withRoundingMode, raiseExceptions, fenvTrace,

    -- * Direct access to the floating point environment
    -- | Special care must be taken when using these functions.  Modifying the
    -- floating point environment will affect all floating point computations
    -- that have not yet been evaluated.
    unsafeSaveEnvironment, unsafeRestoreEnvironment,
    unsafeRaiseExceptions,
    unsafeSetRoundingMode, getRoundingMode
) where

#include <config.h>

import Prelude hiding (Float, Double, Floating(..), RealFloat(..))

import Data.Floating.Classes
import Control.Exception
import Control.Applicative
import Control.Monad

import System.IO.Unsafe
import Debug.Trace

import Foreign.C
import Foreign

foreign import ccall unsafe "set_roundmode"
    set_roundmode :: CInt -> IO CInt
foreign import ccall unsafe "get_roundmode"
    get_roundmode :: IO CInt

foreign import ccall unsafe "fegetenv"
    c_fegetenv :: Ptr FEnvState -> IO CInt
foreign import ccall unsafe "feholdexcept"
    c_feholdexcept :: Ptr FEnvState -> IO CInt
foreign import ccall unsafe "fenv_restore"
    fenv_restore :: Ptr FEnvState -> Ptr CUInt -> IO CInt
foreign import ccall unsafe "fenv_raise_excepts"
    fenv_raise_excepts :: CUInt -> IO CInt

data RoundingMode = ToNearest | Upward | Downward | TowardZero
    deriving (Show, Read, Enum, Bounded)
data FloatException = DivByZero | Inexact | Invalid | Overflow | Underflow
    deriving (Show, Read, Enum, Bounded)

-- | Opaque type which stores the complete floating point environment.  It
-- corresponds to the C type @fenv_t@.
newtype FEnvState = FEnvState (ForeignPtr FEnvState)

instance Storable FEnvState where
    sizeOf    = const SIZEOF_FENV_T
    alignment = const ALIGNOF_FENV_T

    peek ptr = do
        fp <- mallocForeignPtrBytes SIZEOF_FENV_T
        withForeignPtr fp (\p -> copyBytes p ptr SIZEOF_FENV_T)
        return (FEnvState fp)
    poke ptr (FEnvState fp) = do
        withForeignPtr fp (\p -> copyBytes ptr p SIZEOF_FENV_T)

-- | Container for computations which will be run in a modified floating point
-- environment.  The FEnv container records all operations for later evaluation
-- by 'fenvEval'.  Note that 'pure' is strict in order to force evaluation
-- of floating point values stored in the container.
--
-- Do not use the 'Eq' or 'Show' instances, they are provided only because Num
-- requires them.
data FEnv a = forall b . FEnv (b -> a) !b

-- In the following instances, the two FEnv parts must be bashed together
-- exactly once every time the contained value is extracted.  Care must be
-- taken to avoid memoization of this result.  Interestingly, FEnv is not an
-- instance of Monad: While join (FEnv f x) = f x has the right type, it does
-- not satisfy this important property.

instance Functor FEnv where
    fmap f (FEnv g x) = FEnv (f . g) x

instance Applicative FEnv where
    pure = FEnv id
    (FEnv f x) <*> (FEnv g y) = FEnv (\(x',y') -> f x' . g $ y') (x, y)

-- For hysterical raisins, we need to instance Eq and Show since they are
-- superclasses of Num.
instance Eq a => Eq (FEnv a) where
    (==) = error "The Eq instance for FEnv is a lie."
instance Show a => Show (FEnv a) where
    show = const "<<FEnv>>"

instance Num a => Num (FEnv a) where
    (+)         = liftA2 (+)
    (-)         = liftA2 (-)
    (*)         = liftA2 (*)
    negate      = liftA negate
    signum      = liftA signum
    abs         = liftA abs
    fromInteger = pure . fromInteger

instance Fractional a => Fractional (FEnv a) where
    (/)          = liftA2 (/)
    recip        = liftA recip
    fromRational = pure . fromRational

instance Floating a => Floating (FEnv a) where
    (**)  = liftA2 (**)
    sqrt  = liftA  sqrt
    acos  = liftA  acos
    asin  = liftA  asin
    atan  = liftA  atan
    cos   = liftA  cos
    sin   = liftA  sin
    tan   = liftA  tan
    cosh  = liftA  cosh
    sinh  = liftA  sinh
    tanh  = liftA  tanh
    exp   = liftA  exp
    log   = liftA  log
    acosh = liftA  acosh
    asinh = liftA  asinh
    atanh = liftA  atanh

instance RealFloat a => RealFloat (FEnv a) where
    fma       = liftA3 fma
    copysign  = liftA2 copysign
    nextafter = liftA2 nextafter
    fmod      = liftA2 fmod
    frem      = liftA2 frem
    atan2     = liftA2 atan2
    hypot     = liftA2 hypot
    cbrt      = liftA  cbrt
    exp2      = liftA  exp2
    expm1     = liftA  expm1
    log10     = liftA  log10
    log1p     = liftA  log1p
    log2      = liftA  log2
    logb      = liftA  logb
    erf       = liftA  erf
    erfc      = liftA  erfc
    gamma     = liftA  gamma
    lgamma    = liftA  lgamma
    nearbyint = liftA  nearbyint
    rint      = liftA  rint

    infinity  = pure infinity
    nan       = pure nan
    pi        = pure pi

-- | Saves the current floating point environment and, optionally, clears all
-- floating point exception flags and sets non-stop (continue on exceptions)
-- mode.
unsafeSaveEnvironment :: Bool -> IO FEnvState
unsafeSaveEnvironment reset = alloca $ \env -> do
    rc <- saveEnv env
    unless (rc == 0) $ fail "Error saving floating point environment."
    peek env
    where
        saveEnv = if reset then c_feholdexcept else c_fegetenv

-- | Restores a previously-saved floating point environment and returns the
-- list of floating point exceptions that occurred prior to restoring the
-- environment.
unsafeRestoreEnvironment :: FEnvState -> IO [FloatException]
unsafeRestoreEnvironment (FEnvState fp) = alloca $ \pe -> do
    rc <- withForeignPtr fp (flip fenv_restore pe)
    unless (rc == 0) $ fail "Error restoring floating point environment."
    rawExcepts <- peek pe
    return $! filter (testBit rawExcepts . fromEnum) [minBound..maxBound]

-- | Raises the given floating point exceptions.
unsafeRaiseExceptions :: [FloatException] -> IO ()
unsafeRaiseExceptions ex = do
    rc <- fenv_raise_excepts $ foldr (flip setBit . fromEnum) 0 ex
    unless (rc == 0) $ fail "Error raising floating point exceptions."

unsafeSetRoundingMode :: RoundingMode -> IO ()
unsafeSetRoundingMode mode = do
    rc <- set_roundmode (fromIntegral (fromEnum mode))
    unless (rc == 0) $ fail "Error setting rounding mode"

getRoundingMode :: IO RoundingMode
getRoundingMode = do
    rc <- get_roundmode
    unless (rc >= 0) $ fail "Error getting rounding mode"
    return . toEnum . fromIntegral $ rc

-- | Evaluate an FEnv using a specific rounding mode.  Rounding mode selections
-- nest: subcomputations might use another mode.  The default rounding mode is
-- unspecified.
withRoundingMode :: RoundingMode -> FEnv a -> FEnv a
withRoundingMode mode (FEnv f x) = FEnv unsafePerformIO $ do
    oldMode <- getRoundingMode
    unsafeSetRoundingMode mode
    rc <- evaluate $ f x
    unsafeSetRoundingMode oldMode
    return rc

-- | Raise floating point exceptions as part of an FEnv computation.
raiseExceptions :: [FloatException] -> FEnv a -> FEnv a
raiseExceptions ex = liftA2 seq $
    FEnv unsafePerformIO (unsafeRaiseExceptions ex)

-- | This function is to help with debugging the floating point environment
-- handling.  @fenvTrace msg x@ constructs an FEnv value containing @x@ that
-- prints @msg@ (using 'Debug.Trace.trace') whenever the value is extracted.
fenvTrace :: String -> a -> FEnv a
fenvTrace s = fmap (trace s) . pure

-- | Runs all the computations which are recorded in an FEnv container.  The
-- floating point environment is preserved across this call, and any floating
-- point exceptions which were raised during the computation are returned.
fenvEval :: FEnv a -> IO (a, [FloatException])
fenvEval (FEnv f x) = do
    env <- unsafeSaveEnvironment True
    rc  <- evaluate $ f x
    ex  <- unsafeRestoreEnvironment env
    return (rc, ex)