{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}

module Torch.Typed.Optim where

import Control.Monad.State
import Data.Kind
import System.Mem (performGC)
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import Torch.Internal.GC (mallocTrim)
import qualified Torch.Tensor as D
import Torch.Typed.Autograd
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor
import Prelude hiding (div, sqrt)

type LearningRate device dtype = Tensor device dtype '[]

type Loss device dtype = Tensor device dtype '[]

data ZerosLike = ZerosLike

instance
  ( parameter ~ Parameter device dtype shape,
    momentum ~ Tensor device dtype shape,
    TensorOptions shape dtype device
  ) =>
  Apply' ZerosLike parameter momentum
  where
  apply' :: ZerosLike -> parameter -> momentum
apply' ZerosLike
_ parameter
_ = momentum
Tensor device dtype shape
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros

class Optimizer optim gradients tensors dtype device where
  step ::
    LearningRate device dtype ->
    HList gradients ->
    HList tensors ->
    optim ->
    (HList tensors, optim)

runStep ::
  forall model optim parameters gradients tensors dtype device.
  ( Parameterized model,
    parameters ~ Parameters model,
    HasGrad (HList parameters) (HList gradients),
    tensors ~ gradients,
    HMap' ToDependent parameters tensors,
    ATen.Castable (HList gradients) [D.ATenTensor],
    Optimizer optim gradients tensors dtype device,
    HMapM' IO MakeIndependent tensors parameters
  ) =>
  model ->
  optim ->
  Loss device dtype ->
  LearningRate device dtype ->
  IO (model, optim)
runStep :: forall model optim (parameters :: [Type]) (gradients :: [Type])
       (tensors :: [Type]) (dtype :: DType) (device :: (DeviceType, Nat)).
