{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Downhill.Internal.Graph.Graph
  (  -- * Graph type
    Graph (..), Node(..),
    SomeGraph (..),
    -- * Evaluate
    evalGraph,
    -- * Transpose
    transposeGraph,
    --transposeFwdGraph,
    --transposeBackGraph,
    -- * Construct
    unsafeFromOpenGraph,
  )
where

import Data.Either (partitionEithers)
import Data.Functor.Identity (Identity (Identity, runIdentity))
import Downhill.Internal.Graph.NodeMap
  ( IsNodeSet,
    NodeKey,
    NodeMap,
    KeyAndValue (KeyAndValue),
    SomeNodeMap (SomeNodeMap),
  )
import qualified Downhill.Internal.Graph.NodeMap as NodeMap
import Downhill.Internal.Graph.OpenGraph (OpenGraph (OpenGraph), OpenNode (OpenNode), OpenEdge (OpenEdge), OpenEndpoint (OpenSourceNode, OpenInnerNode))
import Downhill.Internal.Graph.Types (FwdFun (FwdFun), BackFun)
import Downhill.Linear.Expr (BasicVector (VecBuilder, sumBuilder))
import Prelude hiding (head, tail)
import GHC.Stack (callStack, prettyCallStack, HasCallStack)

data Endpoint s a v where
    SourceNode :: Endpoint s a a
    InnerNode :: NodeKey s v -> Endpoint s a v

data Edge s e a v where
    Edge :: e u v -> Endpoint s a u -> Edge s e a v

{-| Inner node. This does not include initial node. Contains a list
of ingoing edges. -}
data Node s e a v = BasicVector v => Node [Edge s e a v]

data Graph s e a z = BasicVector a =>
  Graph
  { forall s (e :: * -> * -> *) a z.
Graph s e a z -> NodeMap s (Node s e a)
graphInnerNodes :: NodeMap s (Node s e a),
    forall s (e :: * -> * -> *) a z. Graph s e a z -> Node s e a z
graphFinalNode :: Node s e a z
  }

data SomeGraph e a z where
  SomeGraph :: IsNodeSet s => Graph s e a z -> SomeGraph e a z

{- `Edge` stores head endpoint only. `AnyEdge` stores both endpoints. -}
data AnyEdge s e a z = forall u v.
  AnyEdge
  { ()
_edgeTail :: Endpoint s z v,
    ()
_edgeLabel :: e u v,
    ()
_edgeHead :: Endpoint s a u
  }

-- | Forward mode evaluation
evalGraph :: forall s x z. Graph s FwdFun z x -> z -> x
evalGraph :: forall s x z. Graph s FwdFun z x -> z -> x
evalGraph (Graph NodeMap s (Node s FwdFun z)
nodes Node s FwdFun z x
finalNode) z
dz = forall v. Node s FwdFun z v -> v
evalNode Node s FwdFun z x
finalNode
  where
    evalParent :: forall v. Endpoint s z v -> v
    evalParent :: forall v. Endpoint s z v -> v
evalParent = \case
      Endpoint s z v
SourceNode -> z
dz
      InnerNode NodeKey s v
nodeName -> forall a. Identity a -> a
runIdentity (forall s (f :: * -> *) v. NodeMap s f -> NodeKey s v -> f v
NodeMap.lookup NodeMap s Identity
innerValues NodeKey s v
nodeName)
    evalEdge :: Edge s FwdFun z v -> VecBuilder v
    evalEdge :: forall v. Edge s FwdFun z v -> VecBuilder v
evalEdge (Edge (FwdFun u -> VecBuilder v
f) Endpoint s z u
tail) = u -> VecBuilder v
f forall a b. (a -> b) -> a -> b
$ forall v. Endpoint s z v -> v
evalParent Endpoint s z u
tail
    evalNode :: Node s FwdFun z v -> v
    evalNode :: forall v. Node s FwdFun z v -> v
evalNode (Node [Edge s FwdFun z v]
xs) = forall v. BasicVector v => VecBuilder v -> v
sumBuilder (forall a. Monoid a => [a] -> a
mconcat [forall v. Edge s FwdFun z v -> VecBuilder v
evalEdge Edge s FwdFun z v
x | Edge s FwdFun z v
x <- [Edge s FwdFun z v]
xs])
    innerValues :: NodeMap s Identity
    innerValues :: NodeMap s Identity
innerValues = forall s (f :: * -> *) (g :: * -> *).
(forall v. f v -> g v) -> NodeMap s f -> NodeMap s g
NodeMap.map (forall a. a -> Identity a
Identity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. Node s FwdFun z v -> v
evalNode) NodeMap s (Node s FwdFun z)
nodes

nodeEdges :: forall s f a z x. NodeKey s x -> Node s f a x -> [AnyEdge s f a z]
nodeEdges :: forall s (f :: * -> * -> *) a z x.
NodeKey s x -> Node s f a x -> [AnyEdge s f a z]
nodeEdges NodeKey s x
name (Node [Edge s f a x]
xs) = Edge s f a x -> AnyEdge s f a z
go forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Edge s f a x]
xs
  where
    go :: Edge s f a x -> AnyEdge s f a z
    go :: Edge s f a x -> AnyEdge s f a z
