{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Typed.Optim.CppOptim
( module Torch.Typed.Optim.CppOptim,
AdagradOptions (..),
AdamOptions (..),
AdamwOptions (..),
LbfgsOptions (..),
RmspropOptions (..),
SGDOptions (..),
)
where
import Data.Default.Class
import Data.Foldable (for_)
import Data.Kind (Type)
import qualified Debug.Trace as Debug
import Foreign.ForeignPtr
import System.Mem (performGC)
import qualified Torch as TD
import Torch.HList
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..), CppObject (..), CppTuple2 (..), CppTuple3 (..), CppTuple4 (..))
import Torch.Internal.GC (mallocTrim)
import qualified Torch.Internal.Managed.Optim as LibTorch
import qualified Torch.Internal.Type as ATen
import Torch.Optim.CppOptim
( AdagradOptions (..),
AdamOptions (..),
AdamwOptions (..),
LbfgsOptions (..),
RmspropOptions (..),
SGDOptions (..),
)
import Torch.Typed.Autograd
import Torch.Typed.NN
import qualified Torch.Typed.Optim as Optim
import Torch.Typed.Parameter
import Torch.Typed.Tensor
type CppOptimizerRef = ForeignPtr ATen.Optimizer
data CppOptimizerState option (params :: [Type])
= CppOptimizerState option CppOptimizerRef
data ToParameter = ToParameter
instance Apply' ToParameter (Tensor dev dtype shape) (Parameter dev dtype shape) where
apply' :: ToParameter -> Tensor dev dtype shape -> Parameter dev dtype shape
apply' ToParameter
_ (UnsafeMkTensor Tensor
tensor) = IndependentTensor -> Parameter dev dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IndependentTensor -> Parameter device dtype shape
UnsafeMkParameter (IndependentTensor -> Parameter dev dtype shape)
-> (Tensor -> IndependentTensor)
-> Tensor
-> Parameter dev dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> IndependentTensor
TD.IndependentTensor (Tensor -> Parameter dev dtype shape)
-> Tensor -> Parameter dev dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor
tensor
class CppOptimizer option where
initOptimizer ::
forall model tensors.
( Parameterized model,
HMap' ToDependent (Parameters model) tensors,
Castable (HList tensors) [TD.ATenTensor]
) =>
option ->
model ->
IO (CppOptimizerState option (Parameters model))
unsafeStep ::
forall model dev dtype lossShape tensors res.
( Parameterized model,
HMap' ToDependent (Parameters model) tensors,
HMap' ToParameter tensors (Parameters model),
Castable (HList tensors) [TD.ATenTensor]
) =>
model ->
CppOptimizerState option (Parameters model) ->
Tensor dev dtype lossShape ->
IO (model, CppOptimizerState option (Parameters model))
unsafeStep model
model o :: CppOptimizerState option (Parameters model)
o@(CppOptimizerState option
_ CppOptimizerRef
optimizer) Tensor dev dtype lossShape
loss = do
[ATenTensor]
v :: [TD.ATenTensor] <- (CppOptimizerRef -> ATenTensor -> IO (ForeignPtr TensorList))
-> CppOptimizerRef -> Tensor dev dtype lossShape -> IO [ATenTensor]
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 CppOptimizerRef -> ATenTensor -> IO (ForeignPtr TensorList)
LibTorch.unsafeStep CppOptimizerRef
optimizer Tensor dev dtype lossShape
loss
HList tensors
newParamTensors :: HList tensors <- [ATenTensor]
-> (HList tensors -> IO (HList tensors)) -> IO (HList tensors)
forall r. [ATenTensor] -> (HList tensors -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [ATenTensor]
v HList tensors -> IO (HList tensors)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
let newParams :: HList (Parameters model)
newParams = ToParameter -> HList tensors -> HList (Parameters model)
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToParameter
ToParameter HList tensors
newParamTensors
let newModel :: model
newModel = model -> HList (Parameters model) -> model
forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters model
model HList (Parameters model)
newParams
(model, CppOptimizerState option (Parameters model))
-> IO (model, CppOptimizerState option (Parameters model))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (model
newModel, CppOptimizerState option (Parameters model)
o)
instance CppOptimizer AdamOptions where
initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
Castable (HList tensors) [ATenTensor]) =>
AdamOptions
-> model -> IO (CppOptimizerState AdamOptions (Parameters model))
initOptimizer opt :: AdamOptions
opt@AdamOptions {Bool
Double
(Double, Double)
adamLr :: Double
adamBetas :: (Double, Double)
adamEps :: Double
adamWeightDecay :: Double
adamAmsgrad :: Bool
adamAmsgrad :: AdamOptions -> Bool
adamWeightDecay :: AdamOptions -> Double
adamEps :: AdamOptions -> Double
adamBetas :: AdamOptions -> (Double, Double)
adamLr :: AdamOptions -> Double
..} model
model = do
CppOptimizerRef
v <-
(CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef)
-> Double
-> Double
-> Double
-> Double
-> Double
-> Bool
-> HList tensors
-> IO CppOptimizerRef
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7
CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.adam
Double
adamLr
((Double, Double) -> Double
forall a b. (a, b) -> a
fst (Double, Double)
adamBetas)
((Double, Double) -> Double
forall a b. (a, b) -> b
snd (Double, Double)
adamBetas)
Double
adamEps
Double
adamWeightDecay
Bool
adamAmsgrad
HList tensors
initParams'
CppOptimizerState AdamOptions (Parameters model)
-> IO (CppOptimizerState AdamOptions (Parameters model))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CppOptimizerState AdamOptions (Parameters model)
-> IO (CppOptimizerState AdamOptions (Parameters model)))
-> CppOptimizerState AdamOptions (Parameters model)
-> IO (CppOptimizerState AdamOptions (Parameters model))
forall a b. (a -> b) -> a -> b
$ AdamOptions
-> CppOptimizerRef
-> CppOptimizerState AdamOptions (Parameters model)
forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState AdamOptions
opt CppOptimizerRef
v
where
initParams' :: HList tensors
initParams' = ToDependent -> HList (Parameters model) -> HList tensors
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent (HList (Parameters model) -> HList tensors)
-> HList (Parameters model) -> HList tensors
forall a b. (a -> b) -> a -> b
$ model -> HList (Parameters model)
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
instance CppOptimizer AdamwOptions where
initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
Castable (HList tensors) [ATenTensor]) =>
AdamwOptions
-> model -> IO (CppOptimizerState AdamwOptions (Parameters model))
initOptimizer opt :: AdamwOptions
opt@AdamwOptions {Bool
Double
(Double, Double)
adamwLr :: Double
adamwBetas :: (Double, Double)
adamwEps :: Double
adamwWeightDecay :: Double
adamwAmsgrad :: Bool
adamwAmsgrad :: AdamwOptions -> Bool
adamwWeightDecay :: AdamwOptions -> Double
adamwEps :: AdamwOptions -> Double
adamwBetas :: AdamwOptions -> (Double, Double)
adamwLr :: AdamwOptions -> Double
..} model
model = do
CppOptimizerRef
v <- (CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef)
-> Double
-> Double
-> Double
-> Double
-> Double
-> Bool
-> HList tensors
-> IO CppOptimizerRef
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.adamw Double
adamwLr ((Double, Double) -> Double
forall a b. (a, b) -> a
fst (Double, Double)
adamwBetas) ((Double, Double) -> Double
forall a b. (a, b) -> b
snd (Double, Double)
adamwBetas) Double
adamwEps Double
adamwWeightDecay Bool
adamwAmsgrad HList tensors
initParams'
CppOptimizerState AdamwOptions (Parameters model)
-> IO (CppOptimizerState AdamwOptions (Parameters model))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CppOptimizerState AdamwOptions (Parameters model)
-> IO (CppOptimizerState AdamwOptions (Parameters model)))
-> CppOptimizerState AdamwOptions (Parameters model)
-> IO (CppOptimizerState AdamwOptions (Parameters model))
forall a b. (a -> b) -> a -> b
$ AdamwOptions
-> CppOptimizerRef
-> CppOptimizerState AdamwOptions (Parameters model)
forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState AdamwOptions
opt CppOptimizerRef
v
where
initParams' :: HList tensors
initParams' = ToDependent -> HList (Parameters model) -> HList tensors
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent (HList (Parameters model) -> HList tensors)
-> HList (Parameters model) -> HList tensors
forall a b. (a -> b) -> a -> b
$ model -> HList (Parameters model)
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
instance CppOptimizer LbfgsOptions where
initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
Castable (HList tensors) [ATenTensor]) =>
LbfgsOptions
-> model -> IO (CppOptimizerState LbfgsOptions (Parameters model))
initOptimizer opt :: LbfgsOptions
opt@LbfgsOptions {Double
Int
Maybe String
lbfgsLr :: Double
lbfgsMaxIter :: Int
lbfgsMaxEval :: Int
lbfgsToleranceGrad :: Double
lbfgsToleranceChange :: Double
lbfgsHistorySize :: Int
lbfgsLineSearchFn :: Maybe String
lbfgsLineSearchFn :: LbfgsOptions -> Maybe String
lbfgsHistorySize :: LbfgsOptions -> Int
lbfgsToleranceChange :: LbfgsOptions -> Double
lbfgsToleranceGrad :: LbfgsOptions -> Double
lbfgsMaxEval :: LbfgsOptions -> Int
lbfgsMaxIter :: LbfgsOptions -> Int
lbfgsLr :: LbfgsOptions -> Double
..} model
model = do
CppOptimizerRef
v <- (CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (ForeignPtr StdString)
-> ForeignPtr TensorList
-> IO CppOptimizerRef)
-> Double
-> Int
-> Int
-> Double
-> Double
-> Int
-> Maybe String
-> HList tensors
-> IO CppOptimizerRef
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> IO y
cast8 CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (ForeignPtr StdString)
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.lbfgs Double
lbfgsLr Int
lbfgsMaxIter Int
lbfgsMaxEval Double
lbfgsToleranceGrad Double
lbfgsToleranceChange Int
lbfgsHistorySize Maybe String
lbfgsLineSearchFn HList tensors
initParams'
CppOptimizerState LbfgsOptions (Parameters model)
-> IO (CppOptimizerState LbfgsOptions (Parameters model))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CppOptimizerState LbfgsOptions (Parameters model)
-> IO (CppOptimizerState LbfgsOptions (Parameters model)))
-> CppOptimizerState LbfgsOptions (Parameters model)
-> IO (CppOptimizerState LbfgsOptions (Parameters model))
forall a b. (a -> b) -> a -> b
$ LbfgsOptions
-> CppOptimizerRef
-> CppOptimizerState LbfgsOptions (Parameters model)
forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState LbfgsOptions
opt CppOptimizerRef
v
where
initParams' :: HList tensors
initParams' = ToDependent -> HList (Parameters model) -> HList tensors
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent (HList (Parameters model) -> HList tensors)
-> HList (Parameters model) -> HList tensors
forall a b. (a -> b) -> a -> b
$ model -> HList (Parameters model)
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
instance CppOptimizer RmspropOptions where
initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
Castable (HList tensors) [ATenTensor]) =>
RmspropOptions
-> model
-> IO (CppOptimizerState RmspropOptions (Parameters model))
initOptimizer opt :: RmspropOptions
opt@RmspropOptions {Bool
Double
rmspropLr :: Double
rmspropAlpha :: Double
rmspropEps :: Double
rmspropWeightDecay :: Double
rmspropMomentum :: Double
rmspropCentered :: Bool
rmspropCentered :: RmspropOptions -> Bool
rmspropMomentum :: RmspropOptions -> Double
rmspropWeightDecay :: RmspropOptions -> Double
rmspropEps :: RmspropOptions -> Double
rmspropAlpha :: RmspropOptions -> Double
rmspropLr :: RmspropOptions -> Double
..} model
model = do
CppOptimizerRef
v <- (CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef)
-> Double
-> Double
-> Double
-> Double
-> Double
-> Bool
-> HList tensors
-> IO CppOptimizerRef
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.rmsprop Double
rmspropLr Double
rmspropAlpha Double
rmspropEps Double
rmspropWeightDecay Double
rmspropMomentum Bool
rmspropCentered HList tensors
initParams'
CppOptimizerState RmspropOptions (Parameters model)
-> IO (CppOptimizerState RmspropOptions (Parameters model))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CppOptimizerState RmspropOptions (Parameters model)
-> IO (CppOptimizerState RmspropOptions (Parameters model)))
-> CppOptimizerState RmspropOptions (Parameters model)
-> IO (CppOptimizerState RmspropOptions (Parameters model))
forall a b. (a -> b) -> a -> b
$ RmspropOptions
-> CppOptimizerRef
-> CppOptimizerState RmspropOptions (Parameters model)
forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState RmspropOptions
opt CppOptimizerRef
v
where
initParams' :: HList tensors
initParams' = ToDependent -> HList (Parameters model) -> HList tensors
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent (HList (Parameters model) -> HList tensors)
-> HList (Parameters model) -> HList tensors
forall a b. (a -> b) -> a -> b
$ model -> HList (Parameters model)
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
instance CppOptimizer SGDOptions where
initOptimizer :: forall model (tensors :: [*]).
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
Castable (HList tensors) [ATenTensor]) =>
SGDOptions
-> model -> IO (CppOptimizerState SGDOptions (Parameters model))
initOptimizer opt :: SGDOptions
opt@SGDOptions {Bool
Double
sgdLr :: Double
sgdMomentum :: Double
sgdDampening :: Double
sgdWeightDecay :: Double
sgdNesterov :: Bool
sgdNesterov :: SGDOptions -> Bool
sgdWeightDecay :: SGDOptions -> Double
sgdDampening :: SGDOptions -> Double
sgdMomentum :: SGDOptions -> Double
sgdLr :: SGDOptions -> Double
..} model
model = do
CppOptimizerRef
v <- (CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef)
-> Double
-> Double
-> Double
-> Double
-> Bool
-> HList tensors
-> IO CppOptimizerRef
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
cast6 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO CppOptimizerRef
LibTorch.sgd Double
sgdLr Double
sgdMomentum Double
sgdDampening Double
sgdWeightDecay Bool
sgdNesterov HList tensors
initParams'
CppOptimizerState SGDOptions (Parameters model)
-> IO (CppOptimizerState SGDOptions (Parameters model))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CppOptimizerState SGDOptions (Parameters model)
-> IO (CppOptimizerState SGDOptions (Parameters model)))
-> CppOptimizerState SGDOptions (Parameters model)
-> IO (CppOptimizerState SGDOptions (Parameters model))
forall a b. (a -> b) -> a -> b
$ SGDOptions
-> CppOptimizerRef
-> CppOptimizerState SGDOptions (Parameters model)
forall option (params :: [*]).
option -> CppOptimizerRef -> CppOptimizerState option params
CppOptimizerState SGDOptions
opt CppOptimizerRef
v
where
initParams' :: HList tensors
initParams' = ToDependent -> HList (Parameters model) -> HList tensors
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent (HList (Parameters model) -> HList tensors)
-> HList (Parameters model) -> HList tensors
forall a b. (a -> b) -> a -> b
$ model -> HList (Parameters model)
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters model
model
runStep ::
( CppOptimizer option,
Parameterized model,
HMap' ToDependent (Parameters model) tensors,
HMap' ToParameter tensors (Parameters model),
Castable (HList tensors) [TD.ATenTensor]
) =>
model ->
CppOptimizerState option (Parameters model) ->
Optim.Loss dev dtype ->
IO (model, CppOptimizerState option (Parameters model))
runStep :: forall option model (tensors :: [*]) (dev :: (DeviceType, Nat))
(dtype :: DType).
(CppOptimizer option, Parameterized model,
HMap' ToDependent (Parameters model) tensors,
HMap' ToParameter tensors (Parameters model),
Castable (HList tensors) [ATenTensor]) =>
model
-> CppOptimizerState option (Parameters model)
-> Loss dev dtype
-> IO (model, CppOptimizerState option (Parameters model))
runStep model
model CppOptimizerState option (Parameters model)
optim Loss dev dtype
loss = do
IO ()
performGC
CInt -> IO ()
mallocTrim CInt
0
model
-> CppOptimizerState option (Parameters model)
-> Loss dev dtype
-> IO (model, CppOptimizerState option (Parameters model))
forall option model (dev :: (DeviceType, Nat)) (dtype :: DType)
(lossShape :: [Nat]) (tensors :: [*]) res.
(CppOptimizer option, Parameterized model,
HMap' ToDependent (Parameters model) tensors,
HMap' ToParameter tensors (Parameters model),
Castable (HList tensors) [ATenTensor]) =>
model
-> CppOptimizerState option (Parameters model)
-> Tensor dev dtype lossShape
-> IO (model, CppOptimizerState option (Parameters model))
forall model (dev :: (DeviceType, Nat)) (dtype :: DType)
(lossShape :: [Nat]) (tensors :: [*]) res.
(Parameterized model, HMap' ToDependent (Parameters model) tensors,
HMap' ToParameter tensors (Parameters model),
Castable (HList tensors) [ATenTensor]) =>
model
-> CppOptimizerState option (Parameters model)
-> Tensor dev dtype lossShape
-> IO (model, CppOptimizerState option (Parameters model))
unsafeStep model
model CppOptimizerState option (Parameters model)
optim Loss dev dtype
loss