{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Trustworthy #-}

-- |
-- Module      :   Grisette.Core.Data.Class.Error
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Core.Data.Class.Error
  ( -- * Error transformation
    TransformError (..),

    -- * Throwing error
    symAssertWith,
    symAssertTransformableError,
    symThrowTransformableError,
  )
where

import Control.Monad.Except
import Grisette.Core.Control.Monad.Union
import Grisette.Core.Data.Class.Bool
import Grisette.Core.Data.Class.Mergeable
import Grisette.Core.Data.Class.SimpleMergeable
import {-# SOURCE #-} Grisette.IR.SymPrim.Data.SymPrim

-- $setup
-- >>> import Control.Exception
-- >>> import Grisette.Core
-- >>> import Grisette.IR.SymPrim
-- >>> :set -XOverloadedStrings
-- >>> :set -XFlexibleContexts

-- | This class indicates that the error type @to@ can always represent the
-- error type @from@.
--
-- This is useful in implementing generic procedures that may throw errors.
-- For example, we support symbolic division and modulo operations. These
-- operations should throw an error when the divisor is zero, and we use the
-- standard error type 'Control.Exception.ArithException' for this purpose.
-- However, the user may use other type to represent errors, so we need this
-- type class to transform the 'Control.Exception.ArithException' to the
-- user-defined types.
--
-- Another example of these generic procedures is the
-- 'Grisette.Core.symAssert' and 'Grisette.Core.symAssume' functions.
-- They can be used with any error types that are
-- compatible with the 'Grisette.Core.AssertionError' and
-- 'Grisette.Core.VerificationConditions' types, respectively.
class TransformError from to where
  -- | Transforms an error with type @from@ to an error with type @to@.
  transformError :: from -> to

instance {-# OVERLAPPABLE #-} TransformError a a where
  transformError :: a -> a
transformError = forall a. a -> a
id
  {-# INLINE transformError #-}

instance {-# OVERLAPS #-} TransformError a () where
  transformError :: a -> ()
transformError a
_ = ()
  {-# INLINE transformError #-}

instance {-# OVERLAPPING #-} TransformError () () where
  transformError :: () -> ()
transformError ()
_ = ()
  {-# INLINE transformError #-}

-- | Used within a monadic multi path computation to begin exception processing.
--
-- Terminate the current execution path with the specified error. Compatible
-- errors can be transformed.
--
-- >>> symThrowTransformableError Overflow :: ExceptT AssertionError UnionM ()
-- ExceptT {Left AssertionError}
symThrowTransformableError ::
  ( Mergeable to,
    Mergeable a,
    TransformError from to,
    MonadError to erm,
    MonadUnion erm
  ) =>
  from ->
  erm a
symThrowTransformableError :: forall to a from (erm :: * -> *).
(Mergeable to, Mergeable a, TransformError from to,
 MonadError to erm, MonadUnion erm) =>
from -> erm a
symThrowTransformableError = forall (u :: * -> *) a. (UnionLike u, Mergeable a) => u a -> u a
merge forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall from to. TransformError from to => from -> to
transformError
{-# INLINE symThrowTransformableError #-}

-- | Used within a monadic multi path computation for exception processing.
--
-- Terminate the current execution path with the specified error if the condition does not hold.
-- Compatible error can be transformed.
--
-- >>> let assert = symAssertTransformableError AssertionError
-- >>> assert "a" :: ExceptT AssertionError UnionM ()
-- ExceptT {If (! a) (Left AssertionError) (Right ())}
symAssertTransformableError ::
  ( Mergeable to,
    TransformError from to,
    MonadError to erm,
    MonadUnion erm
  ) =>
  from ->
  SymBool ->
  erm ()
symAssertTransformableError :: forall to from (erm :: * -> *).
(Mergeable to, TransformError from to, MonadError to erm,
 MonadUnion erm) =>
from -> SymBool -> erm ()
symAssertTransformableError from
err SymBool
cond = forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf SymBool
cond (forall (m :: * -> *) a. Monad m => a -> m a
return ()) (forall to a from (erm :: * -> *).
(Mergeable to, Mergeable a, TransformError from to,
 MonadError to erm, MonadUnion erm) =>
from -> erm a
symThrowTransformableError from
err)
{-# INLINE symAssertTransformableError #-}

symAssertWith ::
  ( Mergeable e,
    MonadError e erm,
    MonadUnion erm
  ) =>
  e ->
  SymBool ->
  erm ()
symAssertWith :: forall e (erm :: * -> *).
(Mergeable e, MonadError e erm, MonadUnion erm) =>
e -> SymBool -> erm ()
symAssertWith e
err SymBool
cond = forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf SymBool
cond (forall (m :: * -> *) a. Monad m => a -> m a
return ()) (forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError e
err)
{-# INLINE symAssertWith #-}