module HFusion.Internal.Compositions(fuseDefinitions) where
import Data.List(find,elemIndex,(\\),partition,elemIndex,sortBy,maximumBy,nubBy,intersperse)
import Data.Function(on)
import Data.Maybe(catMaybes)
import qualified Data.Map as M(Map,fromList,lookup,insert)
import Control.Monad.State(State,StateT(..),get,put,evalStateT,MonadState(..),modify)
import Control.Monad.Writer(WriterT(..),MonadWriter(..))
import Control.Monad.Error(ErrorT(..))
import Control.Monad.Trans(MonadTrans(..))
import Control.Monad(liftM2,liftM3,zipWithM,foldM,MonadPlus(..))
import Control.Arrow((***),second,first,(&&&))
import HFusion.Internal.HsSyn
import HFusion.Internal.FuseFace
import HFusion.Internal.Parsing.HyloParser
import HFusion.Internal.Utils(VarGenState)
fuseDefinitions :: [Def]
-> [Def]
-> VarGenState ([Def],[Def])
fuseDefinitions hdfs dfs =
deriveHylos hdfs >>= \(_,hs) ->
runStateT (runWriterT (mapM (\(Defvalue v t) -> fuseComp [] t >>= return . Defvalue v) dfs))
(createHyloSet hs)
>>= return . second (nubBy ((==) `on` getDefName)) . fst
data Composition = Comp FHistory Int HyloT [Composition]
| CompTerm Term
| CompSplit Composition
showComp (Comp fh i _ cs) = "Comp ("++show fh++") "++show i++" ["++concat (intersperse ", " (map showComp cs))++"]"
showComp (CompTerm t) = "CompTerm ("++show t++")"
showComp (CompSplit c) = "CompSplit ("++showComp c++")"
data FHistory = FHNode FHistory [(Int,FHistory)]
| FHLeaf Variable
deriving (Eq,Ord,Show)
fusedH :: FHistory -> Int -> FHistory -> FHistory
fusedH (FHNode fh hs) ai fh' = FHNode fh$ (ai,fh'):hs
fusedH fh ai fh' = FHNode fh [(ai,fh')]
rank :: FHistory -> Int
rank (FHNode fh fhs) = length fhs + rank fh + sum (map (rank . snd) fhs)
rank _ = 0
recordH :: MonadState HyloSet m => HyloT
-> HyloT
-> [Def]
-> HyloT
-> [(Int,[(Int,Int)])]
-> m ()
recordH h h' dfs hr his = get >>= \hs -> put$ addHylo dfs hr (map (getFH hs) his) hs
where ns = getNames h
ns' = getNames h'
getFH hs (i1,is) =
case lookupHylo (ns!!i1) hs of
Just (fh,_,_,_) -> foldr (getFH' hs) fh (sortBy (compare `on` fst) is)
Nothing -> error$ "recordH: Hylo1 "++ show (ns!!i1) ++" not found!"
getFH' hs (ari,i2) fh = case lookupHylo (ns'!!i2) hs of
Just (fh',_,_,_) -> fusedH fh ari fh'
Nothing -> error$ "recordH: Hylo2 "++ show (ns'!!i2) ++" not found!"
findComposition :: [Variable]
-> HyloSet
-> Term
-> Maybe Composition
findComposition vs hs t = case go t of
c@(Comp _ _ _ _) -> Just c
_ -> Nothing
where go (Tfapp v ts) | notElem v vs, Just (fh,i,_,h)<-lookupHylo v hs = Comp fh i h$ map go ts
go (Tapp t0 t1) | Comp fh i h cs <- go t0 = Comp fh i h$ cs ++ [ go t1 ]
go t = CompTerm t
fuseComp :: [Variable] -> Term -> WriterT [Def] (StateT HyloSet VarGenState) Term
fuseComp vs t = get >>= \hs ->
case findComposition vs hs t of
Just c -> lift (fuse c) >>= \(c',dfs) ->
let compCase _ i h = return . Tfapp (getNames h!!i)
in tell dfs >> foldCompositionM compCase return return c'
Nothing -> descend vs t
where
descend vs (Ttuple b ts) = fmap (Ttuple b)$ mapM (fuseComp vs) ts
descend vs (Tfapp v ts) = fmap (Tfapp v)$ mapM (fuseComp vs) ts
descend vs (Tcapp c ts) = fmap (Tcapp c)$ mapM (fuseComp vs) ts
descend vs (Tapp t0 t1) = liftM2 Tapp (fuseComp vs t0) (fuseComp vs t1)
descend vs (Tcase t0 ps ts) = liftM2 (flip Tcase ps) (fuseComp vs t0)
(zipWithM (\p->fuseComp (vars p++vs)) ps ts)
descend vs (Tif t0 t1 t2) = liftM3 Tif (fuseComp vs t0) (fuseComp vs t1) (fuseComp vs t2)
descend vs (Tlet v t0 t1) = liftM2 (Tlet v) (fuseComp (v:vs) t0) (fuseComp vs t1)
descend vs (Tpar t) = fmap Tpar (fuseComp vs t)
descend vs (Tlamb bv t) = fmap (Tlamb bv) (fuseComp (vars bv++vs) t)
descend vs t = return t
fuse :: Composition -> StateT HyloSet VarGenState (Composition,[Def])
fuse c = do cs <- findFusion$ runC (fusings c) (Tree . return . MTReturn)
let (_,c) = maximumBy (compare `on` fst)$ map (rankC &&& id) cs
dfs<-foldCompositionM collectDfs (const (return [])) return c
return (c,dfs)
where rankC (Comp fh _ _ cs) = rank fh + sum (map rankC cs)
rankC (CompSplit c) = rankC c
rankC _ = 0
collectDfs (FHLeaf _) i h dfs = return$ concat dfs
collectDfs fh i h dfs = get >>= \hs ->
case lookupFusion fh hs of
Nothing -> return$ concat dfs
Just (_,hdfs,_) -> return$ hdfs ++ filter (flip elem (vars hdfs) . getDefName) (concat dfs)
findFusion :: Monad m => Tree b m a -> m [a]
findFusion (Tree m) = m >>= \n -> case n of
MTReturn a -> return [a]
MTNode ml mr -> findFusion ml
MTTag _ m -> findFusion m
fusings :: Composition -> TreeT () (StateT HyloSet VarGenState) Composition
fusings (Comp fh i h cs) =
(mapM fusings cs >>= fmap comp . fuseOneComp fh h)
`tnode`
(do (fh',i,h,cs') <- fuseOneComp fh h cs
mapM fusings cs' >>= return . Comp fh' i h
)
where ris h cs = maybe [getConstantArgCount h..length cs] ([0..length cs]\\)$ getConstantArgPos h
isRecArg h cs ai = elemIndex ai (ris h cs)
fuseOneComp :: FHistory -> HyloT -> [Composition]
-> TreeT () (StateT HyloSet VarGenState) (FHistory,Int,HyloT,[Composition])
fuseOneComp fc h cs =
let rargs = [ ((ai,ari),(fc',i',h',cs')) | (ai,Comp fc' i' h' cs') <- zip [0..] cs
, Just ari<-[ isRecArg h cs ai ] ]
in foldM applyFusion (fc,i,h,cs) (reverse rargs)
comp (fc,i,h,cs) = Comp fc i h cs
applyFusion :: (FHistory,Int,HyloT,[Composition]) -> ((Int,Int),(FHistory,Int,HyloT,[Composition]))
-> TreeT () (StateT HyloSet VarGenState) (FHistory,Int,HyloT,[Composition])
applyFusion (fh,i,h,cs) ((ai,ari),(fh',i',h',cs')) =
let fh'' = fusedH fh ari fh'
splitNRR h cs = case getConstantArgPos h of
Nothing -> splitAt (getConstantArgCount h) cs
Just is -> (map snd *** map snd)$ partition (flip elem is . fst) (zip [0..] cs)
(csa,csb) = splitNRR h cs
(csb',_:csb'') = splitAt ari csb
(csah',csbh') = splitNRR h' cs'
in lift get >>= \ hs ->
case lookupFusion fh'' hs of
Just (i,dfs,hr) -> return (fh'',0,hr,csa++csah'++csb'++csbh'++csb'')
_ -> lift (lift (runErrorT (fusionar' [] h i ari h' i'))) >>= \e ->
case e of
Right (r,his,hr) | r>0 ->
do dfs <- lift$ lift (inline hr)
ehr' <- lift$ lift$ (runErrorT (deriveHylo dfs))
let hr' = either (const$ hr) id ehr'
lift$ recordH h h' dfs hr' his
return (fh'',0,hr',csa++csah'++csb'++csbh'++csb'')
_ -> let args c = case getConstantArgPos h of
Nothing ->
let (a,(b,_:c')) = second (splitAt ari)$ splitAt (getConstantArgCount h) cs
in a++b++c:c'
_ -> let (a,_:b) = splitAt ai cs
in a++c:b
in treturn () (fh,i,h,args$ CompSplit (Comp fh' i' h' cs'))
fusings c = return c
foldCompositionM :: Monad m => (FHistory -> Int -> HyloT -> [b] -> m b) -> (Term -> m b) -> (b -> m b) -> Composition -> m b
foldCompositionM f0 f1 f2 (Comp fc i h cs) = mapM (foldCompositionM f0 f1 f2) cs >>= f0 fc i h
foldCompositionM f0 f1 f2 (CompTerm t) = f1 t
foldCompositionM f0 f1 f2 (CompSplit c) = foldCompositionM f0 f1 f2 c >>= f2
type HyloSet = (M.Map Variable FHistory,M.Map FHistory (Int,[Def],HyloT))
lookupHylo :: Variable -> HyloSet -> Maybe (FHistory,Int,[Def],HyloT)
lookupHylo n (nm,hs) =
do fh <- M.lookup n nm
(i,dfs,h) <- M.lookup fh hs
return (fh,i,dfs,h)
lookupFusion :: FHistory -> HyloSet -> Maybe (Int,[Def],HyloT)
lookupFusion fh = M.lookup fh . snd
createHyloSet :: [([Def],HyloT)] -> HyloSet
createHyloSet hs = let assocs = [ ((v,FHLeaf v),(i,dfs,h)) | (dfs,h) <- hs, (i,v) <- zip [0..] (getNames h) ]
in (M.fromList$ map fst assocs, M.fromList$ map (first snd) assocs)
addHylo :: [Def] -> HyloT -> [FHistory] -> HyloSet -> HyloSet
addHylo dfs h fhs (nm,hm) =
let assocs = [ ((v,fh),(i,dfs,h)) | (i,v,fh)<-zip3 [0..] (getNames h) fhs ]
in ( foldr (uncurry M.insert) nm$ map fst assocs
, foldr (uncurry M.insert) hm$ map (first snd) assocs
)
newtype TreeT b m a = C { runC :: forall r. (a -> Tree b m r) -> Tree b m r }
newtype Tree b m a = Tree { runTree :: m (MTree b m a) }
data MTree b m a = MTReturn a | MTNode (Tree b m a) (Tree b m a) | MTTag b (Tree b m a)
instance Functor (TreeT b m) where
fmap f (C fc) = C$ \k -> fc (k . f)
instance Monad (TreeT b m) where
return a = C ($ a)
C f >>= fc = C$ \k -> f (\a -> runC (fc a) k)
instance MonadTrans (TreeT b) where
lift m = C$ \k -> Tree$ m >>= runTree . k
treturn :: Monad m => b -> a -> TreeT b m a
treturn b a = C$ \k -> Tree$ return$ MTTag b$ k a
tnode :: Monad m => TreeT b m a -> TreeT b m a -> TreeT b m a
tnode (C cl) (C cr) = C$ \k -> Tree$ return (MTNode (cl k) (cr k))