go (Edge f u x
f Endpoint s a u
head) = forall s (e :: * -> * -> *) a z u v.
Endpoint s z v -> e u v -> Endpoint s a u -> AnyEdge s e a z
AnyEdge (forall s v a. NodeKey s v -> Endpoint s a v
InnerNode NodeKey s x
name) f u x
f Endpoint s a u
head

allGraphEdges :: forall s f a z. Graph s f a z -> [AnyEdge s f a z]
allGraphEdges :: forall s (f :: * -> * -> *) a z. Graph s f a z -> [AnyEdge s f a z]
allGraphEdges (Graph NodeMap s (Node s f a)
innerNodes (Node [Edge s f a z]
es)) = [AnyEdge s f a z]
finalEdges forall a. [a] -> [a] -> [a]
++ [AnyEdge s f a z]
innerEdges
  where
    innerEdges :: [AnyEdge s f a z]
    innerEdges :: [AnyEdge s f a z]
innerEdges = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (forall s (f :: * -> *) r.
(forall x. NodeKey s x -> f x -> r) -> NodeMap s f -> [r]
NodeMap.toListWith forall s (f :: * -> * -> *) a z x.
NodeKey s x -> Node s f a x -> [AnyEdge s f a z]
nodeEdges NodeMap s (Node s f a)
innerNodes)
    finalEdges :: [AnyEdge s f a z]
    finalEdges :: [AnyEdge s f a z]
finalEdges = Edge s f a z -> AnyEdge s f a z
wrapFinalEdge forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Edge s f a z]
es
      where
        wrapFinalEdge :: Edge s f a z -> AnyEdge s f a z
        wrapFinalEdge :: Edge s f a z -> AnyEdge s f a z
wrapFinalEdge (Edge f u z
f Endpoint s a u
head) = forall s (e :: * -> * -> *) a z u v.
Endpoint s z v -> e u v -> Endpoint s a u -> AnyEdge s e a z
AnyEdge forall s a. Endpoint s a a
SourceNode f u z
f Endpoint s a u
head

sortByTail ::
  forall s f da dz.
  AnyEdge s f da dz ->
  Either (Edge s f da dz) (KeyAndValue s (Edge s f da))
sortByTail :: forall s (f :: * -> * -> *) da dz.
AnyEdge s f da dz
-> Either (Edge s f da dz) (KeyAndValue s (Edge s f da))
sortByTail (AnyEdge Endpoint s dz v
tail f u v
f Endpoint s da u
head) = case Endpoint s dz v
tail of
  Endpoint s dz v
SourceNode -> forall a b. a -> Either a b
Left (forall (e :: * -> * -> *) s v s a.
e s v -> Endpoint s a s -> Edge s e a v
Edge f u v
f Endpoint s da u
head)
  InnerNode NodeKey s v
x -> forall a b. b -> Either a b
Right (forall s (f :: * -> *) x. NodeKey s x -> f x -> KeyAndValue s f
KeyAndValue NodeKey s v
x (forall (e :: * -> * -> *) s v s a.
e s v -> Endpoint s a s -> Edge s e a v
Edge f u v
f Endpoint s da u
head))