(Parameterized model, parameters ~ Parameters model,
 HasGrad (HList parameters) (HList gradients), tensors ~ gradients,
 HMap' ToDependent parameters tensors,
 Castable (HList gradients) [ATenTensor],
 Optimizer optim gradients tensors dtype device,
 HMapM' IO MakeIndependent tensors parameters) =>
model
-> optim
-> Loss device dtype
-> Loss device dtype
-> IO (model, optim)
runStep model
model optim
optim Loss device dtype
loss Loss device dtype
learningRate = do
  IO ()
performGC
  CInt -> IO ()
mallocTrim CInt
0
  let parameters :: HList (Parameters model)
parameters = model -> HList (Parameters model)
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
      gradients :: HList gradients
gradients = Loss device dtype -> HList parameters -> HList gradients
forall a b (dtype :: DType) (device :: (DeviceType, Nat)).
HasGrad a b =>
Tensor device dtype '[] -> a -> b
forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[] -> HList parameters -> HList gradients
grad Loss device dtype
loss HList parameters
HList (Parameters model)
parameters
      tensors :: HList gradients
tensors = ToDependent -> HList parameters -> HList gradients
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent HList parameters
HList (Parameters model)
parameters
      (HList gradients
tensors', optim
optim') = Loss device dtype
-> HList gradients
-> HList gradients
-> optim
-> (HList gradients, optim)
forall {k} {k} optim (gradients :: [k]) (tensors :: [k])
       (dtype :: DType) (device :: (DeviceType, Nat)).
Optimizer optim gradients tensors dtype device =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> optim
-> (HList tensors, optim)
step Loss device dtype
learningRate HList gradients
gradients HList gradients
tensors optim
optim
  parameters' <- MakeIndependent -> HList gradients -> IO (HList parameters)
forall k (m :: Type -> Type) f (xs :: [k]) (ys :: [k]).
HMapM' m f xs ys =>
f -> HList xs -> m (HList ys)
hmapM' MakeIndependent
MakeIndependent HList gradients
tensors'
  let model' = model -> HList (Parameters model) -> model
forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters model
model HList parameters
HList (Parameters model)
parameters'
  return (model', optim')

runStep' ::
  forall model optim parameters gradients tensors dtype device.
  ( Parameterized model,
    parameters ~ Parameters model,
    tensors ~ gradients,
    HMap' ToDependent parameters tensors,
    Optimizer optim gradients tensors dtype device,
    HMapM' IO MakeIndependent tensors parameters
  ) =>
  model ->
  optim ->
  LearningRate device dtype ->
  HList gradients ->
  IO (model, optim)
runStep' :: forall model optim (parameters :: [Type]) (gradients :: [Type])
       (tensors :: [Type]) (dtype :: DType) (device :: (DeviceType, Nat)).
(Parameterized model, parameters ~ Parameters model,
 tensors ~ gradients, HMap' ToDependent parameters tensors,
 Optimizer optim gradients tensors dtype device,
 HMapM' IO MakeIndependent tensors parameters) =>
model
-> optim
-> LearningRate device dtype
-> HList gradients
-> IO (model, optim)
runStep' model
model optim
optim LearningRate device dtype
learningRate HList gradients
gradients = do
  IO ()
performGC
  CInt -> IO ()
mallocTrim CInt
0
  let parameters :: HList (Parameters model)
parameters = model -> HList (Parameters model)
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
      tensors :: HList gradients
tensors = ToDependent -> HList parameters -> HList gradients
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent HList parameters
HList (Parameters model)
parameters
      (HList gradients
tensors', optim
optim') = LearningRate device dtype
-> HList gradients
-> HList gradients
-> optim
-> (HList gradients, optim)
forall {k} {k} optim (gradients :: [k]) (tensors :: [k])
       (dtype :: DType) (device :: (DeviceType, Nat)).
Optimizer optim gradients tensors dtype device =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> optim
-> (HList tensors, optim)
step LearningRate device dtype
learningRate HList gradients
gradients HList gradients
tensors optim
optim
  parameters' <- MakeIndependent -> HList gradients -> IO (HList parameters)
forall k (m :: Type -> Type) f (xs :: [k]) (ys :: [k]).
HMapM' m f xs ys =>
f -> HList xs -> m (HList ys)
hmapM' MakeIndependent
MakeIndependent HList gradients
tensors'
  let model' = model -> HList (Parameters model) -> model
forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters model
model HList parameters
HList (Parameters model)
parameters'
  return (model', optim')

--
-- Gradient Descent (GD)
--

-- | Dummy state representation for GD Optimizer
data GD = GD

mkGD :: GD
mkGD :: GD
mkGD = GD
GD

newtype GDStep device dtype = GDStep (LearningRate device dtype)

instance
  ( parameter ~ Tensor device dtype shape,
    gradient ~ Tensor device dtype shape,
    shape ~ Broadcast '[] shape,
    BasicArithmeticDTypeIsValid device dtype,
    KnownDevice device
  ) =>
  Apply' (GDStep device dtype) (parameter, gradient) parameter
  where
  apply' :: GDStep device dtype -> (parameter, gradient) -> parameter
apply' (GDStep LearningRate device dtype
learningRate) (parameter
parameter, gradient
gradient) =
    parameter
parameter parameter -> parameter -> parameter
forall a. Num a => a -> a -> a
- LearningRate device dtype
-> Tensor device dtype shape
-> Tensor
     device
     (DTypePromotionImpl dtype dtype (CmpDType dtype dtype))
     (ReverseImpl (ReverseImpl shape '[]) '[])
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul LearningRate device dtype
learningRate gradient
Tensor device dtype shape
gradient

-- | Gradient descent step with a dummy state variable
gd ::
  forall gradients tensors dtype device.
  HZipWith (GDStep device dtype) tensors gradients tensors =>
  LearningRate device dtype ->
  HList gradients ->
  HList tensors ->
  GD ->
  (HList tensors, GD)
gd :: forall {k} (gradients :: [k]) (tensors :: [k]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
HZipWith (GDStep device dtype) tensors gradients tensors =>
LearningRate device dtype
-> HList gradients -> HList tensors -> GD -> (HList tensors, GD)
gd LearningRate device dtype
learningRate HList gradients
gradients HList tensors
parameters GD
gd =
  let step :: HList tensors
step = GDStep device dtype
-> HList tensors -> HList gradients -> HList tensors
forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
hzipWith (LearningRate device dtype -> GDStep device dtype
forall (device :: (DeviceType, Nat)) (dtype :: DType).
LearningRate device dtype -> GDStep device dtype
GDStep LearningRate device dtype
learningRate) HList tensors
parameters HList gradients
gradients in (HList tensors
step, GD
gd)

instance
  ( HZipWith (GDStep device dtype) tensors gradients tensors
  ) =>
  Optimizer GD gradients tensors dtype device
  where
  step :: LearningRate device dtype
-> HList gradients -> HList tensors -> GD -> (HList tensors, GD)
step = LearningRate device dtype
-> HList gradients -> HList tensors -> GD -> (HList tensors, GD)
forall {k} (gradients :: [k]) (tensors :: [k]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
HZipWith (GDStep device dtype) tensors gradients tensors =>
LearningRate device dtype
-> HList gradients -> HList tensors -> GD -> (HList tensors, GD)
gd

instance Parameterized GD where
  type Parameters GD = '[]
  flattenParameters :: GD -> HList (Parameters GD)
flattenParameters GD
_ = HList '[]
HList (Parameters GD)
forall k. HList '[]
HNil
  replaceParameters :: GD -> HList (Parameters GD) -> GD
replaceParameters = GD -> HList '[] -> GD
GD -> HList (Parameters GD) -> GD
forall a b. a -> b -> a
const

--
-- Gradient Descent with Momentum (GDM)
--

-- | State representation for GDM Optimizer
data GDM (momenta :: [Type]) = GDM
  { forall (momenta :: [Type]). GDM momenta -> Float
beta :: Float, -- moment forgetting factor
    forall (momenta :: [Type]). GDM momenta -> HList momenta
momenta :: HList momenta -- momenta
  }

mkGDM ::
  forall parameters momenta.
  (HMap' ZerosLike parameters momenta) =>
  Float ->
  HList parameters ->
  GDM momenta
mkGDM :: forall (parameters :: [Type]) (momenta :: [Type]).
HMap' ZerosLike parameters momenta =>
Float -> HList parameters -> GDM momenta
mkGDM Float
beta HList parameters
parameters = Float -> HList momenta -> GDM momenta
forall (momenta :: [Type]). Float -> HList momenta -> GDM momenta
GDM Float
beta (ZerosLike -> HList parameters -> HList momenta
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ZerosLike
ZerosLike HList parameters
parameters)

data GDMStep device dtype = GDMStep Float (LearningRate device dtype)

instance
  ( parameter ~ Tensor device dtype shape,
    gradient ~ Tensor device dtype shape,
    momentum ~ Tensor device dtype shape,
    shape ~ Broadcast '[] shape,
    KnownDevice device,
    BasicArithmeticDTypeIsValid device dtype
  ) =>
  Apply' (GDMStep device dtype) (parameter, gradient, momentum) (parameter, momentum)
  where
  apply' :: GDMStep device dtype
-> (parameter, gradient, momentum) -> (parameter, momentum)
apply' (GDMStep Float
beta LearningRate device dtype
learningRate) (parameter
parameter, gradient
gradient, momentum
momentum) =
    let momentum' :: Tensor device dtype shape
momentum' = Float -> Tensor device dtype shape -> Tensor device dtype shape
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar Float
beta momentum
Tensor device dtype shape
momentum Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
forall a. Num a => a -> a -> a
+ gradient
Tensor device dtype shape
gradient
        parameter' :: parameter
parameter' = parameter
parameter parameter -> parameter -> parameter
forall a. Num a => a -> a -> a
- LearningRate device dtype
-> Tensor device dtype shape -> Tensor device dtype shape
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul LearningRate device dtype
learningRate Tensor device dtype shape
momentum'
     in (parameter
parameter', momentum
Tensor device dtype shape
momentum')

-- | gradient descent with momentum step
gdm ::
  forall gradients tensors momenta gdmStep dtype device.
  ( HZipWith3 (GDMStep device dtype) tensors gradients momenta gdmStep,
    HMap' AFst gdmStep tensors,
    HMap' ASnd gdmStep momenta
  ) =>
  -- | learning rate
  LearningRate device dtype ->
  -- | model parameter gradient tensors
  HList gradients ->
  -- | model parameter tensors
  HList tensors ->
  -- | beta and model parameter momentum tensors
  GDM momenta ->
  -- | returns updated parameters and momenta
  (HList tensors, GDM momenta)
gdm :: forall (gradients :: [Type]) (tensors :: [Type])
       (momenta :: [Type]) (gdmStep :: [Type]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(HZipWith3
   (GDMStep device dtype) tensors gradients momenta gdmStep,
 HMap' AFst gdmStep tensors, HMap' ASnd gdmStep momenta) =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> GDM momenta
-> (HList tensors, GDM momenta)
gdm LearningRate device dtype
learningRate HList gradients
gradients HList tensors
parameters (GDM Float
beta HList momenta
momenta) =
  let step :: HList gdmStep
step = GDMStep device dtype
-> HList tensors
-> HList gradients
-> HList momenta
-> HList gdmStep
forall k f (as :: [k]) (bs :: [k]) (cs :: [k]) (ds :: [k]).
HZipWith3 f as bs cs ds =>
f -> HList as -> HList bs -> HList cs -> HList ds
hzipWith3 (Float -> LearningRate device dtype -> GDMStep device dtype
forall (device :: (DeviceType, Nat)) (dtype :: DType).
Float -> LearningRate device dtype -> GDMStep device dtype
GDMStep Float
beta LearningRate device dtype
learningRate) HList tensors
parameters HList gradients
gradients HList momenta
momenta
   in (AFst -> HList gdmStep -> HList tensors
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' AFst
AFst HList gdmStep
step, Float -> HList momenta -> GDM momenta
forall (momenta :: [Type]). Float -> HList momenta -> GDM momenta
GDM Float
beta (ASnd -> HList gdmStep -> HList momenta
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ASnd
ASnd HList gdmStep
step))

instance
  ( HZipWith3 (GDMStep device dtype) tensors gradients momenta gdmStep,
    HMap' AFst gdmStep tensors,
    HMap' ASnd gdmStep momenta
  ) =>
  Optimizer (GDM momenta) gradients tensors dtype device
  where
  step :: LearningRate device dtype
-> HList gradients
-> HList tensors
-> GDM momenta
-> (HList tensors, GDM momenta)
step = LearningRate device dtype
-> HList gradients
-> HList tensors
-> GDM momenta
-> (HList tensors, GDM momenta)
forall (gradients :: [Type]) (tensors :: [Type])
       (momenta :: [Type]) (gdmStep :: [Type]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(HZipWith3
   (GDMStep device dtype) tensors gradients momenta gdmStep,
 HMap' AFst gdmStep tensors, HMap' ASnd gdmStep momenta) =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> GDM momenta
-> (HList tensors, GDM momenta)
gdm

instance Parameterized (GDM momenta) where
  type Parameters (GDM momenta) = momenta
  flattenParameters :: GDM momenta -> HList (Parameters (GDM momenta))
flattenParameters GDM {Float
HList momenta
beta :: forall (momenta :: [Type]). GDM momenta -> Float
momenta :: forall (momenta :: [Type]). GDM momenta -> HList momenta
beta :: Float
momenta :: HList momenta
..} = HList momenta
HList (Parameters (GDM momenta))
momenta
  replaceParameters :: GDM momenta -> HList (Parameters (GDM momenta)) -> GDM momenta
replaceParameters GDM momenta
gdm HList (Parameters (GDM momenta))
momenta = GDM momenta
gdm {momenta = momenta}

--
-- Adam
-- https://arxiv.org/pdf/1412.6980.pdf
--

type AdamIter = Tensor '( 'D.CPU, 0) 'D.Int64 '[]

-- | State representation for Adam Optimizer
data Adam (momenta :: [Type]) = Adam
  { forall (momenta :: [Type]). Adam momenta -> AdamIter
iter :: AdamIter, -- iteration
    forall (momenta :: [Type]). Adam momenta -> Float
beta1 :: Float, -- 1st moment forgetting factor
    forall (momenta :: [Type]). Adam momenta -> Float
beta2 :: Float, -- 2nd moment forgetting factor
    forall (momenta :: [Type]). Adam momenta -> HList momenta
momenta1 :: HList momenta, -- 1st momenta
    forall (momenta :: [Type]). Adam momenta -> HList momenta
momenta2 :: HList momenta -- 2nd momenta
  }

mkAdam ::
  forall parameters momenta.
  (HMap' ZerosLike parameters momenta) =>
  AdamIter ->
  Float ->
  Float ->
  HList parameters ->
  Adam momenta
mkAdam :: forall (parameters :: [Type]) (momenta :: [Type]).
HMap' ZerosLike parameters momenta =>
AdamIter -> Float -> Float -> HList parameters -> Adam momenta
mkAdam AdamIter
iter Float
beta1 Float
beta2 HList parameters
parameters =
  AdamIter
-> Float -> Float -> HList momenta -> HList momenta -> Adam momenta
forall (momenta :: [Type]).
AdamIter
-> Float -> Float -> HList momenta -> HList momenta -> Adam momenta
Adam
    AdamIter
iter
    Float
beta1
    Float
beta2
    (ZerosLike -> HList parameters -> HList momenta
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ZerosLike
ZerosLike HList parameters
parameters)
    (ZerosLike -> HList parameters -> HList momenta
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ZerosLike
ZerosLike HList parameters
parameters)

newtype AdamMomentum1Update = AdamMomentum1Update Float

-- | decaying average of the first momenta
instance
  ( gradient ~ Tensor device dtype shape,
    momentum1 ~ Tensor device dtype shape,
    KnownDevice device
  ) =>
  Apply' AdamMomentum1Update (momentum1, gradient) momentum1
  where
  apply' :: AdamMomentum1Update -> (momentum1, gradient) -> momentum1
apply' (AdamMomentum1Update Float
beta1) (momentum1
momentum1, gradient
gradient) =
    Float -> Tensor device dtype shape -> Tensor device dtype shape
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar Float
beta1 momentum1
Tensor device dtype shape
momentum1 momentum1 -> momentum1 -> momentum1
forall a. Num a => a -> a -> a
+ Float -> Tensor device dtype shape -> Tensor device dtype shape
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar (Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
beta1) gradient
Tensor device dtype shape
gradient

newtype AdamMomentum2Update = AdamMomentum2Update Float

-- | decaying average of the second momenta
instance
  ( gradient ~ Tensor device dtype shape,
    momentum2 ~ Tensor device dtype shape,
    shape ~ Broadcast shape shape,
    KnownDevice device,
    BasicArithmeticDTypeIsValid device dtype
  ) =>
  Apply' AdamMomentum2Update (momentum2, gradient) momentum2
  where
  apply' :: AdamMomentum2Update -> (momentum2, gradient) -> momentum2
apply' (AdamMomentum2Update Float
beta2) (momentum2
momentum2, gradient
gradient) =
    Float -> Tensor device dtype shape -> Tensor device dtype shape
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar Float
beta2 momentum2
Tensor device dtype shape
momentum2 momentum2 -> momentum2 -> momentum2
forall a. Num a => a -> a -> a
+ Float -> Tensor device dtype shape -> Tensor device dtype shape
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar (Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
beta2) (Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul gradient
Tensor device dtype shape
gradient gradient
Tensor device dtype shape
gradient)

data AdamBiasAdjustment = AdamBiasAdjustment AdamIter Float

-- | bias adjustment
instance
  ( momentum ~ Tensor device dtype shape,
    KnownDevice device,
    KnownDType dtype,
    shape ~ Reverse (Reverse shape),
    BasicArithmeticDTypeIsValid device dtype
  ) =>
  Apply' AdamBiasAdjustment momentum momentum
  where
  apply' :: AdamBiasAdjustment -> momentum -> momentum
apply' (AdamBiasAdjustment AdamIter
iter Float
beta) momentum
momentum =
    let iter' :: Tensor device dtype '[]
iter' = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
       (dtype :: DType) (shape :: [Nat]) t t'.
(KnownDevice device', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDevice'' t device') =>
t -> t'
toDevice @device @'( 'D.CPU, 0) (Tensor '( 'CPU, 0) dtype '[] -> Tensor device dtype '[])
-> (AdamIter -> Tensor '( 'CPU, 0) dtype '[])
-> AdamIter
-> Tensor device dtype '[]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDType'' t dtype') =>
t -> t'
toDType @dtype @'D.Int64 (AdamIter -> Tensor device dtype '[])
-> AdamIter -> Tensor device dtype '[]
forall a b. (a -> b) -> a -> b
$ AdamIter
iter AdamIter -> AdamIter -> AdamIter
forall a. Num a => a -> a -> a
+ AdamIter
1
        beta' :: Tensor device dtype '[]
beta' = forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
full @'[] @dtype @device Float
beta
     in momentum
Tensor device dtype shape
momentum Tensor device dtype shape
-> Tensor device dtype '[]
-> Tensor
     device
     (DTypePromotionImpl dtype dtype (CmpDType dtype dtype))
     (CheckBroadcast
        shape '[] (ComputeBroadcast (Reverse shape) (ReverseImpl '[] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`div` (Tensor device dtype '[]
1 Tensor device dtype '[]
-> Tensor device dtype '[] -> Tensor device dtype '[]
forall a. Num a => a -> a -> a
- Tensor device dtype '[]
-> Tensor device dtype '[] -> Tensor device dtype '[]
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
 shape'' ~ Broadcast shape shape') =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
pow Tensor device dtype '[]
iter' Tensor device dtype '[]
beta')

data AdamParameterUpdate device dtype = AdamParameterUpdate Float (LearningRate device dtype)

-- | parameter update
instance
  ( parameter ~ Tensor device dtype shape,
    momentum ~ Tensor device dtype shape,
    shape ~ Broadcast '[] shape,
    KnownDevice device,
    BasicArithmeticDTypeIsValid device dtype,
    StandardFloatingPointDTypeValidation device dtype
  ) =>
  Apply'
    (AdamParameterUpdate device dtype)
    (parameter, momentum, momentum)
    parameter
  where
  apply' :: AdamParameterUpdate device dtype
-> (parameter, momentum, momentum) -> parameter
apply'
    (AdamParameterUpdate Float
eps LearningRate device dtype
learningRate)
    (parameter
parameter, momentum
biasAdjustedMomentum1, momentum
biasAdjustedMomentum2) =
      parameter
parameter parameter -> parameter -> parameter
forall a. Num a => a -> a -> a
- LearningRate device dtype
-> Tensor device dtype shape
-> Tensor
     device
     (DTypePromotionImpl dtype dtype (CmpDType dtype dtype))
     (ReverseImpl (ReverseImpl shape '[]) '[])
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul LearningRate device dtype
learningRate momentum
Tensor device dtype shape
biasAdjustedMomentum1
        parameter -> parameter -> parameter
forall a. Fractional a => a -> a -> a
/ Float -> Tensor device dtype shape -> Tensor device dtype shape
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
addScalar Float
eps (Tensor device dtype shape -> Tensor device dtype shape
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
sqrt momentum
Tensor device dtype shape
biasAdjustedMomentum2)

-- | Adam step
adam ::
  forall gradients tensors momenta adamStep dtype device.
  ( HZipWith AdamMomentum1Update momenta gradients momenta,
    HZipWith AdamMomentum2Update momenta gradients momenta,
    HMap' AdamBiasAdjustment momenta momenta,
    HZipWith3 (AdamParameterUpdate device dtype) tensors momenta momenta tensors
  ) =>
  -- | learning rate
  LearningRate device dtype ->
  -- | model parameter gradient tensors
  HList gradients ->
  -- | model parameter tensors
  HList tensors ->
  -- | adam parameters - beta1, beta2, momenta1, momenta2, iteration
  Adam momenta ->
  -- | returns new parameters + updated adam parameters
  (HList tensors, Adam momenta)
adam :: forall {k} (gradients :: [Type]) (tensors :: [Type])
       (momenta :: [Type]) (adamStep :: k) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(HZipWith AdamMomentum1Update momenta gradients momenta,
 HZipWith AdamMomentum2Update momenta gradients momenta,
 HMap' AdamBiasAdjustment momenta momenta,
 HZipWith3
   (AdamParameterUpdate device dtype)
   tensors
   momenta
   momenta
   tensors) =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> Adam momenta
-> (HList tensors, Adam momenta)
adam LearningRate device dtype
learningRate HList gradients
gradients HList tensors
parameters Adam {Float
HList momenta
AdamIter
iter :: forall (momenta :: [Type]). Adam momenta -> AdamIter
beta1 :: forall (momenta :: [Type]). Adam momenta -> Float
beta2 :: forall (momenta :: [Type]). Adam momenta -> Float
momenta1 :: forall (momenta :: [Type]). Adam momenta -> HList momenta
momenta2 :: forall (momenta :: [Type]). Adam momenta -> HList momenta
iter :: AdamIter
beta1 :: Float
beta2 :: Float
momenta1 :: HList momenta
momenta2 :: HList momenta
..} =
  (HList tensors
parameters', AdamIter
-> Float -> Float -> HList momenta -> HList momenta -> Adam momenta
forall (momenta :: [Type]).
AdamIter
-> Float -> Float -> HList momenta -> HList momenta -> Adam momenta
Adam (AdamIter
iter AdamIter -> AdamIter -> AdamIter
forall a. Num a => a -> a -> a
+ AdamIter
1) Float
beta1 Float
beta2 HList momenta
momenta1' HList momenta
momenta2')
  where
    momenta1' :: HList momenta
momenta1' = AdamMomentum1Update
-> HList momenta -> HList gradients -> HList momenta
forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
hzipWith (Float -> AdamMomentum1Update
AdamMomentum1Update Float
beta1) HList momenta
momenta1 HList gradients
gradients
    momenta2' :: HList momenta
momenta2' = AdamMomentum2Update
-> HList momenta -> HList gradients -> HList momenta
forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
hzipWith (Float -> AdamMomentum2Update
AdamMomentum2Update Float
beta2) HList momenta
momenta2 HList gradients
gradients
    biasAdjustedMomenta1 :: HList momenta
biasAdjustedMomenta1 = AdamBiasAdjustment -> HList momenta -> HList momenta
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' (AdamIter -> Float -> AdamBiasAdjustment
AdamBiasAdjustment AdamIter
iter Float
beta1) HList momenta
momenta1'
    biasAdjustedMomenta2 :: HList momenta
biasAdjustedMomenta2 = AdamBiasAdjustment -> HList momenta -> HList momenta
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' (AdamIter -> Float -> AdamBiasAdjustment
AdamBiasAdjustment AdamIter
iter Float
beta2) HList momenta
momenta2'
    parameters' :: HList tensors
parameters' =
      AdamParameterUpdate device dtype
-> HList tensors -> HList momenta -> HList momenta -> HList tensors
forall k f (as :: [k]) (bs :: [k]) (cs :: [k]) (ds :: [k]).
HZipWith3 f as bs cs ds =>
f -> HList as -> HList bs -> HList cs -> HList ds
hzipWith3
        (Float
-> LearningRate device dtype -> AdamParameterUpdate device dtype
forall (device :: (DeviceType, Nat)) (dtype :: DType).
Float
-> LearningRate device dtype -> AdamParameterUpdate device dtype
AdamParameterUpdate Float
1e-37 LearningRate device dtype
learningRate)
        HList tensors
parameters
        HList momenta
biasAdjustedMomenta1
        HList momenta
biasAdjustedMomenta2

instance
  ( HZipWith AdamMomentum1Update momenta gradients momenta,
    HZipWith AdamMomentum2Update momenta gradients momenta,
    HMap' AdamBiasAdjustment momenta momenta,
    HZipWith3 (AdamParameterUpdate device dtype) tensors momenta momenta tensors
  ) =>
  Optimizer (Adam momenta) gradients tensors dtype device
  where
  step :: LearningRate device dtype
-> HList gradients
-> HList tensors
-> Adam momenta
-> (HList tensors, Adam momenta)
step = LearningRate device dtype
-> HList gradients
-> HList tensors
-> Adam momenta
-> (HList tensors, Adam momenta)
forall {k} (gradients :: [Type]) (tensors :: [Type])
       (momenta :: [Type]) (adamStep :: k) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(HZipWith AdamMomentum1Update momenta gradients momenta,
 HZipWith AdamMomentum2Update momenta gradients momenta,
 HMap' AdamBiasAdjustment momenta momenta,
 HZipWith3
   (AdamParameterUpdate device dtype)
   tensors
   momenta
   momenta
   tensors) =>
LearningRate device dtype
-> HList gradients
-> HList tensors
-> Adam momenta
-> (HList tensors, Adam momenta)
adam

instance
  HAppendFD momenta momenta (momenta ++ momenta) =>
  Parameterized (Adam momenta)
  where
  type Parameters (Adam momenta) = AdamIter ': (momenta ++ momenta)
  flattenParameters :: Adam momenta -> HList (Parameters (Adam momenta))
flattenParameters Adam {Float
HList momenta
AdamIter
iter :: forall (momenta :: [Type]). Adam momenta -> AdamIter
beta1 :: forall (momenta :: [Type]). Adam momenta -> Float
beta2 :: forall (momenta :: [Type]). Adam momenta -> Float
momenta1 :: forall (momenta :: [Type]). Adam momenta -> HList momenta
momenta2 :: forall (momenta :: [Type]). Adam momenta -> HList momenta
iter :: AdamIter
beta1 :: Float
beta2 :: Float
momenta1 :: HList momenta
momenta2 :: HList momenta
..} = AdamIter
iter AdamIter
-> HList (momenta ++ momenta)
-> HList (AdamIter : (momenta ++ momenta))
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. (HList momenta
momenta1 HList momenta -> HList momenta -> HList (momenta ++ momenta)
forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList a -> HList b -> HList ab
`happendFD` HList momenta
momenta2)
  replaceParameters :: Adam momenta -> HList (Parameters (Adam momenta)) -> Adam momenta
replaceParameters Adam momenta
adam (AdamIter
iter :. HList (momenta ++ momenta)
momenta) =
    let (HList momenta
momenta1, HList momenta
momenta2) = HList (momenta ++ momenta) -> (HList momenta, HList momenta)
forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList ab -> (HList a, HList b)
hunappendFD HList (momenta ++ momenta)
momenta
     in Adam momenta
adam {iter = iter, momenta1 = momenta1, momenta2 = momenta2}