{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE CPP #-}
#if MIN_VERSION_base(4,12,0)
{-# LANGUAGE NoStarIsType #-}
#endif
module Torch.Initialization
( newLinear
, newConv2d
, xavierUniformWith_
, xavierUniform_
, xavierNormalWith_
, xavierNormal_
, Activation(..)
, FanMode(..)
, kaimingUniformWith_
, kaimingUniform_
, kaimingNormalWith_
, kaimingNormal_
) where
import Data.Maybe (fromMaybe)
import Data.Function ((&))
import GHC.Generics
import Prelude as P
import Data.Singletons.Prelude hiding (type (*), All)
import Data.Singletons.Prelude.List hiding (All)
import Numeric.Dimensions
import Control.Exception.Safe (throwString)
import Torch.Double
import qualified Torch.Double as Torch
import Torch.Double.NN.Linear (Linear(..))
import qualified Torch.Double.NN.Conv2d as NN
newLinear :: forall o i . All KnownDim '[i,o] => Generator -> IO (Linear i o)
newLinear g = fmap Linear $ do
let w = new
kaimingUniformWith_ (LeakyReluFn (Just $ P.sqrt 5)) FanIn g w
let
fanin = calculateCorrectFan w FanIn
bound = 1 / P.sqrt fanin
bias = new
Just pair = ord2Tuple (-bound, bound)
_uniform bias g pair
pure (w, bias)
newConv2d :: forall o i kH kW . All KnownDim '[i,o,kH,kW,kH*kW] => Generator -> IO (Conv2d i o '(kH,kW))
newConv2d g = fmap Conv2d $ do
let w = new
kaimingUniformWith_ (LeakyReluFn (Just $ P.sqrt 5)) FanIn g w
let
fanin = calculateCorrectFan w FanIn
bound = 1 / P.sqrt fanin
bias = new
Just pair = ord2Tuple (-bound, bound)
_uniform bias g pair
pure (w, bias)
data Activation
= LinearFn
| Conv1dFn
| Conv2dFn
| Conv3dFn
| Conv1dTFn
| Conv2dTFn
| Conv3dTFn
| SigmoidFn
| TanhFn
| ReluFn
| LeakyReluFn (Maybe Double)
deriving (Eq, Show)
isLinear :: Activation -> Bool
isLinear = \case
LinearFn -> True
Conv1dFn -> True
Conv2dFn -> True
Conv3dFn -> True
Conv1dTFn -> True
Conv2dTFn -> True
Conv3dTFn -> True
otherwise -> False
calculateGain
:: Activation
-> Double
calculateGain f
| isLinear f = 1
| otherwise =
case f of
SigmoidFn -> 1
TanhFn -> 5 / 3
ReluFn -> P.sqrt 2
LeakyReluFn mslope -> P.sqrt $ 2 / (1 + fromMaybe 0.001 mslope ** 2)
fanInAndFanOut
:: forall outs i o
. (Dimensions outs, All KnownDim '[i, o, Product outs])
=> Tensor (i:+o:+outs)
-> (Double, Double)
fanInAndFanOut = const (fan_in, fan_out)
where
fan_in = fromIntegral (dimVal (dim :: Dim o)) * rest
fan_out = fromIntegral (dimVal (dim :: Dim i)) * rest
rest = fromIntegral (dimVal (dim :: Dim (Product outs)))
xavierUniformWith_
:: (Dimensions outs, All KnownDim '[i, o, Product outs])
=> HsReal
-> Generator
-> Tensor (i:+o:+outs)
-> IO ()
xavierUniformWith_ = xavierDistributedWith_ $ \g pstd t -> do
let std = positiveValue pstd
a = P.sqrt 3 * std
Just pair = ord2Tuple (-a, a)
_uniform t g pair
xavierUniform_
:: (Dimensions outs, All KnownDim '[i, o, Product outs])
=> Generator
-> Tensor (i:+o:+outs)
-> IO ()
xavierUniform_ = xavierUniformWith_ 1
xavierNormalWith_
:: (Dimensions outs, All KnownDim '[i, o, Product outs])
=> HsReal
-> Generator
-> Tensor (i:+o:+outs)
-> IO ()
xavierNormalWith_ = xavierDistributedWith_ $ \g std t -> _normal t g 0 std
xavierNormal_
:: (Dimensions outs, All KnownDim '[i, o, Product outs])
=> Generator
-> Tensor (i:+o:+outs)
-> IO ()
xavierNormal_ = xavierNormalWith_ 1
xavierDistributedWith_
:: (Dimensions outs, All KnownDim '[i, o, Product outs])
=> (Generator -> Positive HsReal -> Tensor (i:+o:+outs) -> IO ())
-> HsReal
-> Generator
-> Tensor (i:+o:+outs)
-> IO ()
xavierDistributedWith_ distribution gain g tensor = do
let
(fan_in, fan_out) = fanInAndFanOut tensor
mstd = gain * P.sqrt(2 / (fan_in + fan_out))
case positive mstd of
Just std -> distribution g std tensor
Nothing -> throwString $
"standard deviation is not positive. Found: " ++ show mstd ++ ", most likely the gain is negative, which is incorrect: " ++ show gain
data FanMode = FanIn | FanOut
deriving (Eq, Ord, Show)
calculateCorrectFan
:: (Dimensions outs, All KnownDim '[i, o, Product outs])
=> Tensor (i:+o:+outs) -> FanMode -> Double
calculateCorrectFan t = \case
FanIn -> fan_in
FanOut -> fan_out
where
(fan_in, fan_out) = fanInAndFanOut t
kaimingUniformWith_
:: (Dimensions outs, All KnownDim '[i, o, Product outs])
=> Activation
-> FanMode
-> Generator
-> Tensor (i:+o:+outs)
-> IO ()
kaimingUniformWith_ = kaimingDisributedWith_ $ \g pstd t -> do
let a = P.sqrt 3 * (positiveValue pstd)
Just pair = ord2Tuple (-a, a)
_uniform t g pair
kaimingUniform_
:: (Dimensions outs, All KnownDim '[i, o, Product outs])
=> Generator
-> Tensor (i:+o:+outs)
-> IO ()
kaimingUniform_ = kaimingUniformWith_ (LeakyReluFn (Just 0)) FanIn
kaimingNormalWith_
:: (Dimensions outs, All KnownDim '[i, o, Product outs])
=> Activation
-> FanMode
-> Generator
-> Tensor (i:+o:+outs)
-> IO ()
kaimingNormalWith_ = kaimingDisributedWith_ $ \g std t -> _normal t g 0 std
kaimingNormal_
:: (Dimensions outs, All KnownDim '[i, o, Product outs])
=> Generator
-> Tensor (i:+o:+outs)
-> IO ()
kaimingNormal_ = kaimingNormalWith_ (LeakyReluFn (Just 0)) FanIn
kaimingDisributedWith_
:: (Dimensions outs, All KnownDim '[i, o, Product outs])
=> (Generator -> Positive HsReal -> Tensor (i:+o:+outs) -> IO ())
-> Activation
-> FanMode
-> Generator
-> Tensor (i:+o:+outs)
-> IO ()
kaimingDisributedWith_ distribution activation mode g t =
case positive std of
Just std -> distribution g std t
Nothing -> throwString $
"standard deviation is not positive. Found: " ++ show std ++ ", most likely the gain is negative, which is incorrect: " ++ show gain
where
fan = calculateCorrectFan t mode
gain = calculateGain activation
std = gain / P.sqrt fan