-- | Synthetising attributes, partly motivated by Attribute Grammars.

{-# LANGUAGE CPP #-}
module Data.Generics.Fixplate.Attributes
  ( Attrib(..)
  , annMap
  -- * Synthetised attributes
  , synthetise , synthetise' , synthetiseList , synthetiseM
  -- * Inherited attributes
  , inherit , inherit'
  -- * Traversals
  , synthAccumL    , synthAccumR
  , synthAccumL_   , synthAccumR_
  , enumerateNodes , enumerateNodes_
  -- * Stacking attributes
  , annZip  , annZipWith
  , annZip3 , annZipWith3  
#ifdef WITH_QUICKCHECK
  -- * Tests
  , runtests_Attributes  
  , prop_synthAccumL
  , prop_synthAccumR
  , prop_synthetise
#endif  
  )
  where

--------------------------------------------------------------------------------

import Control.Monad (liftM)
import Data.Foldable
import Data.Traversable
import Prelude hiding (foldl,foldr,mapM,mapM_,concat,concatMap,sum)

import Data.Generics.Fixplate.Base

#ifdef WITH_QUICKCHECK
import Test.QuickCheck
import Data.Generics.Fixplate.Traversals
import Data.Generics.Fixplate.Test.Tools
#endif 

--------------------------------------------------------------------------------

-- | Map over annotations 
-- 
-- > annMap f = unAttrib . fmap f . Attrib
annMap :: Functor f => (a -> b) -> Attr f a -> Attr f b
annMap h = unAttrib . fmap h . Attrib

--------------------------------------------------------------------------------
-- Synthetised attributes

-- | /Synthetised/ attributes are created in a bottom-up manner. 
-- As an example, the @sizes@ function computes the sizes of all
-- subtrees:
--
-- > sizes :: (Functor f, Foldable f) => Mu f -> Attr f Int
-- > sizes = synthetise (\t -> 1 + sum t)
--
-- (note that @sum@ here is @Data.Foldable.sum == Prelude.sum . Data.Foldable.toList@)
--
synthetise :: Functor f => (f a -> a) -> Mu f -> Attr f a
synthetise h = go where
  go (Fix x) = Fix $ Ann (h a) y where 
    y  =  fmap go x  
    a  =  fmap attribute y
 
-- | Generalization of @scanr@ for trees.
synthetise' :: Functor f => (a -> f b -> b) -> Attr f a -> Attr f b
synthetise' h = go where
  go (Fix (Ann b x)) = Fix $ Ann (h b a) y where 
    y  =  fmap go x  
    a  =  fmap attribute y

synthetiseList :: (Functor f, Foldable f) => ([a] -> a) -> Mu f -> Attr f a
synthetiseList h = synthetise (h . toList)

synthetiseM  ::  (Traversable f, Monad m) =>  (f a -> m a) -> Mu f -> m (Attr f a)
synthetiseM act = go where
  go (Fix x) = do
    y  <-  mapM go x    
    a  <-  act $ fmap attribute y
    return (Fix (Ann a y))

--------------------------------------------------------------------------------
-- Inherited attributes

-- | /Inherited/ attributes are created in a top-down manner. 
-- As an example, the @depths@ function computes the depth 
-- (the distance from the root, incremented by 1) of all subtrees:
--
-- > depths :: Functor f => Mu f -> Attr f Int
-- > depths = inherit (\_ i -> i+1) 0
--
inherit :: Functor f => (Mu f -> a -> a) -> a -> Mu f -> Attr f a
inherit h root = go root where
  go p s@(Fix t) = let a = h s p in Fix (Ann a (fmap (go a) t))

-- | Generalization of @scanl@ for trees
inherit' :: Functor f => (a -> b -> a) -> a -> Attr f b -> Attr f a
inherit' h root = go root where
  go p (Fix (Ann a t)) = let b = h p a in Fix (Ann b (fmap (go b) t))

--------------------------------------------------------------------------------
-- Traversals

-- | Synthetising attributes via an accumulating map in a left-to-right fashion
-- (the order is the same as in @foldl@).
synthAccumL :: Traversable f => (a -> Mu f -> (a,b)) -> a -> Mu f -> (a, Attr f b)
synthAccumL h x0 tree = go x0 tree where
  go x t@(Fix sub) = 
    let (y,a   ) = h x t 
        (z,sub') = mapAccumL go y sub
    in (z, Fix (Ann a sub'))

-- | Synthetising attributes via an accumulating map in a right-to-left fashion
-- (the order is the same as in @foldr@).
synthAccumR :: Traversable f => (a -> Mu f -> (a,b)) -> a -> Mu f -> (a, Attr f b)
synthAccumR h x0 tree = go x0 tree where
  go x t@(Fix sub) = 
    let (y,sub') = mapAccumR go x sub
        (z,a   ) = h y t 
    in (z, Fix (Ann a sub'))
    
synthAccumL_ :: Traversable f => (a -> Mu f -> (a,b)) -> a -> Mu f -> Attr f b
synthAccumL_ h x t = snd (synthAccumL h x t)

synthAccumR_ :: Traversable f => (a -> Mu f -> (a,b)) -> a -> Mu f -> Attr f b
synthAccumR_ h x t = snd (synthAccumR h x t)

-- | We use 'synthAccumL' to number the nodes from @0@ to @(n-1)@ in 
-- a left-to-right traversal fashion, where
-- @n == length (universe tree)@ is the number of substructures,
-- which is also returned.
enumerateNodes :: Traversable f => Mu f -> (Int, Attr f Int)
enumerateNodes tree = synthAccumL (\i _ -> (i+1,i)) 0 tree

enumerateNodes_ :: Traversable f => Mu f -> Attr f Int
enumerateNodes_ = snd . enumerateNodes

--------------------------------------------------------------------------------
-- Stacking attributes

-- | Merges two layers of annotations into a single one.
annZip :: Functor f => Mu (Ann (Ann f a) b) -> Attr f (a,b)
annZip (Fix (Ann y (Ann x t))) = Fix (Ann (x,y) (fmap annZip t))

annZipWith :: Functor f => (a -> b -> c) -> Mu (Ann (Ann f a) b) -> Attr f c
annZipWith h = go where 
  go (Fix (Ann y (Ann x t))) = Fix (Ann (h x y) (fmap go t))

-- | Merges three layers of annotations into a single one.
annZip3 :: Functor f => Mu (Ann (Ann (Ann f a) b) c) -> Attr f (a,b,c)
annZip3 (Fix (Ann z (Ann y (Ann x t)))) = Fix (Ann (x,y,z) (fmap annZip3 t))

annZipWith3 :: Functor f => (a -> b -> c -> d) -> Mu (Ann (Ann (Ann f a) b) c) -> Attr f d
annZipWith3 h = go where 
  go (Fix (Ann z (Ann y (Ann x t)))) = Fix (Ann (h x y z) (fmap go t))

--------------------------------------------------------------------------------
-- Tests
#ifdef WITH_QUICKCHECK

runtests_Attributes = do
  quickCheck prop_synthAccumL
  quickCheck prop_synthAccumR
  quickCheck prop_synthetise
 
prop_synthAccumL :: FixT Label -> Bool
prop_synthAccumL tree = 
  toList (Attrib (synthAccumL_ (\i _ -> (i+1,i)) 1 tree)) == [1..length (universe tree)]

prop_synthAccumR :: FixT Label -> Bool
prop_synthAccumR tree = 
  toList (Attrib (synthAccumR_ (\i _ -> (i+1,i)) 1 tree)) == reverse [1..length (universe tree)]

prop_synthetise :: FixT Label -> Bool
prop_synthetise tree = 
  map attribute (universe $ synthetise (\(TreeF (Label l) xs) -> l ++ concat xs) tree)
  ==
  map fold (universe tree)
  where
    fold = foldLeft (\s (Fix (TreeF (Label l) _)) -> s++l) []
  
#endif
--------------------------------------------------------------------------------