{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralisedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Downhill.Internal.Graph.OpenGraph
  ( OpenEdge (..),
    OpenEndpoint (..),
    OpenNode (..),
    OpenGraph (..),
    recoverSharing,
  )
where

import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict (StateT (..), get, modify)
import Downhill.Internal.Graph.OpenMap (OpenKey, OpenMap)
import qualified Downhill.Internal.Graph.OpenMap as OpenMap
import Downhill.Internal.Graph.Types (BackFun (BackFun))
import Downhill.Linear.Expr (BasicVector, Expr (ExprSum, ExprVar), Term (..))
import Prelude hiding (lookup)

data OpenEndpoint a v where
  OpenSourceNode :: OpenEndpoint a a
  OpenInnerNode :: OpenKey v -> OpenEndpoint a v

data OpenEdge a v where
  OpenEdge :: BackFun u v -> OpenEndpoint a u -> OpenEdge a v

data OpenNode a v = BasicVector v => OpenNode [OpenEdge a v]

-- | Maintains a cache of visited 'Expr's.
newtype TreeBuilder a r = TreeCache {TreeBuilder a r -> StateT (OpenMap (OpenNode a)) IO r
unTreeCache :: StateT (OpenMap (OpenNode a)) IO r}
  deriving (a -> TreeBuilder a b -> TreeBuilder a a
(a -> b) -> TreeBuilder a a -> TreeBuilder a b
(forall a b. (a -> b) -> TreeBuilder a a -> TreeBuilder a b)
-> (forall a b. a -> TreeBuilder a b -> TreeBuilder a a)
-> Functor (TreeBuilder a)
forall a b. a -> TreeBuilder a b -> TreeBuilder a a
forall a b. (a -> b) -> TreeBuilder a a -> TreeBuilder a b
forall a a b. a -> TreeBuilder a b -> TreeBuilder a a
forall a a b. (a -> b) -> TreeBuilder a a -> TreeBuilder a b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> TreeBuilder a b -> TreeBuilder a a
$c<$ :: forall a a b. a -> TreeBuilder a b -> TreeBuilder a a
fmap :: (a -> b) -> TreeBuilder a a -> TreeBuilder a b
$cfmap :: forall a a b. (a -> b) -> TreeBuilder a a -> TreeBuilder a b
Functor, Functor (TreeBuilder a)
a -> TreeBuilder a a
Functor (TreeBuilder a)
-> (forall a. a -> TreeBuilder a a)
-> (forall a b.
    TreeBuilder a (a -> b) -> TreeBuilder a a -> TreeBuilder a b)
-> (forall a b c.
    (a -> b -> c)
    -> TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a c)
-> (forall a b.
    TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b)
-> (forall a b.
    TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a a)
-> Applicative (TreeBuilder a)
TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a a
TreeBuilder a (a -> b) -> TreeBuilder a a -> TreeBuilder a b
(a -> b -> c)
-> TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a c
forall a. Functor (TreeBuilder a)
forall a. a -> TreeBuilder a a
forall a a. a -> TreeBuilder a a
forall a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a a
forall a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
forall a b.
TreeBuilder a (a -> b) -> TreeBuilder a a -> TreeBuilder a b
forall a a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a a
forall a a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
forall a a b.
TreeBuilder a (a -> b) -> TreeBuilder a a -> TreeBuilder a b
forall a b c.
(a -> b -> c)
-> TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a c
forall a a b c.
(a -> b -> c)
-> TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a a
$c<* :: forall a a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a a
*> :: TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
$c*> :: forall a a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
liftA2 :: (a -> b -> c)
-> TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a c
$cliftA2 :: forall a a b c.
(a -> b -> c)
-> TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a c
<*> :: TreeBuilder a (a -> b) -> TreeBuilder a a -> TreeBuilder a b
$c<*> :: forall a a b.
TreeBuilder a (a -> b) -> TreeBuilder a a -> TreeBuilder a b
pure :: a -> TreeBuilder a a
$cpure :: forall a a. a -> TreeBuilder a a
$cp1Applicative :: forall a. Functor (TreeBuilder a)
Applicative, Applicative (TreeBuilder a)
a -> TreeBuilder a a
Applicative (TreeBuilder a)
-> (forall a b.
    TreeBuilder a a -> (a -> TreeBuilder a b) -> TreeBuilder a b)
-> (forall a b.
    TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b)
-> (forall a. a -> TreeBuilder a a)
-> Monad (TreeBuilder a)
TreeBuilder a a -> (a -> TreeBuilder a b) -> TreeBuilder a b
TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
forall a. Applicative (TreeBuilder a)
forall a. a -> TreeBuilder a a
forall a a. a -> TreeBuilder a a
forall a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
forall a b.
TreeBuilder a a -> (a -> TreeBuilder a b) -> TreeBuilder a b
forall a a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
forall a a b.
TreeBuilder a a -> (a -> TreeBuilder a b) -> TreeBuilder a b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> TreeBuilder a a
$creturn :: forall a a. a -> TreeBuilder a a
>> :: TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
$c>> :: forall a a b. TreeBuilder a a -> TreeBuilder a b -> TreeBuilder a b
>>= :: TreeBuilder a a -> (a -> TreeBuilder a b) -> TreeBuilder a b
$c>>= :: forall a a b.
TreeBuilder a a -> (a -> TreeBuilder a b) -> TreeBuilder a b
$cp1Monad :: forall a. Applicative (TreeBuilder a)
Monad)

insertIntoCache :: OpenKey dv -> OpenNode a dv -> TreeBuilder a ()
insertIntoCache :: OpenKey dv -> OpenNode a dv -> TreeBuilder a ()
insertIntoCache OpenKey dv
name OpenNode a dv
value = StateT (OpenMap (OpenNode a)) IO () -> TreeBuilder a ()
forall a r. StateT (OpenMap (OpenNode a)) IO r -> TreeBuilder a r
TreeCache (StateT (OpenMap (OpenNode a)) IO () -> TreeBuilder a ())
-> StateT (OpenMap (OpenNode a)) IO () -> TreeBuilder a ()
forall a b. (a -> b) -> a -> b
$ (OpenMap (OpenNode a) -> OpenMap (OpenNode a))
-> StateT (OpenMap (OpenNode a)) IO ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (OpenKey dv
-> OpenNode a dv -> OpenMap (OpenNode a) -> OpenMap (OpenNode a)
forall (f :: * -> *) dx.
OpenKey dx -> f dx -> OpenMap f -> OpenMap f
OpenMap.insert OpenKey dv
name OpenNode a dv
value)

-- | @buildExpr action key@ will run @action@, associate result with @key@ and
-- store it in cache. If @key@ is already in cache, @action@ will not be run.
buildExpr ::
  TreeBuilder a (OpenNode a v) ->
  Expr a v ->
  TreeBuilder a (OpenKey v, OpenNode a v)
buildExpr :: TreeBuilder a (OpenNode a v)
-> Expr a v -> TreeBuilder a (OpenKey v, OpenNode a v)
buildExpr TreeBuilder a (OpenNode a v)
action Expr a v
key = do
  OpenKey v
name <- StateT (OpenMap (OpenNode a)) IO (OpenKey v)
-> TreeBuilder a (OpenKey v)
forall a r. StateT (OpenMap (OpenNode a)) IO r -> TreeBuilder a r
TreeCache (IO (OpenKey v) -> StateT (OpenMap (OpenNode a)) IO (OpenKey v)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Expr a v -> IO (OpenKey v)
forall (f :: * -> *) v. f v -> IO (OpenKey v)
OpenMap.makeOpenKey Expr a v
key))
  OpenMap (OpenNode a)
