module Data.Comp.Thunk
(TermT
,CxtT
,thunk
,whnf
,whnf'
,whnfPr
,nf
,nfPr
,eval
,eval2
,deepEval
,deepEval2
,(#>)
,(#>>)
,AlgT
,cataT
,cataTM
,eqT
,strict
,strictAt) where
import Data.Comp.Algebra
import Data.Comp.Equality
import Data.Comp.Mapping
import Data.Comp.Ops ((:+:) (..), fromInr)
import Data.Comp.Sum
import Data.Comp.Term
import Data.Foldable hiding (and)
import qualified Data.IntSet as IntSet
import Control.Monad hiding (mapM, sequence)
import Data.Traversable
import Prelude hiding (foldl, foldl1, foldr, foldr1, mapM, sequence)
type TermT m f = Term (m :+: f)
type CxtT m h f a = Cxt h (m :+: f) a
thunk :: m (CxtT m h f a) -> CxtT m h f a
thunk = inject_ Inl
whnf :: Monad m => TermT m f -> m (f (TermT m f))
whnf (Term (Inl m)) = m >>= whnf
whnf (Term (Inr t)) = return t
whnf' :: Monad m => TermT m f -> m (TermT m f)
whnf' = liftM (inject_ Inr) . whnf
whnfPr :: (Monad m, g :<: f) => TermT m f -> m (g (TermT m f))
whnfPr t = do res <- whnf t
case proj res of
Just res' -> return res'
Nothing -> fail "projection failed"
eval :: Monad m => (f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
eval cont t = thunk $ cont' =<< whnf t
where cont' = return . cont
infixl 1 #>
(#>) :: Monad m => TermT m f -> (f (TermT m f) -> TermT m f) -> TermT m f
(#>) = flip eval
eval2 :: Monad m => (f (TermT m f) -> f (TermT m f) -> TermT m f)
-> TermT m f -> TermT m f -> TermT m f
eval2 cont x y = (\ x' -> cont x' `eval` y) `eval` x
nf :: (Monad m, Traversable f) => TermT m f -> m (Term f)
nf = liftM Term . mapM nf <=< whnf
nfPr :: (Monad m, Traversable g, g :<: f) => TermT m f -> m (Term g)
nfPr = liftM Term . mapM nfPr <=< whnfPr
deepEval :: (Traversable f, Monad m) =>
(Term f -> TermT m f) -> TermT m f -> TermT m f
deepEval cont v = case deepProject_ fromInr v of
Just v' -> cont v'
_ -> thunk $ liftM cont $ nf v
infixl 1 #>>
(#>>) :: (Monad m, Traversable f) => TermT m f -> (Term f -> TermT m f) -> TermT m f
(#>>) = flip deepEval
deepEval2 :: (Monad m, Traversable f) =>
(Term f -> Term f -> TermT m f)
-> TermT m f -> TermT m f -> TermT m f
deepEval2 cont x y = (\ x' -> cont x' `deepEval` y ) `deepEval` x
type AlgT m f g = Alg f (TermT m g)
cataTM :: forall m f a . (Traversable f, Monad m) => AlgM m f a -> TermT m f -> m a
cataTM alg = run where
run :: TermT m f -> m a
run (Term (Inl m)) = m >>= run
run (Term (Inr t)) = mapM run t >>= alg
cataT :: (Traversable f, Monad m) => Alg f a -> TermT m f -> m a
cataT alg = cataTM (return . alg)
strict :: (f :<: g, Traversable f, Monad m) => f (TermT m g) -> TermT m g
strict x = thunk $ liftM (inject_ (Inr . inj)) $ mapM whnf' x
type Pos f = forall a . f a -> [a]
strictAt :: (f :<: g, Traversable f, Monad m) => Pos f -> f (TermT m g) -> TermT m g
strictAt p s = thunk $ liftM (inject_ (Inr . inj)) $ mapM run s'
where s' = number s
isStrict (Numbered i _) = IntSet.member i $ IntSet.fromList $ map (\(Numbered i _) -> i) $ p s'
run e | isStrict e = whnf' $ unNumbered e
| otherwise = return $ unNumbered e
eqT :: (EqF f, Foldable f, Functor f, Monad m) => TermT m f -> TermT m f -> m Bool
eqT s t = do s' <- whnf s
t' <- whnf t
case eqMod s' t' of
Nothing -> return False
Just l -> liftM and $ mapM (uncurry eqT) l