{-# LANGUAGE DeriveGeneric #-}

module Torch.Typed.NN.Recurrent.Auxiliary where

import GHC.Generics
import Torch.Functional (mulScalar, subScalar)
import Torch.Tensor

data RNNInitialization
  = ConstantInitialization
  | LearnedInitialization
  deriving (Int -> RNNInitialization -> ShowS
[RNNInitialization] -> ShowS
RNNInitialization -> String
(Int -> RNNInitialization -> ShowS)
-> (RNNInitialization -> String)
-> ([RNNInitialization] -> ShowS)
-> Show RNNInitialization
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RNNInitialization -> ShowS
showsPrec :: Int -> RNNInitialization -> ShowS
$cshow :: RNNInitialization -> String
show :: RNNInitialization -> String
$cshowList :: [RNNInitialization] -> ShowS
showList :: [RNNInitialization] -> ShowS
Show, (forall x. RNNInitialization -> Rep RNNInitialization x)
-> (forall x. Rep RNNInitialization x -> RNNInitialization)
-> Generic RNNInitialization
forall x. Rep RNNInitialization x -> RNNInitialization
forall x. RNNInitialization -> Rep RNNInitialization x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. RNNInitialization -> Rep RNNInitialization x
from :: forall x. RNNInitialization -> Rep RNNInitialization x
$cto :: forall x. Rep RNNInitialization x -> RNNInitialization
to :: forall x. Rep RNNInitialization x -> RNNInitialization
Generic)

-- TODO: This is taken from the initializers example code and should be replaced with cannonical,
-- tested versions. However, even a potentially incorrect implementation will likely perform
-- better than an ad-hoc random-normal distribution.

-- | Fan-in / Fan-out scaling calculation
calculateFan :: [Int] -> (Int, Int)
calculateFan :: [Int] -> (Int, Int)
calculateFan [Int]
shape
  | Int
dimT Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 =
    String -> (Int, Int)
forall a. HasCallStack => String -> a
error
      String
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
  | Int
dimT Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 =
    (Int
numInputFmaps, Int
numOutputFmaps)
  | Bool
otherwise =
    (Int
numInputFmaps Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
receptiveFieldSize, Int
numOutputFmaps Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
receptiveFieldSize)
  where
    dimT :: Int
dimT = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shape
    numInputFmaps :: Int
numInputFmaps = [Int]
shape [Int] -> Int -> Int
forall a. HasCallStack => [a] -> Int -> a
!! Int
1
    numOutputFmaps :: Int
numOutputFmaps = [Int]
shape [Int] -> Int -> Int
forall a. HasCallStack => [a] -> Int -> a
!! Int
0
    receptiveFieldSize :: Int
receptiveFieldSize = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
tail [Int]
shape

-- | Xavier Initialization - Uniform
xavierUniformFIXME :: Tensor -> Float -> [Int] -> IO Tensor
xavierUniformFIXME :: Tensor -> Float -> [Int] -> IO Tensor
xavierUniformFIXME Tensor
init Float
gain [Int]
shape =
  Tensor -> IO Tensor
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor -> IO Tensor) -> Tensor -> IO Tensor
forall a b. (a -> b) -> a -> b
$
    Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
bound (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$
      Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
bound Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
2.0) Tensor
init
  where
    (Int
fanIn, Int
fanOut) = [Int] -> (Int, Int)
calculateFan [Int]
shape
    std :: Float
std = Float
gain Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float -> Float
forall a. Floating a => a -> a
sqrt (Float
2.0 Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ (Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanIn Float -> Float -> Float
forall a. Num a => a -> a -> a
+ Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
fanOut))
    bound :: Float
bound = Float -> Float
forall a. Floating a => a -> a
sqrt Float
3.0 Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
std