{-# OPTIONS_GHC -Wall #-}

module Dvda.AD ( backprop
               , rad
               ) where

import Data.Hashable ( Hashable )

import Dvda.Dual hiding ( fad, fad' )
import Dvda.Expr
import Dvda.HashMap ( HashMap )
import qualified Dvda.HashMap as HM

--fad :: Num a => (Dual a -> [Dual a]) -> a -> [a]
--fad f x = map dualPerturbation $ f (Dual x 1)

bpBinary :: (Ord a, Num a)
            => Expr a -> Expr a -> Expr a
            -> (Dual (Expr a) -> Dual (Expr a) -> Dual (Expr a))
            -> [(Expr a, Expr a)]
bpBinary sens g h binop = gsens ++ hsens
  where
    dfdg = dualPerturbation $ binop (Dual g 1) (Dual h 0)
    dfdh = dualPerturbation $ binop (Dual g 0) (Dual h 1)
    gsens = backpropNode (sens*dfdg) g
    hsens = backpropNode (sens*dfdh) h

bpUnary :: (Ord a, Num a)
           => Expr a -> Expr a
           -> (Dual (Expr a) -> Dual (Expr a))
           -> [(Expr a, Expr a)]
bpUnary sens g unop = backpropNode (sens*dfdg) g
  where
    dfdg = dualPerturbation $ unop (Dual g 1)

backpropNode :: (Ord a, Num a) => Expr a -> Expr a -> [(Expr a, Expr a)]
backpropNode sens e@(ESym (SymDependent name k dep_)) = (e,sens):(backpropNode (sens*primal') dep)
  where
    primal' = ESym (SymDependent name (k+1) dep_)
    dep = ESym dep_
backpropNode sens e@(ESym (Sym _)) = [(e,sens)]
backpropNode _ (EConst _) = []
backpropNode _ (ENum (FromInteger _)) = []
backpropNode _ (EFractional (FromRational _)) = []
backpropNode sens (ENum (Mul x y)) = bpBinary sens x y (*)
backpropNode sens (ENum (Add x y)) = bpBinary sens x y (+)
backpropNode sens (ENum (Sub x y)) = bpBinary sens x y (-)
backpropNode sens (ENum (Abs x))    = bpUnary sens x abs
backpropNode sens (ENum (Negate x)) = bpUnary sens x negate
backpropNode sens (ENum (Signum x)) = bpUnary sens x signum
backpropNode sens (EFractional (Div x y)) = bpBinary sens x y (/)
backpropNode sens (EFloating (Pow x y)) = bpBinary sens x y (**)
backpropNode sens (EFloating (LogBase x y)) = bpBinary sens x y logBase
backpropNode sens (EFloating (Exp x))   = bpUnary sens x exp
backpropNode sens (EFloating (Log x))   = bpUnary sens x log
backpropNode sens (EFloating (Sin x))   = bpUnary sens x sin
backpropNode sens (EFloating (Cos x))   = bpUnary sens x cos
backpropNode sens (EFloating (ASin x))  = bpUnary sens x asin
backpropNode sens (EFloating (ATan x))  = bpUnary sens x atan
backpropNode sens (EFloating (ACos x))  = bpUnary sens x acos
backpropNode sens (EFloating (Sinh x))  = bpUnary sens x sinh
backpropNode sens (EFloating (Cosh x))  = bpUnary sens x cosh
backpropNode sens (EFloating (Tanh x))  = bpUnary sens x tanh
backpropNode sens (EFloating (ASinh x)) = bpUnary sens x asinh
backpropNode sens (EFloating (ATanh x)) = bpUnary sens x atanh
backpropNode sens (EFloating (ACosh x)) = bpUnary sens x acosh

backprop :: (Num a, Ord a, Hashable a) => Expr a -> HashMap (Expr a) (Expr a)
backprop x = HM.fromListWith (+) (backpropNode 1 x)

rad :: (Num a, Ord a, Hashable a) => Expr a -> [Expr a] -> [Expr a]
rad x args = map (\arg -> HM.lookupDefault 0 arg sensitivities) args
  where
    sensitivities = backprop x