module Torch.Initializers where
import Torch.Functional hiding (sqrt)
import Torch.Tensor
import Torch.TensorFactories
data NonLinearity = Identity | Sigmoid | Tanh | Relu | LeakyRelu Float
data FanMode = FanIn | FanOut
newtype Shape = Shape [Int]
calculateGain :: NonLinearity -> Float
calculateGain :: NonLinearity -> Float
calculateGain NonLinearity
Identity = Float
1.0
calculateGain NonLinearity
Sigmoid = Float
1.0
calculateGain NonLinearity
Tanh = Float
5.0 Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float
3
calculateGain NonLinearity
Relu = Float -> Float
forall a. Floating a => a -> a
sqrt Float
2.0
calculateGain (LeakyRelu Float
param) = Float -> Float
forall a. Floating a => a -> a
sqrt (Float
2.0 Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ (Float
1.0 Float -> Float -> Float
forall a. Num a => a -> a -> a
+ Float
param Float -> Integer -> Float
forall a b. (Fractional a, Integral b) => a -> b -> a
^^ Integer
2))
calculateFan :: [Int] -> (Int, Int)
calculateFan :: [Int] -> (Int, Int)
calculateFan [Int]
shape
| Int
dimT Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 = [Char] -> (Int, Int)
forall a. HasCallStack => [Char] -> a
error [Char]
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
| Int
dimT Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 = ([Int]
shape [Int] -> Int -> Int
forall a. HasCallStack => [a] -> Int -> a
!! Int
1, [Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
shape)
| Bool
otherwise = (Int
numInputFmaps Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
receptiveFieldSize, Int
numOutputFmaps Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
receptiveFieldSize)
where
dimT :: Int
dimT = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shape
numInputFmaps :: Int
numInputFmaps = [Int]
shape [Int] -> Int -> Int
forall a. HasCallStack => [a] -> Int -> a
!! Int
1
numOutputFmaps :: Int
numOutputFmaps = [Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
shape
receptiveFieldSize :: Int
receptiveFieldSize = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
tail [Int]
shape
xavierUniform :: Float -> [Int] -> IO Tensor
xavierUniform :: Float -> [Int] -> IO Tensor
xavierUniform Float
gain [Int]
shape = do
Tensor
init <- [Int] -> IO Tensor
randIO' [Int]
shape
Tensor -> IO Tensor
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor -> IO Tensor) -> Tensor -> IO Tensor
forall a b. (a -> b) -> a -> b
$ Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
where
(Int
fanIn, Int
fanOut) = [Int] -> (Int, Int)
calculateFan [Int]
shape
std :: Float
std = Float
gain Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float -> Float
forall a. Floating a => a -> a
sqrt (Float
2.0 Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ (Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanIn Float -> Float -> Float
forall a. Num a => a -> a -> a
+ Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanOut))
bound :: Float
bound = Float -> Float
forall a. Floating a => a -> a
sqrt Float
3.0 Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
std
xavierNormal :: Float -> [Int] -> IO Tensor
xavierNormal :: Float -> [Int] -> IO Tensor
xavierNormal Float
gain [Int]
shape = do
Tensor
init <- [Int] -> IO Tensor
randnIO' [Int]
shape
Tensor -> IO Tensor
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor -> IO Tensor) -> Tensor -> IO Tensor
forall a b. (a -> b) -> a -> b
$ Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
std Tensor
init
where
(Int
fanIn, Int
fanOut) = [Int] -> (Int, Int)
calculateFan [Int]
shape
std :: Float
std = Float
gain Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float -> Float
forall a. Floating a => a -> a
sqrt (Float
2.0 Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ (Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanIn Float -> Float -> Float
forall a. Num a => a -> a -> a
+ Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanOut))
getter :: FanMode -> ((Int, Int) -> Int)
getter :: FanMode -> (Int, Int) -> Int
getter FanMode
FanIn = (Int, Int) -> Int
forall a b. (a, b) -> a
fst
getter FanMode
FanOut = (Int, Int) -> Int
forall a b. (a, b) -> b
snd
kaimingUniform :: FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform :: FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform FanMode
mode NonLinearity
nonlinearity [Int]
shape = do
Tensor
init <- [Int] -> IO Tensor
randIO' [Int]
shape
Tensor -> IO Tensor
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor -> IO Tensor) -> Tensor -> IO Tensor
forall a b. (a -> b) -> a -> b
$ Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
where
fanValue :: Float
fanValue = Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Float) -> Int -> Float
forall a b. (a -> b) -> a -> b
$ FanMode -> (Int, Int) -> Int
getter FanMode
mode ([Int] -> (Int, Int)
calculateFan [Int]
shape)
std :: Float
std = NonLinearity -> Float
calculateGain NonLinearity
nonlinearity Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float -> Float
forall a. Floating a => a -> a
sqrt Float
fanValue
bound :: Float
bound = Float -> Float
forall a. Floating a => a -> a
sqrt Float
3.0 Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
std
kaimingNormal :: FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingNormal :: FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingNormal FanMode
mode NonLinearity
nonlinearity [Int]
shape = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
std (Tensor -> Tensor) -> IO Tensor -> IO Tensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> IO Tensor
randnIO' [Int]
shape
where
fanValue :: Float
fanValue = Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Float) -> Int -> Float
forall a b. (a -> b) -> a -> b
$ FanMode -> (Int, Int) -> Int
getter FanMode
mode ([Int] -> (Int, Int)
calculateFan [Int]
shape)
std :: Float
std = NonLinearity -> Float
calculateGain NonLinearity
nonlinearity Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float -> Float
forall a. Floating a => a -> a
sqrt Float
fanValue
kaimingFC :: [Int] -> IO (Tensor, Tensor)
kaimingFC :: [Int] -> IO (Tensor, Tensor)
kaimingFC [Int]
weightShape = do
Tensor
weight <- [Int] -> IO Tensor
kaimingUniform' [Int]
weightShape
Tensor
biasInit <- [Int] -> IO Tensor
randIO' [Int]
biasShape
let bias :: Tensor
bias = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
2.0) Tensor
biasInit
(Tensor, Tensor) -> IO (Tensor, Tensor)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor
weight, Tensor
bias)
where
(Int
fanIn, Int
_) = [Int] -> (Int, Int)
calculateFan [Int]
weightShape
bound :: Float
bound = Float
1.0 Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ (Float -> Float
forall a. Floating a => a -> a
sqrt (Float -> Float) -> (Int -> Float) -> Int -> Float
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Float) -> Int -> Float
forall a b. (a -> b) -> a -> b
$ Int
fanIn) :: Float
biasShape :: [Int]
biasShape = [[Int] -> Int
forall a. HasCallStack => [a] -> a
head [Int]
weightShape]
kaimingUniform' :: [Int] -> IO Tensor
kaimingUniform' :: [Int] -> IO Tensor
kaimingUniform' = FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingUniform FanMode
FanIn (Float -> NonLinearity
LeakyRelu Float
0.0)
kaimingNormal' :: [Int] -> IO Tensor
kaimingNormal' :: [Int] -> IO Tensor
kaimingNormal' = FanMode -> NonLinearity -> [Int] -> IO Tensor
kaimingNormal FanMode
FanIn (Float -> NonLinearity
LeakyRelu Float
0.0)
xavierUniform' :: [Int] -> IO Tensor
xavierUniform' :: [Int] -> IO Tensor
xavierUniform' = Float -> [Int] -> IO Tensor
xavierUniform Float
1.0
xavierNormal' :: [Int] -> IO Tensor
xavierNormal' :: [Int] -> IO Tensor
xavierNormal' = Float -> [Int] -> IO Tensor
xavierNormal Float
1.0