{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
module Torch.Typed.NN.Linear where
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.NN (HasForward (..), Randomizable (..))
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor
data
LinearSpec
(inputFeatures :: Nat)
(outputFeatures :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
= LinearSpec
deriving (Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
[LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
LinearSpec inputFeatures outputFeatures dtype device -> String
(Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS)
-> (LinearSpec inputFeatures outputFeatures dtype device -> String)
-> ([LinearSpec inputFeatures outputFeatures dtype device]
-> ShowS)
-> Show (LinearSpec inputFeatures outputFeatures dtype device)
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
showsPrec :: Int
-> LinearSpec inputFeatures outputFeatures dtype device -> ShowS
$cshow :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device -> String
show :: LinearSpec inputFeatures outputFeatures dtype device -> String
$cshowList :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
showList :: [LinearSpec inputFeatures outputFeatures dtype device] -> ShowS
Show, LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
(LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool)
-> (LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool)
-> Eq (LinearSpec inputFeatures outputFeatures dtype device)
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
== :: LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
$c/= :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
/= :: LinearSpec inputFeatures outputFeatures dtype device
-> LinearSpec inputFeatures outputFeatures dtype device -> Bool
Eq)
data
Linear
(inputFeatures :: Nat)
(outputFeatures :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
Linear ::
forall inputFeatures outputFeatures dtype device.
{ forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures, inputFeatures]
weight :: Parameter device dtype '[outputFeatures, inputFeatures],
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures]
bias :: Parameter device dtype '[outputFeatures]
} ->
Linear inputFeatures outputFeatures dtype device
deriving (Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
[Linear inputFeatures outputFeatures dtype device] -> ShowS
Linear inputFeatures outputFeatures dtype device -> String
(Int -> Linear inputFeatures outputFeatures dtype device -> ShowS)
-> (Linear inputFeatures outputFeatures dtype device -> String)
-> ([Linear inputFeatures outputFeatures dtype device] -> ShowS)
-> Show (Linear inputFeatures outputFeatures dtype device)
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[Linear inputFeatures outputFeatures dtype device] -> ShowS
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
showsPrec :: Int -> Linear inputFeatures outputFeatures dtype device -> ShowS
$cshow :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device -> String
show :: Linear inputFeatures outputFeatures dtype device -> String
$cshowList :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[Linear inputFeatures outputFeatures dtype device] -> ShowS
showList :: [Linear inputFeatures outputFeatures dtype device] -> ShowS
Show, (forall x.
Linear inputFeatures outputFeatures dtype device
-> Rep (Linear inputFeatures outputFeatures dtype device) x)
-> (forall x.
Rep (Linear inputFeatures outputFeatures dtype device) x
-> Linear inputFeatures outputFeatures dtype device)
-> Generic (Linear inputFeatures outputFeatures dtype device)
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep (Linear inputFeatures outputFeatures dtype device) x
-> Linear inputFeatures outputFeatures dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Linear inputFeatures outputFeatures dtype device
-> Rep (Linear inputFeatures outputFeatures dtype device) x
forall x.
Rep (Linear inputFeatures outputFeatures dtype device) x
-> Linear inputFeatures outputFeatures dtype device
forall x.
Linear inputFeatures outputFeatures dtype device
-> Rep (Linear inputFeatures outputFeatures dtype device) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Linear inputFeatures outputFeatures dtype device
-> Rep (Linear inputFeatures outputFeatures dtype device) x
from :: forall x.
Linear inputFeatures outputFeatures dtype device
-> Rep (Linear inputFeatures outputFeatures dtype device) x
$cto :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep (Linear inputFeatures outputFeatures dtype device) x
-> Linear inputFeatures outputFeatures dtype device
to :: forall x.
Rep (Linear inputFeatures outputFeatures dtype device) x
-> Linear inputFeatures outputFeatures dtype device
Generic, Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
(Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device)))
-> (Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device)
-> Parameterized (Linear inputFeatures outputFeatures dtype device)
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
$cflattenParameters :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
flattenParameters :: Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
$creplaceParameters :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
replaceParameters :: Linear inputFeatures outputFeatures dtype device
-> HList
(Parameters (Linear inputFeatures outputFeatures dtype device))
-> Linear inputFeatures outputFeatures dtype device
Parameterized)
linearForward ::
_ =>
Linear _ _ _ _ ->
Tensor _ _ _ ->
Tensor _ _ _
linearForward :: Linear inputFeatures outputFeatures w w
-> Tensor w w shape
-> Tensor
w
w
(CheckBroadcast
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
(ComputeBroadcast
(ReverseImpl
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
'[])
(ReverseImpl
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
'[])))
linearForward Linear {Parameter w w '[outputFeatures, inputFeatures]
Parameter w w '[outputFeatures]
weight :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures, inputFeatures]
bias :: forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Linear inputFeatures outputFeatures dtype device
-> Parameter device dtype '[outputFeatures]
weight :: Parameter w w '[outputFeatures, inputFeatures]
bias :: Parameter w w '[outputFeatures]
..} Tensor w w shape
input = Tensor w w '[outputFeatures, inputFeatures]
-> Tensor w w '[outputFeatures]
-> Tensor w w shape
-> Tensor
w
w
(CheckBroadcast
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
(ComputeBroadcast
(ReverseImpl
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
'[])
(ReverseImpl
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
'[])))
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape'' :: [Nat]).
(shape'' ~ MatMul shape '[inputFeatures, outputFeatures],
shape' ~ Broadcast shape'' shape'') =>
Tensor device dtype '[outputFeatures, inputFeatures]
-> Tensor device dtype '[outputFeatures]
-> Tensor device dtype shape
-> Tensor device dtype shape'
linear' (Parameter w w '[outputFeatures, inputFeatures]
-> Tensor w w '[outputFeatures, inputFeatures]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputFeatures, inputFeatures]
weight) (Parameter w w '[outputFeatures] -> Tensor w w '[outputFeatures]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter w w '[outputFeatures]
bias) Tensor w w shape
input
instance
( shape'' ~ MatMul shape '[inputFeatures, outputFeatures],
shape' ~ Broadcast shape'' shape''
) =>
HasForward (Linear inputFeatures outputFeatures dtype device) (Tensor device dtype shape) (Tensor device dtype shape')
where
forward :: Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape -> Tensor device dtype shape'
forward = Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape -> Tensor device dtype shape'
Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape
-> Tensor
device
dtype
(CheckBroadcast
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
(ComputeBroadcast
(ReverseImpl
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
'[])
(ReverseImpl
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
'[])))
forall {inputFeatures :: Nat} {outputFeatures :: Nat} {w :: DType}
{w :: (DeviceType, Nat)} {shape :: [Nat]}.
Linear inputFeatures outputFeatures w w
-> Tensor w w shape
-> Tensor
w
w
(CheckBroadcast
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
(ComputeBroadcast
(ReverseImpl
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
'[])
(ReverseImpl
(CheckMatMul
shape
'[inputFeatures, outputFeatures]
(ComputeMatMul
(ReverseImpl shape '[]) '[outputFeatures, inputFeatures]))
'[])))
linearForward
forwardStoch :: Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape -> IO (Tensor device dtype shape')
forwardStoch = (Tensor device dtype shape' -> IO (Tensor device dtype shape')
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor device dtype shape' -> IO (Tensor device dtype shape'))
-> (Tensor device dtype shape -> Tensor device dtype shape')
-> Tensor device dtype shape
-> IO (Tensor device dtype shape')
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((Tensor device dtype shape -> Tensor device dtype shape')
-> Tensor device dtype shape -> IO (Tensor device dtype shape'))
-> (Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape -> Tensor device dtype shape')
-> Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape
-> IO (Tensor device dtype shape')
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear inputFeatures outputFeatures dtype device
-> Tensor device dtype shape -> Tensor device dtype shape'
forall f a b. HasForward f a b => f -> a -> b
forward
instance
( KnownNat inputFeatures,
KnownNat outputFeatures,
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
Randomizable
(LinearSpec inputFeatures outputFeatures dtype device)
(Linear inputFeatures outputFeatures dtype device)
where
sample :: LinearSpec inputFeatures outputFeatures dtype device
-> IO (Linear inputFeatures outputFeatures dtype device)
sample LinearSpec inputFeatures outputFeatures dtype device
LinearSpec =
Parameter device dtype '[outputFeatures, inputFeatures]
-> Parameter device dtype '[outputFeatures]
-> Linear inputFeatures outputFeatures dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[outputFeatures, inputFeatures]
-> Parameter device dtype '[outputFeatures]
-> Linear inputFeatures outputFeatures dtype device
Linear (Parameter device dtype '[outputFeatures, inputFeatures]
-> Parameter device dtype '[outputFeatures]
-> Linear inputFeatures outputFeatures dtype device)
-> IO (Parameter device dtype '[outputFeatures, inputFeatures])
-> IO
(Parameter device dtype '[outputFeatures]
-> Linear inputFeatures outputFeatures dtype device)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype '[outputFeatures, inputFeatures]
-> IO (Parameter device dtype '[outputFeatures, inputFeatures])
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[outputFeatures, inputFeatures]
-> IO (Parameter device dtype '[outputFeatures, inputFeatures]))
-> IO (Tensor device dtype '[outputFeatures, inputFeatures])
-> IO (Parameter device dtype '[outputFeatures, inputFeatures])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype '[outputFeatures, inputFeatures])
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn) IO
(Parameter device dtype '[outputFeatures]
-> Linear inputFeatures outputFeatures dtype device)
-> IO (Parameter device dtype '[outputFeatures])
-> IO (Linear inputFeatures outputFeatures dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[outputFeatures]
-> IO (Parameter device dtype '[outputFeatures])
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[outputFeatures]
-> IO (Parameter device dtype '[outputFeatures]))
-> IO (Tensor device dtype '[outputFeatures])
-> IO (Parameter device dtype '[outputFeatures])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype '[outputFeatures])
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)