flipAnyEdge :: (forall u v. f u v -> g v u) -> AnyEdge s f a z -> AnyEdge s g z a
flipAnyEdge :: forall (f :: * -> * -> *) (g :: * -> * -> *) s a z.
(forall u v. f u v -> g v u) -> AnyEdge s f a z -> AnyEdge s g z a
flipAnyEdge forall u v. f u v -> g v u
flipF (AnyEdge Endpoint s z v
tail f u v
f Endpoint s a u
head) = forall s (e :: * -> * -> *) a z u v.
Endpoint s z v -> e u v -> Endpoint s a u -> AnyEdge s e a z
AnyEdge Endpoint s a u
head (forall u v. f u v -> g v u
flipF f u v
f) Endpoint s z v
tail

{- BasicVector constraint is needed to construct a node.
   `NodeMap s NodeDict` is a list of all nodes.
-}
data NodeDict x = BasicVector x => NodeDict

emptyNodeMap :: forall s e z. NodeMap s NodeDict -> NodeMap s (Node s e z)
emptyNodeMap :: forall s (e :: * -> * -> *) z.
NodeMap s NodeDict -> NodeMap s (Node s e z)
emptyNodeMap = forall s (f :: * -> *) (g :: * -> *).
(forall v. f v -> g v) -> NodeMap s f -> NodeMap s g
NodeMap.map forall x. NodeDict x -> Node s e z x
emptyNode
  where
    emptyNode :: forall x. NodeDict x -> Node s e z x
    emptyNode :: forall x. NodeDict x -> Node s e z x
emptyNode = \case
      NodeDict x
NodeDict -> forall s (e :: * -> * -> *) a v.
BasicVector v =>
[Edge s e a v] -> Node s e a v
Node []

edgeListToGraph ::
  forall s e a z.
  (IsNodeSet s, BasicVector a, BasicVector z) =>
  NodeMap s NodeDict ->
  [AnyEdge s e z a] ->
  Graph s e z a
edgeListToGraph :: forall s (e :: * -> * -> *) a z.
(IsNodeSet s, BasicVector a, BasicVector z) =>
NodeMap s NodeDict -> [AnyEdge s e z a] -> Graph s e z a
edgeListToGraph NodeMap s NodeDict
nodes [AnyEdge s e z a]
flippedEdges = forall s (e :: * -> * -> *) a z.
BasicVector a =>
NodeMap s (Node s e a) -> Node s e a z -> Graph s e a z
Graph NodeMap s (Node s e z)
innerNodes (forall s (e :: * -> * -> *) a v.
BasicVector v =>
[Edge s e a v] -> Node s e a v
Node [Edge s e z a]
initialEdges)
  where
    initialEdges :: [Edge s e z a]
    innerEdges :: [KeyAndValue s (Edge s e z)]
    ([Edge s e z a]
initialEdges, [KeyAndValue s (Edge s e z)]
innerEdges) = forall a b. [Either a b] -> ([a], [b])
partitionEithers (forall s (f :: * -> * -> *) da dz.
AnyEdge s f da dz
-> Either (Edge s f da dz) (KeyAndValue s (Edge s f da))
sortByTail forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [AnyEdge s e z a]
flippedEdges)
    prependToMap :: KeyAndValue s (Edge s e z) -> NodeMap s (Node s e z) -> NodeMap s (Node s e z)
    prependToMap :: KeyAndValue s (Edge s e z)
-> NodeMap s (Node s e z) -> NodeMap s (Node s e z)
prependToMap (KeyAndValue NodeKey s x
key Edge s e z x
edge) = forall s (f :: * -> *) x.
(f x -> f x) -> NodeKey s x -> NodeMap s f -> NodeMap s f
NodeMap.adjust Node s e z x -> Node s e z x
prependToNode NodeKey s x
key
      where
        prependToNode :: Node s e z x -> Node s e z x
prependToNode (Node [Edge s e z x]
edges) = forall s (e :: * -> * -> *) a v.
BasicVector v =>
[Edge s e a v] -> Node s e a v
Node (Edge s e z x
edge forall a. a -> [a] -> [a]
: [Edge s e z x]
edges)
    innerNodes :: NodeMap s (Node s e z)
innerNodes = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr KeyAndValue s (Edge s e z)
-> NodeMap s (Node s e z) -> NodeMap s (Node s e z)
prependToMap (forall s (e :: * -> * -> *) z.
NodeMap s NodeDict -> NodeMap s (Node s e z)
emptyNodeMap NodeMap s NodeDict
nodes) [KeyAndValue s (Edge s e z)]
innerEdges
  
