{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Typed.Parameter
  ( module Torch.Typed.Parameter,
    Torch.NN.Randomizable (..),
  )
where

import Control.Monad.State.Strict
import Data.Kind (Type)
import GHC.Generics
import GHC.TypeLits
import GHC.TypeLits.Extra
import qualified Torch.Autograd (IndependentTensor (..), makeIndependent)
import Torch.DType (DType)
import Torch.Device (DeviceType)
import Torch.HList
import qualified Torch.NN (Parameter, Randomizable (..), sample)
import qualified Torch.Tensor (toType, _toDevice)
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Tensor

newtype
  Parameter
    (device :: (DeviceType, Nat))
    (dtype :: DType)
    (shape :: [Nat])
  = UnsafeMkParameter Torch.Autograd.IndependentTensor
  deriving (Int -> Parameter device dtype shape -> ShowS
[Parameter device dtype shape] -> ShowS
Parameter device dtype shape -> String
(Int -> Parameter device dtype shape -> ShowS)
-> (Parameter device dtype shape -> String)
-> ([Parameter device dtype shape] -> ShowS)
-> Show (Parameter device dtype shape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Int -> Parameter device dtype shape -> ShowS
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
[Parameter device dtype shape] -> ShowS
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Parameter device dtype shape -> String
$cshowsPrec :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Int -> Parameter device dtype shape -> ShowS
showsPrec :: Int -> Parameter device dtype shape -> ShowS
$cshow :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Parameter device dtype shape -> String
show :: Parameter device dtype shape -> String
$cshowList :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
[Parameter device dtype shape] -> ShowS
showList :: [Parameter device dtype shape] -> ShowS
Show)

untypeParam :: Parameter device dtype shape -> Torch.NN.Parameter
untypeParam :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Parameter device dtype shape -> IndependentTensor
untypeParam (UnsafeMkParameter IndependentTensor
param) = IndependentTensor
param

toDependent ::
  forall shape dtype device.
  Parameter device dtype shape ->
  Tensor device dtype shape
toDependent :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent (UnsafeMkParameter IndependentTensor
t) = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ IndependentTensor -> Tensor
Torch.Autograd.toDependent IndependentTensor
t

data ToDependent = ToDependent

instance Apply' ToDependent (Parameter device dtype shape) (Tensor device dtype shape) where
  apply' :: ToDependent
-> Parameter device dtype shape -> Tensor device dtype shape
apply' ToDependent
_ = Parameter device dtype shape -> Tensor device dtype shape
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent

makeIndependent ::
  forall shape dtype device.
  Tensor device dtype shape ->
  IO (Parameter device dtype shape)
makeIndependent :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent Tensor device dtype shape
t = IndependentTensor -> Parameter device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IndependentTensor -> Parameter device dtype shape
UnsafeMkParameter (IndependentTensor -> Parameter device dtype shape)
-> IO IndependentTensor -> IO (Parameter device dtype shape)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> IO IndependentTensor
Torch.Autograd.makeIndependent (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)

data MakeIndependent = MakeIndependent

instance
  Apply'
    MakeIndependent
    (Tensor device dtype shape)
    (IO (Parameter device dtype shape))
  where
  apply' :: MakeIndependent
-> Tensor device dtype shape -> IO (Parameter device dtype shape)
apply' MakeIndependent
_ = Tensor device dtype shape -> IO (Parameter device dtype shape)
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent

parameterToDevice ::
  forall device' device dtype shape.
  KnownDevice device' =>
  Parameter device dtype shape ->
  Parameter device' dtype shape
parameterToDevice :: forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
       (dtype :: DType) (shape :: [Nat]).
KnownDevice device' =>
Parameter device dtype shape -> Parameter device' dtype shape
parameterToDevice (UnsafeMkParameter IndependentTensor
t) =
  IndependentTensor -> Parameter device' dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IndependentTensor -> Parameter device dtype shape
UnsafeMkParameter
    (IndependentTensor -> Parameter device' dtype shape)
-> (IndependentTensor -> IndependentTensor)
-> IndependentTensor
-> Parameter device' dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> IndependentTensor
Torch.Autograd.IndependentTensor
    (Tensor -> IndependentTensor)
-> (IndependentTensor -> Tensor)
-> IndependentTensor
-> IndependentTensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Device -> Tensor -> Tensor
Torch.Tensor._toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device')
    (Tensor -> Tensor)
-> (IndependentTensor -> Tensor) -> IndependentTensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IndependentTensor -> Tensor
Torch.Autograd.toDependent
    (IndependentTensor -> Parameter device' dtype shape)
-> IndependentTensor -> Parameter device' dtype shape
forall a b. (a -> b) -> a -> b
$ IndependentTensor
t

parameterToDType ::
  forall dtype' dtype device shape.
  KnownDType dtype' =>
  Parameter device dtype shape ->
  Parameter device dtype' shape
parameterToDType :: forall (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape :: [Nat]).
KnownDType dtype' =>
Parameter device dtype shape -> Parameter device dtype' shape
parameterToDType (UnsafeMkParameter IndependentTensor
t) =
  IndependentTensor -> Parameter device dtype' shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
IndependentTensor -> Parameter device dtype shape
UnsafeMkParameter
    (IndependentTensor -> Parameter device dtype' shape)
-> (IndependentTensor -> IndependentTensor)
-> IndependentTensor
-> Parameter device dtype' shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> IndependentTensor
Torch.Autograd.IndependentTensor
    (Tensor -> IndependentTensor)
-> (IndependentTensor -> Tensor)
-> IndependentTensor
-> IndependentTensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> Tensor -> Tensor
forall a. HasTypes a Tensor => DType -> a -> a
Torch.Tensor.toType (forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype')
    (Tensor -> Tensor)
-> (IndependentTensor -> Tensor) -> IndependentTensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IndependentTensor -> Tensor
Torch.Autograd.toDependent
    (IndependentTensor -> Parameter device dtype' shape)
-> IndependentTensor -> Parameter device dtype' shape
forall a b. (a -> b) -> a -> b
$ IndependentTensor
t

class Parameterized (f :: Type) where
  type Parameters f :: [Type]
  type Parameters f = GParameters (Rep f)
  flattenParameters :: f -> HList (Parameters f)
  default flattenParameters ::
    (Generic f, GParameterized (Rep f), Parameters f ~ GParameters (Rep f)) =>
    f ->
    HList (Parameters f)
  flattenParameters f
f = Rep f Any -> HList (GParameters (Rep f))
forall a. Rep f a -> HList (GParameters (Rep f))
forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f)
gFlattenParameters (f -> Rep f Any
forall x. f -> Rep f x
forall a x. Generic a => a -> Rep a x
from f
f)
  replaceParameters :: f -> HList (Parameters f) -> f
  default replaceParameters ::
    (Generic f, GParameterized (Rep f), Parameters f ~ GParameters (Rep f)) =>
    f ->
    HList (Parameters f) ->
    f
  replaceParameters f
f HList (Parameters f)
as = Rep f Any -> f
forall a x. Generic a => Rep a x -> a
forall x. Rep f x -> f
to (Rep f Any -> HList (GParameters (Rep f)) -> Rep f Any
forall a. Rep f a -> HList (GParameters (Rep f)) -> Rep f a
forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f) -> f a
gReplaceParameters (f -> Rep f Any
forall x. f -> Rep f x
forall a x. Generic a => a -> Rep a x
from f
f) HList (GParameters (Rep f))
HList (Parameters f)
as)

class GParameterized (f :: Type -> Type) where
  type GParameters f :: [Type]
  gFlattenParameters :: forall a. f a -> HList (GParameters f)
  gReplaceParameters :: forall a. f a -> HList (GParameters f) -> f a

instance
  ( GParameterized l,
    GParameterized r,
    HAppendFD (GParameters l) (GParameters r) (GParameters l ++ GParameters r)
  ) =>
  GParameterized (l :*: r)
  where
  type GParameters (l :*: r) = (GParameters l) ++ (GParameters r)
  gFlattenParameters :: forall a. (:*:) l r a -> HList (GParameters (l :*: r))
gFlattenParameters (l a
l :*: r a
r) =
    let as :: HList (GParameters l)
as = l a -> HList (GParameters l)
forall a. l a -> HList (GParameters l)
forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f)
gFlattenParameters l a
l
        bs :: HList (GParameters r)
bs = r a -> HList (GParameters r)
forall a. r a -> HList (GParameters r)
forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f)
gFlattenParameters r a
r
     in HList (GParameters l)
as HList (GParameters l)
-> HList (GParameters r) -> HList (GParameters l ++ GParameters r)
forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList a -> HList b -> HList ab
`happendFD` HList (GParameters r)
bs
  gReplaceParameters :: forall a.
(:*:) l r a -> HList (GParameters (l :*: r)) -> (:*:) l r a
gReplaceParameters (l a
l :*: r a
r) HList (GParameters (l :*: r))
cs =
    let (HList (GParameters l)
as, HList (GParameters r)
bs) = HList (GParameters l ++ GParameters r)
-> (HList (GParameters l), HList (GParameters r))
forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList ab -> (HList a, HList b)
hunappendFD HList (GParameters l ++ GParameters r)
HList (GParameters (l :*: r))
cs
        l' :: l a
l' = l a -> HList (GParameters l) -> l a
forall a. l a -> HList (GParameters l) -> l a
forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f) -> f a
gReplaceParameters l a
l HList (GParameters l)
as
        r' :: r a
r' = r a -> HList (GParameters r) -> r a
forall a. r a -> HList (GParameters r) -> r a
forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f) -> f a
gReplaceParameters r a
r HList (GParameters r)
bs
     in l a
l' l a -> r a -> (:*:) l r a
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: r a
r'

instance
  Parameterized f =>
  GParameterized (K1 i f)
  where
  type GParameters (K1 i f) = Parameters f
  gFlattenParameters :: forall a. K1 i f a -> HList (GParameters (K1 i f))
gFlattenParameters = f -> HList (Parameters f)
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters (f -> HList (Parameters f))
-> (K1 i f a -> f) -> K1 i f a -> HList (Parameters f)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. K1 i f a -> f
forall k i c (p :: k). K1 i c p -> c
unK1
  gReplaceParameters :: forall a. K1 i f a -> HList (GParameters (K1 i f)) -> K1 i f a
gReplaceParameters (K1 f
f) = f -> K1 i f a
forall k i c (p :: k). c -> K1 i c p
K1 (f -> K1 i f a)
-> (HList (Parameters f) -> f) -> HList (Parameters f) -> K1 i f a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f -> HList (Parameters f) -> f
forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters f
f

instance GParameterized f => GParameterized (M1 i t f) where
  type GParameters (M1 i t f) = GParameters f
  gFlattenParameters :: forall a. M1 i t f a -> HList (GParameters (M1 i t f))
gFlattenParameters = f a -> HList (GParameters f)
forall a. f a -> HList (GParameters f)
forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f)
gFlattenParameters (f a -> HList (GParameters f))
-> (M1 i t f a -> f a) -> M1 i t f a -> HList (GParameters f)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. M1 i t f a -> f a
forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p
unM1
  gReplaceParameters :: forall a.
