{-# OPTIONS_HADDOCK show-extensions #-}

{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE Arrows #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}

{-|
Module      : Neural.Model
Description : "neural" components and models
Copyright   : (c) Lars Brünjes, 2016
License     : MIT
Maintainer  : brunjlar@gmail.com
Stability   : experimental
Portability : portable

This module defines /parameterized functions/, /components/ and /models/. The parameterized functions and components
are instances of the 'Arrow' typeclass and can therefore be combined easily and flexibly. 

/Models/ contain a component, can measure their error with regard to samples and can be trained by gradient descent/
backpropagation.
-}

module Numeric.Neural.Model
    ( ParamFun(..)
    , Component(..)
    , _weights
    , activate
    , Model(..)
    , _component
    , model
    , modelR
    , modelError
    , descent
    , StdModel
    , mkStdModel
    ) where

import Control.Arrow
import Control.Category
import Data.Profunctor
import Data.MyPrelude
import Prelude                hiding (id, (.))
import Data.Utils.Analytic
import Data.Utils.Arrow
import Data.Utils.Statistics  (mean)
import Data.Utils.Traversable

-- | The type @'ParamFun' t a b@ describes parameterized functions from @a@ to @b@, where the
--   parameters are of type @t 'Analytic'@.
--   When such components are composed, they all share the /same/ parameters.
--
newtype ParamFun t a b = ParamFun { runPF :: a -> t Analytic -> b }

instance Category (ParamFun t) where

    id = arr id

    ParamFun f . ParamFun g = ParamFun $ \x ts -> f (g x ts) ts

instance Arrow (ParamFun t) where

    arr f = ParamFun (\x _ -> f x)

    first (ParamFun f) = ParamFun $ \(x, y) ts -> (f x ts, y)

instance ArrowChoice (ParamFun t) where

    left (ParamFun f) = ParamFun $ \ex ts -> case ex of
        Left x  -> Left (f x ts)
        Right y -> Right y

instance ArrowConvolve (ParamFun t) where

    convolve (ParamFun f) = ParamFun $ \xs ts -> flip f ts <$> xs

instance Functor (ParamFun t a) where fmap = fmapArr

instance Applicative (ParamFun t a) where pure = pureArr; (<*>) = apArr

instance Profunctor (ParamFun t) where dimap  = dimapArr

-- | A @'Component' a b@ is a parameterized function from @a@ to @b@, combined with /some/ collection of analytic parameters,
--   In contrast to 'ParamFun', when components are composed, parameters are not shared. 
--   Each component carries its own collection of parameters instead.
--
data Component a b = forall t. (Traversable t, Applicative t) => Component
    { weights :: t Double                                -- ^ the specific parameter values
    , compute :: ParamFun t a b                          -- ^ the encapsulated parameterized function
    , initR   :: forall m. MonadRandom m => m (t Double) -- ^ randomly sets the parameters
    }

-- | A 'Lens'' to get or set the weights of a component.
--   The shape of the parameter collection is hidden by existential quantification,
--   so this lens has to use simple generic lists.
--
_weights:: Lens' (Component a b) [Double]
_weights= lens (\(Component ws _ _)    -> toList ws)
               (\(Component _  c i) ws -> let Just ws' = fromList ws in Component ws' c i)

-- | Activates a component, i.e. applies it to the specified input, using the current parameter values.
--
activate :: Component a b -> a -> b
activate (Component ws f _) x = runPF f x $ fromDouble <$> ws

data Empty a = Empty deriving (Show, Read, Eq, Ord, Functor, Foldable, Traversable)

instance Applicative Empty where

    pure = const Empty

    Empty <*> Empty = Empty

data Pair s t a = Pair (s a) (t a) deriving (Show, Read, Eq, Ord, Functor, Foldable, Traversable)

instance (Applicative s, Applicative t) => Applicative (Pair s t) where

    pure x = Pair (pure x) (pure x)

    Pair f g <*> Pair x y = Pair (f <*> x) (g <*> y)

instance Category Component where

    id = arr id

    Component ws c i . Component ws' c' i' = Component
        { weights = Pair ws ws'
        , compute = ParamFun $ \x (Pair zs zs') -> runPF c (runPF c' x zs') zs 
        , initR   = Pair <$> i <*> i'
        }

instance Arrow Component where

    arr f = Component
        { weights = Empty
        , compute = arr f
        , initR   = return Empty
        }

    first (Component ws c i) = Component
        { weights = ws
        , compute = first c
        , initR   = i
        }

instance ArrowChoice Component where

    left (Component ws c i) = Component ws (left c) i

instance ArrowConvolve Component where

    convolve (Component ws c i) = Component ws (convolve c) i

instance Functor (Component a) where fmap = fmapArr

instance Applicative (Component a) where pure = pureArr; (<*>) = apArr

instance Profunctor Component where dimap = dimapArr

-- | A @'Model' f g a b c@ wraps a @'Component' (f 'Analytic') (g 'Analytic')@
--   and models functions @b -> c@ with "samples" (for model error determination)
--   of type @a@.
--
data Model :: (* -> *) -> (* -> *) -> * -> * -> * -> * where

    Model :: (Functor f, Functor g) 
             => Component (f Analytic) (g Analytic)
             -> (a -> (f Double, g Analytic -> Analytic)) 
             -> (b -> f Double)                          
             -> (g Double -> c)                         
             -> Model f g a b c

instance Profunctor (Model f g a) where

    dimap m n (Model c e i o) = Model c e (i . m) (n . o)

-- | A 'Lens' for accessing the component embedded in a model.
--
_component :: Lens' (Model f g a b c) (Component (f Analytic) (g Analytic))
_component = lens (\(Model c _ _ _) -> c)
                  (\(Model _ e i o) c -> Model c e i o)

-- | Computes the modelled function.
model :: Model f g a b c -> b -> c
model (Model c _ i o) = activate $ i ^>> fmap fromDouble ^>> c >>^ fmap (fromJust . fromAnalytic) >>^ o

-- | Generates a model with randomly initialized weights. All other properties are copied from the provided model. 
modelR :: MonadRandom m => Model f g a b c -> m (Model f g a b c)
modelR (Model c e i o) = case c of
    Component _ f r -> do
        ws <- r
        return $ Model (Component ws f r) e i o

errFun :: (Functor f, Foldable h, Traversable t)
          => (a -> (f Double, g Analytic -> Analytic))
          -> h a
          -> ParamFun t (f Analytic) (g Analytic)
          -> (t Analytic -> Analytic)
errFun e xs f = runPF f' xs where

    f' = toList ^>> convolve f'' >>^ mean

    f'' = proc x -> do
        let (x', h) = e x
            x''     = fromDouble <$> x'
        y <- f -< x''
        returnA -< h y

-- | Calculates the avarage model error for a "mini-batch" of samples.
--
modelError :: Foldable h => Model f g a b c -> h a -> Double
modelError (Model c e _ _) xs = case c of
    Component ws f _ -> let f'  = errFun e xs f
                            f'' = fromJust . fromAnalytic . f' . fmap fromDouble
                        in  f'' ws

-- | Performs one step of gradient descent/ backpropagation on the model,
descent :: (Foldable h)
           => Model f g a b c           -- ^ the model whose error should be decreased 
           -> Double                    -- ^ the learning rate
           -> h a                       -- ^ a mini-batch of samples
           -> (Double, Model f g a b c) -- ^ returns the average sample error and the improved model
descent (Model c e i o) eta xs = case c of
    Component ws f r ->
        let f' = errFun e xs f
            (err, ws') = gradient (\w dw -> w - eta * dw) f' ws
            c'         = Component ws' f r
            m          = Model c' e i o
        in  (err, m)

-- | A type abbreviation for the most common type of models, where samples are just input-output tuples.
type StdModel f g b c = Model f g (b, c) b c

-- | Creates a 'StdModel', using the simplifying assumtion that the error can be computed from the expected
--   output allone.
--
mkStdModel :: (Functor f, Functor g) 
              => Component (f Analytic) (g Analytic)
              -> (c -> g Analytic -> Analytic)
              -> (b -> f Double)
              -> (g Double -> c)
              -> StdModel f g b c
mkStdModel c e i o = Model c e' i o where

    e' (x, y) = (i x, e y)