{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
module Retrie.Monad
  ( 
    Retrie
  , addImports
  , apply
  , applyWithStrategy
  , applyWithUpdate
  , applyWithUpdateAndStrategy
  , focus
  , ifChanged
  , iterateR
  , query
  , queryWithUpdate
  , topDownPrune
    
  , getGroundTerms
  , liftRWST
  , runRetrie
  ) where
import Control.Monad.IO.Class
import Control.Monad.State.Strict
import Control.Monad.RWS
import Control.Monad.Writer.Strict
import Data.Foldable
import Retrie.Context
import Retrie.CPP
import Retrie.ExactPrint
import Retrie.Fixity
import Retrie.GroundTerms
import Retrie.Query
import Retrie.Replace
import Retrie.Substitution
import Retrie.SYB
import Retrie.Types
import Retrie.Universe
data Retrie a where
  Bind :: Retrie b -> (b -> Retrie a) -> Retrie a
  Inst :: RetrieInstruction a -> Retrie a
  Pure :: a -> Retrie a
data RetrieInstruction a where
  Focus :: [GroundTerms] -> RetrieInstruction ()
  Tell :: Change -> RetrieInstruction ()
  IfChanged :: Retrie () -> Retrie () -> RetrieInstruction ()
  Compute :: RetrieComp a -> RetrieInstruction a
type RetrieComp = RWST FixityEnv Change (CPP AnnotatedModule) IO
singleton :: RetrieInstruction a -> Retrie a
singleton = Inst
liftRWST :: RetrieComp a -> Retrie a
liftRWST = singleton . Compute
data RetrieView a where
  Return :: a -> RetrieView a
  (:>>=) :: RetrieInstruction b -> (b -> Retrie a) -> RetrieView a
view :: Retrie a -> RetrieView a
view (Pure x) = Return x
view (Inst inst) = inst :>>= return
view (Bind (Pure x) k) = view (k x)
view (Bind (Inst inst) k) = inst :>>= k
view (Bind (Bind m k1) k2) = view (Bind m (k1 >=> k2))
instance Functor Retrie where
  fmap = liftM
instance Applicative Retrie where
  pure = Pure
  (<*>) = ap
instance Monad Retrie where
  return = Pure
  (>>=) = Bind
instance MonadIO Retrie where
  liftIO = singleton . Compute . liftIO
runRetrie
  :: FixityEnv
  -> Retrie a
  -> CPP AnnotatedModule
  -> IO (a, CPP AnnotatedModule, Change)
runRetrie fixities retrie = runRWST (getComp retrie) fixities
getGroundTerms :: Retrie a -> [GroundTerms]
getGroundTerms = eval . view
  where
    eval :: RetrieView a -> [GroundTerms]
    eval Return{} = [] 
    eval (inst :>>= k) =
      case inst of
        Focus gts -> gts
        Tell _ -> getGroundTerms $ k ()
        IfChanged retrie1 retrie2
          | gts@(_:_) <- getGroundTerms retrie1 -> gts
          | gts@(_:_) <- getGroundTerms retrie2 -> gts
          | otherwise -> getGroundTerms $ k ()
        
        
        
        Compute _ -> []
getComp :: Retrie a -> RetrieComp a
getComp = eval . view
  where
    eval (Return x) = return x
    eval (inst :>>= k) = evalInst inst >>= getComp . k
    evalInst (Focus _) = return ()
    evalInst (Tell c) = tell c
    evalInst (IfChanged r1 r2) = ifChangedComp (getComp r1) (getComp r2)
    evalInst (Compute m) = m
focus :: Data k => [Query k v] -> Retrie ()
focus [] = return ()
focus qs = singleton $ Focus $ map groundTerms qs
apply :: [Rewrite Universe] -> Retrie ()
apply = applyWithUpdateAndStrategy updateContext topDownPrune
applyWithUpdate
  :: ContextUpdater -> [Rewrite Universe] -> Retrie ()
applyWithUpdate updCtxt = applyWithUpdateAndStrategy updCtxt topDownPrune
applyWithStrategy
  :: Strategy (TransformT (WriterT Change IO))
  -> [Rewrite Universe]
  -> Retrie ()
applyWithStrategy = applyWithUpdateAndStrategy updateContext
applyWithUpdateAndStrategy
  :: ContextUpdater
  -> Strategy (TransformT (WriterT Change IO))
  -> [Rewrite Universe]
  -> Retrie ()
applyWithUpdateAndStrategy _       _        []  = return ()
applyWithUpdateAndStrategy updCtxt strategy rrs = do
  focus rrs
  singleton $ Compute $ rs $ \ fixityEnv ->
    traverse $ flip transformA $
      everywhereMWithContextBut strategy
        (const False) updCtxt replace (emptyContext fixityEnv m d)
  where
    m = foldMap mkRewriter rrs
    d = foldMap mkRewriter $ rewritesWithDependents rrs
query :: [Query Universe v] -> Retrie [(Context, Substitution, v)]
query = queryWithUpdate updateContext
queryWithUpdate
  :: ContextUpdater
  -> [Query Universe v]
  -> Retrie [(Context, Substitution, v)]
queryWithUpdate _       [] = return []
queryWithUpdate updCtxt qs = do
  focus qs
  singleton $ Compute $ do
    fixityEnv <- ask
    cpp <- get
    results <- lift $ forM (toList cpp) $ \modl -> do
      annotatedResults <- transformA modl $
        everythingMWithContextBut
          (const False)
          updCtxt
          (genericQ matcher)
          (emptyContext fixityEnv mempty mempty)
      return (astA annotatedResults)
    return $ concat results
  where
    matcher = foldMap mkMatcher qs
ifChanged :: Retrie () -> Retrie () -> Retrie ()
ifChanged r1 r2 = singleton $ IfChanged r1 r2
ifChangedComp :: RetrieComp () -> RetrieComp () -> RetrieComp ()
ifChangedComp r1 r2 = do
  (_, c) <- listen r1
  case c of
    Change{} -> r2
    NoChange  -> return ()
iterateR :: Int -> Retrie () -> Retrie ()
iterateR n r = when (n > 0) $ ifChanged r $ iterateR (n-1) r
addImports :: AnnotatedImports -> Retrie ()
addImports imports = singleton $ Tell $ Change [] [imports]
topDownPrune :: Monad m => Strategy (TransformT (WriterT Change m))
topDownPrune p cs x = do
  (p', c) <- listenTransformT (p x)
  case c of
    Change{} -> return p'
    NoChange  -> cs x
listenTransformT
  :: (Monad m, Monoid w)
  => TransformT (WriterT w m) a -> TransformT (WriterT w m) (a, w)
listenTransformT (TransformT rwst) =
  TransformT $ RWST $ \ r s -> do
    ((x,y,z),w) <- listen $ runRWST rwst r s
    return ((x,w),y,z) 
rs :: Monad m => (r -> s -> WriterT w m s) -> RWST r w s m ()
rs f = RWST $ \ r s -> do
  (s', w) <- runWriterT (f r s)
  return ((), s', w)