{-| Types of nodes and edges of the computational graph.

Parameters:

  * @p@ - is parent node; might be 'OpenKey' or 'NodeKey'

  * @e@ - edge type

  * @a@ - type of the initial node of expression

  * @v@ - type of the node.
-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
module Downhill.Internal.Graph.Types
(
  -- * Linear functions
  BackFun(..), FwdFun(..),
  flipBackFun, flipFwdFun
)
 where

import Downhill.Linear.Expr (BasicVector (VecBuilder))


-- | Edge type for backward mode evaluation
newtype BackFun u v = BackFun {forall u v. BackFun u v -> v -> VecBuilder u
unBackFun :: v -> VecBuilder u}

-- | Edge type for forward mode evaluation
newtype FwdFun u v = FwdFun {forall u v. FwdFun u v -> u -> VecBuilder v
unFwdFun :: u -> VecBuilder v}

flipBackFun :: BackFun u v -> FwdFun v u
flipBackFun :: forall u v. BackFun u v -> FwdFun v u
flipBackFun (BackFun v -> VecBuilder u
f) = forall u v. (u -> VecBuilder v) -> FwdFun u v
FwdFun v -> VecBuilder u
f

flipFwdFun :: FwdFun u v -> BackFun v u
flipFwdFun :: forall u v. FwdFun u v -> BackFun v u
flipFwdFun (FwdFun u -> VecBuilder v
f) = forall u v. (v -> VecBuilder u) -> BackFun u v
BackFun u -> VecBuilder v
f