{-# LANGUAGE DataKinds #-}
module Torch.Internal.Managed.Optim where

import Foreign
import Foreign.C.String
import Foreign.C.Types
import Foreign.ForeignPtr.Unsafe
import Torch.Internal.Cast
import Torch.Internal.Class
import Torch.Internal.Objects
import Torch.Internal.Type
import qualified Torch.Internal.Unmanaged.Optim as Unmanaged
import Control.Concurrent.MVar (MVar(..), newEmptyMVar, putMVar, takeMVar)

-- optimizerWithAdam
--   :: CDouble
--   -> CDouble
--   -> CDouble
--   -> CDouble
--   -> CDouble
--   -> CBool
--   -> ForeignPtr TensorList
--   -> (ForeignPtr TensorList -> IO (ForeignPtr Tensor))
--   -> Int
--   -> IO (ForeignPtr TensorList)
-- optimizerWithAdam adamLr adamBetas0 adamBetas1 adamEps adamWeightDecay adamAmsgrad initParams loss numIter = _cast2 (\i n -> Unmanaged.optimizerWithAdam adamLr adamBetas0 adamBetas1 adamEps adamWeightDecay adamAmsgrad i (trans loss) n) initParams numIter
--   where
--     trans :: (ForeignPtr TensorList -> IO (ForeignPtr Tensor)) -> Ptr TensorList -> IO (Ptr Tensor)
--     trans func inputs = do
--       inputs' <- newForeignPtr_ inputs
--       ret <- func inputs'
--       return $ unsafeForeignPtrToPtr ret

adagrad
  :: CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> ForeignPtr TensorList
  -> IO (ForeignPtr Optimizer)
adagrad :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
adagrad = (CDouble
 -> CDouble
 -> CDouble
 -> CDouble
 -> CDouble
 -> Ptr TensorList
 -> IO (Ptr Optimizer))
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
_cast6 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.adagrad

rmsprop
  :: CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CBool
  -> ForeignPtr TensorList
  -> IO (ForeignPtr Optimizer)
rmsprop :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
rmsprop = (CDouble
 -> CDouble
 -> CDouble
 -> CDouble
 -> CDouble
 -> CBool
 -> Ptr TensorList
 -> IO (Ptr Optimizer))
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
_cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.rmsprop

sgd
  :: CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CBool
  -> ForeignPtr TensorList
  -> IO (ForeignPtr Optimizer)
sgd :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
sgd = (CDouble
 -> CDouble
 -> CDouble
 -> CDouble
 -> CBool
 -> Ptr TensorList
 -> IO (Ptr Optimizer))
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> IO y
_cast6 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.sgd

adam
  :: CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CBool
  -> ForeignPtr TensorList
  -> IO (ForeignPtr Optimizer)
adam :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
adam = (CDouble
 -> CDouble
 -> CDouble
 -> CDouble
 -> CDouble
 -> CBool
 -> Ptr TensorList
 -> IO (Ptr Optimizer))
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
_cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.adam

adamw
  :: CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CDouble
  -> CBool
  -> ForeignPtr TensorList
  -> IO (ForeignPtr Optimizer)
adamw :: CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
adamw = (CDouble
 -> CDouble
 -> CDouble
 -> CDouble
 -> CDouble
 -> CBool
 -> Ptr TensorList
 -> IO (Ptr Optimizer))
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
_cast7 CDouble
-> CDouble
-> CDouble
-> CDouble
-> CDouble
-> CBool
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.adamw

lbfgs
  :: CDouble
  -> CInt
  -> CInt
  -> CDouble
  -> CDouble
  -> CInt
  -> Maybe (ForeignPtr StdString)
  -> ForeignPtr TensorList
  -> IO (ForeignPtr Optimizer)
lbfgs :: CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (ForeignPtr StdString)
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
lbfgs = (CDouble
 -> CInt
 -> CInt
 -> CDouble
 -> CDouble
 -> CInt
 -> Maybe (Ptr StdString)
 -> Ptr TensorList
 -> IO (Ptr Optimizer))
-> CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (ForeignPtr StdString)
-> ForeignPtr TensorList
-> IO (ForeignPtr Optimizer)
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 x7 cx7 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6, Castable x7 cx7,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> cx7 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> x7 -> IO y
_cast8 CDouble
-> CInt
-> CInt
-> CDouble
-> CDouble
-> CInt
-> Maybe (Ptr StdString)
-> Ptr TensorList
-> IO (Ptr Optimizer)
Unmanaged.lbfgs

getParams :: ForeignPtr Optimizer -> IO (ForeignPtr TensorList) 
getParams :: ForeignPtr Optimizer -> IO (ForeignPtr TensorList)
getParams = (Ptr Optimizer -> IO (Ptr TensorList))
-> ForeignPtr Optimizer -> IO (ForeignPtr TensorList)
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
_cast1 Ptr Optimizer -> IO (Ptr TensorList)
Unmanaged.getParams

step :: ForeignPtr Optimizer -> (ForeignPtr TensorList -> IO (ForeignPtr Tensor)) -> IO (ForeignPtr Tensor)
step :: ForeignPtr Optimizer
-> (ForeignPtr TensorList -> IO (ForeignPtr Tensor))
-> IO (ForeignPtr Tensor)
step ForeignPtr Optimizer
optimizer ForeignPtr TensorList -> IO (ForeignPtr Tensor)
loss = do
  ref <- IO (MVar (ForeignPtr Tensor))
forall a. IO (MVar a)
newEmptyMVar
  ret <- cast1 (\Ptr Optimizer
opt -> Ptr Optimizer
-> (Ptr TensorList -> IO (Ptr Tensor)) -> IO (Ptr Tensor)
Unmanaged.step Ptr Optimizer
opt (MVar (ForeignPtr Tensor)
-> (ForeignPtr TensorList -> IO (ForeignPtr Tensor))
-> Ptr TensorList
-> IO (Ptr Tensor)
trans MVar (ForeignPtr Tensor)
ref ForeignPtr TensorList -> IO (ForeignPtr Tensor)
loss)) optimizer
  v <- takeMVar ref
  touchForeignPtr v
  return ret
  where
    trans :: MVar (ForeignPtr Tensor) -> (ForeignPtr TensorList -> IO (ForeignPtr Tensor)) -> Ptr TensorList -> IO (Ptr Tensor)
    trans :: MVar (ForeignPtr Tensor)
-> (ForeignPtr TensorList -> IO (ForeignPtr Tensor))
-> Ptr TensorList
-> IO (Ptr Tensor)
trans MVar (ForeignPtr Tensor)
ref ForeignPtr TensorList -> IO (ForeignPtr Tensor)
func Ptr TensorList
inputs = do
      inputs' <- Ptr TensorList -> IO (ForeignPtr TensorList)
forall a. Ptr a -> IO (ForeignPtr a)
newForeignPtr_ Ptr TensorList
inputs
      ret <- func inputs'
      putMVar ref ret
      return $ unsafeForeignPtrToPtr ret

stepWithGenerator
  :: ForeignPtr Optimizer
  -> ForeignPtr Generator
  -> (ForeignPtr TensorList -> ForeignPtr Generator -> IO (ForeignPtr (StdTuple '(Tensor,Generator))))
  -> IO (ForeignPtr (StdTuple '(Tensor,Generator)))
stepWithGenerator :: ForeignPtr Optimizer
-> ForeignPtr Generator
-> (ForeignPtr TensorList
    -> ForeignPtr Generator
    -> IO (ForeignPtr (StdTuple '(Tensor, Generator))))
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
stepWithGenerator ForeignPtr Optimizer
optimizer ForeignPtr Generator
generator ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
loss = do
  ref <- IO (MVar (ForeignPtr (StdTuple '(Tensor, Generator))))
forall a. IO (MVar a)
newEmptyMVar
  ret <- cast2 (\Ptr Optimizer
opt Ptr Generator
gen -> Ptr Optimizer
-> Ptr Generator
-> (Ptr TensorList
    -> Ptr Generator -> IO (Ptr (StdTuple '(Tensor, Generator))))
-> IO (Ptr (StdTuple '(Tensor, Generator)))
Unmanaged.stepWithGenerator Ptr Optimizer
opt Ptr Generator
gen (MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
-> (ForeignPtr TensorList
    -> ForeignPtr Generator
    -> IO (ForeignPtr (StdTuple '(Tensor, Generator))))
-> Ptr TensorList
-> Ptr Generator
-> IO (Ptr (StdTuple '(Tensor, Generator)))
trans MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
ref ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
loss)) optimizer generator
  v <- takeMVar ref
  touchForeignPtr v
  return ret
  where
    trans
      :: MVar (ForeignPtr (StdTuple '(Tensor,Generator)))
      -> (ForeignPtr TensorList -> ForeignPtr Generator -> IO (ForeignPtr (StdTuple '(Tensor,Generator))))
      -> Ptr TensorList
      -> Ptr Generator
      -> IO (Ptr (StdTuple '(Tensor,Generator)))
    trans :: MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
-> (ForeignPtr TensorList
    -> ForeignPtr Generator
    -> IO (ForeignPtr (StdTuple '(Tensor, Generator))))
-> Ptr TensorList
-> Ptr Generator
-> IO (Ptr (StdTuple '(Tensor, Generator)))
trans MVar (ForeignPtr (StdTuple '(Tensor, Generator)))
ref ForeignPtr TensorList
-> ForeignPtr Generator
-> IO (ForeignPtr (StdTuple '(Tensor, Generator)))
func Ptr TensorList
inputs Ptr Generator
generator = do
      inputs' <- Ptr TensorList -> IO (ForeignPtr TensorList)
forall a. Ptr a -> IO (ForeignPtr a)
newForeignPtr_ Ptr TensorList
inputs
      generator' <- newForeignPtr_ generator
      ret <- func inputs' generator'
      putMVar ref ret
      return $ unsafeForeignPtrToPtr ret


unsafeStep :: ForeignPtr Optimizer -> ForeignPtr Tensor -> IO (ForeignPtr TensorList)
unsafeStep :: ForeignPtr Optimizer
-> ForeignPtr Tensor -> IO (ForeignPtr TensorList)
unsafeStep = (Ptr Optimizer -> Ptr Tensor -> IO (Ptr TensorList))
-> ForeignPtr Optimizer
-> ForeignPtr Tensor
-> IO (ForeignPtr TensorList)
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Optimizer -> Ptr Tensor -> IO (Ptr TensorList)
Unmanaged.unsafeStep

save :: ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
save :: ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
save = (Ptr Optimizer -> Ptr StdString -> IO ())
-> ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Optimizer -> Ptr StdString -> IO ()
Unmanaged.save

load :: ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
load :: ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
load = (Ptr Optimizer -> Ptr StdString -> IO ())
-> ForeignPtr Optimizer -> ForeignPtr StdString -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
_cast2 Ptr Optimizer -> Ptr StdString -> IO ()
Unmanaged.load