module Dvda.Algorithm.Reify
( ReifyGraph(..)
, Node(..)
, reifyGraph
) where
import Control.Monad.State.Strict ( StateT(..), runStateT )
import Data.Hashable ( Hashable(..) )
import Control.Applicative ( pure )
import Data.Traversable ( Traversable )
import qualified Data.Traversable as T
import System.Mem.StableName ( StableName, makeStableName, hashStableName )
import Unsafe.Coerce ( unsafeCoerce )
import Dvda.Expr
import qualified Data.HashTable.IO as H
type HashTable k v = H.CuckooHashTable k v
newtype Node = Node Int deriving (Ord, Eq)
instance Show Node where
show (Node k) = '@' : show k
data ReifyGraph e = ReifyGraph [(Node,e Node)]
mapAccumM' :: (Monad m, Functor m, Traversable t) =>
(a -> b -> m (c, a)) -> a -> t b -> m (t c, a)
mapAccumM' f = flip (runStateT . T.traverse (StateT . flip f))
mapAccumM :: (Monad m, Functor m, Traversable t) =>
(a -> b -> m (a, c)) -> a -> t b -> m (t c, a)
mapAccumM f' = mapAccumM' f
where
f acc z = do
(x,y) <- f' acc z
return (y,x)
mapDeRef :: (acc -> Expr a -> IO (acc, Node)) -> acc -> Expr a -> IO (acc, GExpr a Node)
mapDeRef _ acc0 (ESym name) = pure (acc0, GSym name)
mapDeRef _ acc0 (EConst c) = pure (acc0, GConst c)
mapDeRef f acc0 (ENum (Mul x y)) = do
(acc1, fx) <- f acc0 x
(acc2, fy) <- f acc1 y
return (acc2, GNum (Mul fx fy))
mapDeRef f acc0 (ENum (Add x y)) = do
(acc1, fx) <- f acc0 x
(acc2, fy) <- f acc1 y
return (acc2, GNum (Add fx fy))
mapDeRef f acc0 (ENum (Sub x y)) = do
(acc1, fx) <- f acc0 x
(acc2, fy) <- f acc1 y
return (acc2, GNum (Sub fx fy))
mapDeRef f acc0 (ENum (Negate x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GNum (Negate fx))
mapDeRef f acc0 (ENum (Abs x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GNum (Abs fx))
mapDeRef f acc0 (ENum (Signum x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GNum (Signum fx))
mapDeRef _ acc0 (ENum (FromInteger k)) = pure (acc0, GNum (FromInteger k))
mapDeRef f acc0 (EFractional (Div x y)) = do
(acc1, fx) <- f acc0 x
(acc2, fy) <- f acc1 y
return (acc2, GFractional (Div fx fy))
mapDeRef _ acc0 (EFractional (FromRational x)) = pure (acc0, GFractional (FromRational x))
mapDeRef f acc0 (EFloating (Pow x y)) = do
(acc1, fx) <- f acc0 x
(acc2, fy) <- f acc1 y
return (acc2, GFloating (Pow fx fy))
mapDeRef f acc0 (EFloating (LogBase x y)) = do
(acc1, fx) <- f acc0 x
(acc2, fy) <- f acc1 y
return (acc2, GFloating (LogBase fx fy))
mapDeRef f acc0 (EFloating (Exp x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (Exp fx))
mapDeRef f acc0 (EFloating (Log x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (Log fx))
mapDeRef f acc0 (EFloating (Sin x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (Sin fx))
mapDeRef f acc0 (EFloating (Cos x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (Cos fx))
mapDeRef f acc0 (EFloating (Tan x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (Tan fx))
mapDeRef f acc0 (EFloating (ASin x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (ASin fx))
mapDeRef f acc0 (EFloating (ATan x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (ATan fx))
mapDeRef f acc0 (EFloating (ACos x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (ACos fx))
mapDeRef f acc0 (EFloating (Sinh x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (Sinh fx))
mapDeRef f acc0 (EFloating (Cosh x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (Cosh fx))
mapDeRef f acc0 (EFloating (Tanh x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (Tanh fx))
mapDeRef f acc0 (EFloating (ASinh x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (ASinh fx))
mapDeRef f acc0 (EFloating (ATanh x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (ATanh fx))
mapDeRef f acc0 (EFloating (ACosh x)) = do
(acc1, fx) <- f acc0 x
return (acc1, GFloating (ACosh fx))
reifyGraph :: forall t a . Traversable t => t (Expr a) -> IO (ReifyGraph (GExpr a), t Node)
reifyGraph m = do
ht <- H.new :: IO (HashTable DynStableName Node)
let findNodes :: ([(Node, GExpr a Node)],Node) -> Expr a ->
IO (([(Node, GExpr a Node)],Node), Node)
findNodes !(!tab0, nextUnique@(Node nextUnique')) expr = do
stableName <- makeDynStableName expr
lu <- H.lookup ht stableName
case lu of
Just var -> return ((tab0,nextUnique), var)
Nothing -> do
let var = nextUnique
H.insert ht stableName var
((tab1,nextNextUnique), res) <- mapDeRef findNodes (tab0, Node (nextUnique' + 1)) expr
let tab2 :: [(Node,GExpr a Node)]
tab2 = (var,res) : tab1
return ((tab2,nextNextUnique), var)
(root, (pairs,_)) <- mapAccumM findNodes ([], Node 0) m
return (ReifyGraph pairs, root)
newtype DynStableName = DynStableName (StableName ()) deriving Eq
instance Hashable DynStableName where
hashWithSalt salt = (salt `hashWithSalt`) . hashDynStableName
hashDynStableName :: DynStableName -> Int
hashDynStableName (DynStableName sn) = hashStableName sn
makeDynStableName :: a -> IO DynStableName
makeDynStableName !a = do
st <- makeStableName a
return $ DynStableName (unsafeCoerce st)