module Dvda.CSE ( cse
) where
import Control.Monad.ST ( ST, runST )
import Data.Foldable ( toList )
import Data.Hashable ( Hashable )
import Data.IntMap ( IntMap )
import qualified Data.IntMap as IM
import Data.Tuple ( swap )
import Dvda.Expr ( GExpr(..), Floatings(..), Fractionals(..), Nums(..) )
import Dvda.FunGraph
import qualified Data.HashTable.Class as HT
import qualified Data.HashTable.ST.Cuckoo as C
type HashTable s v k = C.HashTable s v k
cse :: (Eq a, Hashable a) => FunGraph a -> FunGraph a
cse fg = nodelistToFunGraph (map swap htList) (fgInputs fg) outputIndices
where
(htList, im) = cse' (fgLookupGExpr fg) (fgOutputs fg)
outputIndices = let
oldIndexToNewIndex k = case IM.lookup k im of
Just k' -> k'
Nothing -> error $
"CSE error, in mapping old output indices to new, found an old one which was missing from" ++
"the old --> new Int mapping"
in map (fmap oldIndexToNewIndex) (fgOutputs fg)
cse' ::
(Eq a, Hashable a)
=> (Int -> Maybe (GExpr a Int))
-> [MVS Int]
-> ([(GExpr a Int, Int)], IntMap Int)
cse' lookupFun outputIndices = runST $ do
ht <- HT.new
let
f (im,n) [] = return (im,n)
f (im0,n0) (k:ks) = do
(_,im,n) <- insertOldNode k lookupFun ht im0 n0
f (im,n) ks
(oldToNewIdx,_) <- f (IM.empty,0) (concatMap toList outputIndices)
htList <- HT.toList ht
return (htList, oldToNewIdx)
insertOldNode ::
(Eq a, Hashable a)
=> Int
-> (Int -> Maybe (GExpr a Int))
-> HashTable s (GExpr a Int) Int
-> IntMap Int
-> Int
-> ST s (Int, IntMap Int, Int)
insertOldNode kOld lookupOldGExpr ht oldNodeToNewNode0 nextFreeInt0 =
case IM.lookup kOld oldNodeToNewNode0 of
Just k -> return (k, oldNodeToNewNode0, nextFreeInt0)
Nothing -> case lookupOldGExpr kOld of
Nothing -> error $ "in CSE, insertOldNode got an old key \"" ++ show kOld ++
"\" with was not found in the old graph"
Just oldGExpr -> do
(k, oldNodeToNewNode1, nextFreeInt1) <- insertOldGExpr oldGExpr lookupOldGExpr ht oldNodeToNewNode0 nextFreeInt0
return (k, IM.insert kOld k oldNodeToNewNode1, nextFreeInt1)
insertOldGExpr ::
(Eq a, Hashable a)
=> GExpr a Int
-> (Int -> Maybe (GExpr a Int))
-> HashTable s (GExpr a Int) Int
-> IntMap Int
-> Int
-> ST s (Int, IntMap Int, Int)
insertOldGExpr g@(GSym _) = \_ -> cseInsert g
insertOldGExpr g@(GConst _) = \_ -> cseInsert g
insertOldGExpr g@(GNum (FromInteger _)) = \_ -> cseInsert g
insertOldGExpr g@(GFractional (FromRational _)) = \_ -> cseInsert g
insertOldGExpr (GNum (Mul x y)) = insertOldGExprBinary GNum Mul x y
insertOldGExpr (GNum (Add x y)) = insertOldGExprBinary GNum Add x y
insertOldGExpr (GNum (Sub x y)) = insertOldGExprBinary GNum Sub x y
insertOldGExpr (GFractional (Div x y)) = insertOldGExprBinary GFractional Div x y
insertOldGExpr (GFloating (Pow x y)) = insertOldGExprBinary GFloating Pow x y
insertOldGExpr (GFloating (LogBase x y)) = insertOldGExprBinary GFloating LogBase x y
insertOldGExpr (GNum (Negate x)) = insertOldGExprUnary GNum Negate x
insertOldGExpr (GNum (Abs x)) = insertOldGExprUnary GNum Abs x
insertOldGExpr (GNum (Signum x)) = insertOldGExprUnary GNum Signum x
insertOldGExpr (GFloating (Exp x)) = insertOldGExprUnary GFloating Exp x
insertOldGExpr (GFloating (Log x)) = insertOldGExprUnary GFloating Log x
insertOldGExpr (GFloating (Sin x)) = insertOldGExprUnary GFloating Sin x
insertOldGExpr (GFloating (Cos x)) = insertOldGExprUnary GFloating Cos x
insertOldGExpr (GFloating (ASin x)) = insertOldGExprUnary GFloating ASin x
insertOldGExpr (GFloating (ATan x)) = insertOldGExprUnary GFloating ATan x
insertOldGExpr (GFloating (ACos x)) = insertOldGExprUnary GFloating ACos x
insertOldGExpr (GFloating (Sinh x)) = insertOldGExprUnary GFloating Sinh x
insertOldGExpr (GFloating (Cosh x)) = insertOldGExprUnary GFloating Cosh x
insertOldGExpr (GFloating (Tanh x)) = insertOldGExprUnary GFloating Tanh x
insertOldGExpr (GFloating (ASinh x)) = insertOldGExprUnary GFloating ASinh x
insertOldGExpr (GFloating (ATanh x)) = insertOldGExprUnary GFloating ATanh x
insertOldGExpr (GFloating (ACosh x)) = insertOldGExprUnary GFloating ACosh x
insertOldGExprBinary ::
(Eq a, Hashable a)
=> (f -> GExpr a Int)
-> (Int -> Int -> f)
-> Int -> Int
-> (Int -> Maybe (GExpr a Int))
-> HashTable s (GExpr a Int) Int
-> IntMap Int
-> Int
-> ST s (Int, IntMap Int, Int)
insertOldGExprBinary gnum mul kxOld kyOld lookupOldGExpr ht oldNodeToNewNode0 nextFreeInt0 = do
(kx, oldNodeToNewNode1,nextFreeInt1) <- insertOldNode kxOld lookupOldGExpr ht oldNodeToNewNode0 nextFreeInt0
(ky, oldNodeToNewNode2,nextFreeInt2) <- insertOldNode kyOld lookupOldGExpr ht oldNodeToNewNode1 nextFreeInt1
let newGExpr = gnum (mul kx ky)
cseInsert newGExpr ht oldNodeToNewNode2 nextFreeInt2
insertOldGExprUnary ::
(Eq a, Hashable a)
=> (f -> GExpr a Int)
-> (Int -> f)
-> Int
-> (Int -> Maybe (GExpr a Int))
-> HashTable s (GExpr a Int) Int
-> IntMap Int
-> Int
-> ST s (Int, IntMap Int, Int)
insertOldGExprUnary gnum neg kxOld lookupOldGExpr ht oldNodeToNewNode0 nextFreeInt0 = do
(kx, oldNodeToNewNode1,nextFreeInt1) <- insertOldNode kxOld lookupOldGExpr ht oldNodeToNewNode0 nextFreeInt0
let newGExpr = gnum (neg kx)
cseInsert newGExpr ht oldNodeToNewNode1 nextFreeInt1
cseInsert :: (Eq a, Hashable a) => GExpr a Int -> HashTable s (GExpr a Int) Int -> IntMap Int -> Int
-> ST s (Int, IntMap Int, Int)
cseInsert gexpr ht oldNodeToNewNode0 nextFreeInt0 = do
lu <- HT.lookup ht gexpr
case lu of
Just k -> return (k, oldNodeToNewNode0, nextFreeInt0)
Nothing -> do
HT.insert ht gexpr nextFreeInt0
return (nextFreeInt0, oldNodeToNewNode0, nextFreeInt0+1)