{-# language FlexibleInstances, DeriveFunctor #-}
{-# language ScopedTypeVariables #-}
{-# language RankNTypes #-}
{-# language ViewPatterns #-}
{-# language FlexibleContexts #-}
{-# language BangPatterns #-}
module Algorithm.SRTree.AD
( forwardMode
, forwardModeUnique
, reverseModeUnique
, forwardModeUniqueJac
) where
import Control.Monad (forM_)
import Control.Monad.ST ( runST )
import Data.Bifunctor (bimap, first, second)
import qualified Data.DList as DL
import Data.Massiv.Array hiding (forM_, map, replicate, zipWith)
import qualified Data.Massiv.Array as M
import qualified Data.Massiv.Array.Unsafe as UMA
import Data.Massiv.Core.Operations (unsafeLiftArray)
import Data.SRTree.Derivative ( derivative )
import Data.SRTree.Eval
( SRVector, evalFun, evalOp, SRMatrix, PVector, replicateAs )
import Data.SRTree.Internal
import Data.SRTree.Print (showExpr)
import Data.SRTree.Recursion ( cataM, cata, accu )
import qualified Data.Vector as V
import Debug.Trace (trace, traceShow)
import GHC.IO (unsafePerformIO)
applyUni :: (Index ix, Source r e, Floating e, Floating b) => Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni :: forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
f (Left Array r ix e
t) =
Array D ix e -> Either (Array D ix e) b
forall a b. a -> Either a b
Left (Array D ix e -> Either (Array D ix e) b)
-> Array D ix e -> Either (Array D ix e) b
forall a b. (a -> b) -> a -> b
$ (e -> e) -> Array r ix e -> Array D ix e
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> e -> e
forall a. Floating a => Function -> a -> a
evalFun Function
f) Array r ix e
t
applyUni Function
f (Right b
t) =
b -> Either (Array D ix e) b
forall a b. b -> Either a b
Right (b -> Either (Array D ix e) b) -> b -> Either (Array D ix e) b
forall a b. (a -> b) -> a -> b
$ Function -> b -> b
forall a. Floating a => Function -> a -> a
evalFun Function
f b
t
{-# INLINE applyUni #-}
applyDer :: (Index ix, Source r e, Floating e, Floating b) => Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyDer :: forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyDer Function
f (Left Array r ix e
t) =
Array D ix e -> Either (Array D ix e) b
forall a b. a -> Either a b
Left (Array D ix e -> Either (Array D ix e) b)
-> Array D ix e -> Either (Array D ix e) b
forall a b. (a -> b) -> a -> b
$ (e -> e) -> Array r ix e -> Array D ix e
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> e -> e
forall a. Floating a => Function -> a -> a
derivative Function
f) Array r ix e
t
applyDer Function
f (Right b
t) =
b -> Either (Array D ix e) b
forall a b. b -> Either a b
Right (b -> Either (Array D ix e) b) -> b -> Either (Array D ix e) b
forall a b. (a -> b) -> a -> b
$ Function -> b -> b
forall a. Floating a => Function -> a -> a
derivative Function
f b
t
{-# INLINE applyDer #-}
negate' :: (Index ix, Source r e, Num e, Num b) => Either (Array r ix e) b -> Either (Array D ix e) b
negate' :: forall ix r e b.
(Index ix, Source r e, Num e, Num b) =>
Either (Array r ix e) b -> Either (Array D ix e) b
negate' (Left Array r ix e
t) = Array D ix e -> Either (Array D ix e) b
forall a b. a -> Either a b
Left (Array D ix e -> Either (Array D ix e) b)
-> Array D ix e -> Either (Array D ix e) b
forall a b. (a -> b) -> a -> b
$ (e -> e) -> Array r ix e -> Array D ix e
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map e -> e
forall a. Num a => a -> a
negate Array r ix e
t
negate' (Right b
t) = b -> Either (Array D ix e) b
forall a b. b -> Either a b
Right (b -> Either (Array D ix e) b) -> b -> Either (Array D ix e) b
forall a b. (a -> b) -> a -> b
$ b -> b
forall a. Num a => a -> a
negate b
t
{-# INLINE negate' #-}
applyBin :: (Index ix, Floating b) => Op -> Either (Array D ix b) b -> Either (Array D ix b) b -> Either (Array D ix b) b
applyBin :: forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
op (Left Array D ix b
ly) (Left Array D ix b
ry) =
Array D ix b -> Either (Array D ix b) b
forall a b. a -> Either a b
Left (Array D ix b -> Either (Array D ix b) b)
-> Array D ix b -> Either (Array D ix b) b
forall a b. (a -> b) -> a -> b
$ case Op
op of
Op
Add -> Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r e.
(HasCallStack, Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!+! Array D ix b
ry
Op
Sub -> Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!-! Array D ix b
ry
Op
Mul -> Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!*! Array D ix b
ry
Op
Div -> Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!/! Array D ix b
ry
Op
Power -> Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r1 e r2.
(Index ix, Source r1 e, Source r2 e, Floating e) =>
Array r1 ix e -> Array r2 ix e -> Array D ix e
.** Array D ix b
ry
Op
PowerAbs -> (b -> b) -> Array D ix b -> Array D ix b
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map b -> b
forall a. Num a => a -> a
abs (Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r1 e r2.
(Index ix, Source r1 e, Source r2 e, Floating e) =>
Array r1 ix e -> Array r2 ix e -> Array D ix e
.** Array D ix b
ry)
Op
AQ -> Array D ix b
ly Array D ix b -> Array D ix b -> Array D ix b
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!/! ((b -> b) -> Array D ix b -> Array D ix b
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map b -> b
forall a. Floating a => a -> a
sqrt ((b -> b) -> Array D ix b -> Array D ix b
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (b -> b -> b
forall a. Num a => a -> a -> a
+b
1) (Array D ix b
ry Array D ix b -> Array D ix b -> Array D ix b
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
!*! Array D ix b
ry)))
applyBin Op
op (Left Array D ix b
ly) (Right b
ry) =
Array D ix b -> Either (Array D ix b) b
forall a b. a -> Either a b
Left (Array D ix b -> Either (Array D ix b) b)
-> Array D ix b -> Either (Array D ix b) b
forall a b. (a -> b) -> a -> b
$ (b -> b) -> Array D ix b -> Array D ix b
forall ix. Index ix => (b -> b) -> Array D ix b -> Array D ix b
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (\ b
x -> Op -> b -> b -> b
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op b
x b
ry) Array D ix b
ly
applyBin Op
op (Right b
ly) (Left Array D ix b
ry) =
Array D ix b -> Either (Array D ix b) b
forall a b. a -> Either a b
Left (Array D ix b -> Either (Array D ix b) b)
-> Array D ix b -> Either (Array D ix b) b
forall a b. (a -> b) -> a -> b
$ (b -> b) -> Array D ix b -> Array D ix b
forall ix. Index ix => (b -> b) -> Array D ix b -> Array D ix b
forall r e ix.
(Numeric r e, Index ix) =>
(e -> e) -> Array r ix e -> Array r ix e
unsafeLiftArray (\ b
x -> Op -> b -> b -> b
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op b
ly b
x) Array D ix b
ry
applyBin Op
op (Right b
ly) (Right b
ry) =
b -> Either (Array D ix b) b
forall a b. b -> Either a b
Right (b -> Either (Array D ix b) b) -> b -> Either (Array D ix b) b
forall a b. (a -> b) -> a -> b
$ Op -> b -> b -> b
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op b
ly b
ry
{-# INLINE applyBin #-}
(!??) :: (Manifest r e, Index ix) => Either (Array r ix e) e -> ix -> e
(Left Array r ix e
y) !?? :: forall r e ix.
(Manifest r e, Index ix) =>
Either (Array r ix e) e -> ix -> e
!?? ix
ix = Array r ix e
y Array r ix e -> ix -> e
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! ix
ix
(Right e
y) !?? ix
ix = e
y
{-# INLINE (!??) #-}
forwardMode :: Array S Ix2 Double -> Array S Ix1 Double -> SRVector -> Fix SRTree -> (Array D Ix1 Double, Array S Ix1 Double)
forwardMode :: Array S Ix2 Double
-> Array S Ix1 Double
-> SRVector
-> Fix SRTree
-> (SRVector, Array S Ix1 Double)
forwardMode Array S Ix2 Double
xss Array S Ix1 Double
theta SRVector
err Fix SRTree
tree = let (Either SRVector Double
yhat, Array S Ix2 Double
jacob) = (forall s. ST s (Either SRVector Double, Array S Ix2 Double))
-> (Either SRVector Double, Array S Ix2 Double)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Either SRVector Double, Array S Ix2 Double))
-> (Either SRVector Double, Array S Ix2 Double))
-> (forall s. ST s (Either SRVector Double, Array S Ix2 Double))
-> (Either SRVector Double, Array S Ix2 Double)
forall a b. (a -> b) -> a -> b
$ (forall x. SRTree (ST s x) -> ST s (SRTree x))
-> (SRTree (Either SRVector Double, Array S Ix2 Double)
-> ST s (Either SRVector Double, Array S Ix2 Double))
-> Fix SRTree
-> ST s (Either SRVector Double, Array S Ix2 Double)
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(forall x. f (m x) -> m (f x)) -> (f a -> m a) -> Fix f -> m a
cataM SRTree (ST s x) -> ST s (SRTree x)
forall x. SRTree (ST s x) -> ST s (SRTree x)
forall {f :: * -> *} {val}.
Applicative f =>
SRTree (f val) -> f (SRTree val)
lToR SRTree (Either SRVector Double, Array S Ix2 Double)
-> ST s (Either SRVector Double, Array S Ix2 Double)
forall {m :: * -> *} {r}.
(Manifest r Double, PrimMonad m) =>
SRTree (Either SRVector Double, Array r Ix2 Double)
-> m (Either SRVector Double, Array r Ix2 Double)
alg Fix SRTree
tree
in (Either SRVector Double -> SRVector
forall {r} {e}.
Load r Ix1 e =>
Either (Array r Ix1 e) e -> Array r Ix1 e
fromEither Either SRVector Double
yhat, S -> SRVector -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S SRVector
err Array S Ix1 Double -> Array S Ix2 Double -> Array S Ix1 Double
forall r e.
(Numeric r e, Manifest r e) =>
Vector r e -> Matrix r e -> Vector r e
><! Array S Ix2 Double
jacob)
where
(Sz Ix1
p) = Array S Ix1 Double -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix1 Double
theta
(Sz (Ix1
m :. Ix1
n)) = Array S Ix2 Double -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix2 Double
xss
cmp :: Comp
cmp = Array S Ix2 Double -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp Array S Ix2 Double
xss
fromEither :: Either (Array r Ix1 e) e -> Array r Ix1 e
fromEither (Left Array r Ix1 e
y) = Array r Ix1 e
y
fromEither (Right e
y) = Comp -> Sz Ix1 -> e -> Array r Ix1 e
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
cmp (Ix1 -> Sz Ix1
forall ix. Index ix => ix -> Sz ix
Sz Ix1
m) e
y
alg :: SRTree (Either SRVector Double, Array r Ix2 Double)
-> m (Either SRVector Double, Array r Ix2 Double)
alg (Var Ix1
ix) = do Array r Ix2 Double
tape <- Sz Ix2 -> Double -> m (MArray (PrimState m) r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Sz ix -> Double -> m (MArray (PrimState m) r ix Double)
M.newMArray (Ix1 -> Ix1 -> Sz Ix2
Sz2 Ix1
m Ix1
p) Double
0
m (MArray (PrimState m) r Ix2 Double)
-> (MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double))
-> m (Array r Ix2 Double)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Comp -> MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix Double -> m (Array r ix Double)
UMA.unsafeFreeze Comp
cmp
(Either SRVector Double, Array r Ix2 Double)
-> m (Either SRVector Double, Array r Ix2 Double)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRVector -> Either SRVector Double
forall a b. a -> Either a b
Left (Array S Ix2 Double
xss Array S Ix2 Double -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix), Array r Ix2 Double
tape)
alg (Const Double
c) = do Array r Ix2 Double
tape <- Sz Ix2 -> Double -> m (MArray (PrimState m) r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Sz ix -> Double -> m (MArray (PrimState m) r ix Double)
M.newMArray (Ix1 -> Ix1 -> Sz Ix2
Sz2 Ix1
m Ix1
p) Double
0
m (MArray (PrimState m) r Ix2 Double)
-> (MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double))
-> m (Array r Ix2 Double)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Comp -> MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix Double -> m (Array r ix Double)
UMA.unsafeFreeze Comp
cmp
(Either SRVector Double, Array r Ix2 Double)
-> m (Either SRVector Double, Array r Ix2 Double)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> Either SRVector Double
forall a b. b -> Either a b
Right Double
c, Array r Ix2 Double
tape)
alg (Param Ix1
ix) = do Array r Ix2 Double
tape <- Sz Ix2
-> (Ix2 -> m Double) -> m (MArray (PrimState m) r Ix2 Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> (ix -> m e) -> m (MArray (PrimState m) r ix e)
M.makeMArrayS (Ix1 -> Ix1 -> Sz Ix2
Sz2 Ix1
m Ix1
p) (\(Ix1
i :. Ix1
j) -> Double -> m Double
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> m Double) -> Double -> m Double
forall a b. (a -> b) -> a -> b
$ if Ix1
jIx1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
==Ix1
ix then Double
1 else Double
0)
m (MArray (PrimState m) r Ix2 Double)
-> (MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double))
-> m (Array r Ix2 Double)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Comp -> MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix Double -> m (Array r ix Double)
UMA.unsafeFreeze Comp
cmp
(Either SRVector Double, Array r Ix2 Double)
-> m (Either SRVector Double, Array r Ix2 Double)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> Either SRVector Double
forall a b. b -> Either a b
Right (Array S Ix1 Double
theta Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
ix), Array r Ix2 Double
tape)
alg (Uni Function
f (Either SRVector Double
t, Array r Ix2 Double
tape')) = do let y :: Array S Ix1 Double
y = S -> SRVector -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S (SRVector -> Array S Ix1 Double)
-> (Either SRVector Double -> SRVector)
-> Either SRVector Double
-> Array S Ix1 Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Either SRVector Double -> SRVector
forall {r} {e}.
Load r Ix1 e =>
Either (Array r Ix1 e) e -> Array r Ix1 e
fromEither (Either SRVector Double -> Array S Ix1 Double)
-> Either SRVector Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ Function -> Either SRVector Double -> Either SRVector Double
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyDer Function
f Either SRVector Double
t
MArray (PrimState m) r Ix2 Double
tape <- Array r Ix2 Double -> m (MArray (PrimState m) r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Array r ix e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Array r ix Double -> m (MArray (PrimState m) r ix Double)
UMA.unsafeThaw Array r Ix2 Double
tape'
[Ix1] -> (Ix1 -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0 .. Ix1
mIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> m ()) -> m ()) -> (Ix1 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
let yi :: Double
yi = Array S Ix1 Double
y Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
i
[Ix1] -> (Ix1 -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0 .. Ix1
pIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> m ()) -> m ()) -> (Ix1 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ix1
j -> do
Double
v <- MArray (PrimState m) r Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) r Ix2 Double
tape (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
j)
MArray (PrimState m) r Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) r Ix2 Double
tape (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
j) (Double
yi Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
v)
Array r Ix2 Double
tapeF <- Comp -> MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix Double -> m (Array r ix Double)
UMA.unsafeFreeze Comp
cmp MArray (PrimState m) r Ix2 Double
tape
(Either SRVector Double, Array r Ix2 Double)
-> m (Either SRVector Double, Array r Ix2 Double)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Function -> Either SRVector Double -> Either SRVector Double
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
f Either SRVector Double
t, Array r Ix2 Double
tapeF)
alg (Bin Op
op (Either SRVector Double
l, Array r Ix2 Double
tl') (Either SRVector Double
r, Array r Ix2 Double
tr')) = do
MArray (PrimState m) r Ix2 Double
tl <- Array r Ix2 Double -> m (MArray (PrimState m) r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Array r ix e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Array r ix Double -> m (MArray (PrimState m) r ix Double)
UMA.unsafeThaw Array r Ix2 Double
tl'
MArray (PrimState m) r Ix2 Double
tr <- Array r Ix2 Double -> m (MArray (PrimState m) r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Array r ix e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Array r ix Double -> m (MArray (PrimState m) r ix Double)
UMA.unsafeThaw Array r Ix2 Double
tr'
let l' :: Either (Array S Ix1 Double) Double
l' = case Either SRVector Double
l of
Left SRVector
y -> Array S Ix1 Double -> Either (Array S Ix1 Double) Double
forall a b. a -> Either a b
Left (Array S Ix1 Double -> Either (Array S Ix1 Double) Double)
-> Array S Ix1 Double -> Either (Array S Ix1 Double) Double
forall a b. (a -> b) -> a -> b
$ S -> SRVector -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S SRVector
y
Right Double
v -> Double -> Either (Array S Ix1 Double) Double
forall a b. b -> Either a b
Right Double
v
r' :: Either (Array S Ix1 Double) Double
r' = case Either SRVector Double
r of
Left SRVector
y -> Array S Ix1 Double -> Either (Array S Ix1 Double) Double
forall a b. a -> Either a b
Left (Array S Ix1 Double -> Either (Array S Ix1 Double) Double)
-> Array S Ix1 Double -> Either (Array S Ix1 Double) Double
forall a b. (a -> b) -> a -> b
$ S -> SRVector -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S SRVector
y
Right Double
v -> Double -> Either (Array S Ix1 Double) Double
forall a b. b -> Either a b
Right Double
v
[Ix1] -> (Ix1 -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0 .. Ix1
mIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> m ()) -> m ()) -> (Ix1 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
let li :: Double
li = Either (Array S Ix1 Double) Double
l' Either (Array S Ix1 Double) Double -> Ix1 -> Double
forall r e ix.
(Manifest r e, Index ix) =>
Either (Array r ix e) e -> ix -> e
!?? Ix1
i
ri :: Double
ri = Either (Array S Ix1 Double) Double
r' Either (Array S Ix1 Double) Double -> Ix1 -> Double
forall r e ix.
(Manifest r e, Index ix) =>
Either (Array r ix e) e -> ix -> e
!?? Ix1
i
[Ix1] -> (Ix1 -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Ix1
0 .. Ix1
pIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> m ()) -> m ()) -> (Ix1 -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Ix1
j -> do
Double
vl <- MArray (PrimState m) r Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) r Ix2 Double
tl (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
j)
Double
vr <- MArray (PrimState m) r Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) r Ix2 Double
tr (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
j)
MArray (PrimState m) r Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) r Ix2 Double
tl (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
j) (Double -> m ()) -> Double -> m ()
forall a b. (a -> b) -> a -> b
$ case Op
op of
Op
Add -> (Double
vlDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
vr)
Op
Sub -> (Double
vlDouble -> Double -> Double
forall a. Num a => a -> a -> a
-Double
vr)
Op
Mul -> (Double
vl Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
ri Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
vr Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
li)
Op
Div -> ((Double
vl Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
ri Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
vr Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
li) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
riDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
Op
Power -> (Double
li Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
ri Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
ri Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vl Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
li Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
li Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vr))
Op
PowerAbs -> (Double -> Double
forall a. Num a => a -> a
abs Double
li Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
ri) Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
vr Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log (Double -> Double
forall a. Num a => a -> a
abs Double
li) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
ri Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vl Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
li)
Op
AQ -> ((Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
riDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
ri) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vl Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
li Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
ri Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
vr) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
riDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
ri) Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
1.5
Array r Ix2 Double
tlF <- Comp -> MArray (PrimState m) r Ix2 Double -> m (Array r Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix e -> m (Array r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Comp -> MArray (PrimState m) r ix Double -> m (Array r ix Double)
UMA.unsafeFreeze Comp
cmp MArray (PrimState m) r Ix2 Double
tl
(Either SRVector Double, Array r Ix2 Double)
-> m (Either SRVector Double, Array r Ix2 Double)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Op
-> Either SRVector Double
-> Either SRVector Double
-> Either SRVector Double
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
op Either SRVector Double
l Either SRVector Double
r, Array r Ix2 Double
tlF)
lToR :: SRTree (f val) -> f (SRTree val)
lToR (Var Ix1
ix) = SRTree val -> f (SRTree val)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ix1 -> SRTree val
forall val. Ix1 -> SRTree val
Var Ix1
ix)
lToR (Param Ix1
ix) = SRTree val -> f (SRTree val)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ix1 -> SRTree val
forall val. Ix1 -> SRTree val
Param Ix1
ix)
lToR (Const Double
c) = SRTree val -> f (SRTree val)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> SRTree val
forall val. Double -> SRTree val
Const Double
c)
lToR (Uni Function
f f val
mt) = Function -> val -> SRTree val
forall val. Function -> val -> SRTree val
Uni Function
f (val -> SRTree val) -> f val -> f (SRTree val)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f val
mt
lToR (Bin Op
op f val
ml f val
mr) = Op -> val -> val -> SRTree val
forall val. Op -> val -> val -> SRTree val
Bin Op
op (val -> val -> SRTree val) -> f val -> f (val -> SRTree val)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f val
ml f (val -> SRTree val) -> f val -> f (SRTree val)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> f val
mr
forwardModeUnique :: SRMatrix -> PVector -> SRVector -> Fix SRTree -> (SRVector, Array S Ix1 Double)
forwardModeUnique :: Array S Ix2 Double
-> Array S Ix1 Double
-> SRVector
-> Fix SRTree
-> (SRVector, Array S Ix1 Double)
forwardModeUnique Array S Ix2 Double
xss Array S Ix1 Double
theta SRVector
err = (DList SRVector -> Array S Ix1 Double)
-> (SRVector, DList SRVector) -> (SRVector, Array S Ix1 Double)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ([SRVector] -> Array S Ix1 Double
forall {r}. Manifest r Double => [SRVector] -> Vector r Double
toGrad ([SRVector] -> Array S Ix1 Double)
-> (DList SRVector -> [SRVector])
-> DList SRVector
-> Array S Ix1 Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DList SRVector -> [SRVector]
forall a. DList a -> [a]
DL.toList) ((SRVector, DList SRVector) -> (SRVector, Array S Ix1 Double))
-> (Fix SRTree -> (SRVector, DList SRVector))
-> Fix SRTree
-> (SRVector, Array S Ix1 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRTree (SRVector, DList SRVector) -> (SRVector, DList SRVector))
-> Fix SRTree -> (SRVector, DList SRVector)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (SRVector, DList SRVector) -> (SRVector, DList SRVector)
alg
where
(Sz Ix1
n) = Array S Ix1 Double -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix1 Double
theta
one :: SRVector
one = Array S Ix2 Double -> Double -> SRVector
replicateAs Array S Ix2 Double
xss Double
1
toGrad :: [SRVector] -> Vector r Double
toGrad [SRVector]
grad = Comp -> [Double] -> Vector r Double
forall r e. Manifest r e => Comp -> [e] -> Vector r e
M.fromList (Array S Ix2 Double -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp Array S Ix2 Double
xss) [SRVector
g SRVector -> SRVector -> Double
forall r e.
(Numeric r e, Source r e) =>
Vector r e -> Vector r e -> e
!.! SRVector
err | SRVector
g <- [SRVector]
grad]
alg :: SRTree (SRVector, DList SRVector) -> (SRVector, DList SRVector)
alg (Var Ix1
ix) = (Array S Ix2 Double
xss Array S Ix2 Double -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix, DList SRVector
forall a. DList a
DL.empty)
alg (Param Ix1
ix) = (Array S Ix2 Double -> Double -> SRVector
replicateAs Array S Ix2 Double
xss (Double -> SRVector) -> Double -> SRVector
forall a b. (a -> b) -> a -> b
$ Array S Ix1 Double
theta Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
ix, SRVector -> DList SRVector
forall a. a -> DList a
DL.singleton SRVector
one)
alg (Const Double
c) = (Array S Ix2 Double -> Double -> SRVector
replicateAs Array S Ix2 Double
xss Double
c, DList SRVector
forall a. DList a
DL.empty)
alg (Uni Function
f (SRVector
v, DList SRVector
gs)) = let v' :: SRVector
v' = Function -> SRVector -> SRVector
forall a. Floating a => Function -> a -> a
evalFun Function
f SRVector
v
dv :: SRVector
dv = Function -> SRVector -> SRVector
forall a. Floating a => Function -> a -> a
derivative Function
f SRVector
v
in (SRVector
v', (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv) DList SRVector
gs)
alg (Bin Op
Add (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
l DList SRVector
r)
alg (Bin Op
Sub (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
-SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
l ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map SRVector -> SRVector
forall a. Num a => a -> a
negate DList SRVector
r))
alg (Bin Op
Mul (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2) DList SRVector
l) ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v1) DList SRVector
r))
alg (Bin Op
Div (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv :: SRVector
dv = ((-SRVector
v1)SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/(SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2))
in (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/SRVector
v2) DList SRVector
l) ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv) DList SRVector
r))
alg (Bin Op
Power (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv1 :: SRVector
dv1 = SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** (SRVector
v2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
one)
dv2 :: SRVector
dv2 = SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
v1
in (SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
v2, (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv1) (DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2) DList SRVector
l) ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv2) DList SRVector
r)))
alg (Bin Op
PowerAbs (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv1 :: SRVector
dv1 = SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
v2
dv2 :: DList SRVector
dv2 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector -> SRVector
forall a. Floating a => a -> a
log (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
v1))) DList SRVector
r
dv3 :: DList SRVector
dv3 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(SRVector
v2 SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ SRVector
v1)) DList SRVector
l
in (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
v2, (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv1) (DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
dv2 DList SRVector
dv3))
alg (Bin Op
AQ (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv1 :: DList SRVector
dv1 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(SRVector
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2)) DList SRVector
l
dv2 :: DList SRVector
dv2 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(-SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2)) DList SRVector
r
in (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/SRVector -> SRVector
forall a. Floating a => a -> a
sqrt(SRVector
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2), (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/(SRVector
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2)SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
**SRVector
1.5) (DList SRVector -> DList SRVector)
-> DList SRVector -> DList SRVector
forall a b. (a -> b) -> a -> b
$ DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
dv1 DList SRVector
dv2)
data TupleF a b = Single a | T a b | Branch a b b deriving (forall a b. (a -> b) -> TupleF a a -> TupleF a b)
-> (forall a b. a -> TupleF a b -> TupleF a a)
-> Functor (TupleF a)
forall a b. a -> TupleF a b -> TupleF a a
forall a b. (a -> b) -> TupleF a a -> TupleF a b
forall a a b. a -> TupleF a b -> TupleF a a
forall a a b. (a -> b) -> TupleF a a -> TupleF a b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a a b. (a -> b) -> TupleF a a -> TupleF a b
fmap :: forall a b. (a -> b) -> TupleF a a -> TupleF a b
$c<$ :: forall a a b. a -> TupleF a b -> TupleF a a
<$ :: forall a b. a -> TupleF a b -> TupleF a a
Functor
type Tuple a = Fix (TupleF a)
reverseModeUnique :: SRMatrix
-> PVector
-> SRVector
-> (SRVector -> SRVector)
-> Fix SRTree
-> (Array D Ix1 Double, Array S Ix1 Double)
reverseModeUnique :: Array S Ix2 Double
-> Array S Ix1 Double
-> SRVector
-> (SRVector -> SRVector)
-> Fix SRTree
-> (SRVector, Array S Ix1 Double)
reverseModeUnique Array S Ix2 Double
xss Array S Ix1 Double
theta SRVector
ys SRVector -> SRVector
f Fix SRTree
t = IO (SRVector, Array S Ix1 Double) -> (SRVector, Array S Ix1 Double)
forall a. IO a -> a
unsafePerformIO (IO (SRVector, Array S Ix1 Double)
-> (SRVector, Array S Ix1 Double))
-> IO (SRVector, Array S Ix1 Double)
-> (SRVector, Array S Ix1 Double)
forall a b. (a -> b) -> a -> b
$
do MArray RealWorld S Ix1 Double
jacob <- Sz Ix1 -> Double -> IO (MArray (PrimState IO) S Ix1 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Sz ix -> Double -> m (MArray (PrimState m) S ix Double)
M.newMArray (Ix1 -> Sz Ix1
forall ix. Index ix => ix -> Sz ix
Sz Ix1
p) Double
0
let !Double
_ = (forall x.
SRTree x
-> (Either SRVector Double, Fix (TupleF (Either SRVector Double)))
-> SRTree
(x,
(Either SRVector Double, Fix (TupleF (Either SRVector Double)))))
-> (SRTree Double
-> (Either SRVector Double, Fix (TupleF (Either SRVector Double)))
-> Double)
-> Fix SRTree
-> (Either SRVector Double, Fix (TupleF (Either SRVector Double)))
-> Double
forall (f :: * -> *) p a.
Functor f =>
(forall x. f x -> p -> f (x, p))
-> (f a -> p -> a) -> Fix f -> p -> a
accu SRTree x
-> (Either SRVector Double, Fix (TupleF (Either SRVector Double)))
-> SRTree
(x,
(Either SRVector Double, Fix (TupleF (Either SRVector Double))))
forall x.
SRTree x
-> (Either SRVector Double, Fix (TupleF (Either SRVector Double)))
-> SRTree
(x,
(Either SRVector Double, Fix (TupleF (Either SRVector Double))))
forall {ix} {b} {a}.
(Index ix, Floating b) =>
SRTree a
-> (Either (Array D ix b) b,
Fix (TupleF (Either (Array D ix b) b)))
-> SRTree
(a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
reverse (MArray RealWorld S Ix1 Double
-> SRTree Double
-> (Either SRVector Double, Fix (TupleF (Either SRVector Double)))
-> Double
forall {r} {b}.
Manifest r Double =>
MArray RealWorld r Ix1 Double
-> SRTree Double -> (Either SRVector Double, b) -> Double
combine MArray RealWorld S Ix1 Double
jacob) Fix SRTree
t ((Double -> Either SRVector Double
forall a b. b -> Either a b
Right Double
1), Fix (TupleF (Either SRVector Double))
fwdMode)
Array S Ix1 Double
j <- MArray (PrimState IO) S Ix1 Double -> IO (Array S Ix1 Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
freezeS MArray RealWorld S Ix1 Double
MArray (PrimState IO) S Ix1 Double
jacob
(SRVector, Array S Ix1 Double) -> IO (SRVector, Array S Ix1 Double)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRVector
v, Array S Ix1 Double
j)
where
fwdMode :: Fix (TupleF (Either SRVector Double))
fwdMode = (SRTree (Fix (TupleF (Either SRVector Double)))
-> Fix (TupleF (Either SRVector Double)))
-> Fix SRTree -> Fix (TupleF (Either SRVector Double))
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Fix (TupleF (Either SRVector Double)))
-> Fix (TupleF (Either SRVector Double))
forward Fix SRTree
t
v :: SRVector
v = Either SRVector Double -> SRVector
forall {r} {e}.
Load r Ix1 e =>
Either (Array r Ix1 e) e -> Array r Ix1 e
fromEither (Either SRVector Double -> SRVector)
-> Either SRVector Double -> SRVector
forall a b. (a -> b) -> a -> b
$ Fix (TupleF (Either SRVector Double)) -> Either SRVector Double
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either SRVector Double))
fwdMode
err :: SRVector
err = SRVector -> SRVector
f SRVector
v SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
ys
(Sz2 Ix1
m Ix1
_) = Array S Ix2 Double -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix2 Double
xss
p :: Ix1
p = Fix SRTree -> Ix1
forall a. Num a => Fix SRTree -> a
countParams Fix SRTree
t
fromEither :: Either (Array r Ix1 e) e -> Array r Ix1 e
fromEither (Left Array r Ix1 e
x) = Array r Ix1 e
x
fromEither (Right e
x) = Comp -> Sz Ix1 -> e -> Array r Ix1 e
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate (Array S Ix2 Double -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp Array S Ix2 Double
xss) (Ix1 -> Sz Ix1
Sz1 Ix1
m) e
x
oneTpl :: a -> Fix (TupleF a)
oneTpl a
x = TupleF a (Fix (TupleF a)) -> Fix (TupleF a)
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (TupleF a (Fix (TupleF a)) -> Fix (TupleF a))
-> TupleF a (Fix (TupleF a)) -> Fix (TupleF a)
forall a b. (a -> b) -> a -> b
$ a -> TupleF a (Fix (TupleF a))
forall a b. a -> TupleF a b
Single a
x
tuple :: a -> Fix (TupleF a) -> Fix (TupleF a)
tuple a
x Fix (TupleF a)
y = TupleF a (Fix (TupleF a)) -> Fix (TupleF a)
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (TupleF a (Fix (TupleF a)) -> Fix (TupleF a))
-> TupleF a (Fix (TupleF a)) -> Fix (TupleF a)
forall a b. (a -> b) -> a -> b
$ a -> Fix (TupleF a) -> TupleF a (Fix (TupleF a))
forall a b. a -> b -> TupleF a b
T a
x Fix (TupleF a)
y
branch :: a -> Fix (TupleF a) -> Fix (TupleF a) -> Fix (TupleF a)
branch a
x Fix (TupleF a)
y Fix (TupleF a)
z = TupleF a (Fix (TupleF a)) -> Fix (TupleF a)
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (TupleF a (Fix (TupleF a)) -> Fix (TupleF a))
-> TupleF a (Fix (TupleF a)) -> Fix (TupleF a)
forall a b. (a -> b) -> a -> b
$ a -> Fix (TupleF a) -> Fix (TupleF a) -> TupleF a (Fix (TupleF a))
forall a b. a -> b -> b -> TupleF a b
Branch a
x Fix (TupleF a)
y Fix (TupleF a)
z
getTop :: Fix (TupleF a) -> a
getTop (Fix (Single a
x)) = a
x
getTop (Fix (T a
x Fix (TupleF a)
y)) = a
x
getTop (Fix (Branch a
x Fix (TupleF a)
y Fix (TupleF a)
z)) = a
x
unCons :: Fix (TupleF a) -> Fix (TupleF a)
unCons (Fix (T a
x Fix (TupleF a)
y)) = Fix (TupleF a)
y
getBranches :: Fix (TupleF a) -> (Fix (TupleF a), Fix (TupleF a))
getBranches (Fix (Branch a
x Fix (TupleF a)
y Fix (TupleF a)
z)) = (Fix (TupleF a)
y,Fix (TupleF a)
z)
forward :: SRTree (Fix (TupleF (Either SRVector Double)))
-> Fix (TupleF (Either SRVector Double))
forward (Var Ix1
ix) = Either SRVector Double -> Fix (TupleF (Either SRVector Double))
forall {a}. a -> Fix (TupleF a)
oneTpl (SRVector -> Either SRVector Double
forall a b. a -> Either a b
Left (SRVector -> Either SRVector Double)
-> SRVector -> Either SRVector Double
forall a b. (a -> b) -> a -> b
$ Array S Ix2 Double
xss Array S Ix2 Double -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix)
forward (Param Ix1
ix) = Either SRVector Double -> Fix (TupleF (Either SRVector Double))
forall {a}. a -> Fix (TupleF a)
oneTpl (Double -> Either SRVector Double
forall a b. b -> Either a b
Right (Double -> Either SRVector Double)
-> Double -> Either SRVector Double
forall a b. (a -> b) -> a -> b
$ Array S Ix1 Double
theta Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
ix)
forward (Const Double
c) = Either SRVector Double -> Fix (TupleF (Either SRVector Double))
forall {a}. a -> Fix (TupleF a)
oneTpl (Double -> Either SRVector Double
forall a b. b -> Either a b
Right Double
c)
forward (Uni Function
g Fix (TupleF (Either SRVector Double))
t) = let v :: Either SRVector Double
v = Fix (TupleF (Either SRVector Double)) -> Either SRVector Double
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either SRVector Double))
t
in Either SRVector Double
-> Fix (TupleF (Either SRVector Double))
-> Fix (TupleF (Either SRVector Double))
forall {a}. a -> Fix (TupleF a) -> Fix (TupleF a)
tuple (Function -> Either SRVector Double -> Either SRVector Double
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
g Either SRVector Double
v) Fix (TupleF (Either SRVector Double))
t
forward (Bin Op
op Fix (TupleF (Either SRVector Double))
l Fix (TupleF (Either SRVector Double))
r) = let vl :: Either SRVector Double
vl = Fix (TupleF (Either SRVector Double)) -> Either SRVector Double
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either SRVector Double))
l
vr :: Either SRVector Double
vr = Fix (TupleF (Either SRVector Double)) -> Either SRVector Double
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either SRVector Double))
r
in Either SRVector Double
-> Fix (TupleF (Either SRVector Double))
-> Fix (TupleF (Either SRVector Double))
-> Fix (TupleF (Either SRVector Double))
forall {a}. a -> Fix (TupleF a) -> Fix (TupleF a) -> Fix (TupleF a)
branch (Op
-> Either SRVector Double
-> Either SRVector Double
-> Either SRVector Double
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
op Either SRVector Double
vl Either SRVector Double
vr) Fix (TupleF (Either SRVector Double))
l Fix (TupleF (Either SRVector Double))
r
reverse :: SRTree a
-> (Either (Array D ix b) b,
Fix (TupleF (Either (Array D ix b) b)))
-> SRTree
(a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
reverse (Var Ix1
ix) (Either (Array D ix b) b
dx, Fix (TupleF (Either (Array D ix b) b))
_) = Ix1
-> SRTree
(a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
forall val. Ix1 -> SRTree val
Var Ix1
ix
reverse (Param Ix1
ix) (Either (Array D ix b) b
dx, Fix (TupleF (Either (Array D ix b) b))
_) = Ix1
-> SRTree
(a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
forall val. Ix1 -> SRTree val
Param Ix1
ix
reverse (Const Double
v) (Either (Array D ix b) b
dx, Fix (TupleF (Either (Array D ix b) b))
_) = Double
-> SRTree
(a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
forall val. Double -> SRTree val
Const Double
v
reverse (Uni Function
f a
t) (Either (Array D ix b) b
dx, Fix (TupleF (Either (Array D ix b) b))
-> Fix (TupleF (Either (Array D ix b) b))
forall {a}. Fix (TupleF a) -> Fix (TupleF a)
unCons -> Fix (TupleF (Either (Array D ix b) b))
v) =
let g' :: Either (Array D ix b) b
g' = Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyDer Function
f (Fix (TupleF (Either (Array D ix b) b)) -> Either (Array D ix b) b
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either (Array D ix b) b))
v)
in Function
-> (a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
-> SRTree
(a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
forall val. Function -> val -> SRTree val
Uni Function
f (a
t, ( Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dx Either (Array D ix b) b
g', Fix (TupleF (Either (Array D ix b) b))
v ))
reverse (Bin Op
op a
l a
r) (Either (Array D ix b) b
dx, Fix (TupleF (Either (Array D ix b) b))
-> (Fix (TupleF (Either (Array D ix b) b)),
Fix (TupleF (Either (Array D ix b) b)))
forall {a}. Fix (TupleF a) -> (Fix (TupleF a), Fix (TupleF a))
getBranches -> (Fix (TupleF (Either (Array D ix b) b))
vl, Fix (TupleF (Either (Array D ix b) b))
vr)) =
let (Either (Array D ix b) b
dxl, Either (Array D ix b) b
dxr) = Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> (Either (Array D ix b) b, Either (Array D ix b) b)
forall {ix} {b}.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> (Either (Array D ix b) b, Either (Array D ix b) b)
diff Op
op Either (Array D ix b) b
dx (Fix (TupleF (Either (Array D ix b) b)) -> Either (Array D ix b) b
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either (Array D ix b) b))
vl) (Fix (TupleF (Either (Array D ix b) b)) -> Either (Array D ix b) b
forall {a}. Fix (TupleF a) -> a
getTop Fix (TupleF (Either (Array D ix b) b))
vr)
in Op
-> (a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
-> (a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
-> SRTree
(a,
(Either (Array D ix b) b, Fix (TupleF (Either (Array D ix b) b))))
forall val. Op -> val -> val -> SRTree val
Bin Op
op (a
l, (Either (Array D ix b) b
dxl, Fix (TupleF (Either (Array D ix b) b))
vl)) (a
r, (Either (Array D ix b) b
dxr, Fix (TupleF (Either (Array D ix b) b))
vr))
diff :: Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> (Either (Array D ix b) b, Either (Array D ix b) b)
diff Op
Add Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = (Either (Array D ix b) b
dx, Either (Array D ix b) b
dx)
diff Op
Sub Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = (Either (Array D ix b) b
dx, Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Num e, Num b) =>
Either (Array r ix e) b -> Either (Array D ix e) b
negate' Either (Array D ix b) b
dx)
diff Op
Mul Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dx Either (Array D ix b) b
gy, Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dx Either (Array D ix b) b
fx)
diff Op
Div Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Div Either (Array D ix b) b
dx Either (Array D ix b) b
gy, Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dx (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Div (Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Num e, Num b) =>
Either (Array r ix e) b -> Either (Array D ix e) b
negate' Either (Array D ix b) b
fx) (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
gy Either (Array D ix b) b
gy)))
diff Op
Power Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = let dxl :: Either (Array D ix b) b
dxl = Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dx (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Power Either (Array D ix b) b
fx (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Sub Either (Array D ix b) b
gy (b -> Either (Array D ix b) b
forall a b. b -> Either a b
Right b
1)))
dv2 :: Either (Array D ix b) b
dv2 = Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
fx (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Log Either (Array D ix b) b
fx)
in (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dxl Either (Array D ix b) b
gy, Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dxl Either (Array D ix b) b
dv2)
diff Op
PowerAbs Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = let dxl :: Either (Array D ix b) b
dxl = Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
gy Either (Array D ix b) b
fx) (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
PowerAbs Either (Array D ix b) b
fx (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Sub Either (Array D ix b) b
gy (b -> Either (Array D ix b) b
forall a b. b -> Either a b
Right b
2)))
dxr :: Either (Array D ix b) b
dxr = Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
LogAbs Either (Array D ix b) b
fx) (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
PowerAbs Either (Array D ix b) b
fx Either (Array D ix b) b
gy)
in (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dxl Either (Array D ix b) b
dx, Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dxr Either (Array D ix b) b
dx)
diff Op
AQ Either (Array D ix b) b
dx Either (Array D ix b) b
fx Either (Array D ix b) b
gy = let dxl :: Either (Array D ix b) b
dxl = Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Recip (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Sqrt (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Add (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Square Either (Array D ix b) b
gy) (b -> Either (Array D ix b) b
forall a b. b -> Either a b
Right b
1)))
dxy :: Either (Array D ix b) b
dxy = Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Div (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
fx Either (Array D ix b) b
gy) (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Cube (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Sqrt (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Add (Function -> Either (Array D ix b) b -> Either (Array D ix b) b
forall ix r e b.
(Index ix, Source r e, Floating e, Floating b) =>
Function -> Either (Array r ix e) b -> Either (Array D ix e) b
applyUni Function
Square Either (Array D ix b) b
gy) (b -> Either (Array D ix b) b
forall a b. b -> Either a b
Right b
1))))
in (Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dxl Either (Array D ix b) b
dx, Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
forall ix b.
(Index ix, Floating b) =>
Op
-> Either (Array D ix b) b
-> Either (Array D ix b) b
-> Either (Array D ix b) b
applyBin Op
Mul Either (Array D ix b) b
dxy Either (Array D ix b) b
dx)
combine :: MArray RealWorld r Ix1 Double
-> SRTree Double -> (Either SRVector Double, b) -> Double
combine MArray RealWorld r Ix1 Double
j (Var Ix1
ix) (Either SRVector Double, b)
s = Double
0
combine MArray RealWorld r Ix1 Double
j (Const Double
_) (Either SRVector Double, b)
s = Double
0
combine MArray RealWorld r Ix1 Double
j (Param Ix1
ix) (Either SRVector Double, b)
s = IO Double -> Double
forall a. IO a -> a
unsafePerformIO (IO Double -> Double) -> IO Double -> Double
forall a b. (a -> b) -> a -> b
$ do
case (Either SRVector Double, b) -> Either SRVector Double
forall a b. (a, b) -> a
fst (Either SRVector Double, b)
s of
Left SRVector
v -> do Double
v' <- SRVector -> SRVector -> IO Double
forall r e (m :: * -> *).
(FoldNumeric r e, Source r e, MonadThrow m) =>
Vector r e -> Vector r e -> m e
dotM SRVector
v SRVector
err
MArray (PrimState IO) r Ix1 Double -> Ix1 -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray RealWorld r Ix1 Double
MArray (PrimState IO) r Ix1 Double
j Ix1
ix Double
v'
Right Double
v -> MArray (PrimState IO) r Ix1 Double -> Ix1 -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray RealWorld r Ix1 Double
MArray (PrimState IO) r Ix1 Double
j Ix1
ix (Double -> IO ()) -> Double -> IO ()
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double) -> Double -> SRVector -> Double
forall ix r e a.
(Index ix, Source r e) =>
(e -> a -> a) -> a -> Array r ix e -> a
M.foldrS (\Double
x Double
acc -> Double
xDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
v Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
acc) Double
0 SRVector
err
MArray (PrimState IO) r Ix1 Double -> Ix1 -> IO Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray RealWorld r Ix1 Double
MArray (PrimState IO) r Ix1 Double
j Ix1
ix
combine MArray RealWorld r Ix1 Double
j (Uni Function
f Double
gs) (Either SRVector Double, b)
s = Double
gs
combine MArray RealWorld r Ix1 Double
j (Bin Op
op Double
l Double
r) (Either SRVector Double, b)
s = Double
lDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
r
forwardModeUniqueJac :: SRMatrix -> PVector -> Fix SRTree -> [PVector]
forwardModeUniqueJac :: Array S Ix2 Double
-> Array S Ix1 Double -> Fix SRTree -> [Array S Ix1 Double]
forwardModeUniqueJac Array S Ix2 Double
xss Array S Ix1 Double
theta = (SRVector, [Array S Ix1 Double]) -> [Array S Ix1 Double]
forall a b. (a, b) -> b
snd ((SRVector, [Array S Ix1 Double]) -> [Array S Ix1 Double])
-> (Fix SRTree -> (SRVector, [Array S Ix1 Double]))
-> Fix SRTree
-> [Array S Ix1 Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DList SRVector -> [Array S Ix1 Double])
-> (SRVector, DList SRVector) -> (SRVector, [Array S Ix1 Double])
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ((SRVector -> Array S Ix1 Double)
-> [SRVector] -> [Array S Ix1 Double]
forall a b. (a -> b) -> [a] -> [b]
map (S -> SRVector -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
M.S) ([SRVector] -> [Array S Ix1 Double])
-> (DList SRVector -> [SRVector])
-> DList SRVector
-> [Array S Ix1 Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DList SRVector -> [SRVector]
forall a. DList a -> [a]
DL.toList) ((SRVector, DList SRVector) -> (SRVector, [Array S Ix1 Double]))
-> (Fix SRTree -> (SRVector, DList SRVector))
-> Fix SRTree
-> (SRVector, [Array S Ix1 Double])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRTree (SRVector, DList SRVector) -> (SRVector, DList SRVector))
-> Fix SRTree -> (SRVector, DList SRVector)
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (SRVector, DList SRVector) -> (SRVector, DList SRVector)
alg
where
(Sz Ix1
n) = Array S Ix1 Double -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix1 Double
theta
one :: SRVector
one = Array S Ix2 Double -> Double -> SRVector
replicateAs Array S Ix2 Double
xss Double
1
alg :: SRTree (SRVector, DList SRVector) -> (SRVector, DList SRVector)
alg (Var Ix1
ix) = (Array S Ix2 Double
xss Array S Ix2 Double -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix, DList SRVector
forall a. DList a
DL.empty)
alg (Param Ix1
ix) = (Array S Ix2 Double -> Double -> SRVector
replicateAs Array S Ix2 Double
xss (Double -> SRVector) -> Double -> SRVector
forall a b. (a -> b) -> a -> b
$ Array S Ix1 Double
theta Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
ix, SRVector -> DList SRVector
forall a. a -> DList a
DL.singleton SRVector
one)
alg (Const Double
c) = (Array S Ix2 Double -> Double -> SRVector
replicateAs Array S Ix2 Double
xss Double
c, DList SRVector
forall a. DList a
DL.empty)
alg (Uni Function
f (SRVector
v, DList SRVector
gs)) = let v' :: SRVector
v' = Function -> SRVector -> SRVector
forall a. Floating a => Function -> a -> a
evalFun Function
f SRVector
v
dv :: SRVector
dv = Function -> SRVector -> SRVector
forall a. Floating a => Function -> a -> a
derivative Function
f SRVector
v
in (SRVector
v', (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv) DList SRVector
gs)
alg (Bin Op
Add (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
l DList SRVector
r)
alg (Bin Op
Sub (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
-SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
l ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map SRVector -> SRVector
forall a. Num a => a -> a
negate DList SRVector
r))
alg (Bin Op
Mul (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2) DList SRVector
l) ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v1) DList SRVector
r))
alg (Bin Op
Div (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv :: SRVector
dv = ((-SRVector
v1)SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/(SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2))
in (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/SRVector
v2, DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/SRVector
v2) DList SRVector
l) ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv) DList SRVector
r))
alg (Bin Op
Power (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv1 :: SRVector
dv1 = SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** (SRVector
v2 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
- SRVector
one)
dv2 :: SRVector
dv2 = SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* SRVector -> SRVector
forall a. Floating a => a -> a
log SRVector
v1
in (SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
v2, (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv1) (DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2) DList SRVector
l) ((SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv2) DList SRVector
r)))
alg (Bin Op
PowerAbs (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv1 :: SRVector
dv1 = SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
v2
dv2 :: DList SRVector
dv2 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
* (SRVector -> SRVector
forall a. Floating a => a -> a
log (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
v1))) DList SRVector
r
dv3 :: DList SRVector
dv3 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(SRVector
v2 SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/ SRVector
v1)) DList SRVector
l
in (SRVector -> SRVector
forall a. Num a => a -> a
abs SRVector
v1 SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
** SRVector
v2, (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
dv1) (DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
dv2 DList SRVector
dv3))
alg (Bin Op
AQ (SRVector
v1, DList SRVector
l) (SRVector
v2, DList SRVector
r)) = let dv1 :: DList SRVector
dv1 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(SRVector
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2)) DList SRVector
l
dv2 :: DList SRVector
dv2 = (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*(-SRVector
v1SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2)) DList SRVector
r
in (SRVector
v1SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/SRVector -> SRVector
forall a. Floating a => a -> a
sqrt(SRVector
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2), (SRVector -> SRVector) -> DList SRVector -> DList SRVector
forall a b. (a -> b) -> DList a -> DList b
DL.map (SRVector -> SRVector -> SRVector
forall a. Fractional a => a -> a -> a
/(SRVector
1 SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
+ SRVector
v2SRVector -> SRVector -> SRVector
forall a. Num a => a -> a -> a
*SRVector
v2)SRVector -> SRVector -> SRVector
forall a. Floating a => a -> a -> a
**SRVector
1.5) (DList SRVector -> DList SRVector)
-> DList SRVector -> DList SRVector
forall a b. (a -> b) -> a -> b
$ DList SRVector -> DList SRVector -> DList SRVector
forall a. DList a -> DList a -> DList a
DL.append DList SRVector
dv1 DList SRVector
dv2)