{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE DuplicateRecordFields #-}
module Torch.NN where
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State.Strict
import Data.Foldable (toList)
import Data.Kind
import GHC.Generics
import System.IO.Unsafe (unsafePerformIO)
import Torch.Autograd
import Torch.Device
import Torch.Functional
import Torch.Initializers
import Torch.Internal.Cast (cast3)
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import Torch.Scalar
import Torch.Tensor
import Torch.TensorFactories (ones', randIO', randnIO', zeros')
type Parameter = IndependentTensor
type ParamStream a = State [Parameter] a
nextParameter :: ParamStream Parameter
nextParameter :: ParamStream IndependentTensor
nextParameter = do
params <- StateT [IndependentTensor] Identity [IndependentTensor]
forall s (m :: * -> *). MonadState s m => m s
get
case params of
[] -> [Char] -> ParamStream IndependentTensor
forall a. HasCallStack => [Char] -> a
error [Char]
"Not enough parameters supplied to replaceParameters"
(IndependentTensor
p : [IndependentTensor]
t) -> do [IndependentTensor] -> StateT [IndependentTensor] Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [IndependentTensor]
t; IndependentTensor -> ParamStream IndependentTensor
forall a. a -> StateT [IndependentTensor] Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return IndependentTensor
p
class HasForward f a b | f a -> b where
forward :: f -> a -> b
default forward ::
( Generic f,
Generic a,
Generic b,
GHasForward (Rep f) (Rep a) (Rep b)
) =>
f ->
a ->
b
forward f
f a
a = Rep b Any -> b
forall a x. Generic a => Rep a x -> a
forall x. Rep b x -> b
to (Rep b Any -> b) -> Rep b Any -> b
forall a b. (a -> b) -> a -> b
$ Rep f Any -> Rep a Any -> Rep b Any
forall c c' c''. Rep f c -> Rep a c' -> Rep b c''
forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> b c''
gForward (f -> Rep f Any
forall x. f -> Rep f x
forall a x. Generic a => a -> Rep a x
from f
f) (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from a
a)
forwardStoch :: f -> a -> IO b
default forwardStoch ::
( Generic f,
Generic a,
Generic b,
GHasForward (Rep f) (Rep a) (Rep b)
) =>
f ->
a ->
IO b
forwardStoch f
f a
a = Rep b Any -> b
forall a x. Generic a => Rep a x -> a
forall x. Rep b x -> b
to (Rep b Any -> b) -> IO (Rep b Any) -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rep f Any -> Rep a Any -> IO (Rep b Any)
forall c c' c''. Rep f c -> Rep a c' -> IO (Rep b c)
forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> IO (b c)
gForwardStoch (f -> Rep f Any
forall x. f -> Rep f x
forall a x. Generic a => a -> Rep a x
from f
f) (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from a
a)
class GHasForward (f :: Type -> Type) (a :: Type -> Type) (b :: Type -> Type) | f a -> b where
gForward :: forall c c' c''. f c -> a c' -> b c''
gForwardStoch :: forall c c' c''. f c -> a c' -> IO (b c)
instance GHasForward U1 U1 U1 where
gForward :: forall c c' c''. U1 c -> U1 c' -> U1 c''
gForward U1 c
U1 U1 c'
U1 = U1 c''
forall k (p :: k). U1 p
U1
gForwardStoch :: forall c c' c''. U1 c -> U1 c' -> IO (U1 c)
gForwardStoch U1 c
U1 U1 c'
U1 = U1 c -> IO (U1 c)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return U1 c
forall k (p :: k). U1 p
U1
instance
( GHasForward f a b,
GHasForward g a' b',
b'' ~ (b :+: b')
) =>
GHasForward (f :+: g) (a :+: a') b''
where
gForward :: forall c c' c''. (:+:) f g c -> (:+:) a a' c' -> b'' c''
gForward (L1 f c
f) (L1 a c'
a) = b c'' -> (:+:) b b' c''
forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 (b c'' -> (:+:) b b' c'') -> b c'' -> (:+:) b b' c''
forall a b. (a -> b) -> a -> b
$ f c -> a c' -> b c''
forall c c' c''. f c -> a c' -> b c''
forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> b c''
gForward f c
f a c'
a
gForward (R1 g c
g) (R1 a' c'
a') = b' c'' -> (:+:) b b' c''
forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 (b' c'' -> (:+:) b b' c'') -> b' c'' -> (:+:) b b' c''
forall a b. (a -> b) -> a -> b
$ g c -> a' c' -> b' c''
forall c c' c''. g c -> a' c' -> b' c''
forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> b c''
gForward g c
g a' c'
a'
gForwardStoch :: forall c c' c''. (:+:) f g c -> (:+:) a a' c' -> IO (b'' c)
gForwardStoch (L1 f c
f) (L1 a c'
a) = b c -> b'' c
b c -> (:+:) b b' c
forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 (b c -> b'' c) -> IO (b c) -> IO (b'' c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f c -> a c' -> IO (b c)
forall c c' c''. f c -> a c' -> IO (b c)
forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> IO (b c)
gForwardStoch f c
f a c'
a
gForwardStoch (R1 g c
g) (R1 a' c'
a') = b' c -> b'' c
b' c -> (:+:) b b' c
forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 (b' c -> b'' c) -> IO (b' c) -> IO (b'' c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> g c -> a' c' -> IO (b' c)
forall c c' c''. g c -> a' c' -> IO (b' c)
forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> IO (b c)
gForwardStoch g c
g a' c'
a'
instance
( GHasForward f a b,
GHasForward g a' b',
b'' ~ (b :*: b')
) =>
GHasForward (f :*: g) (a :*: a') b''
where
gForward :: forall c c' c''. (:*:) f g c -> (:*:) a a' c' -> b'' c''
gForward (f c
f :*: g c
g) (a c'
a :*: a' c'
a') = f c -> a c' -> b c''
forall c c' c''. f c -> a c' -> b c''
forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> b c''
gForward f c
f a c'
a b c'' -> b' c'' -> (:*:) b b' c''
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: g c -> a' c' -> b' c''
forall c c' c''. g c -> a' c' -> b' c''
forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> b c''
gForward g c
g a' c'
a'
gForwardStoch :: forall c c' c''. (:*:) f g c -> (:*:) a a' c' -> IO (b'' c)
gForwardStoch (f c
f :*: g c
g) (a c'
a :*: a' c'
a') = (b c -> b' c -> b'' c) -> IO (b c) -> IO (b' c) -> IO (b'' c)
forall a b c. (a -> b -> c) -> IO a -> IO b -> IO c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 b c -> b' c -> b'' c
b c -> b' c -> (:*:) b b' c
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
(:*:) (f c -> a c' -> IO (b c)
forall c c' c''. f c -> a c' -> IO (b c)
forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> IO (b c)
gForwardStoch f c
f a c'
a) (g c -> a' c' -> IO (b' c)
forall c c' c''. g c -> a' c' -> IO (b' c)
forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> IO (b c)
gForwardStoch g c
g a' c'
a')
instance
(HasForward f a b) =>
GHasForward (K1 i f) (K1 i a) (K1 i b)
where
gForward :: forall c c' c''. K1 i f c -> K1 i a c' -> K1 i b c''
gForward (K1 f
f) (K1 a
a) = b -> K1 i b c''
forall k i c (p :: k). c -> K1 i c p
K1 (b -> K1 i b c'') -> b -> K1 i b c''
forall a b. (a -> b) -> a -> b
$ f -> a -> b
forall f a b. HasForward f a b => f -> a -> b
forward f
f a
a
gForwardStoch :: forall c c' c''. K1 i f c -> K1 i a c' -> IO (K1 i b c)
gForwardStoch (K1 f
f) (K1 a
a) = b -> K1 i b c
forall k i c (p :: k). c -> K1 i c p
K1 (b -> K1 i b c) -> IO b -> IO (K1 i b c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f -> a -> IO b
forall f a b. HasForward f a b => f -> a -> IO b
forwardStoch f
f a
a
instance
(GHasForward f a b) =>
GHasForward (M1 i t f) (M1 i t' a) (M1 i t' b)
where
gForward :: forall c c' c''. M1 i t f c -> M1 i t' a c' -> M1 i t' b c''
gForward (M1 f c
f) (M1 a c'
a) = b c'' -> M1 i t' b c''
forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 (b c'' -> M1 i t' b c'') -> b c'' -> M1 i t' b c''
forall a b. (a -> b) -> a -> b
$ f c -> a c' -> b c''
forall c c' c''. f c -> a c' -> b c''
forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> b c''
gForward f c
f a c'
a
gForwardStoch :: forall c c' c''. M1 i t f c -> M1 i t' a c' -> IO (M1 i t' b c)
gForwardStoch (M1 f c
f) (M1 a c'
a) = b c -> M1 i t' b c
forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 (b c -> M1 i t' b c) -> IO (b c) -> IO (M1 i t' b c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f c -> a c' -> IO (b c)
forall c c' c''. f c -> a c' -> IO (b c)
forall (f :: * -> *) (a :: * -> *) (b :: * -> *) c c' c''.
GHasForward f a b =>
f c -> a c' -> IO (b c)
gForwardStoch f c
f a c'
a
class Parameterized f where
flattenParameters :: f -> [Parameter]
default flattenParameters :: (Generic f, GParameterized (Rep f)) => f -> [Parameter]
flattenParameters f
f = Rep f Any -> [IndependentTensor]
forall a. Rep f a -> [IndependentTensor]
forall (f :: * -> *) a.
GParameterized f =>
f a -> [IndependentTensor]
gFlattenParameters (f -> Rep f Any
forall x. f -> Rep f x
forall a x. Generic a => a -> Rep a x
from f
f)
_replaceParameters :: f -> ParamStream f
default _replaceParameters :: (Generic f, GParameterized (Rep f)) => f -> ParamStream f
_replaceParameters f
f = Rep f Any -> f
forall a x. Generic a => Rep a x -> a
forall x. Rep f x -> f
to (Rep f Any -> f)
-> StateT [IndependentTensor] Identity (Rep f Any) -> ParamStream f
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rep f Any -> StateT [IndependentTensor] Identity (Rep f Any)
forall a. Rep f a -> ParamStream (Rep f a)
forall (f :: * -> *) a.
GParameterized f =>
f a -> ParamStream (f a)
_gReplaceParameters (f -> Rep f Any
forall x. f -> Rep f x
forall a x. Generic a => a -> Rep a x
from f
f)
replaceParameters :: Parameterized f => f -> [Parameter] -> f
replaceParameters :: forall f. Parameterized f => f -> [IndependentTensor] -> f
replaceParameters f
f [IndependentTensor]
params =
let (f
f', [IndependentTensor]
remaining) = State [IndependentTensor] f
-> [IndependentTensor] -> (f, [IndependentTensor])
forall s a. State s a -> s -> (a, s)
runState (f -> State [IndependentTensor] f
forall f. Parameterized f => f -> ParamStream f
_replaceParameters f
f) [IndependentTensor]
params
in if [IndependentTensor] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [IndependentTensor]
remaining
then f
f'
else [Char] -> f
forall a. HasCallStack => [Char] -> a
error [Char]
"Some parameters in a call to replaceParameters haven't been consumed!"
instance Parameterized Tensor where
flattenParameters :: Tensor -> [IndependentTensor]
flattenParameters Tensor
_ = []
_replaceParameters :: Tensor -> ParamStream Tensor
_replaceParameters = Tensor -> ParamStream Tensor
forall a. a -> StateT [IndependentTensor] Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return
instance Parameterized Parameter where
flattenParameters :: IndependentTensor -> [IndependentTensor]
flattenParameters = IndependentTensor -> [IndependentTensor]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
_replaceParameters :: IndependentTensor -> ParamStream IndependentTensor
_replaceParameters IndependentTensor
_ = ParamStream IndependentTensor
nextParameter
instance {-# OVERLAPS #-} (Scalar a) => Parameterized a where
flattenParameters :: a -> [IndependentTensor]
flattenParameters a
_ = []
_replaceParameters :: a -> ParamStream a
_replaceParameters = a -> ParamStream a
forall a. a -> StateT [IndependentTensor] Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return
instance {-# OVERLAPS #-} (Parameterized a, Parameterized b) => Parameterized (a, b) where
flattenParameters :: (a, b) -> [IndependentTensor]
flattenParameters (a
a, b
b) = a -> [IndependentTensor]
forall f. Parameterized f => f -> [IndependentTensor]
flattenParameters a
a [IndependentTensor] -> [IndependentTensor] -> [IndependentTensor]
forall a. [a] -> [a] -> [a]
++ b -> [IndependentTensor]
forall f. Parameterized f => f -> [IndependentTensor]
flattenParameters b
b
_replaceParameters :: (a, b) -> ParamStream (a, b)
_replaceParameters (a
a, b
b) = do
a' <- a -> ParamStream a
forall f. Parameterized f => f -> ParamStream f
_replaceParameters a
a
b' <- _replaceParameters b
return (a', b')
instance {-# OVERLAPS #-} (Parameterized a, Parameterized b, Parameterized c) => Parameterized (a, b, c) where
flattenParameters :: (a, b, c) -> [IndependentTensor]
flattenParameters (a
a, b
b, c
c) = a -> [IndependentTensor]
forall f. Parameterized f => f -> [IndependentTensor]
flattenParameters a
a [IndependentTensor] -> [IndependentTensor] -> [IndependentTensor]
forall a. [a] -> [a] -> [a]
++ b -> [IndependentTensor]
forall f. Parameterized f => f -> [IndependentTensor]
flattenParameters b
b [IndependentTensor] -> [IndependentTensor] -> [IndependentTensor]
forall a. [a] -> [a] -> [a]
++ c -> [IndependentTensor]
forall f. Parameterized f => f -> [IndependentTensor]
flattenParameters c
c
_replaceParameters :: (a, b, c) -> ParamStream (a, b, c)
_replaceParameters (a
a, b
b, c
c) = do
a' <- a -> ParamStream a
forall f. Parameterized f => f -> ParamStream f
_replaceParameters a
a
b' <- _replaceParameters b
c' <- _replaceParameters c
return (a', b', c')
instance {-# OVERLAPS #-} (Foldable t, Traversable t, Parameterized a) => Parameterized (t a) where
flattenParameters :: t a -> [IndependentTensor]
flattenParameters = (a -> [IndependentTensor]) -> [a] -> [IndependentTensor]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
(=<<) a -> [IndependentTensor]
forall f. Parameterized f => f -> [IndependentTensor]
flattenParameters ([a] -> [IndependentTensor])
-> (t a -> [a]) -> t a -> [IndependentTensor]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t a -> [a]
forall a. t a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList
_replaceParameters :: t a -> ParamStream (t a)
_replaceParameters = (a -> StateT [IndependentTensor] Identity a)
-> t a -> ParamStream (t a)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> t a -> m (t b)
mapM a -> StateT [IndependentTensor] Identity a
forall f. Parameterized f => f -> ParamStream f
_replaceParameters
instance Parameterized (a -> a) where
flattenParameters :: (a -> a) -> [IndependentTensor]
flattenParameters a -> a
_ = []
_replaceParameters :: (a -> a) -> ParamStream (a -> a)
_replaceParameters = (a -> a) -> ParamStream (a -> a)
forall a. a -> StateT [IndependentTensor] Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return
class GParameterized f where
gFlattenParameters :: forall a. f a -> [Parameter]
_gReplaceParameters :: forall a. f a -> ParamStream (f a)
instance GParameterized U1 where
gFlattenParameters :: forall a. U1 a -> [IndependentTensor]
gFlattenParameters U1 a
U1 = []
_gReplaceParameters :: forall a. U1 a -> ParamStream (U1 a)
_gReplaceParameters U1 a
U1 = U1 a -> StateT [IndependentTensor] Identity (U1 a)
forall a. a -> StateT [IndependentTensor] Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return U1 a
forall k (p :: k). U1 p
U1
instance (GParameterized f, GParameterized g) => GParameterized (f :+: g) where
gFlattenParameters :: forall a. (:+:) f g a -> [IndependentTensor]
gFlattenParameters (L1 f a
x) = f a -> [IndependentTensor]
forall a. f a -> [IndependentTensor]
forall (f :: * -> *) a.
GParameterized f =>
f a -> [IndependentTensor]
gFlattenParameters f a
x
gFlattenParameters (R1 g a
x) = g a -> [IndependentTensor]
forall a. g a -> [IndependentTensor]
forall (f :: * -> *) a.
GParameterized f =>
f a -> [IndependentTensor]
gFlattenParameters g a
x
_gReplaceParameters :: forall a. (:+:) f g a -> ParamStream ((:+:) f g a)
_gReplaceParameters (L1 f a
x) = do
x' <- f a -> ParamStream (f a)
forall a. f a -> ParamStream (f a)
forall (f :: * -> *) a.
GParameterized f =>
f a -> ParamStream (f a)
_gReplaceParameters f a
x
return $ L1 x'
_gReplaceParameters (R1 g a
x) = do
x' <- g a -> ParamStream (g a)
forall a. g a -> ParamStream (g a)
forall (f :: * -> *) a.
GParameterized f =>
f a -> ParamStream (f a)
_gReplaceParameters g a
x
return $ R1 x'
instance (GParameterized f, GParameterized g) => GParameterized (f :*: g) where
gFlattenParameters :: forall a. (:*:) f g a -> [IndependentTensor]
gFlattenParameters (f a
x :*: g a
y) = f a -> [IndependentTensor]
forall a. f a -> [IndependentTensor]
forall (f :: * -> *) a.
GParameterized f =>
f a -> [IndependentTensor]
gFlattenParameters f a
x [IndependentTensor] -> [IndependentTensor] -> [IndependentTensor]
forall a. [a] -> [a] -> [a]
++ g a -> [IndependentTensor]
forall a. g a -> [IndependentTensor]
forall (f :: * -> *) a.
GParameterized f =>
f a -> [IndependentTensor]
gFlattenParameters g a
y
_gReplaceParameters :: forall a. (:*:) f g a -> ParamStream ((:*:) f g a)
_gReplaceParameters (f a
x :*: g a
y) = do
x' <- f a -> ParamStream (f a)
forall a. f a -> ParamStream (f a)
forall (f :: * -> *) a.
GParameterized f =>
f a -> ParamStream (f a)
_gReplaceParameters f a
x
y' <- _gReplaceParameters y
return $ x' :*: y'
instance (Parameterized c) => GParameterized (K1 i c) where
gFlattenParameters :: forall a. K1 i c a -> [IndependentTensor]
gFlattenParameters (K1 c
x) = c -> [IndependentTensor]
forall f. Parameterized f => f -> [IndependentTensor]
flattenParameters c
x
_gReplaceParameters :: forall a. K1 i c a -> ParamStream (K1 i c a)
_gReplaceParameters (K1 c
x) = do
x' <- c -> ParamStream c
forall f. Parameterized f => f -> ParamStream f
_replaceParameters c
x
return $ K1 x'
instance (GParameterized f) => GParameterized (M1 i t f) where
gFlattenParameters :: forall a. M1 i t f a -> [IndependentTensor]
gFlattenParameters (M1 f a
x) = f a -> [IndependentTensor]
forall a. f a -> [IndependentTensor]
forall (f :: * -> *) a.
GParameterized f =>
f a -> [IndependentTensor]
gFlattenParameters f a
x
_gReplaceParameters :: forall a. M1 i t f a -> ParamStream (M1 i t f a)
_gReplaceParameters (M1 f a
x) = do
x' <- f a -> ParamStream (f a)
forall a. f a -> ParamStream (f a)
forall (f :: * -> *) a.
GParameterized f =>
f a -> ParamStream (f a)
_gReplaceParameters f a
x
return $ M1 x'
class Randomizable spec f | spec -> f where
sample :: spec -> IO f
data LinearSpec = LinearSpec
{ LinearSpec -> Int
in_features :: Int,
LinearSpec -> Int
out_features :: Int
}
deriving (Int -> LinearSpec -> ShowS
[LinearSpec] -> ShowS
LinearSpec -> [Char]
(Int -> LinearSpec -> ShowS)
-> (LinearSpec -> [Char])
-> ([LinearSpec] -> ShowS)
-> Show LinearSpec
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> LinearSpec -> ShowS
showsPrec :: Int -> LinearSpec -> ShowS
$cshow :: LinearSpec -> [Char]
show :: LinearSpec -> [Char]
$cshowList :: [LinearSpec] -> ShowS
showList :: [LinearSpec] -> ShowS
Show, LinearSpec -> LinearSpec -> Bool
(LinearSpec -> LinearSpec -> Bool)
-> (LinearSpec -> LinearSpec -> Bool) -> Eq LinearSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: LinearSpec -> LinearSpec -> Bool
== :: LinearSpec -> LinearSpec -> Bool
$c/= :: LinearSpec -> LinearSpec -> Bool
/= :: LinearSpec -> LinearSpec -> Bool
Eq)
data Linear = Linear
{ Linear -> IndependentTensor
weight :: Parameter,
Linear -> IndependentTensor
bias :: Parameter
}
deriving (Int -> Linear -> ShowS
[Linear] -> ShowS
Linear -> [Char]
(Int -> Linear -> ShowS)
-> (Linear -> [Char]) -> ([Linear] -> ShowS) -> Show Linear
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Linear -> ShowS
showsPrec :: Int -> Linear -> ShowS
$cshow :: Linear -> [Char]
show :: Linear -> [Char]
$cshowList :: [Linear] -> ShowS
showList :: [Linear] -> ShowS
Show, (forall x. Linear -> Rep Linear x)
-> (forall x. Rep Linear x -> Linear) -> Generic Linear
forall x. Rep Linear x -> Linear
forall x. Linear -> Rep Linear x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Linear -> Rep Linear x
from :: forall x. Linear -> Rep Linear x
$cto :: forall x. Rep Linear x -> Linear
to :: forall x. Rep Linear x -> Linear
Generic, Linear -> [IndependentTensor]
Linear -> ParamStream Linear
(Linear -> [IndependentTensor])
-> (Linear -> ParamStream Linear) -> Parameterized Linear
forall f.
(f -> [IndependentTensor])
-> (f -> ParamStream f) -> Parameterized f
$cflattenParameters :: Linear -> [IndependentTensor]
flattenParameters :: Linear -> [IndependentTensor]
$c_replaceParameters :: Linear -> ParamStream Linear
_replaceParameters :: Linear -> ParamStream Linear
Parameterized)
linear :: Linear -> Tensor -> Tensor
linear :: Linear -> Tensor -> Tensor
linear Linear
layer Tensor
input = Tensor -> Tensor -> Tensor -> Tensor
forall {a} {x1} {x2} {a}.
(Castable a (ForeignPtr Tensor), Castable x1 (ForeignPtr Tensor),
Castable x2 (ForeignPtr Tensor), Castable a (ForeignPtr Tensor)) =>
a -> x1 -> x2 -> a
linear' Tensor
input Tensor
w Tensor
b
where
linear' :: a -> x1 -> x2 -> a
linear' a
input x1
weight x2
bias = IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> IO (ForeignPtr Tensor))
-> a -> x1 -> x2 -> IO a
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.linear_ttt a
input x1
weight x2
bias
w :: Tensor
w = IndependentTensor -> Tensor
toDependent (Linear
layer.weight)
b :: Tensor
b = IndependentTensor -> Tensor
toDependent (Linear
layer.bias)
linearForward :: Linear -> Tensor -> Tensor
linearForward :: Linear -> Tensor -> Tensor
linearForward = Linear -> Tensor -> Tensor
linear
instance HasForward Linear Tensor Tensor where
forward :: Linear -> Tensor -> Tensor
forward = Linear -> Tensor -> Tensor
linearForward
forwardStoch :: Linear -> Tensor -> IO Tensor
forwardStoch Linear
m Tensor
x = Tensor -> IO Tensor
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor -> IO Tensor) -> Tensor -> IO Tensor
forall a b. (a -> b) -> a -> b
$ Linear -> Tensor -> Tensor
linearForward Linear
m Tensor
x
instance Randomizable LinearSpec Linear where
sample :: LinearSpec -> IO Linear
sample LinearSpec {Int
in_features :: LinearSpec -> Int
out_features :: LinearSpec -> Int
in_features :: Int
out_features :: Int
..} = do
w <-
Tensor -> IO IndependentTensor
makeIndependent
(Tensor -> IO IndependentTensor)
-> IO Tensor -> IO IndependentTensor
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
FanMode
FanIn
(Float -> NonLinearity
LeakyRelu (Float -> NonLinearity) -> Float -> NonLinearity
forall a b. (a -> b) -> a -> b
$ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
[Int
out_features, Int
in_features]
init <- randIO' [out_features]
let bound =
(Float
1 :: Float)
Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt
( Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral
( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn ((Int, Int) -> Int) -> (Int, Int) -> Int
forall a b. (a -> b) -> a -> b
$
[Int] -> (Int, Int)
calculateFan
[ Int
out_features,
Int
in_features
]
) ::
Float
)
b <-
makeIndependent
=<< pure
( subScalar bound $ mulScalar (bound * 2.0) init
)
return $ Linear w b
data Conv1dSpec = Conv1dSpec
{ Conv1dSpec -> Int
inputChannelSize1d :: Int,
Conv1dSpec -> Int
outputChannelSize1d :: Int,
Conv1dSpec -> Int
kernelSize :: Int
}
deriving (Int -> Conv1dSpec -> ShowS
[Conv1dSpec] -> ShowS
Conv1dSpec -> [Char]
(Int -> Conv1dSpec -> ShowS)
-> (Conv1dSpec -> [Char])
-> ([Conv1dSpec] -> ShowS)
-> Show Conv1dSpec
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Conv1dSpec -> ShowS
showsPrec :: Int -> Conv1dSpec -> ShowS
$cshow :: Conv1dSpec -> [Char]
show :: Conv1dSpec -> [Char]
$cshowList :: [Conv1dSpec] -> ShowS
showList :: [Conv1dSpec] -> ShowS
Show, Conv1dSpec -> Conv1dSpec -> Bool
(Conv1dSpec -> Conv1dSpec -> Bool)
-> (Conv1dSpec -> Conv1dSpec -> Bool) -> Eq Conv1dSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Conv1dSpec -> Conv1dSpec -> Bool
== :: Conv1dSpec -> Conv1dSpec -> Bool
$c/= :: Conv1dSpec -> Conv1dSpec -> Bool
/= :: Conv1dSpec -> Conv1dSpec -> Bool
Eq)
data Conv1d = Conv1d
{ Conv1d -> IndependentTensor
weight :: Parameter,
Conv1d -> IndependentTensor
bias :: Parameter
}
deriving (Int -> Conv1d -> ShowS
[Conv1d] -> ShowS
Conv1d -> [Char]
(Int -> Conv1d -> ShowS)
-> (Conv1d -> [Char]) -> ([Conv1d] -> ShowS) -> Show Conv1d
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Conv1d -> ShowS
showsPrec :: Int -> Conv1d -> ShowS
$cshow :: Conv1d -> [Char]
show :: Conv1d -> [Char]
$cshowList :: [Conv1d] -> ShowS
showList :: [Conv1d] -> ShowS
Show, (forall x. Conv1d -> Rep Conv1d x)
-> (forall x. Rep Conv1d x -> Conv1d) -> Generic Conv1d
forall x. Rep Conv1d x -> Conv1d
forall x. Conv1d -> Rep Conv1d x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Conv1d -> Rep Conv1d x
from :: forall x. Conv1d -> Rep Conv1d x
$cto :: forall x. Rep Conv1d x -> Conv1d
to :: forall x. Rep Conv1d x -> Conv1d
Generic, Conv1d -> [IndependentTensor]
Conv1d -> ParamStream Conv1d
(Conv1d -> [IndependentTensor])
-> (Conv1d -> ParamStream Conv1d) -> Parameterized Conv1d
forall f.
(f -> [IndependentTensor])
-> (f -> ParamStream f) -> Parameterized f
$cflattenParameters :: Conv1d -> [IndependentTensor]
flattenParameters :: Conv1d -> [IndependentTensor]
$c_replaceParameters :: Conv1d -> ParamStream Conv1d
_replaceParameters :: Conv1d -> ParamStream Conv1d
Parameterized)
conv1dForward ::
Conv1d ->
Int ->
Int ->
Tensor ->
Tensor
conv1dForward :: Conv1d -> Int -> Int -> Tensor -> Tensor
conv1dForward Conv1d
layer = Tensor -> Tensor -> Int -> Int -> Tensor -> Tensor
Torch.Functional.conv1d' Tensor
w Tensor
b
where
w :: Tensor
w = IndependentTensor -> Tensor
toDependent (Conv1d
layer.weight)
b :: Tensor
b = IndependentTensor -> Tensor
toDependent (Conv1d
layer.bias)
instance Randomizable Conv1dSpec Conv1d where
sample :: Conv1dSpec -> IO Conv1d
sample Conv1dSpec {Int
inputChannelSize1d :: Conv1dSpec -> Int
outputChannelSize1d :: Conv1dSpec -> Int
kernelSize :: Conv1dSpec -> Int
inputChannelSize1d :: Int
outputChannelSize1d :: Int
kernelSize :: Int
..} = do
w <-
Tensor -> IO IndependentTensor
makeIndependent
(Tensor -> IO IndependentTensor)
-> IO Tensor -> IO IndependentTensor
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
FanMode
FanIn
(Float -> NonLinearity
LeakyRelu (Float -> NonLinearity) -> Float -> NonLinearity
forall a b. (a -> b) -> a -> b
$ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
[ Int
outputChannelSize1d,
Int
inputChannelSize1d,
Int
kernelSize
]
init <- randIO' [outputChannelSize1d]
let bound =
(Float
1 :: Float)
Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt
( Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral
( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn ((Int, Int) -> Int) -> (Int, Int) -> Int
forall a b. (a -> b) -> a -> b
$
[Int] -> (Int, Int)
calculateFan
[ Int
outputChannelSize1d,
Int
inputChannelSize1d,
Int
kernelSize
]
) ::
Float
)
b <-
makeIndependent
=<< pure
( subScalar bound $ mulScalar (bound * 2.0) init
)
return $ Conv1d w b
data Conv2dSpec = Conv2dSpec
{ Conv2dSpec -> Int
inputChannelSize2d :: Int,
Conv2dSpec -> Int
outputChannelSize2d :: Int,
Conv2dSpec -> Int
kernelHeight2d :: Int,
Conv2dSpec -> Int
kernelWidth2d :: Int
}
deriving (Int -> Conv2dSpec -> ShowS
[Conv2dSpec] -> ShowS
Conv2dSpec -> [Char]
(Int -> Conv2dSpec -> ShowS)
-> (Conv2dSpec -> [Char])
-> ([Conv2dSpec] -> ShowS)
-> Show Conv2dSpec
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Conv2dSpec -> ShowS
showsPrec :: Int -> Conv2dSpec -> ShowS
$cshow :: Conv2dSpec -> [Char]
show :: Conv2dSpec -> [Char]
$cshowList :: [Conv2dSpec] -> ShowS
showList :: [Conv2dSpec] -> ShowS
Show, Conv2dSpec -> Conv2dSpec -> Bool
(Conv2dSpec -> Conv2dSpec -> Bool)
-> (Conv2dSpec -> Conv2dSpec -> Bool) -> Eq Conv2dSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Conv2dSpec -> Conv2dSpec -> Bool
== :: Conv2dSpec -> Conv2dSpec -> Bool
$c/= :: Conv2dSpec -> Conv2dSpec -> Bool
/= :: Conv2dSpec -> Conv2dSpec -> Bool
Eq)
data Conv2d = Conv2d
{ Conv2d -> IndependentTensor
weight :: Parameter,
Conv2d -> IndependentTensor
bias :: Parameter
}
deriving (Int -> Conv2d -> ShowS
[Conv2d] -> ShowS
Conv2d -> [Char]
(Int -> Conv2d -> ShowS)
-> (Conv2d -> [Char]) -> ([Conv2d] -> ShowS) -> Show Conv2d
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Conv2d -> ShowS
showsPrec :: Int -> Conv2d -> ShowS
$cshow :: Conv2d -> [Char]
show :: Conv2d -> [Char]
$cshowList :: [Conv2d] -> ShowS
showList :: [Conv2d] -> ShowS
Show, (forall x. Conv2d -> Rep Conv2d x)
-> (forall x. Rep Conv2d x -> Conv2d) -> Generic Conv2d
forall x. Rep Conv2d x -> Conv2d
forall x. Conv2d -> Rep Conv2d x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Conv2d -> Rep Conv2d x
from :: forall x. Conv2d -> Rep Conv2d x
$cto :: forall x. Rep Conv2d x -> Conv2d
to :: forall x. Rep Conv2d x -> Conv2d
Generic, Conv2d -> [IndependentTensor]
Conv2d -> ParamStream Conv2d
(Conv2d -> [IndependentTensor])
-> (Conv2d -> ParamStream Conv2d) -> Parameterized Conv2d
forall f.
(f -> [IndependentTensor])
-> (f -> ParamStream f) -> Parameterized f
$cflattenParameters :: Conv2d -> [IndependentTensor]
flattenParameters :: Conv2d -> [IndependentTensor]
$c_replaceParameters :: Conv2d -> ParamStream Conv2d
_replaceParameters :: Conv2d -> ParamStream Conv2d
Parameterized)
conv2dForward ::
Conv2d ->
(Int, Int) ->
(Int, Int) ->
Tensor ->
Tensor
conv2dForward :: Conv2d -> (Int, Int) -> (Int, Int) -> Tensor -> Tensor
conv2dForward Conv2d
layer = Tensor -> Tensor -> (Int, Int) -> (Int, Int) -> Tensor -> Tensor
Torch.Functional.conv2d' Tensor
w Tensor
b
where
w :: Tensor
w = IndependentTensor -> Tensor
toDependent (Conv2d
layer.weight)
b :: Tensor
b = IndependentTensor -> Tensor
toDependent (Conv2d
layer.bias)
instance Randomizable Conv2dSpec Conv2d where
sample :: Conv2dSpec -> IO Conv2d
sample Conv2dSpec {Int
inputChannelSize2d :: Conv2dSpec -> Int
outputChannelSize2d :: Conv2dSpec -> Int
kernelHeight2d :: Conv2dSpec -> Int
kernelWidth2d :: Conv2dSpec -> Int
inputChannelSize2d :: Int
outputChannelSize2d :: Int
kernelHeight2d :: Int
kernelWidth2d :: Int
..} = do
w <-
Tensor -> IO IndependentTensor
makeIndependent
(Tensor -> IO IndependentTensor)
-> IO Tensor -> IO IndependentTensor
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
FanMode
FanIn
(Float -> NonLinearity
LeakyRelu (Float -> NonLinearity) -> Float -> NonLinearity
forall a b. (a -> b) -> a -> b
$ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
[ Int
outputChannelSize2d,
Int
inputChannelSize2d,
Int
kernelHeight2d,
Int
kernelWidth2d
]
init <- randIO' [outputChannelSize2d]
let bound =
(Float
1 :: Float)
Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt
( Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral
( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn ((Int, Int) -> Int) -> (Int, Int) -> Int
forall a b. (a -> b) -> a -> b
$
[Int] -> (Int, Int)
calculateFan
[ Int
outputChannelSize2d,
Int
inputChannelSize2d,
Int
kernelHeight2d,
Int
kernelWidth2d
]
) ::
Float
)
b <-
makeIndependent
=<< pure
( subScalar bound $ mulScalar (bound * 2.0) init
)
return $ Conv2d w b
data Conv3dSpec = Conv3dSpec
{ Conv3dSpec -> Int
inputChannelSize3d :: Int,
Conv3dSpec -> Int
outputChannelSize3d :: Int,
Conv3dSpec -> Int
kernelHeight3d :: Int,
Conv3dSpec -> Int
kernelWidth3d :: Int,
Conv3dSpec -> Int
kernelDepth3d :: Int
}
deriving (Int -> Conv3dSpec -> ShowS
[Conv3dSpec] -> ShowS
Conv3dSpec -> [Char]
(Int -> Conv3dSpec -> ShowS)
-> (Conv3dSpec -> [Char])
-> ([Conv3dSpec] -> ShowS)
-> Show Conv3dSpec
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Conv3dSpec -> ShowS
showsPrec :: Int -> Conv3dSpec -> ShowS
$cshow :: Conv3dSpec -> [Char]
show :: Conv3dSpec -> [Char]
$cshowList :: [Conv3dSpec] -> ShowS
showList :: [Conv3dSpec] -> ShowS
Show, Conv3dSpec -> Conv3dSpec -> Bool
(Conv3dSpec -> Conv3dSpec -> Bool)
-> (Conv3dSpec -> Conv3dSpec -> Bool) -> Eq Conv3dSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Conv3dSpec -> Conv3dSpec -> Bool
== :: Conv3dSpec -> Conv3dSpec -> Bool
$c/= :: Conv3dSpec -> Conv3dSpec -> Bool
/= :: Conv3dSpec -> Conv3dSpec -> Bool
Eq)
data Conv3d = Conv3d
{ Conv3d -> IndependentTensor
weight :: Parameter,
Conv3d -> IndependentTensor
bias :: Parameter
}
deriving (Int -> Conv3d -> ShowS
[Conv3d] -> ShowS
Conv3d -> [Char]
(Int -> Conv3d -> ShowS)
-> (Conv3d -> [Char]) -> ([Conv3d] -> ShowS) -> Show Conv3d
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Conv3d -> ShowS
showsPrec :: Int -> Conv3d -> ShowS
$cshow :: Conv3d -> [Char]
show :: Conv3d -> [Char]
$cshowList :: [Conv3d] -> ShowS
showList :: [Conv3d] -> ShowS
Show, (forall x. Conv3d -> Rep Conv3d x)
-> (forall x. Rep Conv3d x -> Conv3d) -> Generic Conv3d
forall x. Rep Conv3d x -> Conv3d
forall x. Conv3d -> Rep Conv3d x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Conv3d -> Rep Conv3d x
from :: forall x. Conv3d -> Rep Conv3d x
$cto :: forall x. Rep Conv3d x -> Conv3d
to :: forall x. Rep Conv3d x -> Conv3d
Generic, Conv3d -> [IndependentTensor]
Conv3d -> ParamStream Conv3d
(Conv3d -> [IndependentTensor])
-> (Conv3d -> ParamStream Conv3d) -> Parameterized Conv3d
forall f.
(f -> [IndependentTensor])
-> (f -> ParamStream f) -> Parameterized f
$cflattenParameters :: Conv3d -> [IndependentTensor]
flattenParameters :: Conv3d -> [IndependentTensor]
$c_replaceParameters :: Conv3d -> ParamStream Conv3d
_replaceParameters :: Conv3d -> ParamStream Conv3d
Parameterized)
conv3dForward ::
Conv3d ->
(Int, Int, Int) ->
(Int, Int, Int) ->
Tensor ->
Tensor
conv3dForward :: Conv3d -> (Int, Int, Int) -> (Int, Int, Int) -> Tensor -> Tensor
conv3dForward Conv3d
layer = Tensor
-> Tensor -> (Int, Int, Int) -> (Int, Int, Int) -> Tensor -> Tensor
Torch.Functional.conv3d' Tensor
w Tensor
b
where
w :: Tensor
w = IndependentTensor -> Tensor
toDependent (Conv3d
layer.weight)
b :: Tensor
b = IndependentTensor -> Tensor
toDependent (Conv3d
layer.bias)
instance Randomizable Conv3dSpec Conv3d where
sample :: Conv3dSpec -> IO Conv3d
sample Conv3dSpec {Int
inputChannelSize3d :: Conv3dSpec -> Int
outputChannelSize3d :: Conv3dSpec -> Int
kernelHeight3d :: Conv3dSpec -> Int
kernelWidth3d :: Conv3dSpec -> Int
kernelDepth3d :: Conv3dSpec -> Int
inputChannelSize3d :: Int
outputChannelSize3d :: Int
kernelHeight3d :: Int
kernelWidth3d :: Int
kernelDepth3d :: Int
..} = do
w <-
Tensor -> IO IndependentTensor
makeIndependent
(Tensor -> IO IndependentTensor)
-> IO Tensor -> IO IndependentTensor
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
FanMode
FanIn
(Float -> NonLinearity
LeakyRelu (Float -> NonLinearity) -> Float -> NonLinearity
forall a b. (a -> b) -> a -> b
$ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
[ Int
outputChannelSize3d,
Int
inputChannelSize3d,
Int
kernelHeight3d,
Int
kernelWidth3d,
Int
kernelDepth3d
]
init <- randIO' [outputChannelSize3d]
let bound =
(Float
1 :: Float)
Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt
( Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral
( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn ((Int, Int) -> Int) -> (Int, Int) -> Int
forall a b. (a -> b) -> a -> b
$
[Int] -> (Int, Int)
calculateFan
[ Int
outputChannelSize3d,
Int
inputChannelSize3d,
Int
kernelHeight3d,
Int
kernelWidth3d,
Int
kernelDepth3d
]
) ::
Float
)
b <-
makeIndependent
=<< pure
( subScalar bound $ mulScalar (bound * 2.0) init
)
return $ Conv3d w b
data ConvTranspose1dSpec = ConvTranspose1dSpec
{ ConvTranspose1dSpec -> Int
trInputChannelSize1d :: Int,
ConvTranspose1dSpec -> Int
trOutputChannelSize1d :: Int,
ConvTranspose1dSpec -> Int
trKernelSize :: Int
}
deriving (Int -> ConvTranspose1dSpec -> ShowS
[ConvTranspose1dSpec] -> ShowS
ConvTranspose1dSpec -> [Char]
(Int -> ConvTranspose1dSpec -> ShowS)
-> (ConvTranspose1dSpec -> [Char])
-> ([ConvTranspose1dSpec] -> ShowS)
-> Show ConvTranspose1dSpec
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConvTranspose1dSpec -> ShowS
showsPrec :: Int -> ConvTranspose1dSpec -> ShowS
$cshow :: ConvTranspose1dSpec -> [Char]
show :: ConvTranspose1dSpec -> [Char]
$cshowList :: [ConvTranspose1dSpec] -> ShowS
showList :: [ConvTranspose1dSpec] -> ShowS
Show, ConvTranspose1dSpec -> ConvTranspose1dSpec -> Bool
(ConvTranspose1dSpec -> ConvTranspose1dSpec -> Bool)
-> (ConvTranspose1dSpec -> ConvTranspose1dSpec -> Bool)
-> Eq ConvTranspose1dSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConvTranspose1dSpec -> ConvTranspose1dSpec -> Bool
== :: ConvTranspose1dSpec -> ConvTranspose1dSpec -> Bool
$c/= :: ConvTranspose1dSpec -> ConvTranspose1dSpec -> Bool
/= :: ConvTranspose1dSpec -> ConvTranspose1dSpec -> Bool
Eq)
data ConvTranspose1d = ConvTranspose1d
{ ConvTranspose1d -> IndependentTensor
weight :: Parameter,
ConvTranspose1d -> IndependentTensor
bias :: Parameter
}
deriving (Int -> ConvTranspose1d -> ShowS
[ConvTranspose1d] -> ShowS
ConvTranspose1d -> [Char]
(Int -> ConvTranspose1d -> ShowS)
-> (ConvTranspose1d -> [Char])
-> ([ConvTranspose1d] -> ShowS)
-> Show ConvTranspose1d
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConvTranspose1d -> ShowS
showsPrec :: Int -> ConvTranspose1d -> ShowS
$cshow :: ConvTranspose1d -> [Char]
show :: ConvTranspose1d -> [Char]
$cshowList :: [ConvTranspose1d] -> ShowS
showList :: [ConvTranspose1d] -> ShowS
Show, (forall x. ConvTranspose1d -> Rep ConvTranspose1d x)
-> (forall x. Rep ConvTranspose1d x -> ConvTranspose1d)
-> Generic ConvTranspose1d
forall x. Rep ConvTranspose1d x -> ConvTranspose1d
forall x. ConvTranspose1d -> Rep ConvTranspose1d x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ConvTranspose1d -> Rep ConvTranspose1d x
from :: forall x. ConvTranspose1d -> Rep ConvTranspose1d x
$cto :: forall x. Rep ConvTranspose1d x -> ConvTranspose1d
to :: forall x. Rep ConvTranspose1d x -> ConvTranspose1d
Generic, ConvTranspose1d -> [IndependentTensor]
ConvTranspose1d -> ParamStream ConvTranspose1d
(ConvTranspose1d -> [IndependentTensor])
-> (ConvTranspose1d -> ParamStream ConvTranspose1d)
-> Parameterized ConvTranspose1d
forall f.
(f -> [IndependentTensor])
-> (f -> ParamStream f) -> Parameterized f
$cflattenParameters :: ConvTranspose1d -> [IndependentTensor]
flattenParameters :: ConvTranspose1d -> [IndependentTensor]
$c_replaceParameters :: ConvTranspose1d -> ParamStream ConvTranspose1d
_replaceParameters :: ConvTranspose1d -> ParamStream ConvTranspose1d
Parameterized)
convTranspose1dForward ::
ConvTranspose1d ->
Int ->
Int ->
Tensor ->
Tensor
convTranspose1dForward :: ConvTranspose1d -> Int -> Int -> Tensor -> Tensor
convTranspose1dForward ConvTranspose1d
layer = Tensor -> Tensor -> Int -> Int -> Tensor -> Tensor
convTranspose1d' Tensor
w Tensor
b
where
w :: Tensor
w = IndependentTensor -> Tensor
toDependent (ConvTranspose1d
layer.weight)
b :: Tensor
b = IndependentTensor -> Tensor
toDependent (ConvTranspose1d
layer.bias)
instance Randomizable ConvTranspose1dSpec ConvTranspose1d where
sample :: ConvTranspose1dSpec -> IO ConvTranspose1d
sample ConvTranspose1dSpec {Int
trInputChannelSize1d :: ConvTranspose1dSpec -> Int
trOutputChannelSize1d :: ConvTranspose1dSpec -> Int
trKernelSize :: ConvTranspose1dSpec -> Int
trInputChannelSize1d :: Int
trOutputChannelSize1d :: Int
trKernelSize :: Int
..} = do
w <-
Tensor -> IO IndependentTensor
makeIndependent
(Tensor -> IO IndependentTensor)
-> IO Tensor -> IO IndependentTensor
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
FanMode
FanIn
(Float -> NonLinearity
LeakyRelu (Float -> NonLinearity) -> Float -> NonLinearity
forall a b. (a -> b) -> a -> b
$ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
[ Int
trInputChannelSize1d,
Int
trOutputChannelSize1d,
Int
trKernelSize
]
init <- randIO' [trOutputChannelSize1d]
let bound =
(Float
1 :: Float)
Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt
( Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral
( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn ((Int, Int) -> Int) -> (Int, Int) -> Int
forall a b. (a -> b) -> a -> b
$
[Int] -> (Int, Int)
calculateFan
[ Int
trInputChannelSize1d,
Int
trOutputChannelSize1d,
Int
trKernelSize
]
) ::
Float
)
b <-
makeIndependent
=<< pure
( subScalar bound $ mulScalar (bound * 2.0) init
)
return $ ConvTranspose1d w b
data ConvTranspose2dSpec = ConvTranspose2dSpec
{ ConvTranspose2dSpec -> Int
trInputChannelSize2d :: Int,
ConvTranspose2dSpec -> Int
trOutputChannelSize2d :: Int,
ConvTranspose2dSpec -> Int
trKernelHeight2d :: Int,
ConvTranspose2dSpec -> Int
trKernelWidth2d :: Int
}
deriving (Int -> ConvTranspose2dSpec -> ShowS
[ConvTranspose2dSpec] -> ShowS
ConvTranspose2dSpec -> [Char]
(Int -> ConvTranspose2dSpec -> ShowS)
-> (ConvTranspose2dSpec -> [Char])
-> ([ConvTranspose2dSpec] -> ShowS)
-> Show ConvTranspose2dSpec
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConvTranspose2dSpec -> ShowS
showsPrec :: Int -> ConvTranspose2dSpec -> ShowS
$cshow :: ConvTranspose2dSpec -> [Char]
show :: ConvTranspose2dSpec -> [Char]
$cshowList :: [ConvTranspose2dSpec] -> ShowS
showList :: [ConvTranspose2dSpec] -> ShowS
Show, ConvTranspose2dSpec -> ConvTranspose2dSpec -> Bool
(ConvTranspose2dSpec -> ConvTranspose2dSpec -> Bool)
-> (ConvTranspose2dSpec -> ConvTranspose2dSpec -> Bool)
-> Eq ConvTranspose2dSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConvTranspose2dSpec -> ConvTranspose2dSpec -> Bool
== :: ConvTranspose2dSpec -> ConvTranspose2dSpec -> Bool
$c/= :: ConvTranspose2dSpec -> ConvTranspose2dSpec -> Bool
/= :: ConvTranspose2dSpec -> ConvTranspose2dSpec -> Bool
Eq)
data ConvTranspose2d = ConvTranspose2d
{ ConvTranspose2d -> IndependentTensor
weight :: Parameter,
ConvTranspose2d -> IndependentTensor
bias :: Parameter
}
deriving (Int -> ConvTranspose2d -> ShowS
[ConvTranspose2d] -> ShowS
ConvTranspose2d -> [Char]
(Int -> ConvTranspose2d -> ShowS)
-> (ConvTranspose2d -> [Char])
-> ([ConvTranspose2d] -> ShowS)
-> Show ConvTranspose2d
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConvTranspose2d -> ShowS
showsPrec :: Int -> ConvTranspose2d -> ShowS
$cshow :: ConvTranspose2d -> [Char]
show :: ConvTranspose2d -> [Char]
$cshowList :: [ConvTranspose2d] -> ShowS
showList :: [ConvTranspose2d] -> ShowS
Show, (forall x. ConvTranspose2d -> Rep ConvTranspose2d x)
-> (forall x. Rep ConvTranspose2d x -> ConvTranspose2d)
-> Generic ConvTranspose2d
forall x. Rep ConvTranspose2d x -> ConvTranspose2d
forall x. ConvTranspose2d -> Rep ConvTranspose2d x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ConvTranspose2d -> Rep ConvTranspose2d x
from :: forall x. ConvTranspose2d -> Rep ConvTranspose2d x
$cto :: forall x. Rep ConvTranspose2d x -> ConvTranspose2d
to :: forall x. Rep ConvTranspose2d x -> ConvTranspose2d
Generic, ConvTranspose2d -> [IndependentTensor]
ConvTranspose2d -> ParamStream ConvTranspose2d
(ConvTranspose2d -> [IndependentTensor])
-> (ConvTranspose2d -> ParamStream ConvTranspose2d)
-> Parameterized ConvTranspose2d
forall f.
(f -> [IndependentTensor])
-> (f -> ParamStream f) -> Parameterized f
$cflattenParameters :: ConvTranspose2d -> [IndependentTensor]
flattenParameters :: ConvTranspose2d -> [IndependentTensor]
$c_replaceParameters :: ConvTranspose2d -> ParamStream ConvTranspose2d
_replaceParameters :: ConvTranspose2d -> ParamStream ConvTranspose2d
Parameterized)
convTranspose2dForward ::
ConvTranspose2d ->
(Int, Int) ->
(Int, Int) ->
Tensor ->
Tensor
convTranspose2dForward :: ConvTranspose2d -> (Int, Int) -> (Int, Int) -> Tensor -> Tensor
convTranspose2dForward ConvTranspose2d
layer = Tensor -> Tensor -> (Int, Int) -> (Int, Int) -> Tensor -> Tensor
convTranspose2d' Tensor
w Tensor
b
where
w :: Tensor
w = IndependentTensor -> Tensor
toDependent (ConvTranspose2d
layer.weight)
b :: Tensor
b = IndependentTensor -> Tensor
toDependent (ConvTranspose2d
layer.bias)
instance Randomizable ConvTranspose2dSpec ConvTranspose2d where
sample :: ConvTranspose2dSpec -> IO ConvTranspose2d
sample ConvTranspose2dSpec {Int
trInputChannelSize2d :: ConvTranspose2dSpec -> Int
trOutputChannelSize2d :: ConvTranspose2dSpec -> Int
trKernelHeight2d :: ConvTranspose2dSpec -> Int
trKernelWidth2d :: ConvTranspose2dSpec -> Int
trInputChannelSize2d :: Int
trOutputChannelSize2d :: Int
trKernelHeight2d :: Int
trKernelWidth2d :: Int
..} = do
w <-
Tensor -> IO IndependentTensor
makeIndependent
(Tensor -> IO IndependentTensor)
-> IO Tensor -> IO IndependentTensor
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
FanMode
FanIn
(Float -> NonLinearity
LeakyRelu (Float -> NonLinearity) -> Float -> NonLinearity
forall a b. (a -> b) -> a -> b
$ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
[ Int
trInputChannelSize2d,
Int
trOutputChannelSize2d,
Int
trKernelHeight2d,
Int
trKernelWidth2d
]
init <- randIO' [trOutputChannelSize2d]
let bound =
(Float
1 :: Float)
Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt
( Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral
( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn ((Int, Int) -> Int) -> (Int, Int) -> Int
forall a b. (a -> b) -> a -> b
$
[Int] -> (Int, Int)
calculateFan
[ Int
trInputChannelSize2d,
Int
trOutputChannelSize2d,
Int
trKernelHeight2d,
Int
trKernelWidth2d
]
) ::
Float
)
b <-
makeIndependent
=<< pure
( subScalar bound $ mulScalar (bound * 2.0) init
)
return $ ConvTranspose2d w b
data ConvTranspose3dSpec = ConvTranspose3dSpec
{ ConvTranspose3dSpec -> Int
trInputChannelSize3d :: Int,
ConvTranspose3dSpec -> Int
trOutputChannelSize3d :: Int,
ConvTranspose3dSpec -> Int
trKernelHeight3d :: Int,
ConvTranspose3dSpec -> Int
trKernelWidth3d :: Int,
ConvTranspose3dSpec -> Int
trKernelDepth3d :: Int
}
deriving (Int -> ConvTranspose3dSpec -> ShowS
[ConvTranspose3dSpec] -> ShowS
ConvTranspose3dSpec -> [Char]
(Int -> ConvTranspose3dSpec -> ShowS)
-> (ConvTranspose3dSpec -> [Char])
-> ([ConvTranspose3dSpec] -> ShowS)
-> Show ConvTranspose3dSpec
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConvTranspose3dSpec -> ShowS
showsPrec :: Int -> ConvTranspose3dSpec -> ShowS
$cshow :: ConvTranspose3dSpec -> [Char]
show :: ConvTranspose3dSpec -> [Char]
$cshowList :: [ConvTranspose3dSpec] -> ShowS
showList :: [ConvTranspose3dSpec] -> ShowS
Show, ConvTranspose3dSpec -> ConvTranspose3dSpec -> Bool
(ConvTranspose3dSpec -> ConvTranspose3dSpec -> Bool)
-> (ConvTranspose3dSpec -> ConvTranspose3dSpec -> Bool)
-> Eq ConvTranspose3dSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConvTranspose3dSpec -> ConvTranspose3dSpec -> Bool
== :: ConvTranspose3dSpec -> ConvTranspose3dSpec -> Bool
$c/= :: ConvTranspose3dSpec -> ConvTranspose3dSpec -> Bool
/= :: ConvTranspose3dSpec -> ConvTranspose3dSpec -> Bool
Eq)
data ConvTranspose3d = ConvTranspose3d
{ ConvTranspose3d -> IndependentTensor
weight :: Parameter,
ConvTranspose3d -> IndependentTensor
bias :: Parameter
}
deriving (Int -> ConvTranspose3d -> ShowS
[ConvTranspose3d] -> ShowS
ConvTranspose3d -> [Char]
(Int -> ConvTranspose3d -> ShowS)
-> (ConvTranspose3d -> [Char])
-> ([ConvTranspose3d] -> ShowS)
-> Show ConvTranspose3d
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConvTranspose3d -> ShowS
showsPrec :: Int -> ConvTranspose3d -> ShowS
$cshow :: ConvTranspose3d -> [Char]
show :: ConvTranspose3d -> [Char]
$cshowList :: [ConvTranspose3d] -> ShowS
showList :: [ConvTranspose3d] -> ShowS
Show, (forall x. ConvTranspose3d -> Rep ConvTranspose3d x)
-> (forall x. Rep ConvTranspose3d x -> ConvTranspose3d)
-> Generic ConvTranspose3d
forall x. Rep ConvTranspose3d x -> ConvTranspose3d
forall x. ConvTranspose3d -> Rep ConvTranspose3d x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ConvTranspose3d -> Rep ConvTranspose3d x
from :: forall x. ConvTranspose3d -> Rep ConvTranspose3d x
$cto :: forall x. Rep ConvTranspose3d x -> ConvTranspose3d
to :: forall x. Rep ConvTranspose3d x -> ConvTranspose3d
Generic, ConvTranspose3d -> [IndependentTensor]
ConvTranspose3d -> ParamStream ConvTranspose3d
(ConvTranspose3d -> [IndependentTensor])
-> (ConvTranspose3d -> ParamStream ConvTranspose3d)
-> Parameterized ConvTranspose3d
forall f.
(f -> [IndependentTensor])
-> (f -> ParamStream f) -> Parameterized f
$cflattenParameters :: ConvTranspose3d -> [IndependentTensor]
flattenParameters :: ConvTranspose3d -> [IndependentTensor]
$c_replaceParameters :: ConvTranspose3d -> ParamStream ConvTranspose3d
_replaceParameters :: ConvTranspose3d -> ParamStream ConvTranspose3d
Parameterized)
convTranspose3dForward ::
ConvTranspose3d ->
(Int, Int, Int) ->
(Int, Int, Int) ->
Tensor ->
Tensor
convTranspose3dForward :: ConvTranspose3d
-> (Int, Int, Int) -> (Int, Int, Int) -> Tensor -> Tensor
convTranspose3dForward ConvTranspose3d
layer = Tensor
-> Tensor -> (Int, Int, Int) -> (Int, Int, Int) -> Tensor -> Tensor
convTranspose3d' Tensor
w Tensor
b
where
w :: Tensor
w = IndependentTensor -> Tensor
toDependent (ConvTranspose3d
layer.weight)
b :: Tensor
b = IndependentTensor -> Tensor
toDependent (ConvTranspose3d
layer.bias)
instance Randomizable ConvTranspose3dSpec ConvTranspose3d where
sample :: ConvTranspose3dSpec -> IO ConvTranspose3d
sample ConvTranspose3dSpec {Int
trInputChannelSize3d :: ConvTranspose3dSpec -> Int
trOutputChannelSize3d :: ConvTranspose3dSpec -> Int
trKernelHeight3d :: ConvTranspose3dSpec -> Int
trKernelWidth3d :: ConvTranspose3dSpec -> Int
trKernelDepth3d :: ConvTranspose3dSpec -> Int
trInputChannelSize3d :: Int
trOutputChannelSize3d :: Int
trKernelHeight3d :: Int
trKernelWidth3d :: Int
trKernelDepth3d :: Int
..} = do
w <-
Tensor -> IO IndependentTensor
makeIndependent
(Tensor -> IO IndependentTensor)
-> IO Tensor -> IO IndependentTensor
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform
FanMode
FanIn
(Float -> NonLinearity
LeakyRelu (Float -> NonLinearity) -> Float -> NonLinearity
forall a b. (a -> b) -> a -> b
$ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt (Float
5.0 :: Float))
[ Int
trInputChannelSize3d,
Int
trOutputChannelSize3d,
Int
trKernelHeight3d,
Int
trKernelWidth3d,
Int
trKernelDepth3d
]
init <- randIO' [trOutputChannelSize3d]
let bound =
(Float
1 :: Float)
Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt
( Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral
( FanMode -> (Int, Int) -> Int
getter FanMode
FanIn ((Int, Int) -> Int) -> (Int, Int) -> Int
forall a b. (a -> b) -> a -> b
$
[Int] -> (Int, Int)
calculateFan
[ Int
trInputChannelSize3d,
Int
trOutputChannelSize3d,
Int
trKernelHeight3d,
Int
trKernelWidth3d,
Int
trKernelDepth3d
]
) ::
Float
)
b <-
makeIndependent
=<< pure
( subScalar bound $ mulScalar (bound * 2.0) init
)
return $ ConvTranspose3d w b
data BatchNormSpec = BatchNormSpec
{ BatchNormSpec -> Int
numFeatures :: Int
}
deriving (Int -> BatchNormSpec -> ShowS
[BatchNormSpec] -> ShowS
BatchNormSpec -> [Char]
(Int -> BatchNormSpec -> ShowS)
-> (BatchNormSpec -> [Char])
-> ([BatchNormSpec] -> ShowS)
-> Show BatchNormSpec
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BatchNormSpec -> ShowS
showsPrec :: Int -> BatchNormSpec -> ShowS
$cshow :: BatchNormSpec -> [Char]
show :: BatchNormSpec -> [Char]
$cshowList :: [BatchNormSpec] -> ShowS
showList :: [BatchNormSpec] -> ShowS
Show, BatchNormSpec -> BatchNormSpec -> Bool
(BatchNormSpec -> BatchNormSpec -> Bool)
-> (BatchNormSpec -> BatchNormSpec -> Bool) -> Eq BatchNormSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BatchNormSpec -> BatchNormSpec -> Bool
== :: BatchNormSpec -> BatchNormSpec -> Bool
$c/= :: BatchNormSpec -> BatchNormSpec -> Bool
/= :: BatchNormSpec -> BatchNormSpec -> Bool
Eq)
data BatchNorm = BatchNorm
{ BatchNorm -> IndependentTensor
weight :: Parameter,
BatchNorm -> IndependentTensor
bias :: Parameter,
BatchNorm -> MutableTensor
runningMean :: MutableTensor,
BatchNorm -> MutableTensor
runningVar :: MutableTensor
}
deriving (Int -> BatchNorm -> ShowS
[BatchNorm] -> ShowS
BatchNorm -> [Char]
(Int -> BatchNorm -> ShowS)
-> (BatchNorm -> [Char])
-> ([BatchNorm] -> ShowS)
-> Show BatchNorm
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BatchNorm -> ShowS
showsPrec :: Int -> BatchNorm -> ShowS
$cshow :: BatchNorm -> [Char]
show :: BatchNorm -> [Char]
$cshowList :: [BatchNorm] -> ShowS
showList :: [BatchNorm] -> ShowS
Show, (forall x. BatchNorm -> Rep BatchNorm x)
-> (forall x. Rep BatchNorm x -> BatchNorm) -> Generic BatchNorm
forall x. Rep BatchNorm x -> BatchNorm
forall x. BatchNorm -> Rep BatchNorm x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. BatchNorm -> Rep BatchNorm x
from :: forall x. BatchNorm -> Rep BatchNorm x
$cto :: forall x. Rep BatchNorm x -> BatchNorm
to :: forall x. Rep BatchNorm x -> BatchNorm
Generic)
batchNormForwardIO :: BatchNorm -> Bool -> Double -> Double -> Tensor -> IO Tensor
batchNormForwardIO :: BatchNorm -> Bool -> Double -> Double -> Tensor -> IO Tensor
batchNormForwardIO BatchNorm
params Bool
train Double
momentum Double
eps Tensor
input =
Tensor
-> Tensor
-> MutableTensor
-> MutableTensor
-> Bool
-> Double
-> Double
-> Tensor
-> IO Tensor
Torch.Functional.batchNormIO
(IndependentTensor -> Tensor
toDependent BatchNorm
params.weight)
(IndependentTensor -> Tensor
toDependent BatchNorm
params.bias)
BatchNorm
params.runningMean
BatchNorm
params.runningVar
Bool
train
Double
momentum
Double
eps
Tensor
input
instance Randomizable BatchNormSpec BatchNorm where
sample :: BatchNormSpec -> IO BatchNorm
sample BatchNormSpec {Int
numFeatures :: BatchNormSpec -> Int
numFeatures :: Int
..} = do
w <- Tensor -> IO IndependentTensor
makeIndependent ([Int] -> Tensor
ones' [Int
numFeatures])
b <- makeIndependent (zeros' [numFeatures])
mean <- MutableTensor <$> toDependent <$> makeIndependentWithRequiresGrad (zeros' [numFeatures]) False
var <- MutableTensor <$> toDependent <$> makeIndependentWithRequiresGrad (ones' [numFeatures]) False
return $ BatchNorm w b mean var
data InstanceNormSpec = InstanceNormSpec
{ InstanceNormSpec -> Int
numFeatures :: Int
}
deriving (Int -> InstanceNormSpec -> ShowS
[InstanceNormSpec] -> ShowS
InstanceNormSpec -> [Char]
(Int -> InstanceNormSpec -> ShowS)
-> (InstanceNormSpec -> [Char])
-> ([InstanceNormSpec] -> ShowS)
-> Show InstanceNormSpec
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> InstanceNormSpec -> ShowS
showsPrec :: Int -> InstanceNormSpec -> ShowS
$cshow :: InstanceNormSpec -> [Char]
show :: InstanceNormSpec -> [Char]
$cshowList :: [InstanceNormSpec] -> ShowS
showList :: [InstanceNormSpec] -> ShowS
Show, InstanceNormSpec -> InstanceNormSpec -> Bool
(InstanceNormSpec -> InstanceNormSpec -> Bool)
-> (InstanceNormSpec -> InstanceNormSpec -> Bool)
-> Eq InstanceNormSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: InstanceNormSpec -> InstanceNormSpec -> Bool
== :: InstanceNormSpec -> InstanceNormSpec -> Bool
$c/= :: InstanceNormSpec -> InstanceNormSpec -> Bool
/= :: InstanceNormSpec -> InstanceNormSpec -> Bool
Eq)
data InstanceNorm = InstanceNorm
{ InstanceNorm -> IndependentTensor
weight :: Parameter,
InstanceNorm -> IndependentTensor
bias :: Parameter,
InstanceNorm -> MutableTensor
runningMean :: MutableTensor,
InstanceNorm -> MutableTensor
runningVar :: MutableTensor
}
deriving (Int -> InstanceNorm -> ShowS
[InstanceNorm] -> ShowS
InstanceNorm -> [Char]
(Int -> InstanceNorm -> ShowS)
-> (InstanceNorm -> [Char])
-> ([InstanceNorm] -> ShowS)
-> Show InstanceNorm
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> InstanceNorm -> ShowS
showsPrec :: Int -> InstanceNorm -> ShowS
$cshow :: InstanceNorm -> [Char]
show :: InstanceNorm -> [Char]
$cshowList :: [InstanceNorm] -> ShowS
showList :: [InstanceNorm] -> ShowS
Show, (forall x. InstanceNorm -> Rep InstanceNorm x)
-> (forall x. Rep InstanceNorm x -> InstanceNorm)
-> Generic InstanceNorm
forall x. Rep InstanceNorm x -> InstanceNorm
forall x. InstanceNorm -> Rep InstanceNorm x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. InstanceNorm -> Rep InstanceNorm x
from :: forall x. InstanceNorm -> Rep InstanceNorm x
$cto :: forall x. Rep InstanceNorm x -> InstanceNorm
to :: forall x. Rep InstanceNorm x -> InstanceNorm
Generic)
instanceNormForwardIO :: InstanceNorm -> Bool -> Double -> Double -> Tensor -> IO Tensor
instanceNormForwardIO :: InstanceNorm -> Bool -> Double -> Double -> Tensor -> IO Tensor
instanceNormForwardIO InstanceNorm
params Bool
train Double
momentum Double
eps Tensor
input =
Tensor
-> Tensor
-> MutableTensor
-> MutableTensor
-> Bool
-> Double
-> Double
-> Tensor
-> IO Tensor
Torch.Functional.instanceNormIO
(IndependentTensor -> Tensor
toDependent InstanceNorm
params.weight)
(IndependentTensor -> Tensor
toDependent InstanceNorm
params.bias)
InstanceNorm
params.runningMean
InstanceNorm
params.runningVar
Bool
train
Double
momentum
Double
eps
Tensor
input
instance Randomizable InstanceNormSpec InstanceNorm where
sample :: InstanceNormSpec -> IO InstanceNorm
sample InstanceNormSpec {Int
numFeatures :: InstanceNormSpec -> Int
numFeatures :: Int
..} = do
w <- Tensor -> IO IndependentTensor
makeIndependent ([Int] -> Tensor
ones' [Int
numFeatures])
b <- makeIndependent (zeros' [numFeatures])
mean <- MutableTensor <$> toDependent <$> makeIndependentWithRequiresGrad (zeros' [numFeatures]) False
var <- MutableTensor <$> toDependent <$> makeIndependentWithRequiresGrad (ones' [numFeatures]) False
return $ InstanceNorm w b mean var
data UpSampleSpec = UpSampleSpec
{ UpSampleSpec -> Int
upsampleInputFilters :: Int,
UpSampleSpec -> Int
upsampleStride :: Int
}
deriving (Int -> UpSampleSpec -> ShowS
[UpSampleSpec] -> ShowS
UpSampleSpec -> [Char]
(Int -> UpSampleSpec -> ShowS)
-> (UpSampleSpec -> [Char])
-> ([UpSampleSpec] -> ShowS)
-> Show UpSampleSpec
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> UpSampleSpec -> ShowS
showsPrec :: Int -> UpSampleSpec -> ShowS
$cshow :: UpSampleSpec -> [Char]
show :: UpSampleSpec -> [Char]
$cshowList :: [UpSampleSpec] -> ShowS
showList :: [UpSampleSpec] -> ShowS
Show, UpSampleSpec -> UpSampleSpec -> Bool
(UpSampleSpec -> UpSampleSpec -> Bool)
-> (UpSampleSpec -> UpSampleSpec -> Bool) -> Eq UpSampleSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: UpSampleSpec -> UpSampleSpec -> Bool
== :: UpSampleSpec -> UpSampleSpec -> Bool
$c/= :: UpSampleSpec -> UpSampleSpec -> Bool
/= :: UpSampleSpec -> UpSampleSpec -> Bool
Eq)
instance Parameterized UpSampleSpec where
flattenParameters :: UpSampleSpec -> [IndependentTensor]
flattenParameters UpSampleSpec
_ = []
_replaceParameters :: UpSampleSpec -> ParamStream UpSampleSpec
_replaceParameters = UpSampleSpec -> ParamStream UpSampleSpec
forall a. a -> StateT [IndependentTensor] Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return
data UpSample = UpSample
{ UpSample -> UpSampleSpec
upsampleSpec :: UpSampleSpec
}
deriving (Int -> UpSample -> ShowS
[UpSample] -> ShowS
UpSample -> [Char]
(Int -> UpSample -> ShowS)
-> (UpSample -> [Char]) -> ([UpSample] -> ShowS) -> Show UpSample
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> UpSample -> ShowS
showsPrec :: Int -> UpSample -> ShowS
$cshow :: UpSample -> [Char]
show :: UpSample -> [Char]
$cshowList :: [UpSample] -> ShowS
showList :: [UpSample] -> ShowS
Show, (forall x. UpSample -> Rep UpSample x)
-> (forall x. Rep UpSample x -> UpSample) -> Generic UpSample
forall x. Rep UpSample x -> UpSample
forall x. UpSample -> Rep UpSample x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. UpSample -> Rep UpSample x
from :: forall x. UpSample -> Rep UpSample x
$cto :: forall x. Rep UpSample x -> UpSample
to :: forall x. Rep UpSample x -> UpSample
Generic, UpSample -> [IndependentTensor]
UpSample -> ParamStream UpSample
(UpSample -> [IndependentTensor])
-> (UpSample -> ParamStream UpSample) -> Parameterized UpSample
forall f.
(f -> [IndependentTensor])
-> (f -> ParamStream f) -> Parameterized f
$cflattenParameters :: UpSample -> [IndependentTensor]
flattenParameters :: UpSample -> [IndependentTensor]
$c_replaceParameters :: UpSample -> ParamStream UpSample
_replaceParameters :: UpSample -> ParamStream UpSample
Parameterized)
instance Randomizable UpSampleSpec UpSample where
sample :: UpSampleSpec -> IO UpSample
sample UpSampleSpec
s = do
UpSampleSpec -> UpSample
UpSample
(UpSampleSpec -> UpSample) -> IO UpSampleSpec -> IO UpSample
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UpSampleSpec -> IO UpSampleSpec
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure UpSampleSpec
s
instance HasForward UpSample Tensor Tensor where
forward :: UpSample -> Tensor -> Tensor
forward (UpSample (UpSampleSpec {Int
upsampleInputFilters :: UpSampleSpec -> Int
upsampleStride :: UpSampleSpec -> Int
upsampleInputFilters :: Int
upsampleStride :: Int
..})) Tensor
input =
(Int, Int) -> Double -> Double -> Tensor -> Tensor
upsampleNearest2d (Int
outputWidth Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
upsampleStride, Int
outputHeight Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
upsampleStride) (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
upsampleStride) (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
upsampleStride) Tensor
input
where
Int
outputWidth : Int
outputHeight : [Int]
_ = [Int] -> [Int]
forall a. [a] -> [a]
reverse ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ Tensor -> [Int]
shape Tensor
input
forwardStoch :: UpSample -> Tensor -> IO Tensor
forwardStoch UpSample
m Tensor
x = Tensor -> IO Tensor
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor -> IO Tensor) -> Tensor -> IO Tensor
forall a b. (a -> b) -> a -> b
$ UpSample -> Tensor -> Tensor
forall f a b. HasForward f a b => f -> a -> b
forward UpSample
m Tensor
x