{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- Module      :   Grisette.IR.SymPrim.Data.Prim.InternedTerm.TermSubstitution
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.IR.SymPrim.Data.Prim.InternedTerm.TermSubstitution (substTerm) where

import Grisette.Core.Data.MemoUtils
import Grisette.IR.SymPrim.Data.Prim.InternedTerm.InternedCtors
import Grisette.IR.SymPrim.Data.Prim.InternedTerm.SomeTerm
import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term
import Grisette.IR.SymPrim.Data.Prim.PartialEval.BV
import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bits
import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool
import Grisette.IR.SymPrim.Data.Prim.PartialEval.GeneralFun
import Grisette.IR.SymPrim.Data.Prim.PartialEval.Integer
import Grisette.IR.SymPrim.Data.Prim.PartialEval.Num
import Grisette.IR.SymPrim.Data.Prim.PartialEval.TabularFun
import Type.Reflection
import Unsafe.Coerce

substTerm :: forall a b. (SupportedPrim a, SupportedPrim b) => TypedSymbol a -> Term a -> Term b -> Term b
substTerm :: forall a b.
(SupportedPrim a, SupportedPrim b) =>
TypedSymbol a -> Term a -> Term b -> Term b
substTerm TypedSymbol a
sym Term a
term = forall x. SupportedPrim x => Term x -> Term x
gov
  where
    gov :: (SupportedPrim x) => Term x -> Term x
    gov :: forall x. SupportedPrim x => Term x -> Term x
gov Term x
b = case SomeTerm -> SomeTerm
go (forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm Term x
b) of
      SomeTerm Term a
v -> forall a b. a -> b
unsafeCoerce Term a
v
    go :: SomeTerm -> SomeTerm
    go :: SomeTerm -> SomeTerm
go = forall k a. (Eq k, Hashable k) => (k -> a) -> k -> a
htmemo forall a b. (a -> b) -> a -> b
$ \stm :: SomeTerm
stm@(SomeTerm (Term a
tm :: Term v)) ->
      case Term a
tm of
        ConTerm Id
_ a
cv -> case (forall {k} (a :: k). Typeable a => TypeRep a
typeRep :: TypeRep v) of
          App (App TypeRep a
gf TypeRep b
_) TypeRep b
_ ->
            case forall k1 k2 (a :: k1) (b :: k2).
TypeRep a -> TypeRep b -> Maybe (a :~~: b)
eqTypeRep TypeRep a
gf (forall {k} (a :: k). Typeable a => TypeRep a
typeRep @(-->)) of
              Just a :~~: (-->)
HRefl -> case a
cv of
                GeneralFun TypedSymbol b
sym1 Term b
tm1 ->
                  if forall t. TypedSymbol t -> SomeTypedSymbol
someTypedSymbol TypedSymbol b
sym1 forall a. Eq a => a -> a -> Bool
== forall t. TypedSymbol t -> SomeTypedSymbol
someTypedSymbol TypedSymbol a
sym
                    then SomeTerm
stm
                    else forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall t.
(SupportedPrim t, Typeable t, Hashable t, Eq t, Show t) =>
t -> Term t
conTerm forall a b. (a -> b) -> a -> b
$ forall a b.
(SupportedPrim a, SupportedPrim b) =>
TypedSymbol a -> Term b -> a --> b
GeneralFun TypedSymbol b
sym1 (forall x. SupportedPrim x => Term x -> Term x
gov Term b
tm1)
              Maybe (a :~~: (-->))
Nothing -> SomeTerm
stm
          TypeRep a
_ -> SomeTerm
stm
        SymTerm Id
_ TypedSymbol a
ts -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ if forall t. TypedSymbol t -> SomeTypedSymbol
someTypedSymbol TypedSymbol a
ts forall a. Eq a => a -> a -> Bool
== forall t. TypedSymbol t -> SomeTypedSymbol
someTypedSymbol TypedSymbol a
sym then forall a b. a -> b
unsafeCoerce Term a
term else Term a
tm
        UnaryTerm Id
_ tag
tag Term arg
te -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall tag arg t.
(UnaryOp tag arg t, Typeable tag, Typeable t) =>
tag -> Term arg -> Term t
partialEvalUnary tag
tag (forall x. SupportedPrim x => Term x -> Term x
gov Term arg
te)
        BinaryTerm Id
_ tag
tag Term arg1
te Term arg2
te' -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall tag arg1 arg2 t.
(BinaryOp tag arg1 arg2 t, Typeable tag, Typeable t) =>
tag -> Term arg1 -> Term arg2 -> Term t
partialEvalBinary tag
tag (forall x. SupportedPrim x => Term x -> Term x
gov Term arg1
te) (forall x. SupportedPrim x => Term x -> Term x
gov Term arg2
te')
        TernaryTerm Id
_ tag
tag Term arg1
op1 Term arg2
op2 Term arg3
op3 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall tag arg1 arg2 arg3 t.
(TernaryOp tag arg1 arg2 arg3 t, Typeable tag, Typeable t) =>
tag -> Term arg1 -> Term arg2 -> Term arg3 -> Term t
partialEvalTernary tag
tag (forall x. SupportedPrim x => Term x -> Term x
gov Term arg1
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term arg2
op2) (forall x. SupportedPrim x => Term x -> Term x
gov Term arg3
op3)
        NotTerm Id
_ Term Bool
op -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ Term Bool -> Term Bool
pevalNotTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term Bool
op)
        OrTerm Id
_ Term Bool
op1 Term Bool
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ Term Bool -> Term Bool -> Term Bool
pevalOrTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term Bool
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term Bool
op2)
        AndTerm Id
_ Term Bool
op1 Term Bool
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ Term Bool -> Term Bool -> Term Bool
pevalAndTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term Bool
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term Bool
op2)
        EqvTerm Id
