module Numeric.AD.Newton
(
findZero
, inverse
, fixedPoint
, extremum
, gradientDescent
, AD(..)
, Mode(..)
) where
import Prelude hiding (all)
import Numeric.AD.Classes
import Numeric.AD.Internal
import Data.Foldable (all)
import Data.Traversable (Traversable)
import Numeric.AD.Forward (diff, diff2)
import Numeric.AD.Reverse (grad2)
findZero :: Fractional a => (forall s. Mode s => AD s a -> AD s a) -> a -> [a]
findZero f x0 = iterate (\x -> let (y,y') = diff2 f x in x y/y') x0
inverse :: Fractional a => (forall s. Mode s => AD s a -> AD s a) -> a -> a -> [a]
inverse f x0 y = findZero (\x -> f x lift y) x0
fixedPoint :: Fractional a => (forall s. Mode s => AD s a -> AD s a) -> a -> [a]
fixedPoint f = findZero (\x -> f x x)
extremum :: Fractional a => (forall t s. (Mode t, Mode s) => AD t (AD s a) -> AD t (AD s a)) -> a -> [a]
extremum f x0 = findZero (diff f) x0
gradientDescent :: (Traversable f, Fractional a, Ord a) => (forall s. Mode s => f (AD s a) -> AD s a) -> f a -> [f a]
gradientDescent f x0 = go x0 fx0 gx0 0.1 (0 :: Int)
where
(fx0, gx0) = grad2 f x0
go x fx gx !eta !i
| eta == 0 = []
| fx1 > fx = go x fx gx (eta/2) 0
| all (==0) gx = []
| otherwise = x1 : if i == 10
then go x1 fx1 gx1 (eta*2) 0
else go x1 fx1 gx1 eta (i+1)
where
x1 = zipWithT (\xi gxi -> xi eta * gxi) x gx
(fx1, gx1) = grad2 f x1