module Data.Comp.Thunk
(TermT
,CxtT
,thunk
,injectT
,whnf
,whnf'
,whnfPr
,nf
,nfPr
,eval
,eval2
,deepEval
,deepEval2
,(#>)
,(#>>)
,AlgT
,cataT
,cataTM
,eqT
,strict
,strictAt) where
import Data.Comp.Term
import Data.Comp.Equality
import Data.Comp.Algebra
import Data.Comp.Ops ((:+:)(..), fromInr)
import Data.Comp.Sum
import Data.Comp.Number
import Data.Foldable hiding (and)
import qualified Data.Set as Set
import Data.Traversable
import Control.Monad hiding (sequence,mapM)
import Prelude hiding (foldr, foldl,foldr1, foldl1,sequence,mapM)
type TermT m f = Term (m :+: f)
type CxtT m h f a = Cxt h (m :+: f) a
thunk :: m (CxtT m h f a) -> CxtT m h f a
thunk = inject_ Inl
injectT :: (g :<: f) => g (CxtT m h f a) -> CxtT m h f a
injectT = inject_ (Inr . inj)
whnf :: Monad m => TermT m f -> m (f (TermT m f))
whnf (Term (Inl m)) = m >>= whnf
whnf (Term (Inr t)) = return t
whnf' :: Monad m => TermT m f -> m (TermT m f)
whnf' = liftM (inject_ Inr) . whnf
whnfPr :: (Monad m, g :<: f) => TermT m f -> m (g (TermT m f))
whnfPr t = do res <- whnf t
case proj res of
Just res' -> return res'
Nothing -> fail "projection failed"
eval :: Monad m => (f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
eval cont t = thunk $ cont' =<< whnf t
where cont' = return . cont
infixl 1 #>
(#>) :: Monad m => TermT m f -> (f (TermT m f) -> TermT m f) -> TermT m f
(#>) = flip eval
eval2 :: Monad m => (f (TermT m f) -> f (TermT m f) -> TermT m f)
-> TermT m f -> TermT m f -> TermT m f
eval2 cont x y = (\ x' -> cont x' `eval` y) `eval` x
nf :: (Monad m, Traversable f) => TermT m f -> m (Term f)
nf = liftM Term . mapM nf <=< whnf
nfPr :: (Monad m, Traversable g, g :<: f) => TermT m f -> m (Term g)
nfPr = liftM Term . mapM nfPr <=< whnfPr
deepEval :: (Traversable f, Monad m) =>
(Term f -> TermT m f) -> TermT m f -> TermT m f
deepEval cont v = case deepProject_ fromInr v of
Just v' -> cont v'
_ -> thunk $ liftM cont $ nf v
infixl 1 #>>
(#>>) :: (Monad m, Traversable f) => TermT m f -> (Term f -> TermT m f) -> TermT m f
(#>>) = flip deepEval
deepEval2 :: (Monad m, Traversable f) =>
(Term f -> Term f -> TermT m f)
-> TermT m f -> TermT m f -> TermT m f
deepEval2 cont x y = (\ x' -> cont x' `deepEval` y ) `deepEval` x
type AlgT m f g = Alg f (TermT m g)
cataTM :: forall m f a . (Traversable f, Monad m) => AlgM m f a -> TermT m f -> m a
cataTM alg = run where
run :: TermT m f -> m a
run (Term (Inl m)) = m >>= run
run (Term (Inr t)) = mapM run t >>= alg
cataT :: (Traversable f, Monad m) => Alg f a -> TermT m f -> m a
cataT alg = cataTM (return . alg)
strict :: (f :<: g, Traversable f, Monad m) => f (TermT m g) -> TermT m g
strict x = thunk $ liftM (inject_ (Inr . inj)) $ mapM whnf' x
type Pos f = forall a . Ord a => f a -> [a]
strictAt :: (f :<: g, Traversable f, Monad m) => Pos f -> f (TermT m g) -> TermT m g
strictAt p s = thunk $ liftM (inject_ (Inr . inj)) $ mapM run s'
where s' = number s
isStrict e = Set.member e $ Set.fromList $ p s'
run e | isStrict e = whnf' $ unNumbered e
| otherwise = return $ unNumbered e
eqT :: (EqF f, Foldable f, Functor f, Monad m) => TermT m f -> TermT m f -> m Bool
eqT s t = do s' <- whnf s
t' <- whnf t
case eqMod s' t' of
Nothing -> return False
Just l -> liftM and $ mapM (uncurry eqT) l