-------------------------------------------------------------------------------
-- |
-- Module    :  Torch.Models.Internal
-- Copyright :  (c) Sam Stites 2017
-- License   :  BSD3
-- Maintainer:  sam@stites.io
-- Stability :  experimental
-- Portability: non-portable
--
-- Helper functions which might end up migrating to the -indef codebase
-------------------------------------------------------------------------------
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE CPP #-}

#if MIN_VERSION_base(4,12,0)
{-# LANGUAGE NoStarIsType #-}
#endif

module Torch.Initialization
  ( newLinear
  , newConv2d
  , xavierUniformWith_
  , xavierUniform_
  , xavierNormalWith_
  , xavierNormal_

  , Activation(..)
  , FanMode(..)
  , kaimingUniformWith_
  , kaimingUniform_
  , kaimingNormalWith_
  , kaimingNormal_
  ) where

import Data.Maybe (fromMaybe)
import Data.Function ((&))
import GHC.Generics
import Prelude as P
import Data.Singletons.Prelude hiding (type (*), All)
import Data.Singletons.Prelude.List hiding (All)
import Numeric.Dimensions
import Control.Exception.Safe (throwString)

import Torch.Double
import qualified Torch.Double as Torch
import Torch.Double.NN.Linear (Linear(..))
import qualified Torch.Double.NN.Conv2d as NN


-- Layer initialization: These depend on random functions which are not unified and, thus,
-- it's a little trickier to fold these back into their respective NN modules.

-- | initialize a new linear layer
newLinear :: forall o i . All KnownDim '[i,o] => Generator -> IO (Linear i o)
newLinear g = fmap Linear $ do
  let w = new
  kaimingUniformWith_ (LeakyReluFn (Just $ P.sqrt 5)) FanIn g w

  let
    fanin = calculateCorrectFan w FanIn
    bound = 1 / P.sqrt fanin
    bias = new
    Just pair = ord2Tuple (-bound, bound)
  _uniform bias g pair
  pure (w, bias)


-- | initialize a new conv2d layer
newConv2d :: forall o i kH kW . All KnownDim '[i,o,kH,kW,kH*kW] => Generator -> IO (Conv2d i o '(kH,kW))
newConv2d g = fmap Conv2d $ do
  let w = new
  kaimingUniformWith_ (LeakyReluFn (Just $ P.sqrt 5)) FanIn g w

  let
    fanin = calculateCorrectFan w FanIn
    bound = 1 / P.sqrt fanin
    bias = new
    Just pair = ord2Tuple (-bound, bound)
  _uniform bias g pair
  pure (w, bias)


data Activation
  -- linear functions
  = LinearFn   -- ^ Linear activation
  | Conv1dFn   -- ^ Conv1d activation
  | Conv2dFn   -- ^ Conv2d activation
  | Conv3dFn   -- ^ Conv3d activation
  | Conv1dTFn  -- ^ Conv1d transpose activation
  | Conv2dTFn  -- ^ Conv2d transpose activation
  | Conv3dTFn  -- ^ Conv3d transpose activation

  -- non-linear
  | SigmoidFn
  | TanhFn
  | ReluFn
  | LeakyReluFn (Maybe Double)
  deriving (Eq, Show)

isLinear :: Activation -> Bool
isLinear = \case
  LinearFn  -> True
  Conv1dFn  -> True
  Conv2dFn  -> True
  Conv3dFn  -> True
  Conv1dTFn -> True
  Conv2dTFn -> True
  Conv3dTFn -> True
  otherwise -> False



-- |
-- Return the recommended gain value for the given nonlinearity function.
-- The values are as follows:
-- ================= ====================================================
-- nonlinearity      gain
-- ================= ====================================================
-- Linear / Identity :math:`1`
-- Conv{1,2,3}D      :math:`1`
-- Sigmoid           :math:`1`
-- Tanh              :math:`\frac{5}{3}`
-- ReLU              :math:`\sqrt{2}`
-- Leaky Relu        :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
-- ================= ====================================================
-- Args:
--     param: optional parameter for the non-linear function
-- Examples:
--     >>> gain = nn.init.calculate_gain('leaky_relu')
calculateGain
  :: Activation  -- ^ the non-linear function (`nn.functional` name)
  -- param=None
  -> Double
