module Dvda.AD ( backprop
, rad
) where
import Data.Hashable ( Hashable )
import Dvda.Dual hiding ( fad, fad' )
import Dvda.Expr
import Dvda.HashMap ( HashMap )
import qualified Dvda.HashMap as HM
bpBinary :: (Ord a, Num a)
=> Expr a -> Expr a -> Expr a
-> (Dual (Expr a) -> Dual (Expr a) -> Dual (Expr a))
-> [(Expr a, Expr a)]
bpBinary sens g h binop = gsens ++ hsens
where
dfdg = dualPerturbation $ binop (Dual g 1) (Dual h 0)
dfdh = dualPerturbation $ binop (Dual g 0) (Dual h 1)
gsens = backpropNode (sens*dfdg) g
hsens = backpropNode (sens*dfdh) h
bpUnary :: (Ord a, Num a)
=> Expr a -> Expr a
-> (Dual (Expr a) -> Dual (Expr a))
-> [(Expr a, Expr a)]
bpUnary sens g unop = backpropNode (sens*dfdg) g
where
dfdg = dualPerturbation $ unop (Dual g 1)
backpropNode :: (Ord a, Num a) => Expr a -> Expr a -> [(Expr a, Expr a)]
backpropNode sens e@(ESym (SymDependent name k dep_)) = (e,sens):(backpropNode (sens*primal') dep)
where
primal' = ESym (SymDependent name (k+1) dep_)
dep = ESym dep_
backpropNode sens e@(ESym (Sym _)) = [(e,sens)]
backpropNode _ (EConst _) = []
backpropNode _ (ENum (FromInteger _)) = []
backpropNode _ (EFractional (FromRational _)) = []
backpropNode sens (ENum (Mul x y)) = bpBinary sens x y (*)
backpropNode sens (ENum (Add x y)) = bpBinary sens x y (+)
backpropNode sens (ENum (Sub x y)) = bpBinary sens x y ()
backpropNode sens (ENum (Abs x)) = bpUnary sens x abs
backpropNode sens (ENum (Negate x)) = bpUnary sens x negate
backpropNode sens (ENum (Signum x)) = bpUnary sens x signum
backpropNode sens (EFractional (Div x y)) = bpBinary sens x y (/)
backpropNode sens (EFloating (Pow x y)) = bpBinary sens x y (**)
backpropNode sens (EFloating (LogBase x y)) = bpBinary sens x y logBase
backpropNode sens (EFloating (Exp x)) = bpUnary sens x exp
backpropNode sens (EFloating (Log x)) = bpUnary sens x log
backpropNode sens (EFloating (Sin x)) = bpUnary sens x sin
backpropNode sens (EFloating (Cos x)) = bpUnary sens x cos
backpropNode sens (EFloating (ASin x)) = bpUnary sens x asin
backpropNode sens (EFloating (ATan x)) = bpUnary sens x atan
backpropNode sens (EFloating (ACos x)) = bpUnary sens x acos
backpropNode sens (EFloating (Sinh x)) = bpUnary sens x sinh
backpropNode sens (EFloating (Cosh x)) = bpUnary sens x cosh
backpropNode sens (EFloating (Tanh x)) = bpUnary sens x tanh
backpropNode sens (EFloating (ASinh x)) = bpUnary sens x asinh
backpropNode sens (EFloating (ATanh x)) = bpUnary sens x atanh
backpropNode sens (EFloating (ACosh x)) = bpUnary sens x acosh
backprop :: (Num a, Ord a, Hashable a) => Expr a -> HashMap (Expr a) (Expr a)
backprop x = HM.fromListWith (+) (backpropNode 1 x)
rad :: (Num a, Ord a, Hashable a) => Expr a -> [Expr a] -> [Expr a]
rad x args = map (\arg -> HM.lookupDefault 0 arg sensitivities) args
where
sensitivities = backprop x