{-# LANGUAGE PatternGuards, FlexibleContexts, RankNTypes #-}
module HFusion.Internal.Compositions(fuseDefinitions) where

import Data.List(find,elemIndex,(\\),partition,elemIndex,sortBy,maximumBy,nubBy,intersperse)
import Data.Function(on)
import Data.Maybe(catMaybes)
import qualified Data.Map as M(Map,fromList,lookup,insert)
import Control.Monad.State(State,StateT(..),get,put,evalStateT,MonadState(..),modify)
import Control.Monad.Writer(WriterT(..),MonadWriter(..))
import Control.Monad.Error(ErrorT(..))
import Control.Monad.Trans(MonadTrans(..))
import Control.Monad(liftM2,liftM3,zipWithM,foldM,MonadPlus(..))
import Control.Arrow((***),second,first,(&&&))

import HFusion.Internal.HsSyn
import HFusion.Internal.FuseFace
import HFusion.Internal.Parsing.HyloParser
import HFusion.Internal.Utils(VarGenState)

-- | Eliminates compositions of recursive functions from definitions.
fuseDefinitions :: [Def] -- ^ Definitions in scope. Hylomorphism will be derived from them.
                -> [Def] -- ^ Definitions containing compositions which must be eliminated.
                -> VarGenState ([Def],[Def]) -- ^ The transformed definitions without the compositions that 
                                             -- were succesfully fused and the additional definitions 
                                             -- introduced as a result of fusion.
fuseDefinitions hdfs dfs =
         deriveHylos hdfs >>= \(_,hs) ->
           runStateT (runWriterT (mapM (\(Defvalue v t) -> fuseComp [] t >>= return . Defvalue v) dfs))
                     (createHyloSet hs) 
              >>= return . second (nubBy ((==) `on` getDefName)) . fst

-- | Compositions of hylomorphisms found in terms.
data Composition = Comp FHistory Int HyloT [Composition] 
                   -- ^ @Comp fh i h cs@ is the application of the component @i@ of hylo @h@
                   -- with history @fh@ to arguments @cs@. 
                 | CompTerm Term
                 | CompSplit Composition  
                   -- ^ This constructors is intended to appear in the deeper levels of a 'Composition' value.
                   -- When fusion cannot be performed the split constructor indicates that no more attempts
                   -- should be done to fuse the hylos from oposite sides of the split.

-- | A function for debugging purposes
showComp (Comp fh i _ cs) = "Comp ("++show fh++") "++show i++" ["++concat (intersperse ", " (map showComp cs))++"]"
showComp (CompTerm t) = "CompTerm ("++show t++")"
showComp (CompSplit c) = "CompSplit ("++showComp c++")"

-- | History of fusions. It tells which hylomorphisms have 
-- been fused together. 
data FHistory = FHNode FHistory [(Int,FHistory)]
                -- ^ @FNode h i k@ describes the fusion of hylo @h@ with the hylo
                -- @k@ in argument @i@.
              | FHLeaf Variable -- An unfused hylo. 
  deriving (Eq,Ord,Show)

fusedH :: FHistory -> Int -> FHistory -> FHistory
fusedH (FHNode fh hs) ai fh' = FHNode fh$ (ai,fh'):hs
fusedH fh ai fh' = FHNode fh [(ai,fh')] 

rank :: FHistory -> Int
rank (FHNode fh fhs) = length fhs + rank fh + sum (map (rank . snd) fhs)
rank _ = 0

-- | Records a fusion in the HyloSet in the state.
recordH :: MonadState HyloSet m => HyloT -- ^ the hylo at the left of the composition
                                -> HyloT -- ^ the hylo at the right of the composition
                                -> [Def] -- ^ the inlined result
                                -> HyloT -- ^ the hylo which resulted from fusion
                                -> [(Int,[(Int,Int)])] -- ^ the hylo components which were fused
                                -> m () 
