{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE ParallelListComp #-}
module Numeric.AD.Newton
(
findZero
, findZeroNoEq
, inverse
, inverseNoEq
, fixedPoint
, fixedPointNoEq
, extremum
, extremumNoEq
, gradientDescent, constrainedDescent, CC(..), eval
, gradientAscent
, conjugateGradientDescent
, conjugateGradientAscent
, stochasticGradientDescent
) where
import Data.Foldable (all, sum)
import Data.Reflection (Reifies)
import Data.Traversable
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Forward (Forward)
import Numeric.AD.Internal.On
import Numeric.AD.Internal.Or
import Numeric.AD.Internal.Reverse (Reverse, Tape)
import Numeric.AD.Internal.Type (AD(..))
import Numeric.AD.Mode
import Numeric.AD.Mode.Reverse as Reverse (gradWith, gradWith', grad')
import Numeric.AD.Rank1.Kahn as Kahn (Kahn, grad)
import qualified Numeric.AD.Rank1.Newton as Rank1
import Prelude hiding (all, mapM, sum)
findZero :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZero :: forall a.
(Fractional a, Eq a) =>
(forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZero forall s. AD s (Forward a) -> AD s (Forward a)
f = forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> [a]
Rank1.findZero (forall s a. AD s a -> a
runADforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s. AD s (Forward a) -> AD s (Forward a)
fforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s a. a -> AD s a
AD)
{-# INLINE findZero #-}
findZeroNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZeroNoEq :: forall a.
Fractional a =>
(forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
findZeroNoEq forall s. AD s (Forward a) -> AD s (Forward a)
f = forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
Rank1.findZeroNoEq (forall s a. AD s a -> a
runADforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s. AD s (Forward a) -> AD s (Forward a)
fforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s a. a -> AD s a
AD)
{-# INLINE findZeroNoEq #-}
inverse :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverse :: forall a.
(Fractional a, Eq a) =>
(forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverse forall s. AD s (Forward a) -> AD s (Forward a)
f = forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> a -> [a]
Rank1.inverse (forall s a. AD s a -> a
runADforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s. AD s (Forward a) -> AD s (Forward a)
fforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s a. a -> AD s a
AD)
{-# INLINE inverse #-}
inverseNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverseNoEq :: forall a.
Fractional a =>
(forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> a -> [a]
inverseNoEq forall s. AD s (Forward a) -> AD s (Forward a)
f = forall a. Fractional a => (Forward a -> Forward a) -> a -> a -> [a]
Rank1.inverseNoEq (forall s a. AD s a -> a
runADforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s. AD s (Forward a) -> AD s (Forward a)
fforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s a. a -> AD s a
AD)
{-# INLINE inverseNoEq #-}
fixedPoint :: (Fractional a, Eq a) => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPoint :: forall a.
(Fractional a, Eq a) =>
(forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPoint forall s. AD s (Forward a) -> AD s (Forward a)
f = forall a.
(Fractional a, Eq a) =>
(Forward a -> Forward a) -> a -> [a]
Rank1.fixedPoint (forall s a. AD s a -> a
runADforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s. AD s (Forward a) -> AD s (Forward a)
fforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s a. a -> AD s a
AD)
{-# INLINE fixedPoint #-}
fixedPointNoEq :: Fractional a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPointNoEq :: forall a.
Fractional a =>
(forall s. AD s (Forward a) -> AD s (Forward a)) -> a -> [a]
fixedPointNoEq forall s. AD s (Forward a) -> AD s (Forward a)
f = forall a. Fractional a => (Forward a -> Forward a) -> a -> [a]
Rank1.fixedPointNoEq (forall s a. AD s a -> a
runADforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s. AD s (Forward a) -> AD s (Forward a)
fforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s a. a -> AD s a
AD)
{-# INLINE fixedPointNoEq #-}
extremum :: (Fractional a, Eq a) => (forall s. AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))) -> a -> [a]
extremum :: forall a.
(Fractional a, Eq a) =>
(forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a))))
-> a -> [a]
extremum forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))
f = forall a.
(Fractional a, Eq a) =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
Rank1.extremum (forall s a. AD s a -> a
runADforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))
fforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s a. a -> AD s a
AD)
{-# INLINE extremum #-}
extremumNoEq :: Fractional a => (forall s. AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))) -> a -> [a]
extremumNoEq :: forall a.
Fractional a =>
(forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a))))
-> a -> [a]
extremumNoEq forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))
f = forall a.
Fractional a =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
Rank1.extremumNoEq (forall s a. AD s a -> a
runADforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s.
AD s (On (Forward (Forward a))) -> AD s (On (Forward (Forward a)))
fforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall s a. a -> AD s a
AD)
{-# INLINE extremumNoEq #-}
gradientDescent :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> [f a]
gradientDescent :: forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientDescent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
x0 = f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x0 a
fx0 f (a, a)
xgx0 a
0.1 (Int
0 :: Int)
where
(a
fx0, f (a, a)
xgx0) = forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b)
-> (forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f b)
Reverse.gradWith' (,) forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
x0
go :: f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x a
fx f (a, a)
xgx !a
eta !Int
i
| a
eta forall a. Eq a => a -> a -> Bool
== a
0 = []
| a
fx1 forall a. Ord a => a -> a -> Bool
> a
fx = f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x a
fx f (a, a)
xgx (a
etaforall a. Fractional a => a -> a -> a
/a
2) Int
0
| forall {a}. f (a, a) -> Bool
zeroGrad f (a, a)
xgx = []
| Bool
otherwise = f a
x1 forall a. a -> [a] -> [a]
: if Int
i forall a. Eq a => a -> a -> Bool
== Int
10
then f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x1 a
fx1 f (a, a)
xgx1 (a
etaforall a. Num a => a -> a -> a
*a
2) Int
0
else f a -> a -> f (a, a) -> a -> Int -> [f a]
go f a
x1 a
fx1 f (a, a)
xgx1 a
eta (Int
iforall a. Num a => a -> a -> a
+Int
1)
where
zeroGrad :: f (a, a) -> Bool
zeroGrad = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(a
_,a
g) -> a
g forall a. Eq a => a -> a -> Bool
== a
0)
x1 :: f a
x1 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(a
xi,a
gxi) -> a
xi forall a. Num a => a -> a -> a
- a
eta forall a. Num a => a -> a -> a
* a
gxi) f (a, a)
xgx
(a
fx1, f (a, a)
xgx1) = forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b)
-> (forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f b)
Reverse.gradWith' (,) forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
x1
{-# INLINE gradientDescent #-}
data SEnv (f :: * -> *) a = SEnv { forall (f :: * -> *) a. SEnv f a -> a
sValue :: a, forall (f :: * -> *) a. SEnv f a -> f a
origEnv :: f a }
deriving (forall a b. a -> SEnv f b -> SEnv f a
forall a b. (a -> b) -> SEnv f a -> SEnv f b
forall (f :: * -> *) a b. Functor f => a -> SEnv f b -> SEnv f a
forall (f :: * -> *) a b.
Functor f =>
(a -> b) -> SEnv f a -> SEnv f b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> SEnv f b -> SEnv f a
$c<$ :: forall (f :: * -> *) a b. Functor f => a -> SEnv f b -> SEnv f a
fmap :: forall a b. (a -> b) -> SEnv f a -> SEnv f b
$cfmap :: forall (f :: * -> *) a b.
Functor f =>
(a -> b) -> SEnv f a -> SEnv f b
Functor, forall a. SEnv f a -> Bool
forall m a. Monoid m => (a -> m) -> SEnv f a -> m
forall a b. (a -> b -> b) -> b -> SEnv f a -> b
forall (f :: * -> *) a. (Foldable f, Eq a) => a -> SEnv f a -> Bool
forall (f :: * -> *) a. (Foldable f, Num a) => SEnv f a -> a
forall (f :: * -> *) a. (Foldable f, Ord a) => SEnv f a -> a
forall (f :: * -> *) m. (Foldable f, Monoid m) => SEnv f m -> m
forall (f :: * -> *) a. Foldable f => SEnv f a -> Bool
forall (f :: * -> *) a. Foldable f => SEnv f a -> Int
forall (f :: * -> *) a. Foldable f => SEnv f a -> [a]
forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> SEnv f a -> a
forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> SEnv f a -> m
forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> SEnv f a -> b
forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> SEnv f a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: forall a. Num a => SEnv f a -> a
$cproduct :: forall (f :: * -> *) a. (Foldable f, Num a) => SEnv f a -> a
sum :: forall a. Num a => SEnv f a -> a
$csum :: forall (f :: * -> *) a. (Foldable f, Num a) => SEnv f a -> a
minimum :: forall a. Ord a => SEnv f a -> a
$cminimum :: forall (f :: * -> *) a. (Foldable f, Ord a) => SEnv f a -> a
maximum :: forall a. Ord a => SEnv f a -> a
$cmaximum :: forall (f :: * -> *) a. (Foldable f, Ord a) => SEnv f a -> a
elem :: forall a. Eq a => a -> SEnv f a -> Bool
$celem :: forall (f :: * -> *) a. (Foldable f, Eq a) => a -> SEnv f a -> Bool
length :: forall a. SEnv f a -> Int
$clength :: forall (f :: * -> *) a. Foldable f => SEnv f a -> Int
null :: forall a. SEnv f a -> Bool
$cnull :: forall (f :: * -> *) a. Foldable f => SEnv f a -> Bool
toList :: forall a. SEnv f a -> [a]
$ctoList :: forall (f :: * -> *) a. Foldable f => SEnv f a -> [a]
foldl1 :: forall a. (a -> a -> a) -> SEnv f a -> a
$cfoldl1 :: forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> SEnv f a -> a
foldr1 :: forall a. (a -> a -> a) -> SEnv f a -> a
$cfoldr1 :: forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> SEnv f a -> a
foldl' :: forall b a. (b -> a -> b) -> b -> SEnv f a -> b
$cfoldl' :: forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> SEnv f a -> b
foldl :: forall b a. (b -> a -> b) -> b -> SEnv f a -> b
$cfoldl :: forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> SEnv f a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> SEnv f a -> b
$cfoldr' :: forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> SEnv f a -> b
foldr :: forall a b. (a -> b -> b) -> b -> SEnv f a -> b
$cfoldr :: forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> SEnv f a -> b
foldMap' :: forall m a. Monoid m => (a -> m) -> SEnv f a -> m
$cfoldMap' :: forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> SEnv f a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> SEnv f a -> m
$cfoldMap :: forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> SEnv f a -> m
fold :: forall m. Monoid m => SEnv f m -> m
$cfold :: forall (f :: * -> *) m. (Foldable f, Monoid m) => SEnv f m -> m
Foldable, forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall {f :: * -> *}. Traversable f => Functor (SEnv f)
forall {f :: * -> *}. Traversable f => Foldable (SEnv f)
forall (f :: * -> *) (m :: * -> *) a.
(Traversable f, Monad m) =>
SEnv f (m a) -> m (SEnv f a)
forall (f :: * -> *) (f :: * -> *) a.
(Traversable f, Applicative f) =>
SEnv f (f a) -> f (SEnv f a)
forall (f :: * -> *) (m :: * -> *) a b.
(Traversable f, Monad m) =>
(a -> m b) -> SEnv f a -> m (SEnv f b)
forall (f :: * -> *) (f :: * -> *) a b.
(Traversable f, Applicative f) =>
(a -> f b) -> SEnv f a -> f (SEnv f b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> SEnv f a -> f (SEnv f b)
sequence :: forall (m :: * -> *) a. Monad m => SEnv f (m a) -> m (SEnv f a)
$csequence :: forall (f :: * -> *) (m :: * -> *) a.
(Traversable f, Monad m) =>
SEnv f (m a) -> m (SEnv f a)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> SEnv f a -> m (SEnv f b)
$cmapM :: forall (f :: * -> *) (m :: * -> *) a b.
(Traversable f, Monad m) =>
(a -> m b) -> SEnv f a -> m (SEnv f b)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
SEnv f (f a) -> f (SEnv f a)
$csequenceA :: forall (f :: * -> *) (f :: * -> *) a.
(Traversable f, Applicative f) =>
SEnv f (f a) -> f (SEnv f a)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> SEnv f a -> f (SEnv f b)
$ctraverse :: forall (f :: * -> *) (f :: * -> *) a b.
(Traversable f, Applicative f) =>
(a -> f b) -> SEnv f a -> f (SEnv f b)
Traversable)
data CC f a where
CC :: forall f a. (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> CC f a
constrainedDescent :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
=> (forall s. Reifies s Tape => f (Reverse s a)
-> Reverse s a)
-> [CC f a]
-> f a
-> [(a,f a)]
constrainedDescent :: forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> [CC f a] -> f a -> [(a, f a)]
constrainedDescent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF [] f a
env =
forall a b. (a -> b) -> [a] -> [b]
map (\f a
x -> (forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF f a
x, f a
x)) (forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientDescent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF f a
env)
constrainedDescent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF [CC f a]
cs f a
env =
let s0 :: a
s0 = a
1 forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f a
env | CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c <- [CC f a]
cs]
cs' :: [CC (SEnv f) a]
cs' = [forall (f :: * -> *) a.
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
CC (\(SEnv Reverse s a
sVal f (Reverse s a)
rest) -> forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
rest forall a. Num a => a -> a -> a
- Reverse s a
sVal) | CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c <- [CC f a]
cs]
envS :: SEnv f a
envS = forall (f :: * -> *) a. a -> f a -> SEnv f a
SEnv a
s0 f a
env
cc :: [(a, SEnv f a)]
cc = forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
CC f a -> [CC f a] -> f a -> (f a -> Bool) -> [(a, f a)]
constrainedConvex' (forall (f :: * -> *) a.
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
CC forall (f :: * -> *) a. SEnv f a -> a
sValue) [CC (SEnv f) a]
cs' SEnv f a
envS ((forall a. Ord a => a -> a -> Bool
<=a
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. SEnv f a -> a
sValue)
in case forall a. (a -> Bool) -> [a] -> [a]
dropWhile ((a
0 forall a. Ord a => a -> a -> Bool
<) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) (forall a. Int -> [a] -> [a]
take (Int
2forall a b. (Num a, Integral b) => a -> b -> a
^(Int
20::Int)) [(a, SEnv f a)]
cc) of
[] -> []
(a
_,SEnv f a
envFeasible) : [(a, SEnv f a)]
_ ->
forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
CC f a -> [CC f a] -> f a -> (f a -> Bool) -> [(a, f a)]
constrainedConvex' (forall (f :: * -> *) a.
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
objF) [CC f a]
cs (forall (f :: * -> *) a. SEnv f a -> f a
origEnv SEnv f a
envFeasible) (forall a b. a -> b -> a
const Bool
True)
{-# INLINE constrainedDescent #-}
eval :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> a
eval :: forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
e = forall a b. (a, b) -> a
fst (forall (f :: * -> *) a.
(Traversable f, Num a) =>
(forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a -> (a, f a)
grad' forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
e)
{-# INLINE eval #-}
constrainedConvex' :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
=> CC f a
-> [CC f a]
-> f a
-> (f a -> Bool)
-> [(a,f a)]
constrainedConvex' :: forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
CC f a -> [CC f a] -> f a -> (f a -> Bool) -> [(a, f a)]
constrainedConvex' CC f a
objF [CC f a]
cs f a
env f a -> Bool
term =
let os :: [CC f a]
os = forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
CC f a -> [CC f a] -> a -> CC f a
mkOpt CC f a
objF [CC f a]
cs) [a]
tValues
envs :: [[(a, f a)]]
envs = [(forall a. HasCallStack => a
undefined,f a
env)] forall a. a -> [a] -> [a]
:
[forall {f :: * -> *} {a}.
(Traversable f, Fractional a, Ord a) =>
f a -> CC f a -> [(a, f a)]
gD (forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [(a, f a)]
e) CC f a
o
| CC f a
o <- [CC f a]
os
| [(a, f a)]
e <- [[(a, f a)]]
limEnvs
]
limEnvs :: [[(a, f a)]]
limEnvs = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. a -> a
id [[(a, f a)] -> [(a, f a)]]
nrSteps [[(a, f a)]]
envs
in forall a. (a -> Bool) -> [a] -> [a]
dropWhile (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> Bool
term forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop Int
1 [[(a, f a)]]
limEnvs)
where
tValues :: [a]
tValues = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (Real a, Fractional b) => a -> b
realToFrac forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
64 forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate (forall a. Num a => a -> a -> a
*a
2) (a
2 :: a)
nrSteps :: [[(a, f a)] -> [(a, f a)]]
nrSteps = [forall a. Int -> [a] -> [a]
take Int
20 | Int
_ <- [Int
1..forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
tValues]] forall a. [a] -> [a] -> [a]
++ [forall a. a -> a
id]
gD :: f a -> CC f a -> [(a, f a)]
gD f a
e (CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f) = (forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
e, f a
e) forall a. a -> [a] -> [a]
:
forall a b. (a -> b) -> [a] -> [b]
map (\f a
x -> (forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> a
eval forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
x, f a
x)) (forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientDescent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f f a
e)
{-# INLINE constrainedConvex' #-}
mkOpt :: forall f a. (Traversable f, RealFloat a, Floating a, Ord a)
=> CC f a -> [CC f a]
-> a -> CC f a
mkOpt :: forall (f :: * -> *) a.
(Traversable f, RealFloat a, Floating a, Ord a) =>
CC f a -> [CC f a] -> a -> CC f a
mkOpt (CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
o) [CC f a]
xs a
t = forall (f :: * -> *) a.
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> CC f a
CC (\f (Reverse s a)
e -> forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
o f (Reverse s a)
e forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (\(CC forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c) -> forall a (f :: * -> *).
(Traversable f, RealFloat a, Floating a, Ord a) =>
a
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
iHat a
t forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
e) [CC f a]
xs))
{-# INLINE mkOpt #-}
iHat :: forall a f. (Traversable f, RealFloat a, Floating a, Ord a)
=> a
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
iHat :: forall a (f :: * -> *).
(Traversable f, RealFloat a, Floating a, Ord a) =>
a
-> (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
iHat a
t forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
e =
let r :: Reverse s a
r = forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
e
in if Reverse s a
r forall a. Ord a => a -> a -> Bool
>= Reverse s a
0 Bool -> Bool -> Bool
|| forall a. RealFloat a => a -> Bool
isNaN Reverse s a
r
then Reverse s a
1 forall a. Fractional a => a -> a -> a
/ Reverse s a
0
else (-Reverse s a
1 forall a. Fractional a => a -> a -> a
/ forall t. Mode t => Scalar t -> t
auto a
t) forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log( - (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
c f (Reverse s a)
e))
{-# INLINE iHat #-}
stochasticGradientDescent :: (Traversable f, Fractional a, Ord a)
=> (forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a)
-> [e]
-> f a
-> [f a]
stochasticGradientDescent :: forall (f :: * -> *) a e.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a)
-> [e] -> f a -> [f a]
stochasticGradientDescent forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a
errorSingle [e]
d0 f a
x0 = f (a, a) -> a -> [e] -> [f a]
go f (a, a)
xgx0 a
0.001 [e]
dLeft
where
dLeft :: [e]
dLeft = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
cycle [e]
d0
xgx0 :: f (a, a)
xgx0 = forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b)
-> (forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a
-> f b
Reverse.gradWith (,) (forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a
errorSingle (forall a. [a] -> a
head [e]
d0)) f a
x0
go :: f (a, a) -> a -> [e] -> [f a]
go f (a, a)
xgx !a
eta [e]
d
| a
eta forall a. Eq a => a -> a -> Bool
==a
0 = []
| Bool
otherwise = f a
x1 forall a. a -> [a] -> [a]
: f (a, a) -> a -> [e] -> [f a]
go f (a, a)
xgx1 a
eta (forall a. [a] -> [a]
tail [e]
d)
where
x1 :: f a
x1 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(a
xi, a
gxi) -> a
xi forall a. Num a => a -> a -> a
- a
eta forall a. Num a => a -> a -> a
* a
gxi) f (a, a)
xgx
(a
_, f (a, a)
xgx1) = forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(a -> a -> b)
-> (forall s.
(Reifies s Tape, Typeable s) =>
f (Reverse s a) -> Reverse s a)
-> f a
-> (a, f b)
Reverse.gradWith' (,) (forall s. Reifies s Tape => e -> f (Reverse s a) -> Reverse s a
errorSingle (forall a. [a] -> a
head [e]
d)) f a
x1
{-# INLINE stochasticGradientDescent #-}
gradientAscent :: (Traversable f, Fractional a, Ord a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> [f a]
gradientAscent :: forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientAscent forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f = forall (f :: * -> *) a.
(Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> [f a]
gradientDescent (forall a. Num a => a -> a
negate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a
f)
{-# INLINE gradientAscent #-}
conjugateGradientDescent
:: (Traversable f, Ord a, Fractional a)
=> (forall s. Chosen s => f (Or s (On (Forward (Forward a))) (Kahn a)) -> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientDescent :: forall (f :: * -> *) a.
(Traversable f, Ord a, Fractional a) =>
(forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientDescent forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f = forall (f :: * -> *) a.
(Traversable f, Ord a, Fractional a) =>
(forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientAscent (forall a. Num a => a -> a
negate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f)
{-# INLINE conjugateGradientDescent #-}
lfu :: Functor f => (f (Or F a b) -> Or F a b) -> f a -> a
lfu :: forall (f :: * -> *) a b.
Functor f =>
(f (Or F a b) -> Or F a b) -> f a -> a
lfu f (Or F a b) -> Or F a b
f = forall a b. Or F a b -> a
runL forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or F a b) -> Or F a b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. a -> Or F a b
L
rfu :: Functor f => (f (Or T a b) -> Or T a b) -> f b -> b
rfu :: forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu f (Or T a b) -> Or T a b
f = forall a b. Or T a b -> b
runR forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Or T a b) -> Or T a b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall b a. b -> Or T a b
R
conjugateGradientAscent
:: (Traversable f, Ord a, Fractional a)
=> (forall s. Chosen s => f (Or s (On (Forward (Forward a))) (Kahn a)) -> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientAscent :: forall (f :: * -> *) a.
(Traversable f, Ord a, Fractional a) =>
(forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a))
-> f a -> [f a]
conjugateGradientAscent forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f f a
x0 = forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\a
a -> a
a forall a. Eq a => a -> a -> Bool
== a
a)) (f a -> f a -> f a -> a -> [f a]
go f a
x0 f a
d0 f a
d0 a
delta0)
where
dot :: f a -> t a -> a
dot f a
x t a
y = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT forall a. Num a => a -> a -> a
(*) f a
x t a
y
d0 :: f a
d0 = forall (f :: * -> *) a.
(Traversable f, Num a) =>
(f (Kahn a) -> Kahn a) -> f a -> f a
Kahn.grad (forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f) f a
x0
delta0 :: a
delta0 = forall {a} {t :: * -> *} {f :: * -> *}.
(Num a, Foldable f, Traversable t) =>
f a -> t a -> a
dot f a
d0 f a
d0
go :: f a -> f a -> f a -> a -> [f a]
go f a
xi f a
_ri f a
di a
deltai = f a
xi forall a. a -> [a] -> [a]
: f a -> f a -> f a -> a -> [f a]
go f a
xi1 f a
ri1 f a
di1 a
deltai1
where
ai :: a
ai = forall a. [a] -> a
last forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
20 forall a b. (a -> b) -> a -> b
$ forall a.
(Fractional a, Eq a) =>
(On (Forward (Forward a)) -> On (Forward (Forward a))) -> a -> [a]
Rank1.extremum (\On (Forward (Forward a))
a -> forall (f :: * -> *) a b.
Functor f =>
(f (Or F a b) -> Or F a b) -> f a -> a
lfu forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\a
x a
d -> forall t. Mode t => Scalar t -> t
auto a
x forall a. Num a => a -> a -> a
+ On (Forward (Forward a))
a forall a. Num a => a -> a -> a
* forall t. Mode t => Scalar t -> t
auto a
d) f a
xi f a
di) a
0
xi1 :: f a
xi1 = forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\a
x a
d -> a
x forall a. Num a => a -> a -> a
+ a
aiforall a. Num a => a -> a -> a
*a
d) f a
xi f a
di
ri1 :: f a
ri1 = forall (f :: * -> *) a.
(Traversable f, Num a) =>
(f (Kahn a) -> Kahn a) -> f a -> f a
Kahn.grad (forall (f :: * -> *) a b.
Functor f =>
(f (Or T a b) -> Or T a b) -> f b -> b
rfu forall s.
Chosen s =>
f (Or s (On (Forward (Forward a))) (Kahn a))
-> Or s (On (Forward (Forward a))) (Kahn a)
f) f a
xi1
deltai1 :: a
deltai1 = forall {a} {t :: * -> *} {f :: * -> *}.
(Num a, Foldable f, Traversable t) =>
f a -> t a -> a
dot f a
ri1 f a
ri1
bi1 :: a
bi1 = a
deltai1 forall a. Fractional a => a -> a -> a
/ a
deltai
di1 :: f a
di1 = forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT (\a
r a
d -> a
r forall a. Num a => a -> a -> a
+ a
bi1 forall a. Num a => a -> a -> a
* a
d) f a
ri1 f a
di
{-# INLINE conjugateGradientAscent #-}