cache <- StateT (OpenMap (OpenNode a)) IO (OpenMap (OpenNode a))
-> TreeBuilder a (OpenMap (OpenNode a))
forall a r. StateT (OpenMap (OpenNode a)) IO r -> TreeBuilder a r
TreeCache StateT (OpenMap (OpenNode a)) IO (OpenMap (OpenNode a))
forall (m :: * -> *) s. Monad m => StateT s m s
get
  case OpenMap (OpenNode a) -> OpenKey v -> Maybe (OpenNode a v)
forall (f :: * -> *) x. OpenMap f -> OpenKey x -> Maybe (f x)
OpenMap.lookup OpenMap (OpenNode a)
cache OpenKey v
name of
    Just OpenNode a v
x -> (OpenKey v, OpenNode a v)
-> TreeBuilder a (OpenKey v, OpenNode a v)
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenKey v
name, OpenNode a v
x)
    Maybe (OpenNode a v)
Nothing -> do
      OpenNode a v
value <- TreeBuilder a (OpenNode a v)
action
      OpenKey v -> OpenNode a v -> TreeBuilder a ()
forall dv a. OpenKey dv -> OpenNode a dv -> TreeBuilder a ()
insertIntoCache OpenKey v
name OpenNode a v
value
      (OpenKey v, OpenNode a v)
-> TreeBuilder a (OpenKey v, OpenNode a v)
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenKey v
name, OpenNode a v
value)

runTreeBuilder :: forall a g dv. TreeBuilder a (g dv) -> IO (g dv, OpenMap (OpenNode a))
runTreeBuilder :: TreeBuilder a (g dv) -> IO (g dv, OpenMap (OpenNode a))
runTreeBuilder TreeBuilder a (g dv)
rs_x = StateT (OpenMap (OpenNode a)) IO (g dv)
-> OpenMap (OpenNode a) -> IO (g dv, OpenMap (OpenNode a))
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (TreeBuilder a (g dv) -> StateT (OpenMap (OpenNode a)) IO (g dv)
forall a r. TreeBuilder a r -> StateT (OpenMap (OpenNode a)) IO r
unTreeCache TreeBuilder a (g dv)
rs_x) OpenMap (OpenNode a)
forall (f :: * -> *). OpenMap f
OpenMap.empty

