{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-|
Module      : Grenade.Core.Network
Description : Core definition of a Neural Network
Copyright   : (c) Huw Campbell, 2016-2017
License     : BSD2
Stability   : experimental

This module defines the core data types and functions
for non-recurrent neural networks.
-}

module Grenade.Core.Network (
    Network (..)
  , Gradients (..)
  , Tapes (..)

  , runNetwork
  , runGradient
  , applyUpdate

  , randomNetwork
  ) where

import           Control.Monad.Random ( MonadRandom )

import           Data.Singletons
import           Data.Singletons.Prelude
import           Data.Serialize

import           Grenade.Core.Layer
import           Grenade.Core.LearningParameters
import           Grenade.Core.Shape

-- | Type of a network.
--
--   The @[*]@ type specifies the types of the layers.
--
--   The @[Shape]@ type specifies the shapes of data passed between the layers.
--
--   Can be considered to be a heterogeneous list of layers which are able to
--   transform the data shapes of the network.
data Network :: [*] -> [Shape] -> * where
    NNil  :: SingI i
          => Network '[] '[i]

    (:~>) :: (SingI i, SingI h, Layer x i h)
          => !x
          -> !(Network xs (h ': hs))
          -> Network (x ': xs) (i ': h ': hs)
infixr 5 :~>

instance Show (Network '[] '[i]) where
  show NNil = "NNil"
instance (Show x, Show (Network xs rs)) => Show (Network (x ': xs) (i ': rs)) where
  show (x :~> xs) = show x ++ "\n~>\n" ++ show xs

-- | Gradient of a network.
--
--   Parameterised on the layers of the network.
data Gradients :: [*] -> * where
   GNil  :: Gradients '[]

   (:/>) :: UpdateLayer x
         => Gradient x
         -> Gradients xs
         -> Gradients (x ': xs)

-- | Wegnert Tape of a network.
--
--   Parameterised on the layers and shapes of the network.
data Tapes :: [*] -> [Shape] -> * where
   TNil  :: SingI i
         => Tapes '[] '[i]

   (:\>) :: (SingI i, SingI h, Layer x i h)
         => !(Tape x i h)
         -> !(Tapes xs (h ': hs))
         -> Tapes (x ': xs) (i ': h ': hs)


-- | Running a network forwards with some input data.
--
--   This gives the output, and the Wengert tape required for back
--   propagation.
runNetwork :: forall layers shapes.
              Network layers shapes
           -> S (Head shapes)
           -> (Tapes layers shapes, S (Last shapes))
runNetwork =
  go
    where
  go  :: forall js ss. (Last js ~ Last shapes)
      => Network ss js
      -> S (Head js)
      -> (Tapes ss js, S (Last js))
  go (layer :~> n) !x =
    let (tape, forward) = runForwards layer x
        (tapes, answer) = go n forward
    in  (tape :\> tapes, answer)

  go NNil !x
      = (TNil, x)


-- | Running a loss gradient back through the network.
--
--   This requires a Wengert tape, generated with the appropriate input
--   for the loss.
--
--   Gives the gradients for the layer, and the gradient across the
--   input (which may not be required).
runGradient :: forall layers shapes.
               Network layers shapes
            -> Tapes layers shapes
            -> S (Last shapes)
            -> (Gradients layers, S (Head shapes))
runGradient net tapes o =
  go net tapes
    where
  go  :: forall js ss. (Last js ~ Last shapes)
      => Network ss js
      -> Tapes ss js
      -> (Gradients ss, S (Head js))
  go (layer :~> n) (tape :\> nt) =
    let (gradients, feed)  = go n nt
        (layer', backGrad) = runBackwards layer tape feed
    in  (layer' :/> gradients, backGrad)

  go NNil TNil
      = (GNil, o)


-- | Apply one step of stochastic gradient decent across the network.
applyUpdate :: LearningParameters
            -> Network layers shapes
            -> Gradients layers
            -> Network layers shapes
applyUpdate rate (layer :~> rest) (gradient :/> grest)
  = runUpdate rate layer gradient :~> applyUpdate rate rest grest

applyUpdate _ NNil GNil
  = NNil

-- | A network can easily be created by hand with (:~>), but an easy way to
--   initialise a random network is with the randomNetwork.
class CreatableNetwork (xs :: [*]) (ss :: [Shape]) where
  -- | Create a network with randomly initialised weights.
  --
  --   Calls to this function will not compile if the type of the neural
  --   network is not sound.
  randomNetwork :: MonadRandom m => m (Network xs ss)

instance SingI i => CreatableNetwork '[] '[i] where
  randomNetwork = return NNil

instance (SingI i, SingI o, Layer x i o, CreatableNetwork xs (o ': rs)) => CreatableNetwork (x ': xs) (i ': o ': rs) where
  randomNetwork = (:~>) <$> createRandom <*> randomNetwork

-- | Add very simple serialisation to the network
instance SingI i => Serialize (Network '[] '[i]) where
  put NNil = pure ()
  get = return NNil

instance (SingI i, SingI o, Layer x i o, Serialize x, Serialize (Network xs (o ': rs))) => Serialize (Network (x ': xs) (i ': o ': rs)) where
  put (x :~> r) = put x >> put r
  get = (:~>) <$> get <*> get


-- | Ultimate composition.
--
--   This allows a complete network to be treated as a layer in a larger network.
instance CreatableNetwork sublayers subshapes => UpdateLayer (Network sublayers subshapes) where
  type Gradient (Network sublayers subshapes) = Gradients sublayers
  runUpdate    = applyUpdate
  createRandom = randomNetwork

-- | Ultimate composition.
--
--   This allows a complete network to be treated as a layer in a larger network.
instance (CreatableNetwork sublayers subshapes, i ~ (Head subshapes), o ~ (Last subshapes)) => Layer (Network sublayers subshapes) i o where
  type Tape (Network sublayers subshapes) i o = Tapes sublayers subshapes
  runForwards  = runNetwork
  runBackwards = runGradient