{-# LANGUAGE UndecidableInstances, LambdaCase, ParallelListComp, ViewPatterns, ImpredicativeTypes #-}
module Data.Syntax where

import Definitive

newtype Lambda n m = Lambda (Maybe String,ThunkT n m () -> ThunkT n m ())
instance Show (Lambda n m) where
  show (Lambda (b,_)) = show b
instance Semigroup (Lambda n m) where
  Lambda (_,f) + Lambda (_,g) = Lambda (Nothing,g.f)

class (Foldable (n (Lambda n m)),Traversable (n (Lambda n m))) => NodeFunctor n m 

newtype SyntaxT n m a = SyntaxT (Free (n (Lambda n m)) a)
                      deriving Unit
deriving instance NodeFunctor n m => Functor (SyntaxT n m)
deriving instance NodeFunctor n m => Applicative (SyntaxT n m)
instance NodeFunctor n m => Monad (SyntaxT n m) where join = coerceJoin SyntaxT
deriving instance NodeFunctor n m => Foldable (SyntaxT n m)
instance NodeFunctor n m => Traversable (SyntaxT n m) where sequence = coerceSeq SyntaxT
instance MonadFree (n (Lambda n m)) (SyntaxT n m) where
  step = coerceStep SyntaxT ; perform = coercePerform SyntaxT ; liftF = coerceLiftF SyntaxT

newtype ThunkT n m a = ThunkT (Free (m:.:SyntaxT n m) a)
                        deriving Unit

deriving instance (NodeFunctor n m,Functor m) => Functor (ThunkT n m)
deriving instance (NodeFunctor n m,Functor m) => Applicative (ThunkT n m)
deriving instance (NodeFunctor n m,Foldable m) => Foldable (ThunkT n m)
type Thunk n a = ThunkT n Id a
instance (NodeFunctor n m,Monad m) => Monad (ThunkT n m) where join = coerceJoin ThunkT
instance (NodeFunctor n m,Traversable m) => Traversable (ThunkT n m) where sequence = coerceSeq ThunkT
instance (NodeFunctor n m,Monad m) => MonadFree (m:.:SyntaxT n m) (ThunkT n m) where
  step = coerceStep ThunkT ; perform = coercePerform ThunkT ; liftF = coerceLiftF ThunkT
deriving instance (NodeFunctor n m,MonadReader r m) => MonadReader r (ThunkT n m)
deriving instance (NodeFunctor n m,MonadState s m) => MonadState s (ThunkT n m)
deriving instance (NodeFunctor n m,MonadWriter w m) => MonadWriter w (ThunkT n m)

type NodeT n m a = n (Lambda n m) (ThunkT n m a)

force :: (NodeFunctor n m,Monad m) => ThunkT n m a -> ThunkT n m (NodeT n m a)
force = liftF . force'
  where force' t = emerge (step t) >>= \(SyntaxT s) -> case s of
          Pure t' -> force' t'
          Join j -> pure (liftS . SyntaxT<$>j)
forcing :: (NodeFunctor n m,Monad m) => (NodeT n m a -> ThunkT n m a) -> ThunkT n m a -> ThunkT n m a
forcing f x = force x >>= f

liftN :: (NodeFunctor n m,Unit m) => NodeT n m a -> SyntaxT n m (ThunkT n m a)
liftN = SyntaxT . liftF
liftS :: (NodeFunctor n m,Unit m) => SyntaxT n m (ThunkT n m a) -> ThunkT n m a
liftS = ThunkT . Join . Compose . pure . map (\(ThunkT t) -> t)
liftNS :: (NodeFunctor n m,Unit m) => NodeT n m a -> ThunkT n m a
liftNS = liftS . liftN