{-# OPTIONS_GHC -Wall #-} module Dvda.CSE ( cse ) where import Control.Monad.ST ( ST, runST ) import Data.Foldable ( toList ) import Data.Hashable ( Hashable ) import Data.IntMap ( IntMap ) import qualified Data.IntMap as IM import Data.Tuple ( swap ) import Dvda.Expr ( GExpr(..), Floatings(..), Fractionals(..), Nums(..) ) import Dvda.FunGraph import qualified Data.HashTable.Class as HT import qualified Data.HashTable.ST.Cuckoo as C type HashTable s v k = C.HashTable s v k cse :: (Eq a, Hashable a) => FunGraph a -> FunGraph a cse fg = nodelistToFunGraph (map swap htList) (fgInputs fg) outputIndices where (htList, im) = cse' (fgLookupGExpr fg) (fgOutputs fg) -- since the fgInputs are all symbolic (GSym _) there is no need for mapping old inputs to new inputs outputIndices = let oldIndexToNewIndex k = case IM.lookup k im of Just k' -> k' Nothing -> error $ "CSE error, in mapping old output indices to new, found an old one which was missing from" ++ "the old --> new Int mapping" in map (fmap oldIndexToNewIndex) (fgOutputs fg) cse' :: (Eq a, Hashable a) => (Int -> Maybe (GExpr a Int)) -> [MVS Int] -> ([(GExpr a Int, Int)], IntMap Int) cse' lookupFun outputIndices = runST $ do ht <- HT.new let -- folding function f (im,n) [] = return (im,n) f (im0,n0) (k:ks) = do (_,im,n) <- insertOldNode k lookupFun ht im0 n0 f (im,n) ks -- outputs (oldToNewIdx,_) <- f (IM.empty,0) (concatMap toList outputIndices) htList <- HT.toList ht return (htList, oldToNewIdx) ---- | take in an Int that represents a node in the original graph ---- see if that int has been inserted in the new graph insertOldNode :: (Eq a, Hashable a) => Int -- ^ Int to be inserted -> (Int -> Maybe (GExpr a Int)) -- ^ function to lookup old GExpr from old Int reference -> HashTable s (GExpr a Int) Int -- ^ hashmap of new GExprs to their new Int references -> IntMap Int -- ^ intmap of old int reference to new int references -> Int -- ^ next free index -> ST s (Int, IntMap Int, Int) insertOldNode kOld lookupOldGExpr ht oldNodeToNewNode0 nextFreeInt0 = case IM.lookup kOld oldNodeToNewNode0 of -- if the int has already been inserted in the new graph, return it Just k -> return (k, oldNodeToNewNode0, nextFreeInt0) -- if the int has not yet been inserted, then insert it -- get the old GExpr to which this node corresponds Nothing -> case lookupOldGExpr kOld of Nothing -> error $ "in CSE, insertOldNode got an old key \"" ++ show kOld ++ "\" with was not found in the old graph" -- insert this old GExpr Just oldGExpr -> do (k, oldNodeToNewNode1, nextFreeInt1) <- insertOldGExpr oldGExpr lookupOldGExpr ht oldNodeToNewNode0 nextFreeInt0 return (k, IM.insert kOld k oldNodeToNewNode1, nextFreeInt1) insertOldGExpr :: (Eq a, Hashable a) => GExpr a Int -- ^ GExpr to be inserted -> (Int -> Maybe (GExpr a Int)) -- ^ function to lookup old GExpr from old Int reference -> HashTable s (GExpr a Int) Int -- ^ hashmap of new GExprs to their new Int references -> IntMap Int -- ^ intmap of old int reference to new int references -> Int -- ^ next free index -> ST s (Int, IntMap Int, Int) insertOldGExpr g@(GSym _) = \_ -> cseInsert g insertOldGExpr g@(GConst _) = \_ -> cseInsert g insertOldGExpr g@(GNum (FromInteger _)) = \_ -> cseInsert g insertOldGExpr g@(GFractional (FromRational _)) = \_ -> cseInsert g insertOldGExpr (GNum (Mul x y)) = insertOldGExprBinary GNum Mul x y insertOldGExpr (GNum (Add x y)) = insertOldGExprBinary GNum Add x y insertOldGExpr (GNum (Sub x y)) = insertOldGExprBinary GNum Sub x y insertOldGExpr (GFractional (Div x y)) = insertOldGExprBinary GFractional Div x y insertOldGExpr (GFloating (Pow x y)) = insertOldGExprBinary GFloating Pow x y insertOldGExpr (GFloating (LogBase x y)) = insertOldGExprBinary GFloating LogBase x y insertOldGExpr (GNum (Negate x)) = insertOldGExprUnary GNum Negate x insertOldGExpr (GNum (Abs x)) = insertOldGExprUnary GNum Abs x insertOldGExpr (GNum (Signum x)) = insertOldGExprUnary GNum Signum x insertOldGExpr (GFloating (Exp x)) = insertOldGExprUnary GFloating Exp x insertOldGExpr (GFloating (Log x)) = insertOldGExprUnary GFloating Log x insertOldGExpr (GFloating (Sin x)) = insertOldGExprUnary GFloating Sin x insertOldGExpr (GFloating (Cos x)) = insertOldGExprUnary GFloating Cos x insertOldGExpr (GFloating (ASin x)) = insertOldGExprUnary GFloating ASin x insertOldGExpr (GFloating (ATan x)) = insertOldGExprUnary GFloating ATan x insertOldGExpr (GFloating (ACos x)) = insertOldGExprUnary GFloating ACos x insertOldGExpr (GFloating (Sinh x)) = insertOldGExprUnary GFloating Sinh x insertOldGExpr (GFloating (Cosh x)) = insertOldGExprUnary GFloating Cosh x insertOldGExpr (GFloating (Tanh x)) = insertOldGExprUnary GFloating Tanh x insertOldGExpr (GFloating (ASinh x)) = insertOldGExprUnary GFloating ASinh x insertOldGExpr (GFloating (ATanh x)) = insertOldGExprUnary GFloating ATanh x insertOldGExpr (GFloating (ACosh x)) = insertOldGExprUnary GFloating ACosh x insertOldGExprBinary :: (Eq a, Hashable a) => (f -> GExpr a Int) -> (Int -> Int -> f) -> Int -> Int -> (Int -> Maybe (GExpr a Int)) -- ^ function to lookup old GExpr from old Int reference -> HashTable s (GExpr a Int) Int -- ^ hashmap of new GExprs to their new Int references -> IntMap Int -- ^ intmap of old int reference to new int references -> Int -- ^ next free index -> ST s (Int, IntMap Int, Int) insertOldGExprBinary gnum mul kxOld kyOld lookupOldGExpr ht oldNodeToNewNode0 nextFreeInt0 = do (kx, oldNodeToNewNode1,nextFreeInt1) <- insertOldNode kxOld lookupOldGExpr ht oldNodeToNewNode0 nextFreeInt0 (ky, oldNodeToNewNode2,nextFreeInt2) <- insertOldNode kyOld lookupOldGExpr ht oldNodeToNewNode1 nextFreeInt1 let newGExpr = gnum (mul kx ky) cseInsert newGExpr ht oldNodeToNewNode2 nextFreeInt2 insertOldGExprUnary :: (Eq a, Hashable a) => (f -> GExpr a Int) -> (Int -> f) -> Int -> (Int -> Maybe (GExpr a Int)) -- ^ function to lookup old GExpr from old Int reference -> HashTable s (GExpr a Int) Int -- ^ hashmap of new GExprs to their new Int references -> IntMap Int -- ^ intmap of old int reference to new int references -> Int -- ^ next free index -> ST s (Int, IntMap Int, Int) insertOldGExprUnary gnum neg kxOld lookupOldGExpr ht oldNodeToNewNode0 nextFreeInt0 = do (kx, oldNodeToNewNode1,nextFreeInt1) <- insertOldNode kxOld lookupOldGExpr ht oldNodeToNewNode0 nextFreeInt0 let newGExpr = gnum (neg kx) cseInsert newGExpr ht oldNodeToNewNode1 nextFreeInt1 cseInsert :: (Eq a, Hashable a) => GExpr a Int -> HashTable s (GExpr a Int) Int -> IntMap Int -> Int -> ST s (Int, IntMap Int, Int) cseInsert gexpr ht oldNodeToNewNode0 nextFreeInt0 = do lu <- HT.lookup ht gexpr case lu of Just k -> return (k, oldNodeToNewNode0, nextFreeInt0) Nothing -> do HT.insert ht gexpr nextFreeInt0 return (nextFreeInt0, oldNodeToNewNode0, nextFreeInt0+1)