graphNodes :: Graph s f da dz -> NodeMap s NodeDict
graphNodes :: forall s (f :: * -> * -> *) da dz.
Graph s f da dz -> NodeMap s NodeDict
graphNodes (Graph NodeMap s (Node s f da)
env Node s f da dz
_) = forall s (f :: * -> *) (g :: * -> *).
(forall v. f v -> g v) -> NodeMap s f -> NodeMap s g
NodeMap.map forall s (f :: * -> * -> *) da dv. Node s f da dv -> NodeDict dv
go NodeMap s (Node s f da)
env
  where
    go :: Node s f da dv -> NodeDict dv
    go :: forall s (f :: * -> * -> *) da dv. Node s f da dv -> NodeDict dv
go = \case
      Node [Edge s f da dv]
_ -> forall x. BasicVector x => NodeDict x
NodeDict

-- | Reverse edges. Turns reverse mode evaluation into forward mode.
transposeGraph :: forall s f g a z. IsNodeSet s => (forall u v. f u v -> g v u) -> Graph s f a z -> Graph s g z a
transposeGraph :: forall s (f :: * -> * -> *) (g :: * -> * -> *) a z.
IsNodeSet s =>
(forall u v. f u v -> g v u) -> Graph s f a z -> Graph s g z a
transposeGraph forall u v. f u v -> g v u
flipEdge g :: Graph s f a z
g@(Graph NodeMap s (Node s f a)
_ (Node [Edge s f a z]
_)) = forall s (e :: * -> * -> *) a z.
(IsNodeSet s, BasicVector a, BasicVector z) =>
NodeMap s NodeDict -> [AnyEdge s e z a] -> Graph s e z a
edgeListToGraph (forall s (f :: * -> * -> *) da dz.
Graph s f da dz -> NodeMap s NodeDict
graphNodes Graph s f a z
g) [AnyEdge s g z a]
flippedEdges
  where edges :: [AnyEdge s f a z]
        edges :: [AnyEdge s f a z]
edges = forall s (f :: * -> * -> *) a z. Graph s f a z -> [AnyEdge s f a z]
allGraphEdges Graph s f a z
g
        flippedEdges :: [AnyEdge s g z a]
        flippedEdges :: [AnyEdge s g z a]
flippedEdges = forall (f :: * -> * -> *) (g :: * -> * -> *) s a z.
(forall u v. f u v -> g v u) -> AnyEdge s f a z -> AnyEdge s g z a
flipAnyEdge forall u v. f u v -> g v u
flipEdge forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [AnyEdge s f a z]
edges

_mapEdges :: forall s f g a z. (forall u v. f u v -> g u v) -> Graph s f a z -> Graph s g a z
_mapEdges :: forall s (f :: * -> * -> *) (g :: * -> * -> *) a z.
(forall u v. f u v -> g u v) -> Graph s f a z -> Graph s g a z
_mapEdges forall u v. f u v -> g u v
f (Graph NodeMap s (Node s f a)
inner Node s f a z
final) = forall s (e :: * -> * -> *) a z.
BasicVector a =>
NodeMap s (Node s e a) -> Node s e a z -> Graph s e a z
Graph (forall s (f :: * -> *) (g :: * -> *).
(forall v. f v -> g v) -> NodeMap s f -> NodeMap s g
NodeMap.map forall v. Node s f a v -> Node s g a v
go NodeMap s (Node s f a)
inner) (forall v. Node s f a v -> Node s g a v
go Node s f a z
final)
  where
    go :: Node s f a v -> Node s g a v
    go :: forall v. Node s f a v -> Node s g a v
go (Node [Edge s f a v]
xs) = forall s (e :: * -> * -> *) a v.
BasicVector v =>
[Edge s e a v] -> Node s e a v
Node [forall p x. Edge p f a x -> Edge p g a x
goEdge Edge s f a v
x | Edge s f a v
x <- [Edge s f a v]
xs]
    goEdge :: Edge p f a x -> Edge p g a x
    goEdge :: forall p x. Edge p f a x -> Edge p g a x