M1 i t f a -> HList (GParameters (M1 i t f)) -> M1 i t f a
gReplaceParameters (M1 f a
f) = f a -> M1 i t f a
forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 (f a -> M1 i t f a)
-> (HList (GParameters f) -> f a)
-> HList (GParameters f)
-> M1 i t f a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> HList (GParameters f) -> f a
forall a. f a -> HList (GParameters f) -> f a
forall (f :: * -> *) a.
GParameterized f =>
f a -> HList (GParameters f) -> f a
gReplaceParameters f a
f

instance GParameterized U1 where
  type GParameters U1 = '[]
  gFlattenParameters :: forall a. U1 a -> HList (GParameters U1)
gFlattenParameters U1 a
_ = HList '[]
HList (GParameters U1)
forall k. HList '[]
HNil
  gReplaceParameters :: forall a. U1 a -> HList (GParameters U1) -> U1 a
gReplaceParameters = U1 a -> HList '[] -> U1 a
U1 a -> HList (GParameters U1) -> U1 a
forall a b. a -> b -> a
const

instance Parameterized (Tensor device dtype shape) where
  type Parameters (Tensor device dtype shape) = '[]
  flattenParameters :: Tensor device dtype shape
-> HList (Parameters (Tensor device dtype shape))
flattenParameters Tensor device dtype shape
_ = HList '[]
HList (Parameters (Tensor device dtype shape))
forall k. HList '[]
HNil
  replaceParameters :: Tensor device dtype shape
