{-# OPTIONS_GHC -Wall #-} {-# Language BangPatterns #-} {-# Language ScopedTypeVariables #-} -- | This file is a modified version from Andy Gill's data-reify package -- It is modified to use Data.HashTable.IO, which gives a speed improvement -- at the expense of portability. This also gives me a more convenient -- sandbox to investigate other performance tweaks, though it is unclear -- if I have made anything any faster. module Dvda.Algorithm.Reify ( ReifyGraph(..) , Node(..) , reifyGraph ) where import Control.Monad.State.Strict ( StateT(..), runStateT ) import Data.Hashable ( Hashable(..) ) import Control.Applicative ( pure ) import Data.Traversable ( Traversable ) import qualified Data.Traversable as T import System.Mem.StableName ( StableName, makeStableName, hashStableName ) import Unsafe.Coerce ( unsafeCoerce ) import Dvda.Expr import qualified Data.HashTable.IO as H type HashTable k v = H.CuckooHashTable k v newtype Node = Node Int deriving (Ord, Eq) instance Show Node where show (Node k) = '@' : show k data ReifyGraph e = ReifyGraph [(Node,e Node)] mapAccumM' :: (Monad m, Functor m, Traversable t) => (a -> b -> m (c, a)) -> a -> t b -> m (t c, a) mapAccumM' f = flip (runStateT . T.traverse (StateT . flip f)) --{-# INLINE mapAccumM' #-} mapAccumM :: (Monad m, Functor m, Traversable t) => (a -> b -> m (a, c)) -> a -> t b -> m (t c, a) mapAccumM f' = mapAccumM' f where f acc z = do (x,y) <- f' acc z return (y,x) --{-# INLINE mapAccumM #-} mapDeRef :: (acc -> Expr a -> IO (acc, Node)) -> acc -> Expr a -> IO (acc, GExpr a Node) mapDeRef _ acc0 (ESym name) = pure (acc0, GSym name) mapDeRef _ acc0 (EConst c) = pure (acc0, GConst c) mapDeRef f acc0 (ENum (Mul x y)) = do (acc1, fx) <- f acc0 x (acc2, fy) <- f acc1 y return (acc2, GNum (Mul fx fy)) mapDeRef f acc0 (ENum (Add x y)) = do (acc1, fx) <- f acc0 x (acc2, fy) <- f acc1 y return (acc2, GNum (Add fx fy)) mapDeRef f acc0 (ENum (Sub x y)) = do (acc1, fx) <- f acc0 x (acc2, fy) <- f acc1 y return (acc2, GNum (Sub fx fy)) mapDeRef f acc0 (ENum (Negate x)) = do (acc1, fx) <- f acc0 x return (acc1, GNum (Negate fx)) mapDeRef f acc0 (ENum (Abs x)) = do (acc1, fx) <- f acc0 x return (acc1, GNum (Abs fx)) mapDeRef f acc0 (ENum (Signum x)) = do (acc1, fx) <- f acc0 x return (acc1, GNum (Signum fx)) mapDeRef _ acc0 (ENum (FromInteger k)) = pure (acc0, GNum (FromInteger k)) mapDeRef f acc0 (EFractional (Div x y)) = do (acc1, fx) <- f acc0 x (acc2, fy) <- f acc1 y return (acc2, GFractional (Div fx fy)) mapDeRef _ acc0 (EFractional (FromRational x)) = pure (acc0, GFractional (FromRational x)) mapDeRef f acc0 (EFloating (Pow x y)) = do (acc1, fx) <- f acc0 x (acc2, fy) <- f acc1 y return (acc2, GFloating (Pow fx fy)) mapDeRef f acc0 (EFloating (LogBase x y)) = do (acc1, fx) <- f acc0 x (acc2, fy) <- f acc1 y return (acc2, GFloating (LogBase fx fy)) mapDeRef f acc0 (EFloating (Exp x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (Exp fx)) mapDeRef f acc0 (EFloating (Log x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (Log fx)) mapDeRef f acc0 (EFloating (Sin x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (Sin fx)) mapDeRef f acc0 (EFloating (Cos x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (Cos fx)) mapDeRef f acc0 (EFloating (Tan x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (Tan fx)) mapDeRef f acc0 (EFloating (ASin x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (ASin fx)) mapDeRef f acc0 (EFloating (ATan x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (ATan fx)) mapDeRef f acc0 (EFloating (ACos x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (ACos fx)) mapDeRef f acc0 (EFloating (Sinh x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (Sinh fx)) mapDeRef f acc0 (EFloating (Cosh x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (Cosh fx)) mapDeRef f acc0 (EFloating (Tanh x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (Tanh fx)) mapDeRef f acc0 (EFloating (ASinh x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (ASinh fx)) mapDeRef f acc0 (EFloating (ATanh x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (ATanh fx)) mapDeRef f acc0 (EFloating (ACosh x)) = do (acc1, fx) <- f acc0 x return (acc1, GFloating (ACosh fx)) -- {-# INLINE mapDeRef #-} reifyGraph :: forall t a . Traversable t => t (Expr a) -> IO (ReifyGraph (GExpr a), t Node) reifyGraph m = do ht <- H.new :: IO (HashTable DynStableName Node) let findNodes :: ([(Node, GExpr a Node)],Node) -> Expr a -> IO (([(Node, GExpr a Node)],Node), Node) findNodes !(!tab0, nextUnique@(Node nextUnique')) expr = do stableName <- makeDynStableName expr lu <- H.lookup ht stableName case lu of Just var -> return ((tab0,nextUnique), var) Nothing -> do let var = nextUnique H.insert ht stableName var ((tab1,nextNextUnique), res) <- mapDeRef findNodes (tab0, Node (nextUnique' + 1)) expr let tab2 :: [(Node,GExpr a Node)] tab2 = (var,res) : tab1 return ((tab2,nextNextUnique), var) -- {-# INLINE findNodes #-} (root, (pairs,_)) <- mapAccumM findNodes ([], Node 0) m return (ReifyGraph pairs, root) -- Stable names that not use phantom types. -- As suggested by Ganesh Sittampalam. newtype DynStableName = DynStableName (StableName ()) deriving Eq instance Hashable DynStableName where hashWithSalt salt = (salt `hashWithSalt`) . hashDynStableName hashDynStableName :: DynStableName -> Int hashDynStableName (DynStableName sn) = hashStableName sn makeDynStableName :: a -> IO DynStableName makeDynStableName !a = do st <- makeStableName a return $ DynStableName (unsafeCoerce st) --{-# INLINE makeDynStableName #-}