{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Grisette.Internal.Core.Data.Class.SafeDivision
( ArithException (..),
SafeDivision (..),
)
where
import Control.Exception (ArithException (DivideByZero, Overflow, Underflow))
import Control.Monad.Except (MonadError (throwError))
import Data.Int (Int16, Int32, Int64, Int8)
import Data.Word (Word16, Word32, Word64, Word8)
import GHC.TypeNats (KnownNat, type (<=))
import Grisette.Internal.Core.Control.Monad.Union (MonadUnion)
import Grisette.Internal.Core.Data.Class.LogicalOp (LogicalOp ((.&&)))
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SEq (SEq ((.==)))
import Grisette.Internal.Core.Data.Class.SimpleMergeable
( mrgIf,
)
import Grisette.Internal.Core.Data.Class.Solvable (Solvable (con))
import Grisette.Internal.Core.Data.Class.TryMerge
( TryMerge,
mrgSingle,
)
import Grisette.Internal.SymPrim.BV
( IntN,
WordN,
)
import Grisette.Internal.SymPrim.Prim.Term
( PEvalDivModIntegralTerm
( pevalDivIntegralTerm,
pevalModIntegralTerm,
pevalQuotIntegralTerm,
pevalRemIntegralTerm
),
)
import Grisette.Internal.SymPrim.SymBV
( SymIntN (SymIntN),
SymWordN (SymWordN),
)
import Grisette.Internal.SymPrim.SymInteger (SymInteger (SymInteger))
import Grisette.Lib.Control.Monad (mrgReturn)
import Grisette.Lib.Control.Monad.Except (mrgThrowError)
import Grisette.Lib.Data.Functor (mrgFmap)
class (MonadError e m, TryMerge m, Mergeable a) => SafeDivision e a m where
safeDiv :: a -> a -> m a
safeDiv a
l a
r = ((a, a) -> a) -> m (a, a) -> m a
forall (f :: * -> *) a b.
(TryMerge f, Mergeable a, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a, a) -> a
forall a b. (a, b) -> a
fst (m (a, a) -> m a) -> m (a, a) -> m a
forall a b. (a -> b) -> a -> b
$ a -> a -> m (a, a)
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m (a, a)
safeDivMod a
l a
r
{-# INLINE safeDiv #-}
safeMod :: a -> a -> m a
safeMod a
l a
r = ((a, a) -> a) -> m (a, a) -> m a
forall (f :: * -> *) a b.
(TryMerge f, Mergeable a, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a, a) -> a
forall a b. (a, b) -> b
snd (m (a, a) -> m a) -> m (a, a) -> m a
forall a b. (a -> b) -> a -> b
$ a -> a -> m (a, a)
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m (a, a)
safeDivMod a
l a
r
{-# INLINE safeMod #-}
safeDivMod :: a -> a -> m (a, a)
safeDivMod a
l a
r = do
a
d <- a -> a -> m a
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m a
safeDiv a
l a
r
a
m <- a -> a -> m a
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m a
safeMod a
l a
r
(a, a) -> m (a, a)
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (a
d, a
m)
{-# INLINE safeDivMod #-}
safeQuot :: a -> a -> m a
safeQuot a
l a
r = ((a, a) -> a) -> m (a, a) -> m a
forall (f :: * -> *) a b.
(TryMerge f, Mergeable a, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a, a) -> a
forall a b. (a, b) -> a
fst (m (a, a) -> m a) -> m (a, a) -> m a
forall a b. (a -> b) -> a -> b
$ a -> a -> m (a, a)
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m (a, a)
safeQuotRem a
l a
r
{-# INLINE safeQuot #-}
safeRem :: a -> a -> m a
safeRem a
l a
r = ((a, a) -> a) -> m (a, a) -> m a
forall (f :: * -> *) a b.
(TryMerge f, Mergeable a, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a, a) -> a
forall a b. (a, b) -> b
snd (m (a, a) -> m a) -> m (a, a) -> m a
forall a b. (a -> b) -> a -> b
$ a -> a -> m (a, a)
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m (a, a)
safeQuotRem a
l a
r
{-# INLINE safeRem #-}
safeQuotRem :: a -> a -> m (a, a)
safeQuotRem a
l a
r = do
a
q <- a -> a -> m a
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m a
safeQuot a
l a
r
a
m <- a -> a -> m a
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m a
safeRem a
l a
r
(a, a) -> m (a, a)
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (a
q, a
m)
{-# INLINE safeQuotRem #-}
{-# MINIMAL
((safeDiv, safeMod) | safeDivMod),
((safeQuot, safeRem) | safeQuotRem)
#-}
concreteSafeDivisionHelper ::
(MonadError ArithException m, TryMerge m, Integral a, Mergeable r) =>
(a -> a -> r) ->
a ->
a ->
m r
concreteSafeDivisionHelper :: forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper a -> a -> r
f a
l a
r
| a
r a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = ArithException -> m r
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
DivideByZero
| Bool
otherwise = r -> m r
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (r -> m r) -> r -> m r
forall a b. (a -> b) -> a -> b
$ a -> a -> r
f a
l a
r
concreteSignedBoundedSafeDivisionHelper ::
( MonadError ArithException m,
TryMerge m,
Integral a,
Bounded a,
Mergeable r
) =>
(a -> a -> r) ->
a ->
a ->
m r
concreteSignedBoundedSafeDivisionHelper :: forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a, Bounded a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSignedBoundedSafeDivisionHelper a -> a -> r
f a
l a
r
| a
r a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = ArithException -> m r
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
DivideByZero
| a
l a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Bounded a => a
minBound Bool -> Bool -> Bool
&& a
r a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 = ArithException -> m r
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
Overflow
| Bool
otherwise = r -> m r
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (r -> m r) -> r -> m r
forall a b. (a -> b) -> a -> b
$ a -> a -> r
f a
l a
r
#define QUOTE() '
#define QID(a) a
#define QRIGHT(a) QID(a)'
#define QRIGHTT(a) QID(a)' t'
#define QRIGHTU(a) QID(a)' _'
#define SAFE_DIVISION_CONCRETE(type) \
instance (MonadError ArithException m, TryMerge m) => \
SafeDivision ArithException type m where \
safeDiv = concreteSafeDivisionHelper div; \
safeMod = concreteSafeDivisionHelper mod; \
safeDivMod = concreteSafeDivisionHelper divMod; \
safeQuot = concreteSafeDivisionHelper quot; \
safeRem = concreteSafeDivisionHelper rem; \
safeQuotRem = concreteSafeDivisionHelper quotRem
#define SAFE_DIVISION_CONCRETE_SIGNED_BOUNDED(type) \
instance (MonadError ArithException m, TryMerge m) => \
SafeDivision ArithException type m where \
safeDiv = concreteSignedBoundedSafeDivisionHelper div; \
safeMod = concreteSafeDivisionHelper mod; \
safeDivMod = concreteSignedBoundedSafeDivisionHelper divMod; \
safeQuot = concreteSignedBoundedSafeDivisionHelper quot; \
safeRem = concreteSafeDivisionHelper rem; \
safeQuotRem = concreteSignedBoundedSafeDivisionHelper quotRem
#define SAFE_DIVISION_CONCRETE_BV(type) \
instance \
(MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) => \
SafeDivision ArithException (type n) m where \
safeDiv = concreteSafeDivisionHelper div; \
safeMod = concreteSafeDivisionHelper mod; \
safeDivMod = concreteSafeDivisionHelper divMod; \
safeQuot = concreteSafeDivisionHelper quot; \
safeRem = concreteSafeDivisionHelper rem; \
safeQuotRem = concreteSafeDivisionHelper quotRem
#if 1
SAFE_DIVISION_CONCRETE(Integer)
SAFE_DIVISION_CONCRETE_SIGNED_BOUNDED(Int8)
SAFE_DIVISION_CONCRETE_SIGNED_BOUNDED(Int16)
SAFE_DIVISION_CONCRETE_SIGNED_BOUNDED(Int32)
SAFE_DIVISION_CONCRETE_SIGNED_BOUNDED(Int64)
SAFE_DIVISION_CONCRETE_SIGNED_BOUNDED(Int)
SAFE_DIVISION_CONCRETE(Word8)
SAFE_DIVISION_CONCRETE(Word16)
SAFE_DIVISION_CONCRETE(Word32)
SAFE_DIVISION_CONCRETE(Word64)
SAFE_DIVISION_CONCRETE(Word)
instance
(MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) =>
SafeDivision ArithException (IntN n) m where
safeDiv :: IntN n -> IntN n -> m (IntN n)
safeDiv = (IntN n -> IntN n -> IntN n) -> IntN n -> IntN n -> m (IntN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a, Bounded a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSignedBoundedSafeDivisionHelper IntN n -> IntN n -> IntN n
forall a. Integral a => a -> a -> a
div
safeMod :: IntN n -> IntN n -> m (IntN n)
safeMod = (IntN n -> IntN n -> IntN n) -> IntN n -> IntN n -> m (IntN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper IntN n -> IntN n -> IntN n
forall a. Integral a => a -> a -> a
mod
safeDivMod :: IntN n -> IntN n -> m (IntN n, IntN n)
safeDivMod = (IntN n -> IntN n -> (IntN n, IntN n))
-> IntN n -> IntN n -> m (IntN n, IntN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a, Bounded a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSignedBoundedSafeDivisionHelper IntN n -> IntN n -> (IntN n, IntN n)
forall a. Integral a => a -> a -> (a, a)
divMod
safeQuot :: IntN n -> IntN n -> m (IntN n)
safeQuot = (IntN n -> IntN n -> IntN n) -> IntN n -> IntN n -> m (IntN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a, Bounded a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSignedBoundedSafeDivisionHelper IntN n -> IntN n -> IntN n
forall a. Integral a => a -> a -> a
quot
safeRem :: IntN n -> IntN n -> m (IntN n)
safeRem = (IntN n -> IntN n -> IntN n) -> IntN n -> IntN n -> m (IntN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper IntN n -> IntN n -> IntN n
forall a. Integral a => a -> a -> a
rem
safeQuotRem :: IntN n -> IntN n -> m (IntN n, IntN n)
safeQuotRem = (IntN n -> IntN n -> (IntN n, IntN n))
-> IntN n -> IntN n -> m (IntN n, IntN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a, Bounded a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSignedBoundedSafeDivisionHelper IntN n -> IntN n -> (IntN n, IntN n)
forall a. Integral a => a -> a -> (a, a)
quotRem
instance
(MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) =>
SafeDivision ArithException (WordN n) m where
safeDiv :: WordN n -> WordN n -> m (WordN n)
safeDiv = (WordN n -> WordN n -> WordN n)
-> WordN n -> WordN n -> m (WordN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper WordN n -> WordN n -> WordN n
forall a. Integral a => a -> a -> a
div
safeMod :: WordN n -> WordN n -> m (WordN n)
safeMod = (WordN n -> WordN n -> WordN n)
-> WordN n -> WordN n -> m (WordN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper WordN n -> WordN n -> WordN n
forall a. Integral a => a -> a -> a
mod
safeDivMod :: WordN n -> WordN n -> m (WordN n, WordN n)
safeDivMod = (WordN n -> WordN n -> (WordN n, WordN n))
-> WordN n -> WordN n -> m (WordN n, WordN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper WordN n -> WordN n -> (WordN n, WordN n)
forall a. Integral a => a -> a -> (a, a)
divMod
safeQuot :: WordN n -> WordN n -> m (WordN n)
safeQuot = (WordN n -> WordN n -> WordN n)
-> WordN n -> WordN n -> m (WordN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper WordN n -> WordN n -> WordN n
forall a. Integral a => a -> a -> a
quot
safeRem :: WordN n -> WordN n -> m (WordN n)
safeRem = (WordN n -> WordN n -> WordN n)
-> WordN n -> WordN n -> m (WordN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper WordN n -> WordN n -> WordN n
forall a. Integral a => a -> a -> a
rem
safeQuotRem :: WordN n -> WordN n -> m (WordN n, WordN n)
safeQuotRem = (WordN n -> WordN n -> (WordN n, WordN n))
-> WordN n -> WordN n -> m (WordN n, WordN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper WordN n -> WordN n -> (WordN n, WordN n)
forall a. Integral a => a -> a -> (a, a)
quotRem
#endif
#define SAFE_DIVISION_SYMBOLIC_FUNC(name, type, op) \
name (type l) rs@(type r) = \
mrgIf \
(rs .== con 0) \
(throwError DivideByZero) \
(mrgSingle $ type $ op l r); \
#define SAFE_DIVISION_SYMBOLIC_FUNC2(name, type, op1, op2) \
name (type l) rs@(type r) = \
mrgIf \
(rs .== con 0) \
(throwError DivideByZero) \
(mrgSingle (type $ op1 l r, type $ op2 l r)); \
#if 1
instance
(MonadUnion m, MonadError ArithException m) =>
SafeDivision ArithException SymInteger m where
SAFE_DIVISION_SYMBOLIC_FUNC(safeDiv, SymInteger, pevalDivIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC(safeMod, SymInteger, pevalModIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC(safeQuot, SymInteger, pevalQuotIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC(safeRem, SymInteger, pevalRemIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC2(safeDivMod, SymInteger, pevalDivIntegralTerm, pevalModIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC2(safeQuotRem, SymInteger, pevalQuotIntegralTerm, pevalRemIntegralTerm)
#endif
#define SAFE_DIVISION_SYMBOLIC_FUNC_BOUNDED_SIGNED(name, type, op) \
name ls@(type l) rs@(type r) = \
mrgIf \
(rs .== con 0) \
(throwError DivideByZero) \
(mrgIf (rs .== con (-1) .&& ls .== con minBound) \
(throwError Overflow) \
(mrgSingle $ type $ op l r)); \
#define SAFE_DIVISION_SYMBOLIC_FUNC2_BOUNDED_SIGNED(name, type, op1, op2) \
name ls@(type l) rs@(type r) = \
mrgIf \
(rs .== con 0) \
(throwError DivideByZero) \
(mrgIf (rs .== con (-1) .&& ls .== con minBound) \
(throwError Overflow) \
(mrgSingle (type $ op1 l r, type $ op2 l r))); \
#if 1
instance
(MonadError ArithException m, MonadUnion m, KnownNat n, 1 <= n) =>
SafeDivision ArithException (SymIntN n) m where
SAFE_DIVISION_SYMBOLIC_FUNC_BOUNDED_SIGNED(safeDiv, SymIntN, pevalDivIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC(safeMod, SymIntN, pevalModIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC_BOUNDED_SIGNED(safeQuot, SymIntN, pevalQuotIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC(safeRem, SymIntN, pevalRemIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC2_BOUNDED_SIGNED(safeDivMod, SymIntN, pevalDivIntegralTerm, pevalModIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC2_BOUNDED_SIGNED(safeQuotRem, SymIntN, pevalQuotIntegralTerm, pevalRemIntegralTerm)
#endif
#if 1
instance
(MonadError ArithException m, MonadUnion m, KnownNat n, 1 <= n) =>
SafeDivision ArithException (SymWordN n) m where
SAFE_DIVISION_SYMBOLIC_FUNC(safeDiv, SymWordN, pevalDivIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC(safeMod, SymWordN, pevalModIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC(safeQuot, SymWordN, pevalQuotIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC(safeRem, SymWordN, pevalRemIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC2(safeDivMod, SymWordN, pevalDivIntegralTerm, pevalModIntegralTerm)
SAFE_DIVISION_SYMBOLIC_FUNC2(safeQuotRem, SymWordN, pevalQuotIntegralTerm, pevalRemIntegralTerm)
#endif