{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts      #-}
{-|
Module      : Grenade.Core.Pad
Description : Padding layer for 2D and 3D images
Copyright   : (c) Huw Campbell, 2016-2017
License     : BSD2
Stability   : experimental
-}
module Grenade.Layers.Pad (
    Pad (..)
  ) where

import           Data.Maybe
import           Data.Proxy
import           Data.Serialize
import           Data.Singletons.TypeLits
import           GHC.TypeLits

import           Grenade.Core
import           Grenade.Layers.Internal.Pad

import           Numeric.LinearAlgebra (konst, subMatrix, diagBlock)
import           Numeric.LinearAlgebra.Static (extract, create)

-- | A padding layer for a neural network.
--
--   Pads on the X and Y dimension of an image.
data Pad  :: Nat
          -> Nat
          -> Nat
          -> Nat -> * where
  Pad  :: Pad padLeft padTop padRight padBottom

instance Show (Pad padLeft padTop padRight padBottom) where
  show Pad = "Pad"

instance UpdateLayer (Pad l t r b) where
  type Gradient (Pad l t r b) = ()
  runUpdate _ x _ = x
  createRandom = return Pad

instance Serialize (Pad l t r b) where
  put _ = return ()
  get = return Pad

-- | A two dimentional image can be padped.
instance ( KnownNat padLeft
         , KnownNat padTop
         , KnownNat padRight
         , KnownNat padBottom
         , KnownNat inputRows
         , KnownNat inputColumns
         , KnownNat outputRows
         , KnownNat outputColumns
         , (inputRows + padTop + padBottom) ~ outputRows
         , (inputColumns + padLeft + padRight) ~ outputColumns
         ) => Layer (Pad padLeft padTop padRight padBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
  type Tape (Pad padLeft padTop padRight padBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns)  = ()
  runForwards Pad (S2D input) =
    let padl  = fromIntegral $ natVal (Proxy :: Proxy padLeft)
        padt  = fromIntegral $ natVal (Proxy :: Proxy padTop)
        padr  = fromIntegral $ natVal (Proxy :: Proxy padRight)
        padb  = fromIntegral $ natVal (Proxy :: Proxy padBottom)
        m     = extract input
        r     = diagBlock [konst 0 (padt,padl), m, konst 0 (padb,padr)]
    in  ((), S2D . fromJust . create $ r)
  runBackwards Pad _ (S2D dEdy) =
    let padl  = fromIntegral $ natVal (Proxy :: Proxy padLeft)
        padt  = fromIntegral $ natVal (Proxy :: Proxy padTop)
        nrows = fromIntegral $ natVal (Proxy :: Proxy inputRows)
        ncols = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
        m     = extract dEdy
        vs    = subMatrix (padt, padl) (nrows, ncols) m
    in  ((), S2D . fromJust . create $ vs)

-- | A two dimentional image can be padped.
instance ( KnownNat padLeft
         , KnownNat padTop
         , KnownNat padRight
         , KnownNat padBottom
         , KnownNat inputRows
         , KnownNat inputColumns
         , KnownNat outputRows
         , KnownNat outputColumns
         , KnownNat channels
         , KnownNat (inputRows * channels)
         , KnownNat (outputRows * channels)
         , (inputRows + padTop + padBottom) ~ outputRows
         , (inputColumns + padLeft + padRight) ~ outputColumns
         ) => Layer (Pad padLeft padTop padRight padBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where
  type Tape (Pad padLeft padTop padRight padBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels)  = ()
  runForwards Pad (S3D input) =
    let padl  = fromIntegral $ natVal (Proxy :: Proxy padLeft)
        padt  = fromIntegral $ natVal (Proxy :: Proxy padTop)
        padr  = fromIntegral $ natVal (Proxy :: Proxy padRight)
        padb  = fromIntegral $ natVal (Proxy :: Proxy padBottom)
        outr  = fromIntegral $ natVal (Proxy :: Proxy outputRows)
        outc  = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
        inr   = fromIntegral $ natVal (Proxy :: Proxy inputRows)
        inc   = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
        ch    = fromIntegral $ natVal (Proxy :: Proxy channels)
        m     = extract input
        padded = pad ch padl padt padr padb inr inc outr outc m
    in  ((), S3D . fromJust . create $ padded)

  runBackwards Pad () (S3D gradient) =
    let padl  = fromIntegral $ natVal (Proxy :: Proxy padLeft)
        padt  = fromIntegral $ natVal (Proxy :: Proxy padTop)
        padr  = fromIntegral $ natVal (Proxy :: Proxy padRight)
        padb  = fromIntegral $ natVal (Proxy :: Proxy padBottom)
        outr  = fromIntegral $ natVal (Proxy :: Proxy outputRows)
        outc  = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
        inr   = fromIntegral $ natVal (Proxy :: Proxy inputRows)
        inc   = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
        ch    = fromIntegral $ natVal (Proxy :: Proxy channels)
        m     = extract gradient
        cropped = crop ch padl padt padr padb inr inc outr outc m
    in  ((), S3D . fromJust . create $ cropped)