{-# LANGUAGE RecordWildCards #-}

module Torch.Optim where

import Control.Monad.State
import Control.Monad (foldM)
import System.Mem (performGC)
import Torch.Autograd
import Torch.Functional
import Torch.Internal.GC (mallocTrim)
import Torch.NN
import Torch.Tensor
import Torch.TensorFactories
import Prelude hiding (sqrt)

type LearningRate = Tensor

type Loss = Tensor

newtype Gradients = Gradients [Tensor] deriving (Int -> Gradients -> ShowS
[Gradients] -> ShowS
Gradients -> String
(Int -> Gradients -> ShowS)
-> (Gradients -> String)
-> ([Gradients] -> ShowS)
-> Show Gradients
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Gradients -> ShowS
showsPrec :: Int -> Gradients -> ShowS
$cshow :: Gradients -> String
show :: Gradients -> String
$cshowList :: [Gradients] -> ShowS
showList :: [Gradients] -> ShowS
Show)

newtype OptimizerState option = OptimizerState option

grad' :: Loss -> [Parameter] -> Gradients
grad' :: Tensor -> [Parameter] -> Gradients
grad' Tensor
t [Parameter]
p = [Tensor] -> Gradients
Gradients (Tensor -> [Parameter] -> [Tensor]
grad Tensor
t [Parameter]
p)

class Optimizer optimizer where
  step :: LearningRate -> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)

  -- | run a single iteration of an optimizer, returning new parameters and updated optimizer state
  runStep :: (Parameterized model) => model -> optimizer -> Loss -> LearningRate -> IO (model, optimizer)
  runStep model
paramState optimizer
optState Tensor
lossValue = model -> optimizer -> Gradients -> Tensor -> IO (model, optimizer)
forall model.
Parameterized model =>
model -> optimizer -> Gradients -> Tensor -> IO (model, optimizer)
forall optimizer model.
(Optimizer optimizer, Parameterized model) =>
model -> optimizer -> Gradients -> Tensor -> IO (model, optimizer)
runStep' model
paramState optimizer
optState (Tensor -> [Parameter] -> Gradients
grad' Tensor
lossValue ([Parameter] -> Gradients) -> [Parameter] -> Gradients
forall a b. (a -> b) -> a -> b
$ model -> [Parameter]
forall f. Parameterized f => f -> [Parameter]
flattenParameters model
paramState)

  -- | run a single iteration of an optimizer, returning new parameters and updated optimizer state
  runStep' :: (Parameterized model) => model -> optimizer -> Gradients -> LearningRate -> IO (model, optimizer)
  runStep' model
paramState optimizer
optState Gradients
gradients Tensor
lr = do
    IO ()
performGC
    CInt -> IO ()
mallocTrim CInt
0
    let ([Tensor]
flatParameters', optimizer
optState') = Tensor
-> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)
forall optimizer.
Optimizer optimizer =>
Tensor
-> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)
step Tensor
lr Gradients
gradients [Tensor]
depParameters optimizer
optState
    [Parameter]