-> HList (Parameters (Tensor device dtype shape))
-> Tensor device dtype shape
replaceParameters = Tensor device dtype shape -> HList '[] -> Tensor device dtype shape
Tensor device dtype shape
-> HList (Parameters (Tensor device dtype shape))
-> Tensor device dtype shape
forall a b. a -> b -> a
const

instance Parameterized (Parameter device dtype shape) where
  type Parameters (Parameter device dtype shape) = '[Parameter device dtype shape]
  flattenParameters :: Parameter device dtype shape
-> HList (Parameters (Parameter device dtype shape))
flattenParameters = (Parameter device dtype shape
-> HList '[] -> HList '[Parameter device dtype shape]
forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. HList '[]
forall k. HList '[]
HNil)
  replaceParameters :: Parameter device dtype shape
-> HList (Parameters (Parameter device dtype shape))
-> Parameter device dtype shape
replaceParameters Parameter device dtype shape
_ (Parameter device dtype shape
parameter :. HList '[]
R:HListk[] (*)
HNil) = Parameter device dtype shape
parameter

instance Parameterized Int where
  type Parameters Int = '[]
  flattenParameters :: Int -> HList (Parameters Int)
flattenParameters Int
_ = HList '[]
HList (Parameters Int)
forall k. HList '[]
HNil
  replaceParameters :: Int -> HList (Parameters Int) -> Int
