{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts      #-}
{-|
Module      : Grenade.Layers.Reshape
Description : Multipurpose reshaping layer
Copyright   : (c) Huw Campbell, 2016-2017
License     : BSD2
Stability   : experimental
-}
module Grenade.Layers.Reshape (
    Reshape (..)
  ) where

import           Data.Serialize

import           Data.Singletons.TypeLits
import           GHC.TypeLits

import           Numeric.LinearAlgebra.Static
import           Numeric.LinearAlgebra.Data as LA ( flatten )

import           Grenade.Core

-- | Reshape Layer
--
-- The Reshape layer can flatten any 2D or 3D image to 1D vector with the
-- same number of activations, as well as cast up from 1D to a 2D or 3D
-- shape.
--
-- Can also be used to turn a 3D image with only one channel into a 2D image
-- or vice versa.
data Reshape = Reshape
  deriving Show

instance UpdateLayer Reshape where
  type Gradient Reshape = ()
  runUpdate _ _ _ = Reshape
  createRandom = return Reshape

instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer Reshape ('D2 x y) ('D1 a) where
  type Tape Reshape ('D2 x y) ('D1 a) = ()
  runForwards _ (S2D y)   =  ((), fromJust' . fromStorable . flatten . extract $ y)
  runBackwards _ _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)

instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer Reshape ('D3 x y z) ('D1 a) where
  type Tape Reshape ('D3 x y z) ('D1 a) = ()
  runForwards _ (S3D y)     = ((), fromJust' . fromStorable . flatten . extract $ y)
  runBackwards _ _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)

instance (KnownNat y, KnownNat x, KnownNat z, z ~ 1) => Layer Reshape ('D3 x y z) ('D2 x y) where
  type Tape Reshape ('D3 x y z) ('D2 x y) = ()
  runForwards _ (S3D y)    = ((), S2D y)
  runBackwards _ _ (S2D y) = ((), S3D y)

instance (KnownNat y, KnownNat x, KnownNat z, z ~ 1) => Layer Reshape ('D2 x y) ('D3 x y z) where
  type Tape Reshape ('D2 x y) ('D3 x y z) = ()
  runForwards _ (S2D y)    = ((), S3D y)
  runBackwards _ _ (S3D y) = ((), S2D y)

instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer Reshape ('D1 a) ('D2 x y) where
  type Tape Reshape ('D1 a) ('D2 x y) = ()
  runForwards _ (S1D y)   =  ((), fromJust' . fromStorable . extract $ y)
  runBackwards _ _ (S2D y) = ((), fromJust' . fromStorable . flatten . extract $ y)

instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer Reshape ('D1 a) ('D3 x y z) where
  type Tape Reshape ('D1 a) ('D3 x y z) = ()
  runForwards _ (S1D y)     = ((), fromJust' . fromStorable . extract $ y)
  runBackwards _ _ (S3D y)  = ((), fromJust' . fromStorable . flatten . extract $ y)

instance Serialize Reshape where
  put _ = return ()
  get = return Reshape


fromJust' :: Maybe x -> x
fromJust' (Just x) = x
fromJust' Nothing  = error $ "Reshape error: data shape couldn't be converted."