module LevMar.Utils.AD where
import Data.Derivative ( (:~>), (:>), powVal, idD, pureD, derivAtBasis )
import Data.VectorSpace ( VectorSpace, Scalar, AdditiveGroup )
import Data.Basis ( HasBasis, Basis )
import Data.MemoTrie ( HasTrie )
value :: a :~> b -> b
value m = powVal $ m undefined
firstDeriv :: (HasBasis a, Basis a ~ (), AdditiveGroup b)
=> (a :> b) -> b
firstDeriv f = powVal $ derivAtBasis f ()
constant :: (AdditiveGroup b, HasBasis a, HasTrie (Basis a))
=> b -> a:~>b
constant = const . pureD
idDAt :: (HasBasis r, HasTrie (Basis r), VectorSpace (Scalar r))
=> Int -> [r] -> [r :~> r]
idDAt n = replace n idD . map constant
replace :: Int -> a -> [a] -> [a]
replace i r xs
| i < 0 = xs
| otherwise = rep i xs
where rep _ [] = []
rep j (y:ys)
| j > 0 = y : rep (j 1) ys
| otherwise = r : ys