calculateGain f
  | isLinear f = 1
  | otherwise =
    case f of
      SigmoidFn -> 1
      TanhFn -> 5 / 3
      ReluFn -> P.sqrt 2
      LeakyReluFn mslope -> P.sqrt $ 2 / (1 + fromMaybe 0.001 mslope ** 2)

fanInAndFanOut
  :: forall outs i o
  .  (Dimensions outs, All KnownDim '[i, o, Product outs])
  => Tensor (i:+o:+outs)
  -> (Double, Double)
fanInAndFanOut = const (fan_in, fan_out)
 where
  fan_in  = fromIntegral (dimVal (dim :: Dim o)) * rest
  fan_out = fromIntegral (dimVal (dim :: Dim i)) * rest
  rest    = fromIntegral (dimVal (dim :: Dim (Product outs)))

-- |
-- Fills the input `Tensor` with values according to the method
-- described in "Understanding the difficulty of training deep feedforward
-- neural networks" - Glorot, X. & Bengio, Y. (2010), using a uniform
-- distribution. The resulting tensor will have values sampled from
-- :math:`\mathcal{U}(-a, a)` where
-- .. math::
--     a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
-- Also known as Glorot initialization.
-- Examples:
--     -set -XScopedTypeVariables
--     w :: Tensor '[3, 5] <- torch.new
--     xavierUniformWith_ w (calculate_gain Relu)
xavierUniformWith_
  :: (Dimensions outs, All KnownDim '[i, o, Product outs])
  => HsReal              -- ^ gain: an optional scaling factor
  -> Generator
  -> Tensor (i:+o:+outs) -- ^ tensor: an n-dimensional `torch.Tensor` (minimum length 2)
  -> IO ()
xavierUniformWith_ = xavierDistributedWith_ $ \g pstd t -> do
  let std = positiveValue pstd
      a = P.sqrt 3 * std   -- Calculate uniform bounds from standard deviation
      Just pair = ord2Tuple (-a, a)
  _uniform t g pair

-- | xavierUniformWith_ with default of gain = 1
xavierUniform_
  :: (Dimensions outs, All KnownDim '[i, o, Product outs])
  => Generator
  -> Tensor (i:+o:+outs) -- ^ tensor: an n-dimensional `torch.Tensor` (minimum length 2)
  -> IO ()
xavierUniform_ = xavierUniformWith_ 1

xavierNormalWith_
  :: (Dimensions outs, All KnownDim '[i, o, Product outs])
  => HsReal              -- ^ gain: an optional scaling factor
  -> Generator
  -> Tensor (i:+o:+outs) -- ^ tensor: an n-dimensional `torch.Tensor` (minimum length 2)
  -> IO ()
xavierNormalWith_ = xavierDistributedWith_ $ \g std t -> _normal t g 0 std

-- | 'xavierNormalWith_' with default of gain = 1
xavierNormal_
  :: (Dimensions outs, All KnownDim '[i, o, Product outs])
  => Generator
  -> Tensor (i:+o:+outs) -- ^ tensor: an n-dimensional `torch.Tensor` (minimum length 2)
  -> IO ()
xavierNormal_ = xavierNormalWith_ 1


xavierDistributedWith_
  :: (Dimensions outs, All KnownDim '[i, o, Product outs])
  => (Generator -> Positive HsReal -> Tensor (i:+o:+outs) -> IO ())
  -> HsReal              -- ^ gain: an optional scaling factor
  -> Generator
  -> Tensor (i:+o:+outs) -- ^ tensor: an n-dimensional `torch.Tensor` (minimum length 2)
  -> IO ()
xavierDistributedWith_ distribution gain g tensor = do
  let
    (fan_in, fan_out) = fanInAndFanOut tensor
    mstd = gain * P.sqrt(2 / (fan_in + fan_out))
  case positive mstd of
    Just std -> distribution g std tensor
    Nothing -> throwString $
      "standard deviation is not positive. Found: " ++ show mstd ++ ", most likely the gain is negative, which is incorrect: " ++ show gain



data FanMode = FanIn | FanOut
  deriving (Eq, Ord, Show)


calculateCorrectFan
  :: (Dimensions outs, All KnownDim '[i, o, Product outs])
  => Tensor (i:+o:+outs) -> FanMode -> Double
calculateCorrectFan t = \case
  FanIn -> fan_in
  FanOut -> fan_out
 where
  (fan_in, fan_out) = fanInAndFanOut t


-- |
-- Fills the input `Tensor` with values according to the method
-- described in "Delving deep into rectifiers: Surpassing human-level
-- performance on ImageNet classification" - He, K. et al. (2015), using a
-- uniform distribution. The resulting tensor will have values sampled from
-- :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
-- .. math::
--     \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan\_in}}}
-- Also known as He initialization.
-- Args:
--     tensor: an n-dimensional `torch.Tensor`
--     a: the negative slope of the rectifier used after this layer (0 for ReLU
--         by default)
--     mode: either 'fan_in' (default) or 'fan_out'. Choosing `fan_in`
--         preserves the magnitude of the variance of the weights in the
--         forward pass. Choosing `fan_out` preserves the magnitudes in the
--         backwards pass.
--     nonlinearity: the non-linear function (`nn.functional` name),
--         recommended to use only with 'relu' or 'leaky_relu' (default).
-- Examples:
--     >>> w = torch.empty(3, 5)
--     >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
kaimingUniformWith_
  :: (Dimensions outs, All KnownDim '[i, o, Product outs])
  => Activation
  -> FanMode
  -> Generator
  -> Tensor (i:+o:+outs) -- ^ tensor: an n-dimensional `torch.Tensor` (minimum length 2)
  -> IO ()
kaimingUniformWith_ = kaimingDisributedWith_ $ \g pstd t -> do
  let a = P.sqrt 3 * (positiveValue pstd)   -- Calculate uniform bounds from standard deviation
      Just pair = ord2Tuple (-a, a)
  _uniform t g pair

kaimingUniform_
  :: (Dimensions outs, All KnownDim '[i, o, Product outs])
  => Generator
  -> Tensor (i:+o:+outs) -- ^ tensor: an n-dimensional `torch.Tensor` (minimum length 2)
  -> IO ()
kaimingUniform_ = kaimingUniformWith_ (LeakyReluFn (Just 0)) FanIn

-- |
-- Fills the input `Tensor` with values according to the method
-- described in "Delving deep into rectifiers: Surpassing human-level
-- performance on ImageNet classification" - He, K. et al. (2015), using a
-- normal distribution. The resulting tensor will have values sampled from
-- :math:`\mathcal{N}(0, \text{std})` where
-- .. math::
--     \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan\_in}}}
-- Also known as He initialization.
-- Args:
--     tensor: an n-dimensional `torch.Tensor`
--     a: the negative slope of the rectifier used after this layer (0 for ReLU
--         by default)
--     mode: either 'fan_in' (default) or 'fan_out'. Choosing `fan_in`
--         preserves the magnitude of the variance of the weights in the
--         forward pass. Choosing `fan_out` preserves the magnitudes in the
--         backwards pass.
--     nonlinearity: the non-linear function (`nn.functional` name),
--         recommended to use only with 'relu' or 'leaky_relu' (default).
-- Examples:
--     >>> w = torch.empty(3, 5)
--     >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
kaimingNormalWith_
  :: (Dimensions outs, All KnownDim '[i, o, Product outs])
  => Activation
  -> FanMode
  -> Generator
  -> Tensor (i:+o:+outs) -- ^ tensor: an n-dimensional `torch.Tensor` (minimum length 2)
  -> IO ()
kaimingNormalWith_ = kaimingDisributedWith_ $ \g std t -> _normal t g 0 std

kaimingNormal_
  :: (Dimensions outs, All KnownDim '[i, o, Product outs])
  => Generator
  -> Tensor (i:+o:+outs) -- ^ tensor: an n-dimensional `torch.Tensor` (minimum length 2)
  -> IO ()
kaimingNormal_ = kaimingNormalWith_ (LeakyReluFn (Just 0)) FanIn


kaimingDisributedWith_
  :: (Dimensions outs, All KnownDim '[i, o, Product outs])
  => (Generator -> Positive HsReal -> Tensor (i:+o:+outs) -> IO ()) -- ^ randomizing fill which takes a standard of deviation
  -> Activation
  -> FanMode
  -> Generator
  -> Tensor (i:+o:+outs) -- ^ tensor: an n-dimensional `torch.Tensor` (minimum length 2)
  -> IO ()
kaimingDisributedWith_ distribution activation mode g t =
  case positive std of
    Just std -> distribution g std t
    Nothing -> throwString $
      "standard deviation is not positive. Found: " ++ show std ++ ", most likely the gain is negative, which is incorrect: " ++ show gain
 where
  fan = calculateCorrectFan t mode
  gain = calculateGain activation
  std = gain / P.sqrt fan