-- | Uniplate-style traversals.

{-# LANGUAGE CPP #-}
module Data.Generics.Fixplate.Traversals where

--------------------------------------------------------------------------------

import Control.Monad (liftM)
import Data.Foldable
import Data.Traversable
import Prelude hiding (foldl,foldr,mapM,mapM_,concat,concatMap)

import Data.Generics.Fixplate.Base 
import Data.Generics.Fixplate.Open
import Data.Generics.Fixplate.Misc

#ifdef WITH_QUICKCHECK
import Test.QuickCheck
import Data.Generics.Fixplate.Test.Tools
#endif

--------------------------------------------------------------------------------
-- * Queries

-- | The list of direct descendants.
children :: Foldable f => Mu f -> [Mu f]
children = foldr (:) [] . unFix

-- | The list of all substructures. Together with list-comprehension syntax
-- this is a powerful query tool. For example the following is how you get
-- the list of all variable names in an expression:
--
-- > variables expr = [ s | Fix (Var s) <- universe expr ]
--
universe :: Foldable f => Mu f -> [Mu f]
universe x = x : concatMap universe (children x)

--------------------------------------------------------------------------------
-- * Traversals

-- | Bottom-up transformation.
transform :: Functor f => (Mu f -> Mu f) -> Mu f -> Mu f 
transform h = go where 
  go = h . Fix . fmap go . unFix

transformM  ::  (Traversable f, Monad m) 
            =>  (Mu f -> m (Mu f)) -> Mu f -> m (Mu f)
transformM action = go where
  go (Fix x) = do 
    y <- mapM go x
    action (Fix y)
    
-- | Top-down transformation. This provided only for completeness;
-- usually, it is 'transform' what you want use instead.
topDownTransform :: Functor f => (Mu f -> Mu f) -> Mu f -> Mu f 
topDownTransform h = go where 
  go = Fix . fmap go . unFix . h

topDownTransformM :: (Traversable f, Monad m) => (Mu f -> m (Mu f)) -> Mu f -> m (Mu f)
topDownTransformM h = go where 
  go x = do
    Fix y <- h x
    liftM Fix (mapM go y)
  
-- | Non-recursive top-down transformation.
descend :: Functor f => (Mu f -> Mu f) -> Mu f -> Mu f 
descend h = Fix . fmap h . unFix

descendM :: (Traversable f, Monad m) => (Mu f -> m (Mu f)) -> Mu f -> m (Mu f)
descendM action = liftM Fix . mapM action . unFix

-- | Bottom-up transformation until a normal form is reached.
rewrite :: Functor f => (Mu f -> Maybe (Mu f)) -> Mu f -> Mu f 
rewrite h = transform g  where 
  g x = maybe x (rewrite h) (h x)

rewriteM :: (Traversable f, Monad m) => (Mu f -> m (Maybe (Mu f))) -> Mu f -> m (Mu f)
rewriteM h = transformM g where 
  g x = h x >>= \y -> maybe (return x) (rewriteM h) y
  
--------------------------------------------------------------------------------
-- * Context

-- | We /annotate/ the nodes of the tree with functions which replace that
-- particular subtree.
context :: Traversable f => Mu f -> Attr f (Mu f -> Mu f)
context = go id where
  go h = Fix . Ann h . fmap g . holes . unFix where
    g (y,replace) = go (h . Fix . replace) y where 

-- | Flattened version of 'context'.
contextList :: Traversable f => Mu f -> [(Mu f, Mu f -> Mu f)]
contextList = map h . universe . context where
  h this@(Fix (Ann g x)) = (forget this, g)

--------------------------------------------------------------------------------
-- * Folds

-- | (Strict) left fold. Since @Mu f@ is not a functor, but a data type, we cannot make
-- it an instance of the @Foldable@ type class.
foldLeft :: Foldable f => (a -> Mu f -> a) -> a -> Mu f -> a
#ifdef __GLASGOW_HASKELL__
foldLeft h x0 t = go x0 t where
  go !x !t = foldl go (h x t) (unFix t)
#else           
foldLeft h x0 t = go x0 t where
  go x t = x `seq` t `seq` foldl go (h x t) (unFix t)
#endif

foldLeftLazy :: Foldable f => (a -> Mu f -> a) -> a -> Mu f -> a
foldLeftLazy h x0 t = go x0 t where
  go x t = foldl go (h x t) $ unFix t

foldRight :: Foldable f => (Mu f -> a -> a) -> a -> Mu f -> a
foldRight h x0 t = go t x0 where
  go t x = h t $ foldr go x $ unFix t 

--------------------------------------------------------------------------------
#ifdef WITH_QUICKCHECK
-- * Tests

universeNaive :: Foldable f => Mu f -> [Mu f]
universeNaive x = x : concatMap universeNaive (children x)

runtests_Traversals = do
  quickCheck prop_leftFold
  quickCheck prop_leftFoldLazy
  quickCheck prop_rightFold
  quickCheck prop_universe1
  quickCheck prop_universe2
  
prop_universe1 :: FixT Label -> Bool
prop_universe1 tree = universe tree == universeNaive tree

prop_universe2 :: FixT Label -> Bool
prop_universe2 tree = universe tree == foldRight (:) [] tree
  
prop_leftFold :: FixT Label -> Bool
prop_leftFold tree = 
  foldLeft (\xs (Fix (TreeF l s)) -> (l:xs)) [] tree == foldl (flip (:)) [] (fromFixT tree)

prop_leftFoldLazy :: FixT Label -> Bool
prop_leftFoldLazy tree = 
  foldLeftLazy (\xs (Fix (TreeF l s)) -> (l:xs)) [] tree == foldl (flip (:)) [] (fromFixT tree)
  
prop_rightFold :: FixT Label -> Bool
prop_rightFold tree = 
  foldRight (\(Fix (TreeF l s)) xs -> (l:xs)) [] tree == foldr (:) [] (fromFixT tree)

#endif
--------------------------------------------------------------------------------