newFlatParam <- (Tensor -> IO Parameter) -> [Tensor] -> IO [Parameter]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Tensor -> IO Parameter
makeIndependent [Tensor]
flatParameters'
    (model, optimizer) -> IO (model, optimizer)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (model -> [Parameter] -> model
forall f. Parameterized f => f -> [Parameter] -> f
replaceParameters model
paramState [Parameter]
newFlatParam, optimizer
optState')
    where
      flatParameters :: [Parameter]
flatParameters = model -> [Parameter]
forall f. Parameterized f => f -> [Parameter]
flattenParameters model
paramState
      depParameters :: [Tensor]
depParameters = (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Parameter -> Tensor
toDependent [Parameter]
flatParameters

--
-- Gradient Descent
--

data GD = GD deriving (Int -> GD -> ShowS
[GD] -> ShowS
GD -> String
(Int -> GD -> ShowS)
-> (GD -> String) -> ([GD] -> ShowS) -> Show GD
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GD -> ShowS
showsPrec :: Int -> GD -> ShowS
$cshow :: GD -> String
show :: GD -> String
$cshowList :: [GD] -> ShowS
showList :: [GD] -> ShowS
Show)

-- | Stateless gradient descent step
gd :: LearningRate -> Gradients -> [Tensor] -> [Tensor]
gd :: Tensor -> Gradients -> [Tensor] -> [Tensor]
gd Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters = (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
step [Tensor]
parameters [Tensor]
gradients
  where
    step :: Tensor -> Tensor -> Tensor
step Tensor
p Tensor
dp = Tensor
p Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- (Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
dp)

-- | Gradient descent step with a dummy state variable
gd' :: LearningRate -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd' :: Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd' Tensor
lr Gradients
gradients [Tensor]
depParameters GD
dummy = (Tensor -> Gradients -> [Tensor] -> [Tensor]
gd Tensor
lr Gradients
gradients [Tensor]
depParameters, GD
dummy)

instance Optimizer GD where
  step :: Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
step = Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd'

sgd :: LearningRate -> [Parameter] -> [Tensor] -> [Tensor]
sgd :: Tensor -> [Parameter] -> [Tensor] -> [Tensor]
sgd Tensor
lr [Parameter]
parameters = (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
step [Tensor]
depParameters
  where
    step :: Tensor -> Tensor -> Tensor
step Tensor
p Tensor
dp = Tensor
p Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- (Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
dp)
    depParameters :: [Tensor]
depParameters = (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Tensor
toDependent [Parameter]
parameters

--
-- Gradient Descent with Momentum
--

data GDM = GDM {GDM -> Float
beta :: Float, GDM -> [Tensor]
momentum :: [Tensor]} deriving (Int -> GDM -> ShowS
[GDM] -> ShowS
GDM -> String
(Int -> GDM -> ShowS)
-> (GDM -> String) -> ([GDM] -> ShowS) -> Show GDM
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GDM -> ShowS
showsPrec :: Int -> GDM -> ShowS
$cshow :: GDM -> String
show :: GDM -> String
$cshowList :: [GDM] -> ShowS
showList :: [GDM] -> ShowS
Show)

-- gradient descent with momentum step
gdm ::
  -- | learning rate
  LearningRate ->
  -- | model parameter gradients
  Gradients ->
  -- | model parameters
  [Tensor] ->
  -- | beta & momentum
  GDM ->
  -- | returns new parameters + updated momentum
  ([Tensor], GDM)
gdm :: Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
gdm Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters (GDM Float
beta [Tensor]
momentum) =
  (((Tensor, Tensor) -> Tensor) -> [(Tensor, Tensor)] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Tensor, Tensor) -> Tensor
forall a b. (a, b) -> a
fst [(Tensor, Tensor)]
runStep, Float -> [Tensor] -> GDM
GDM Float
beta (((Tensor, Tensor) -> Tensor) -> [(Tensor, Tensor)] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Tensor, Tensor) -> Tensor
forall a b. (a, b) -> b
snd [(Tensor, Tensor)]
runStep))
  where
    step :: Tensor -> Tensor -> Tensor -> (Tensor, Tensor)
step Tensor
p Tensor
dp Tensor
z = let z' :: Tensor
z' = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta Tensor
z Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Tensor
dp in (Tensor
p Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
z', Tensor
z')
    runStep :: [(Tensor, Tensor)]
runStep = (Tensor -> Tensor -> Tensor -> (Tensor, Tensor))
-> [Tensor] -> [Tensor] -> [Tensor] -> [(Tensor, Tensor)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> (Tensor, Tensor)
step [Tensor]
parameters [Tensor]
gradients [Tensor]
momentum

instance Optimizer GDM where
  step :: Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
step = Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
gdm

--
-- Adam
--

-- | State representation for Adam Optimizer
data Adam = Adam
  { Adam -> Float
beta1 :: Float, -- 1st moment forgetting factor
    Adam -> Float
beta2 :: Float, -- 2nd moment forgetting factor
    Adam -> [Tensor]
m1 :: [Tensor], -- 1st moment
    Adam -> [Tensor]
m2 :: [Tensor], -- 2nd moment
    Adam -> Int
iter :: Int -- iteration
  }
  deriving (Int -> Adam -> ShowS
[Adam] -> ShowS
Adam -> String
(Int -> Adam -> ShowS)
-> (Adam -> String) -> ([Adam] -> ShowS) -> Show Adam
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Adam -> ShowS
showsPrec :: Int -> Adam -> ShowS
$cshow :: Adam -> String
show :: Adam -> String
$cshowList :: [Adam] -> ShowS
showList :: [Adam] -> ShowS
Show)

mkAdam ::
  Int ->
  Float ->
  Float ->
  [Parameter] ->
  Adam
mkAdam :: Int -> Float -> Float -> [Parameter] -> Adam
mkAdam Int
iter Float
beta1 Float
beta2 [Parameter]
parameters =
  Float -> Float -> [Tensor] -> [Tensor] -> Int -> Adam
Adam
    Float
beta1
    Float
beta2
    (Parameter -> Tensor
initZeros (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Parameter]
parameters)
    (Parameter -> Tensor
initZeros (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Parameter]
parameters)
    Int
iter
  where
    initZeros :: Parameter -> Tensor
initZeros = Tensor -> Tensor
zerosLike (Tensor -> Tensor) -> (Parameter -> Tensor) -> Parameter -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parameter -> Tensor
toDependent

-- | Adam step
adam ::
  -- | learning rate
  LearningRate ->
  -- | model parameter gradients
  Gradients ->
  -- | model parameters
  [Tensor] ->
  -- | adam parameters - beta1, beta2, moments, iteration
  Adam ->
  -- | returns new parameters + updated adam parameters
  ([Tensor], Adam)
adam :: Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
adam Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters Adam {Float
Int
[Tensor]
beta1 :: Adam -> Float
beta2 :: Adam -> Float
m1 :: Adam -> [Tensor]
m2 :: Adam -> [Tensor]
iter :: Adam -> Int
beta1 :: Float
beta2 :: Float
m1 :: [Tensor]
m2 :: [Tensor]
iter :: Int
..} = ([Tensor]
parameters', Float -> Float -> [Tensor] -> [Tensor] -> Int -> Adam
Adam Float
beta1 Float
beta2 [Tensor]
m1' [Tensor]
m2' (Int
iter Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
  where
    -- decaying averages of 1st & 2nd moments
    f1 :: Tensor -> Tensor -> Tensor
f1 Tensor
m1 Tensor
dp = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta1 Tensor
m1 Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
beta1) Tensor
dp
    f2 :: Tensor -> Tensor -> Tensor
f2 Tensor
m2 Tensor
dp = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta2 Tensor
m2 Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
beta2) (Tensor
dp Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
dp)
    m1' :: [Tensor]
m1' = (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
f1 [Tensor]
m1 [Tensor]
gradients
    m2' :: [Tensor]
m2' = (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
f2 [Tensor]
m2 [Tensor]
gradients
    -- bias adjustment
    a :: a -> Tensor -> Tensor
a a
beta = a -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
divScalar (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
beta a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
iter Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
    a1 :: [Tensor]
a1 = (Tensor -> Tensor) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Float -> Tensor -> Tensor
forall {a}. (Scalar a, Num a) => a -> Tensor -> Tensor
a Float
beta1) [Tensor]
m1'
    a2 :: [Tensor]
a2 = (Tensor -> Tensor) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Float -> Tensor -> Tensor
forall {a}. (Scalar a, Num a) => a -> Tensor -> Tensor
a Float
beta2) [Tensor]
m2'
    -- parameter update
    eps :: Tensor
eps = Tensor
1e-37
    update :: Tensor -> Tensor -> Tensor -> Tensor
update Tensor
prevParam Tensor
a1' Tensor
a2' = Tensor
prevParam Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
a1' Tensor -> Tensor -> Tensor
forall a. Fractional a => a -> a -> a
/ (Tensor -> Tensor
sqrt Tensor
a2' Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Tensor
eps)
    parameters' :: [Tensor]
parameters' = (Tensor -> Tensor -> Tensor -> Tensor)
-> [Tensor] -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> Tensor
update [Tensor]
parameters [Tensor]
a1 [Tensor]
a2

instance Optimizer Adam where
  step :: Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
step = Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
adam

--
-- Adagrad
--

-- | State representation for Adagrad Optimizer
data Adagrad = Adagrad {Adagrad -> [Tensor]
gsum :: [Tensor]} -- sum of squared gradients
  deriving (Int -> Adagrad -> ShowS
[Adagrad] -> ShowS
Adagrad -> String
(Int -> Adagrad -> ShowS)
-> (Adagrad -> String) -> ([Adagrad] -> ShowS) -> Show Adagrad
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Adagrad -> ShowS
showsPrec :: Int -> Adagrad -> ShowS
$cshow :: Adagrad -> String
show :: Adagrad -> String
$cshowList :: [Adagrad] -> ShowS
showList :: [Adagrad] -> ShowS
Show)

-- | Adagrad step
adagrad ::
  -- | learning rate
  LearningRate ->
  -- | model parameter gradients
  Gradients ->
  -- | model parameters
  [Tensor] ->
  -- | adagrad parameters - gsum, iteration
  Adagrad ->
  -- | returns new parameters + updated adam parameters
  ([Tensor], Adagrad)
adagrad :: Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
adagrad Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters Adagrad {[Tensor]
gsum :: Adagrad -> [Tensor]
gsum :: [Tensor]
..} = ([Tensor]
parameters', [Tensor] -> Adagrad
Adagrad [Tensor]
gsum')
  where
    -- add gradient squared to running total
    f :: a -> a -> a
f a
gsum a
dp = a
gsum a -> a -> a
forall a. Num a => a -> a -> a
+ a
dp a -> a -> a
forall a. Num a => a -> a -> a
* a
dp
    gsum' :: [Tensor]
gsum' = (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
f [Tensor]
gsum [Tensor]
gradients

    -- parameter update
    eps :: Tensor
eps = Tensor
1e-37
    update :: Tensor -> Tensor -> Tensor -> Tensor
update Tensor
prevParam Tensor
a1' Tensor
a2' = Tensor
prevParam Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
a1' Tensor -> Tensor -> Tensor
forall a. Fractional a => a -> a -> a
/ (Tensor -> Tensor
sqrt (Tensor
a2' Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Tensor
eps))
    parameters' :: [Tensor]
parameters' = (Tensor -> Tensor -> Tensor -> Tensor)
-> [Tensor] -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> Tensor
update [Tensor]
parameters [Tensor]
gradients [Tensor]
gsum'

instance Optimizer Adagrad where
  step :: Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
step = Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
adagrad

-- | syntactic sugar for looping with foldM
foldLoop :: a -> Int -> (a -> Int -> IO a) -> IO a
foldLoop :: forall a. a -> Int -> (a -> Int -> IO a) -> IO a
foldLoop a
x Int
count a -> Int -> IO a
block = (a -> Int -> IO a) -> a -> [Int] -> IO a
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM a -> Int -> IO a
block a
x [Int
1 .. Int
count]