{-# language FlexibleInstances, DeriveFunctor #-}
{-# language ScopedTypeVariables #-}
{-# language RankNTypes #-}
{-# language ViewPatterns #-}
{-# language FlexibleContexts #-}
{-# language BangPatterns #-}
module Algorithm.SRTree.AD
( forwardMode
, forwardModeUnique
, reverseModeUnique
, reverseModeUniqueArr
, forwardModeUniqueJac
) where
import Control.Monad (forM_, foldM)
import Control.Monad.ST ( runST )
import Data.Bifunctor (bimap, first, second)
import qualified Data.DList as DL
import Data.Massiv.Array hiding (forM_, map, replicate, zipWith)
import qualified Data.Massiv.Array as M
import qualified Data.Massiv.Array.Unsafe as UMA
import Data.Massiv.Core.Operations (unsafeLiftArray)
import Data.SRTree.Derivative ( derivative )
import Data.SRTree.Eval
( SRVector, evalFun, evalOp, SRMatrix, PVector, replicateAs )
import Data.SRTree.Internal
import Data.SRTree.Print (showExpr)
import Data.SRTree.Recursion ( cataM, cata, accu )
import qualified Data.Vector as V
import Debug.Trace (trace, traceShow)
import GHC.IO (unsafePerformIO)
import qualified Data.IntMap.Strict as IntMap
import Data.List ( foldl' )
applyUni :: (Index ix, Source r e, Floating e, Floating b) => Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni :: forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
f (Left Array r ix e
t) =
Array D ix e -> Either (Array D ix e) b
forall a b. a -> Either a b
Left (Array D ix e -> Either (Array D ix e) b)
-> Array D ix e -> Either (Array D ix e) b
forall a b. (a -> b) -> a -> b
$ (e -> e) -> Array r ix e -> Array D ix e
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> e -> e
forall a. Floating a => Function -> a -> a
evalFun Function
f) Array r ix e
t
applyUni Function
f (Right b
t) =
b -> Either (Array D ix e) b
forall a b. b -> Either a b
Right (b -> Either (Array D ix e) b) -> b -> Either (Array D ix e) b
forall a b. (a -> b) -> a -> b
$ Function -> b -> b
forall a. Floating a => Function -> a -> a
evalFun Function
f b
t
{-# INLINE applyUni #-}
applyDer :: (Index ix, Source r e, Floating e, Floating b) => Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyDer :: forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyDer Function
f (Left Array r ix e
t) =
Array D ix e -> Either (Array D ix e) b
forall a b. a -> Either a b
Left (Array D ix e -> Either (Array D ix e) b)
-> Array D ix e -> Either (Array D ix e) b
forall a b. (a -> b) -> a -> b
$ (e -> e) -> Array r ix e -> Array D ix e
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> e -> e
forall a. Floating a => Function -> a -> a
derivative Function
f) Array r ix e
t
applyDer Function
f (Right b
t) =
b -> Either (Array D ix e) b
forall a b. b -> Either a b
Right (b -> Either (Array D ix e) b) -> b -> Either (Array D ix e) b
forall a b. (a -> b) -> a -> b
$ Function -> b -> b
forall a. Floating a => Function -> a -> a
derivative Function
f b
t
{-# INLINE applyDer #-}
negate' :: (Index ix, Source r e, Num e, Num b) => Either (Array r ix e) b -> Either (Array D ix e) b
negate' :: forall ix r e b.
(Index ix, Source r e, Num e, Num b) =>
Either (Array r ix e) b -> Either (Array D ix e) b
negate' (Left Array r ix e
t) = Array D ix e -> Either (Array D ix e) b
forall a b. a -> Either a b
Left (Array D ix e -> Either (Array D ix e) b)
-> Array D ix e -> Either (Array D ix e) b
forall a b. (a -> b) -> a -> b
$ (e -> e) -> Array r ix e -> Array D ix e
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map e -> e
forall a. Num a => a -> a
negate Array r ix e
t
negate' (Right b
t) = b -> Either (Array D ix e) b
forall a b. b -> Either a b
Right (b -> Either (Array D ix e) b) -> b -> Either (Array D ix e) b
forall a b. (a -> b) -> a -> b
$ b -> b
forall a. Num a => a -> a
negate b
t
{-# INLINE negate' #-}
applyBin :: (Index ix, Floating b) => Op -> Either (Array D ix b) b -> Either (Array D ix b) b -> Either (Array D ix b) b
applyBin :: forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
op (Left Array D ix b
ly) (Left Array D ix b
ry) =
Array D ix b -> Either (Array D ix b) b
forall a b. a -> Either a b
Left (Array D ix b -> Either (Array D ix b) b)
-> Array D ix b -> Either (Array D ix b) b
forall a b. (a -> b) -> a -> b
$ case Op
op of
Op
Add -> Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r e.
(HasCallStack, Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!+! Array D ix b
ry
Op
Sub -> Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!-! Array D ix b
ry
Op
Mul -> Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!*! Array D ix b
ry
Op
Div -> Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!/! Array D ix b
ry
Op
Power -> Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r1 e r2.
(Index ix, Source r1 e, Source r2 e, Floating e) =>
Array r1 ix e -> Array r2 ix e -> Array D ix e
.** Array D ix b
ry
Op
PowerAbs -> (b -> b) -> Array D ix b -> Array D ix b
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map b -> b
forall a. Num a => a -> a
abs (Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r1 e r2.
(Index ix, Source r1 e, Source r2 e, Floating e) =>
Array r1 ix e -> Array r2 ix e -> Array D ix e
.** Array D ix b
ry)
Op
AQ -> Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!/! ((b -> b) -> Array D ix b -> Array D ix b
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map b -> b
forall a. Floating a => a -> a
sqrt ((b -> b) -> Array D ix b -> Array D ix b
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (b -> b -> b
forall a. Num a => a -> a -> a
+b
1) (Array D ix b
ry Array D ix b -> Array D ix b -> Array D ix b
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!*! Array D ix b
ry)))
applyBin Op
op (Left Array D ix b
ly) (Right b
ry) =
Array D ix b -> Either (Array D ix b) b
forall a b. a -> Either a b
Left (Array D ix b -> Either (Array D ix b) b)
-> Array D ix b -> Either (Array D ix b) b
forall a b. (a -> b) -> a -> b
$ (b -> b) -> Array D ix b -> Array D ix b
forall ix. Index ix => (b -> b) -> Array D ix b -> Array D ix b
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (\ b
x -> Op -> b -> b -> b
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op b
x b
ry) Array D ix b
ly
applyBin Op
op (Right b
ly) (Left Array D ix b
ry) =
Array D ix b -> Either (Array D ix b) b
forall a b. a -> Either a b
Left (Array D ix b -> Either (Array D ix b) b)
-> Array D ix b -> Either (Array D ix b) b
forall a b. (a -> b) -> a -> b
$ (b -> b) -> Array D ix b -> Array D ix b
forall ix. Index ix => (b -> b) -> Array D ix b -> Array D ix b
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (\ b
x -> Op -> b -> b -> b
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op b
ly b
x) Array D ix b
ry
applyBin Op
op (Right b
ly) (Right b
ry) =
b -> Either (Array D ix b) b
forall a b. b -> Either a b
Right (b -> Either (Array D ix b) b) -> b -> Either (Array D ix b) b
forall a b. (a -> b) -> a -> b
$ Op -> b -> b -> b
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op b
ly b
ry
{-# INLINE applyBin #-}
(!??) :: (Manifest r e, Index ix) => Either (Array r ix e) e -> ix -> e
(Left Array r ix e
y) !?? :: forall r e ix.
(Manifest r e, Index ix) =>
Either (Array r ix e) e -> ix -> e
!?? ix
ix = Array r ix e
y Array r ix e -> ix -> e
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! ix
ix
(Right e
y) !?? ix
ix = e
y
{-# INLINE (!??) #-}
forwardMode :: Array S Ix2 Double -> Array S Ix1 Double -> SRVector -> Fix SRTree -> (Array D Ix1 Double, Array S Ix1 Double)
forwardMode :: Array S Ix2 Double
-> Array S Ix1 Double
-> SRVector
-> Fix SRTree
-> (SRVector, Array S Ix1 Double)
forwardMode Array S Ix2 Double
xss Array S Ix1 Double
theta SRVector
err Fix SRTree
tree = let (Either SRVector Double
yhat, Array S Ix2 Double
jacob) = (forall s. ST s (Either SRVector Double, Array S Ix2 Double))
-> (Either SRVector Double, Array S Ix2 Double)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Either SRVector Double, Array S Ix2 Double))
-> (Either SRVector Double, Array S Ix2 Double))
-> (forall s. ST s (Either SRVector Double, Array S Ix2 Double))
-> (Either SRVector Double, Array S Ix2 Double)
forall a b. (a -> b) -> a -> b
$ (forall x. SRTree (ST s x) -> ST s (SRTree x))
-> (SRTree (Either SRVector Double, Array S Ix2 Double)
-> ST s (Either SRVector Double, Array S Ix2 Double))
-> Fix SRTree
-> ST s (Either SRVector Double, Array S Ix2 Double)
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(forall x. f (m x) -> m (f x)) -> (f a -> m a) -> Fix f -> m a
cataM SRTree (ST s x) -> ST s (SRTree x)
forall x. SRTree (ST s x) -> ST s (SRTree x)
forall {f :: * -> *} {val}.
Applicative f =>
SRTree (f val) -> f (SRTree val)
lToR SRTree (Either SRVector Double, Array S Ix2 Double)
-> ST s (Either SRVector Double, Array S Ix2 Double)
forall {m :: * -> *} {r}.
(Manifest r Double, PrimMonad m) =>
SRTree (Either SRVector Double, Array r Ix2 Double)
-> m (Either SRVector Double, Array r Ix2 Double)
alg Fix SRTree
tree
in (Either SRVector Double -> SRVector
forall {r} {e}.
Load r Ix1 e =>
Either (Array r Ix1 e) e -> Array r Ix1 e
fromEither Either SRVector Double
yhat, S -> SRVector -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S SRVector
err Array S Ix1 Double -> Array S Ix2 Double -> Array S Ix1 Double
forall r e.
(Numeric r e, Manifest r e) =>
Vector r e -> Matrix r e -> Vector r e
><! Array S Ix2 Double
jacob)
where
(Sz Ix1
p) = Array S Ix1 Double -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix1 Double
theta
(Sz (Ix1
m :. Ix1
n)) = Array S Ix2 Double -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix2 Double
xss
cmp :: Comp
cmp = Array S Ix2 Double -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp Array S Ix2 Double
xss
fromEither :: Either (Array r Ix1 e) e -> Array r Ix1 e
fromEither (Left Array r Ix1 e
y) = Array r Ix1 e
y
fromEither (Right e
y) = Comp -> Sz Ix1 -> e -> Array r Ix1 e
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Ix1 -> Sz Ix1
forall ix. Index ix => ix -> Sz ix
Sz Ix1
m) e
y
alg :: SRTree (Either SRVector Double, Array r Ix2 Double)
-> m (Either SRVector Double, Array r Ix2 Double)
alg (Var Ix1
ix) = do tape <- Sz Ix2 -> Double -> m (MArray (PrimState m) r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Sz ix -> Double -> m (MArray (PrimState m) r ix Double)
M.newMArray (Ix1 -> Ix1 -> Sz Ix2
Sz2 Ix1
m Ix1
p) Double
0
m (MArray (PrimState m) r Ix2 Double)
-> (MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double))
-> m (Array r Ix2 Double)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Comp -> MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix Double -> m (Array r ix Double)
UMA.unsafeFreeze Comp
cmp
pure (Left (xss <! ix), tape)
alg (Const Double
c) = do tape <- Sz Ix2 -> Double -> m (MArray (PrimState m) r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Sz ix -> Double -> m (MArray (PrimState m) r ix Double)
M.newMArray (Ix1 -> Ix1 -> Sz Ix2
Sz2 Ix1
m Ix1
p) Double
0
m (MArray (PrimState m) r Ix2 Double)
-> (MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double))
-> m (Array r Ix2 Double)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Comp -> MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix Double -> m (Array r ix Double)
UMA.unsafeFreeze Comp
cmp
pure (Right c, tape)
alg (Param Ix1
ix) = do tape <- Sz Ix2
-> (Ix2 -> m Double) -> m (MArray (PrimState m) r Ix2 Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> (ix -> m e) -> m (MArray (PrimState m) r ix e)
M.makeMArrayS (Ix1 -> Ix1 -> Sz Ix2
Sz2 Ix1
m Ix1
p) (\(Ix1
i :. Ix1
j) -> Double -> m Double
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> m Double) -> Double -> m Double
forall a b. (a -> b) -> a -> b
$ if Ix1
jIx1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
==Ix1
ix then Double
1 else Double
0)
m (MArray (PrimState m) r Ix2 Double)
-> (MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double))
-> m (Array r Ix2 Double)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Comp -> MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix Double -> m (Array r ix Double)
UMA.unsafeFreeze Comp
cmp
pure (Right (theta ! ix), tape)
alg (Uni Function
f (Either SRVector Double
t, Array r Ix2 Double
tape')) = do let y :: Array S Ix1 Double
y = S -> SRVector -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S (SRVector -> Array S Ix1 Double)
-> (Either SRVector Double -> SRVector)
-> Either SRVector Double
-> Array S Ix1 Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Either SRVector Double -> SRVector
forall {r} {e}.
Load r Ix1 e =>
Either (Array r Ix1 e) e -> Array r Ix1 e
fromEither (Either SRVector Double -> Array S Ix1 Double)
-> Either SRVector Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ Function -> Either SRVector Double -> Either SRVector Double
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyDer Function
f Either SRVector Double
t
tape <- Array r Ix2 Double -> m (MArray (PrimState m) r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Array r ix e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Array r ix Double -> m (MArray (PrimState m) r ix Double)
UMA.unsafeThaw Array r Ix2 Double
tape'
forM_ [0 .. m-1] $ \Ix1
i -> do
let yi :: Double
yi = Array S Ix1 Double
y Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
i
[Ix1] -> (Ix1 -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0 .. Ix1
pIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> m ()) -> m ()) -> (Ix1 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ix1
j -> do
v <- MArray (PrimState m) r Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) r Ix2 Double
tape (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
j)
UMA.unsafeWrite tape (i :. j) (yi * v)
tapeF <- UMA.unsafeFreeze cmp tape
pure (applyUni f t, tapeF)
alg (Bin Op
op (Either SRVector Double
l, Array r Ix2 Double
tl') (Either SRVector Double
r, Array r Ix2 Double
tr')) = do
tl <- Array r Ix2 Double -> m (MArray (PrimState m) r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Array r ix e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Array r ix Double -> m (MArray (PrimState m) r ix Double)
UMA.unsafeThaw Array r Ix2 Double
tl'
tr <- UMA.unsafeThaw tr'
let l' = case Either SRVector Double
l of
Left SRVector
y -> Array S Ix1 Double -> Either (Array S Ix1 Double) Double
forall a b. a -> Either a b
Left (Array S Ix1 Double -> Either (Array S Ix1 Double) Double)
-> Array S Ix1 Double -> Either (Array S Ix1 Double) Double
forall a b. (a -> b) -> a -> b
$ S -> SRVector -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S SRVector
y
Right Double
v -> Double -> Either (Array S Ix1 Double) Double
forall a b. b -> Either a b
Right Double
v
r' = case Either SRVector Double
r of
Left SRVector
y -> Array S Ix1 Double -> Either (Array S Ix1 Double) Double
forall a b. a -> Either a b
Left (Array S Ix1 Double -> Either (Array S Ix1 Double) Double)
-> Array S Ix1 Double -> Either (Array S Ix1 Double) Double
forall a b. (a -> b) -> a -> b
$ S -> SRVector -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S SRVector
y
Right Double
v -> Double -> Either (Array S Ix1 Double) Double
forall a b. b -> Either a b
Right Double
v
forM_ [0 .. m-1] $ \Ix1
i -> do
let li :: Double
li = Either (Array S Ix1 Double) Double
l' Either (Array S Ix1 Double) Double -> Ix1 -> Double
forall r e ix.
(Manifest r e, Index ix) =>
Either (Array r ix e) e -> ix -> e
!?? Ix1
i
ri :: Double
ri = Either (Array S Ix1 Double) Double
r' Either (Array S Ix1 Double) Double -> Ix1 -> Double
forall r e ix.
(Manifest r e, Index ix) =>
Either (Array r ix e) e -> ix -> e
!?? Ix1
i
[Ix1] -> (Ix1 -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0 .. Ix1
pIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> m ()) -> m ()) -> (Ix1 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ix1
j -> do
vl <- MArray (PrimState m) r Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) r Ix2 Double
tl (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
j)
vr <- UMA.unsafeRead tr (i :. j)
UMA.unsafeWrite tl (i :. j) $ case op of
Op
Add -> (Double
vlDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
vr)
Op
Sub -> (Double
vlDouble -> Double -> Double
forall a. Num a => a -> a -> a
-Double
vr)
Op
Mul -> (Double
vl Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
ri Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
vr Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
li)
Op
Div -> ((Double
vl Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
ri Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
vr Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
li) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
riDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
Op
Power -> (Double
li Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
ri Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
ri Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vl Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
li Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
li Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vr))
Op
PowerAbs -> (Double -> Double
forall a. Num a => a -> a
abs Double
li Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
ri) Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
vr Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log (Double -> Double
forall a. Num a => a -> a
abs Double
li) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
ri Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vl Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
li)
Op
AQ -> ((Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
riDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
ri) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vl Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
li Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
ri Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vr) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
riDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
ri) Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
1.5
tlF <- UMA.unsafeFreeze cmp tl
pure (applyBin op l r, tlF)
lToR :: SRTree (f val) -> f (SRTree val)
lToR (Var Ix1
ix) = SRTree val -> f (SRTree val)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ix1 -> SRTree val
forall val. Ix1 -> SRTree val
Var Ix1
ix)
lToR (Param Ix1
ix) = SRTree val -> f (SRTree val)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ix1 -> SRTree val
forall val. Ix1 -> SRTree val
Param Ix1
ix)
lToR (Const Double
c) = SRTree val -> f (SRTree val)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> SRTree val
forall val. Double -> SRTree val
Const Double
c)
lToR (Uni Function
f f val
mt) = Function -> val -> SRTree val
forall val. Function -> val -> SRTree val
Uni Function
f (val -> SRTree val) -> f val -> f (SRTree val)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f val
mt
lToR (Bin Op
op f val
ml f val
mr) = Op -> val -> val -> SRTree val
forall val. Op -> val -> val -> SRTree val
Bin Op
op (val -> val -> SRTree val) -> f val -> f (val -> SRTree val)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f val
ml f (val -> SRTree val) -> f val -> f (SRTree val)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> f val
mr
forwardModeUnique :: SRMatrix -> PVector -> SRVector -> Fix SRTree -> (SRVector, Array S Ix1 Double)
forwardModeUnique :: Array S Ix2 Double
-> Array S Ix1 Double
-> SRVector
-> Fix SRTree
-> (SRVector, Array S Ix1 Double)
forwardModeUnique Array S Ix2 Double
xss Array S Ix1 Double
theta SRVector
err = (DList SRVector -> Array S Ix1 Double)
-> (SRVector, DList SRVector) -> (SRVector, Array S Ix1 Double)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ([SRVector] -> Array S Ix1 Double
forall {r}. Manifest r Double => [SRVector] -> Vector r Double
toGrad ([SRVector] -> Array S Ix1 Double)
-> (DList SRVector -> [SRVector])
-> DList SRVector
-> Array S Ix1 Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DList SRVector -> [SRVector]
forall a. DList a -> [a]
DL.toList) ((SRVector, DList SRVector) -> (SRVector, Array S Ix1 Double))
-> (Fix SRTree -> (SRVector, DList SRVector))
-> Fix SRTree
-> (SRVector, Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRTree (SRVector, DList SRVector) -> (SRVector, DList SRVector))
-> Fix SRTree -> (SRVector, DList SRVector)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (SRVector, DList SRVector) -> (SRVector, DList SRVector)
alg
where
(Sz Ix1
n) = Array S Ix1 Double -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix1 Double
theta
one :: SRVector
one = Array S Ix2 Double -> Double -> SRVector
replicateAs Array S Ix2 Double
xss Double
1
toGrad :: [SRVector] -> Vector r Double
toGrad [SRVector]
grad = Comp -> [Double] -> Vector r Double
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList (Array S Ix2 Double -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp Array S Ix2 Double
xss) [SRVector
g SRVector -> SRVector -> Double
forall r e.
(Numeric r e, Source r e) =>
Vector r e -> Vector r e -> e
!.! SRVector
err | SRVector
g <- [SRVector]
grad]
alg :: SRTree (SRVector, DList SRVector) -> (SRVector, DList SRVector)
alg (Var Ix1
ix) = (Array S Ix2 Double
xss Array S Ix2 Double -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix, DList SRVector
forall a. DList a
DL.empty)
alg (Param Ix1
ix) = (Array S Ix2 Double -> Double -> SRVector
replicateAs Array S Ix2 Double
xss (Double -> SRVector) -> Double -> SRVector
forall a b. (a -> b) -> a -> b
$ Array S Ix1 Double
theta Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
ix, SRVector -> DList SRVector
forall a. a -> DList a
DL.singleton SRVector
one)
alg (Const Double
c) = (Array S Ix2 Double -> Double -> SRVector
replicateAs Array S Ix2 Double
xss Double
c, DList SRVector
forall a. DList a
DL.empty)
alg (Uni Function
f (SRVector
v, DList SRVector
gs)) = let v' :: SRVector
v' = Function -> SRVector -> SRVector
forall a. Floating a => Function -> a -> a
evalFun Function
f SRVector
v
dv :: SRVector
dv = Function -> SRVector -> SRVector
forall a. Floating a => Function -> a -> a
derivative Function
f SRVector
v
in (SRVector
v', (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv) DList SRVector
gs)
alg (Bin Op
Add (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
l DList SRVector
r)
alg (Bin Op
Sub (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
-SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
l ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map SRVector -> SRVector
forall a. Num a => a -> a
negate DList SRVector
r))
alg (Bin Op
Mul (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2) DList SRVector
l) ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v1) DList SRVector
r))
alg (Bin Op
Div (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv :: SRVector
dv = ((-SRVector
v1)SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/(SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2))
in (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/SRVector
v2) DList SRVector
l) ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv) DList SRVector
r))
alg (Bin Op
Power (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv1 :: SRVector
dv1 = SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** (SRVector
v2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
one)
dv2 :: SRVector
dv2 = SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
v1
in (SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
v2, (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv1) (DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2) DList SRVector
l) ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv2) DList SRVector
r)))
alg (Bin Op
PowerAbs (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv1 :: SRVector
dv1 = SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
v2
dv2 :: DList SRVector
dv2 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector -> SRVector
forall a. Floating a => a -> a
log (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
v1))) DList SRVector
r
dv3 :: DList SRVector
dv3 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(SRVector
v2 SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ SRVector
v1)) DList SRVector
l
in (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
v2, (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv1) (DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
dv2 DList SRVector
dv3))
alg (Bin Op
AQ (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv1 :: DList SRVector
dv1 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(SRVector
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2)) DList SRVector
l
dv2 :: DList SRVector
dv2 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(-SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2)) DList SRVector
r
in (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/SRVector -> SRVector
forall a. Floating a => a -> a
sqrt(SRVector
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2), (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/(SRVector
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2)SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
**SRVector
1.5) (DList SRVector -> DList SRVector)
-> DList SRVector -> DList SRVector
forall a b. (a -> b) -> a -> b
$ DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
dv1 DList SRVector
dv2)
data TupleF a b = Single a | T a b | Branch a b b deriving (forall a b. (a -> b) -> TupleF a a -> TupleF a b)
-> (forall a b. a -> TupleF a b -> TupleF a a)
-> Functor (TupleF a)
forall a b. a -> TupleF a b -> TupleF a a
forall a b. (a -> b) -> TupleF a a -> TupleF a b
forall a a b. a -> TupleF a b -> TupleF a a
forall a a b. (a -> b) -> TupleF a a -> TupleF a b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a a b. (a -> b) -> TupleF a a -> TupleF a b
fmap :: forall a b. (a -> b) -> TupleF a a -> TupleF a b
$c<$ :: forall a a b. a -> TupleF a b -> TupleF a a
<$ :: forall a b. a -> TupleF a b -> TupleF a a
Functor
type Tuple a = Fix (TupleF a)
reverseModeUnique :: SRMatrix
-> PVector
-> SRVector
-> (SRVector -> SRVector)
-> Fix SRTree
-> (Array D Ix1 Double, Array S Ix1 Double)
reverseModeUnique :: Array S Ix2 Double
-> Array S Ix1 Double
-> SRVector
-> (SRVector -> SRVector)
-> Fix SRTree
-> (SRVector, Array S Ix1 Double)
reverseModeUnique Array S Ix2 Double
xss Array S Ix1 Double
theta SRVector
ys SRVector -> SRVector
f Fix SRTree
t = IO (SRVector, Array S Ix1 Double) -> (SRVector, Array S Ix1 Double)
forall a. IO a -> a
unsafePerformIO (IO (SRVector, Array S Ix1 Double)
-> (SRVector, Array S Ix1 Double))
-> IO (SRVector, Array S Ix1 Double)
-> (SRVector, Array S Ix1 Double)
forall a b. (a -> b) -> a -> b
$
do jacob <- Sz Ix1 -> Double -> IO (MArray (PrimState IO) S Ix1 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Sz ix -> Double -> m (MArray (PrimState m) S ix Double)
M.newMArray (Ix1 -> Sz Ix1
forall ix. Index ix => ix -> Sz ix
Sz Ix1
p) Double
0
let !_ = accu reverse (combine jacob) t ((Right 1), fwdMode)
j <- freezeS jacob
pure (v, j)
where
fwdMode :: Fix (TupleF (Either SRVector Double))
fwdMode = (SRTree (Fix (TupleF (Either SRVector Double)))
-> Fix (TupleF (Either SRVector Double)))
-> Fix SRTree -> Fix (TupleF (Either SRVector Double))
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Fix (TupleF (Either SRVector Double)))
-> Fix (TupleF (Either SRVector Double))
forward Fix SRTree
t
v :: SRVector
v = Either SRVector Double -> SRVector
forall {r} {e}.
Load r Ix1 e =>
Either (Array r Ix1 e) e -> Array r Ix1 e
fromEither (Either SRVector Double -> SRVector)
-> Either SRVector Double -> SRVector
forall a b. (a -> b) -> a -> b
$ Fix (TupleF (Either SRVector Double)) -> Either SRVector Double
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either SRVector Double))
fwdMode
err :: SRVector
err = SRVector -> SRVector
f SRVector
v SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
ys
(Sz2 Ix1
m Ix1
_) = Array S Ix2 Double -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix2 Double
xss
p :: Ix1
p = Fix SRTree -> Ix1
forall a. Num a => Fix SRTree -> a
countParams Fix SRTree
t
fromEither :: Either (Array r Ix1 e) e -> Array r Ix1 e
fromEither (Left Array r Ix1 e
x) = Array r Ix1 e
x
fromEither (Right e
x) = Comp -> Sz Ix1 -> e -> Array r Ix1 e
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate (Array S Ix2 Double -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp Array S Ix2 Double
xss) (Ix1 -> Sz Ix1
Sz1 Ix1
m) e
x
oneTpl :: a -> Fix (TupleF a)
oneTpl a
x = TupleF a (Fix (TupleF a)) -> Fix (TupleF a)
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (TupleF a (Fix (TupleF a)) -> Fix (TupleF a))
-> TupleF a (Fix (TupleF a)) -> Fix (TupleF a)
forall a b. (a -> b) -> a -> b
$ a -> TupleF a (Fix (TupleF a))
forall a b. a -> TupleF a b
Single a
x
tuple :: a -> Fix (TupleF a) -> Fix (TupleF a)
tuple a
x Fix (TupleF a)
y = TupleF a (Fix (TupleF a)) -> Fix (TupleF a)
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (TupleF a (Fix (TupleF a)) -> Fix (TupleF a))
-> TupleF a (Fix (TupleF a)) -> Fix (TupleF a)
forall a b. (a -> b) -> a -> b
$ a -> Fix (TupleF a) -> TupleF a (Fix (TupleF a))
forall a b. a -> b -> TupleF a b
T a
x Fix (TupleF a)
y
branch :: a -> Fix (TupleF a) -> Fix (TupleF a) -> Fix (TupleF a)
branch a
x Fix (TupleF a)
y Fix (TupleF a)
z = TupleF a (Fix (TupleF a)) -> Fix (TupleF a)
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (TupleF a (Fix (TupleF a)) -> Fix (TupleF a))
-> TupleF a (Fix (TupleF a)) -> Fix (TupleF a)
forall a b. (a -> b) -> a -> b
$ a -> Fix (TupleF a) -> Fix (TupleF a) -> TupleF a (Fix (TupleF a))
forall a b. a -> b -> b -> TupleF a b
Branch a
x Fix (TupleF a)
y Fix (TupleF a)
z
getTop :: Fix (TupleF a) -> a
getTop (Fix (Single a
x)) = a
x
getTop (Fix (T a
x Fix (TupleF a)
y)) = a
x
getTop (Fix (Branch a
x Fix (TupleF a)
y Fix (TupleF a)
z)) = a
x
unCons :: Fix (TupleF a) -> Fix (TupleF a)
unCons (Fix (T a
x Fix (TupleF a)
y)) = Fix (TupleF a)
y
getBranches :: Fix (TupleF a) -> (Fix (TupleF a), Fix (TupleF a))
getBranches (Fix (Branch a
x Fix (TupleF a)
y Fix (TupleF a)
z)) = (Fix (TupleF a)
y,Fix (TupleF a)
z)
forward :: SRTree (Fix (TupleF (Either SRVector Double)))
-> Fix (TupleF (Either SRVector Double))
forward (Var Ix1
ix) = Either SRVector Double -> Fix (TupleF (Either SRVector Double))
forall {a}. a -> Fix (TupleF a)
oneTpl (SRVector -> Either SRVector Double
forall a b. a -> Either a b
Left (SRVector -> Either SRVector Double)
-> SRVector -> Either SRVector Double
forall a b. (a -> b) -> a -> b
$ Array S Ix2 Double
xss Array S Ix2 Double -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix)
forward (Param Ix1
ix) = Either SRVector Double -> Fix (TupleF (Either SRVector Double))
forall {a}. a -> Fix (TupleF a)
oneTpl (Double -> Either SRVector Double
forall a b. b -> Either a b
Right (Double -> Either SRVector Double)
-> Double -> Either SRVector Double
forall a b. (a -> b) -> a -> b
$ Array S Ix1 Double
theta Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
ix)
forward (Const Double
c) = Either SRVector Double -> Fix (TupleF (Either SRVector Double))
forall {a}. a -> Fix (TupleF a)
oneTpl (Double -> Either SRVector Double
forall a b. b -> Either a b
Right Double
c)
forward (Uni Function
g Fix (TupleF (Either SRVector Double))
t) = let v :: Either SRVector Double
v = Fix (TupleF (Either SRVector Double)) -> Either SRVector Double
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either SRVector Double))
t
in Either SRVector Double
-> Fix (TupleF (Either SRVector Double))
-> Fix (TupleF (Either SRVector Double))
forall {a}. a -> Fix (TupleF a) -> Fix (TupleF a)
tuple (Function -> Either SRVector Double -> Either SRVector Double
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
g Either SRVector Double
v) Fix (TupleF (Either SRVector Double))
t
forward (Bin Op
op Fix (TupleF (Either SRVector Double))
l Fix (TupleF (Either SRVector Double))
r) = let vl :: Either SRVector Double
vl = Fix (TupleF (Either SRVector Double)) -> Either SRVector Double
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either SRVector Double))
l
vr :: Either SRVector Double
vr = Fix (TupleF (Either SRVector Double)) -> Either SRVector Double
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either SRVector Double))
r
in Either SRVector Double
-> Fix (TupleF (Either SRVector Double))
-> Fix (TupleF (Either SRVector Double))
-> Fix (TupleF (Either SRVector Double))
forall {a}. a -> Fix (TupleF a) -> Fix (TupleF a) -> Fix (TupleF a)
branch (Op
-> Either SRVector Double
-> Either SRVector Double
-> Either SRVector Double
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
op Either SRVector Double
vl Either SRVector Double
vr) Fix (TupleF (Either SRVector Double))
l Fix (TupleF (Either SRVector Double))
r
reverse :: SRTree a
-> (Either (Array D ix b) b,
Fix (TupleF (Either (Array D ix b) b)))
-> SRTree
(a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
reverse (Var Ix1
ix) (Either (Array D ix b) b
dx, Fix (TupleF (Either (Array D ix b) b))
_) = Ix1
-> SRTree
(a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
forall val. Ix1 -> SRTree val
Var Ix1
ix
reverse (Param Ix1
ix) (Either (Array D ix b) b
dx, Fix (TupleF (Either (Array D ix b) b))
_) = Ix1
-> SRTree
(a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
forall val. Ix1 -> SRTree val
Param Ix1
ix
reverse (Const Double
v) (Either (Array D ix b) b
dx, Fix (TupleF (Either (Array D ix b) b))
_) = Double
-> SRTree
(a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
forall val. Double -> SRTree val
Const Double
v
reverse (Uni Function
f a
t) (Either (Array D ix b) b
dx, Fix (TupleF (Either (Array D ix b) b))
-> Fix (TupleF (Either (Array D ix b) b))
forall {a}. Fix (TupleF a) -> Fix (TupleF a)
unCons -> Fix (TupleF (Either (Array D ix b) b))
v) =
let g' :: Either (Array D ix b) b
g' = Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyDer Function
f (Fix (TupleF (Either (Array D ix b) b)) -> Either (Array D ix b) b
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either (Array D ix b) b))
v)
in Function
-> (a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
-> SRTree
(a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
forall val. Function -> val -> SRTree val
Uni Function
f (a
t, ( Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dx Either (Array D ix b) b
g', Fix (TupleF (Either (Array D ix b) b))
v ))
reverse (Bin Op
op a
l a
r) (Either (Array D ix b) b
dx, Fix (TupleF (Either (Array D ix b) b))
-> (Fix (TupleF (Either (Array D ix b) b)),
Fix (TupleF (Either (Array D ix b) b)))
forall {a}. Fix (TupleF a) -> (Fix (TupleF a), Fix (TupleF a))
getBranches -> (Fix (TupleF (Either (Array D ix b) b))
vl, Fix (TupleF (Either (Array D ix b) b))
vr)) =
let (Either (Array D ix b) b
dxl, Either (Array D ix b) b
dxr) = Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> (Either (Array D ix b) b, Either (Array D ix b) b)
forall {ix} {b}.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> (Either (Array D ix b) b, Either (Array D ix b) b)
diff Op
op Either (Array D ix b) b
dx (Fix (TupleF (Either (Array D ix b) b)) -> Either (Array D ix b) b
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either (Array D ix b) b))
vl) (Fix (TupleF (Either (Array D ix b) b)) -> Either (Array D ix b) b
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either (Array D ix b) b))
vr)
in Op
-> (a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
-> (a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
-> SRTree
(a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
forall val. Op -> val -> val -> SRTree val
Bin Op
op (a
l, (Either (Array D ix b) b
dxl, Fix (TupleF (Either (Array D ix b) b))
vl)) (a
r, (Either (Array D ix b) b
dxr, Fix (TupleF (Either (Array D ix b) b))
vr))
diff :: Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> (Either (Array D ix b) b, Either (Array D ix b) b)
diff Op
Add Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = (Either (Array D ix b) b
dx, Either (Array D ix b) b
dx)
diff Op
Sub Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = (Either (Array D ix b) b
dx, Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Num e, Num b) =>
Either (Array r ix e) b -> Either (Array D ix e) b
negate' Either (Array D ix b) b
dx)
diff Op
Mul Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dx Either (Array D ix b) b
gy, Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dx Either (Array D ix b) b
fx)
diff Op
Div Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Div Either (Array D ix b) b
dx Either (Array D ix b) b
gy, Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dx (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Div (Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Num e, Num b) =>
Either (Array r ix e) b -> Either (Array D ix e) b
negate' Either (Array D ix b) b
fx) (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
gy Either (Array D ix b) b
gy)))
diff Op
Power Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = let dxl :: Either (Array D ix b) b
dxl = Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dx (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Power Either (Array D ix b) b
fx (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Sub Either (Array D ix b) b
gy (b -> Either (Array D ix b) b
forall a b. b -> Either a b
Right b
1)))
dv2 :: Either (Array D ix b) b
dv2 = Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
fx (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Log Either (Array D ix b) b
fx)
in (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dxl Either (Array D ix b) b
gy, Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dxl Either (Array D ix b) b
dv2)
diff Op
PowerAbs Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = let dxl :: Either (Array D ix b) b
dxl = Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
gy Either (Array D ix b) b
fx) (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
PowerAbs Either (Array D ix b) b
fx (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Sub Either (Array D ix b) b
gy (b -> Either (Array D ix b) b
forall a b. b -> Either a b
Right b
2)))
dxr :: Either (Array D ix b) b
dxr = Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
LogAbs Either (Array D ix b) b
fx) (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
PowerAbs Either (Array D ix b) b
fx Either (Array D ix b) b
gy)
in (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dxl Either (Array D ix b) b
dx, Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dxr Either (Array D ix b) b
dx)
diff Op
AQ Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = let dxl :: Either (Array D ix b) b
dxl = Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Recip (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Sqrt (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Add (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Square Either (Array D ix b) b
gy) (b -> Either (Array D ix b) b
forall a b. b -> Either a b
Right b
1)))
dxy :: Either (Array D ix b) b
dxy = Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Div (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
fx Either (Array D ix b) b
gy) (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Cube (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Sqrt (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Add (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Square Either (Array D ix b) b
gy) (b -> Either (Array D ix b) b
forall a b. b -> Either a b
Right b
1))))
in (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dxl Either (Array D ix b) b
dx, Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dxy Either (Array D ix b) b
dx)
combine :: MArray RealWorld r Ix1 Double
-> SRTree Double -> (Either SRVector Double, b) -> Double
combine MArray RealWorld r Ix1 Double
j (Var Ix1
ix) (Either SRVector Double, b)
s = Double
0
combine MArray RealWorld r Ix1 Double
j (Const Double
_) (Either SRVector Double, b)
s = Double
0
combine MArray RealWorld r Ix1 Double
j (Param Ix1
ix) (Either SRVector Double, b)
s = IO Double -> Double
forall a. IO a -> a
unsafePerformIO (IO Double -> Double) -> IO Double -> Double
forall a b. (a -> b) -> a -> b
$ do
case (Either SRVector Double, b) -> Either SRVector Double
forall a b. (a, b) -> a
fst (Either SRVector Double, b)
s of
Left SRVector
v -> do v' <- SRVector -> SRVector -> IO Double
forall r e (m :: * -> *).
(FoldNumeric r e, Source r e, MonadThrow m) =>
Vector r e -> Vector r e -> m e
dotM SRVector
v SRVector
err
UMA.unsafeWrite j ix v'
Right Double
v -> MArray (PrimState IO) r Ix1 Double -> Ix1 -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray RealWorld r Ix1 Double
MArray (PrimState IO) r Ix1 Double
j Ix1
ix (Double -> IO ()) -> Double -> IO ()
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double) -> Double -> SRVector -> Double
forall ix r e a.
(Index ix, Source r e) =>
(e -> a -> a) -> a -> Array r ix e -> a
M.foldrS (\Double
x Double
acc -> Double
xDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
acc) Double
0 SRVector
err
MArray (PrimState IO) r Ix1 Double -> Ix1 -> IO Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray RealWorld r Ix1 Double
MArray (PrimState IO) r Ix1 Double
j Ix1
ix
combine MArray RealWorld r Ix1 Double
j (Uni Function
f Double
gs) (Either SRVector Double, b)
s = Double
gs
combine MArray RealWorld r Ix1 Double
j (Bin Op
op Double
l Double
r) (Either SRVector Double, b)
s = Double
lDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
r
reverseModeUniqueArr :: Array r Ix2 Double
-> Array r Ix1 Double
-> Array r Ix1 Double
-> (Array D (Lower Ix2) Double -> SRVector)
-> [(Ix1, (a, Ix1, Ix1, Double))]
-> IntMap Ix1
-> (SRVector, Array S Ix1 Double)
reverseModeUniqueArr Array r Ix2 Double
xss Array r Ix1 Double
theta Array r Ix1 Double
ys Array D (Lower Ix2) Double -> SRVector
f [(Ix1, (a, Ix1, Ix1, Double))]
t IntMap Ix1
j2ix =
IO (SRVector, Array S Ix1 Double) -> (SRVector, Array S Ix1 Double)
forall a. IO a -> a
unsafePerformIO (IO (SRVector, Array S Ix1 Double)
-> (SRVector, Array S Ix1 Double))
-> IO (SRVector, Array S Ix1 Double)
-> (SRVector, Array S Ix1 Double)
forall a b. (a -> b) -> a -> b
$ do
fwd <- Sz Ix2 -> Double -> IO (MArray (PrimState IO) S Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Sz ix -> Double -> m (MArray (PrimState m) S ix Double)
M.newMArray (Ix1 -> Ix1 -> Sz Ix2
Sz2 Ix1
m Ix1
n) Double
0
partial <- M.newMArray (Sz2 m n) 0
jacob <- M.newMArray (Sz p) 0
fwd' <- UMA.unsafeFreeze (getComp xss) fwd
let v = Array S Ix2 Double
fwd' Array S Ix2 Double -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
M.<! Ix1
0
err = S -> SRVector -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (SRVector -> Array S Ix1 Double) -> SRVector -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ Array D (Lower Ix2) Double -> SRVector
f Array D (Lower Ix2) Double
v SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- Array r Ix1 Double -> SRVector
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay Array r Ix1 Double
ys
forward fwd
combine partial jacob err
j <- UMA.unsafeFreeze (getComp xss) jacob
pure (v, j)
where
(Sz2 Ix1
m Ix1
_) = Array r Ix2 Double -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
M.size Array r Ix2 Double
xss
(Sz Ix1
p) = Array r Ix1 Double -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array r ix e -> Sz ix
M.size Array r Ix1 Double
theta
n :: Ix1
n = [(Ix1, (a, Ix1, Ix1, Double))] -> Ix1
forall a. [a] -> Ix1
forall (t :: * -> *) a. Foldable t => t a -> Ix1
length [(Ix1, (a, Ix1, Ix1, Double))]
t
forward :: MArray (PrimState IO) S Ix2 Double -> IO ()
forward :: MArray (PrimState IO) S Ix2 Double -> IO ()
forward MArray (PrimState IO) S Ix2 Double
fwd = [(Ix1, (a, Ix1, Ix1, Double))]
-> ((Ix1, (a, Ix1, Ix1, Double)) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Ix1, (a, Ix1, Ix1, Double))] -> [(Ix1, (a, Ix1, Ix1, Double))]
forall a. [a] -> [a]
Prelude.reverse [(Ix1, (a, Ix1, Ix1, Double))]
t) (Ix1, (a, Ix1, Ix1, Double)) -> IO ()
forall {m :: * -> *} {a}.
(PrimState m ~ RealWorld, Eq a, Num a, PrimMonad m) =>
(Ix1, (a, Ix1, Ix1, Double)) -> m ()
makeFwd
where
makeFwd :: (Ix1, (a, Ix1, Ix1, Double)) -> m ()
makeFwd (Ix1
j, (a
0, Ix1
0, Ix1
ix, Double
_)) = do let j' :: Ix1
j' = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
[Ix1] -> (Ix1 -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0..Ix1
mIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> m ()) -> m ()) -> (Ix1 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
let val :: Double
val = Array r Ix2 Double
xss Array r Ix2 Double -> Ix2 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
ix)
MArray (PrimState m) S Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
j') Double
val
makeFwd (Ix1
j, (a
0, Ix1
1, Ix1
ix, Double
_)) = do let j' :: Ix1
j' = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
v :: Double
v = Array r Ix1 Double
theta Array r Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! Ix1
ix
[Ix1] -> (Ix1 -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0..Ix1
mIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> m ()) -> m ()) -> (Ix1 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
MArray (PrimState m) S Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
j') Double
v
makeFwd (Ix1
j, (a
0, Ix1
2, Ix1
_, Double
x)) = do let j' :: Ix1
j' = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
[Ix1] -> (Ix1 -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0..Ix1
mIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> m ()) -> m ()) -> (Ix1 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
MArray (PrimState m) S Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
j') Double
x
makeFwd (Ix1
j, (a
1, Ix1
f, Ix1
_, Double
_)) = do let j' :: Ix1
j' = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
j2 :: Ix1
j2 = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1)
[Ix1] -> (Ix1 -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0..Ix1
mIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> m ()) -> m ()) -> (Ix1 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
v <- MArray (PrimState m) S Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
j2)
let val = Function -> Double -> Double
forall a. Floating a => Function -> a -> a
evalFun (Ix1 -> Function
forall a. Enum a => Ix1 -> a
toEnum Ix1
f) Double
v
UMA.unsafeWrite fwd (i :. j') val
makeFwd (Ix1
j, (a
2, Ix1
op, Ix1
_, Double
_)) = do let j' :: Ix1
j' = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
j2 :: Ix1
j2 = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1)
j3 :: Ix1
j3 = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
2)
[Ix1] -> (Ix1 -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0..Ix1
mIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> m ()) -> m ()) -> (Ix1 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
l <- MArray (PrimState m) S Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
j2)
r <- UMA.unsafeRead fwd (i :. j3)
let val = Op -> Double -> Double -> Double
forall a. Floating a => Op -> a -> a -> a
evalOp (Ix1 -> Op
forall a. Enum a => Ix1 -> a
toEnum Ix1
op) Double
l Double
r
UMA.unsafeWrite fwd (i :. j') val
reverseMode :: MArray (PrimState IO) S Ix2 Double -> MArray (PrimState IO) S Ix2 Double -> IO ()
reverseMode :: MArray (PrimState IO) S Ix2 Double
-> MArray (PrimState IO) S Ix2 Double -> IO ()
reverseMode MArray (PrimState IO) S Ix2 Double
fwd MArray (PrimState IO) S Ix2 Double
partial = do [Ix1] -> (Ix1 -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0..Ix1
mIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> IO ()) -> IO ()) -> (Ix1 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> MArray (PrimState IO) S Ix2 Double -> Ix2 -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState IO) S Ix2 Double
partial (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
0) Double
1
[(Ix1, (a, Ix1, Ix1, Double))]
-> ((Ix1, (a, Ix1, Ix1, Double)) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Ix1, (a, Ix1, Ix1, Double))]
t (Ix1, (a, Ix1, Ix1, Double)) -> IO ()
forall {m :: * -> *} {a} {c} {d}.
(PrimState m ~ RealWorld, Eq a, Num a, PrimMonad m) =>
(Ix1, (a, Ix1, c, d)) -> m ()
makeRev
where
makeRev :: (Ix1, (a, Ix1, c, d)) -> m ()
makeRev (Ix1
j, (a
1, Ix1
f, c
_, d
_)) = do [Ix1] -> (Ix1 -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0..Ix1
mIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> m ()) -> m ()) -> (Ix1 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
let dxj :: Ix1
dxj = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
vj :: Ix1
vj = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1)
v <- MArray (PrimState m) S Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
vj)
dx <- UMA.unsafeRead partial (i :. dxj)
let val = Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Function -> Double -> Double
forall a. Floating a => Function -> a -> a
derivative (Ix1 -> Function
forall a. Enum a => Ix1 -> a
toEnum Ix1
f) Double
v
UMA.unsafeWrite partial (i :. vj) val
makeRev (Ix1
j, (a
2, Ix1
op, c
_, d
_)) = do [Ix1] -> (Ix1 -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0..Ix1
mIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> m ()) -> m ()) -> (Ix1 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
let dxj :: Ix1
dxj = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
lj :: Ix1
lj = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1)
rj :: Ix1
rj = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
2)
l <- MArray (PrimState m) S Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
lj)
r <- UMA.unsafeRead fwd (i :. rj)
dx <- UMA.unsafeRead partial (i :. dxj)
let (dxl, dxr) = diff (toEnum op) dx l r
UMA.unsafeWrite partial (i :. lj) dxl
UMA.unsafeWrite partial (i :. rj) dxr
makeRev (Ix1, (a, Ix1, c, d))
_ = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Array r1 ix e
arr1 !**! :: Array r1 ix e -> Array r2 ix e -> Array D ix e
!**! Array r2 ix e
arr2 = (e -> e -> e) -> Array r1 ix e -> Array r2 ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith e -> e -> e
forall a. Floating a => a -> a -> a
(**) Array r1 ix e
arr1 Array r2 ix e
arr2
diff :: Op -> b -> b -> b -> (b, b)
diff Op
Add b
dx b
fx b
gy = (b
dx, b
dx)
diff Op
Sub b
dx b
fx b
gy = (b
dx, b -> b
forall a. Num a => a -> a
negate b
dx)
diff Op
Mul b
dx b
fx b
gy = (b
dx b -> b -> b
forall a. Num a => a -> a -> a
* b
gy, b
dx b -> b -> b
forall a. Num a => a -> a -> a
* b
fx)
diff Op
Div b
dx b
fx b
gy = (b
dx b -> b -> b
forall a. Fractional a => a -> a -> a
/ b
gy, b
dx b -> b -> b
forall a. Num a => a -> a -> a
* (b -> b
forall a. Num a => a -> a
negate b
fx b -> b -> b
forall a. Fractional a => a -> a -> a
/ (b
gy b -> b -> b
forall a. Num a => a -> a -> a
* b
gy)))
diff Op
Power b
dx b
fx b
gy = let dxl :: b
dxl = b
dx b -> b -> b
forall a. Num a => a -> a -> a
* (b
fx b -> b -> b
forall a. Floating a => a -> a -> a
** (b
gyb -> b -> b
forall a. Num a => a -> a -> a
-b
1))
dv2 :: b
dv2 = b
fx b -> b -> b
forall a. Num a => a -> a -> a
* b -> b
forall a. Floating a => a -> a
log b
fx
in (b
dxl b -> b -> b
forall a. Num a => a -> a -> a
* b
gy, b
dxl b -> b -> b
forall a. Num a => a -> a -> a
* b
dv2)
diff Op
PowerAbs b
dx b
fx b
gy = let dxl :: b
dxl = (b
gy b -> b -> b
forall a. Num a => a -> a -> a
* b
fx) b -> b -> b
forall a. Num a => a -> a -> a
* (b
fx b -> b -> b
forall a. Floating a => a -> a -> a
** b -> b
forall a. Num a => a -> a
abs (b
gy b -> b -> b
forall a. Num a => a -> a -> a
- b
2))
dxr :: b
dxr = (b -> b
forall a. Floating a => a -> a
log (b -> b
forall a. Num a => a -> a
abs b
fx)) b -> b -> b
forall a. Num a => a -> a -> a
* (b
fx b -> b -> b
forall a. Floating a => a -> a -> a
** b -> b
forall a. Num a => a -> a
abs b
gy)
in (b
dxl b -> b -> b
forall a. Num a => a -> a -> a
* b
dx, b
dxr b -> b -> b
forall a. Num a => a -> a -> a
* b
dx)
diff Op
AQ b
dx b
fx b
gy = let dxl :: b
dxl = b -> b
forall a. Fractional a => a -> a
recip ((b -> b
forall a. Floating a => a -> a
sqrt (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> b -> b
forall a. Num a => a -> a -> a
+b
1)) (b
gy b -> b -> b
forall a. Num a => a -> a -> a
* b
gy))
dxy :: b
dxy = b
fx b -> b -> b
forall a. Num a => a -> a -> a
* b
gy b -> b -> b
forall a. Num a => a -> a -> a
* (b
dxlb -> Integer -> b
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
3)
in (b
dxl b -> b -> b
forall a. Num a => a -> a -> a
* b
dx, b
dxy b -> b -> b
forall a. Num a => a -> a -> a
* b
dx)
combine :: MArray (PrimState IO) S Ix2 Double -> MArray (PrimState IO) S Ix1 Double -> Array S Ix1 Double -> IO ()
combine :: MArray (PrimState IO) S Ix2 Double
-> MArray (PrimState IO) S Ix1 Double
-> Array S Ix1 Double
-> IO ()
combine MArray (PrimState IO) S Ix2 Double
partial MArray (PrimState IO) S Ix1 Double
jacob Array S Ix1 Double
err = [(Ix1, (a, Ix1, Ix1, Double))]
-> ((Ix1, (a, Ix1, Ix1, Double)) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Ix1, (a, Ix1, Ix1, Double))]
t (Ix1, (a, Ix1, Ix1, Double)) -> IO ()
forall {m :: * -> *} {a} {a} {d}.
(PrimState m ~ RealWorld, Eq a, Eq a, Num a, Num a, PrimMonad m) =>
(Ix1, (a, a, Ix1, d)) -> m ()
makeJacob
where
makeJacob :: (Ix1, (a, a, Ix1, d)) -> m ()
makeJacob (Ix1
j, (a
0, a
1, Ix1
ix, d
_)) = do let j' :: Ix1
j' = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
addI :: Ix1 -> Ix1 -> Double -> m Double
addI Ix1
a Ix1
b Double
acc = do let v1 :: Double
v1 = Array S Ix1 Double
err Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! Ix1
a
v2 <- MArray (PrimState m) S Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
partial (Ix1
a Ix1 -> Ix1 -> Ix2
:. Ix1
b)
pure (v1*v2 + acc)
acc <- (Double -> Ix1 -> m Double) -> Double -> [Ix1] -> m Double
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Double
a Ix1
i -> Ix1 -> Ix1 -> Double -> m Double
forall {m :: * -> *}.
(PrimState m ~ RealWorld, PrimMonad m) =>
Ix1 -> Ix1 -> Double -> m Double
addI Ix1
i Ix1
j' Double
a) Double
0 [Ix1
0..Ix1
mIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1]
UMA.unsafeWrite jacob ix acc
makeJacob (Ix1, (a, a, Ix1, d))
_ = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
forwardModeUniqueJac :: SRMatrix -> PVector -> Fix SRTree -> [PVector]
forwardModeUniqueJac :: Array S Ix2 Double
-> Array S Ix1 Double -> Fix SRTree -> [Array S Ix1 Double]
forwardModeUniqueJac Array S Ix2 Double
xss Array S Ix1 Double
theta = (SRVector, [Array S Ix1 Double]) -> [Array S Ix1 Double]
forall a b. (a, b) -> b
snd ((SRVector, [Array S Ix1 Double]) -> [Array S Ix1 Double])
-> (Fix SRTree -> (SRVector, [Array S Ix1 Double]))
-> Fix SRTree
-> [Array S Ix1 Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DList SRVector -> [Array S Ix1 Double])
-> (SRVector, DList SRVector) -> (SRVector, [Array S Ix1 Double])
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ((SRVector -> Array S Ix1 Double)
-> [SRVector] -> [Array S Ix1 Double]
forall a b. (a -> b) -> [a] -> [b]
map (S -> SRVector -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
M.S) ([SRVector] -> [Array S Ix1 Double])
-> (DList SRVector -> [SRVector])
-> DList SRVector
-> [Array S Ix1 Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DList SRVector -> [SRVector]
forall a. DList a -> [a]
DL.toList) ((SRVector, DList SRVector) -> (SRVector, [Array S Ix1 Double]))
-> (Fix SRTree -> (SRVector, DList SRVector))
-> Fix SRTree
-> (SRVector, [Array S Ix1 Double])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRTree (SRVector, DList SRVector) -> (SRVector, DList SRVector))
-> Fix SRTree -> (SRVector, DList SRVector)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (SRVector, DList SRVector) -> (SRVector, DList SRVector)
alg
where
(Sz Ix1
n) = Array S Ix1 Double -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix1 Double
theta
one :: SRVector
one = Array S Ix2 Double -> Double -> SRVector
replicateAs Array S Ix2 Double
xss Double
1
alg :: SRTree (SRVector, DList SRVector) -> (SRVector, DList SRVector)
alg (Var Ix1
ix) = (Array S Ix2 Double
xss Array S Ix2 Double -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix, DList SRVector
forall a. DList a
DL.empty)
alg (Param Ix1
ix) = (Array S Ix2 Double -> Double -> SRVector
replicateAs Array S Ix2 Double
xss (Double -> SRVector) -> Double -> SRVector
forall a b. (a -> b) -> a -> b
$ Array S Ix1 Double
theta Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
ix, SRVector -> DList SRVector
forall a. a -> DList a
DL.singleton SRVector
one)
alg (Const Double
c) = (Array S Ix2 Double -> Double -> SRVector
replicateAs Array S Ix2 Double
xss Double
c, DList SRVector
forall a. DList a
DL.empty)
alg (Uni Function
f (SRVector
v, DList SRVector
gs)) = let v' :: SRVector
v' = Function -> SRVector -> SRVector
forall a. Floating a => Function -> a -> a
evalFun Function
f SRVector
v
dv :: SRVector
dv = Function -> SRVector -> SRVector
forall a. Floating a => Function -> a -> a
derivative Function
f SRVector
v
in (SRVector
v', (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv) DList SRVector
gs)
alg (Bin Op
Add (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
l DList SRVector
r)
alg (Bin Op
Sub (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
-SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
l ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map SRVector -> SRVector
forall a. Num a => a -> a
negate DList SRVector
r))
alg (Bin Op
Mul (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2) DList SRVector
l) ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v1) DList SRVector
r))
alg (Bin Op
Div (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv :: SRVector
dv = ((-SRVector
v1)SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/(SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2))
in (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/SRVector
v2) DList SRVector
l) ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv) DList SRVector
r))
alg (Bin Op
Power (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv1 :: SRVector
dv1 = SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** (SRVector
v2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
one)
dv2 :: SRVector
dv2 = SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
v1
in (SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
v2, (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv1) (DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2) DList SRVector
l) ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv2) DList SRVector
r)))
alg (Bin Op
PowerAbs (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv1 :: SRVector
dv1 = SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
v2
dv2 :: DList SRVector
dv2 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector -> SRVector
forall a. Floating a => a -> a
log (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
v1))) DList SRVector
r
dv3 :: DList SRVector
dv3 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(SRVector
v2 SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ SRVector
v1)) DList SRVector
l
in (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
v2, (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv1) (DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
dv2 DList SRVector
dv3))
alg (Bin Op
AQ (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv1 :: DList SRVector
dv1 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(SRVector
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2)) DList SRVector
l
dv2 :: DList SRVector
dv2 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(-SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2)) DList SRVector
r
in (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/SRVector -> SRVector
forall a. Floating a => a -> a
sqrt(SRVector
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2), (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/(SRVector
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2)SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
**SRVector
1.5) (DList SRVector -> DList SRVector)
-> DList SRVector -> DList SRVector
forall a b. (a -> b) -> a -> b
$ DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
dv1 DList SRVector
dv2)