replaceParameters = Int -> HList '[] -> Int
Int -> HList (Parameters Int) -> Int
forall a b. a -> b -> a
const

instance Parameterized Float where
  type Parameters Float = '[]
  flattenParameters :: Float -> HList (Parameters Float)
flattenParameters Float
_ = HList '[]
HList (Parameters Float)
forall k. HList '[]
HNil
  replaceParameters :: Float -> HList (Parameters Float) -> Float
replaceParameters = Float -> HList '[] -> Float
Float -> HList (Parameters Float) -> Float
forall a b. a -> b -> a
const

instance Parameterized Double where
  type Parameters Double = '[]
  flattenParameters :: Double -> HList (Parameters Double)
flattenParameters Double
_ = HList '[]
HList (Parameters Double)
forall k. HList '[]
HNil
  replaceParameters :: Double -> HList (Parameters Double) -> Double
replaceParameters = Double -> HList '[] -> Double
Double -> HList (Parameters Double) -> Double
forall a b. a -> b -> a
const

instance Parameterized (HList '[]) where
  type Parameters (HList '[]) = '[]
  flattenParameters :: HList '[] -> HList (Parameters (HList '[]))
flattenParameters HList '[]
_ = HList '[]
HList (Parameters (HList '[]))
forall k. HList '[]
HNil
  replaceParameters :: HList '[] -> HList (Parameters (HList '[])) -> HList '[]
replaceParameters = HList '[] -> HList '[] -> HList '[]
HList '[] -> HList (Parameters (HList '[])) -> HList '[]
forall a b. a -> b -> a
const

instance
  ( Parameterized f,
    Parameterized (HList fs),
    HAppendFD (Parameters f) (Parameters (HList fs)) (Parameters f ++ Parameters (HList fs))
  ) =>
  Parameterized (HList (f ': fs))
  where
  type Parameters (HList (f ': fs)) = Parameters f ++ Parameters (HList fs)
  flattenParameters :: HList (f : fs) -> HList (Parameters (HList (f : fs)))
flattenParameters (f
f :. HList fs
fs) = f -> HList (Parameters f)
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters f
f HList (Parameters f)
-> HList (Parameters (HList fs))
-> HList (Parameters f ++ Parameters (HList fs))
forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList a -> HList b -> HList ab
`happendFD` HList fs -> HList (Parameters (HList fs))
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters HList fs
fs
  replaceParameters :: HList (f : fs)
-> HList (Parameters (HList (f : fs))) -> HList (f : fs)
replaceParameters (f
f :. HList fs
fs) HList (Parameters (HList (f : fs)))
cs =
    let (HList (Parameters f)
as, HList (Parameters (HList fs))
bs) = HList (Parameters f ++ Parameters (HList fs))
-> (HList (Parameters f), HList (Parameters (HList fs)))
forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList ab -> (HList a, HList b)
hunappendFD HList (Parameters f ++ Parameters (HList fs))
HList (Parameters (HList (f : fs)))
cs
        f' :: f
f' = f -> HList (Parameters f) -> f
forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters f
f HList (Parameters f)
as
        fs' :: HList fs
fs' = HList fs -> HList (Parameters (HList fs)) -> HList fs
forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters HList fs
fs HList (Parameters (HList fs))
bs
     in f
f' f -> HList fs -> HList (f : fs)
forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. HList fs
fs'

instance Torch.NN.Randomizable (HList ('[] :: [Type])) (HList ('[] :: [Type])) where
  sample :: HList '[] -> IO (HList '[])
sample = HList '[] -> IO (HList '[])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return

instance
  ( Torch.NN.Randomizable xSpec x,
    Torch.NN.Randomizable (HList xsSpec) (HList xs)
  ) =>
  Torch.NN.Randomizable (HList (xSpec ': xsSpec)) (HList (x ': xs))
  where
  sample :: HList (xSpec : xsSpec) -> IO (HList (x : xs))
sample (xSpec
xSpec :. HList xsSpec
xsSpec) = do
    x
x <- xSpec -> IO x
forall spec f. Randomizable spec f => spec -> IO f
Torch.NN.sample xSpec
xSpec
    HList xs
xs <- HList xsSpec -> IO (HList xs)
forall spec f. Randomizable spec f => spec -> IO f
Torch.NN.sample HList xsSpec
xsSpec
    HList (x : xs) -> IO (HList (x : xs))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (HList (x : xs) -> IO (HList (x : xs)))
-> HList (x : xs) -> IO (HList (x : xs))
forall a b. (a -> b) -> a -> b
$ x
x x -> HList xs -> HList (x : xs)
forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. HList xs
xs