{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Torch.Typed.Parameter ( module Torch.Typed.Parameter, Torch.NN.Randomizable (..), ) where import Control.Monad.State.Strict import Data.Kind (Type) import GHC.Generics import GHC.TypeLits import GHC.TypeLits.Extra import qualified Torch.Autograd (IndependentTensor (..), makeIndependent) import Torch.DType (DType) import Torch.Device (DeviceType) import Torch.HList import qualified Torch.NN (Parameter, Randomizable (..), sample) import qualified Torch.Tensor (toType, _toDevice) import Torch.Typed.Auxiliary import Torch.Typed.Factories import Torch.Typed.Functional import Torch.Typed.Tensor newtype Parameter (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]) = UnsafeMkParameter Torch.Autograd.IndependentTensor deriving (Int -> Parameter device dtype shape -> ShowS [Parameter device dtype shape] -> ShowS Parameter device dtype shape -> String (Int -> Parameter device dtype shape -> ShowS) -> (Parameter device dtype shape -> String) -> ([Parameter device dtype shape] -> ShowS) -> Show (Parameter device dtype shape) forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Int -> Parameter device dtype shape -> ShowS forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). [Parameter device dtype shape] -> ShowS forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Parameter device dtype shape -> String $cshowsPrec :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Int -> Parameter device dtype shape -> ShowS showsPrec :: Int -> Parameter device dtype shape -> ShowS $cshow :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Parameter device dtype shape -> String show :: Parameter device dtype shape -> String $cshowList :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). [Parameter device dtype shape] -> ShowS showList :: [Parameter device dtype shape] -> ShowS Show) untypeParam :: Parameter device dtype shape -> Torch.NN.Parameter untypeParam :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Parameter device dtype shape -> IndependentTensor untypeParam (UnsafeMkParameter IndependentTensor param) = IndependentTensor param toDependent :: forall shape dtype device. Parameter device dtype shape -> Tensor device dtype shape toDependent :: forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Parameter device dtype shape -> Tensor device dtype shape toDependent (UnsafeMkParameter IndependentTensor t) = Tensor -> Tensor device dtype shape forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Tensor -> Tensor device dtype shape UnsafeMkTensor (Tensor -> Tensor device dtype shape) -> Tensor -> Tensor device dtype shape forall a b. (a -> b) -> a -> b $ IndependentTensor -> Tensor Torch.Autograd.toDependent IndependentTensor t data ToDependent = ToDependent instance Apply' ToDependent (Parameter device dtype shape) (Tensor device dtype shape) where apply' :: ToDependent -> Parameter device dtype shape -> Tensor device dtype shape apply' ToDependent _ = Parameter device dtype shape -> Tensor device dtype shape forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Parameter device dtype shape -> Tensor device dtype shape toDependent makeIndependent :: forall shape dtype device. Tensor device dtype shape -> IO (Parameter device dtype shape) makeIndependent :: forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Tensor device dtype shape -> IO (Parameter device dtype shape) makeIndependent Tensor device dtype shape t = IndependentTensor -> Parameter device dtype shape forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). IndependentTensor -> Parameter device dtype shape UnsafeMkParameter (IndependentTensor -> Parameter device dtype shape) -> IO IndependentTensor -> IO (Parameter device dtype shape) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> Tensor -> IO IndependentTensor Torch.Autograd.makeIndependent (Tensor device dtype shape -> Tensor forall t. Unnamed t => t -> Tensor toDynamic Tensor device dtype shape t) data MakeIndependent = MakeIndependent instance Apply' MakeIndependent (Tensor device dtype shape) (IO (Parameter device dtype shape)) where apply' :: MakeIndependent -> Tensor device dtype shape -> IO (Parameter device dtype shape) apply' MakeIndependent _ = Tensor device dtype shape -> IO (Parameter device dtype shape) forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Tensor device dtype shape -> IO (Parameter device dtype shape) makeIndependent parameterToDevice :: forall device' device dtype shape. KnownDevice device' => Parameter device dtype shape -> Parameter device' dtype shape parameterToDevice :: forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). KnownDevice device' => Parameter device dtype shape -> Parameter device' dtype shape parameterToDevice (UnsafeMkParameter IndependentTensor t) = IndependentTensor -> Parameter device' dtype shape forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). IndependentTensor -> Parameter device dtype shape UnsafeMkParameter (IndependentTensor -> Parameter device' dtype shape) -> (IndependentTensor -> IndependentTensor) -> IndependentTensor -> Parameter device' dtype shape forall b c a. (b -> c) -> (a -> b) -> a -> c . Tensor -> IndependentTensor Torch.Autograd.IndependentTensor (Tensor -> IndependentTensor) -> (IndependentTensor -> Tensor) -> IndependentTensor -> IndependentTensor forall b c a. (b -> c) -> (a -> b) -> a -> c . Device -> Tensor -> Tensor Torch.Tensor._toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device deviceVal @device') (Tensor -> Tensor) -> (IndependentTensor -> Tensor) -> IndependentTensor -> Tensor forall b c a. (b -> c) -> (a -> b) -> a -> c . IndependentTensor -> Tensor Torch.Autograd.toDependent (IndependentTensor -> Parameter device' dtype shape) -> IndependentTensor -> Parameter device' dtype shape forall a b. (a -> b) -> a -> b $ IndependentTensor t parameterToDType :: forall dtype' dtype device shape. KnownDType dtype' => Parameter device dtype shape -> Parameter device dtype' shape parameterToDType :: forall (dtype' :: DType) (dtype :: DType) (device :: (DeviceType, Nat)) (shape :: [Nat]). KnownDType dtype' => Parameter device dtype shape -> Parameter device dtype' shape parameterToDType (UnsafeMkParameter IndependentTensor t) = IndependentTensor -> Parameter device dtype' shape forall (device :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). IndependentTensor -> Parameter device dtype shape UnsafeMkParameter (IndependentTensor -> Parameter device dtype' shape) -> (IndependentTensor -> IndependentTensor) -> IndependentTensor -> Parameter device dtype' shape forall b c a. (b -> c) -> (a -> b) -> a -> c . Tensor -> IndependentTensor Torch.Autograd.IndependentTensor (Tensor -> IndependentTensor) -> (IndependentTensor -> Tensor) -> IndependentTensor -> IndependentTensor forall b c a. (b -> c) -> (a -> b) -> a -> c . DType -> Tensor -> Tensor forall a. HasTypes a Tensor => DType -> a -> a Torch.Tensor.toType (forall (dtype :: DType). KnownDType dtype => DType dtypeVal @dtype') (Tensor -> Tensor) -> (IndependentTensor -> Tensor) -> IndependentTensor -> Tensor forall b c a. (b -> c) -> (a -> b) -> a -> c . IndependentTensor -> Tensor Torch.Autograd.toDependent (IndependentTensor -> Parameter device dtype' shape) -> IndependentTensor -> Parameter device dtype' shape forall a b. (a -> b) -> a -> b $ IndependentTensor t class Parameterized (f :: Type) where type Parameters f :: [Type] type Parameters f = GParameters (Rep f) flattenParameters :: f -> HList (Parameters f) default flattenParameters :: (Generic f, GParameterized (Rep f), Parameters f ~ GParameters (Rep f)) => f -> HList (Parameters f) flattenParameters f f = Rep f Any -> HList (GParameters (Rep f)) forall a. Rep f a -> HList (GParameters (Rep f)) forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) gFlattenParameters (f -> Rep f Any forall x. f -> Rep f x forall a x. Generic a => a -> Rep a x from f f) replaceParameters :: f -> HList (Parameters f) -> f default replaceParameters :: (Generic f, GParameterized (Rep f), Parameters f ~ GParameters (Rep f)) => f -> HList (Parameters f) -> f replaceParameters f f HList (Parameters f) as = Rep f Any -> f forall a x. Generic a => Rep a x -> a forall x. Rep f x -> f to (Rep f Any -> HList (GParameters (Rep f)) -> Rep f Any forall a. Rep f a -> HList (GParameters (Rep f)) -> Rep f a forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) -> f a gReplaceParameters (f -> Rep f Any forall x. f -> Rep f x forall a x. Generic a => a -> Rep a x from f f) HList (GParameters (Rep f)) HList (Parameters f) as) class GParameterized (f :: Type -> Type) where type GParameters f :: [Type] gFlattenParameters :: forall a. f a -> HList (GParameters f) gReplaceParameters :: forall a. f a -> HList (GParameters f) -> f a instance ( GParameterized l, GParameterized r, HAppendFD (GParameters l) (GParameters r) (GParameters l ++ GParameters r) ) => GParameterized (l :*: r) where type GParameters (l :*: r) = (GParameters l) ++ (GParameters r) gFlattenParameters :: forall a. (:*:) l r a -> HList (GParameters (l :*: r)) gFlattenParameters (l a l :*: r a r) = let as :: HList (GParameters l) as = l a -> HList (GParameters l) forall a. l a -> HList (GParameters l) forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) gFlattenParameters l a l bs :: HList (GParameters r) bs = r a -> HList (GParameters r) forall a. r a -> HList (GParameters r) forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) gFlattenParameters r a r in HList (GParameters l) as HList (GParameters l) -> HList (GParameters r) -> HList (GParameters l ++ GParameters r) forall k (a :: [k]) (b :: [k]) (ab :: [k]). HAppendFD a b ab => HList a -> HList b -> HList ab `happendFD` HList (GParameters r) bs gReplaceParameters :: forall a. (:*:) l r a -> HList (GParameters (l :*: r)) -> (:*:) l r a gReplaceParameters (l a l :*: r a r) HList (GParameters (l :*: r)) cs = let (HList (GParameters l) as, HList (GParameters r) bs) = HList (GParameters l ++ GParameters r) -> (HList (GParameters l), HList (GParameters r)) forall k (a :: [k]) (b :: [k]) (ab :: [k]). HAppendFD a b ab => HList ab -> (HList a, HList b) hunappendFD HList (GParameters l ++ GParameters r) HList (GParameters (l :*: r)) cs l' :: l a l' = l a -> HList (GParameters l) -> l a forall a. l a -> HList (GParameters l) -> l a forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) -> f a gReplaceParameters l a l HList (GParameters l) as r' :: r a r' = r a -> HList (GParameters r) -> r a forall a. r a -> HList (GParameters r) -> r a forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) -> f a gReplaceParameters r a r HList (GParameters r) bs in l a l' l a -> r a -> (:*:) l r a forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> g p -> (:*:) f g p :*: r a r' instance Parameterized f => GParameterized (K1 i f) where type GParameters (K1 i f) = Parameters f gFlattenParameters :: forall a. K1 i f a -> HList (GParameters (K1 i f)) gFlattenParameters = f -> HList (Parameters f) forall f. Parameterized f => f -> HList (Parameters f) flattenParameters (f -> HList (Parameters f)) -> (K1 i f a -> f) -> K1 i f a -> HList (Parameters f) forall b c a. (b -> c) -> (a -> b) -> a -> c . K1 i f a -> f forall k i c (p :: k). K1 i c p -> c unK1 gReplaceParameters :: forall a. K1 i f a -> HList (GParameters (K1 i f)) -> K1 i f a gReplaceParameters (K1 f f) = f -> K1 i f a forall k i c (p :: k). c -> K1 i c p K1 (f -> K1 i f a) -> (HList (Parameters f) -> f) -> HList (Parameters f) -> K1 i f a forall b c a. (b -> c) -> (a -> b) -> a -> c . f -> HList (Parameters f) -> f forall f. Parameterized f => f -> HList (Parameters f) -> f replaceParameters f f instance GParameterized f => GParameterized (M1 i t f) where type GParameters (M1 i t f) = GParameters f gFlattenParameters :: forall a. M1 i t f a -> HList (GParameters (M1 i t f)) gFlattenParameters = f a -> HList (GParameters f) forall a. f a -> HList (GParameters f) forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) gFlattenParameters (f a -> HList (GParameters f)) -> (M1 i t f a -> f a) -> M1 i t f a -> HList (GParameters f) forall b c a. (b -> c) -> (a -> b) -> a -> c . M1 i t f a -> f a forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p unM1 gReplaceParameters :: forall a. M1 i t f a -> HList (GParameters (M1 i t f)) -> M1 i t f a gReplaceParameters (M1 f a f) = f a -> M1 i t f a forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p M1 (f a -> M1 i t f a) -> (HList (GParameters f) -> f a) -> HList (GParameters f) -> M1 i t f a forall b c a. (b -> c) -> (a -> b) -> a -> c . f a -> HList (GParameters f) -> f a forall a. f a -> HList (GParameters f) -> f a forall (f :: * -> *) a. GParameterized f => f a -> HList (GParameters f) -> f a gReplaceParameters f a f instance GParameterized U1 where type GParameters U1 = '[] gFlattenParameters :: forall a. U1 a -> HList (GParameters U1) gFlattenParameters U1 a _ = HList '[] HList (GParameters U1) forall k. HList '[] HNil gReplaceParameters :: forall a. U1 a -> HList (GParameters U1) -> U1 a gReplaceParameters = U1 a -> HList '[] -> U1 a U1 a -> HList (GParameters U1) -> U1 a forall a b. a -> b -> a const instance Parameterized (Tensor device dtype shape) where type Parameters (Tensor device dtype shape) = '[] flattenParameters :: Tensor device dtype shape -> HList (Parameters (Tensor device dtype shape)) flattenParameters Tensor device dtype shape _ = HList '[] HList (Parameters (Tensor device dtype shape)) forall k. HList '[] HNil replaceParameters :: Tensor device dtype shape -> HList (Parameters (Tensor device dtype shape)) -> Tensor device dtype shape replaceParameters = Tensor device dtype shape -> HList '[] -> Tensor device dtype shape Tensor device dtype shape -> HList (Parameters (Tensor device dtype shape)) -> Tensor device dtype shape forall a b. a -> b -> a const instance Parameterized (Parameter device dtype shape) where type Parameters (Parameter device dtype shape) = '[Parameter device dtype shape] flattenParameters :: Parameter device dtype shape -> HList (Parameters (Parameter device dtype shape)) flattenParameters = (Parameter device dtype shape -> HList '[] -> HList '[Parameter device dtype shape] forall x (xs :: [*]). x -> HList xs -> HList (x : xs) :. HList '[] forall k. HList '[] HNil) replaceParameters :: Parameter device dtype shape -> HList (Parameters (Parameter device dtype shape)) -> Parameter device dtype shape replaceParameters Parameter device dtype shape _ (Parameter device dtype shape parameter :. HList '[] R:HListk[] (*) HNil) = Parameter device dtype shape parameter instance Parameterized Int where type Parameters Int = '[] flattenParameters :: Int -> HList (Parameters Int) flattenParameters Int _ = HList '[] HList (Parameters Int) forall k. HList '[] HNil replaceParameters :: Int -> HList (Parameters Int) -> Int replaceParameters = Int -> HList '[] -> Int Int -> HList (Parameters Int) -> Int forall a b. a -> b -> a const instance Parameterized Float where type Parameters Float = '[] flattenParameters :: Float -> HList (Parameters Float) flattenParameters Float _ = HList '[] HList (Parameters Float) forall k. HList '[] HNil replaceParameters :: Float -> HList (Parameters Float) -> Float replaceParameters = Float -> HList '[] -> Float Float -> HList (Parameters Float) -> Float forall a b. a -> b -> a const instance Parameterized Double where type Parameters Double = '[] flattenParameters :: Double -> HList (Parameters Double) flattenParameters Double _ = HList '[] HList (Parameters Double) forall k. HList '[] HNil replaceParameters :: Double -> HList (Parameters Double) -> Double replaceParameters = Double -> HList '[] -> Double Double -> HList (Parameters Double) -> Double forall a b. a -> b -> a const instance Parameterized (HList '[]) where type Parameters (HList '[]) = '[] flattenParameters :: HList '[] -> HList (Parameters (HList '[])) flattenParameters HList '[] _ = HList '[] HList (Parameters (HList '[])) forall k. HList '[] HNil replaceParameters :: HList '[] -> HList (Parameters (HList '[])) -> HList '[] replaceParameters = HList '[] -> HList '[] -> HList '[] HList '[] -> HList (Parameters (HList '[])) -> HList '[] forall a b. a -> b -> a const instance ( Parameterized f, Parameterized (HList fs), HAppendFD (Parameters f) (Parameters (HList fs)) (Parameters f ++ Parameters (HList fs)) ) => Parameterized (HList (f ': fs)) where type Parameters (HList (f ': fs)) = Parameters f ++ Parameters (HList fs) flattenParameters :: HList (f : fs) -> HList (Parameters (HList (f : fs))) flattenParameters (f f :. HList fs fs) = f -> HList (Parameters f) forall f. Parameterized f => f -> HList (Parameters f) flattenParameters f f HList (Parameters f) -> HList (Parameters (HList fs)) -> HList (Parameters f ++ Parameters (HList fs)) forall k (a :: [k]) (b :: [k]) (ab :: [k]). HAppendFD a b ab => HList a -> HList b -> HList ab `happendFD` HList fs -> HList (Parameters (HList fs)) forall f. Parameterized f => f -> HList (Parameters f) flattenParameters HList fs fs replaceParameters :: HList (f : fs) -> HList (Parameters (HList (f : fs))) -> HList (f : fs) replaceParameters (f f :. HList fs fs) HList (Parameters (HList (f : fs))) cs = let (HList (Parameters f) as, HList (Parameters (HList fs)) bs) = HList (Parameters f ++ Parameters (HList fs)) -> (HList (Parameters f), HList (Parameters (HList fs))) forall k (a :: [k]) (b :: [k]) (ab :: [k]). HAppendFD a b ab => HList ab -> (HList a, HList b) hunappendFD HList (Parameters f ++ Parameters (HList fs)) HList (Parameters (HList (f : fs))) cs f' :: f f' = f -> HList (Parameters f) -> f forall f. Parameterized f => f -> HList (Parameters f) -> f replaceParameters f f HList (Parameters f) as fs' :: HList fs fs' = HList fs -> HList (Parameters (HList fs)) -> HList fs forall f. Parameterized f => f -> HList (Parameters f) -> f replaceParameters HList fs fs HList (Parameters (HList fs)) bs in f f' f -> HList fs -> HList (f : fs) forall x (xs :: [*]). x -> HList xs -> HList (x : xs) :. HList fs fs' instance Torch.NN.Randomizable (HList ('[] :: [Type])) (HList ('[] :: [Type])) where sample :: HList '[] -> IO (HList '[]) sample = HList '[] -> IO (HList '[]) forall a. a -> IO a forall (m :: * -> *) a. Monad m => a -> m a return instance ( Torch.NN.Randomizable xSpec x, Torch.NN.Randomizable (HList xsSpec) (HList xs) ) => Torch.NN.Randomizable (HList (xSpec ': xsSpec)) (HList (x ': xs)) where sample :: HList (xSpec : xsSpec) -> IO (HList (x : xs)) sample (xSpec xSpec :. HList xsSpec xsSpec) = do x x <- xSpec -> IO x forall spec f. Randomizable spec f => spec -> IO f Torch.NN.sample xSpec xSpec HList xs xs <- HList xsSpec -> IO (HList xs) forall spec f. Randomizable spec f => spec -> IO f Torch.NN.sample HList xsSpec xsSpec HList (x : xs) -> IO (HList (x : xs)) forall a. a -> IO a forall (m :: * -> *) a. Monad m => a -> m a return (HList (x : xs) -> IO (HList (x : xs))) -> HList (x : xs) -> IO (HList (x : xs)) forall a b. (a -> b) -> a -> b $ x x x -> HList xs -> HList (x : xs) forall x (xs :: [*]). x -> HList xs -> HList (x : xs) :. HList xs xs