module Data.Comp.Thunk
(TermT
,CxtT
,thunk
,whnf
,whnf'
,whnfPr
,nf
,nfPr
,eval
,eval2
,deepEval
,deepEval2
,(#>)
,(#>>)
,AlgT
,cataT
,cataTM
,eqT
,strict
,strictAt) where
import Data.Comp.Term
import Data.Comp.Equality
import Data.Comp.Algebra
import Data.Comp.Ops
import Data.Comp.Sum
import Data.Comp.Number
import Data.Foldable hiding (and)
import qualified Data.Set as Set
import Data.Traversable
import Control.Monad hiding (sequence,mapM)
import Prelude hiding (foldr, foldl,foldr1, foldl1,sequence,mapM)
type TermT m f = Term (m :+: f)
type CxtT m h f a = Cxt h (m :+: f) a
thunk :: (m :<: f) => m (Cxt h f a) -> Cxt h f a
thunk = inject
whnf :: Monad m => TermT m f -> m (f (TermT m f))
whnf (Term (Inl m)) = m >>= whnf
whnf (Term (Inr t)) = return t
whnf' :: Monad m => TermT m f -> m (TermT m f)
whnf' = liftM inject . whnf
whnfPr :: (Monad m, g :<: f) => TermT m f -> m (g (TermT m f))
whnfPr t = do res <- whnf t
case proj res of
Just res' -> return res'
Nothing -> fail "projection failed"
eval :: Monad m => (f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
eval cont t = thunk $ cont' =<< whnf t
where cont' = return . cont
infixl 1 #>
(#>) :: Monad m => TermT m f -> (f (TermT m f) -> TermT m f) -> TermT m f
(#>) = flip eval
eval2 :: Monad m => (f (TermT m f) -> f (TermT m f) -> TermT m f)
-> TermT m f -> TermT m f -> TermT m f
eval2 cont x y = (\ x' -> cont x' `eval` y) `eval` x
nf :: (Monad m, Traversable f) => TermT m f -> m (Term f)
nf = liftM Term . mapM nf <=< whnf
nfPr :: (Monad m, Traversable g, g :<: f) => TermT m f -> m (Term g)
nfPr = liftM Term . mapM nfPr <=< whnfPr
deepEval :: (Traversable f, Monad m) =>
(Term f -> TermT m f) -> TermT m f -> TermT m f
deepEval cont v = case deepProject v of
Just v' -> cont v'
_ -> thunk $ liftM cont $ nf v
infixl 1 #>>
(#>>) :: (Monad m, Traversable f) => TermT m f -> (Term f -> TermT m f) -> TermT m f
(#>>) = flip deepEval
deepEval2 :: (Monad m, Traversable f) =>
(Term f -> Term f -> TermT m f)
-> TermT m f -> TermT m f -> TermT m f
deepEval2 cont x y = (\ x' -> cont x' `deepEval` y ) `deepEval` x
type AlgT m f g = Alg f (TermT m g)
cataTM :: forall m f a . (Traversable f, Monad m) => AlgM m f a -> TermT m f -> m a
cataTM alg = run where
run :: TermT m f -> m a
run (Term (Inl m)) = m >>= run
run (Term (Inr t)) = mapM run t >>= alg
cataT :: (Traversable f, Monad m) => Alg f a -> TermT m f -> m a
cataT alg = cataTM (return . alg)
strict :: (f :<: g, Traversable f, Monad m) => f (TermT m g) -> TermT m g
strict x = thunk $ liftM inject $ mapM whnf' x
type Pos f = forall a . Ord a => f a -> [a]
strictAt :: (f :<: g, Traversable f, Monad m) => Pos f -> f (TermT m g) -> TermT m g
strictAt p s = thunk $ liftM inject $ mapM run s'
where s' = number s
isStrict e = Set.member e $ Set.fromList $ p s'
run e | isStrict e = whnf' $ unNumbered e
| otherwise = return $ unNumbered e
eqT :: (EqF f, Foldable f, Functor f, Monad m) => TermT m f -> TermT m f -> m Bool
eqT s t = do s' <- whnf s
t' <- whnf t
case eqMod s' t' of
Nothing -> return False
Just l -> liftM and $ mapM (uncurry eqT) l