{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

{-# HLINT ignore "Unused LANGUAGE pragma" #-}

-- |
-- Module      :   Grisette.Internal.SymPrim.SymTabularFun
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.SymPrim.SymTabularFun
  ( type (=~>) (SymTabularFun),
  )
where

import Control.DeepSeq (NFData (rnf))
import qualified Data.Binary as Binary
import Data.Bytes.Serial (Serial (deserialize, serialize))
import Data.Hashable (Hashable (hashWithSalt))
import qualified Data.Serialize as Cereal
import Data.String (IsString (fromString))
import Grisette.Internal.Core.Data.Class.Function
  ( Apply (FunType, apply),
    Function ((#)),
  )
import Grisette.Internal.Core.Data.Class.Solvable
  ( Solvable (con, conView, ssym, sym),
  )
import Grisette.Internal.SymPrim.AllSyms (AllSyms (allSymsS), SomeSym (SomeSym))
import Grisette.Internal.SymPrim.Prim.Term
  ( ConRep (ConType),
    LinkedRep (underlyingTerm, wrapTerm),
    PEvalApplyTerm (pevalApplyTerm),
    SupportedNonFuncPrim,
    SupportedPrim,
    SymRep (SymType),
    Term (ConTerm),
    conTerm,
    pformatTerm,
    symTerm,
    typedAnySymbol,
  )
import Grisette.Internal.SymPrim.TabularFun (type (=->))
import Language.Haskell.TH.Syntax (Lift (liftTyped))

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Grisette.Backend
-- >>> import Data.Proxy

-- | Symbolic tabular function type.
--
-- >>> f' = "f" :: SymInteger =~> SymInteger
-- >>> f = (f' #)
-- >>> f 1
-- (apply f 1)
--
-- >>> f' = con (TabularFun [(1, 2), (2, 3)] 4) :: SymInteger =~> SymInteger
-- >>> f = (f' #)
-- >>> f 1
-- 2
-- >>> f 2
-- 3
-- >>> f 3
-- 4
-- >>> f "b"
-- (ite (= b 1) 2 (ite (= b 2) 3 4))
data sa =~> sb where
  SymTabularFun ::
    ( LinkedRep ca sa,
      LinkedRep cb sb,
      SupportedPrim (ca =-> cb),
      SupportedNonFuncPrim ca
    ) =>
    Term (ca =-> cb) ->
    sa =~> sb

infixr 0 =~>

instance Lift (sa =~> sb) where
  liftTyped :: forall (m :: * -> *). Quote m => (sa =~> sb) -> Code m (sa =~> sb)
liftTyped (SymTabularFun Term (ca =-> cb)
t) = [||Term (ca =-> cb) -> sa =~> sb
forall ca sa cb sb.
(LinkedRep ca sa, LinkedRep cb sb, SupportedPrim (ca =-> cb),
 SupportedNonFuncPrim ca) =>
Term (ca =-> cb) -> sa =~> sb
SymTabularFun Term (ca =-> cb)
t||]

instance NFData (sa =~> sb) where
  rnf :: (sa =~> sb) -> ()
rnf (SymTabularFun Term (ca =-> cb)
t) = Term (ca =-> cb) -> ()
forall a. NFData a => a -> ()
rnf Term (ca =-> cb)
t

instance (ConRep a, ConRep b) => ConRep (a =~> b) where
  type ConType (a =~> b) = ConType a =-> ConType b

instance (SymRep a, SymRep b, SupportedPrim (a =-> b)) => SymRep (a =-> b) where
  type SymType (a =-> b) = SymType a =~> SymType b

instance
  ( LinkedRep ca sa,
    LinkedRep cb sb,
    SupportedPrim (ca =-> cb),
    SupportedNonFuncPrim ca
  ) =>
  LinkedRep (ca =-> cb) (sa =~> sb)
  where
  underlyingTerm :: (sa =~> sb) -> Term (ca =-> cb)
underlyingTerm (SymTabularFun Term (ca =-> cb)
a) = Term (ca =-> cb)
Term (ca =-> cb)
a
  wrapTerm :: Term (ca =-> cb) -> sa =~> sb
wrapTerm = Term (ca =-> cb) -> sa =~> sb
forall ca sa cb sb.
(LinkedRep ca sa, LinkedRep cb sb, SupportedPrim (ca =-> cb),
 SupportedNonFuncPrim ca) =>
Term (ca =-> cb) -> sa =~> sb
SymTabularFun

instance Function (sa =~> sb) sa sb where
  (SymTabularFun Term (ca =-> cb)
f) # :: (sa =~> sb) -> sa -> sb
# sa
t = Term cb -> sb
forall con sym. LinkedRep con sym => Term con -> sym
wrapTerm (Term cb -> sb) -> Term cb -> sb
forall a b. (a -> b) -> a -> b
$ Term (ca =-> cb) -> Term ca -> Term cb
forall f a b. PEvalApplyTerm f a b => Term f -> Term a -> Term b
pevalApplyTerm Term (ca =-> cb)
f (sa -> Term ca
forall con sym. LinkedRep con sym => sym -> Term con
underlyingTerm sa
t)

instance (Apply st) => Apply (sa =~> st) where
  type FunType (sa =~> st) = sa -> FunType st
  apply :: (sa =~> st) -> FunType (sa =~> st)
apply sa =~> st
uf sa
a = st -> FunType st
forall uf. Apply uf => uf -> FunType uf
apply (sa =~> st
uf (sa =~> st) -> sa -> st
forall f arg ret. Function f arg ret => f -> arg -> ret
# sa
a)

instance
  ( LinkedRep ca sa,
    LinkedRep cb sb,
    SupportedPrim (ca =-> cb),
    SupportedNonFuncPrim ca
  ) =>
  Solvable (ca =-> cb) (sa =~> sb)
  where
  con :: (ca =-> cb) -> sa =~> sb
con = Term (ca =-> cb) -> sa =~> sb
forall ca sa cb sb.
(LinkedRep ca sa, LinkedRep cb sb, SupportedPrim (ca =-> cb),
 SupportedNonFuncPrim ca) =>
Term (ca =-> cb) -> sa =~> sb
SymTabularFun (Term (ca =-> cb) -> sa =~> sb)
-> ((ca =-> cb) -> Term (ca =-> cb)) -> (ca =-> cb) -> sa =~> sb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ca =-> cb) -> Term (ca =-> cb)
forall t. SupportedPrim t => t -> Term t
conTerm
  sym :: Symbol -> sa =~> sb
sym = Term (ca =-> cb) -> sa =~> sb
forall ca sa cb sb.
(LinkedRep ca sa, LinkedRep cb sb, SupportedPrim (ca =-> cb),
 SupportedNonFuncPrim ca) =>
Term (ca =-> cb) -> sa =~> sb
SymTabularFun (Term (ca =-> cb) -> sa =~> sb)
-> (Symbol -> Term (ca =-> cb)) -> Symbol -> sa =~> sb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypedSymbol 'AnyKind (ca =-> cb) -> Term (ca =-> cb)
forall (knd :: SymbolKind) t. TypedSymbol knd t -> Term t
symTerm (TypedSymbol 'AnyKind (ca =-> cb) -> Term (ca =-> cb))
-> (Symbol -> TypedSymbol 'AnyKind (ca =-> cb))
-> Symbol
-> Term (ca =-> cb)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Symbol -> TypedSymbol 'AnyKind (ca =-> cb)
forall t. SupportedPrim t => Symbol -> TypedSymbol 'AnyKind t
typedAnySymbol
  conView :: (sa =~> sb) -> Maybe (ca =-> cb)
conView (SymTabularFun (ConTerm WeakThreadId
_ Digest
_ Digest
_ Ident
_ ca =-> cb
t)) = (ca =-> cb) -> Maybe (ca =-> cb)
forall a. a -> Maybe a
Just ca =-> cb
ca =-> cb
t
  conView sa =~> sb
_ = Maybe (ca =-> cb)
forall a. Maybe a
Nothing

instance
  ( SupportedPrim (ca =-> cb),
    LinkedRep ca sa,
    LinkedRep cb sb,
    SupportedNonFuncPrim ca
  ) =>
  IsString (sa =~> sb)
  where
  fromString :: String -> sa =~> sb
fromString = Identifier -> sa =~> sb
forall c t. Solvable c t => Identifier -> t
ssym (Identifier -> sa =~> sb)
-> (String -> Identifier) -> String -> sa =~> sb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Identifier
forall a. IsString a => String -> a
fromString

instance Show (sa =~> sb) where
  show :: (sa =~> sb) -> String
show (SymTabularFun Term (ca =-> cb)
t) = Term (ca =-> cb) -> String
forall t. Term t -> String
pformatTerm Term (ca =-> cb)
t

instance Eq (sa =~> sb) where
  SymTabularFun Term (ca =-> cb)
l == :: (sa =~> sb) -> (sa =~> sb) -> Bool
== SymTabularFun Term (ca =-> cb)
r = Term (ca =-> cb)
l Term (ca =-> cb) -> Term (ca =-> cb) -> Bool
forall a. Eq a => a -> a -> Bool
== Term (ca =-> cb)
Term (ca =-> cb)
r

instance Hashable (sa =~> sb) where
  hashWithSalt :: Int -> (sa =~> sb) -> Int
hashWithSalt Int
s (SymTabularFun Term (ca =-> cb)
v) = Int
s Int -> Term (ca =-> cb) -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` Term (ca =-> cb)
v

instance AllSyms (sa =~> sb) where
  allSymsS :: (sa =~> sb) -> [SomeSym] -> [SomeSym]
allSymsS v :: sa =~> sb
v@SymTabularFun {} = ((sa =~> sb) -> SomeSym
forall con sym. LinkedRep con sym => sym -> SomeSym
SomeSym sa =~> sb
v SomeSym -> [SomeSym] -> [SomeSym]
forall a. a -> [a] -> [a]
:)

instance
  ( LinkedRep ca sa,
    LinkedRep cb sb,
    SupportedPrim (ca =-> cb),
    SupportedNonFuncPrim ca
  ) =>
  Serial (sa =~> sb)
  where
  serialize :: forall (m :: * -> *). MonadPut m => (sa =~> sb) -> m ()
serialize = Term (ca =-> cb) -> m ()
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => Term (ca =-> cb) -> m ()
serialize (Term (ca =-> cb) -> m ())
-> ((sa =~> sb) -> Term (ca =-> cb)) -> (sa =~> sb) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (sa =~> sb) -> Term (ca =-> cb)
forall con sym. LinkedRep con sym => sym -> Term con
underlyingTerm
  deserialize :: forall (m :: * -> *). MonadGet m => m (sa =~> sb)
deserialize = Term (ca =-> cb) -> sa =~> sb
forall ca sa cb sb.
(LinkedRep ca sa, LinkedRep cb sb, SupportedPrim (ca =-> cb),
 SupportedNonFuncPrim ca) =>
Term (ca =-> cb) -> sa =~> sb
SymTabularFun (Term (ca =-> cb) -> sa =~> sb)
-> m (Term (ca =-> cb)) -> m (sa =~> sb)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Term (ca =-> cb))
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (Term (ca =-> cb))
deserialize

instance
  ( LinkedRep ca sa,
    LinkedRep cb sb,
    SupportedPrim (ca =-> cb),
    SupportedNonFuncPrim ca
  ) =>
  Cereal.Serialize (sa =~> sb)
  where
  put :: Putter (sa =~> sb)
put = Putter (sa =~> sb)
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => (sa =~> sb) -> m ()
serialize
  get :: Get (sa =~> sb)
get = Get (sa =~> sb)
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (sa =~> sb)
deserialize

instance
  ( LinkedRep ca sa,
    LinkedRep cb sb,
    SupportedPrim (ca =-> cb),
    SupportedNonFuncPrim ca
  ) =>
  Binary.Binary (sa =~> sb)
  where
  put :: (sa =~> sb) -> Put
put = (sa =~> sb) -> Put
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => (sa =~> sb) -> m ()
serialize
  get :: Get (sa =~> sb)
get = Get (sa =~> sb)
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (sa =~> sb)
deserialize