{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-----------------------------------------------------------------------------

module Numeric.AD.Rank1.Newton
  (
  -- * Newton's Method (Forward)
    findZero
  , findZeroNoEq
  , inverse
  , inverseNoEq
  , fixedPoint
  , fixedPointNoEq
  , extremum
  , extremumNoEq
  -- * Gradient Ascent/Descent (Kahn)
  , gradientDescent
  , gradientAscent
  ) where

import Prelude hiding (all, mapM)
import Data.Foldable (all)
import Numeric.AD.Mode
import Numeric.AD.Rank1.Forward (Forward, diff, diff')
import Numeric.AD.Rank1.Kahn as Kahn (Kahn, gradWith')
import Numeric.AD.Internal.On
import Numeric.AD.Internal.Combinators (takeWhileDifferent)

-- $setup
-- >>> import Data.Complex

-- | The 'findZero' function finds a zero of a scalar function using
-- Newton's method; its output is a stream of increasingly accurate
-- results.  (Modulo the usual caveats.) If the stream becomes constant
-- ("it converges"), no further elements are returned.
--
-- Examples:
--
-- >>> take 10 $ findZero (\x->x^2-4) 1
-- [1.0,2.5,2.05,2.000609756097561,2.0000000929222947,2.000000000000002,2.0]
--
-- >>> last $ take 10 $ findZero ((+1).(^2)) (1 :+ 1)
-- 0.0 :+ 1.0
findZero :: (Fractional a, Eq a) => (Forward a -> Forward a) -> a -> [a]
findZero :: forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> [a]
findZero Forward a -> Forward a
f = forall a. Eq a => [a] -> [a]
takeWhileDifferent forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq Forward a -> Forward a
f
{-# INLINE findZero #-}

-- | The 'findZeroNoEq' function behaves the same as 'findZero' except that it
-- doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
findZeroNoEq :: Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq :: forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq Forward a -> Forward a
f = forall a. (a -> a) -> a -> [a]
iterate a -> a
go where
  go :: a -> a
go a
x = a
xn where
    (a
y,a
y') = forall a. Num a => (Forward a -> Forward a) -> a -> (a, a)
diff' Forward a -> Forward a
f a
x
    xn :: a
xn = a
x forall a. Num a => a -> a -> a
- a
yforall a. Fractional a => a -> a -> a
/a
y'
{-# INLINE findZeroNoEq #-}

-- | The 'inverse' function inverts a scalar function using
-- Newton's method; its output is a stream of increasingly accurate
-- results.  (Modulo the usual caveats.) If the stream becomes
-- constant ("it converges"), no further elements are returned.
--
-- Example:
--
-- >>> last $ take 10 $ inverse sqrt 1 (sqrt 10)
-- 10.0
inverse :: (Fractional a, Eq a) => (Forward a -> Forward a) -> a -> a -> [a]
inverse :: forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> a -> [a]
inverse Forward a -> Forward a
f a
x0 = forall a. Eq a => [a] -> [a]
takeWhileDifferent forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Fractional a => (Forward a -> Forward a) -> a -> a -> [a]
inverseNoEq Forward a -> Forward a
f a
x0
{-# INLINE inverse  #-}

-- | The 'inverseNoEq' function behaves the same as 'inverse' except that it
-- doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
inverseNoEq :: Fractional a => (Forward a -> Forward a) -> a -> a -> [a]
inverseNoEq :: forall a. Fractional a => (Forward a -> Forward a) -> a -> a -> [a]
inverseNoEq Forward a -> Forward a
f a
x0 a
y = forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq (\Forward a
x -> Forward a -> Forward a
f Forward a
x forall a. Num a => a -> a -> a
- forall t. Mode t => Scalar t -> t
auto a
y) a
x0
{-# INLINE inverseNoEq #-}

-- | The 'fixedPoint' function find a fixedpoint of a scalar
-- function using Newton's method; its output is a stream of
-- increasingly accurate results.  (Modulo the usual caveats.)
--
-- If the stream becomes constant ("it converges"), no further
-- elements are returned.
--
-- >>> last $ take 10 $ fixedPoint cos 1
-- 0.7390851332151607
fixedPoint :: (Fractional a, Eq a) => (Forward a -> Forward a) -> a -> [a]
fixedPoint :: forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> [a]
fixedPoint Forward a -> Forward a
f = forall a. Eq a => [a] -> [a]
takeWhileDifferent forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
fixedPointNoEq Forward a -> Forward a
f
{-# INLINE fixedPoint #-}

-- | The 'fixedPointNoEq' function behaves the same as 'fixedPoint' except that
-- it doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
fixedPointNoEq :: Fractional a => (Forward a -> Forward a) -> a -> [a]
fixedPointNoEq :: forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
fixedPointNoEq Forward a -> Forward a
f = forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq (\Forward a
x -> Forward a -> Forward a
f Forward a
x forall a. Num a => a -> a -> a
- Forward a
x)
{-# INLINE fixedPointNoEq #-}

-- | The 'extremum' function finds an extremum of a scalar
-- function using Newton's method; produces a stream of increasingly
-- accurate results.  (Modulo the usual caveats.) If the stream
-- becomes constant ("it converges"), no further elements are returned.
--
-- >>> last $ take 10 $ extremum cos 1
-- 0.0
extremum :: (Fractional a, Eq a) => (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremum :: forall a.
(Fractional a, Eq a) =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremum On (Forward (Forward a)) -> On (Forward (Forward a))
f = forall a. Eq a => [a] -> [a]
takeWhileDifferent forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
Fractional a =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremumNoEq On (Forward (Forward a)) -> On (Forward (Forward a))
f
{-# INLINE extremum #-}

-- | The 'extremumNoEq' function behaves the same as 'extremum' except that it
-- doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
extremumNoEq :: Fractional a => (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremumNoEq :: forall a.
Fractional a =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
extremumNoEq On (Forward (Forward a)) -> On (Forward (Forward a))
f = forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
findZeroNoEq (forall a. Num a => (Forward a -> Forward a) -> a -> a
diff (forall t. On t -> t
off forall b c a. (b -> c) -> (a -> b) -> a -> c
. On (Forward (Forward a)) -> On (Forward (Forward a))
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. t -> On t
On))
{-# INLINE extremumNoEq #-}

-- | The 'gradientDescent' function performs a multivariate
-- optimization, based on the naive-gradient-descent in the file
-- @stalingrad\/examples\/flow-tests\/pre-saddle-1a.vlad@ from the
-- VLAD compiler Stalingrad sources.  Its output is a stream of
-- increasingly accurate results.  (Modulo the usual caveats.)
--
-- It uses reverse mode automatic differentiation to compute the gradient.
gradientDescent :: (Traversable f, Fractional a, Ord a) => (f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientDescent :: forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientDescent f (Kahn a) -> Kahn a
f f a
x0 = f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x0 a
fx0 f (a, a)
xgx0 a
0.1 (Int
0 :: Int)
  where
    (a
fx0, f (a, a)
xgx0) = forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b) -> (f (Kahn a) -> Kahn a) -> f a -> (a, f b)
Kahn.gradWith' (,) f (Kahn a) -> Kahn a
f f a
x0
    go :: f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x a
fx f (a, a)
xgx !a
eta !Int
i
      | a
eta forall a. Eq a => a -> a -> Bool
== a
0     = [] -- step size is 0
      | a
fx1 forall a. Ord a => a -> a -> Bool
> a
fx     = f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x a
fx f (a, a)
xgx (a
etaforall a. Fractional a => a -> a -> a
/a
2) Int
0 -- we stepped too far
      | forall {a}. f (a, a) -> Bool
zeroGrad f (a, a)
xgx = [] -- gradient is 0
      | Bool
otherwise    = f a
x1 forall a. a -> [a] -> [a]
: if Int
i forall a. Eq a => a -> a -> Bool
== Int
10
                            then f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x1 a
fx1 f (a, a)
xgx1 (a
etaforall a. Num a => a -> a -> a
*a
2) Int
0
                            else f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x1 a
fx1 f (a, a)
xgx1 a
eta (Int
iforall a. Num a => a -> a -> a
+Int
1)
      where
        zeroGrad :: f (a, a) -> Bool
zeroGrad = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(a
_,a
g) -> a
g forall a. Eq a => a -> a -> Bool
== a
0)
        x1 :: f a
x1 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(a
xi,a
gxi) -> a
xi forall a. Num a => a -> a -> a
- a
eta forall a. Num a => a -> a -> a
* a
gxi) f (a, a)
xgx
        (a
fx1, f (a, a)
xgx1) = forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b) -> (f (Kahn a) -> Kahn a) -> f a -> (a, f b)
Kahn.gradWith' (,) f (Kahn a) -> Kahn a
f f a
x1
{-# INLINE gradientDescent #-}

-- | Perform a gradient descent using reverse mode automatic differentiation to compute the gradient.
gradientAscent :: (Traversable f, Fractional a, Ord a) => (f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientAscent :: forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientAscent f (Kahn a) -> Kahn a
f = forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(f (Kahn a) -> Kahn a) -> f a -> [f a]
gradientDescent (forall a. Num a => a -> a
negate forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Kahn a) -> Kahn a
f)
{-# INLINE gradientAscent #-}