module Data.Comp.Param.Thunk
(TermT
,TrmT
,CxtT
,Thunk
,thunk
,whnf
,whnf'
,whnfPr
,nf
,nfT
,nfPr
,nfTPr
,evalStrict
,AlgT
,strict
,strict')
where
import Data.Comp.Param.Term
import Data.Comp.Param.Sum
import Data.Comp.Param.Ops
import Data.Comp.Param.Algebra
import Data.Comp.Param.Ditraversable
import Data.Comp.Param.Difunctor
import Control.Monad
type TermT m f = Term (Thunk m :+: f)
type TrmT m f a = Trm (Thunk m :+: f) a
type CxtT h m f a = Cxt h (Thunk m :+: f) a
newtype Thunk m a b = Thunk (m b)
thunk :: (Thunk m :<: f) => m (Cxt h f a b) -> Cxt h f a b
thunk = inject . Thunk
whnf :: Monad m => TrmT m f a -> m (Either a (f a (TrmT m f a)))
whnf (In (Inl (Thunk m))) = m >>= whnf
whnf (In (Inr t)) = return $ Right t
whnf (Var x) = return $ Left x
whnf' :: Monad m => TrmT m f a -> m (TrmT m f a)
whnf' = liftM (either Var inject) . whnf
whnfPr :: (Monad m, g :<: f) => TrmT m f a -> m (g a (TrmT m f a))
whnfPr t = do res <- whnf t
case res of
Left _ -> fail "cannot project variable"
Right t ->
case proj t of
Just res' -> return res'
Nothing -> fail "projection failed"
nfT :: (ParamFunctor m, Monad m, Ditraversable f) => TermT m f -> m (Term f)
nfT t = termM $ nf $ unTerm t
nf :: (Monad m, Ditraversable f) => TrmT m f a -> m (Trm f a)
nf = either (return . Var) (liftM In . dimapM nf) <=< whnf
nfTPr :: (ParamFunctor m, Monad m, Ditraversable g, g :<: f) => TermT m f -> m (Term g)
nfTPr t = termM $ nfPr $ unTerm t
nfPr :: (Monad m, Ditraversable g, g :<: f) => TrmT m f a -> m (Trm g a)
nfPr = liftM In . dimapM nfPr <=< whnfPr
evalStrict :: (Ditraversable g, Monad m, g :<: f) =>
(g (TrmT m f a) (f a (TrmT m f a)) -> TrmT m f a)
-> g (TrmT m f a) (TrmT m f a) -> TrmT m f a
evalStrict cont t = thunk $ do
t' <- dimapM (liftM (either (const Nothing) Just) . whnf) t
case disequence t' of
Nothing -> return $ inject' t
Just s -> return $ cont s
type AlgT m f g = Alg f (TermT m g)
strict :: (f :<: g, Ditraversable f, Monad m) => f a (TrmT m g a) -> TrmT m g a
strict x = thunk $ liftM inject $ dimapM whnf' x
strict' :: (f :<: g, Ditraversable f, Monad m) => f (TrmT m g a) (TrmT m g a) -> TrmT m g a
strict' = strict . dimap Var id