{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Downhill.Internal.Graph.Graph
(
Graph (..), Node(..),
SomeGraph (..),
evalGraph,
transposeGraph,
unsafeFromOpenGraph,
)
where
import Data.Either (partitionEithers)
import Data.Functor.Identity (Identity (Identity, runIdentity))
import Downhill.Internal.Graph.NodeMap
( IsNodeSet,
NodeKey,
NodeMap,
KeyAndValue (KeyAndValue),
SomeNodeMap (SomeNodeMap),
)
import qualified Downhill.Internal.Graph.NodeMap as NodeMap
import Downhill.Internal.Graph.OpenGraph (OpenGraph (OpenGraph), OpenNode (OpenNode), OpenEdge (OpenEdge), OpenEndpoint (OpenSourceNode, OpenInnerNode))
import Downhill.Internal.Graph.Types (FwdFun (FwdFun), BackFun)
import Downhill.Linear.Expr (BasicVector (VecBuilder, sumBuilder))
import Prelude hiding (head, tail)
import GHC.Stack (callStack, prettyCallStack, HasCallStack)
data Endpoint s a v where
SourceNode :: Endpoint s a a
InnerNode :: NodeKey s v -> Endpoint s a v
data Edge s e a v where
Edge :: e u v -> Endpoint s a u -> Edge s e a v
data Node s e a v = BasicVector v => Node [Edge s e a v]
data Graph s e a z = BasicVector a =>
Graph
{ forall s (e :: * -> * -> *) a z.
Graph s e a z -> NodeMap s (Node s e a)
graphInnerNodes :: NodeMap s (Node s e a),
forall s (e :: * -> * -> *) a z. Graph s e a z -> Node s e a z
graphFinalNode :: Node s e a z
}
data SomeGraph e a z where
SomeGraph :: IsNodeSet s => Graph s e a z -> SomeGraph e a z
data AnyEdge s e a z = forall u v.
AnyEdge
{ ()
_edgeTail :: Endpoint s z v,
()
_edgeLabel :: e u v,
()
_edgeHead :: Endpoint s a u
}
evalGraph :: forall s x z. Graph s FwdFun z x -> z -> x
evalGraph :: forall s x z. Graph s FwdFun z x -> z -> x
evalGraph (Graph NodeMap s (Node s FwdFun z)
nodes Node s FwdFun z x
finalNode) z
dz = forall v. Node s FwdFun z v -> v
evalNode Node s FwdFun z x
finalNode
where
evalParent :: forall v. Endpoint s z v -> v
evalParent :: forall v. Endpoint s z v -> v
evalParent = \case
Endpoint s z v
SourceNode -> z
dz
InnerNode NodeKey s v
nodeName -> forall a. Identity a -> a
runIdentity (forall s (f :: * -> *) v. NodeMap s f -> NodeKey s v -> f v
NodeMap.lookup NodeMap s Identity
innerValues NodeKey s v
nodeName)
evalEdge :: Edge s FwdFun z v -> VecBuilder v
evalEdge :: forall v. Edge s FwdFun z v -> VecBuilder v
evalEdge (Edge (FwdFun u -> VecBuilder v
f) Endpoint s z u
tail) = u -> VecBuilder v
f forall a b. (a -> b) -> a -> b
$ forall v. Endpoint s z v -> v
evalParent Endpoint s z u
tail
evalNode :: Node s FwdFun z v -> v
evalNode :: forall v. Node s FwdFun z v -> v
evalNode (Node [Edge s FwdFun z v]
xs) = forall v. BasicVector v => VecBuilder v -> v
sumBuilder (forall a. Monoid a => [a] -> a
mconcat [forall v. Edge s FwdFun z v -> VecBuilder v
evalEdge Edge s FwdFun z v
x | Edge s FwdFun z v
x <- [Edge s FwdFun z v]
xs])
innerValues :: NodeMap s Identity
innerValues :: NodeMap s Identity
innerValues = forall s (f :: * -> *) (g :: * -> *).
(forall v. f v -> g v) -> NodeMap s f -> NodeMap s g
NodeMap.map (forall a. a -> Identity a
Identity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. Node s FwdFun z v -> v
evalNode) NodeMap s (Node s FwdFun z)
nodes
nodeEdges :: forall s f a z x. NodeKey s x -> Node s f a x -> [AnyEdge s f a z]
nodeEdges :: forall s (f :: * -> * -> *) a z x.
NodeKey s x -> Node s f a x -> [AnyEdge s f a z]
nodeEdges NodeKey s x
name (Node [Edge s f a x]
xs) = Edge s f a x -> AnyEdge s f a z
go forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Edge s f a x]
xs
where
go :: Edge s f a x -> AnyEdge s f a z
go :: Edge s f a x -> AnyEdge s f a z
go (Edge f u x
f Endpoint s a u
head) = forall s (e :: * -> * -> *) a z u v.
Endpoint s z v -> e u v -> Endpoint s a u -> AnyEdge s e a z
AnyEdge (forall s v a. NodeKey s v -> Endpoint s a v
InnerNode NodeKey s x
name) f u x
f Endpoint s a u
head
allGraphEdges :: forall s f a z. Graph s f a z -> [AnyEdge s f a z]
allGraphEdges :: forall s (f :: * -> * -> *) a z. Graph s f a z -> [AnyEdge s f a z]
allGraphEdges (Graph NodeMap s (Node s f a)
innerNodes (Node [Edge s f a z]
es)) = [AnyEdge s f a z]
finalEdges forall a. [a] -> [a] -> [a]
++ [AnyEdge s f a z]
innerEdges
where
innerEdges :: [AnyEdge s f a z]
innerEdges :: [AnyEdge s f a z]
innerEdges = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (forall s (f :: * -> *) r.
(forall x. NodeKey s x -> f x -> r) -> NodeMap s f -> [r]
NodeMap.toListWith forall s (f :: * -> * -> *) a z x.
NodeKey s x -> Node s f a x -> [AnyEdge s f a z]
nodeEdges NodeMap s (Node s f a)
innerNodes)
finalEdges :: [AnyEdge s f a z]
finalEdges :: [AnyEdge s f a z]
finalEdges = Edge s f a z -> AnyEdge s f a z
wrapFinalEdge forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Edge s f a z]
es
where
wrapFinalEdge :: Edge s f a z -> AnyEdge s f a z
wrapFinalEdge :: Edge s f a z -> AnyEdge s f a z
wrapFinalEdge (Edge f u z
f Endpoint s a u
head) = forall s (e :: * -> * -> *) a z u v.
Endpoint s z v -> e u v -> Endpoint s a u -> AnyEdge s e a z
AnyEdge forall s a. Endpoint s a a
SourceNode f u z
f Endpoint s a u
head
sortByTail ::
forall s f da dz.
AnyEdge s f da dz ->
Either (Edge s f da dz) (KeyAndValue s (Edge s f da))
sortByTail :: forall s (f :: * -> * -> *) da dz.
AnyEdge s f da dz
-> Either (Edge s f da dz) (KeyAndValue s (Edge s f da))
sortByTail (AnyEdge Endpoint s dz v
tail f u v
f Endpoint s da u
head) = case Endpoint s dz v
tail of
Endpoint s dz v
SourceNode -> forall a b. a -> Either a b
Left (forall (e :: * -> * -> *) s v s a.
e s v -> Endpoint s a s -> Edge s e a v
Edge f u v
f Endpoint s da u
head)
InnerNode NodeKey s v
x -> forall a b. b -> Either a b
Right (forall s (f :: * -> *) x. NodeKey s x -> f x -> KeyAndValue s f
KeyAndValue NodeKey s v
x (forall (e :: * -> * -> *) s v s a.
e s v -> Endpoint s a s -> Edge s e a v
Edge f u v
f Endpoint s da u
head))
flipAnyEdge :: (forall u v. f u v -> g v u) -> AnyEdge s f a z -> AnyEdge s g z a
flipAnyEdge :: forall (f :: * -> * -> *) (g :: * -> * -> *) s a z.
(forall u v. f u v -> g v u) -> AnyEdge s f a z -> AnyEdge s g z a
flipAnyEdge forall u v. f u v -> g v u
flipF (AnyEdge Endpoint s z v
tail f u v
f Endpoint s a u
head) = forall s (e :: * -> * -> *) a z u v.
Endpoint s z v -> e u v -> Endpoint s a u -> AnyEdge s e a z
AnyEdge Endpoint s a u
head (forall u v. f u v -> g v u
flipF f u v
f) Endpoint s z v
tail
data NodeDict x = BasicVector x => NodeDict
emptyNodeMap :: forall s e z. NodeMap s NodeDict -> NodeMap s (Node s e z)
emptyNodeMap :: forall s (e :: * -> * -> *) z.
NodeMap s NodeDict -> NodeMap s (Node s e z)
emptyNodeMap = forall s (f :: * -> *) (g :: * -> *).
(forall v. f v -> g v) -> NodeMap s f -> NodeMap s g
NodeMap.map forall x. NodeDict x -> Node s e z x
emptyNode
where
emptyNode :: forall x. NodeDict x -> Node s e z x
emptyNode :: forall x. NodeDict x -> Node s e z x
emptyNode = \case
NodeDict x
NodeDict -> forall s (e :: * -> * -> *) a v.
BasicVector v =>
[Edge s e a v] -> Node s e a v
Node []
edgeListToGraph ::
forall s e a z.
(IsNodeSet s, BasicVector a, BasicVector z) =>
NodeMap s NodeDict ->
[AnyEdge s e z a] ->
Graph s e z a
edgeListToGraph :: forall s (e :: * -> * -> *) a z.
(IsNodeSet s, BasicVector a, BasicVector z) =>
NodeMap s NodeDict -> [AnyEdge s e z a] -> Graph s e z a
edgeListToGraph NodeMap s NodeDict
nodes [AnyEdge s e z a]
flippedEdges = forall s (e :: * -> * -> *) a z.
BasicVector a =>
NodeMap s (Node s e a) -> Node s e a z -> Graph s e a z
Graph NodeMap s (Node s e z)
innerNodes (forall s (e :: * -> * -> *) a v.
BasicVector v =>
[Edge s e a v] -> Node s e a v
Node [Edge s e z a]
initialEdges)
where
initialEdges :: [Edge s e z a]
innerEdges :: [KeyAndValue s (Edge s e z)]
([Edge s e z a]
initialEdges, [KeyAndValue s (Edge s e z)]
innerEdges) = forall a b. [Either a b] -> ([a], [b])
partitionEithers (forall s (f :: * -> * -> *) da dz.
AnyEdge s f da dz
-> Either (Edge s f da dz) (KeyAndValue s (Edge s f da))
sortByTail forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [AnyEdge s e z a]
flippedEdges)
prependToMap :: KeyAndValue s (Edge s e z) -> NodeMap s (Node s e z) -> NodeMap s (Node s e z)
prependToMap :: KeyAndValue s (Edge s e z)
-> NodeMap s (Node s e z) -> NodeMap s (Node s e z)
prependToMap (KeyAndValue NodeKey s x
key Edge s e z x
edge) = forall s (f :: * -> *) x.
(f x -> f x) -> NodeKey s x -> NodeMap s f -> NodeMap s f
NodeMap.adjust Node s e z x -> Node s e z x
prependToNode NodeKey s x
key
where
prependToNode :: Node s e z x -> Node s e z x
prependToNode (Node [Edge s e z x]
edges) = forall s (e :: * -> * -> *) a v.
BasicVector v =>
[Edge s e a v] -> Node s e a v
Node (Edge s e z x
edge forall a. a -> [a] -> [a]
: [Edge s e z x]
edges)
innerNodes :: NodeMap s (Node s e z)
innerNodes = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr KeyAndValue s (Edge s e z)
-> NodeMap s (Node s e z) -> NodeMap s (Node s e z)
prependToMap (forall s (e :: * -> * -> *) z.
NodeMap s NodeDict -> NodeMap s (Node s e z)
emptyNodeMap NodeMap s NodeDict
nodes) [KeyAndValue s (Edge s e z)]
innerEdges
graphNodes :: Graph s f da dz -> NodeMap s NodeDict
graphNodes :: forall s (f :: * -> * -> *) da dz.
Graph s f da dz -> NodeMap s NodeDict
graphNodes (Graph NodeMap s (Node s f da)
env Node s f da dz
_) = forall s (f :: * -> *) (g :: * -> *).
(forall v. f v -> g v) -> NodeMap s f -> NodeMap s g
NodeMap.map forall s (f :: * -> * -> *) da dv. Node s f da dv -> NodeDict dv
go NodeMap s (Node s f da)
env
where
go :: Node s f da dv -> NodeDict dv
go :: forall s (f :: * -> * -> *) da dv. Node s f da dv -> NodeDict dv
go = \case
Node [Edge s f da dv]
_ -> forall x. BasicVector x => NodeDict x
NodeDict
transposeGraph :: forall s f g a z. IsNodeSet s => (forall u v. f u v -> g v u) -> Graph s f a z -> Graph s g z a
transposeGraph :: forall s (f :: * -> * -> *) (g :: * -> * -> *) a z.
IsNodeSet s =>
(forall u v. f u v -> g v u) -> Graph s f a z -> Graph s g z a
transposeGraph forall u v. f u v -> g v u
flipEdge g :: Graph s f a z
g@(Graph NodeMap s (Node s f a)
_ (Node [Edge s f a z]
_)) = forall s (e :: * -> * -> *) a z.
(IsNodeSet s, BasicVector a, BasicVector z) =>
NodeMap s NodeDict -> [AnyEdge s e z a] -> Graph s e z a
edgeListToGraph (forall s (f :: * -> * -> *) da dz.
Graph s f da dz -> NodeMap s NodeDict
graphNodes Graph s f a z
g) [AnyEdge s g z a]
flippedEdges
where edges :: [AnyEdge s f a z]
edges :: [AnyEdge s f a z]
edges = forall s (f :: * -> * -> *) a z. Graph s f a z -> [AnyEdge s f a z]
allGraphEdges Graph s f a z
g
flippedEdges :: [AnyEdge s g z a]
flippedEdges :: [AnyEdge s g z a]
flippedEdges = forall (f :: * -> * -> *) (g :: * -> * -> *) s a z.
(forall u v. f u v -> g v u) -> AnyEdge s f a z -> AnyEdge s g z a
flipAnyEdge forall u v. f u v -> g v u
flipEdge forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [AnyEdge s f a z]
edges
_mapEdges :: forall s f g a z. (forall u v. f u v -> g u v) -> Graph s f a z -> Graph s g a z
_mapEdges :: forall s (f :: * -> * -> *) (g :: * -> * -> *) a z.
(forall u v. f u v -> g u v) -> Graph s f a z -> Graph s g a z
_mapEdges forall u v. f u v -> g u v
f (Graph NodeMap s (Node s f a)
inner Node s f a z
final) = forall s (e :: * -> * -> *) a z.
BasicVector a =>
NodeMap s (Node s e a) -> Node s e a z -> Graph s e a z
Graph (forall s (f :: * -> *) (g :: * -> *).
(forall v. f v -> g v) -> NodeMap s f -> NodeMap s g
NodeMap.map forall v. Node s f a v -> Node s g a v
go NodeMap s (Node s f a)
inner) (forall v. Node s f a v -> Node s g a v
go Node s f a z
final)
where
go :: Node s f a v -> Node s g a v
go :: forall v. Node s f a v -> Node s g a v
go (Node [Edge s f a v]
xs) = forall s (e :: * -> * -> *) a v.
BasicVector v =>
[Edge s e a v] -> Node s e a v
Node [forall p x. Edge p f a x -> Edge p g a x
goEdge Edge s f a v
x | Edge s f a v
x <- [Edge s f a v]
xs]
goEdge :: Edge p f a x -> Edge p g a x
goEdge :: forall p x. Edge p f a x -> Edge p g a x
goEdge (Edge f u x
e Endpoint p a u
x) = forall (e :: * -> * -> *) s v s a.
e s v -> Endpoint s a s -> Edge s e a v
Edge (forall u v. f u v -> g u v
f f u x
e) Endpoint p a u
x
unsafeConstructGraph :: forall s a v. (IsNodeSet s, BasicVector a, HasCallStack) => NodeMap s (OpenNode a) -> OpenNode a v -> Graph s BackFun a v
unsafeConstructGraph :: forall s a v.
(IsNodeSet s, BasicVector a, HasCallStack) =>
NodeMap s (OpenNode a) -> OpenNode a v -> Graph s BackFun a v
unsafeConstructGraph NodeMap s (OpenNode a)
m OpenNode a v
x = forall s (e :: * -> * -> *) a z.
BasicVector a =>
NodeMap s (Node s e a) -> Node s e a z -> Graph s e a z
Graph (forall s (f :: * -> *) (g :: * -> *).
(forall v. f v -> g v) -> NodeMap s f -> NodeMap s g
NodeMap.map forall x. OpenNode a x -> Node s BackFun a x
mkExpr NodeMap s (OpenNode a)
m) (forall x. OpenNode a x -> Node s BackFun a x
mkExpr OpenNode a v
x)
where
mkExpr :: forall x. OpenNode a x -> Node s BackFun a x
mkExpr :: forall x. OpenNode a x -> Node s BackFun a x
mkExpr = \case
OpenNode [OpenEdge a x]
terms -> forall s (e :: * -> * -> *) a v.
BasicVector v =>
[Edge s e a v] -> Node s e a v
Node (forall x. OpenEdge a x -> Edge s BackFun a x
mkTerm forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [OpenEdge a x]
terms)
mkTerm :: forall x. OpenEdge a x -> Edge s BackFun a x
mkTerm :: forall x. OpenEdge a x -> Edge s BackFun a x
mkTerm = \case
OpenEdge BackFun u x
f OpenEndpoint a u
x' -> forall (e :: * -> * -> *) s v s a.
e s v -> Endpoint s a s -> Edge s e a v
Edge BackFun u x
f (forall u. OpenEndpoint a u -> Endpoint s a u
mkArg OpenEndpoint a u
x')
mkArg :: forall u. OpenEndpoint a u -> Endpoint s a u
mkArg :: forall u. OpenEndpoint a u -> Endpoint s a u
mkArg = \case
OpenEndpoint a u
OpenSourceNode -> forall s a. Endpoint s a a
SourceNode
OpenInnerNode OpenKey u
key -> case forall s (f :: * -> *) x.
NodeMap s f -> OpenKey x -> Maybe (NodeKey s x, f x)
NodeMap.tryLookup NodeMap s (OpenNode a)
m OpenKey u
key of
Just (NodeKey s u
key', OpenNode a u
_value) -> forall s v a. NodeKey s v -> Endpoint s a v
InnerNode NodeKey s u
key'
Maybe (NodeKey s u, OpenNode a u)
Nothing -> forall a. HasCallStack => [Char] -> a
error ([Char]
"Downhill: invalid key in constructGraph\n" forall a. [a] -> [a] -> [a]
++ CallStack -> [Char]
prettyCallStack HasCallStack => CallStack
callStack)
unsafeFromOpenGraph :: (BasicVector a, HasCallStack) => OpenGraph a v -> SomeGraph BackFun a v
unsafeFromOpenGraph :: forall a v.
(BasicVector a, HasCallStack) =>
OpenGraph a v -> SomeGraph BackFun a v
unsafeFromOpenGraph (OpenGraph OpenNode a v
x OpenMap (OpenNode a)
m) =
case forall (f :: * -> *). OpenMap f -> SomeNodeMap f
NodeMap.fromOpenMap OpenMap (OpenNode a)
m of
SomeNodeMap NodeMap s (OpenNode a)
m' -> forall s (e :: * -> * -> *) a z.
IsNodeSet s =>
Graph s e a z -> SomeGraph e a z
SomeGraph (forall s a v.
(IsNodeSet s, BasicVector a, HasCallStack) =>
NodeMap s (OpenNode a) -> OpenNode a v -> Graph s BackFun a v
unsafeConstructGraph NodeMap s (OpenNode a)
m' OpenNode a v
x)