{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} -- | -- Module : Grisette.IR.SymPrim.Data.Prim.PartialEval.Unfold -- Copyright : (c) Sirui Lu 2021-2023 -- License : BSD-3-Clause (see the LICENSE file) -- -- Maintainer : siruilu@cs.washington.edu -- Stability : Experimental -- Portability : GHC only module Grisette.IR.SymPrim.Data.Prim.PartialEval.Unfold ( unaryUnfoldOnce, binaryUnfoldOnce, ) where import Control.Monad.Except import Data.Typeable import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bool import Grisette.IR.SymPrim.Data.Prim.PartialEval.PartialEval unaryPartialUnfoldOnce :: forall a b. (Typeable a, SupportedPrim b) => PartialRuleUnary a b -> TotalRuleUnary a b -> PartialRuleUnary a b unaryPartialUnfoldOnce partial fallback = ret where oneLevel :: TotalRuleUnary a b -> PartialRuleUnary a b oneLevel fallback' x = case (x, partial x) of (ITETerm _ cond vt vf, pr) -> let pt = partial vt pf = partial vf in case (pt, pf) of (Nothing, Nothing) -> pr (mt, mf) -> pevalITETerm cond <$> catchError mt (\_ -> Just $ totalize (oneLevel fallback') fallback' vt) <*> catchError mf (\_ -> Just $ totalize (oneLevel fallback') fallback vf) (_, pr) -> pr ret :: PartialRuleUnary a b ret = oneLevel (totalize @(Term a) @(Term b) partial fallback) unaryUnfoldOnce :: forall a b. (Typeable a, SupportedPrim b) => PartialRuleUnary a b -> TotalRuleUnary a b -> TotalRuleUnary a b unaryUnfoldOnce partial fallback = totalize (unaryPartialUnfoldOnce partial fallback) fallback binaryPartialUnfoldOnce :: forall a b c. (Typeable a, Typeable b, SupportedPrim c) => PartialRuleBinary a b c -> TotalRuleBinary a b c -> PartialRuleBinary a b c binaryPartialUnfoldOnce partial fallback = ret where oneLevel :: (Typeable x, Typeable y) => PartialRuleBinary x y c -> TotalRuleBinary x y c -> PartialRuleBinary x y c oneLevel partial' fallback' x y = catchError (partial' x y) ( \_ -> catchError ( case x of ITETerm _ cond vt vf -> left cond vt vf y partial' fallback' _ -> Nothing ) ( \_ -> case y of ITETerm _ cond vt vf -> left cond vt vf x (flip partial') (flip fallback') _ -> Nothing ) ) left :: (Typeable x, Typeable y) => Term Bool -> Term x -> Term x -> Term y -> PartialRuleBinary x y c -> TotalRuleBinary x y c -> Maybe (Term c) left cond vt vf y partial' fallback' = let pt = partial' vt y pf = partial' vf y in case (pt, pf) of (Nothing, Nothing) -> Nothing (mt, mf) -> pevalITETerm cond <$> catchError mt (\_ -> Just $ totalize2 (oneLevel partial' fallback') fallback' vt y) <*> catchError mf (\_ -> Just $ totalize2 (oneLevel partial' fallback') fallback' vf y) ret :: PartialRuleBinary a b c ret = oneLevel partial (totalize2 @(Term a) @(Term b) @(Term c) partial fallback) binaryUnfoldOnce :: forall a b c. (Typeable a, Typeable b, SupportedPrim c) => PartialRuleBinary a b c -> TotalRuleBinary a b c -> TotalRuleBinary a b c binaryUnfoldOnce partial fallback = totalize2 (binaryPartialUnfoldOnce partial fallback) fallback