{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE ParallelListComp #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2010-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-----------------------------------------------------------------------------

module Numeric.AD.Newton
  (
  -- * Newton's Method (Forward AD)
    findZero
  , findZeroNoEq
  , inverse
  , inverseNoEq
  , fixedPoint
  , fixedPointNoEq
  , extremum
  , extremumNoEq
  -- * Gradient Ascent/Descent (Reverse AD)
  , gradientDescent, constrainedDescent, CC(..), eval
  , gradientAscent
  , conjugateGradientDescent
  , conjugateGradientAscent
  , stochasticGradientDescent
  ) where

import Data.Foldable (all, sum)
import Data.Reflection (Reifies)
import Data.Traversable
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Forward (Forward)
import Numeric.AD.Internal.On
import Numeric.AD.Internal.Or
import Numeric.AD.Internal.Reverse (Reverse, Tape)
import Numeric.AD.Internal.Type (AD(..))
import Numeric.AD.Mode
import Numeric.AD.Mode.Reverse as Reverse (gradWith, gradWith', grad')
import Numeric.AD.Rank1.Kahn as Kahn (Kahn, grad)
import qualified Numeric.AD.Rank1.Newton as Rank1
import Prelude hiding (all, mapM, sum)

-- $setup
-- >>> import Data.Complex

-- | The 'findZero' function finds a zero of a scalar function using
-- Newton's method; its output is a stream of increasingly accurate
-- results.  (Modulo the usual caveats.) If the stream becomes constant
-- ("it converges"), no further elements are returned.
--
-- Examples:
--
-- >>> take 10 $ findZero (\x->x^2-4) 1
-- [1.0,2.5,2.05,2.000609756097561,2.0000000929222947,2.000000000000002,2.0]
--
-- >>> last $ take 10 $ findZero ((+1).(^2)) (1 :+ 1)
-- 0.0 :+ 1.0
findZero :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZero :: (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZero forall s. AD s (Forward a) -> AD s (Forward a)
f = (Forward a -> Forward a) -> a -> [a]
forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> [a]
Rank1.findZero (AD Any (Forward a) -> Forward a
forall s a. AD s a -> a
runAD(AD Any (Forward a) -> Forward a)
-> (Forward a -> AD Any (Forward a)) -> Forward a -> Forward a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (Forward a) -> AD Any (Forward a)
forall s. AD s (Forward a) -> AD s (Forward a)
f(AD Any (Forward a) -> AD Any (Forward a))
-> (Forward a -> AD Any (Forward a))
-> Forward a
-> AD Any (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Forward a -> AD Any (Forward a)
forall s a. a -> AD s a
AD)
{-# INLINE findZero #-}

-- | The 'findZeroNoEq' function behaves the same as 'findZero' except that it
-- doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
findZeroNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZeroNoEq :: (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZeroNoEq forall s. AD s (Forward a) -> AD s (Forward a)
f = (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
Rank1.findZeroNoEq (AD Any (Forward a) -> Forward a
forall s a. AD s a -> a
runAD(AD Any (Forward a) -> Forward a)
-> (Forward a -> AD Any (Forward a)) -> Forward a -> Forward a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (Forward a) -> AD Any (Forward a)
forall s. AD s (Forward a) -> AD s (Forward a)
f(AD Any (Forward a) -> AD Any (Forward a))
-> (Forward a -> AD Any (Forward a))
-> Forward a
-> AD Any (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Forward a -> AD Any (Forward a)
forall s a. a -> AD s a
AD)
{-# INLINE findZeroNoEq #-}

-- | The 'inverse' function inverts a scalar function using
-- Newton's method; its output is a stream of increasingly accurate
-- results.  (Modulo the usual caveats.) If the stream becomes
-- constant ("it converges"), no further elements are returned.
--
-- Example:
--
-- >>> last $ take 10 $ inverse sqrt 1 (sqrt 10)
-- 10.0
inverse :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverse :: (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverse forall s. AD s (Forward a) -> AD s (Forward a)
f = (Forward a -> Forward a) -> a -> a -> [a]
forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> a -> [a]
Rank1.inverse (AD Any (Forward a) -> Forward a
forall s a. AD s a -> a
runAD(AD Any (Forward a) -> Forward a)
-> (Forward a -> AD Any (Forward a)) -> Forward a -> Forward a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (Forward a) -> AD Any (Forward a)
forall s. AD s (Forward a) -> AD s (Forward a)
f(AD Any (Forward a) -> AD Any (Forward a))
-> (Forward a -> AD Any (Forward a))
-> Forward a
-> AD Any (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Forward a -> AD Any (Forward a)
forall s a. a -> AD s a
AD)
{-# INLINE inverse  #-}

-- | The 'inverseNoEq' function behaves the same as 'inverse' except that it
-- doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
inverseNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverseNoEq :: (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverseNoEq forall s. AD s (Forward a) -> AD s (Forward a)
f = (Forward a -> Forward a) -> a -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> a -> [a]
Rank1.inverseNoEq (AD Any (Forward a) -> Forward a
forall s a. AD s a -> a
runAD(AD Any (Forward a) -> Forward a)
-> (Forward a -> AD Any (Forward a)) -> Forward a -> Forward a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (Forward a) -> AD Any (Forward a)
forall s. AD s (Forward a) -> AD s (Forward a)
f(AD Any (Forward a) -> AD Any (Forward a))
-> (Forward a -> AD Any (Forward a))
-> Forward a
-> AD Any (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Forward a -> AD Any (Forward a)
forall s a. a -> AD s a
AD)
{-# INLINE inverseNoEq #-}

-- | The 'fixedPoint' function find a fixedpoint of a scalar
-- function using Newton's method; its output is a stream of
-- increasingly accurate results.  (Modulo the usual caveats.)
--
-- If the stream becomes constant ("it converges"), no further
-- elements are returned.
--
-- >>> last $ take 10 $ fixedPoint cos 1
-- 0.7390851332151607
fixedPoint :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPoint :: (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPoint forall s. AD s (Forward a) -> AD s (Forward a)
f = (Forward a -> Forward a) -> a -> [a]
forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> [a]
Rank1.fixedPoint (AD Any (Forward a) -> Forward a
forall s a. AD s a -> a
runAD(AD Any (Forward a) -> Forward a)
-> (Forward a -> AD Any (Forward a)) -> Forward a -> Forward a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (Forward a) -> AD Any (Forward a)
forall s. AD s (Forward a) -> AD s (Forward a)
f(AD Any (Forward a) -> AD Any (Forward a))
-> (Forward a -> AD Any (Forward a))
-> Forward a
-> AD Any (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Forward a -> AD Any (Forward a)
forall s a. a -> AD s a
AD)
{-# INLINE fixedPoint #-}

-- | The 'fixedPointNoEq' function behaves the same as 'fixedPoint' except that
-- it doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
fixedPointNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPointNoEq :: (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPointNoEq forall s. AD s (Forward a) -> AD s (Forward a)
f = (Forward a -> Forward a) -> a -> [a]
forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
Rank1.fixedPointNoEq (AD Any (Forward a) -> Forward a
forall s a. AD s a -> a
runAD(AD Any (Forward a) -> Forward a)
-> (Forward a -> AD Any (Forward a)) -> Forward a -> Forward a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (Forward a) -> AD Any (Forward a)
forall s. AD s (Forward a) -> AD s (Forward a)
f(AD Any (Forward a) -> AD Any (Forward a))
-> (Forward a -> AD Any (Forward a))
-> Forward a
-> AD Any (Forward a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Forward a -> AD Any (Forward a)
forall s a. a -> AD s a
AD)
{-# INLINE fixedPointNoEq #-}

-- | The 'extremum' function finds an extremum of a scalar
-- function using Newton's method; produces a stream of increasingly
-- accurate results.  (Modulo the usual caveats.) If the stream
-- becomes constant ("it converges"), no further elements are returned.
--
-- >>> last $ take 10 $ extremum cos 1
-- 0.0
extremum :: (Fractional a, Eq a) => (forall s. AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))) -> a -> [a]
extremum :: (forall s.
 AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a))))
-> a -> [a]
extremum forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))
f = (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
forall a.
(Fractional a, Eq a) =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
Rank1.extremum (AD Any (On (Forward (Forward a))) -> On (Forward (Forward a))
forall s a. AD s a -> a
runAD(AD Any (On (Forward (Forward a))) -> On (Forward (Forward a)))
-> (On (Forward (Forward a)) -> AD Any (On (Forward (Forward a))))
-> On (Forward (Forward a))
-> On (Forward (Forward a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (On (Forward (Forward a)))
-> AD Any (On (Forward (Forward a)))
forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))
f(AD Any (On (Forward (Forward a)))
 -> AD Any (On (Forward (Forward a))))
-> (On (Forward (Forward a)) -> AD Any (On (Forward (Forward a))))
-> On (Forward (Forward a))
-> AD Any (On (Forward (Forward a)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.On (Forward (Forward a)) -> AD Any (On (Forward (Forward a)))
forall s a. a -> AD s a
AD)
{-# INLINE extremum #-}

-- | The 'extremumNoEq' function behaves the same as 'extremum' except that it
-- doesn't truncate the list once the results become constant. This means it
-- can be used with types without an 'Eq' instance.
extremumNoEq :: Fractional a => (forall s. AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))) -> a -> [a]
extremumNoEq :: (forall s.
 AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a))))
-> a -> [a]
extremumNoEq forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))
f = (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
forall a.
Fractional a =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
Rank1.extremumNoEq (AD Any (On (Forward (Forward a))) -> On (Forward (Forward a))
forall s a. AD s a -> a
runAD(AD Any (On (Forward (Forward a))) -> On (Forward (Forward a)))
-> (On (Forward (Forward a)) -> AD Any (On (Forward (Forward a))))
-> On (Forward (Forward a))
-> On (Forward (Forward a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.AD Any (On (Forward (Forward a)))
-> AD Any (On (Forward (Forward a)))
forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))
f(AD Any (On (Forward (Forward a)))
 -> AD Any (On (Forward (Forward a))))
-> (On (Forward (Forward a)) -> AD Any (On (Forward (Forward a))))
-> On (Forward (Forward a))
-> AD Any (On (Forward (Forward a)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.On (Forward (Forward a)) -> AD Any (On (Forward (Forward a)))
forall s a. a -> AD s a
AD)
{-# INLINE extremumNoEq #-}

-- | The 'gradientDescent' function performs a multivariate
-- optimization, based on the naive-gradient-descent in the file
-- @stalingrad\/examples\/flow-tests\/pre-saddle-1a.vlad@ from the
-- VLAD compiler Stalingrad sources.  Its output is a stream of
-- increasingly accurate results.  (Modulo the usual caveats.)
--
-- It uses reverse mode automatic differentiation to compute the gradient.
gradientDescent :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> [f a]
gradientDescent :: (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientDescent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
x0 = f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x0 a
fx0 f (a, a)
xgx0 a
0.1 (Int
0 :: Int)
  where
    (a
fx0, f (a, a)
xgx0) = (a -> a -> (a, a))
-> (forall s.
    (Reifies s Tape, Typeable s) =>
    f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f (a, a))
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b)
-> (forall s.
    (Reifies s Tape, Typeable s) =>
    f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f b)
Reverse.gradWith' (,) forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a
f f a
x0
    go :: f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x a
fx f (a, a)
xgx !a
eta !Int
i
      | a
eta a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0     = [] -- step size is 0
      | a
fx1 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
fx     = f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x a
fx f (a, a)
xgx (a
etaa -> a -> a
forall a. Fractional a => a -> a -> a
/a
2) Int
0 -- we stepped too far
      | f (a, a) -> Bool
forall a. f (a, a) -> Bool
zeroGrad f (a, a)
xgx = [] -- gradient is 0
      | Bool
otherwise    = f a
x1 f a -> [f a] -> [f a]
forall a. a -> [a] -> [a]
: if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
10
                            then f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x1 a
fx1 f (a, a)
xgx1 (a
etaa -> a -> a
forall a. Num a => a -> a -> a
*a
2) Int
0
                            else f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x1 a
fx1 f (a, a)
xgx1 a
eta (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
      where
        zeroGrad :: f (a, a) -> Bool
zeroGrad = ((a, a) -> Bool) -> f (a, a) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(a
_,a
g) -> a
g a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0)
        x1 :: f a
x1 = ((a, a) -> a) -> f (a, a) -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(a
xi,a
gxi) -> a
xi a -> a -> a
forall a. Num a => a -> a -> a
- a
eta a -> a -> a
forall a. Num a => a -> a -> a
* a
gxi) f (a, a)
xgx
        (a
fx1, f (a, a)
xgx1) = (a -> a -> (a, a))
-> (forall s.
    (Reifies s Tape, Typeable s) =>
    f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f (a, a))
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b)
-> (forall s.
    (Reifies s Tape, Typeable s) =>
    f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f b)
Reverse.gradWith' (,) forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a
f f a
x1
{-# INLINE gradientDescent #-}

data SEnv (f :: * -> *) a = SEnv { SEnv f a -> a
sValue :: a, SEnv f a -> f a
origEnv :: f a }
  deriving (a -> SEnv f b -> SEnv f a
(a -> b) -> SEnv f a -> SEnv f b
(forall a b. (a -> b) -> SEnv f a -> SEnv f b)
-> (forall a b. a -> SEnv f b -> SEnv f a) -> Functor (SEnv f)
forall a b. a -> SEnv f b -> SEnv f a
forall a b. (a -> b) -> SEnv f a -> SEnv f b
forall (f :: * -> *) a b. Functor f => a -> SEnv f b -> SEnv f a
forall (f :: * -> *) a b.
Functor f =>
(a -> b) -> SEnv f a -> SEnv f b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> SEnv f b -> SEnv f a
$c<$ :: forall (f :: * -> *) a b. Functor f => a -> SEnv f b -> SEnv f a
fmap :: (a -> b) -> SEnv f a -> SEnv f b
$cfmap :: forall (f :: * -> *) a b.
Functor f =>
(a -> b) -> SEnv f a -> SEnv f b
Functor, SEnv f a -> Bool
(a -> m) -> SEnv f a -> m
(a -> b -> b) -> b -> SEnv f a -> b
(forall m. Monoid m => SEnv f m -> m)
-> (forall m a. Monoid m => (a -> m) -> SEnv f a -> m)
-> (forall m a. Monoid m => (a -> m) -> SEnv f a -> m)
-> (forall a b. (a -> b -> b) -> b -> SEnv f a -> b)
-> (forall a b. (a -> b -> b) -> b -> SEnv f a -> b)
-> (forall b a. (b -> a -> b) -> b -> SEnv f a -> b)
-> (forall b a. (b -> a -> b) -> b -> SEnv f a -> b)
-> (forall a. (a -> a -> a) -> SEnv f a -> a)
-> (forall a. (a -> a -> a) -> SEnv f a -> a)
-> (forall a. SEnv f a -> [a])
-> (forall a. SEnv f a -> Bool)
-> (forall a. SEnv f a -> Int)
-> (forall a. Eq a => a -> SEnv f a -> Bool)
-> (forall a. Ord a => SEnv f a -> a)
-> (forall a. Ord a => SEnv f a -> a)
-> (forall a. Num a => SEnv f a -> a)
-> (forall a. Num a => SEnv f a -> a)
-> Foldable (SEnv f)
forall a. Eq a => a -> SEnv f a -> Bool
forall a. Num a => SEnv f a -> a
forall a. Ord a => SEnv f a -> a
forall m. Monoid m => SEnv f m -> m
forall a. SEnv f a -> Bool
forall a. SEnv f a -> Int
forall a. SEnv f a -> [a]
forall a. (a -> a -> a) -> SEnv f a -> a
forall m a. Monoid m => (a -> m) -> SEnv f a -> m
forall b a. (b -> a -> b) -> b -> SEnv f a -> b
forall a b. (a -> b -> b) -> b -> SEnv f a -> b
forall (f :: * -> *) a. (Foldable f, Eq a) => a -> SEnv f a -> Bool
forall (f :: * -> *) a. (Foldable f, Num a) => SEnv f a -> a
forall (f :: * -> *) a. (Foldable f, Ord a) => SEnv f a -> a
forall (f :: * -> *) m. (Foldable f, Monoid m) => SEnv f m -> m
forall (f :: * -> *) a. Foldable f => SEnv f a -> Bool
forall (f :: * -> *) a. Foldable f => SEnv f a -> Int
forall (f :: * -> *) a. Foldable f => SEnv f a -> [a]
forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> SEnv f a -> a
forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> SEnv f a -> m
forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> SEnv f a -> b
forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> SEnv f a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: SEnv f a -> a
$cproduct :: forall (f :: * -> *) a. (Foldable f, Num a) => SEnv f a -> a
sum :: SEnv f a -> a
$csum :: forall (f :: * -> *) a. (Foldable f, Num a) => SEnv f a -> a
minimum :: SEnv f a -> a
$cminimum :: forall (f :: * -> *) a. (Foldable f, Ord a) => SEnv f a -> a
maximum :: SEnv f a -> a
$cmaximum :: forall (f :: * -> *) a. (Foldable f, Ord a) => SEnv f a -> a
elem :: a -> SEnv f a -> Bool
$celem :: forall (f :: * -> *) a. (Foldable f, Eq a) => a -> SEnv f a -> Bool
length :: SEnv f a -> Int
$clength :: forall (f :: * -> *) a. Foldable f => SEnv f a -> Int
null :: SEnv f a -> Bool
$cnull :: forall (f :: * -> *) a. Foldable f => SEnv f a -> Bool
toList :: SEnv f a -> [a]
$ctoList :: forall (f :: * -> *) a. Foldable f => SEnv f a -> [a]
foldl1 :: (a -> a -> a) -> SEnv f a -> a
$cfoldl1 :: forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> SEnv f a -> a
foldr1 :: (a -> a -> a) -> SEnv f a -> a
$cfoldr1 :: forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> SEnv f a -> a
foldl' :: (b -> a -> b) -> b -> SEnv f a -> b
$cfoldl' :: forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> SEnv f a -> b
foldl :: (b -> a -> b) -> b -> SEnv f a -> b
$cfoldl :: forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> SEnv f a -> b
foldr' :: (a -> b -> b) -> b -> SEnv f a -> b
$cfoldr' :: forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> SEnv f a -> b
foldr :: (a -> b -> b) -> b -> SEnv f a -> b
$cfoldr :: forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> SEnv f a -> b
foldMap' :: (a -> m) -> SEnv f a -> m
$cfoldMap' :: forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> SEnv f a -> m
foldMap :: (a -> m) -> SEnv f a -> m
$cfoldMap :: forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> SEnv f a -> m
fold :: SEnv f m -> m
$cfold :: forall (f :: * -> *) m. (Foldable f, Monoid m) => SEnv f m -> m
Foldable, Functor (SEnv f)
Foldable (SEnv f)
Functor (SEnv f)
-> Foldable (SEnv f)
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> SEnv f a -> f (SEnv f b))
-> (forall (f :: * -> *) a.
    Applicative f =>
    SEnv f (f a) -> f (SEnv f a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> SEnv f a -> m (SEnv f b))
-> (forall (m :: * -> *) a.
    Monad m =>
    SEnv f (m a) -> m (SEnv f a))
-> Traversable (SEnv f)
(a -> f b) -> SEnv f a -> f (SEnv f b)
forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (f :: * -> *). Traversable f => Functor (SEnv f)
forall (f :: * -> *). Traversable f => Foldable (SEnv f)
forall (f :: * -> *) (m :: * -> *) a.
(Traversable f, Monad m) =>
SEnv f (m a) -> m (SEnv f a)
forall (f :: * -> *) (f :: * -> *) a.
(Traversable f, Applicative f) =>
SEnv f (f a) -> f (SEnv f a)
forall (f :: * -> *) (m :: * -> *) a b.
(Traversable f, Monad m) =>
(a -> m b) -> SEnv f a -> m (SEnv f b)
forall (f :: * -> *) (f :: * -> *) a b.
(Traversable f, Applicative f) =>
(a -> f b) -> SEnv f a -> f (SEnv f b)
forall (m :: * -> *) a. Monad m => SEnv f (m a) -> m (SEnv f a)
forall (f :: * -> *) a.
Applicative f =>
SEnv f (f a) -> f (SEnv f a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> SEnv f a -> m (SEnv f b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> SEnv f a -> f (SEnv f b)
sequence :: SEnv f (m a) -> m (SEnv f a)
$csequence :: forall (f :: * -> *) (m :: * -> *) a.
(Traversable f, Monad m) =>
SEnv f (m a) -> m (SEnv f a)
mapM :: (a -> m b) -> SEnv f a -> m (SEnv f b)
$cmapM :: forall (f :: * -> *) (m :: * -> *) a b.
(Traversable f, Monad m) =>
(a -> m b) -> SEnv f a -> m (SEnv f b)
sequenceA :: SEnv f (f a) -> f (SEnv f a)
$csequenceA :: forall (f :: * -> *) (f :: * -> *) a.
(Traversable f, Applicative f) =>
SEnv f (f a) -> f (SEnv f a)
traverse :: (a -> f b) -> SEnv f a -> f (SEnv f b)
$ctraverse :: forall (f :: * -> *) (f :: * -> *) a b.
(Traversable f, Applicative f) =>
(a -> f b) -> SEnv f a -> f (SEnv f b)
$cp2Traversable :: forall (f :: * -> *). Traversable f => Foldable (SEnv f)
$cp1Traversable :: forall (f :: * -> *). Traversable f => Functor (SEnv f)
Traversable)

-- | Convex constraint, CC, is a GADT wrapper that hides the existential
-- ('s') which is so prevalent in the rest of the API.  This is an
-- engineering convenience for managing the skolems.
data CC f a where
  CC :: forall f a. (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> CC f a

-- |@constrainedDescent obj fs env@ optimizes the convex function @obj@
-- subject to the convex constraints @f <= 0@ where @f `elem` fs@. This is
-- done using a log barrier to model constraints (i.e. Boyd, Chapter 11.3).
-- The returned optimal point for the objective function must satisfy @fs@,
-- but the initial environment, @env@, needn't be feasible.
constrainedDescent :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
                   => (forall s. Reifies s Tape => f (Reverse s a)
                                                -> Reverse s a)
                   -> [CC f a]
                   -> f a
                   -> [(a,f a)]
constrainedDescent :: (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> [CC f a] -> f a -> [(a, f a)]
constrainedDescent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF [] f a
env =
  (f a -> (a, f a)) -> [f a] -> [(a, f a)]
forall a b. (a -> b) -> [a] -> [b]
map (\f a
x -> ((forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF f a
x, f a
x)) ((forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientDescent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF f a
env)
constrainedDescent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF [CC f a]
cs f a
env =
    let s0 :: a
s0       = a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f a
env | CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c <- [CC f a]
cs]
        -- ^ s0 = max ( f_i(0) )
        cs' :: [CC (SEnv f) a]
cs'      = [(forall s. Reifies s Tape => SEnv f (Reverse s a) -> Reverse s a)
-> CC (SEnv f) a
forall (f :: * -> *) a.
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
CC (\(SEnv sVal rest) -> f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
rest Reverse s a -> Reverse s a -> Reverse s a
forall a. Num a => a -> a -> a
- Reverse s a
sVal) | CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c <- [CC f a]
cs]
        -- ^ f_i' = f_i - s0  and thus f_i' <= 0
        envS :: SEnv f a
envS     = a -> f a -> SEnv f a
forall (f :: * -> *) a. a -> f a -> SEnv f a
SEnv a
s0 f a
env
        -- feasible point for f_i', use gd to find feasiblity for f_i
        cc :: [(a, SEnv f a)]
cc       = CC (SEnv f) a
-> [CC (SEnv f) a]
-> SEnv f a
-> (SEnv f a -> Bool)
-> [(a, SEnv f a)]
forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
CC f a -> [CC f a] -> f a -> (f a -> Bool) -> [(a, f a)]
constrainedConvex' ((forall s. Reifies s Tape => SEnv f (Reverse s a) -> Reverse s a)
-> CC (SEnv f) a
forall (f :: * -> *) a.
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
CC forall s. Reifies s Tape => SEnv f (Reverse s a) -> Reverse s a
forall (f :: * -> *) a. SEnv f a -> a
sValue) [CC (SEnv f) a]
cs' SEnv f a
envS ((a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<=a
0) (a -> Bool) -> (SEnv f a -> a) -> SEnv f a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SEnv f a -> a
forall (f :: * -> *) a. SEnv f a -> a
sValue)
    in case ((a, SEnv f a) -> Bool) -> [(a, SEnv f a)] -> [(a, SEnv f a)]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile ((a
0 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<) (a -> Bool) -> ((a, SEnv f a) -> a) -> (a, SEnv f a) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, SEnv f a) -> a
forall a b. (a, b) -> a
fst) (Int -> [(a, SEnv f a)] -> [(a, SEnv f a)]
forall a. Int -> [a] -> [a]
take (Int
2Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
20::Int)) [(a, SEnv f a)]
cc) of
        []                  -> []
        (a
_,SEnv f a
envFeasible) : [(a, SEnv f a)]
_ ->
            CC f a -> [CC f a] -> f a -> (f a -> Bool) -> [(a, f a)]
forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
CC f a -> [CC f a] -> f a -> (f a -> Bool) -> [(a, f a)]
constrainedConvex' ((forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
forall (f :: * -> *) a.
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF) [CC f a]
cs (SEnv f a -> f a
forall (f :: * -> *) a. SEnv f a -> f a
origEnv SEnv f a
envFeasible) (Bool -> f a -> Bool
forall a b. a -> b -> a
const Bool
True)
{-# INLINE constrainedDescent #-}

eval :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> a
eval :: (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
e = (a, f a) -> a
forall a b. (a, b) -> a
fst ((forall s.
 (Reifies s Tape, Typeable s) =>
 f (Reverse s a) -> Reverse s a)
-> f a -> (a, f a)
forall (f :: * -> *) a.
(Traversable f, Num a) =>
(forall s.
 (Reifies s Tape, Typeable s) =>
 f (Reverse s a) -> Reverse s a)
-> f a -> (a, f a)
grad' forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a
f f a
e)
{-# INLINE eval #-}

-- | Like 'constrainedDescent' except the initial point must be feasible.
constrainedConvex' :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
                   => CC f a
                   -> [CC f a]
                   -> f a
                   -> (f a -> Bool)
                   -> [(a,f a)]
constrainedConvex' :: CC f a -> [CC f a] -> f a -> (f a -> Bool) -> [(a, f a)]
constrainedConvex' CC f a
objF [CC f a]
cs f a
env f a -> Bool
term =
  -- 1. Transform cs using a log barrier with increasing t values.
  let os :: [CC f a]
os   = (a -> CC f a) -> [a] -> [CC f a]
forall a b. (a -> b) -> [a] -> [b]
map (CC f a -> [CC f a] -> a -> CC f a
forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
CC f a -> [CC f a] -> a -> CC f a
mkOpt CC f a
objF [CC f a]
cs) [a]
tValues
  -- 2. Iteratively run gradientDescent on each os.
      envs :: [[(a, f a)]]
envs =  [(a
forall a. HasCallStack => a
undefined,f a
env)] [(a, f a)] -> [[(a, f a)]] -> [[(a, f a)]]
forall a. a -> [a] -> [a]
:
              [f a -> CC f a -> [(a, f a)]
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
f a -> CC f a -> [(a, f a)]
gD ((a, f a) -> f a
forall a b. (a, b) -> b
snd ((a, f a) -> f a) -> (a, f a) -> f a
forall a b. (a -> b) -> a -> b
$ [(a, f a)] -> (a, f a)
forall a. [a] -> a
last [(a, f a)]
e) CC f a
o
                          | CC f a
o  <- [CC f a]
os
                          | [(a, f a)]
e  <- [[(a, f a)]]
limEnvs
                          ]
      -- Obtain a finite number of elements from the initial len tValues - 1 lists.
      limEnvs :: [[(a, f a)]]
limEnvs = (([(a, f a)] -> [(a, f a)]) -> [(a, f a)] -> [(a, f a)])
-> [[(a, f a)] -> [(a, f a)]] -> [[(a, f a)]] -> [[(a, f a)]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ([(a, f a)] -> [(a, f a)]) -> [(a, f a)] -> [(a, f a)]
forall a. a -> a
id [[(a, f a)] -> [(a, f a)]]
nrSteps [[(a, f a)]]
envs
  in ((a, f a) -> Bool) -> [(a, f a)] -> [(a, f a)]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (Bool -> Bool
not (Bool -> Bool) -> ((a, f a) -> Bool) -> (a, f a) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> Bool
term (f a -> Bool) -> ((a, f a) -> f a) -> (a, f a) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, f a) -> f a
forall a b. (a, b) -> b
snd) ([[(a, f a)]] -> [(a, f a)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(a, f a)]] -> [(a, f a)]) -> [[(a, f a)]] -> [(a, f a)]
forall a b. (a -> b) -> a -> b
$ Int -> [[(a, f a)]] -> [[(a, f a)]]
forall a. Int -> [a] -> [a]
drop Int
1 [[(a, f a)]]
limEnvs)
 where
  tValues :: [a]
tValues = (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map a -> a
forall a b. (Real a, Fractional b) => a -> b
realToFrac ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
64 ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ (a -> a) -> a -> [a]
forall a. (a -> a) -> a -> [a]
iterate (a -> a -> a
forall a. Num a => a -> a -> a
*a
2) (a
2 :: a)
  nrSteps :: [[(a, f a)] -> [(a, f a)]]
nrSteps = [Int -> [(a, f a)] -> [(a, f a)]
forall a. Int -> [a] -> [a]
take Int
20 | Int
_ <- [Int
1..[a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
tValues]] [[(a, f a)] -> [(a, f a)]]
-> [[(a, f a)] -> [(a, f a)]] -> [[(a, f a)] -> [(a, f a)]]
forall a. [a] -> [a] -> [a]
++ [[(a, f a)] -> [(a, f a)]
forall a. a -> a
id]
  -- | `gD f e` is gradient descent with the evaulated result
  gD :: f a -> CC f a -> [(a, f a)]
gD f a
e (CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f)  = ((forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
e, f a
e) (a, f a) -> [(a, f a)] -> [(a, f a)]
forall a. a -> [a] -> [a]
:
                 (f a -> (a, f a)) -> [f a] -> [(a, f a)]
forall a b. (a -> b) -> [a] -> [b]
map (\f a
x -> ((forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
x, f a
x)) ((forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientDescent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
e)
{-# INLINE constrainedConvex' #-}

-- @mkOpt u fs t@ converts an inequality convex problem (@u,fs@) into an
-- unconstrained convex problem using log barrier @u + -(1/t)log(-f_i)@.
-- As @t@ increases the approximation is more accurate but the gradient
-- decreases, making the gradient descent more expensive.
mkOpt :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
      => CC f a -> [CC f a]
      -> a -> CC f a
mkOpt :: CC f a -> [CC f a] -> a -> CC f a
mkOpt (CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
o) [CC f a]
xs a
t = (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
forall (f :: * -> *) a.
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
CC (\f (Reverse s a)
e -> f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
o f (Reverse s a)
e Reverse s a -> Reverse s a -> Reverse s a
forall a. Num a => a -> a -> a
+ [Reverse s a] -> Reverse s a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((CC f a -> Reverse s a) -> [CC f a] -> [Reverse s a]
forall a b. (a -> b) -> [a] -> [b]
map (\(CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c) -> a
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f (Reverse s a)
-> Reverse s a
forall a (f :: * -> *).
(Traversable f, RealFloat a, Floating a, Ord a) =>
a
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
iHat a
t forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
e) [CC f a]
xs))
{-# INLINE mkOpt #-}

iHat :: forall a f. (Traversable f, RealFloat a, Floating a, Ord a)
     => a
     -> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
     -> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
iHat :: a
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
iHat a
t forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
e =
   let r :: Reverse s a
r = f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
e
   in if Reverse s a
r Reverse s a -> Reverse s a -> Bool
forall a. Ord a => a -> a -> Bool
>= Reverse s a
0 Bool -> Bool -> Bool
|| Reverse s a -> Bool
forall a. RealFloat a => a -> Bool
isNaN Reverse s a
r
        then Reverse s a
1  Reverse s a -> Reverse s a -> Reverse s a
forall a. Fractional a => a -> a -> a
/ Reverse s a
0
        else (-Reverse s a
1 Reverse s a -> Reverse s a -> Reverse s a
forall a. Fractional a => a -> a -> a
/ Scalar (Reverse s a) -> Reverse s a
forall t. Mode t => Scalar t -> t
auto a
Scalar (Reverse s a)
t) Reverse s a -> Reverse s a -> Reverse s a
forall a. Num a => a -> a -> a
* Reverse s a -> Reverse s a
forall a. Floating a => a -> a
log( - (f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c  f (Reverse s a)
e))
{-# INLINE iHat #-}

-- | The 'stochasticGradientDescent' function approximates
-- the true gradient of the constFunction by a gradient at
-- a single example. As the algorithm sweeps through the training
-- set, it performs the update for each training example.
--
-- It uses reverse mode automatic differentiation to compute the gradient
-- The learning rate is constant through out, and is set to 0.001
stochasticGradientDescent :: (Traversable f, Fractional a, Ord a)
  => (forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a)
  -> [e]
  -> f a
  -> [f a]
stochasticGradientDescent :: (forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a)
-> [e] -> f a -> [f a]
stochasticGradientDescent forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a
errorSingle [e]
d0 f a
x0 = f (a, a) -> a -> [e] -> [f a]
go f (a, a)
xgx0 a
0.001 [e]
dLeft
  where
    dLeft :: [e]
dLeft = [e] -> [e]
forall a. [a] -> [a]
tail ([e] -> [e]) -> [e] -> [e]
forall a b. (a -> b) -> a -> b
$ [e] -> [e]
forall a. [a] -> [a]
cycle [e]
d0
    xgx0 :: f (a, a)
xgx0 = (a -> a -> (a, a))
-> (forall s.
    (Reifies s Tape, Typeable s) =>
    f (Reverse s a) -> Reverse s a)
-> f a
-> f (a, a)
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b)
-> (forall s.
    (Reifies s Tape, Typeable s) =>
    f (Reverse s a) -> Reverse s a)
-> f a
-> f b
Reverse.gradWith (,) (e -> f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a
errorSingle ([e] -> e
forall a. [a] -> a
head [e]
d0)) f a
x0
    go :: f (a, a) -> a -> [e] -> [f a]
go f (a, a)
xgx !a
eta [e]
d
      | a
eta a -> a -> Bool
forall a. Eq a => a -> a -> Bool
==a
0       = []
      | Bool
otherwise     = f a
x1 f a -> [f a] -> [f a]
forall a. a -> [a] -> [a]
: f (a, a) -> a -> [e] -> [f a]
go f (a, a)
xgx1 a
eta ([e] -> [e]
forall a. [a] -> [a]
tail [e]
d)
      where
        x1 :: f a
x1 = ((a, a) -> a) -> f (a, a) -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(a
xi, a
gxi) -> a
xi a -> a -> a
forall a. Num a => a -> a -> a
- a
eta a -> a -> a
forall a. Num a => a -> a -> a
* a
gxi) f (a, a)
xgx
        (a
_, f (a, a)
xgx1) = (a -> a -> (a, a))
-> (forall s.
    (Reifies s Tape, Typeable s) =>
    f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f (a, a))
forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b)
-> (forall s.
    (Reifies s Tape, Typeable s) =>
    f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f b)
Reverse.gradWith' (,) (e -> f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a
errorSingle ([e] -> e
forall a. [a] -> a
head [e]
d)) f a
x1
{-# INLINE stochasticGradientDescent #-}

-- | Perform a gradient descent using reverse mode automatic differentiation to compute the gradient.
gradientAscent :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> [f a]
gradientAscent :: (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientAscent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f = (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientDescent (Reverse s a -> Reverse s a
forall a. Num a => a -> a
negate (Reverse s a -> Reverse s a)
-> (f (Reverse s a) -> Reverse s a)
-> f (Reverse s a)
-> Reverse s a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Reverse s a) -> Reverse s a
forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f)
{-# INLINE gradientAscent #-}

-- | Perform a conjugate gradient descent using reverse mode automatic differentiation to compute the gradient, and using forward-on-forward mode for computing extrema.
--
-- >>> let sq x = x * x
-- >>> let rosenbrock [x,y] = sq (1 - x) + 100 * sq (y - sq x)
-- >>> rosenbrock [0,0]
-- 1
-- >>> rosenbrock (conjugateGradientDescent rosenbrock [0, 0] !! 5) < 0.1
-- True
conjugateGradientDescent
  :: (Traversable f, Ord a, Fractional a)
  => (forall s. Chosen s => f (Or s (On (Forward (Forward a))) (Kahn a)) -> Or s (On (Forward (Forward a))) (Kahn a))
  -> f a -> [f a]
conjugateGradientDescent :: (forall s.
 Chosen s =>
 f (Or s (On (Forward (Forward a))) (Kahn a))
 -> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientDescent forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f = (forall s.
 Chosen s =>
 f (Or s (On (Forward (Forward a))) (Kahn a))
 -> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
forall (f :: * -> *) a.
(Traversable f, Ord a, Fractional a) =>
(forall s.
 Chosen s =>
 f (Or s (On (Forward (Forward a))) (Kahn a))
 -> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientAscent (Or s (On (Forward (Forward a))) (Kahn a)
-> Or s (On (Forward (Forward a))) (Kahn a)
forall a. Num a => a -> a
negate (Or s (On (Forward (Forward a))) (Kahn a)
 -> Or s (On (Forward (Forward a))) (Kahn a))
-> (f (Or s (On (Forward (Forward a))) (Kahn a))
    -> Or s (On (Forward (Forward a))) (Kahn a))
-> f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f)
{-# INLINE conjugateGradientDescent #-}

lfu :: Functor f => (f (Or F a b) -> Or F a b) -> f a -> a
lfu :: (f (Or F a b) -> Or F a b) -> f a -> a
lfu f (Or F a b) -> Or F a b
f = Or F a b -> a
forall a b. Or F a b -> a
runL (Or F a b -> a) -> (f a -> Or F a b) -> f a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or F a b) -> Or F a b
f (f (Or F a b) -> Or F a b)
-> (f a -> f (Or F a b)) -> f a -> Or F a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Or F a b) -> f a -> f (Or F a b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Or F a b
forall a b. a -> Or F a b
L

rfu :: Functor f => (f (Or T a b) -> Or T a b) -> f b -> b
rfu :: (f (Or T a b) -> Or T a b) -> f b -> b
rfu f (Or T a b) -> Or T a b
f = Or T a b -> b
forall a b. Or T a b -> b
runR (Or T a b -> b) -> (f b -> Or T a b) -> f b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or T a b) -> Or T a b
f (f (Or T a b) -> Or T a b)
-> (f b -> f (Or T a b)) -> f b -> Or T a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Or T a b) -> f b -> f (Or T a b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap b -> Or T a b
forall b a. b -> Or T a b
R

-- | Perform a conjugate gradient ascent using reverse mode automatic differentiation to compute the gradient.
conjugateGradientAscent
  :: (Traversable f, Ord a, Fractional a)
  => (forall s. Chosen s => f (Or s (On (Forward (Forward a))) (Kahn a)) -> Or s (On (Forward (Forward a))) (Kahn a))
  -> f a -> [f a]
conjugateGradientAscent :: (forall s.
 Chosen s =>
 f (Or s (On (Forward (Forward a))) (Kahn a))
 -> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientAscent forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f f a
x0 = (f a -> Bool) -> [f a] -> [f a]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile ((a -> Bool) -> f a -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\a
a -> a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
a)) (f a -> f a -> f a -> a -> [f a]
go f a
x0 f a
d0 f a
d0 a
delta0)
  where
    dot :: f a -> t a -> a
dot f a
x t a
y = t a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (t a -> a) -> t a -> a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> f a -> t a -> t a
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> a -> a
forall a. Num a => a -> a -> a
(*) f a
x t a
y
    d0 :: f a
d0 = (f (Kahn a) -> Kahn a) -> f a -> f a
forall (f :: * -> *) a.
(Traversable f, Num a) =>
(f (Kahn a) -> Kahn a) -> f a -> f a
Kahn.grad ((f (Or T (On (Forward (Forward a))) (Kahn a))
 -> Or T (On (Forward (Forward a))) (Kahn a))
-> f (Kahn a) -> Kahn a
forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu f (Or T (On (Forward (Forward a))) (Kahn a))
-> Or T (On (Forward (Forward a))) (Kahn a)
forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f) f a
x0
    delta0 :: a
delta0 = f a -> f a -> a
forall a (t :: * -> *) (f :: * -> *).
(Num a, Foldable f, Traversable t) =>
f a -> t a -> a
dot f a
d0 f a
d0
    go :: f a -> f a -> f a -> a -> [f a]
go f a
xi f a
_ri f a
di a
deltai = f a
xi f a -> [f a] -> [f a]
forall a. a -> [a] -> [a]
: f a -> f a -> f a -> a -> [f a]
go f a
xi1 f a
ri1 f a
di1 a
deltai1
      where
        ai :: a
ai = [a] -> a
forall a. [a] -> a
last ([a] -> a) -> [a] -> a
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
20 ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ (On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
forall a.
(Fractional a, Eq a) =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
Rank1.extremum (\On (Forward (Forward a))
a -> (f (Or F (On (Forward (Forward a))) (Kahn a))
 -> Or F (On (Forward (Forward a))) (Kahn a))
-> f (On (Forward (Forward a))) -> On (Forward (Forward a))
forall (f :: * -> *) a b.
Functor f =>
(f (Or F a b) -> Or F a b) -> f a -> a
lfu f (Or F (On (Forward (Forward a))) (Kahn a))
-> Or F (On (Forward (Forward a))) (Kahn a)
forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f (f (On (Forward (Forward a))) -> On (Forward (Forward a)))
-> f (On (Forward (Forward a))) -> On (Forward (Forward a))
forall a b. (a -> b) -> a -> b
$ (a -> a -> On (Forward (Forward a)))
-> f a -> f a -> f (On (Forward (Forward a)))
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\a
x a
d -> Scalar (On (Forward (Forward a))) -> On (Forward (Forward a))
forall t. Mode t => Scalar t -> t
auto a
Scalar (On (Forward (Forward a)))
x On (Forward (Forward a))
-> On (Forward (Forward a)) -> On (Forward (Forward a))
forall a. Num a => a -> a -> a
+ On (Forward (Forward a))
a On (Forward (Forward a))
-> On (Forward (Forward a)) -> On (Forward (Forward a))
forall a. Num a => a -> a -> a
* Scalar (On (Forward (Forward a))) -> On (Forward (Forward a))
forall t. Mode t => Scalar t -> t
auto a
Scalar (On (Forward (Forward a)))
d) f a
xi f a
di) a
0
        xi1 :: f a
xi1 = (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\a
x a
d -> a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
aia -> a -> a
forall a. Num a => a -> a -> a
*a
d) f a
xi f a
di
        ri1 :: f a
ri1 = (f (Kahn a) -> Kahn a) -> f a -> f a
forall (f :: * -> *) a.
(Traversable f, Num a) =>
(f (Kahn a) -> Kahn a) -> f a -> f a
Kahn.grad ((f (Or T (On (Forward (Forward a))) (Kahn a))
 -> Or T (On (Forward (Forward a))) (Kahn a))
-> f (Kahn a) -> Kahn a
forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu f (Or T (On (Forward (Forward a))) (Kahn a))
-> Or T (On (Forward (Forward a))) (Kahn a)
forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f) f a
xi1
        deltai1 :: a
deltai1 = f a -> f a -> a
forall a (t :: * -> *) (f :: * -> *).
(Num a, Foldable f, Traversable t) =>
f a -> t a -> a
dot f a
ri1 f a
ri1
        bi1 :: a
bi1 = a
deltai1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
deltai
        di1 :: f a
di1 = (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\a
r a
d -> a
r a -> a -> a
forall a. Num a => a -> a -> a
+ a
bi1 a -> a -> a
forall a. Num a => a -> a -> a
* a
d) f a
ri1 f a
di
{-# INLINE conjugateGradientAscent #-}