{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-full-laziness #-}
{-# OPTIONS_HADDOCK not-home #-}
module Numeric.AD.Internal.Kahn
( Kahn(..)
, Tape(..)
, partials
, partialArray
, partialMap
, derivative
, derivative'
, vgrad, vgrad'
, Grad(..)
, bind
, unbind
, unbindMap
, unbindWith
, unbindMapWithDefault
, primal
, var
, varId
) where
import Control.Monad.ST
import Control.Monad hiding (mapM)
import Control.Monad.Trans.State
import Data.List (foldl')
import Data.Array.ST
import Data.Array
import Data.IntMap (IntMap, fromListWith, findWithDefault)
import Data.Graph (Vertex, transposeG, Graph)
import Data.Number.Erf
import Data.Reify (reifyGraph, MuRef(..))
import qualified Data.Reify.Graph as Reified
import System.IO.Unsafe (unsafePerformIO)
import Data.Data (Data)
import Data.Typeable (Typeable)
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode
data Tape a t
= Zero
| Lift !a
| Var !a {-# UNPACK #-} !Int
| Binary !a a a t t
| Unary !a a t
deriving (Int -> Tape a t -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a t. (Show a, Show t) => Int -> Tape a t -> ShowS
forall a t. (Show a, Show t) => [Tape a t] -> ShowS
forall a t. (Show a, Show t) => Tape a t -> String
showList :: [Tape a t] -> ShowS
$cshowList :: forall a t. (Show a, Show t) => [Tape a t] -> ShowS
show :: Tape a t -> String
$cshow :: forall a t. (Show a, Show t) => Tape a t -> String
showsPrec :: Int -> Tape a t -> ShowS
$cshowsPrec :: forall a t. (Show a, Show t) => Int -> Tape a t -> ShowS
Show, Tape a t -> DataType
Tape a t -> Constr
forall a.
Typeable a
-> (forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
(r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
(r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall {a} {t}. (Data a, Data t) => Typeable (Tape a t)
forall a t. (Data a, Data t) => Tape a t -> DataType
forall a t. (Data a, Data t) => Tape a t -> Constr
forall a t.
(Data a, Data t) =>
(forall b. Data b => b -> b) -> Tape a t -> Tape a t
forall a t u.
(Data a, Data t) =>
Int -> (forall d. Data d => d -> u) -> Tape a t -> u
forall a t u.
(Data a, Data t) =>
(forall d. Data d => d -> u) -> Tape a t -> [u]
forall a t r r'.
(Data a, Data t) =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
forall a t r r'.
(Data a, Data t) =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
forall a t (m :: * -> *).
(Data a, Data t, Monad m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
forall a t (m :: * -> *).
(Data a, Data t, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
forall a t (c :: * -> *).
(Data a, Data t) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tape a t)
forall a t (c :: * -> *).
(Data a, Data t) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tape a t -> c (Tape a t)
forall a t (t :: * -> *) (c :: * -> *).
(Data a, Data t, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Tape a t))
forall a t (t :: * -> * -> *) (c :: * -> *).
(Data a, Data t, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tape a t))
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tape a t)
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tape a t -> c (Tape a t)
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tape a t))
gmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
$cgmapMo :: forall a t (m :: * -> *).
(Data a, Data t, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
gmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
$cgmapMp :: forall a t (m :: * -> *).
(Data a, Data t, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
gmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
$cgmapM :: forall a t (m :: * -> *).
(Data a, Data t, Monad m) =>
(forall d. Data d => d -> m d) -> Tape a t -> m (Tape a t)
gmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> Tape a t -> u
$cgmapQi :: forall a t u.
(Data a, Data t) =>
Int -> (forall d. Data d => d -> u) -> Tape a t -> u
gmapQ :: forall u. (forall d. Data d => d -> u) -> Tape a t -> [u]
$cgmapQ :: forall a t u.
(Data a, Data t) =>
(forall d. Data d => d -> u) -> Tape a t -> [u]
gmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
$cgmapQr :: forall a t r r'.
(Data a, Data t) =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
gmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
$cgmapQl :: forall a t r r'.
(Data a, Data t) =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tape a t -> r
gmapT :: (forall b. Data b => b -> b) -> Tape a t -> Tape a t
$cgmapT :: forall a t.
(Data a, Data t) =>
(forall b. Data b => b -> b) -> Tape a t -> Tape a t
dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tape a t))
$cdataCast2 :: forall a t (t :: * -> * -> *) (c :: * -> *).
(Data a, Data t, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tape a t))
dataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Tape a t))
$cdataCast1 :: forall a t (t :: * -> *) (c :: * -> *).
(Data a, Data t, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Tape a t))
dataTypeOf :: Tape a t -> DataType
$cdataTypeOf :: forall a t. (Data a, Data t) => Tape a t -> DataType
toConstr :: Tape a t -> Constr
$ctoConstr :: forall a t. (Data a, Data t) => Tape a t -> Constr
gunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tape a t)
$cgunfold :: forall a t (c :: * -> *).
(Data a, Data t) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tape a t)
gfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tape a t -> c (Tape a t)
$cgfoldl :: forall a t (c :: * -> *).
(Data a, Data t) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tape a t -> c (Tape a t)
Data, Typeable)
newtype Kahn a = Kahn (Tape a (Kahn a)) deriving (Int -> Kahn a -> ShowS
forall a. Show a => Int -> Kahn a -> ShowS
forall a. Show a => [Kahn a] -> ShowS
forall a. Show a => Kahn a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Kahn a] -> ShowS
$cshowList :: forall a. Show a => [Kahn a] -> ShowS
show :: Kahn a -> String
$cshow :: forall a. Show a => Kahn a -> String
showsPrec :: Int -> Kahn a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Kahn a -> ShowS
Show, Typeable)
instance MuRef (Kahn a) where
type DeRef (Kahn a) = Tape a
mapDeRef :: forall (f :: * -> *) u.
Applicative f =>
(forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u)
-> Kahn a -> f (DeRef (Kahn a) u)
mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
_ (Kahn Tape a (Kahn a)
Zero) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a t. Tape a t
Zero
mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
_ (Kahn (Lift a
a)) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a t. a -> Tape a t
Lift a
a)
mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
_ (Kahn (Var a
a Int
v)) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a t. a -> Int -> Tape a t
Var a
a Int
v)
mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f (Kahn (Binary a
a a
dadb a
dadc Kahn a
b Kahn a
c)) = forall a t. a -> a -> a -> t -> t -> Tape a t
Binary a
a a
dadb a
dadc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f Kahn a
b forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f Kahn a
c
mapDeRef forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f (Kahn (Unary a
a a
dadb Kahn a
b)) = forall a t. a -> a -> t -> Tape a t
Unary a
a a
dadb forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall b. (MuRef b, DeRef (Kahn a) ~ DeRef b) => b -> f u
f Kahn a
b
instance Num a => Mode (Kahn a) where
type Scalar (Kahn a) = a
isKnownZero :: Kahn a -> Bool
isKnownZero (Kahn Tape a (Kahn a)
Zero) = Bool
True
isKnownZero Kahn a
_ = Bool
False
asKnownConstant :: Kahn a -> Maybe (Scalar (Kahn a))
asKnownConstant (Kahn Tape a (Kahn a)
Zero) = forall a. a -> Maybe a
Just a
0
asKnownConstant (Kahn (Lift a
n)) = forall a. a -> Maybe a
Just a
n
asKnownConstant Kahn a
_ = forall a. Maybe a
Nothing
isKnownConstant :: Kahn a -> Bool
isKnownConstant (Kahn Tape a (Kahn a)
Zero) = Bool
True
isKnownConstant (Kahn (Lift a
_)) = Bool
True
isKnownConstant Kahn a
_ = Bool
False
auto :: Scalar (Kahn a) -> Kahn a
auto Scalar (Kahn a)
a = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> Tape a t
Lift Scalar (Kahn a)
a)
zero :: Kahn a
zero = forall a. Tape a (Kahn a) -> Kahn a
Kahn forall a t. Tape a t
Zero
Scalar (Kahn a)
a *^ :: Scalar (Kahn a) -> Kahn a -> Kahn a
*^ Kahn a
b = forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Scalar (Kahn a)
a forall a. Num a => a -> a -> a
*) (\D (Kahn a)
_ -> forall t. Mode t => Scalar t -> t
auto Scalar (Kahn a)
a) Kahn a
b
Kahn a
a ^* :: Kahn a -> Scalar (Kahn a) -> Kahn a
^* Scalar (Kahn a)
b = forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (forall a. Num a => a -> a -> a
* Scalar (Kahn a)
b) (\D (Kahn a)
_ -> forall t. Mode t => Scalar t -> t
auto Scalar (Kahn a)
b) Kahn a
a
Kahn a
a ^/ :: Fractional (Scalar (Kahn a)) => Kahn a -> Scalar (Kahn a) -> Kahn a
^/ Scalar (Kahn a)
b = forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (forall a. Fractional a => a -> a -> a
/ Scalar (Kahn a)
b) (\D (Kahn a)
_ -> forall t. Mode t => Scalar t -> t
auto (forall a. Fractional a => a -> a
recip Scalar (Kahn a)
b)) Kahn a
a
(<+>) :: Num a => Kahn a -> Kahn a -> Kahn a
<+> :: forall a. Num a => Kahn a -> Kahn a -> Kahn a
(<+>) = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary forall a. Num a => a -> a -> a
(+) Id a
1 Id a
1
primal :: Num a => Kahn a -> a
primal :: forall a. Num a => Kahn a -> a
primal (Kahn Tape a (Kahn a)
Zero) = a
0
primal (Kahn (Lift a
a)) = a
a
primal (Kahn (Var a
a Int
_)) = a
a
primal (Kahn (Binary a
a a
_ a
_ Kahn a
_ Kahn a
_)) = a
a
primal (Kahn (Unary a
a a
_ Kahn a
_)) = a
a
instance Num a => Jacobian (Kahn a) where
type D (Kahn a) = Id a
unary :: (Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> Kahn a -> Kahn a
unary Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ (Kahn Tape a (Kahn a)
Zero) = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a)
f a
0))
unary Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ (Kahn (Lift a
a)) = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a)
f a
a))
unary Scalar (Kahn a) -> Scalar (Kahn a)
f (Id a
dadb) Kahn a
b = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a)
f (forall a. Num a => Kahn a -> a
primal Kahn a
b)) a
dadb Kahn a
b)
lift1 :: (Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a)) -> Kahn a -> Kahn a
lift1 Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a) -> D (Kahn a)
df Kahn a
b = forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary Scalar (Kahn a) -> Scalar (Kahn a)
f (D (Kahn a) -> D (Kahn a)
df (forall a. a -> Id a
Id a
pb)) Kahn a
b where
pb :: a
pb = forall a. Num a => Kahn a -> a
primal Kahn a
b
lift1_ :: (Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a) -> D (Kahn a)) -> Kahn a -> Kahn a
lift1_ Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a) -> D (Kahn a) -> D (Kahn a)
df Kahn a
b = forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary (forall a b. a -> b -> a
const Scalar (Kahn a)
a) (D (Kahn a) -> D (Kahn a) -> D (Kahn a)
df (forall a. a -> Id a
Id Scalar (Kahn a)
a) (forall a. a -> Id a
Id a
pb)) Kahn a
b where
pb :: a
pb = forall a. Num a => Kahn a -> a
primal Kahn a
b
a :: Scalar (Kahn a)
a = Scalar (Kahn a) -> Scalar (Kahn a)
f a
pb
binary :: (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> D (Kahn a) -> D (Kahn a) -> Kahn a -> Kahn a -> Kahn a
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ D (Kahn a)
_ (Kahn Tape a (Kahn a)
Zero) (Kahn Tape a (Kahn a)
Zero) = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
0 a
0))
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ D (Kahn a)
_ (Kahn Tape a (Kahn a)
Zero) (Kahn (Lift a
c)) = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
0 a
c))
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ D (Kahn a)
_ (Kahn (Lift a
b)) (Kahn Tape a (Kahn a)
Zero) = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
b a
0))
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ D (Kahn a)
_ (Kahn (Lift a
b)) (Kahn (Lift a
c)) = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> Tape a t
Lift (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
b a
c))
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ (Id a
dadc) (Kahn Tape a (Kahn a)
Zero) Kahn a
c = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
0 (forall a. Num a => Kahn a -> a
primal Kahn a
c)) a
dadc Kahn a
c)
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
_ (Id a
dadc) (Kahn (Lift a
b)) Kahn a
c = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
b (forall a. Num a => Kahn a -> a
primal Kahn a
c)) a
dadc Kahn a
c)
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Id a
dadb) D (Kahn a)
_ Kahn a
b (Kahn Tape a (Kahn a)
Zero) = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (forall a. Num a => Kahn a -> a
primal Kahn a
b) a
0) a
dadb Kahn a
b)
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Id a
dadb) D (Kahn a)
_ Kahn a
b (Kahn (Lift a
c)) = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> a -> t -> Tape a t
Unary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (forall a. Num a => Kahn a -> a
primal Kahn a
b) a
c) a
dadb Kahn a
b)
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (Id a
dadb) (Id a
dadc) Kahn a
b Kahn a
c = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> a -> a -> t -> t -> Tape a t
Binary (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f (forall a. Num a => Kahn a -> a
primal Kahn a
b) (forall a. Num a => Kahn a -> a
primal Kahn a
c)) a
dadb a
dadc Kahn a
b Kahn a
c)
lift2 :: (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a)))
-> Kahn a
-> Kahn a
-> Kahn a
lift2 Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a))
df Kahn a
b Kahn a
c = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a)
dadb D (Kahn a)
dadc Kahn a
b Kahn a
c where
(D (Kahn a)
dadb, D (Kahn a)
dadc) = D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a))
df (forall a. a -> Id a
Id (forall a. Num a => Kahn a -> a
primal Kahn a
b)) (forall a. a -> Id a
Id (forall a. Num a => Kahn a -> a
primal Kahn a
c))
lift2_ :: (Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a))
-> (D (Kahn a)
-> D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a)))
-> Kahn a
-> Kahn a
-> Kahn a
lift2_ Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f D (Kahn a) -> D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a))
df Kahn a
b Kahn a
c = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary (\Scalar (Kahn a)
_ Scalar (Kahn a)
_ -> Scalar (Kahn a)
a) D (Kahn a)
dadb D (Kahn a)
dadc Kahn a
b Kahn a
c where
pb :: a
pb = forall a. Num a => Kahn a -> a
primal Kahn a
b
pc :: a
pc = forall a. Num a => Kahn a -> a
primal Kahn a
c
a :: Scalar (Kahn a)
a = Scalar (Kahn a) -> Scalar (Kahn a) -> Scalar (Kahn a)
f a
pb a
pc
(D (Kahn a)
dadb, D (Kahn a)
dadc) = D (Kahn a) -> D (Kahn a) -> D (Kahn a) -> (D (Kahn a), D (Kahn a))
df (forall a. a -> Id a
Id Scalar (Kahn a)
a) (forall a. a -> Id a
Id a
pb) (forall a. a -> Id a
Id a
pc)
mul :: Num a => Kahn a -> Kahn a -> Kahn a
mul :: forall a. Num a => Kahn a -> Kahn a -> Kahn a
mul = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t)
-> (D t -> D t -> (D t, D t)) -> t -> t -> t
lift2 forall a. Num a => a -> a -> a
(*) (\D (Kahn a)
x D (Kahn a)
y -> (D (Kahn a)
y, D (Kahn a)
x))
#define HEAD (Kahn a)
#include <instances.h>
derivative :: Num a => Kahn a -> a
derivative :: forall a. Num a => Kahn a -> a
derivative = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Kahn a -> [(Int, a)]
partials
{-# INLINE derivative #-}
derivative' :: Num a => Kahn a -> (a, a)
derivative' :: forall a. Num a => Kahn a -> (a, a)
derivative' Kahn a
r = (forall a. Num a => Kahn a -> a
primal Kahn a
r, forall a. Num a => Kahn a -> a
derivative Kahn a
r)
{-# INLINE derivative' #-}
backPropagate :: Num a => (Vertex -> (Tape a Int, Int, [Int])) -> STArray s Int a -> Vertex -> ST s ()
backPropagate :: forall a s.
Num a =>
(Int -> (Tape a Int, Int, [Int]))
-> STArray s Int a -> Int -> ST s ()
backPropagate Int -> (Tape a Int, Int, [Int])
vmap STArray s Int a
ss Int
v = case Tape a Int
node of
Unary a
_ a
g Int
b -> do
a
da <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
i
a
db <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
b
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
b (a
db forall a. Num a => a -> a -> a
+ a
gforall a. Num a => a -> a -> a
*a
da)
Binary a
_ a
gb a
gc Int
b Int
c -> do
a
da <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
i
a
db <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
b
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
b (a
db forall a. Num a => a -> a -> a
+ a
gbforall a. Num a => a -> a -> a
*a
da)
a
dc <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
c
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
c (a
dc forall a. Num a => a -> a -> a
+ a
gcforall a. Num a => a -> a -> a
*a
da)
Tape a Int
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
where
(Tape a Int
node, Int
i, [Int]
_) = Int -> (Tape a Int, Int, [Int])
vmap Int
v
topSortAcyclic :: Graph -> [Vertex]
topSortAcyclic :: Graph -> [Int]
topSortAcyclic Graph
g = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
STUArray s Int Bool
del <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (forall i e. Array i e -> (i, i)
bounds Graph
g) Bool
False :: ST s (STUArray s Int Bool)
let tg :: Graph
tg = Graph -> Graph
transposeG Graph
g
starters :: [Int]
starters = [ Int
n | (Int
n, []) <- forall i e. Ix i => Array i e -> [(i, e)]
assocs Graph
tg ]
loop :: [Int] -> [Int] -> ST s [Int]
loop [] [Int]
rs = forall (m :: * -> *) a. Monad m => a -> m a
return [Int]
rs
loop (Int
n:[Int]
ns) [Int]
rs = do
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Bool
del Int
n Bool
True
let add :: [Int] -> ST s [Int]
add [] = forall (m :: * -> *) a. Monad m => a -> m a
return [Int]
ns
add (Int
m:[Int]
ms) = do
Bool
b <- [Int] -> ST s Bool
ok (Graph
tgforall i e. Ix i => Array i e -> i -> e
!Int
m)
[Int]
ms' <- [Int] -> ST s [Int]
add [Int]
ms
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ if Bool
b then Int
m forall a. a -> [a] -> [a]
: [Int]
ms' else [Int]
ms'
ok :: [Int] -> ST s Bool
ok [] = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
ok (Int
x:[Int]
xs) = do Bool
b <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Bool
del Int
x; if Bool
b then [Int] -> ST s Bool
ok [Int]
xs else forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
[Int]
ns' <- [Int] -> ST s [Int]
add (Graph
gforall i e. Ix i => Array i e -> i -> e
!Int
n)
[Int] -> [Int] -> ST s [Int]
loop [Int]
ns' (Int
n forall a. a -> [a] -> [a]
: [Int]
rs)
[Int] -> [Int] -> ST s [Int]
loop [Int]
starters []
{-# SPECIALIZE partials :: Kahn Double -> [(Int, Double)] #-}
partials :: forall a. Num a => Kahn a -> [(Int, a)]
partials :: forall a. Num a => Kahn a -> [(Int, a)]
partials Kahn a
tape = [ let v :: a
v = Array Int a
sensitivities forall i e. Ix i => Array i e -> i -> e
! Int
ix in seq :: forall a b. a -> b -> b
seq a
v (Int
ident, a
v) | (Int
ix, Var a
_ Int
ident) <- [(Int, Tape a Int)]
xs ] where
Reified.Graph [(Int, Tape a Int)]
xs Int
start = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall s. MuRef s => s -> IO (Graph (DeRef s))
reifyGraph Kahn a
tape
g :: Graph
g = forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (Int, Int)
xsBounds [ (Int
i, Tape a Int -> [Int]
successors Tape a Int
t) | (Int
i, Tape a Int
t) <- [(Int, Tape a Int)]
xs ]
vertexMap :: Array Int (Tape a Int)
vertexMap = forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (Int, Int)
xsBounds [(Int, Tape a Int)]
xs
vmap :: Int -> (Tape a Int, Int, [Int])
vmap Int
i = (Array Int (Tape a Int)
vertexMap forall i e. Ix i => Array i e -> i -> e
! Int
i, Int
i, [])
xsBounds :: (Int, Int)
xsBounds = forall {a} {b}. Ord a => [(a, b)] -> (a, a)
sbounds [(Int, Tape a Int)]
xs
sensitivities :: Array Int a
sensitivities = forall i e. (forall s. ST s (STArray s i e)) -> Array i e
runSTArray forall a b. (a -> b) -> a -> b
$ do
STArray s Int a
ss <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int, Int)
xsBounds a
0
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
start a
1
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Graph -> [Int]
topSortAcyclic Graph
g) forall a b. (a -> b) -> a -> b
$
forall a s.
Num a =>
(Int -> (Tape a Int, Int, [Int]))
-> STArray s Int a -> Int -> ST s ()
backPropagate Int -> (Tape a Int, Int, [Int])
vmap STArray s Int a
ss
forall (m :: * -> *) a. Monad m => a -> m a
return STArray s Int a
ss
sbounds :: [(a, b)] -> (a, a)
sbounds ((a
a,b
_):[(a, b)]
as) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\(a
lo,a
hi) (a
b,b
_) -> let lo' :: a
lo' = forall a. Ord a => a -> a -> a
min a
lo a
b; hi' :: a
hi' = forall a. Ord a => a -> a -> a
max a
hi a
b in a
lo' seq :: forall a b. a -> b -> b
`seq` a
hi' seq :: forall a b. a -> b -> b
`seq` (a
lo', a
hi')) (a
a,a
a) [(a, b)]
as
sbounds [(a, b)]
_ = forall a. HasCallStack => a
undefined
successors :: Tape a Int -> [Int]
successors :: Tape a Int -> [Int]
successors (Unary a
_ a
_ Int
b) = [Int
b]
successors (Binary a
_ a
_ a
_ Int
b Int
c) = if Int
b forall a. Eq a => a -> a -> Bool
== Int
c then [Int
b] else [Int
b,Int
c]
successors Tape a Int
_ = []
partialArray :: Num a => (Int, Int) -> Kahn a -> Array Int a
partialArray :: forall a. Num a => (Int, Int) -> Kahn a -> Array Int a
partialArray (Int, Int)
vbounds Kahn a
tape = forall i e a.
Ix i =>
(e -> a -> e) -> e -> (i, i) -> [(i, a)] -> Array i e
accumArray forall a. Num a => a -> a -> a
(+) a
0 (Int, Int)
vbounds (forall a. Num a => Kahn a -> [(Int, a)]
partials Kahn a
tape)
{-# INLINE partialArray #-}
partialMap :: Num a => Kahn a -> IntMap a
partialMap :: forall a. Num a => Kahn a -> IntMap a
partialMap = forall a. (a -> a -> a) -> [(Int, a)] -> IntMap a
fromListWith forall a. Num a => a -> a -> a
(+) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Kahn a -> [(Int, a)]
partials
{-# INLINE partialMap #-}
class Num a => Grad i o o' a | i -> a o o', o -> a i o', o' -> a i o where
pack :: i -> [Kahn a] -> Kahn a
unpack :: ([a] -> [a]) -> o
unpack' :: ([a] -> (a, [a])) -> o'
instance Num a => Grad (Kahn a) [a] (a, [a]) a where
pack :: Kahn a -> [Kahn a] -> Kahn a
pack Kahn a
i [Kahn a]
_ = Kahn a
i
unpack :: ([a] -> [a]) -> [a]
unpack [a] -> [a]
f = [a] -> [a]
f []
unpack' :: ([a] -> (a, [a])) -> (a, [a])
unpack' [a] -> (a, [a])
f = [a] -> (a, [a])
f []
instance Grad i o o' a => Grad (Kahn a -> i) (a -> o) (a -> o') a where
pack :: (Kahn a -> i) -> [Kahn a] -> Kahn a
pack Kahn a -> i
f (Kahn a
a:[Kahn a]
as) = forall i o o' a. Grad i o o' a => i -> [Kahn a] -> Kahn a
pack (Kahn a -> i
f Kahn a
a) [Kahn a]
as
pack Kahn a -> i
_ [] = forall a. HasCallStack => String -> a
error String
"Grad.pack: logic error"
unpack :: ([a] -> [a]) -> a -> o
unpack [a] -> [a]
f a
a = forall i o o' a. Grad i o o' a => ([a] -> [a]) -> o
unpack ([a] -> [a]
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
aforall a. a -> [a] -> [a]
:))
unpack' :: ([a] -> (a, [a])) -> a -> o'
unpack' [a] -> (a, [a])
f a
a = forall i o o' a. Grad i o o' a => ([a] -> (a, [a])) -> o'
unpack' ([a] -> (a, [a])
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
aforall a. a -> [a] -> [a]
:))
vgrad :: Grad i o o' a => i -> o
vgrad :: forall i o o' a. Grad i o o' a => i -> o
vgrad i
i = forall i o o' a. Grad i o o' a => ([a] -> [a]) -> o
unpack (forall {f :: * -> *} {a}.
(Traversable f, Num a) =>
(f (Kahn a) -> Kahn a) -> f a -> f a
unsafeGrad (forall i o o' a. Grad i o o' a => i -> [Kahn a] -> Kahn a
pack i
i)) where
unsafeGrad :: (f (Kahn a) -> Kahn a) -> f a -> f a
unsafeGrad f (Kahn a) -> Kahn a
f f a
as = forall (f :: * -> *) a.
Functor f =>
f (Kahn a) -> Array Int a -> f a
unbind f (Kahn a)
vs (forall a. Num a => (Int, Int) -> Kahn a -> Array Int a
partialArray (Int, Int)
bds forall a b. (a -> b) -> a -> b
$ f (Kahn a) -> Kahn a
f f (Kahn a)
vs) where
(f (Kahn a)
vs,(Int, Int)
bds) = forall (f :: * -> *) a.
Traversable f =>
f a -> (f (Kahn a), (Int, Int))
bind f a
as
vgrad' :: Grad i o o' a => i -> o'
vgrad' :: forall i o o' a. Grad i o o' a => i -> o'
vgrad' i
i = forall i o o' a. Grad i o o' a => ([a] -> (a, [a])) -> o'
unpack' (forall {f :: * -> *} {a}.
(Traversable f, Num a) =>
(f (Kahn a) -> Kahn a) -> f a -> (a, f a)
unsafeGrad' (forall i o o' a. Grad i o o' a => i -> [Kahn a] -> Kahn a
pack i
i)) where
unsafeGrad' :: (f (Kahn a) -> Kahn a) -> f a -> (a, f a)
unsafeGrad' f (Kahn a) -> Kahn a
f f a
as = (forall a. Num a => Kahn a -> a
primal Kahn a
r, forall (f :: * -> *) a.
Functor f =>
f (Kahn a) -> Array Int a -> f a
unbind f (Kahn a)
vs (forall a. Num a => (Int, Int) -> Kahn a -> Array Int a
partialArray (Int, Int)
bds Kahn a
r)) where
r :: Kahn a
r = f (Kahn a) -> Kahn a
f f (Kahn a)
vs
(f (Kahn a)
vs,(Int, Int)
bds) = forall (f :: * -> *) a.
Traversable f =>
f a -> (f (Kahn a), (Int, Int))
bind f a
as
var :: a -> Int -> Kahn a
var :: forall a. a -> Int -> Kahn a
var a
a Int
v = forall a. Tape a (Kahn a) -> Kahn a
Kahn (forall a t. a -> Int -> Tape a t
Var a
a Int
v)
varId :: Kahn a -> Int
varId :: forall a. Kahn a -> Int
varId (Kahn (Var a
_ Int
v)) = Int
v
varId Kahn a
_ = forall a. HasCallStack => String -> a
error String
"varId: not a Var"
bind :: Traversable f => f a -> (f (Kahn a), (Int,Int))
bind :: forall (f :: * -> *) a.
Traversable f =>
f a -> (f (Kahn a), (Int, Int))
bind f a
xs = (f (Kahn a)
r,(Int
0,Int
hi)) where
(f (Kahn a)
r,Int
hi) = forall s a. State s a -> s -> (a, s)
runState (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *} {a}. Monad m => a -> StateT Int m (Kahn a)
freshVar f a
xs) Int
0
freshVar :: a -> StateT Int m (Kahn a)
freshVar a
a = forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state forall a b. (a -> b) -> a -> b
$ \Int
s -> let s' :: Int
s' = Int
s forall a. Num a => a -> a -> a
+ Int
1 in Int
s' seq :: forall a b. a -> b -> b
`seq` (forall a. a -> Int -> Kahn a
var a
a Int
s, Int
s')
unbind :: Functor f => f (Kahn a) -> Array Int a -> f a
unbind :: forall (f :: * -> *) a.
Functor f =>
f (Kahn a) -> Array Int a -> f a
unbind f (Kahn a)
xs Array Int a
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Kahn a
v -> Array Int a
ys forall i e. Ix i => Array i e -> i -> e
! forall a. Kahn a -> Int
varId Kahn a
v) f (Kahn a)
xs
unbindWith :: (Functor f, Num a) => (a -> b -> c) -> f (Kahn a) -> Array Int b -> f c
unbindWith :: forall (f :: * -> *) a b c.
(Functor f, Num a) =>
(a -> b -> c) -> f (Kahn a) -> Array Int b -> f c
unbindWith a -> b -> c
f f (Kahn a)
xs Array Int b
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Kahn a
v -> a -> b -> c
f (forall a. Num a => Kahn a -> a
primal Kahn a
v) (Array Int b
ys forall i e. Ix i => Array i e -> i -> e
! forall a. Kahn a -> Int
varId Kahn a
v)) f (Kahn a)
xs
unbindMap :: (Functor f, Num a) => f (Kahn a) -> IntMap a -> f a
unbindMap :: forall (f :: * -> *) a.
(Functor f, Num a) =>
f (Kahn a) -> IntMap a -> f a
unbindMap f (Kahn a)
xs IntMap a
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Kahn a
v -> forall a. a -> Int -> IntMap a -> a
findWithDefault a
0 (forall a. Kahn a -> Int
varId Kahn a
v) IntMap a
ys) f (Kahn a)
xs
unbindMapWithDefault :: (Functor f, Num a) => b -> (a -> b -> c) -> f (Kahn a) -> IntMap b -> f c
unbindMapWithDefault :: forall (f :: * -> *) a b c.
(Functor f, Num a) =>
b -> (a -> b -> c) -> f (Kahn a) -> IntMap b -> f c
unbindMapWithDefault b
z a -> b -> c
f f (Kahn a)
xs IntMap b
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Kahn a
v -> a -> b -> c
f (forall a. Num a => Kahn a -> a
primal Kahn a
v) forall a b. (a -> b) -> a -> b
$ forall a. a -> Int -> IntMap a -> a
findWithDefault b
z (forall a. Kahn a -> Int
varId Kahn a
v) IntMap b
ys) f (Kahn a)
xs