{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.Class.SafeDivision
-- Copyright   :   (c) Sirui Lu 2021-2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.Data.Class.SafeDivision
  ( ArithException (..),
    SafeDivision (..),
  )
where

import Control.Exception (ArithException (DivideByZero, Overflow, Underflow))
import Control.Monad.Except (MonadError (throwError))
import Data.Int (Int16, Int32, Int64, Int8)
import Data.Word (Word16, Word32, Word64, Word8)
import GHC.TypeNats (KnownNat, type (<=))
import Grisette.Internal.Core.Control.Monad.Union (MonadUnion)
import Grisette.Internal.Core.Data.Class.LogicalOp (LogicalOp ((.&&)))
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SEq (SEq ((.==)))
import Grisette.Internal.Core.Data.Class.SimpleMergeable
  ( mrgIf,
  )
import Grisette.Internal.Core.Data.Class.Solvable (Solvable (con))
import Grisette.Internal.Core.Data.Class.TryMerge
  ( TryMerge,
    mrgSingle,
  )
import Grisette.Internal.SymPrim.BV
  ( IntN,
    WordN,
  )
import Grisette.Internal.SymPrim.Prim.Term
  ( PEvalDivModIntegralTerm
      ( pevalDivIntegralTerm,
        pevalModIntegralTerm,
        pevalQuotIntegralTerm,
        pevalRemIntegralTerm
      ),
  )
import Grisette.Internal.SymPrim.SymBV
  ( SymIntN (SymIntN),
    SymWordN (SymWordN),
  )
import Grisette.Internal.SymPrim.SymInteger (SymInteger (SymInteger))
import Grisette.Lib.Control.Monad (mrgReturn)
import Grisette.Lib.Control.Monad.Except (mrgThrowError)
import Grisette.Lib.Data.Functor (mrgFmap)

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Control.Monad.Except

-- | Safe division with monadic error handling in multi-path
-- execution. These procedures throw an exception when the
-- divisor is zero. The result should be able to handle errors with
-- `MonadError`.
class (MonadError e m, TryMerge m, Mergeable a) => SafeDivision e a m where
  -- | Safe signed 'div' with monadic error handling in multi-path execution.
  --
  -- >>> safeDiv (ssym "a") (ssym "b") :: ExceptT ArithException UnionM SymInteger
  -- ExceptT {If (= b 0) (Left divide by zero) (Right (div a b))}
  safeDiv :: a -> a -> m a
  safeDiv a
l a
r = ((a, a) -> a) -> m (a, a) -> m a
forall (f :: * -> *) a b.
(TryMerge f, Mergeable a, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a, a) -> a
forall a b. (a, b) -> a
fst (m (a, a) -> m a) -> m (a, a) -> m a
forall a b. (a -> b) -> a -> b
$ a -> a -> m (a, a)
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m (a, a)
safeDivMod a
l a
r
  {-# INLINE safeDiv #-}

  -- | Safe signed 'mod' with monadic error handling in multi-path execution.
  --
  -- >>> safeMod (ssym "a") (ssym "b") :: ExceptT ArithException UnionM SymInteger
  -- ExceptT {If (= b 0) (Left divide by zero) (Right (mod a b))}
  safeMod :: a -> a -> m a
  safeMod a
l a
r = ((a, a) -> a) -> m (a, a) -> m a
forall (f :: * -> *) a b.
(TryMerge f, Mergeable a, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a, a) -> a
forall a b. (a, b) -> b
snd (m (a, a) -> m a) -> m (a, a) -> m a
forall a b. (a -> b) -> a -> b
$ a -> a -> m (a, a)
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m (a, a)
safeDivMod a
l a
r
  {-# INLINE safeMod #-}

  -- | Safe signed 'divMod' with monadic error handling in multi-path execution.
  --
  -- >>> safeDivMod (ssym "a") (ssym "b") :: ExceptT ArithException UnionM (SymInteger, SymInteger)
  -- ExceptT {If (= b 0) (Left divide by zero) (Right ((div a b),(mod a b)))}
  safeDivMod :: a -> a -> m (a, a)
  safeDivMod a
l a
r = do
    a
d <- a -> a -> m a
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m a
safeDiv a
l a
r
    a
m <- a -> a -> m a
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m a
safeMod a
l a
r
    (a, a) -> m (a, a)
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (a
d, a
m)
  {-# INLINE safeDivMod #-}

  -- | Safe signed 'quot' with monadic error handling in multi-path execution.
  safeQuot :: a -> a -> m a
  safeQuot a
l a
r = ((a, a) -> a) -> m (a, a) -> m a
forall (f :: * -> *) a b.
(TryMerge f, Mergeable a, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a, a) -> a
forall a b. (a, b) -> a
fst (m (a, a) -> m a) -> m (a, a) -> m a
forall a b. (a -> b) -> a -> b
$ a -> a -> m (a, a)
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m (a, a)
safeQuotRem a
l a
r
  {-# INLINE safeQuot #-}

  -- | Safe signed 'rem' with monadic error handling in multi-path execution.
  safeRem :: a -> a -> m a
  safeRem a
l a
r = ((a, a) -> a) -> m (a, a) -> m a
forall (f :: * -> *) a b.
(TryMerge f, Mergeable a, Mergeable b, Functor f) =>
(a -> b) -> f a -> f b
mrgFmap (a, a) -> a
forall a b. (a, b) -> b
snd (m (a, a) -> m a) -> m (a, a) -> m a
forall a b. (a -> b) -> a -> b
$ a -> a -> m (a, a)
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m (a, a)
safeQuotRem a
l a
r
  {-# INLINE safeRem #-}

  -- | Safe signed 'quotRem' with monadic error handling in multi-path execution.
  safeQuotRem :: a -> a -> m (a, a)
  safeQuotRem a
l a
r = do
    a
q <- a -> a -> m a
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m a
safeQuot a
l a
r
    a
m <- a -> a -> m a
forall e a (m :: * -> *). SafeDivision e a m => a -> a -> m a
safeRem a
l a
r
    (a, a) -> m (a, a)
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (a
q, a
m)
  {-# INLINE safeQuotRem #-}

  {-# MINIMAL
    ((safeDiv, safeMod) | safeDivMod),
    ((safeQuot, safeRem) | safeQuotRem)
    #-}

concreteSafeDivisionHelper ::
  (MonadError ArithException m, TryMerge m, Integral a, Mergeable r) =>
  (a -> a -> r) ->
  a ->
  a ->
  m r
concreteSafeDivisionHelper :: forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper a -> a -> r
f a
l a
r
  | a
r a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = ArithException -> m r
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
DivideByZero
  | Bool
otherwise = r -> m r
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (r -> m r) -> r -> m r
forall a b. (a -> b) -> a -> b
$ a -> a -> r
f a
l a
r

concreteSignedBoundedSafeDivisionHelper ::
  ( MonadError ArithException m,
    TryMerge m,
    Integral a,
    Bounded a,
    Mergeable r
  ) =>
  (a -> a -> r) ->
  a ->
  a ->
  m r
concreteSignedBoundedSafeDivisionHelper :: forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a, Bounded a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSignedBoundedSafeDivisionHelper a -> a -> r
f a
l a
r
  | a
r a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = ArithException -> m r
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
DivideByZero
  | a
l a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Bounded a => a
minBound Bool -> Bool -> Bool
&& a
r a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== -a
1 = ArithException -> m r
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
Overflow
  | Bool
otherwise = r -> m r
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (r -> m r) -> r -> m r
forall a b. (a -> b) -> a -> b
$ a -> a -> r
f a
l a
r

#define QUOTE() '
#define QID(a) a
#define QRIGHT(a) QID(a)'

#define QRIGHTT(a) QID(a)' t'
#define QRIGHTU(a) QID(a)' _'

#define SAFE_DIVISION_CONCRETE(type) \
instance (MonadError ArithException m, TryMerge m) => \
  SafeDivision ArithException type m where \
  safeDiv = concreteSafeDivisionHelper div; \
  safeMod = concreteSafeDivisionHelper mod; \
  safeDivMod = concreteSafeDivisionHelper divMod; \
  safeQuot = concreteSafeDivisionHelper quot; \
  safeRem = concreteSafeDivisionHelper rem; \
  safeQuotRem = concreteSafeDivisionHelper quotRem

#define SAFE_DIVISION_CONCRETE_SIGNED_BOUNDED(type) \
instance (MonadError ArithException m, TryMerge m) => \
  SafeDivision ArithException type m where \
  safeDiv = concreteSignedBoundedSafeDivisionHelper div; \
  safeMod = concreteSafeDivisionHelper mod; \
  safeDivMod = concreteSignedBoundedSafeDivisionHelper divMod; \
  safeQuot = concreteSignedBoundedSafeDivisionHelper quot; \
  safeRem = concreteSafeDivisionHelper rem; \
  safeQuotRem = concreteSignedBoundedSafeDivisionHelper quotRem

#define SAFE_DIVISION_CONCRETE_BV(type) \
instance \
  (MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) => \
  SafeDivision ArithException (type n) m where \
  safeDiv = concreteSafeDivisionHelper div; \
  safeMod = concreteSafeDivisionHelper mod; \
  safeDivMod = concreteSafeDivisionHelper divMod; \
  safeQuot = concreteSafeDivisionHelper quot; \
  safeRem = concreteSafeDivisionHelper rem; \
  safeQuotRem = concreteSafeDivisionHelper quotRem

#if 1
SAFE_DIVISION_CONCRETE(Integer)
SAFE_DIVISION_CONCRETE_SIGNED_BOUNDED(Int8)
SAFE_DIVISION_CONCRETE_SIGNED_BOUNDED(Int16)
SAFE_DIVISION_CONCRETE_SIGNED_BOUNDED(Int32)
SAFE_DIVISION_CONCRETE_SIGNED_BOUNDED(Int64)
SAFE_DIVISION_CONCRETE_SIGNED_BOUNDED(Int)
SAFE_DIVISION_CONCRETE(Word8)
SAFE_DIVISION_CONCRETE(Word16)
SAFE_DIVISION_CONCRETE(Word32)
SAFE_DIVISION_CONCRETE(Word64)
SAFE_DIVISION_CONCRETE(Word)

instance
  (MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) =>
  SafeDivision ArithException (IntN n) m where
  safeDiv :: IntN n -> IntN n -> m (IntN n)
safeDiv = (IntN n -> IntN n -> IntN n) -> IntN n -> IntN n -> m (IntN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a, Bounded a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSignedBoundedSafeDivisionHelper IntN n -> IntN n -> IntN n
forall a. Integral a => a -> a -> a
div
  safeMod :: IntN n -> IntN n -> m (IntN n)
safeMod = (IntN n -> IntN n -> IntN n) -> IntN n -> IntN n -> m (IntN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper IntN n -> IntN n -> IntN n
forall a. Integral a => a -> a -> a
mod
  safeDivMod :: IntN n -> IntN n -> m (IntN n, IntN n)
safeDivMod = (IntN n -> IntN n -> (IntN n, IntN n))
-> IntN n -> IntN n -> m (IntN n, IntN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a, Bounded a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSignedBoundedSafeDivisionHelper IntN n -> IntN n -> (IntN n, IntN n)
forall a. Integral a => a -> a -> (a, a)
divMod
  safeQuot :: IntN n -> IntN n -> m (IntN n)
safeQuot = (IntN n -> IntN n -> IntN n) -> IntN n -> IntN n -> m (IntN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a, Bounded a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSignedBoundedSafeDivisionHelper IntN n -> IntN n -> IntN n
forall a. Integral a => a -> a -> a
quot
  safeRem :: IntN n -> IntN n -> m (IntN n)
safeRem = (IntN n -> IntN n -> IntN n) -> IntN n -> IntN n -> m (IntN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper IntN n -> IntN n -> IntN n
forall a. Integral a => a -> a -> a
rem
  safeQuotRem :: IntN n -> IntN n -> m (IntN n, IntN n)
safeQuotRem = (IntN n -> IntN n -> (IntN n, IntN n))
-> IntN n -> IntN n -> m (IntN n, IntN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a, Bounded a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSignedBoundedSafeDivisionHelper IntN n -> IntN n -> (IntN n, IntN n)
forall a. Integral a => a -> a -> (a, a)
quotRem

instance
  (MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) =>
  SafeDivision ArithException (WordN n) m where
  safeDiv :: WordN n -> WordN n -> m (WordN n)
safeDiv = (WordN n -> WordN n -> WordN n)
-> WordN n -> WordN n -> m (WordN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper WordN n -> WordN n -> WordN n
forall a. Integral a => a -> a -> a
div
  safeMod :: WordN n -> WordN n -> m (WordN n)
safeMod = (WordN n -> WordN n -> WordN n)
-> WordN n -> WordN n -> m (WordN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper WordN n -> WordN n -> WordN n
forall a. Integral a => a -> a -> a
mod
  safeDivMod :: WordN n -> WordN n -> m (WordN n, WordN n)
safeDivMod = (WordN n -> WordN n -> (WordN n, WordN n))
-> WordN n -> WordN n -> m (WordN n, WordN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper WordN n -> WordN n -> (WordN n, WordN n)
forall a. Integral a => a -> a -> (a, a)
divMod
  safeQuot :: WordN n -> WordN n -> m (WordN n)
safeQuot = (WordN n -> WordN n -> WordN n)
-> WordN n -> WordN n -> m (WordN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper WordN n -> WordN n -> WordN n
forall a. Integral a => a -> a -> a
quot
  safeRem :: WordN n -> WordN n -> m (WordN n)
safeRem = (WordN n -> WordN n -> WordN n)
-> WordN n -> WordN n -> m (WordN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper WordN n -> WordN n -> WordN n
forall a. Integral a => a -> a -> a
rem
  safeQuotRem :: WordN n -> WordN n -> m (WordN n, WordN n)
safeQuotRem = (WordN n -> WordN n -> (WordN n, WordN n))
-> WordN n -> WordN n -> m (WordN n, WordN n)
forall (m :: * -> *) a r.
(MonadError ArithException m, TryMerge m, Integral a,
 Mergeable r) =>
(a -> a -> r) -> a -> a -> m r
concreteSafeDivisionHelper WordN n -> WordN n -> (WordN n, WordN n)
forall a. Integral a => a -> a -> (a, a)
quotRem
#endif

#define SAFE_DIVISION_SYMBOLIC_FUNC(name, type, op) \
name (type l) rs@(type r) = \
  mrgIf \
    (rs .== con 0) \
    (throwError DivideByZero) \
    (mrgSingle $ type $ op l r); \

#define SAFE_DIVISION_SYMBOLIC_FUNC2(name, type, op1, op2) \
name (type l) rs@(type r) = \
  mrgIf \
    (rs .== con 0) \
    (throwError DivideByZero) \
    (mrgSingle (type $ op1 l r, type $ op2 l r)); \

#if 1
instance
  (MonadUnion m, MonadError ArithException m) =>
  SafeDivision ArithException SymInteger m where
  SAFE_DIVISION_SYMBOLIC_FUNC(safeDiv, SymInteger, pevalDivIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC(safeMod, SymInteger, pevalModIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC(safeQuot, SymInteger, pevalQuotIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC(safeRem, SymInteger, pevalRemIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC2(safeDivMod, SymInteger, pevalDivIntegralTerm, pevalModIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC2(safeQuotRem, SymInteger, pevalQuotIntegralTerm, pevalRemIntegralTerm)
#endif

#define SAFE_DIVISION_SYMBOLIC_FUNC_BOUNDED_SIGNED(name, type, op) \
name ls@(type l) rs@(type r) = \
  mrgIf \
    (rs .== con 0) \
    (throwError DivideByZero) \
    (mrgIf (rs .== con (-1) .&& ls .== con minBound) \
      (throwError Overflow) \
      (mrgSingle $ type $ op l r)); \

#define SAFE_DIVISION_SYMBOLIC_FUNC2_BOUNDED_SIGNED(name, type, op1, op2) \
name ls@(type l) rs@(type r) = \
  mrgIf \
    (rs .== con 0) \
    (throwError DivideByZero) \
    (mrgIf (rs .== con (-1) .&& ls .== con minBound) \
      (throwError Overflow) \
      (mrgSingle (type $ op1 l r, type $ op2 l r))); \

#if 1
instance
  (MonadError ArithException m, MonadUnion m, KnownNat n, 1 <= n) =>
  SafeDivision ArithException (SymIntN n) m where
  SAFE_DIVISION_SYMBOLIC_FUNC_BOUNDED_SIGNED(safeDiv, SymIntN, pevalDivIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC(safeMod, SymIntN, pevalModIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC_BOUNDED_SIGNED(safeQuot, SymIntN, pevalQuotIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC(safeRem, SymIntN, pevalRemIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC2_BOUNDED_SIGNED(safeDivMod, SymIntN, pevalDivIntegralTerm, pevalModIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC2_BOUNDED_SIGNED(safeQuotRem, SymIntN, pevalQuotIntegralTerm, pevalRemIntegralTerm)
#endif

#if 1
instance
  (MonadError ArithException m, MonadUnion m, KnownNat n, 1 <= n) =>
  SafeDivision ArithException (SymWordN n) m where
  SAFE_DIVISION_SYMBOLIC_FUNC(safeDiv, SymWordN, pevalDivIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC(safeMod, SymWordN, pevalModIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC(safeQuot, SymWordN, pevalQuotIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC(safeRem, SymWordN, pevalRemIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC2(safeDivMod, SymWordN, pevalDivIntegralTerm, pevalModIntegralTerm)
  SAFE_DIVISION_SYMBOLIC_FUNC2(safeQuotRem, SymWordN, pevalQuotIntegralTerm, pevalRemIntegralTerm)
#endif