{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}

module LevMar.Utils.AD where

import Data.Derivative  ( (:~>), (:>), powVal, idD, pureD, derivAtBasis )
import Data.VectorSpace ( VectorSpace, Scalar, AdditiveGroup )
import Data.Basis       ( HasBasis, Basis )
import Data.MemoTrie    ( HasTrie )


value :: a :~> b -> b
value m = powVal $ m undefined

-- | @firstDeriv f@ returns the first derivative of @f@.
firstDeriv :: (HasBasis a, Basis a ~ (), AdditiveGroup b)
           => (a :> b) -> b
firstDeriv f = powVal $ derivAtBasis f ()

-- | A constant infinitely differentiable function.
constant :: (AdditiveGroup b, HasBasis a, HasTrie (Basis a))
         => b -> a:~>b
constant = const . pureD

-- | @idDAt n ps@ maps each parameter in @ps@ to a /constant/
-- infinitely differentiable function (@const . pureD@), except the @n@th
-- parameter is replaced with the differentiable /identity/ function
-- (@idD@).
idDAt :: (HasBasis r, HasTrie (Basis r), VectorSpace (Scalar r))
      => Int -> [r] -> [r :~> r]
idDAt n = replace n idD . map constant

-- | @replace i r xs@ replaces the @i@th element in @xs@ with @r@.
replace :: Int -> a -> [a] -> [a]
replace i r xs
    | i < 0     = xs
    | otherwise = rep i xs
  where rep _ [] = []
        rep j (y:ys)
          | j > 0     = y : rep (j - 1) ys
          | otherwise = r : ys