{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Torch.Autograd where

import Foreign.ForeignPtr
import GHC.Generics
import System.IO.Unsafe
import Torch.Internal.Cast
import Torch.Internal.Class
import qualified Torch.Internal.Managed.Autograd
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Type as ATen
import Torch.Tensor

-- | Note: to create an `IndependentTensor` use `makeIndependent`;
-- | otherwise, Torch will complain the parameter does not require a gradient.
newtype IndependentTensor = IndependentTensor
  { IndependentTensor -> Tensor
toDependent :: Tensor
  }
  deriving (Int -> IndependentTensor -> ShowS
[IndependentTensor] -> ShowS
IndependentTensor -> String
(Int -> IndependentTensor -> ShowS)
-> (IndependentTensor -> String)
-> ([IndependentTensor] -> ShowS)
-> Show IndependentTensor
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> IndependentTensor -> ShowS
showsPrec :: Int -> IndependentTensor -> ShowS
$cshow :: IndependentTensor -> String
show :: IndependentTensor -> String
$cshowList :: [IndependentTensor] -> ShowS
showList :: [IndependentTensor] -> ShowS
Show, (forall x. IndependentTensor -> Rep IndependentTensor x)
-> (forall x. Rep IndependentTensor x -> IndependentTensor)
-> Generic IndependentTensor
forall x. Rep IndependentTensor x -> IndependentTensor
forall x. IndependentTensor -> Rep IndependentTensor x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. IndependentTensor -> Rep IndependentTensor x
from :: forall x. IndependentTensor -> Rep IndependentTensor x
$cto :: forall x. Rep IndependentTensor x -> IndependentTensor
to :: forall x. Rep IndependentTensor x -> IndependentTensor
Generic)

grad :: Tensor -> [IndependentTensor] -> [Tensor]
grad :: Tensor -> [IndependentTensor] -> [Tensor]
grad Tensor
y [IndependentTensor]
inputs = IO [Tensor] -> [Tensor]
forall a. IO a -> a
unsafePerformIO (IO [Tensor] -> [Tensor]) -> IO [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor
 -> ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Tensor -> [Tensor] -> IO [Tensor]
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
Torch.Internal.Managed.Autograd.grad Tensor
y ((IndependentTensor -> Tensor) -> [IndependentTensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map IndependentTensor -> Tensor
toDependent [IndependentTensor]
inputs)

requiresGrad :: Tensor -> Bool
requiresGrad :: Tensor -> Bool
requiresGrad Tensor
t = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> IO CBool) -> Tensor -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_requires_grad Tensor
t

setRequiresGrad :: Bool -> Tensor -> Tensor
setRequiresGrad :: Bool -> Tensor -> Tensor
setRequiresGrad Bool
flag Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor))
-> Tensor -> Bool -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_set_requires_grad_b Tensor
t Bool
flag

makeIndependent :: Tensor -> IO IndependentTensor
makeIndependent :: Tensor -> IO IndependentTensor
makeIndependent Tensor
tensor = Tensor -> Bool -> IO IndependentTensor
makeIndependentWithRequiresGrad Tensor
tensor Bool
True

makeIndependentWithRequiresGrad :: Tensor -> Bool -> IO IndependentTensor
makeIndependentWithRequiresGrad :: Tensor -> Bool -> IO IndependentTensor
makeIndependentWithRequiresGrad Tensor
tensor Bool
requires_grad = Tensor -> IndependentTensor
IndependentTensor (Tensor -> IndependentTensor) -> IO Tensor -> IO IndependentTensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor))
-> Tensor -> Bool -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> CBool -> IO (ForeignPtr Tensor)
Torch.Internal.Managed.Autograd.makeIndependent Tensor
tensor Bool
requires_grad