{-# LANGUAGE Rank2Types, BangPatterns, ScopedTypeVariables #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.AD.Newton -- Copyright : (c) Edward Kmett 2010 -- License : BSD3 -- Maintainer : ekmett@gmail.com -- Stability : experimental -- Portability : GHC only -- ----------------------------------------------------------------------------- module Numeric.AD.Newton ( -- * Newton's Method (Forward AD) findZero , findZeroM , inverse , inverseM , fixedPoint , fixedPointM , extremum , extremumM -- * Gradient Ascent/Descent (Reverse AD) , gradientDescent , gradientDescentM , gradientAscent , gradientAscentM -- * Exposed Types , UU, UF, FU, FF , AD(..) , Mode(..) ) where import Prelude hiding (all) import Control.Monad (liftM) import Data.MList import Numeric.AD.Internal import Data.Foldable (all) import Data.Traversable (Traversable) import Numeric.AD.Forward (diff, diff', diffM, diffM') import Numeric.AD.Reverse (gradWith', gradWithM') import Numeric.AD.Internal.Composition -- | The 'findZero' function finds a zero of a scalar function using -- Newton's method; its output is a stream of increasingly accurate -- results. (Modulo the usual caveats.) -- -- Examples: -- -- > take 10 $ findZero (\\x->x^2-4) 1 -- converge to 2.0 -- -- > module Data.Complex -- > take 10 $ findZero ((+1).(^2)) (1 :+ 1) -- converge to (0 :+ 1)@ -- findZero :: Fractional a => UU a -> a -> [a] findZero f = go where go x = x : go (x - y/y') where (y,y') = diff' f x {-# INLINE findZero #-} findZeroM :: (Monad m, Fractional a) => UF m a -> a -> MList m a findZeroM f x0 = MList (go x0) where go x = return $ MCons x $ MList $ do (y,y') <- diffM' f x go (x - y/y') {-# INLINE findZeroM #-} -- | The 'inverseNewton' function inverts a scalar function using -- Newton's method; its output is a stream of increasingly accurate -- results. (Modulo the usual caveats.) -- -- Example: -- -- > take 10 $ inverseNewton sqrt 1 (sqrt 10) -- converges to 10 -- inverse :: Fractional a => UU a -> a -> a -> [a] inverse f x0 y = findZero (\x -> f x - lift y) x0 {-# INLINE inverse #-} inverseM :: (Monad m, Fractional a) => UF m a -> a -> a -> MList m a inverseM f x0 y = findZeroM (\x -> subtract (lift y) `liftM` f x) x0 {-# INLINE inverseM #-} -- | The 'fixedPoint' function find a fixedpoint of a scalar -- function using Newton's method; its output is a stream of -- increasingly accurate results. (Modulo the usual caveats.) -- -- > take 10 $ fixedPoint cos 1 -- converges to 0.7390851332151607 fixedPoint :: Fractional a => UU a -> a -> [a] fixedPoint f = findZero (\x -> f x - x) {-# INLINE fixedPoint #-} fixedPointM :: (Monad m, Fractional a) => UF m a -> a -> MList m a fixedPointM f = findZeroM (\x -> subtract x `liftM` f x) {-# INLINE fixedPointM #-} -- | The 'extremum' function finds an extremum of a scalar -- function using Newton's method; produces a stream of increasingly -- accurate results. (Modulo the usual caveats.) -- -- > take 10 $ extremum cos 1 -- convert to 0 extremum :: Fractional a => UU a -> a -> [a] extremum f = findZero (diff (decomposeMode . f . composeMode)) {-# INLINE extremum #-} extremumM :: (Monad m, Fractional a) => UF m a -> a -> MList m a extremumM f = findZeroM (diffM (liftM decomposeMode . f . composeMode)) {-# INLINE extremumM #-} -- | The 'gradientDescent' function performs a multivariate -- optimization, based on the naive-gradient-descent in the file -- @stalingrad\/examples\/flow-tests\/pre-saddle-1a.vlad@ from the -- VLAD compiler Stalingrad sources. Its output is a stream of -- increasingly accurate results. (Modulo the usual caveats.) -- -- It uses reverse mode automatic differentiation to compute the gradient. gradientDescent :: (Traversable f, Fractional a, Ord a) => FU f a -> f a -> [f a] gradientDescent f x0 = go x0 fx0 xgx0 0.1 (0 :: Int) where (fx0, xgx0) = gradWith' (,) f x0 go x fx xgx !eta !i | eta == 0 = [] -- step size is 0 | fx1 > fx = go x fx xgx (eta/2) 0 -- we stepped too far | zeroGrad xgx = [] -- gradient is 0 | otherwise = x1 : if i == 10 then go x1 fx1 xgx1 (eta*2) 0 else go x1 fx1 xgx1 eta (i+1) where zeroGrad = all (\(_,g) -> g == 0) x1 = fmap (\(xi,gxi) -> xi - eta * gxi) xgx (fx1, xgx1) = gradWith' (,) f x1 {-# INLINE gradientDescent #-} gradientAscent :: (Traversable f, Fractional a, Ord a) => FU f a -> f a -> [f a] gradientAscent f = gradientDescent (negate . f) {-# INLINE gradientAscent #-} -- monadic gradient descent gradientDescentM :: (Traversable f, Monad m, Fractional a, Ord a) => FF f m a -> f a -> MList m (f a) gradientDescentM f x0 = MList $ do (fx0, xgx0) <- gradWithM' (,) f x0 go x0 fx0 xgx0 0.1 (0 :: Int) where go x fx xgx !eta !i | eta == 0 = return MNil -- step size is 0 | otherwise = do (fx1, xgx1) <- gradWithM' (,) f x1 case () of _ | fx1 > fx -> go x fx xgx (eta/2) 0 -- we stepped too far | zeroGrad xgx -> return MNil -- gradient is 0 | otherwise -> return $ MCons x1 $ MList $ if i == 10 then go x1 fx1 xgx1 (eta*2) 0 else go x1 fx1 xgx1 eta (i+1) where x1 = fmap (\(xi,gxi) -> xi - eta * gxi) xgx zeroGrad = all (\(_,g) -> g == 0) {-# INLINE gradientDescentM #-} gradientAscentM :: (Traversable f, Monad m, Fractional a, Ord a) => FF f m a -> f a -> MList m (f a) gradientAscentM f = gradientDescentM (liftM negate . f) {-# INLINE gradientAscentM #-}