-- |
-- Module: Data.Reify.Graph.CSE
-- Copyright: (c) 2009 Sebastiaan Visser
-- License: BSD3
--
-- Maintainer: Sebastiaan Visser <sfvisser@cs.uu.nl>
-- Stability: unstable
-- Portability: ghc
--
-- This module implements common sub-expression elimination for graphs
-- generated by the Data.Reify package. The algorithm performs a simple
-- fixed point iteration and is not optimized for speed.
--
-- As an illustration, take this simple datatype representing an embedded
-- language containing primitives and function application.  The datatype
-- abstracts away from the recursive points which is common when using the
-- 'Data.Reify' package.  A fixed point combinator can be used to tie the knot.
--
-- >data Val f = App f f | Prim String
-- >  deriving (Eq, Ord, Show)
-- >
-- >newtype Fix f = In { out :: f (Fix f) }
--
-- No we can add some useful instances and make the fixed point combinator an
-- instance of the 'Data.Reify' 'MuRef' class.
--
-- >instance Functor Val      ...
-- >instance Foldable Val     ...
-- >instance Traversable Val  ...
-- >
-- >instance Traversable a => MuRef (Fix a) where
-- >  type DeRef (Fix a) = a
-- >  mapDeRef f = traverse f . out
--
-- When we now take the following example term in our embedded language we can
-- see that the `cse` function can eliminate common terms without changing the
-- semantics. Evidently, we assume our language is referential transparent language.
--
-- >myTerm :: Fix Val
-- >myTerm = In $ clc `mul` clc
-- >  where clc = Prim "2" `add` Prim "5"
-- >        add a b = Prim "+" `app` a `app` b
-- >        mul a b = Prim "*" `app` a `app` b
-- >        app a b = App (In a) (In b)
--
-- The term @fmap cse $ reifyGraph myTerm@ yields an optimized graph compared
-- to the normal result of `reifyGraph`.
--
-- >with CSE:       without CSE:
-- >
-- >(1,App 2 9)     (1,App 2 9)
-- >(2,App 3 9)     (9,App 10 13)
-- >(10,App 6 7)    (13,Prim "5")
-- >(9,App 10 8)    (10,App 11 12)
-- >(3,Prim "*")    (12,Prim "2")
-- >(6,Prim "+")    (11,Prim "+")
-- >(7,Prim "2")    (2,App 3 4)
-- >(8,Prim "5")    (4,App 5 8)
-- >                (8,Prim "5")
-- >                (5,App 6 7)
-- >                (7,Prim "2")
-- >                (6,Prim "+")
-- >                (3,Prim "*")

{-# LANGUAGE TypeFamilies #-}
module Data.Reify.Graph.CSE (cse) where

import Data.Map (Map, toList, fromListWith, filter, update, mapKeysWith)
import Data.Reify
import Prelude hiding (filter)

{- | Perform CSE on the input graph. -} 

cse :: (Ord (f Unique), Functor f) => Graph f -> Graph f
cse (Graph xs root) =
  let swapped = map (\(a, b) -> (head b, a)) . toList . eliminate $ xs
  in Graph swapped root

groupById :: Ord k => [(a, k)] -> Map k [a]
groupById = fromListWith (++) . map (\(a, b) -> (b, [a]))

fixpoint :: Eq a => (a -> a) -> a -> a
fixpoint f a = let b = f a in if b == a then a else fixpoint f b

eliminate :: (Eq a, Functor f, Ord (f a)) => [(a, f a)] -> Map (f a) [a]
eliminate = fixpoint eliminate1 . groupById

eliminate1 :: (Eq a, Functor f, Ord (f a)) => Map (f a) [a] -> Map (f a) [a]
eliminate1 g =
  case toList $ filter ((>1) . length) g of
    (term, ids):_ -> update (elimKey ids) term . mapKeysWith (++) (elimVal ids) $ g
    _ -> g

elimVal :: (Eq a, Functor f) => [a] -> f a -> f a
elimVal i = fmap (\a -> if a `elem` i then head i else a)

elimKey :: [a] -> b -> Maybe [a]
elimKey = const . Just . take 1