{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE RecursiveDo         #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Language.Haskell.TH.LetRec (
    letrecE,
) where

import Control.Monad.Fix              (MonadFix)
import Control.Monad.Trans.Class      (lift)
import Control.Monad.Trans.State.Lazy (StateT, get, modify, runStateT)
import Language.Haskell.TH.Lib        (letE, normalB, valD, varE, varP)
import Language.Haskell.TH.Syntax     (Exp, Name, Quote (newName))

import qualified Data.Map.Lazy as Map

-- $setup
-- >>> :set -XTemplateHaskell
-- >>> import Language.Haskell.TH.Syntax as TH
-- >>> import Language.Haskell.TH.Lib    as TH
-- >>> import Language.Haskell.TH.Ppr    as TH

-- | Generate potentially recursive let expression.
--
-- The 'Monad' constraint in generators forces to sequence
-- binding generation calls, thus allowing to do lazy binding generation.
--
-- Example of generating a list of alternating 'True' and 'False' values:
--
-- >>> let trueFalse = letrecE (\tag -> "go" ++ show tag) (\rec tag -> rec (not tag) >>= \next -> return [| $(TH.lift tag) : $next |]) ($ True)
--
-- The generated let-bindings look like:
--
-- >>> TH.ppr <$> trueFalse
-- let {goFalse_0 = GHC.Types.False GHC.Types.: goTrue_1;
--      goTrue_1 = GHC.Types.True GHC.Types.: goFalse_0}
--  in goTrue_1
--
-- And when spliced it produces a list of alternative 'True' and 'False' values:
--
-- >>> take 10 $trueFalse
-- [True,False,True,False,True,False,True,False,True,False]
--
-- Another example where dynamic nature is visible is generating
-- fibonacci numbers:
--
-- >>> let fibRec rec tag = case tag of { 0 -> return [| 1 |]; 1 -> return [| 1 |]; _ -> do { minus1 <- rec (tag - 1); minus2 <- rec (tag - 2); return [| $minus1 + $minus2 |] }}
-- >>> let fib n = letrecE (\tag -> "fib" ++ show tag) fibRec ($ n)
--
-- The generated let-bindings look like:
-- >>> TH.ppr <$> fib 7
-- let {fib0_0 = 1;
--      fib1_1 = 1;
--      fib2_2 = fib1_1 GHC.Num.+ fib0_0;
--      fib3_3 = fib2_2 GHC.Num.+ fib1_1;
--      fib4_4 = fib3_3 GHC.Num.+ fib2_2;
--      fib5_5 = fib4_4 GHC.Num.+ fib3_3;
--      fib6_6 = fib5_5 GHC.Num.+ fib4_4;
--      fib7_7 = fib6_6 GHC.Num.+ fib5_5}
--  in fib7_7
--
-- And the result is expected:
--
-- >>> $(fib 7)
-- 21
--
letrecE
    :: forall q tag. (Ord tag, Quote q, MonadFix q)
    => (tag -> String)                                                   -- ^ tag naming function
    -> (forall m. Monad m => (tag -> m (q Exp)) -> (tag -> m (q Exp)))   -- ^ bindings generator (with recursive function)
    -> (forall m. Monad m => (tag -> m (q Exp)) -> m (q Exp))            -- ^ final expression generator
    -> q Exp                                                             -- ^ generated let expression.
letrecE :: forall (q :: * -> *) tag.
(Ord tag, Quote q, MonadFix q) =>
(tag -> String)
-> (forall (m :: * -> *).
    Monad m =>
    (tag -> m (q Exp)) -> tag -> m (q Exp))
-> (forall (m :: * -> *).
    Monad m =>
    (tag -> m (q Exp)) -> m (q Exp))
-> q Exp
letrecE tag -> String
nameOf forall (m :: * -> *).
Monad m =>
(tag -> m (q Exp)) -> tag -> m (q Exp)
recf forall (m :: * -> *). Monad m => (tag -> m (q Exp)) -> m (q Exp)
exprf = do
    (q Exp
expr0, Map tag (Name, q Exp)
bindings) <- forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall (m :: * -> *). Monad m => (tag -> m (q Exp)) -> m (q Exp)
exprf tag -> StateT (Map tag (Name, q Exp)) q (q Exp)
loop) forall k a. Map k a
Map.empty
    forall (m :: * -> *). Quote m => [m Dec] -> m Exp -> m Exp
letE
        [ forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Dec
valD (forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
name) (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB q Exp
expr) []
        | (tag
_tag, (Name
name, q Exp
expr)) <- forall k a. Map k a -> [(k, a)]
Map.toList Map tag (Name, q Exp)
bindings
        ]
        q Exp
expr0
  where
    loop :: tag -> StateT (Map.Map tag (Name, q Exp)) q (q Exp)
    loop :: tag -> StateT (Map tag (Name, q Exp)) q (q Exp)
loop tag
tag = do
        Map tag (Name, q Exp)
m <- forall (m :: * -> *) s. Monad m => StateT s m s
get
        case forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup tag
tag Map tag (Name, q Exp)
m of
            -- if name is already generated, return it.
            Just (Name
name, q Exp
_exp) -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
name)

            -- otherwise generate new name, and insert it into the loop.
            Maybe (Name, q Exp)
Nothing -> mdo
                Name
name <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *). Quote m => String -> m Name
newName (tag -> String
nameOf tag
tag))
                forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert tag
tag (Name
name, q Exp
expr))
                q Exp
expr <- forall (m :: * -> *).
Monad m =>
(tag -> m (q Exp)) -> tag -> m (q Exp)
recf tag -> StateT (Map tag (Name, q Exp)) q (q Exp)
loop tag
tag
                forall (m :: * -> *) a. Monad m => a -> m a
return (forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
name)