{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

{-# HLINT ignore "Eta reduce" #-}

-- |
-- Module      :   Grisette.Unified.Internal.Class.UnifiedITEOp
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Unified.Internal.Class.UnifiedITEOp
  ( symIte,
    symIteMerge,
    UnifiedITEOp (..),
  )
where

import Control.Monad.Identity (Identity (runIdentity))
import Data.Kind (Constraint)
import Data.Type.Bool (If)
import Data.Typeable (Typeable)
import Grisette.Internal.Core.Control.Monad.Union (Union)
import Grisette.Internal.Core.Data.Class.ITEOp (ITEOp)
import qualified Grisette.Internal.Core.Data.Class.ITEOp
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import qualified Grisette.Internal.Core.Data.Class.PlainUnion
import Grisette.Unified.Internal.BaseMonad (BaseMonad)
import Grisette.Unified.Internal.EvalModeTag (EvalModeTag (S), IsConMode)
import Grisette.Unified.Internal.UnifiedBool (UnifiedBool (GetBool))
import Grisette.Unified.Internal.Util (withMode)

-- | Unified `Grisette.Internal.Core.Data.Class.ITEOp.symIte` operation.
--
-- This function isn't able to infer the mode of the boolean variable, so you
-- need to provide the mode explicitly. For example:
--
-- > symIte @mode (a .== b) ...
-- > symIte (a .== b :: SymBool) ...
-- > symIte (a .== b :: GetBool mode) ...
symIte ::
  forall mode v.
  (Typeable mode, UnifiedITEOp mode v) =>
  GetBool mode ->
  v ->
  v ->
  v
symIte :: forall (mode :: EvalModeTag) v.
(Typeable mode, UnifiedITEOp mode v) =>
GetBool mode -> v -> v -> v
symIte GetBool mode
c v
a v
b =
  forall (mode :: EvalModeTag) r.
Typeable mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode @mode
    (forall (mode :: EvalModeTag) v r.
UnifiedITEOp mode v =>
(If (IsConMode mode) (() :: Constraint) (ITEOp v) => r) -> r
withBaseITEOp @mode @v ((If (IsConMode mode) (() :: Constraint) (ITEOp v) => v) -> v)
-> (If (IsConMode mode) (() :: Constraint) (ITEOp v) => v) -> v
forall a b. (a -> b) -> a -> b
$ if Bool
GetBool mode
c then v
a else v
b)
    ( forall (mode :: EvalModeTag) v r.
UnifiedITEOp mode v =>
(If (IsConMode mode) (() :: Constraint) (ITEOp v) => r) -> r
withBaseITEOp @mode @v ((If (IsConMode mode) (() :: Constraint) (ITEOp v) => v) -> v)
-> (If (IsConMode mode) (() :: Constraint) (ITEOp v) => v) -> v
forall a b. (a -> b) -> a -> b
$
        SymBool -> v -> v -> v
forall v. ITEOp v => SymBool -> v -> v -> v
Grisette.Internal.Core.Data.Class.ITEOp.symIte SymBool
GetBool mode
c v
a v
b
    )

-- | Unified `Grisette.Internal.Core.Data.Class.PlainUnion.symIteMerge`
-- operation.
--
-- This function isn't able to infer the mode of the base monad from the result,
-- so you need to provide the mode explicitly. For example:
--
-- > symIteMerge @mode ...
-- > symIteMerge (... :: BaseMonad mode v) ...
symIteMerge ::
  forall mode v.
  (Typeable mode, UnifiedITEOp mode v, Mergeable v) =>
  BaseMonad mode v ->
  v
symIteMerge :: forall (mode :: EvalModeTag) v.
(Typeable mode, UnifiedITEOp mode v, Mergeable v) =>
BaseMonad mode v -> v
symIteMerge BaseMonad mode v
m =
  forall (mode :: EvalModeTag) r.
Typeable mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode @mode
    (forall (mode :: EvalModeTag) v r.
UnifiedITEOp mode v =>
(If (IsConMode mode) (() :: Constraint) (ITEOp v) => r) -> r
withBaseITEOp @mode @v ((If (IsConMode mode) (() :: Constraint) (ITEOp v) => v) -> v)
-> (If (IsConMode mode) (() :: Constraint) (ITEOp v) => v) -> v
forall a b. (a -> b) -> a -> b
$ Identity v -> v
forall a. Identity a -> a
runIdentity Identity v
BaseMonad mode v
m)
    ( forall (mode :: EvalModeTag) v r.
UnifiedITEOp mode v =>
(If (IsConMode mode) (() :: Constraint) (ITEOp v) => r) -> r
withBaseITEOp @mode @v ((If (IsConMode mode) (() :: Constraint) (ITEOp v) => v) -> v)
-> (If (IsConMode mode) (() :: Constraint) (ITEOp v) => v) -> v
forall a b. (a -> b) -> a -> b
$
        Union v -> v
forall a (u :: * -> *).
(ITEOp a, Mergeable a, PlainUnion u) =>
u a -> a
Grisette.Internal.Core.Data.Class.PlainUnion.symIteMerge Union v
BaseMonad mode v
m
    )

-- | A class that provides unified equality comparison.
--
-- We use this type class to help resolve the constraints for `ITEOp`.
class UnifiedITEOp mode v where
  withBaseITEOp ::
    ((If (IsConMode mode) (() :: Constraint) (ITEOp v)) => r) -> r

instance
  {-# INCOHERENT #-}
  ( Typeable mode,
    If (IsConMode mode) (() :: Constraint) (ITEOp a)
  ) =>
  UnifiedITEOp mode a
  where
  withBaseITEOp :: forall r.
(If (IsConMode mode) (() :: Constraint) (ITEOp a) => r) -> r
withBaseITEOp If (IsConMode mode) (() :: Constraint) (ITEOp a) => r
r = forall (mode :: EvalModeTag) r.
Typeable mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode @mode r
(mode ~ 'C) => r
If (IsConMode mode) (() :: Constraint) (ITEOp a) => r
r r
(mode ~ 'S) => r
If (IsConMode mode) (() :: Constraint) (ITEOp a) => r
r
  {-# INLINE withBaseITEOp #-}

instance (Mergeable v, UnifiedITEOp 'S v) => UnifiedITEOp 'S (Union v) where
  withBaseITEOp :: forall r.
(If (IsConMode 'S) (() :: Constraint) (ITEOp (Union v)) => r) -> r
withBaseITEOp If (IsConMode 'S) (() :: Constraint) (ITEOp (Union v)) => r
r = forall (mode :: EvalModeTag) v r.
UnifiedITEOp mode v =>
(If (IsConMode mode) (() :: Constraint) (ITEOp v) => r) -> r
withBaseITEOp @'S @v r
If (IsConMode 'S) (() :: Constraint) (ITEOp v) => r
If (IsConMode 'S) (() :: Constraint) (ITEOp (Union v)) => r
r
  {-# INLINE withBaseITEOp #-}