{-# LANGUAGE ImplicitParams, RankNTypes, ExistentialQuantification, ScopedTypeVariables, GeneralizedNewtypeDeriving #-}

{- |
  Module      :  Numeric.IEEE.Monad
  Copyright   :  (c) Sterling Clover 2008 <s.clover@gmail.com>
  License     :  BSD3
  Maintainer  :  Matt Morrow <mjm2002@gmail.com>
  Stability   :  provisional
  Portability :  portable (FFI)

  The IEEE monad provides tools for enforcing sequencing of calculations such that
  fine grained control is provided over triggering exceptions, evaluations within
  particular rounding modes, etc. The perturb family of functions is built using this,
  allowing pure computations paramaterized over an arbitrary Floating type to be
  tested for numeric stability.

-}

module Numeric.IEEE.Monad where

import Numeric.IEEE.RoundMode (RoundMode(..))
import qualified Numeric.IEEE.RoundMode as RM
import qualified Numeric.IEEE.FloatExceptions as FE
import Control.Applicative
import Control.Concurrent
import Control.Exception
import Control.Monad

-- | All uses of the IEEE monad need to be wrapped inside a top level call to
-- withIeeeDo. This ensures that access to floating point internals
-- is serialized properly, even between multiple threads.
withIeeeDo :: ((?ieeeMutex :: MVar ()) => IO a) -> IO a
withIeeeDo f = newMVar () >>= \x -> let ?ieeeMutex = x in f

newtype IEEE a = IEEE {unIEEE :: IO a} deriving (Monad, Functor, Applicative)

runIEEE :: (?ieeeMutex :: MVar()) => IEEE a -> IO a
runIEEE f = do
  takeMVar ?ieeeMutex
  ret <- unIEEE f
  putMVar ?ieeeMutex ()
  return ret

getRound :: IEEE RoundMode
getRound = IEEE RM.getRound

setRound :: RoundMode -> IEEE Bool
setRound m = IEEE $ RM.setRound m

clearFloatExcepts :: [ArithException] -> IEEE Bool
clearFloatExcepts xs = IEEE $ FE.clearFloatExcepts xs

getFloatExcepts :: IEEE [ArithException]
getFloatExcepts = IEEE FE.getFloatExcepts

-- | Forces strict evaluation of the enclosed numeric argument.
calculate :: a -> IEEE a
calculate = IEEE . evaluate

-- | Calculate, but also returns any floating exceptions triggered.
calculate' :: a -> IEEE (a,[FE.ArithException])
calculate' f = do
  getFloatExcepts >>= clearFloatExcepts
  ret <- IEEE . evaluate $ f
  exs <- getFloatExcepts
  return (ret,exs)

-- | Executes the specified IEEE action within a specific round mode.
withRoundMode :: RoundMode -> IEEE a -> IEEE a
withRoundMode r f = do
  x <- getRound
  if x == r then f else setRound r >> (f >>= calculate) >>= \ret -> setRound x >> return ret

-- | Given something of (forall a. Floating a => IEEE a) produces a four-tuple of
-- the value as calculated rounding up, down, to nearest, and towards zero.
perturb' :: (?ieeeMutex :: MVar (), Floating b) => (forall a. Floating a => IEEE a) -> IO (b, b, b, b)
perturb' f = runIEEE $ do
  x <- getRound
  setRound Upward
  u <- calculate =<< foo f
  setRound Downward
  d <- calculate =<< foo f
  setRound ToNearest
  tn <- calculate =<< foo f
  setRound TowardZero
  tz <- calculate =<< foo f
  setRound x
  return (u,d,tn,tz)
  where foo :: Floating b => (forall c. Floating c => IEEE c) -> IEEE b
        foo = id

-- | Given something that produces a Floating, returns a representation of the
-- absolute difference between the results as calculated rounding upwards and downwards.
perturb :: (?ieeeMutex :: MVar (), Floating b) => (forall a. Floating a => IEEE a) -> IO b
perturb f = do
  (u,d,_,_) <- perturb' f
  return $ abs (u - d)

-- | Given something that produces a Floating, returns the magnitude of instability
-- introduced by perturbing the equation by rounding upwards and then downwards. This is
-- the absolute difference between the results as calculated rounding upwards and downwards,
-- and then divided by the averaged result.
perturbedMag :: (?ieeeMutex :: MVar (), Floating b) => (forall a. Floating a => IEEE a) -> IO b
perturbedMag f = do
  (u,d,_,_) <- perturb' f
  return $ abs ((u - d) / (u + d))