goEdge (Edge f u x
e Endpoint p a u
x) = forall (e :: * -> * -> *) s v s a.
e s v -> Endpoint s a s -> Edge s e a v
Edge (forall u v. f u v -> g u v
f f u x
e) Endpoint p a u
x

unsafeConstructGraph :: forall s a v. (IsNodeSet s, BasicVector a, HasCallStack) => NodeMap s (OpenNode a) -> OpenNode a v -> Graph s BackFun a v
unsafeConstructGraph :: forall s a v.
(IsNodeSet s, BasicVector a, HasCallStack) =>
NodeMap s (OpenNode a) -> OpenNode a v -> Graph s BackFun a v
unsafeConstructGraph NodeMap s (OpenNode a)
m OpenNode a v
x = forall s (e :: * -> * -> *) a z.
BasicVector a =>
NodeMap s (Node s e a) -> Node s e a z -> Graph s e a z
Graph (forall s (f :: * -> *) (g :: * -> *).
(forall v. f v -> g v) -> NodeMap s f -> NodeMap s g
NodeMap.map forall x. OpenNode a x -> Node s BackFun a x
mkExpr NodeMap s (OpenNode a)
m) (forall x. OpenNode a x -> Node s BackFun a x
mkExpr OpenNode a v
x)
  where
    mkExpr :: forall x. OpenNode a x -> Node s BackFun a x
    mkExpr :: forall x. OpenNode a x -> Node s BackFun a x
mkExpr = \case
      OpenNode [OpenEdge a x]
terms -> forall s (e :: * -> * -> *) a v.
BasicVector v =>
[Edge s e a v] -> Node s e a v
Node (forall x. OpenEdge a x -> Edge s BackFun a x
mkTerm forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [OpenEdge a x]
terms)
    mkTerm :: forall x. OpenEdge a x -> Edge s BackFun a x
    mkTerm :: forall x. OpenEdge a x -> Edge s BackFun a x
mkTerm = \case
      OpenEdge BackFun u x
f OpenEndpoint a u
x' -> forall (e :: * -> * -> *) s v s a.
e s v -> Endpoint s a s -> Edge s e a v
Edge BackFun u x
f (forall u. OpenEndpoint a u -> Endpoint s a u
mkArg OpenEndpoint a u
x')
    mkArg :: forall u. OpenEndpoint a u -> Endpoint s a u
    mkArg :: forall u. OpenEndpoint a u -> Endpoint s a u
mkArg = \case
      OpenEndpoint a u
OpenSourceNode -> forall s a. Endpoint s a a
SourceNode
      OpenInnerNode OpenKey u
key -> case forall s (f :: * -> *) x.
NodeMap s f -> OpenKey x -> Maybe (NodeKey s x, f x)
NodeMap.tryLookup NodeMap s (OpenNode a)
m OpenKey u
key of
        Just (NodeKey s u
key', OpenNode a u
_value) -> forall s v a. NodeKey s v -> Endpoint s a v
InnerNode NodeKey s u
key'
        Maybe (NodeKey s u, OpenNode a u)
Nothing -> forall a. HasCallStack => [Char] -> a
error ([Char]
"Downhill: invalid key in constructGraph\n" forall a. [a] -> [a] -> [a]
++ CallStack -> [Char]
prettyCallStack HasCallStack => CallStack
callStack)

-- | Will crash if graph has invalid keys
unsafeFromOpenGraph :: (BasicVector a, HasCallStack) => OpenGraph a v -> SomeGraph BackFun a v
unsafeFromOpenGraph :: forall a v.
(BasicVector a, HasCallStack) =>
OpenGraph a v -> SomeGraph BackFun a v
unsafeFromOpenGraph (OpenGraph OpenNode a v
x OpenMap (OpenNode a)
m) =
  case forall (f :: * -> *). OpenMap f -> SomeNodeMap f
NodeMap.fromOpenMap OpenMap (OpenNode a)
m of
    SomeNodeMap NodeMap s (OpenNode a)
m' -> forall s (e :: * -> * -> *) a z.
IsNodeSet s =>
Graph s e a z -> SomeGraph e a z
SomeGraph (forall s a v.
(IsNodeSet s, BasicVector a, HasCallStack) =>
NodeMap s (OpenNode a) -> OpenNode a v -> Graph s BackFun a v
unsafeConstructGraph NodeMap s (OpenNode a)
m' OpenNode a v
x)