{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Data.Comp.Thunk
(TermT
,CxtT
,thunk
,whnf
,whnf'
,whnfPr
,nf
,nfPr
,eval
,eval2
,deepEval
,deepEval2
,(#>)
,(#>>)
,AlgT
,cataT
,cataTM
,eqT
,strict
,strictAt) where
import Data.Comp.Algebra
import Data.Comp.Equality
import Data.Comp.Mapping
import Data.Comp.Ops ((:+:) (..), fromInr)
import Data.Comp.Sum
import Data.Comp.Term
import Data.Foldable hiding (and)
import qualified Data.IntSet as IntSet
import Control.Monad hiding (mapM, sequence)
import Data.Traversable
#if !MIN_VERSION_base(4,13,0)
import Control.Monad.Fail (MonadFail)
#endif
import Prelude hiding (foldl, foldl1, foldr, foldr1, mapM, sequence)
type TermT m f = Term (m :+: f)
type CxtT m h f a = Cxt h (m :+: f) a
thunk :: m (CxtT m h f a) -> CxtT m h f a
thunk :: m (CxtT m h f a) -> CxtT m h f a
thunk = SigFun m (m :+: f) -> m (CxtT m h f a) -> CxtT m h f a
forall (g :: * -> *) (f :: * -> *) h a.
SigFun g f -> g (Cxt h f a) -> Cxt h f a
inject_ SigFun m (m :+: f)
forall k (f :: k -> *) (g :: k -> *) (e :: k). f e -> (:+:) f g e
Inl
whnf :: Monad m => TermT m f -> m (f (TermT m f))
whnf :: TermT m f -> m (f (TermT m f))
whnf (Term (Inl m (TermT m f)
m)) = m (TermT m f)
m m (TermT m f)
-> (TermT m f -> m (f (TermT m f))) -> m (f (TermT m f))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TermT m f -> m (f (TermT m f))
forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf
whnf (Term (Inr f (TermT m f)
t)) = f (TermT m f) -> m (f (TermT m f))
forall (m :: * -> *) a. Monad m => a -> m a
return f (TermT m f)
t
whnf' :: Monad m => TermT m f -> m (TermT m f)
whnf' :: TermT m f -> m (TermT m f)
whnf' = (f (TermT m f) -> TermT m f) -> m (f (TermT m f)) -> m (TermT m f)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (SigFun f (m :+: f) -> f (TermT m f) -> TermT m f
forall (g :: * -> *) (f :: * -> *) h a.
SigFun g f -> g (Cxt h f a) -> Cxt h f a
inject_ SigFun f (m :+: f)
forall k (f :: k -> *) (g :: k -> *) (e :: k). g e -> (:+:) f g e
Inr) (m (f (TermT m f)) -> m (TermT m f))
-> (TermT m f -> m (f (TermT m f))) -> TermT m f -> m (TermT m f)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermT m f -> m (f (TermT m f))
forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf
whnfPr :: (MonadFail m, g :<: f) => TermT m f -> m (g (TermT m f))
whnfPr :: TermT m f -> m (g (TermT m f))
whnfPr TermT m f
t = do f (TermT m f)
res <- TermT m f -> m (f (TermT m f))
forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf TermT m f
t
case f (TermT m f) -> Maybe (g (TermT m f))
forall (f :: * -> *) (g :: * -> *) a.
(f :<: g) =>
g a -> Maybe (f a)
proj f (TermT m f)
res of
Just g (TermT m f)
res' -> g (TermT m f) -> m (g (TermT m f))
forall (m :: * -> *) a. Monad m => a -> m a
return g (TermT m f)
res'
Maybe (g (TermT m f))
Nothing -> String -> m (g (TermT m f))
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"projection failed"
eval :: Monad m => (f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
eval :: (f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
eval f (TermT m f) -> TermT m f
cont TermT m f
t = m (TermT m f) -> TermT m f
forall (m :: * -> *) h (f :: * -> *) a.
m (CxtT m h f a) -> CxtT m h f a
thunk (m (TermT m f) -> TermT m f) -> m (TermT m f) -> TermT m f
forall a b. (a -> b) -> a -> b
$ f (TermT m f) -> m (TermT m f)
cont' (f (TermT m f) -> m (TermT m f))
-> m (f (TermT m f)) -> m (TermT m f)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TermT m f -> m (f (TermT m f))
forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf TermT m f
t
where cont' :: f (TermT m f) -> m (TermT m f)
cont' = TermT m f -> m (TermT m f)
forall (m :: * -> *) a. Monad m => a -> m a
return (TermT m f -> m (TermT m f))
-> (f (TermT m f) -> TermT m f) -> f (TermT m f) -> m (TermT m f)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (TermT m f) -> TermT m f
cont
infixl 1 #>
(#>) :: Monad m => TermT m f -> (f (TermT m f) -> TermT m f) -> TermT m f
#> :: TermT m f -> (f (TermT m f) -> TermT m f) -> TermT m f
(#>) = ((f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f)
-> TermT m f -> (f (TermT m f) -> TermT m f) -> TermT m f
forall a b c. (a -> b -> c) -> b -> a -> c
flip (f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
forall (m :: * -> *) (f :: * -> *).
Monad m =>
(f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
eval
eval2 :: Monad m => (f (TermT m f) -> f (TermT m f) -> TermT m f)
-> TermT m f -> TermT m f -> TermT m f
eval2 :: (f (TermT m f) -> f (TermT m f) -> TermT m f)
-> TermT m f -> TermT m f -> TermT m f
eval2 f (TermT m f) -> f (TermT m f) -> TermT m f
cont TermT m f
x TermT m f
y = (\ f (TermT m f)
x' -> f (TermT m f) -> f (TermT m f) -> TermT m f
cont f (TermT m f)
x' (f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
forall (m :: * -> *) (f :: * -> *).
Monad m =>
(f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
`eval` TermT m f
y) (f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
forall (m :: * -> *) (f :: * -> *).
Monad m =>
(f (TermT m f) -> TermT m f) -> TermT m f -> TermT m f
`eval` TermT m f
x
nf :: (Monad m, Traversable f) => TermT m f -> m (Term f)
nf :: TermT m f -> m (Term f)
nf = (f (Term f) -> Term f) -> m (f (Term f)) -> m (Term f)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM f (Term f) -> Term f
forall (f :: * -> *) h a. f (Cxt h f a) -> Cxt h f a
Term (m (f (Term f)) -> m (Term f))
-> (f (TermT m f) -> m (f (Term f))) -> f (TermT m f) -> m (Term f)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TermT m f -> m (Term f)) -> f (TermT m f) -> m (f (Term f))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TermT m f -> m (Term f)
forall (m :: * -> *) (f :: * -> *).
(Monad m, Traversable f) =>
TermT m f -> m (Term f)
nf (f (TermT m f) -> m (Term f))
-> (TermT m f -> m (f (TermT m f))) -> TermT m f -> m (Term f)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TermT m f -> m (f (TermT m f))
forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf
nfPr :: (MonadFail m, Traversable g, g :<: f) => TermT m f -> m (Term g)
nfPr :: TermT m f -> m (Term g)
nfPr = (g (Term g) -> Term g) -> m (g (Term g)) -> m (Term g)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM g (Term g) -> Term g
forall (f :: * -> *) h a. f (Cxt h f a) -> Cxt h f a
Term (m (g (Term g)) -> m (Term g))
-> (g (TermT m f) -> m (g (Term g))) -> g (TermT m f) -> m (Term g)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TermT m f -> m (Term g)) -> g (TermT m f) -> m (g (Term g))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TermT m f -> m (Term g)
forall (m :: * -> *) (g :: * -> *) (f :: * -> *).
(MonadFail m, Traversable g, g :<: f) =>
TermT m f -> m (Term g)
nfPr (g (TermT m f) -> m (Term g))
-> (TermT m f -> m (g (TermT m f))) -> TermT m f -> m (Term g)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TermT m f -> m (g (TermT m f))
forall (m :: * -> *) (g :: * -> *) (f :: * -> *).
(MonadFail m, g :<: f) =>
TermT m f -> m (g (TermT m f))
whnfPr
deepEval :: (Traversable f, Monad m) =>
(Term f -> TermT m f) -> TermT m f -> TermT m f
deepEval :: (Term f -> TermT m f) -> TermT m f -> TermT m f
deepEval Term f -> TermT m f
cont TermT m f
v = case SigFunM Maybe (m :+: f) f -> TermT m f -> Maybe (Term f)
forall (g :: * -> *) (f :: * -> *).
Traversable g =>
SigFunM Maybe f g -> CxtFunM Maybe f g
deepProject_ SigFunM Maybe (m :+: f) f
forall k (f :: k -> *) (g :: k -> *) (e :: k).
(:+:) f g e -> Maybe (g e)
fromInr TermT m f
v of
Just Term f
v' -> Term f -> TermT m f
cont Term f
v'
Maybe (Term f)
_ -> m (TermT m f) -> TermT m f
forall (m :: * -> *) h (f :: * -> *) a.
m (CxtT m h f a) -> CxtT m h f a
thunk (m (TermT m f) -> TermT m f) -> m (TermT m f) -> TermT m f
forall a b. (a -> b) -> a -> b
$ (Term f -> TermT m f) -> m (Term f) -> m (TermT m f)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Term f -> TermT m f
cont (m (Term f) -> m (TermT m f)) -> m (Term f) -> m (TermT m f)
forall a b. (a -> b) -> a -> b
$ TermT m f -> m (Term f)
forall (m :: * -> *) (f :: * -> *).
(Monad m, Traversable f) =>
TermT m f -> m (Term f)
nf TermT m f
v
infixl 1 #>>
(#>>) :: (Monad m, Traversable f) => TermT m f -> (Term f -> TermT m f) -> TermT m f
#>> :: TermT m f -> (Term f -> TermT m f) -> TermT m f
(#>>) = ((Term f -> TermT m f) -> TermT m f -> TermT m f)
-> TermT m f -> (Term f -> TermT m f) -> TermT m f
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Term f -> TermT m f) -> TermT m f -> TermT m f
forall (f :: * -> *) (m :: * -> *).
(Traversable f, Monad m) =>
(Term f -> TermT m f) -> TermT m f -> TermT m f
deepEval
deepEval2 :: (Monad m, Traversable f) =>
(Term f -> Term f -> TermT m f)
-> TermT m f -> TermT m f -> TermT m f
deepEval2 :: (Term f -> Term f -> TermT m f)
-> TermT m f -> TermT m f -> TermT m f
deepEval2 Term f -> Term f -> TermT m f
cont TermT m f
x TermT m f
y = (\ Term f
x' -> Term f -> Term f -> TermT m f
cont Term f
x' (Term f -> TermT m f) -> TermT m f -> TermT m f
forall (f :: * -> *) (m :: * -> *).
(Traversable f, Monad m) =>
(Term f -> TermT m f) -> TermT m f -> TermT m f
`deepEval` TermT m f
y ) (Term f -> TermT m f) -> TermT m f -> TermT m f
forall (f :: * -> *) (m :: * -> *).
(Traversable f, Monad m) =>
(Term f -> TermT m f) -> TermT m f -> TermT m f
`deepEval` TermT m f
x
type AlgT m f g = Alg f (TermT m g)
cataTM :: forall m f a . (Traversable f, Monad m) => AlgM m f a -> TermT m f -> m a
cataTM :: AlgM m f a -> TermT m f -> m a
cataTM AlgM m f a
alg = TermT m f -> m a
run where
run :: TermT m f -> m a
run :: TermT m f -> m a
run (Term (Inl m (TermT m f)
m)) = m (TermT m f)
m m (TermT m f) -> (TermT m f -> m a) -> m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TermT m f -> m a
run
run (Term (Inr f (TermT m f)
t)) = (TermT m f -> m a) -> f (TermT m f) -> m (f a)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TermT m f -> m a
run f (TermT m f)
t m (f a) -> AlgM m f a -> m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= AlgM m f a
alg
cataT :: (Traversable f, Monad m) => Alg f a -> TermT m f -> m a
cataT :: Alg f a -> TermT m f -> m a
cataT Alg f a
alg = AlgM m f a -> TermT m f -> m a
forall (m :: * -> *) (f :: * -> *) a.
(Traversable f, Monad m) =>
AlgM m f a -> TermT m f -> m a
cataTM (a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> m a) -> Alg f a -> AlgM m f a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alg f a
alg)
strict :: (f :<: g, Traversable f, Monad m) => f (TermT m g) -> TermT m g
strict :: f (TermT m g) -> TermT m g
strict f (TermT m g)
x = m (TermT m g) -> TermT m g
forall (m :: * -> *) h (f :: * -> *) a.
m (CxtT m h f a) -> CxtT m h f a
thunk (m (TermT m g) -> TermT m g) -> m (TermT m g) -> TermT m g
forall a b. (a -> b) -> a -> b
$ (f (TermT m g) -> TermT m g) -> m (f (TermT m g)) -> m (TermT m g)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (SigFun f (m :+: g) -> f (TermT m g) -> TermT m g
forall (g :: * -> *) (f :: * -> *) h a.
SigFun g f -> g (Cxt h f a) -> Cxt h f a
inject_ (g a -> (:+:) m g a
forall k (f :: k -> *) (g :: k -> *) (e :: k). g e -> (:+:) f g e
Inr (g a -> (:+:) m g a) -> (f a -> g a) -> f a -> (:+:) m g a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> g a
forall (f :: * -> *) (g :: * -> *) a. (f :<: g) => f a -> g a
inj)) (m (f (TermT m g)) -> m (TermT m g))
-> m (f (TermT m g)) -> m (TermT m g)
forall a b. (a -> b) -> a -> b
$ (TermT m g -> m (TermT m g)) -> f (TermT m g) -> m (f (TermT m g))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TermT m g -> m (TermT m g)
forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (TermT m f)
whnf' f (TermT m g)
x
type Pos f = forall a . f a -> [a]
strictAt :: (f :<: g, Traversable f, Monad m) => Pos f -> f (TermT m g) -> TermT m g
strictAt :: Pos f -> f (TermT m g) -> TermT m g
strictAt Pos f
p f (TermT m g)
s = m (TermT m g) -> TermT m g
forall (m :: * -> *) h (f :: * -> *) a.
m (CxtT m h f a) -> CxtT m h f a
thunk (m (TermT m g) -> TermT m g) -> m (TermT m g) -> TermT m g
forall a b. (a -> b) -> a -> b
$ (f (TermT m g) -> TermT m g) -> m (f (TermT m g)) -> m (TermT m g)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (SigFun f (m :+: g) -> f (TermT m g) -> TermT m g
forall (g :: * -> *) (f :: * -> *) h a.
SigFun g f -> g (Cxt h f a) -> Cxt h f a
inject_ (g a -> (:+:) m g a
forall k (f :: k -> *) (g :: k -> *) (e :: k). g e -> (:+:) f g e
Inr (g a -> (:+:) m g a) -> (f a -> g a) -> f a -> (:+:) m g a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> g a
forall (f :: * -> *) (g :: * -> *) a. (f :<: g) => f a -> g a
inj)) (m (f (TermT m g)) -> m (TermT m g))
-> m (f (TermT m g)) -> m (TermT m g)
forall a b. (a -> b) -> a -> b
$ (Numbered (TermT m g) -> m (TermT m g))
-> f (Numbered (TermT m g)) -> m (f (TermT m g))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Numbered (TermT m g) -> m (TermT m g)
forall (m :: * -> *) (f :: * -> *).
Monad m =>
Numbered (TermT m f) -> m (TermT m f)
run f (Numbered (TermT m g))
s'
where s' :: f (Numbered (TermT m g))
s' = f (TermT m g) -> f (Numbered (TermT m g))
forall (f :: * -> *) a. Traversable f => f a -> f (Numbered a)
number f (TermT m g)
s
isStrict :: Numbered a -> Bool
isStrict (Numbered Int
i a
_) = Int -> IntSet -> Bool
IntSet.member Int
i (IntSet -> Bool) -> IntSet -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> IntSet
IntSet.fromList ([Int] -> IntSet) -> [Int] -> IntSet
forall a b. (a -> b) -> a -> b
$ (Numbered (TermT m g) -> Int) -> [Numbered (TermT m g)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (\(Numbered Int
i TermT m g
_) -> Int
i) ([Numbered (TermT m g)] -> [Int])
-> [Numbered (TermT m g)] -> [Int]
forall a b. (a -> b) -> a -> b
$ f (Numbered (TermT m g)) -> [Numbered (TermT m g)]
Pos f
p f (Numbered (TermT m g))
s'
run :: Numbered (TermT m f) -> m (TermT m f)
run Numbered (TermT m f)
e | Numbered (TermT m f) -> Bool
forall a. Numbered a -> Bool
isStrict Numbered (TermT m f)
e = TermT m f -> m (TermT m f)
forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (TermT m f)
whnf' (TermT m f -> m (TermT m f)) -> TermT m f -> m (TermT m f)
forall a b. (a -> b) -> a -> b
$ Numbered (TermT m f) -> TermT m f
forall a. Numbered a -> a
unNumbered Numbered (TermT m f)
e
| Bool
otherwise = TermT m f -> m (TermT m f)
forall (m :: * -> *) a. Monad m => a -> m a
return (TermT m f -> m (TermT m f)) -> TermT m f -> m (TermT m f)
forall a b. (a -> b) -> a -> b
$ Numbered (TermT m f) -> TermT m f
forall a. Numbered a -> a
unNumbered Numbered (TermT m f)
e
eqT :: (EqF f, Foldable f, Functor f, Monad m) => TermT m f -> TermT m f -> m Bool
eqT :: TermT m f -> TermT m f -> m Bool
eqT TermT m f
s TermT m f
t = do f (TermT m f)
s' <- TermT m f -> m (f (TermT m f))
forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf TermT m f
s
f (TermT m f)
t' <- TermT m f -> m (f (TermT m f))
forall (m :: * -> *) (f :: * -> *).
Monad m =>
TermT m f -> m (f (TermT m f))
whnf TermT m f
t
case f (TermT m f) -> f (TermT m f) -> Maybe [(TermT m f, TermT m f)]
forall (f :: * -> *) a b.
(EqF f, Functor f, Foldable f) =>
f a -> f b -> Maybe [(a, b)]
eqMod f (TermT m f)
s' f (TermT m f)
t' of
Maybe [(TermT m f, TermT m f)]
Nothing -> Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
Just [(TermT m f, TermT m f)]
l -> ([Bool] -> Bool) -> m [Bool] -> m Bool
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and (m [Bool] -> m Bool) -> m [Bool] -> m Bool
forall a b. (a -> b) -> a -> b
$ ((TermT m f, TermT m f) -> m Bool)
-> [(TermT m f, TermT m f)] -> m [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TermT m f -> TermT m f -> m Bool)
-> (TermT m f, TermT m f) -> m Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry TermT m f -> TermT m f -> m Bool
forall (f :: * -> *) (m :: * -> *).
(EqF f, Foldable f, Functor f, Monad m) =>
TermT m f -> TermT m f -> m Bool
eqT) [(TermT m f, TermT m f)]
l