{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Numeric.Decimal.BoundedArithmetic
  (
  -- * Arith Monad
    Arith(..)
  , arithM
  , arithMaybe
  , arithEither
  , arithError
  -- * Bounded
  , plusBounded
  , minusBounded
  , timesBounded
  , absBounded
  , fromIntegerBounded
  , divBounded
  , quotBounded
  , quotRemBounded
  ) where

import Control.Exception
import Control.Monad.Catch
import GHC.Stack

-- | Monad for performing safe computation
data Arith a
  = Arith !a
  | ArithError !SomeException

instance Bounded a => Bounded (Arith a) where
  maxBound :: Arith a
maxBound = a -> Arith a
forall a. a -> Arith a
Arith a
forall a. Bounded a => a
maxBound
  minBound :: Arith a
minBound = a -> Arith a
forall a. a -> Arith a
Arith a
forall a. Bounded a => a
minBound

-- | Convert `Arith` computation to any `MonadThrow`
--
-- >>> import Numeric.Decimal
-- >>> :set -XDataKinds
-- >>> arithM (1.1 * 123 :: Arith (Decimal RoundDown 3 Int))
-- 135.300
-- >>> arithM (1.1 - 123 :: Arith (Decimal RoundDown 3 Word))
-- *** Exception: arithmetic underflow
-- >>> 1.1 - 123 :: Arith (Decimal RoundDown 3 Word)
-- ArithError arithmetic underflow
--
-- @since 0.2.0
arithM :: MonadThrow m => Arith a -> m a
arithM :: Arith a -> m a
arithM = \case
  Arith a
a -> a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
  ArithError SomeException
exc -> SomeException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM SomeException
exc

-- | A version of `arithM` restricted to `Maybe`
--
-- @since 0.2.0
arithMaybe :: Arith a -> Maybe a
arithMaybe :: Arith a -> Maybe a
arithMaybe = Arith a -> Maybe a
forall (m :: * -> *) a. MonadThrow m => Arith a -> m a
arithM

-- | A version of `arithM` restricted to `Either`
--
-- @since 0.2.0
arithEither :: Arith a -> Either SomeException a
arithEither :: Arith a -> Either SomeException a
arithEither = Arith a -> Either SomeException a
forall (m :: * -> *) a. MonadThrow m => Arith a -> m a
arithM


-- | Throws a `userError` on any `Arith` failure. Should only be used as a helper for
-- testing and development.
--
-- @since 0.2.1
arithError :: HasCallStack => Arith a -> a
arithError :: Arith a -> a
arithError = \case
  Arith a
a -> a
a
  ArithError SomeException
exc -> [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ SomeException -> [Char]
forall e. Exception e => e -> [Char]
displayException SomeException
exc



instance Show a => Show (Arith a) where
  showsPrec :: Int -> Arith a -> ShowS
showsPrec Int
n Arith a
r =
    case Arith a
r of
      Arith a
a -> [Char] -> ShowS -> ShowS
showsA [Char]
"Arith" (a -> ShowS
forall a. Show a => a -> ShowS
shows a
a)
      ArithError SomeException
exc -> [Char] -> ShowS -> ShowS
showsA [Char]
"ArithError" (SomeException -> [Char]
forall e. Exception e => e -> [Char]
displayException SomeException
exc [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++)
    where
      showsA :: [Char] -> ShowS -> ShowS
showsA [Char]
prefix ShowS
content =
        let showsExc :: ShowS
showsExc = ([Char]
prefix [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++) ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char
' 'Char -> ShowS
forall a. a -> [a] -> [a]
:) ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
content
         in if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
              then ShowS
showsExc
              else (Char
'(' Char -> ShowS
forall a. a -> [a] -> [a]
:) ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
showsExc ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char
')' Char -> ShowS
forall a. a -> [a] -> [a]
:)

instance Functor Arith where
  fmap :: (a -> b) -> Arith a -> Arith b
fmap a -> b
f Arith a
a =
    case Arith a
a of
      Arith a
r -> b -> Arith b
forall a. a -> Arith a
Arith (a -> b
f a
r)
      ArithError SomeException
exc -> SomeException -> Arith b
forall a. SomeException -> Arith a
ArithError SomeException
exc
  {-# INLINE fmap #-}

instance Applicative Arith where
  pure :: a -> Arith a
pure = a -> Arith a
forall a. a -> Arith a
Arith
  {-# INLINE pure #-}
  <*> :: Arith (a -> b) -> Arith a -> Arith b
(<*>) Arith (a -> b)
fa Arith a
a =
    case Arith (a -> b)
fa of
      Arith a -> b
fr ->
        case Arith a
a of
          Arith a
r -> b -> Arith b
forall a. a -> Arith a
Arith (a -> b
fr a
r)
          ArithError SomeException
exc -> SomeException -> Arith b
forall a. SomeException -> Arith a
ArithError SomeException
exc
      ArithError SomeException
exc -> SomeException -> Arith b
forall a. SomeException -> Arith a
ArithError SomeException
exc
  {-# INLINE (<*>) #-}

instance Monad Arith where
  return :: a -> Arith a
return = a -> Arith a
forall a. a -> Arith a
Arith
  {-# INLINE return #-}
  >>= :: Arith a -> (a -> Arith b) -> Arith b
(>>=) Arith a
fa a -> Arith b
fab =
    case Arith a
fa of
      Arith a
fr -> a -> Arith b
fab a
fr
      ArithError SomeException
exc -> SomeException -> Arith b
forall a. SomeException -> Arith a
ArithError SomeException
exc
  {-# INLINE (>>=) #-}


instance MonadThrow Arith where
  throwM :: e -> Arith a
throwM = SomeException -> Arith a
forall a. SomeException -> Arith a
ArithError (SomeException -> Arith a) -> (e -> SomeException) -> e -> Arith a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> SomeException
forall e. Exception e => e -> SomeException
toException
  {-# INLINE throwM #-}


-----------------------------------
-- Bounded arithmetics ------------
-----------------------------------

-- | Add two bounded numbers while checking for `Overflow`/`Underflow`
--
-- @since 0.1.0
plusBounded :: (MonadThrow m, Ord a, Num a, Bounded a) => a -> a -> m a
plusBounded :: a -> a -> m a
plusBounded a
x a
y
  | Bool
sameSig Bool -> Bool -> Bool
&& a
sigX a -> a -> Bool
forall a. Eq a => a -> a -> Bool
==  a
1 Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
forall a. Bounded a => a
maxBound a -> a -> a
forall a. Num a => a -> a -> a
- a
y = ArithException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
Overflow
  | Bool
sameSig Bool -> Bool -> Bool
&& a
sigX a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
forall a. Bounded a => a
minBound a -> a -> a
forall a. Num a => a -> a -> a
- a
y = ArithException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
Underflow
  | Bool
otherwise = a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
y)
  where
    sigX :: a
sigX = a -> a
forall a. Num a => a -> a
signum a
x
    sigY :: a
sigY = a -> a
forall a. Num a => a -> a
signum a
y
    sameSig :: Bool
sameSig = a
sigX a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
sigY
{-# INLINABLE plusBounded #-}

-- | Subtract two bounded numbers while checking for `Overflow`/`Underflow`
--
-- @since 0.1.0
minusBounded :: (MonadThrow m, Ord a, Num a, Bounded a) => a -> a -> m a
minusBounded :: a -> a -> m a
minusBounded a
x a
y
  | a
sigY a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
forall a. Bounded a => a
maxBound a -> a -> a
forall a. Num a => a -> a -> a
+ a
y = ArithException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
Overflow
  | a
sigY a -> a -> Bool
forall a. Eq a => a -> a -> Bool
==  a
1 Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
forall a. Bounded a => a
minBound a -> a -> a
forall a. Num a => a -> a -> a
+ a
y = ArithException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
Underflow
  | Bool
otherwise = a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
y)
  where sigY :: a
sigY = a -> a
forall a. Num a => a -> a
signum a
y
{-# INLINABLE minusBounded #-}

-- | Compute absolute value, while checking for `Overflow`
--
-- @since 0.2.0
absBounded :: (MonadThrow m, Num p, Ord p) => p -> m p
absBounded :: p -> m p
absBounded p
d
  | p
absd p -> p -> Bool
forall a. Ord a => a -> a -> Bool
< p
0 = ArithException -> m p
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
Overflow
  | Bool
otherwise = p -> m p
forall (f :: * -> *) a. Applicative f => a -> f a
pure p
absd
  where
    absd :: p
absd = p -> p
forall a. Num a => a -> a
abs p
d
{-# INLINABLE absBounded #-}


-- | Divide two numbers while checking for `Overflow` and `DivideByZero`
--
-- @since 0.1.0
divBounded :: (MonadThrow m, Integral a, Bounded a) => a -> a -> m a
divBounded :: a -> a -> m a
divBounded a
x a
y
  | a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = ArithException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
DivideByZero
  | a -> a
forall a. Num a => a -> a
signum a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 Bool -> Bool -> Bool
&& a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Bounded a => a
minBound = ArithException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
Overflow
    ------------------- ^ Here we deal with special case overflow when (minBound * (-1))
  | Bool
otherwise = a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x a -> a -> a
forall a. Integral a => a -> a -> a
`div` a
y)
{-# INLINABLE divBounded #-}


-- | Find quotient of two numbers while checking for `Overflow` and `DivideByZero`
--
-- @since 0.1.0
quotBounded :: (MonadThrow m, Integral a, Bounded a) => a -> a -> m a
quotBounded :: a -> a -> m a
quotBounded a
x a
y
  | a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = ArithException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
DivideByZero
  | a
sigY a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 Bool -> Bool -> Bool
&& a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Bounded a => a
minBound = ArithException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
Overflow
    ------------------- ^ Here we deal with special case overflow when (minBound * (-1))
  | Bool
otherwise = a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
y)
  where
    sigY :: a
sigY = a -> a
forall a. Num a => a -> a
signum a
y -- Guard against wraparound in case of unsigned Word
{-# INLINABLE quotBounded #-}

-- | Find quotient an remainder of two numbers while checking for `Overflow` and
-- `DivideByZero`
--
-- @since 0.1.0
quotRemBounded :: (MonadThrow m, Integral a, Bounded a) => a -> a -> m (a, a)
quotRemBounded :: a -> a -> m (a, a)
quotRemBounded a
x a
y
  | a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = ArithException -> m (a, a)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
DivideByZero
  | a
sigY a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 Bool -> Bool -> Bool
&& a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Bounded a => a
minBound = ArithException -> m (a, a)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
Overflow
  | Bool
otherwise = (a, a) -> m (a, a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x a -> a -> (a, a)
forall a. Integral a => a -> a -> (a, a)
`quotRem` a
y)
  where
    sigY :: a
sigY = a -> a
forall a. Num a => a -> a
signum a
y
{-# INLINABLE quotRemBounded #-}


-- | Multiply two numbers while checking for `Overflow`
--
-- @since 0.1.0
timesBounded :: (MonadThrow m, Integral a, Bounded a) => a -> a -> m a
timesBounded :: a -> a -> m a
timesBounded a
x a
y
  | a
sigY a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 Bool -> Bool -> Bool
&& a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Bounded a => a
minBound = ArithException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
Overflow
  | a -> a
forall a. Num a => a -> a
signum a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 Bool -> Bool -> Bool
&& a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Bounded a => a
minBound = ArithException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
Overflow
  | a
sigY a -> a -> Bool
forall a. Eq a => a -> a -> Bool
==  a
1 Bool -> Bool -> Bool
&& (a
minBoundQuotY a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
x Bool -> Bool -> Bool
|| a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
maxBoundQuotY) = m a
forall a. m a
eitherOverUnder
  | a
sigY a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 Bool -> Bool -> Bool
&& a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= -a
1 Bool -> Bool -> Bool
&& (a
minBoundQuotY a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
x Bool -> Bool -> Bool
|| a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
maxBoundQuotY) = m a
forall a. m a
eitherOverUnder
  | Bool
otherwise = a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
y)
  where
    sigY :: a
sigY = a -> a
forall a. Num a => a -> a
signum a
y
    maxBoundQuotY :: a
maxBoundQuotY = a
forall a. Bounded a => a
maxBound a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
y
    minBoundQuotY :: a
minBoundQuotY = a
forall a. Bounded a => a
minBound a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
y
    eitherOverUnder :: m a
eitherOverUnder = ArithException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (ArithException -> m a) -> ArithException -> m a
forall a b. (a -> b) -> a -> b
$ if a
sigY a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a -> a
forall a. Num a => a -> a
signum a
x then ArithException
Overflow else ArithException
Underflow
{-# INLINABLE timesBounded #-}

-- | Convert from an unbounded `Integer` to a `Bounded` `Integral`, while checking for
-- bounds and raising `Overflow`/`Underflow`
--
-- @since 0.1.0
fromIntegerBounded ::
     forall m a. (MonadThrow m, Integral a, Bounded a)
  => Integer
  -> m a
fromIntegerBounded :: Integer -> m a
fromIntegerBounded Integer
x
  | Integer
x Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> a -> Integer
forall a. Integral a => a -> Integer
toInteger (a
forall a. Bounded a => a
maxBound :: a) = ArithException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
Overflow
  | Integer
x Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< a -> Integer
forall a. Integral a => a -> Integer
toInteger (a
forall a. Bounded a => a
minBound :: a) = ArithException -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM ArithException
Underflow
  | Bool
otherwise = a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ Integer -> a
forall a. Num a => Integer -> a
fromInteger Integer
x
{-# INLINABLE fromIntegerBounded #-}