{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}

module Torch.Typed.Autograd
  ( Torch.Typed.Autograd.HasGrad,
    Torch.Typed.Autograd.grad,
  )
where

import Data.Kind
import GHC.TypeLits
import System.IO.Unsafe
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Autograd as LibTorch
import qualified Torch.Tensor as D
import Torch.Typed.Parameter
import Torch.Typed.Tensor

class HasGrad a b | a -> b where
  -- | calculate gradients of a zero-dimensional tensor with respect to a list of parameters
  grad :: forall dtype device. Tensor device dtype '[] -> a -> b

  toDependent :: a -> b

-- instance HasGrad (Tensor device dtype shape) (Tensor device dtype shape) where
--   grad loss input = head . unsafePerformIO $ ATen.cast2
--     Torch.Managed.Autograd.grad
--     loss
--     [Torch.Typed.Autograd.toDependent input]
--   toDependent = id

instance HasGrad (Parameter device dtype shape) (Tensor device dtype shape) where
  grad :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[]
-> Parameter device dtype shape -> Tensor device dtype shape
grad Tensor device dtype '[]
loss Parameter device dtype shape
input =
    [Tensor device dtype shape] -> Tensor device dtype shape
forall a. HasCallStack => [a] -> a
head ([Tensor device dtype shape] -> Tensor device dtype shape)
-> (IO [Tensor device dtype shape] -> [Tensor device dtype shape])
-> IO [Tensor device dtype shape]
-> Tensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO [Tensor device dtype shape] -> [Tensor device dtype shape]
forall a. IO a -> a
unsafePerformIO (IO [Tensor device dtype shape] -> Tensor device dtype shape)
-> IO [Tensor device dtype shape] -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$
      (ForeignPtr Tensor
 -> ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Tensor device dtype '[]
-> [Tensor device dtype shape]
-> IO [Tensor device dtype shape]
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
        ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
LibTorch.grad
        Tensor device dtype '[]
loss
        [Parameter device dtype shape -> Tensor device dtype shape
forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent Parameter device dtype shape
input]
  toDependent :: Parameter device dtype shape -> Tensor device dtype shape
toDependent = Parameter device dtype shape -> Tensor device dtype shape
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
Torch.Typed.Parameter.toDependent

instance HasGrad (HList ('[] :: [Type])) (HList ('[] :: [Type])) where
  grad :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[] -> HList '[] -> HList '[]
grad Tensor device dtype '[]
_ = HList '[] -> HList '[]
forall a. a -> a
id
  toDependent :: HList '[] -> HList '[]
toDependent = HList '[] -> HList '[]
forall a. a -> a
id

instance
  ( HasGrad a b,
    HasGrad (HList as) (HList bs),
    ATen.Castable (HList (b ': bs)) [D.ATenTensor]
  ) =>
  HasGrad (HList (a ': as)) (HList (b ': bs))
  where
  grad :: forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[] -> HList (a : as) -> HList (b : bs)
grad Tensor device dtype '[]
loss HList (a : as)
inputs =
    IO (HList (b : bs)) -> HList (b : bs)
forall a. IO a -> a
unsafePerformIO (IO (HList (b : bs)) -> HList (b : bs))
-> IO (HList (b : bs)) -> HList (b : bs)
forall a b. (a -> b) -> a -> b
$
      (ForeignPtr Tensor
 -> ForeignPtr TensorList -> IO (ForeignPtr TensorList))
-> Tensor device dtype '[] -> HList (b : bs) -> IO (HList (b : bs))
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2
        ForeignPtr Tensor
-> ForeignPtr TensorList -> IO (ForeignPtr TensorList)
LibTorch.grad
        Tensor device dtype '[]
loss
        (HList (a : as) -> HList (b : bs)
forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent HList (a : as)
inputs)
  toDependent :: HList (a : as) -> HList (b : bs)
toDependent (a
a :. HList as
as) =
    a -> b
forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent a
a b -> HList bs -> HList (b : bs)
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. HList as -> HList bs
forall a b. HasGrad a b => a -> b
Torch.Typed.Autograd.toDependent HList as
as