-- | Synthetising attributes, partly motivated by Attribute Grammars, and partly by recursion schemes.

{-# LANGUAGE CPP #-}
module Data.Generics.Fixplate.Attributes
  ( Attrib(..)
  , annMap
  -- * Synthetised attributes
  , synthetise  
  , synthetise' , synthetiseList
  , synthetiseM
  -- * Synthetised attributes as generalized cata- and paramorphisms
  , synthCata   , scanCata
  , synthPara   , synthPara' 
  , scanPara
  , synthZygo_  , synthZygo  , synthZygoWith
  , synthAccumCata , synthAccumPara'
  , mapAccumCata
  , synthCataM  , synthParaM , synthParaM'
  -- * Inherited attributes
  , inherit , inherit'
  , inherit2
  , inheritM , inheritM_
  -- * Top-down folds
  , topDownSweepM , topDownSweepM'
  -- * Traversals
  , synthAccumL    , synthAccumR
  , synthAccumL_   , synthAccumR_
  , enumerateNodes , enumerateNodes_
  -- * Resynthetising transformations
  , synthTransform  , synthTransform'
  , synthRewrite    , synthRewrite'   
  -- * Stacking attributes
  , annZip  , annZipWith
  , annZip3 , annZipWith3  
#ifdef WITH_QUICKCHECK
  -- * Tests
  , runtests_Attributes  
  , scanCataNaive , mapAccumCataNaive
  , prop_synthAccumL
  , prop_synthAccumR
  , prop_synthetise
  , prop_synthCata
  , prop_synthPara
  , prop_synthPara'
  , prop_scanCata 
  , prop_mapAccumCata
  -- * Morphism tests which are here to avoid circular imports
  , zygoNaive_
  , zygoNaive
  , prop_zygo
  , prop_zygo_
#endif  
  )
  where

--------------------------------------------------------------------------------

import Control.Monad (liftM)
import Data.Foldable
import Data.Traversable
import Prelude hiding (foldl,foldr,mapM,mapM_,concat,concatMap,sum,and,or)

import Data.Generics.Fixplate.Base
import Data.Generics.Fixplate.Open

#ifdef WITH_QUICKCHECK
import Test.QuickCheck
import Data.List (intercalate)
import Data.Char (ord)
import qualified Prelude
import Data.Generics.Fixplate.Misc
import Data.Generics.Fixplate.Morphisms
import Data.Generics.Fixplate.Traversals
import Data.Generics.Fixplate.Test.Tools
#endif 

--------------------------------------------------------------------------------

-- | Map over annotations 
-- 
-- > annMap f = unAttrib . fmap f . Attrib
annMap :: Functor f => (a -> b) -> Attr f a -> Attr f b
annMap h = unAttrib . fmap h . Attrib

--------------------------------------------------------------------------------
-- Synthetised attributes

-- | /Synthetised/ attributes are created in a bottom-up manner. 
-- As an example, the @sizes@ function computes the sizes of all
-- subtrees:
--
-- > sizes :: (Functor f, Foldable f) => Mu f -> Attr f Int
-- > sizes = synthetise (\t -> 1 + sum t)
--
-- (note that @sum@ here is @Data.Foldable.sum == Prelude.sum . Data.Foldable.toList@)
--
-- See also 'synthCata'.
synthetise :: Functor f => (f a -> a) -> Mu f -> Attr f a
synthetise = synthCata
 
-- | Generalization of @scanr@ for trees. See also 'scanCata'.
synthetise' :: Functor f => (a -> f b -> b) -> Attr f a -> Attr f b
synthetise' h = go where
  go (Fix (Ann b x)) = Fix $ Ann (h b a) y where 
    y  =  fmap go x  
    a  =  fmap attribute y

-- | List version of 'synthetise' (compare with Uniplate)
synthetiseList :: (Functor f, Foldable f) => ([a] -> a) -> Mu f -> Attr f a
synthetiseList h = synthetise (h . toList)

-- | Monadic version of 'synthetise'.
synthetiseM  ::  (Traversable f, Monad m) =>  (f a -> m a) -> Mu f -> m (Attr f a)
synthetiseM = synthCataM 

--------------------------------------------------------------------------------
-- Synthetised attributes as generalized cata- and paramorphisms

-- | Synonym for 'synthetise', motivated by the equation
--
-- >  attribute . synthCata f == cata f
--
-- That is, it attributes all subtrees with the result of the corresponding catamorphism.
synthCata :: Functor f => (f a -> a) -> Mu f -> Attr f a
synthCata h = go where
  go (Fix x) = Fix $ Ann (h a) y where 
    y  =  fmap go x  
    a  =  fmap attribute y

-- | Synonym for 'synthetise''. Note that this could be a special case of 'synthCata':
--
-- > scanCata f == annZipWith (flip const) . synthCata (\(Ann a x) -> f a x)
--
-- Catamorphim ('cata') is the generalization of @foldr@ from lists to trees; 
-- 'synthCata' is one generalization of @scanr@, and 'scanCata' is another generalization.
scanCata :: Functor f => (a -> f b -> b) -> Attr f a -> Attr f b
scanCata h =  go where
  go (Fix (Ann a x)) = Fix $ Ann (h a b) y where 
    y  =  fmap go x  
    b  =  fmap attribute y

-- | Attributes all subtrees with the result of the corresponding paramorphism.
--
-- >  attribute . synthPara f == para f
--
synthPara :: Functor f => (f (Mu f, a) -> a) -> Mu f -> Attr f a
synthPara h = snd . go where
  go orig@(Fix x) = ( orig , Fix $ Ann (h lft) rht ) where 
    lft = fmap (\(s,t) -> (s, attribute t)) uv
    rht = fmap snd uv    -- :: f (Attr f a)
    uv  = fmap go x      -- :: f (Mu f , Attr f a)

-- | Another version of 'synthPara'.
--
-- >  attribute . synthPara' f == para' f
--
synthPara' :: Functor f => (Mu f -> f a -> a) -> Mu f -> Attr f a
synthPara' h = go where
  go t@(Fix x) = Fix $ Ann (h t a) y where 
    y  =  fmap go x  
    a  =  fmap attribute y


scanPara :: Functor f => (Attr f a -> f b -> b) -> Attr f a -> Attr f b
scanPara h = go where
  go t@(Fix (Ann a x)) = Fix $ Ann (h t b) y where 
    y  =  fmap go x  
    b  =  fmap attribute y

-- | Synthetising zygomorphism. 
-- 
-- > attribute . synthZygo_ g h == zygo_ g h
synthZygo_ :: Functor f => (f b -> b) -> (f (b,a) -> a) -> Mu f -> Attr f a
synthZygo_ = synthZygoWith (flip const)

synthZygo :: Functor f => (f b -> b) -> (f (b,a) -> a) -> Mu f -> Attr f (b,a)
synthZygo = synthZygoWith (,)

synthZygoWith :: Functor f => (b -> a -> c) -> (f b -> b) -> (f (b,a) -> a) -> Mu f -> Attr f c
synthZygoWith u g h = snd . go where
  go (Fix t) = ( (b,a) , Fix (Ann (u b a) s) ) where
    b  = g (fmap fst ba)  
    a  = h ba             
    (ba,s) = unzipF (fmap go t)  -- :: ( f (b,a) , f (Attr f c) )

-- | Accumulating catamorphisms. Generalization of 'mapAccumR' from lists to trees.
synthAccumCata :: Functor f => (f acc -> (acc,b)) -> Mu f -> (acc, Attr f b)
synthAccumCata h = go where
  go (Fix x) = (a, Fix (Ann b (fmap snd y))) where
    y = fmap go x
    (a,b) = h (fmap fst y)

-- | Accumulating paramorphisms.  
synthAccumPara' :: Functor f => (Mu f -> f acc -> (acc,b)) -> Mu f -> (acc, Attr f b)
synthAccumPara' h = go where
  go t@(Fix x) = (a, Fix (Ann b (fmap snd y))) where
    y = fmap go x
    (a,b) = h t (fmap fst y)

-- | Could be a special case of 'synthAccumCata':
--
-- > mapAccumCata f == second (annZipWith (flip const)) . synthAccumCata (\(Ann b t) -> f b t) 
-- >   where second g (x,y) = (x, g y)
--
mapAccumCata :: Functor f => (f acc -> b -> (acc,c)) -> Attr f b -> (acc, Attr f c)
mapAccumCata h = go where
  go (Fix (Ann b x)) = (acc, Fix (Ann c (fmap snd y))) where
    y = fmap go x
    (acc,c) = h (fmap fst y) b

-- | Synonym for 'synthetiseM'. If you don't need the result, use 'cataM_' instead.
synthCataM  ::  (Traversable f, Monad m) =>  (f a -> m a) -> Mu f -> m (Attr f a)
synthCataM act = go where
  go (Fix x) = do
    y  <-  mapM go x    
    a  <-  act $ fmap attribute y
    return (Fix (Ann a y))

-- | Monadic version of 'synthPara'. If you don't need the result,  use 'paraM_' instead.
synthParaM  ::  (Traversable f, Monad m) =>  (f (Mu f, a) -> m a) -> Mu f -> m (Attr f a)
synthParaM act tree = liftM snd (go tree) where
  go orig@(Fix x) = do 
    uv <- mapM go x      
    let lft = fmap (\(s,t) -> (s, attribute t)) uv
    let rht = fmap snd uv 
    a <- act lft
    return ( orig , Fix $ Ann a rht ) 
{-
synthParaM act = go where
  go (Fix x) = do
    y  <-  mapM go x    
    a  <-  act $ unsafeZipWithF (,) x (fmap attribute y)
    return (Fix (Ann a y))
-}

-- | Monadic version of 'synthPara''. 
synthParaM'  ::  (Traversable f, Monad m) => (Mu f -> f a -> m a) -> Mu f -> m (Attr f a)
synthParaM' act = go where
  go t@(Fix x) = do
    y  <-  mapM go x  
    a  <-  act t $ fmap attribute y
    return (Fix (Ann a y))

--------------------------------------------------------------------------------
-- Inherited attributes

-- | /Inherited/ attributes are created in a top-down manner. 
-- As an example, the @depths@ function computes the depth 
-- (the distance from the root, incremented by 1) of all subtrees:
--
-- > depths :: Functor f => Mu f -> Attr f Int
-- > depths = inherit (\_ i -> i+1) 0
--
inherit :: Functor f => (Mu f -> a -> a) -> a -> Mu f -> Attr f a
inherit h root = go root where
  go p s@(Fix t) = let a = h s p in Fix (Ann a (fmap (go a) t))

-- | Generalization of `inherit'. TODO: better name?
inherit2 :: Functor f => (Mu f -> a -> (b,a)) -> a -> Mu f -> Attr f b
inherit2 h root = go root where
  go p s@(Fix t) = let (b,a) = h s p in Fix (Ann b (fmap (go a) t))

-- | Generalization of @scanl@ from lists to trees.
inherit' :: Functor f => (a -> b -> a) -> a -> Attr f b -> Attr f a
inherit' h root = go root where
  go p (Fix (Ann a t)) = let b = h p a in Fix (Ann b (fmap (go b) t))

-- | Monadic version of 'inherit'.
inheritM :: (Traversable f, Monad m) => (Mu f -> a -> m a) -> a -> Mu f -> m (Attr f a)
inheritM act root = go root where
  go p s@(Fix t) = do
    a <- act s p 
    u <- mapM (go a) t
    return (Fix (Ann a u))

inheritM_ :: (Traversable f, Monad m) => (Mu f -> a -> m a) -> a -> Mu f -> m ()
inheritM_ act root = go root where
  go p s@(Fix t) = do
    a <- act s p 
    _ <- mapM (go a) t
    return ()

--------------------------------------------------------------------------------
-- Top-down folds

-- | Monadic top-down \"sweep\" of a tree. It's kind of a more complicated folding version of 'inheritM'.
-- This is unsafe in the sense that the user is responsible to retain the shape of the node.
-- TODO: better name?
topDownSweepM :: (Traversable f, Monad m) => (f () -> a -> m (f a)) -> a -> Mu f -> m ()
topDownSweepM act root = go root where
  go p (Fix t) = do
    s <- act (fmap (const ()) t) p 
    _ <- unsafeZipWithFM go s t
    return ()

-- | An attributed version of 'topDownSweepM'. Probably more useful.
topDownSweepM' :: (Traversable f, Monad m) => (b -> f b -> a -> m (f a)) -> a -> Attr f b -> m ()
topDownSweepM' act root = go root where
  go p (Fix (Ann u t)) = do
    s <- act u (fmap attribute t) p 
    _ <- unsafeZipWithFM go s t
    return ()

--------------------------------------------------------------------------------
-- Traversals

-- | Synthetising attributes via an accumulating map in a left-to-right fashion
-- (the order is the same as in @foldl@).
synthAccumL :: Traversable f => (a -> Mu f -> (a,b)) -> a -> Mu f -> (a, Attr f b)
synthAccumL h x0 tree = go x0 tree where
  go x t@(Fix sub) = 
    let (y,a   ) = h x t 
        (z,sub') = mapAccumL go y sub
    in (z, Fix (Ann a sub'))

-- | Synthetising attributes via an accumulating map in a right-to-left fashion
-- (the order is the same as in @foldr@).
synthAccumR :: Traversable f => (a -> Mu f -> (a,b)) -> a -> Mu f -> (a, Attr f b)
synthAccumR h x0 tree = go x0 tree where
  go x t@(Fix sub) = 
    let (y,sub') = mapAccumR go x sub
        (z,a   ) = h y t 
    in (z, Fix (Ann a sub'))
    
synthAccumL_ :: Traversable f => (a -> Mu f -> (a,b)) -> a -> Mu f -> Attr f b
synthAccumL_ h x t = snd (synthAccumL h x t)

synthAccumR_ :: Traversable f => (a -> Mu f -> (a,b)) -> a -> Mu f -> Attr f b
synthAccumR_ h x t = snd (synthAccumR h x t)

-- | We use 'synthAccumL' to number the nodes from @0@ to @(n-1)@ in 
-- a left-to-right traversal fashion, where
-- @n == length (universe tree)@ is the number of substructures,
-- which is also returned.
enumerateNodes :: Traversable f => Mu f -> (Int, Attr f Int)
enumerateNodes tree = synthAccumL (\i _ -> (i+1,i)) 0 tree

enumerateNodes_ :: Traversable f => Mu f -> Attr f Int
enumerateNodes_ = snd . enumerateNodes

--------------------------------------------------------------------------------
-- Resynthetising transformations

-- | Bottom-up transformations which automatically resynthetise attributes 
-- in case of changes.
synthTransform :: Traversable f => (f a -> a) -> (Attr f a -> Maybe (f (Attr f a))) -> Attr f a -> Attr f a 
synthTransform calc = synthTransform' (calc . fmap attribute)

synthTransform' :: Traversable f => (f (Attr f a) -> a) -> (Attr f a -> Maybe (f (Attr f a))) -> Attr f a -> Attr f a 
synthTransform' calc h0 = snd . go False where
  synth x = Fix $ Ann (calc x) x
  hsynth x = case h0 (synth x) of 
    Nothing -> Nothing
    Just y  -> Just (synth y)
  go changed0 old@(Fix (Ann _ x)) = 
    let (changed1,y) = mapAccumL go changed0 x
        new = case hsynth y of
          Nothing -> (changed1,w) where
            w = if changed1
              then synth y
              else old
          Just z -> (True, z)
    in  new

-- | Bottom-up transformations to normal form (applying transformation exhaustively)
-- which automatically resynthetise attributes in case of changes.
synthRewrite :: Traversable f => (f a -> a) -> (Attr f a -> Maybe (f (Attr f a))) -> Attr f a -> Attr f a 
synthRewrite calc = synthRewrite' (calc . fmap attribute)

synthRewrite' :: Traversable f => (f (Attr f a) -> a) -> (Attr f a -> Maybe (f (Attr f a))) -> Attr f a -> Attr f a 
synthRewrite' calc h0 = rewrite where
  rewrite = snd . go False 
  synth x = Fix $ Ann (calc x) x
  hsynth x = case h0 (synth x) of 
    Nothing -> Nothing
    Just y  -> Just (synth y)
  go changed0 old@(Fix (Ann _ x)) = 
    let (changed1,y) = mapAccumL go changed0 x
        new = case hsynth y of
          Nothing -> (changed1,w) where
            w = if changed1
              then synth y
              else old
          Just z -> (True, rewrite z)
    in  new
    
--------------------------------------------------------------------------------
-- Stacking attributes

-- | Merges two layers of annotations into a single one.
annZip :: Functor f => Mu (Ann (Ann f a) b) -> Attr f (a,b)
annZip (Fix (Ann y (Ann x t))) = Fix (Ann (x,y) (fmap annZip t))

annZipWith :: Functor f => (a -> b -> c) -> Mu (Ann (Ann f a) b) -> Attr f c
annZipWith h = go where 
  go (Fix (Ann y (Ann x t))) = Fix (Ann (h x y) (fmap go t))

-- | Merges three layers of annotations into a single one.
annZip3 :: Functor f => Mu (Ann (Ann (Ann f a) b) c) -> Attr f (a,b,c)
annZip3 (Fix (Ann z (Ann y (Ann x t)))) = Fix (Ann (x,y,z) (fmap annZip3 t))

annZipWith3 :: Functor f => (a -> b -> c -> d) -> Mu (Ann (Ann (Ann f a) b) c) -> Attr f d
annZipWith3 h = go where 
  go (Fix (Ann z (Ann y (Ann x t)))) = Fix (Ann (h x y z) (fmap go t))

--------------------------------------------------------------------------------
-- Tests
#ifdef WITH_QUICKCHECK

runtests_Attributes :: IO ()
runtests_Attributes = do
  quickCheck prop_synthAccumL
  quickCheck prop_synthAccumR
  quickCheck prop_synthetise
  quickCheck prop_synthCata
  quickCheck prop_synthPara
  quickCheck prop_synthPara'
  quickCheck prop_scanCata 
  quickCheck prop_mapAccumCata
  quickCheck prop_zygo
  quickCheck prop_zygo_

prop_synthAccumL :: FixT Label -> Bool
prop_synthAccumL tree = 
  toList (Attrib (synthAccumL_ (\i _ -> (i+1,i)) 1 tree)) == [1..length (universe tree)]

prop_synthAccumR :: FixT Label -> Bool
prop_synthAccumR tree = 
  toList (Attrib (synthAccumR_ (\i _ -> (i+1,i)) 1 tree)) == reverse [1..length (universe tree)]

prop_synthetise :: FixT Label -> Bool
prop_synthetise tree = 
  map attribute (universe $ synthetise (\(TreeF (Label l) xs) -> l ++ concat xs) tree)
  ==
  map fold (universe tree)
  where
    fold = foldLeft (\s (Fix (TreeF (Label l) _)) -> s++l) []

prop_synthCata :: FixT Label -> Bool
prop_synthCata tree = attribute (synthCata f tree) == cata f tree where 
  f :: TreeF Label String -> String
  f (TreeF (Label label) xs) = label++"(" ++ intercalate "," xs ++ ")"

prop_synthPara' :: FixT Label -> Bool
prop_synthPara' tree = attribute (synthPara' h tree) == para' h tree where 
  h :: FixT Label -> TreeF Label String -> String
  h tree@(Fix (TreeF label ts)) ys = unLabel label++"_"++show siz++"(" ++ intercalate "," (zipWith c (toList ys) sizs) ++ ")" where
    siz = cata f tree
    sizs = map (cata f) ts
    f t = 1 + Data.Foldable.sum t
    c str j = str ++ "<" ++ show j ++ ">"

prop_synthPara :: FixT Label -> Bool
prop_synthPara tree = attribute (synthPara g tree) == para g tree where 
  g :: TreeF Label (FixT Label , String) -> String
  g (TreeF (Label label) xs) = label++"(" ++ intercalate "," (map u xs) ++ ")" where
    u (tree,a) = show siz ++ "_" ++ a where
      siz = cata (\t -> 1 + Data.Foldable.sum t) tree

scanCataNaive :: Functor f => (a -> f b -> b) -> Attr f a -> Attr f b
scanCataNaive f = annZipWith (flip const) . synthCata (\(Ann a x) -> f a x)

prop_scanCata :: Attr (TreeF Label) String -> Bool
prop_scanCata tree = scanCata f tree == scanCataNaive f tree where
  f :: (String -> TreeF Label Integer -> Integer) -- -> Attr (TreeF Label) String -> Attr (TreeF Label) Integer
  f str t = Prelude.product (toList t) + sumchar str
  sumchar :: String -> Integer
  sumchar = fromIntegral . Prelude.sum . map ord 
--   tree = synthetise (\(TreeF (Label l) xs) -> map toUpper l ++ concat xs) tree) tree0

mapAccumCataNaive :: Functor f => (f acc -> b -> (acc,c)) -> Attr f b -> (acc, Attr f c)
mapAccumCataNaive f = second (annZipWith (flip const)) . synthAccumCata (\(Ann b t) -> f t b)

prop_mapAccumCata :: Attr (TreeF Label) String -> Bool
prop_mapAccumCata tree = mapAccumCata f tree == mapAccumCataNaive f tree where
  f :: (TreeF Label Integer -> String -> (Integer,String)) -- -> Attr (TreeF Label) String -> Attr (TreeF Label) Integer
  f t str = ( k - fromIntegral (length str) + sumchar str , "<" ++ show k ++ "," ++ str ++ ">") where 
    ls = toList t
    k  = Prelude.product ls
  sumchar :: String -> Integer
  sumchar = fromIntegral . Prelude.sum . map ord 
  -- tree = synthetise (\(TreeF (Label l) xs) -> map toLower l ++ concat xs) tree) tree0

--------------------------------------------------------------------------------
-- Morphism tests which are here to avoid circular imports

zygoNaive_ :: Functor f => (f b -> b) -> (f (b,a) -> a) -> Mu f -> a
zygoNaive_ g h = para (h . fmap (first attribute) . unAnn) . synthCata g 

zygoNaive :: Functor f => (f b -> b) -> (f (b,a) -> a) -> Mu f -> (b,a)
zygoNaive g h tree = (attribute tmp, para h1 tmp) where 
  tmp = synthCata g tree
  h1 = h . fmap (first attribute) . unAnn

prop_zygo :: FixT Label -> Bool
prop_zygo tree = zygo g h tree == zygoNaive g h tree where
  g :: TreeF Label Integer -> Integer
  g (TreeF (Label label) child) = Prelude.product child + sumchar label

  h :: TreeF Label (Integer,String) -> String
  h (TreeF (Label label) child) = "[" ++ label ++ "]<" ++ intercalate "," (map f child) ++ ">"

  f (k,s) = show k ++ "_" ++ s

  sumchar = fromIntegral . Prelude.sum . map ord 

prop_zygo_ :: FixT Label -> Bool
prop_zygo_ tree = zygo_ g h tree == zygoNaive_ g h tree where
  g :: TreeF Label Integer -> Integer
  g (TreeF (Label label) child) = Prelude.product child + prodchar label

  h :: TreeF Label (Integer,String) -> String
  h (TreeF (Label label) child) = "<" ++ intercalate "," (map f child) ++ ">" ++ "[" ++ label ++ "]"

  f (k,s) = s ++ "_" ++ show k

  prodchar = fromIntegral . Prelude.product . map ord 

#endif
--------------------------------------------------------------------------------