_ Term t1
op1 Term t1
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a. SupportedPrim a => Term a -> Term a -> Term Bool
pevalEqvTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term t1
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term t1
op2)
        ITETerm Id
_ Term Bool
c Term a
op1 Term a
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a.
SupportedPrim a =>
Term Bool -> Term a -> Term a -> Term a
pevalITETerm (forall x. SupportedPrim x => Term x -> Term x
gov Term Bool
c) (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op2)
        AddNumTerm Id
_ Term a
op1 Term a
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a. (Num a, SupportedPrim a) => Term a -> Term a -> Term a
pevalAddNumTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op2)
        UMinusNumTerm Id
_ Term a
op -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a. (Num a, SupportedPrim a) => Term a -> Term a
pevalUMinusNumTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op)
        TimesNumTerm Id
_ Term a
op1 Term a
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a. (Num a, SupportedPrim a) => Term a -> Term a -> Term a
pevalTimesNumTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op2)
        AbsNumTerm Id
_ Term a
op -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a. (SupportedPrim a, Num a) => Term a -> Term a
pevalAbsNumTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op)
        SignumNumTerm Id
_ Term a
op -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a. (Num a, SupportedPrim a) => Term a -> Term a
pevalSignumNumTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op)
        LTNumTerm Id
_ Term t1
op1 Term t1
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a.
(Num a, Ord a, SupportedPrim a) =>
Term a -> Term a -> Term Bool
pevalLtNumTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term t1
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term t1
op2)
        LENumTerm Id
_ Term t1
op1 Term t1
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a.
(Num a, Ord a, SupportedPrim a) =>
Term a -> Term a -> Term Bool
pevalLeNumTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term t1
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term t1
op2)
        AndBitsTerm Id
_ Term a
op1 Term a
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a. (Bits a, SupportedPrim a) => Term a -> Term a -> Term a
pevalAndBitsTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op2)
        OrBitsTerm Id
_ Term a
op1 Term a
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a. (Bits a, SupportedPrim a) => Term a -> Term a -> Term a
pevalOrBitsTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op2)
        XorBitsTerm Id
_ Term a
op1 Term a
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a. (Bits a, SupportedPrim a) => Term a -> Term a -> Term a
pevalXorBitsTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op2)
        ComplementBitsTerm Id
_ Term a
op -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a. (Bits a, SupportedPrim a) => Term a -> Term a
pevalComplementBitsTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op)
        ShiftBitsTerm Id
_ Term a
op Id
n -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a. (Bits a, SupportedPrim a) => Term a -> Id -> Term a
pevalShiftBitsTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op) Id
n
        RotateBitsTerm Id
_ Term a
op Id
n -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a. (Bits a, SupportedPrim a) => Term a -> Id -> Term a
pevalRotateBitsTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op) Id
n
        BVConcatTerm Id
_ Term (bv a)
op1 Term (bv b)
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall (s :: Nat -> *) (w :: Nat) (w' :: Nat) (w'' :: Nat).
(SupportedPrim (s w), SupportedPrim (s w'), SupportedPrim (s w''),
 KnownNat w, KnownNat w', KnownNat w'',
 BVConcat (s w) (s w') (s w'')) =>
Term (s w) -> Term (s w') -> Term (s w'')
pevalBVConcatTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term (bv a)
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term (bv b)
op2)
        BVSelectTerm Id
_ TypeRep ix
ix TypeRep w
w Term (bv a)
op -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall (bv :: Nat -> *) (a :: Nat) (ix :: Nat) (w :: Nat)
       (proxy :: Nat -> *).
(SupportedPrim (bv a), SupportedPrim (bv w), KnownNat a,
 KnownNat w, KnownNat ix, BVSelect (bv a) ix w (bv w)) =>
proxy ix -> proxy w -> Term (bv a) -> Term (bv w)
pevalBVSelectTerm TypeRep ix
ix TypeRep w
w (forall x. SupportedPrim x => Term x -> Term x
gov Term (bv a)
op)
        BVExtendTerm Id
_ Bool
n TypeRep n
signed Term (bv a)
op -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall (proxy :: Nat -> *) (a :: Nat) (n :: Nat) (b :: Nat)
       (bv :: Nat -> *).
(KnownNat a, KnownNat b, KnownNat n, BVExtend (bv a) n (bv b),
 SupportedPrim (bv a), SupportedPrim (bv b)) =>
Bool -> proxy n -> Term (bv a) -> Term (bv b)
pevalBVExtendTerm Bool
n TypeRep n
signed (forall x. SupportedPrim x => Term x -> Term x
gov Term (bv a)
op)
        TabularFunApplyTerm Id
_ Term (a =-> a)
f Term a
op -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a b.
(SupportedPrim a, SupportedPrim b) =>
Term (a =-> b) -> Term a -> Term b
pevalTabularFunApplyTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term (a =-> a)
f) (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op)
        GeneralFunApplyTerm Id
_ Term (a --> a)
f Term a
op -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ forall a b.
(SupportedPrim a, SupportedPrim b) =>
Term (a --> b) -> Term a -> Term b
pevalGeneralFunApplyTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term (a --> a)
f) (forall x. SupportedPrim x => Term x -> Term x
gov Term a
op)
        DivIntegerTerm Id
_ Term Integer
op1 Term Integer
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer -> Term Integer
pevalDivIntegerTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term Integer
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term Integer
op2)
        ModIntegerTerm Id
_ Term Integer
op1 Term Integer
op2 -> forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer -> Term Integer
pevalModIntegerTerm (forall x. SupportedPrim x => Term x -> Term x
gov Term Integer
op1) (forall x. SupportedPrim x => Term x -> Term x
gov Term Integer
op2)