recordH h h' dfs hr his = get >>= \hs -> put$ addHylo dfs hr (map (getFH hs) his) hs
  where ns = getNames h
        ns' = getNames h'
        getFH hs (i1,is) = 
                case lookupHylo (ns!!i1) hs of
                 Just (fh,_,_,_) -> foldr (getFH' hs) fh (sortBy (compare `on` fst) is)
                 Nothing -> error$ "recordH: Hylo1 "++ show (ns!!i1) ++" not found!"
        getFH' hs (ari,i2) fh = case lookupHylo (ns'!!i2) hs of
                                 Just (fh',_,_,_) -> fusedH fh ari fh'
                                 Nothing -> error$ "recordH: Hylo2 "++ show (ns'!!i2) ++" not found!"

-- | Examines a term looking for compositions of hylomorphisms. 
findComposition :: [Variable] -- ^ variables in scope which are used to avoid confusing a hylo with a local binding
                -> HyloSet -- ^ the set of hylos which can appear in the given term
                -> Term    -- ^ the term to scan for compositions
                -> Maybe Composition
findComposition vs hs t = case go t of
                           c@(Comp _ _ _ _) -> Just c
                           _                -> Nothing
  where go (Tfapp v ts) | notElem v vs, Just (fh,i,_,h)<-lookupHylo v hs = Comp fh i h$ map go ts
        go (Tapp t0 t1) | Comp fh i h cs <- go t0 = Comp fh i h$ cs ++ [ go t1 ]
        go t = CompTerm t

-- | Transforms a term by fusing the compositions of hylos that it could have.
-- The Writer monad yields the definitions produced during fusion.
fuseComp :: [Variable] -> Term -> WriterT [Def] (StateT HyloSet VarGenState) Term
fuseComp vs t = get >>= \hs ->
    case findComposition vs hs t of
     Just c -> lift (fuse c) >>= \(c',dfs) ->
                  let compCase _ i h = return . Tfapp (getNames h!!i)
                   in tell dfs >> foldCompositionM compCase return return c'
     Nothing -> descend vs t
    descend vs (Ttuple b ts) = fmap (Ttuple b)$ mapM (fuseComp vs) ts
    descend vs (Tfapp v ts) = fmap (Tfapp v)$ mapM (fuseComp vs) ts
    descend vs (Tcapp c ts) = fmap (Tcapp c)$ mapM (fuseComp vs) ts
    descend vs (Tapp t0 t1) = liftM2 Tapp (fuseComp vs t0) (fuseComp vs t1)
    descend vs (Tcase t0 ps ts) = liftM2 (flip Tcase ps) (fuseComp vs t0)
                                                         (zipWithM (\p->fuseComp (vars p++vs)) ps ts)
    descend vs (Tif t0 t1 t2) = liftM3 Tif (fuseComp vs t0) (fuseComp vs t1) (fuseComp vs t2)
    descend vs (Tlet v t0 t1) = liftM2 (Tlet v) (fuseComp (v:vs) t0) (fuseComp vs t1)
    descend vs (Tpar t) = fmap Tpar (fuseComp vs t)
    descend vs (Tlamb bv t) = fmap (Tlamb bv) (fuseComp (vars bv++vs) t)
    descend vs t = return t

-- | Transforms a 'Composition' value by fusing the hylos it could have. 
fuse :: Composition -> StateT HyloSet VarGenState (Composition,[Def])
fuse c = do cs <- findFusion$ runC (fusings c) (Tree . return . MTReturn)
            let (_,c) = maximumBy (compare `on` fst)$ map (rankC &&& id) cs
            dfs<-foldCompositionM collectDfs (const (return [])) return c  
            return (c,dfs) 
  where rankC (Comp fh _ _ cs) = rank fh + sum (map rankC cs)
        rankC (CompSplit c) = rankC c
        rankC _ = 0
        collectDfs (FHLeaf _) i h dfs = return$ concat dfs 
        collectDfs fh i h dfs = get >>= \hs ->
           case lookupFusion fh hs of
             Nothing -> return$ concat dfs
             Just (_,hdfs,_) -> return$ hdfs ++ filter (flip elem (vars hdfs) . getDefName) (concat dfs)

-- | Right now, this just returns the leftmost leaf.
findFusion :: Monad m => Tree b m a -> m [a]
findFusion (Tree m) = m >>= \n -> case n of
     MTReturn a -> return [a]
     MTNode ml mr -> findFusion ml
     MTTag _ m -> findFusion m 

-- | Produces a tree with all possible ways to fuse a composition.
fusings :: Composition -> TreeT () (StateT HyloSet VarGenState) Composition
fusings (Comp fh i h cs) = 
           (mapM fusings cs >>= fmap comp . fuseOneComp fh h) 
           (do (fh',i,h,cs') <- fuseOneComp fh h cs
               mapM fusings cs' >>= return . Comp fh' i h  
  where ris h cs = maybe [getConstantArgCount h..length cs] ([0..length cs]\\)$ getConstantArgPos h
        isRecArg h cs ai = elemIndex ai (ris h cs) 
        fuseOneComp :: FHistory -> HyloT -> [Composition]
                        -> TreeT () (StateT HyloSet VarGenState) (FHistory,Int,HyloT,[Composition])
        fuseOneComp fc h cs =
             let rargs = [ ((ai,ari),(fc',i',h',cs')) | (ai,Comp fc' i' h' cs') <- zip [0..] cs
                                                       , Just ari<-[ isRecArg h cs ai ] ]
              in  foldM applyFusion (fc,i,h,cs) (reverse rargs)
        comp (fc,i,h,cs) = Comp fc i h cs
        applyFusion :: (FHistory,Int,HyloT,[Composition]) -> ((Int,Int),(FHistory,Int,HyloT,[Composition]))
                        -> TreeT () (StateT HyloSet VarGenState) (FHistory,Int,HyloT,[Composition])
        applyFusion (fh,i,h,cs) ((ai,ari),(fh',i',h',cs')) = 
          let fh'' = fusedH fh ari fh'
              splitNRR h cs = case getConstantArgPos h of 
                                Nothing -> splitAt (getConstantArgCount h) cs 
                                Just is -> (map snd *** map snd)$ partition (flip elem is . fst) (zip [0..] cs)
              (csa,csb) = splitNRR h cs
              (csb',_:csb'') = splitAt ari csb
              (csah',csbh') = splitNRR h' cs'
           in lift get >>= \ hs ->
              case lookupFusion fh'' hs of
                Just (i,dfs,hr) -> return (fh'',0,hr,csa++csah'++csb'++csbh'++csb'')
                _ -> lift (lift (runErrorT (fusionar' [] h i ari h' i'))) >>= \e ->
                  case e of
                    Right (r,his,hr) | r>0 ->
                             do dfs <- lift$ lift (inline hr)
                                ehr' <- lift$ lift$ (runErrorT (deriveHylo dfs))
                                let hr' = either (const$ hr) id ehr'
                                lift$ recordH h h' dfs hr' his
                                return (fh'',0,hr',csa++csah'++csb'++csbh'++csb'')
                    _ -> let args c = case getConstantArgPos h of
                                   Nothing -> 
                                     let (a,(b,_:c')) = second (splitAt ari)$ splitAt (getConstantArgCount h) cs 
                                      in a++b++c:c'
                                   _ -> let (a,_:b) = splitAt ai cs 
                                         in a++c:b
                          in treturn () (fh,i,h,args$ CompSplit (Comp fh' i' h' cs'))
fusings c = return c

foldCompositionM :: Monad m => (FHistory -> Int -> HyloT -> [b] -> m b) -> (Term -> m b) -> (b -> m b) -> Composition -> m b
foldCompositionM f0 f1 f2 (Comp fc i h cs) = mapM (foldCompositionM f0 f1 f2) cs >>= f0 fc i h 
foldCompositionM f0 f1 f2 (CompTerm t) = f1 t 
foldCompositionM f0 f1 f2 (CompSplit c) = foldCompositionM f0 f1 f2 c >>= f2 

-- | The first map tells which is the history of fusions which produced a hylo
-- with a given name. The second map tells if a hylo has been produced for a given
-- history. For each hylo associations are stored for each of it components,
-- and also its inlined form is stored.
type HyloSet = (M.Map Variable FHistory,M.Map FHistory (Int,[Def],HyloT))

-- | Retrieves the information for a hylo from the name of any of its components.
lookupHylo :: Variable -> HyloSet -> Maybe (FHistory,Int,[Def],HyloT)
lookupHylo n (nm,hs) = 
    do fh <- M.lookup n nm
       (i,dfs,h) <- M.lookup fh hs
       return (fh,i,dfs,h)

lookupFusion :: FHistory -> HyloSet -> Maybe (Int,[Def],HyloT)
lookupFusion fh = M.lookup fh . snd  

createHyloSet :: [([Def],HyloT)] -> HyloSet 
createHyloSet hs = let assocs = [ ((v,FHLeaf v),(i,dfs,h))  | (dfs,h) <- hs, (i,v) <- zip [0..] (getNames h) ]
                    in (M.fromList$ map fst assocs, M.fromList$ map (first snd) assocs)

addHylo :: [Def] -> HyloT -> [FHistory] -> HyloSet -> HyloSet
addHylo dfs h fhs (nm,hm) = 
                        let assocs = [ ((v,fh),(i,dfs,h)) | (i,v,fh)<-zip3 [0..] (getNames h) fhs ]
                         in ( foldr (uncurry M.insert) nm$ map fst assocs
                            , foldr (uncurry M.insert) hm$ map (first snd) assocs

-- A tree monad

newtype TreeT b m a = C { runC :: forall r. (a -> Tree b m r) -> Tree b m r }
newtype Tree b m a = Tree { runTree :: m (MTree b m a) }
data MTree b m a = MTReturn a | MTNode (Tree b m a) (Tree b m a) | MTTag b (Tree b m a)

instance Functor (TreeT b m) where
  fmap f (C fc) = C$ \k -> fc (k . f)

instance Monad (TreeT b m) where
  return a = C ($ a)
  C f >>= fc = C$ \k -> f (\a -> runC (fc a) k)

instance MonadTrans (TreeT b) where  
  lift m = C$ \k -> Tree$ m >>= runTree . k

-- | A tagged return tags the subtree which results from @treturn b a >>= f@
-- with a value @b@. The tag may be useful for deciding how to traverse the
-- tree.
treturn :: Monad m => b -> a -> TreeT b m a
treturn b a = C$ \k -> Tree$ return$ MTTag b$ k a

-- | Makes a node in the tree.
tnode :: Monad m => TreeT b m a -> TreeT b m a -> TreeT b m a
tnode (C cl) (C cr) = C$ \k -> Tree$ return (MTNode (cl k) (cr k))