{-# Language BlockArguments, OverloadedStrings #-}
{-# Language BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# Language GADTs #-}
module What4.Utils.FloatHelpers where

import qualified Control.Exception as Ex
import Data.Ratio(numerator,denominator)
import Data.Hashable
import GHC.Generics (Generic)
import GHC.Stack

import LibBF

import What4.BaseTypes
import What4.Panic (panic)

-- | Rounding modes for IEEE-754 floating point operations.
data RoundingMode
  = RNE -- ^ Round to nearest even.
  | RNA -- ^ Round to nearest away.
  | RTP -- ^ Round toward plus Infinity.
  | RTN -- ^ Round toward minus Infinity.
  | RTZ -- ^ Round toward zero.
  deriving (RoundingMode -> RoundingMode -> Bool
(RoundingMode -> RoundingMode -> Bool)
-> (RoundingMode -> RoundingMode -> Bool) -> Eq RoundingMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RoundingMode -> RoundingMode -> Bool
$c/= :: RoundingMode -> RoundingMode -> Bool
== :: RoundingMode -> RoundingMode -> Bool
$c== :: RoundingMode -> RoundingMode -> Bool
Eq, (forall x. RoundingMode -> Rep RoundingMode x)
-> (forall x. Rep RoundingMode x -> RoundingMode)
-> Generic RoundingMode
forall x. Rep RoundingMode x -> RoundingMode
forall x. RoundingMode -> Rep RoundingMode x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep RoundingMode x -> RoundingMode
$cfrom :: forall x. RoundingMode -> Rep RoundingMode x
Generic, Eq RoundingMode
Eq RoundingMode
-> (RoundingMode -> RoundingMode -> Ordering)
-> (RoundingMode -> RoundingMode -> Bool)
-> (RoundingMode -> RoundingMode -> Bool)
-> (RoundingMode -> RoundingMode -> Bool)
-> (RoundingMode -> RoundingMode -> Bool)
-> (RoundingMode -> RoundingMode -> RoundingMode)
-> (RoundingMode -> RoundingMode -> RoundingMode)
-> Ord RoundingMode
RoundingMode -> RoundingMode -> Bool
RoundingMode -> RoundingMode -> Ordering
RoundingMode -> RoundingMode -> RoundingMode
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: RoundingMode -> RoundingMode -> RoundingMode
$cmin :: RoundingMode -> RoundingMode -> RoundingMode
max :: RoundingMode -> RoundingMode -> RoundingMode
$cmax :: RoundingMode -> RoundingMode -> RoundingMode
>= :: RoundingMode -> RoundingMode -> Bool
$c>= :: RoundingMode -> RoundingMode -> Bool
> :: RoundingMode -> RoundingMode -> Bool
$c> :: RoundingMode -> RoundingMode -> Bool
<= :: RoundingMode -> RoundingMode -> Bool
$c<= :: RoundingMode -> RoundingMode -> Bool
< :: RoundingMode -> RoundingMode -> Bool
$c< :: RoundingMode -> RoundingMode -> Bool
compare :: RoundingMode -> RoundingMode -> Ordering
$ccompare :: RoundingMode -> RoundingMode -> Ordering
$cp1Ord :: Eq RoundingMode
Ord, Int -> RoundingMode -> ShowS
[RoundingMode] -> ShowS
RoundingMode -> String
(Int -> RoundingMode -> ShowS)
-> (RoundingMode -> String)
-> ([RoundingMode] -> ShowS)
-> Show RoundingMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RoundingMode] -> ShowS
$cshowList :: [RoundingMode] -> ShowS
show :: RoundingMode -> String
$cshow :: RoundingMode -> String
showsPrec :: Int -> RoundingMode -> ShowS
$cshowsPrec :: Int -> RoundingMode -> ShowS
Show, Int -> RoundingMode
RoundingMode -> Int
RoundingMode -> [RoundingMode]
RoundingMode -> RoundingMode
RoundingMode -> RoundingMode -> [RoundingMode]
RoundingMode -> RoundingMode -> RoundingMode -> [RoundingMode]
(RoundingMode -> RoundingMode)
-> (RoundingMode -> RoundingMode)
-> (Int -> RoundingMode)
-> (RoundingMode -> Int)
-> (RoundingMode -> [RoundingMode])
-> (RoundingMode -> RoundingMode -> [RoundingMode])
-> (RoundingMode -> RoundingMode -> [RoundingMode])
-> (RoundingMode -> RoundingMode -> RoundingMode -> [RoundingMode])
-> Enum RoundingMode
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: RoundingMode -> RoundingMode -> RoundingMode -> [RoundingMode]
$cenumFromThenTo :: RoundingMode -> RoundingMode -> RoundingMode -> [RoundingMode]
enumFromTo :: RoundingMode -> RoundingMode -> [RoundingMode]
$cenumFromTo :: RoundingMode -> RoundingMode -> [RoundingMode]
enumFromThen :: RoundingMode -> RoundingMode -> [RoundingMode]
$cenumFromThen :: RoundingMode -> RoundingMode -> [RoundingMode]
enumFrom :: RoundingMode -> [RoundingMode]
$cenumFrom :: RoundingMode -> [RoundingMode]
fromEnum :: RoundingMode -> Int
$cfromEnum :: RoundingMode -> Int
toEnum :: Int -> RoundingMode
$ctoEnum :: Int -> RoundingMode
pred :: RoundingMode -> RoundingMode
$cpred :: RoundingMode -> RoundingMode
succ :: RoundingMode -> RoundingMode
$csucc :: RoundingMode -> RoundingMode
Enum)

instance Hashable RoundingMode

bfStatus :: HasCallStack => (a, Status) -> a
bfStatus :: (a, Status) -> a
bfStatus (a
_, Status
MemError)     = AsyncException -> a
forall a e. Exception e => e -> a
Ex.throw AsyncException
Ex.HeapOverflow
bfStatus (a
x,Status
_)             = a
x

fppOpts :: FloatPrecisionRepr fpp -> RoundingMode -> BFOpts
fppOpts :: FloatPrecisionRepr fpp -> RoundingMode -> BFOpts
fppOpts (FloatingPointPrecisionRepr NatRepr eb
eb NatRepr sb
sb) RoundingMode
r =
  Integer -> Integer -> RoundMode -> BFOpts
fpOpts (NatRepr eb -> Integer
forall (n :: Nat). NatRepr n -> Integer
intValue NatRepr eb
eb) (NatRepr sb -> Integer
forall (n :: Nat). NatRepr n -> Integer
intValue NatRepr sb
sb) (RoundingMode -> RoundMode
toRoundMode RoundingMode
r)

toRoundMode :: RoundingMode -> RoundMode
toRoundMode :: RoundingMode -> RoundMode
toRoundMode RoundingMode
RNE = RoundMode
NearEven
toRoundMode RoundingMode
RNA = RoundMode
NearAway
toRoundMode RoundingMode
RTP = RoundMode
ToPosInf
toRoundMode RoundingMode
RTN = RoundMode
ToNegInf
toRoundMode RoundingMode
RTZ = RoundMode
ToZero

-- | Make LibBF options for the given precision and rounding mode.
fpOpts :: Integer -> Integer -> RoundMode -> BFOpts
fpOpts :: Integer -> Integer -> RoundMode -> BFOpts
fpOpts Integer
e Integer
p RoundMode
r =
  case Maybe BFOpts
ok of
    Just BFOpts
opts -> BFOpts
opts
    Maybe BFOpts
Nothing   -> String -> [String] -> BFOpts
forall a. HasCallStack => String -> [String] -> a
panic String
"floatOpts" [ String
"Invalid Float size"
                                   , String
"exponent: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
e
                                   , String
"precision: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
p
                                   ]
  where
  ok :: Maybe BFOpts
ok = do BFOpts
eb <- (Int -> BFOpts) -> Int -> Int -> Integer -> Maybe BFOpts
forall a a t a.
(Integral a, Integral a, Num t) =>
(t -> a) -> a -> a -> Integer -> Maybe a
rng Int -> BFOpts
expBits Int
expBitsMin Int
expBitsMax Integer
e
          BFOpts
pb <- (Word -> BFOpts) -> Int -> Int -> Integer -> Maybe BFOpts
forall a a t a.
(Integral a, Integral a, Num t) =>
(t -> a) -> a -> a -> Integer -> Maybe a
rng Word -> BFOpts
precBits Int
precBitsMin Int
precBitsMax Integer
p
          BFOpts -> Maybe BFOpts
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (BFOpts
eb BFOpts -> BFOpts -> BFOpts
forall a. Semigroup a => a -> a -> a
<> BFOpts
pb BFOpts -> BFOpts -> BFOpts
forall a. Semigroup a => a -> a -> a
<> BFOpts
allowSubnormal BFOpts -> BFOpts -> BFOpts
forall a. Semigroup a => a -> a -> a
<> RoundMode -> BFOpts
rnd RoundMode
r)

  rng :: (t -> a) -> a -> a -> Integer -> Maybe a
rng t -> a
f a
a a
b Integer
x = if a -> Integer
forall a. Integral a => a -> Integer
toInteger a
a Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
x Bool -> Bool -> Bool
&& Integer
x Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= a -> Integer
forall a. Integral a => a -> Integer
toInteger a
b
                  then a -> Maybe a
forall a. a -> Maybe a
Just (t -> a
f (Integer -> t
forall a. Num a => Integer -> a
fromInteger Integer
x))
                  else Maybe a
forall a. Maybe a
Nothing


-- | Make a floating point number from an integer, using the given rounding mode
floatFromInteger :: BFOpts -> Integer -> BigFloat
floatFromInteger :: BFOpts -> Integer -> BigFloat
floatFromInteger BFOpts
opts Integer
i = (BigFloat, Status) -> BigFloat
forall a. HasCallStack => (a, Status) -> a
bfStatus (BFOpts -> BigFloat -> (BigFloat, Status)
bfRoundFloat BFOpts
opts (Integer -> BigFloat
bfFromInteger Integer
i))

-- | Make a floating point number from a rational, using the given rounding mode
floatFromRational :: BFOpts -> Rational -> BigFloat
floatFromRational :: BFOpts -> Rational -> BigFloat
floatFromRational BFOpts
opts Rational
rat = (BigFloat, Status) -> BigFloat
forall a. HasCallStack => (a, Status) -> a
bfStatus
    if Integer
den Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
1 then BFOpts -> BigFloat -> (BigFloat, Status)
bfRoundFloat BFOpts
opts BigFloat
num
                else BFOpts -> BigFloat -> BigFloat -> (BigFloat, Status)
bfDiv BFOpts
opts BigFloat
num (Integer -> BigFloat
bfFromInteger Integer
den)
  where

  num :: BigFloat
num   = Integer -> BigFloat
bfFromInteger (Rational -> Integer
forall a. Ratio a -> a
numerator Rational
rat)
  den :: Integer
den   = Rational -> Integer
forall a. Ratio a -> a
denominator Rational
rat


-- | Convert a floating point number to a rational, if possible.
floatToRational :: BigFloat -> Maybe Rational
floatToRational :: BigFloat -> Maybe Rational
floatToRational BigFloat
bf =
  case BigFloat -> BFRep
bfToRep BigFloat
bf of
    BFRep
BFNaN -> Maybe Rational
forall a. Maybe a
Nothing
    BFRep Sign
s BFNum
num ->
      case BFNum
num of
        BFNum
Inf  -> Maybe Rational
forall a. Maybe a
Nothing
        BFNum
Zero -> Rational -> Maybe Rational
forall a. a -> Maybe a
Just Rational
0
        Num Integer
i Int64
ev -> Rational -> Maybe Rational
forall a. a -> Maybe a
Just case Sign
s of
                           Sign
Pos -> Rational
ab
                           Sign
Neg -> Rational -> Rational
forall a. Num a => a -> a
negate Rational
ab
          where ab :: Rational
ab = Integer -> Rational
forall a. Num a => Integer -> a
fromInteger Integer
i Rational -> Rational -> Rational
forall a. Num a => a -> a -> a
* (Rational
2 Rational -> Int64 -> Rational
forall a b. (Fractional a, Integral b) => a -> b -> a
^^ Int64
ev)

-- | Convert a floating point number to an integer, if possible.
floatToInteger :: RoundingMode -> BigFloat -> Maybe Integer
floatToInteger :: RoundingMode -> BigFloat -> Maybe Integer
floatToInteger RoundingMode
r BigFloat
fp =
  do Rational
rat <- BigFloat -> Maybe Rational
floatToRational BigFloat
fp
     Integer -> Maybe Integer
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure case RoundingMode
r of
            RoundingMode
RNE -> Rational -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
round Rational
rat
            RoundingMode
RNA -> if Rational
rat Rational -> Rational -> Bool
forall a. Ord a => a -> a -> Bool
> Rational
0 then Rational -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
ceiling Rational
rat else Rational -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
floor Rational
rat
            RoundingMode
RTP -> Rational -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
ceiling Rational
rat
            RoundingMode
RTN -> Rational -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
floor Rational
rat
            RoundingMode
RTZ -> Rational -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
truncate Rational
rat

floatRoundToInt :: HasCallStack =>
  FloatPrecisionRepr fpp -> RoundingMode -> BigFloat -> BigFloat
floatRoundToInt :: FloatPrecisionRepr fpp -> RoundingMode -> BigFloat -> BigFloat
floatRoundToInt FloatPrecisionRepr fpp
fpp RoundingMode
r BigFloat
bf =
  (BigFloat, Status) -> BigFloat
forall a. HasCallStack => (a, Status) -> a
bfStatus (BFOpts -> BigFloat -> (BigFloat, Status)
bfRoundFloat (FloatPrecisionRepr fpp -> RoundingMode -> BFOpts
forall (fpp :: FloatPrecision).
FloatPrecisionRepr fpp -> RoundingMode -> BFOpts
fppOpts FloatPrecisionRepr fpp
fpp RoundingMode
r) ((BigFloat, Status) -> BigFloat
forall a. HasCallStack => (a, Status) -> a
bfStatus (RoundMode -> BigFloat -> (BigFloat, Status)
bfRoundInt (RoundingMode -> RoundMode
toRoundMode RoundingMode
r) BigFloat
bf)))