-- | Computational graph under construction. "Open" refers to the set of the nodes – new nodes can be
-- added to this graph. Once the graph is complete the set of nodes will be frozen
-- and the type of the graph will become 'Graph' ("Downhill.Internal.Graph" module).
data OpenGraph a z = OpenGraph (OpenNode a z) (OpenMap (OpenNode a))

goEdges :: BasicVector v => [Term a v] -> TreeBuilder a (OpenNode a v)
goEdges :: [Term a v] -> TreeBuilder a (OpenNode a v)
goEdges [Term a v]
xs = do
  [OpenEdge a v]
xs' <- (Term a v -> TreeBuilder a (OpenEdge a v))
-> [Term a v] -> TreeBuilder a [OpenEdge a v]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Term a v -> TreeBuilder a (OpenEdge a v)
forall a v. Term a v -> TreeBuilder a (OpenEdge a v)
goSharing4term [Term a v]
xs
  OpenNode a v -> TreeBuilder a (OpenNode a v)
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenNode a v -> TreeBuilder a (OpenNode a v))
-> OpenNode a v -> TreeBuilder a (OpenNode a v)
forall a b. (a -> b) -> a -> b
$ [OpenEdge a v] -> OpenNode a v
forall a v. BasicVector v => [OpenEdge a v] -> OpenNode a v
OpenNode [OpenEdge a v]
xs'

goSharing4arg :: forall a v. Expr a v -> TreeBuilder a (OpenEndpoint a v)
goSharing4arg :: Expr a v -> TreeBuilder a (OpenEndpoint a v)
goSharing4arg Expr a v
key = case Expr a v
key of
  Expr a v
ExprVar -> OpenEndpoint a a -> TreeBuilder a (OpenEndpoint a a)
forall (m :: * -> *) a. Monad m => a -> m a
return OpenEndpoint a a
forall a. OpenEndpoint a a
OpenSourceNode
  ExprSum [Term a v]
xs -> do
    (OpenKey v
gRef, OpenNode a v
_) <- TreeBuilder a (OpenNode a v)
-> Expr a v -> TreeBuilder a (OpenKey v, OpenNode a v)
forall a v.
TreeBuilder a (OpenNode a v)
-> Expr a v -> TreeBuilder a (OpenKey v, OpenNode a v)
buildExpr ([Term a v] -> TreeBuilder a (OpenNode a v)
forall v a.
BasicVector v =>
[Term a v] -> TreeBuilder a (OpenNode a v)
goEdges [Term a v]
xs) Expr a v
key
    OpenEndpoint a v -> TreeBuilder a (OpenEndpoint a v)
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenKey v -> OpenEndpoint a v
forall v a. OpenKey v -> OpenEndpoint a v
OpenInnerNode OpenKey v
gRef)

goSharing4term :: forall a v. Term a v -> TreeBuilder a (OpenEdge a v)
goSharing4term :: Term a v -> TreeBuilder a (OpenEdge a v)
goSharing4term = \case
  Term v -> VecBuilder u
f Expr a u
arg -> do
    OpenEndpoint a u
arg' <- Expr a u -> TreeBuilder a (OpenEndpoint a u)
forall a v. Expr a v -> TreeBuilder a (OpenEndpoint a v)
goSharing4arg Expr a u
arg
    OpenEdge a v -> TreeBuilder a (OpenEdge a v)
forall (m :: * -> *) a. Monad m => a -> m a
return (BackFun u v -> OpenEndpoint a u -> OpenEdge a v
forall u v a. BackFun u v -> OpenEndpoint a u -> OpenEdge a v
OpenEdge ((v -> VecBuilder u) -> BackFun u v
forall u v. (v -> VecBuilder u) -> BackFun u v
BackFun v -> VecBuilder u
f) OpenEndpoint a u
arg')

-- | Collects duplicate nodes in 'Expr' tree and converts it to a graph.
recoverSharing :: forall a z. BasicVector z => [Term a z] -> IO (OpenGraph a z)
recoverSharing :: [Term a z] -> IO (OpenGraph a z)
recoverSharing [Term a z]
xs = do
  (OpenNode a z
final_node, OpenMap (OpenNode a)
graph) <- TreeBuilder a (OpenNode a z)
-> IO (OpenNode a z, OpenMap (OpenNode a))
forall a (g :: * -> *) dv.
TreeBuilder a (g dv) -> IO (g dv, OpenMap (OpenNode a))
runTreeBuilder ([Term a z] -> TreeBuilder a (OpenNode a z)
forall v a.
BasicVector v =>
[Term a v] -> TreeBuilder a (OpenNode a v)
goEdges [Term a z]
xs)
  OpenGraph a z -> IO (OpenGraph a z)
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenNode a z -> OpenMap (OpenNode a) -> OpenGraph a z
forall a z. OpenNode a z -> OpenMap (OpenNode a) -> OpenGraph a z
OpenGraph OpenNode a z
final_node OpenMap (OpenNode a)
graph)