{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}

-- {-# OPTIONS_GHC -O -ddump-rule-firings #-}


-- | Provides the class `ParamSet` which is used to represent the set of
-- parameters of a particular model.  The goal of SGD is then to find the
-- parameter values which minimize a given objective function.


module Numeric.SGD.ParamSet
  (
  -- * Class
    ParamSet(..)
  -- * Generics
  , GPMap
  , GAdd
  , GSub
  , GDiv
  , GMul
  , GNorm2
  ) where


import           GHC.Generics
import           GHC.TypeNats (KnownNat)

import           Prelude hiding (div)

import qualified Data.Map.Strict as M

import qualified Numeric.LinearAlgebra.Static as LA


-- | Class of types that can be treated as parameter sets.  It provides basic
-- element-wise operations (addition, multiplication, mapping) which are
-- required to perform stochastic gradient descent.  Many of the operations
-- (`add`, `mul`, `sub`, `div`, etc.) have the same interpretation and follow
-- the same laws (e.g. associativity) as the corresponding operations in `Num`
-- and `Fractional`.  
-- 
-- `zero` takes a parameter set as argument and "zero out"'s all its elements
-- (as in the backprop library).  This allows instances for `Maybe`, `M.Map`,
-- etc., where the structure of the parameter set is dynamic.  This leads to
-- the following property:
--
--     @add (zero x) x = x@
--
-- However, `zero` does not have to obey @(add (zero x) y = y)@.
--
-- A `ParamSet` can be also seen as a (structured) vector, hence `pmap` and
-- `norm_2`.  The latter is not strictly necessary to perform SGD, but it is
-- useful to control the training process.
--
-- `pmap` should obey the following law:
--
--     @pmap id x = x@
--
-- If you leave the body of an instance declaration blank, GHC Generics will be
-- used to derive instances if the type has a single constructor and each field
-- is an instance of `ParamSet`.
class ParamSet a where
  -- | Element-wise mapping
  pmap :: (Double -> Double) -> a -> a

  -- | Zero-out all elements
  zero :: a -> a
  zero = pmap (const 0.0)

--   -- | Element-wise negation
--   neg :: a -> a
--   neg = pmap (\x -> -x)

  -- | Element-wise addition
  add :: a -> a -> a
  -- | Elementi-wise substruction
  sub :: a -> a -> a

  -- | Element-wise multiplication
  mul :: a -> a -> a
  -- | Element-wise division
  div :: a -> a -> a

  -- | L2 norm
  norm_2 :: a -> Double

--   default zero :: (Generic a, GZero (Rep a)) => a -> a
--   zero = genericZero
--   {-# INLINE zero #-}

  default pmap
    :: (Generic a, GPMap (Rep a))
    => (Double -> Double) -> a -> a
  pmap = genericPMap
  {-# INLINE pmap #-}


  default add :: (Generic a, GAdd (Rep a)) => a -> a -> a
  add = genericAdd
  {-# INLINE add #-}

  default sub :: (Generic a, GSub (Rep a)) => a -> a -> a
  sub = genericSub
  {-# INLINE sub #-}

  default mul :: (Generic a, GMul (Rep a)) => a -> a -> a
  mul = genericMul
  {-# INLINE mul #-}

  default div :: (Generic a, GDiv (Rep a)) => a -> a -> a
  div = genericDiv
  {-# INLINE div #-}

  default norm_2 :: (Generic a, GNorm2 (Rep a)) => a -> Double
  norm_2 = genericNorm2
  {-# INLINE norm_2 #-}


{-# RULES
"ParamSet pmap/pmap" forall f g p. pmap f (pmap g p) = pmap (f . g) p
  #-}


-- {-# RULES
-- "ParamSet pmap/add/pmap" forall f g p h q. 
--   pmap f (add (pmap g p) (pmap h q))
--   = add (pmap (f . g) p) (pmap (f . h) q)
--   #-}


-- -- | 'add' using GHC Generics; works if all fields are instances of
-- -- 'ParamSet', but only for values with single constructors.
-- genericZero :: (Generic a, GZero (Rep a)) => a -> a
-- genericZero x = to $ gzero (from x)
-- {-# INLINE genericZero #-}


-- | 'add' using GHC Generics; works if all fields are instances of
-- 'ParamSet', but only for values with single constructors.
genericAdd :: (Generic a, GAdd (Rep a)) => a -> a -> a
genericAdd x y = to $ gadd (from x) (from y)
{-# INLINE genericAdd #-}


-- | 'sub' using GHC Generics; works if all fields are instances of
-- 'ParamSet', but only for values with single constructors.
genericSub :: (Generic a, GSub (Rep a)) => a -> a -> a
genericSub x y = to $ gsub (from x) (from y)
{-# INLINE genericSub #-}


-- | 'div' using GHC Generics; works if all fields are instances of
-- 'ParamSet', but only for values with single constructors.
genericDiv :: (Generic a, GDiv (Rep a)) => a -> a -> a
genericDiv x y = to $ gdiv (from x) (from y)
{-# INLINE genericDiv #-}


-- | 'mul' using GHC Generics; works if all fields are instances of
-- 'ParamSet', but only for values with single constructors.
genericMul :: (Generic a, GMul (Rep a)) => a -> a -> a
genericMul x y = to $ gmul (from x) (from y)
{-# INLINE genericMul #-}


-- | 'norm_2' using GHC Generics; works if all fields are instances of
-- 'ParamSet', but only for values with single constructors.
genericNorm2 :: (Generic a, GNorm2 (Rep a)) => a -> Double
genericNorm2 x = gnorm_2 (from x)
{-# INLINE genericNorm2 #-}


-- | 'pmap' using GHC Generics; works if all fields are instances of
-- 'ParamSet', but only for values with single constructors.
genericPMap :: (Generic a, GPMap (Rep a)) => (Double -> Double) -> a -> a
genericPMap f x = to $ gpmap f (from x)
{-# INLINE genericPMap #-}


--------------------------------------------------
-- Generics
--
-- Partially borrowed from the backprop library
--------------------------------------------------


-- -- | Helper class for automatically deriving 'add' using GHC Generics.
-- class GZero f where
--     gzero :: f t -> f t
-- 
-- instance ParamSet p => GZero (K1 i p) where
--     gzero (K1 x) = K1 (zero x)
--     {-# INLINE gzero #-}
-- 
-- instance (GZero f, GZero g) => GZero (f :*: g) where
--     gzero (x1 :*: y1) = x2 :*: y2
--       where
--         !x2 = gzero x1
--         !y2 = gzero y1
--     {-# INLINE gzero #-}
-- 
-- instance GZero V1 where
--     gzero = \case {}
--     {-# INLINE gzero #-}
-- 
-- instance GZero U1 where
--     gzero _ = U1
--     {-# INLINE gzero #-}
-- 
-- instance GZero f => GZero (M1 i c f) where
--     gzero (M1 x) = M1 (gzero x)
--     {-# INLINE gzero #-}
-- 
-- -- instance GZero f => GZero (f :.: g) where
-- --     gzero = Comp1 gzero
-- --     {-# INLINE gzero #-}


-- | Helper class for automatically deriving 'add' using GHC Generics.
class GAdd f where
    gadd :: f t -> f t -> f t

instance ParamSet a => GAdd (K1 i a) where
    gadd (K1 x) (K1 y) = K1 (add x y)
    {-# INLINE gadd #-}

instance (GAdd f, GAdd g) => GAdd (f :*: g) where
    gadd (x1 :*: y1) (x2 :*: y2) = x3 :*: y3
      where
        !x3 = gadd x1 x2
        !y3 = gadd y1 y2
    {-# INLINE gadd #-}

instance GAdd V1 where
    gadd = \case {}
    {-# INLINE gadd #-}

instance GAdd U1 where
    gadd _ _ = U1
    {-# INLINE gadd #-}

instance GAdd f => GAdd (M1 i c f) where
    gadd (M1 x) (M1 y) = M1 (gadd x y)
    {-# INLINE gadd #-}

-- instance GAdd f => GAdd (f :.: g) where
--     gadd (Comp1 x) (Comp1 y) = Comp1 (gadd x y)
--     {-# INLINE gadd #-}


-- | Helper class for automatically deriving 'sub' using GHC Generics.
class GSub f where
    gsub :: f t -> f t -> f t

instance ParamSet a => GSub (K1 i a) where
    gsub (K1 x) (K1 y) = K1 (sub x y)
    {-# INLINE gsub #-}

instance (GSub f, GSub g) => GSub (f :*: g) where
    gsub (x1 :*: y1) (x2 :*: y2) = x3 :*: y3
      where
        !x3 = gsub x1 x2
        !y3 = gsub y1 y2
    {-# INLINE gsub #-}

instance GSub V1 where
    gsub = \case {}
    {-# INLINE gsub #-}

instance GSub U1 where
    gsub _ _ = U1
    {-# INLINE gsub #-}

instance GSub f => GSub (M1 i c f) where
    gsub (M1 x) (M1 y) = M1 (gsub x y)
    {-# INLINE gsub #-}

-- instance GSub f => GSub (f :.: g) where
--     gsub (Comp1 x) (Comp1 y) = Comp1 (gsub x y)
--     {-# INLINE gsub #-}


-- | Helper class for automatically deriving 'mul' using GHC Generics.
class GMul f where
    gmul :: f t -> f t -> f t

instance ParamSet a => GMul (K1 i a) where
    gmul (K1 x) (K1 y) = K1 (mul x y)
    {-# INLINE gmul #-}

instance (GMul f, GMul g) => GMul (f :*: g) where
    gmul (x1 :*: y1) (x2 :*: y2) = x3 :*: y3
      where
        !x3 = gmul x1 x2
        !y3 = gmul y1 y2
    {-# INLINE gmul #-}

instance GMul V1 where
    gmul = \case {}
    {-# INLINE gmul #-}

instance GMul U1 where
    gmul _ _ = U1
    {-# INLINE gmul #-}

instance GMul f => GMul (M1 i c f) where
    gmul (M1 x) (M1 y) = M1 (gmul x y)
    {-# INLINE gmul #-}

-- instance GMul f => GMul (f :.: g) where
--     gmul (Comp1 x) (Comp1 y) = Comp1 (gmul x y)
--     {-# INLINE gmul #-}


-- | Helper class for automatically deriving 'div' using GHC Generics.
class GDiv f where
    gdiv :: f t -> f t -> f t

instance ParamSet a => GDiv (K1 i a) where
    gdiv (K1 x) (K1 y) = K1 (div x y)
    {-# INLINE gdiv #-}

instance (GDiv f, GDiv g) => GDiv (f :*: g) where
    gdiv (x1 :*: y1) (x2 :*: y2) = x3 :*: y3
      where
        !x3 = gdiv x1 x2
        !y3 = gdiv y1 y2
    {-# INLINE gdiv #-}

instance GDiv V1 where
    gdiv = \case {}
    {-# INLINE gdiv #-}

instance GDiv U1 where
    gdiv _ _ = U1
    {-# INLINE gdiv #-}

instance GDiv f => GDiv (M1 i c f) where
    gdiv (M1 x) (M1 y) = M1 (gdiv x y)
    {-# INLINE gdiv #-}

-- instance GDiv f => GDiv (f :.: g) where
--     gdiv (Comp1 x) (Comp1 y) = Comp1 (gdiv x y)
--     {-# INLINE gdiv #-}


-- | Helper class for automatically deriving 'norm_2' using GHC Generics.
class GNorm2 f where
    gnorm_2 :: f t -> Double

instance ParamSet a => GNorm2 (K1 i a) where
    gnorm_2 (K1 x) = norm_2 x
    {-# INLINE gnorm_2 #-}

instance (GNorm2 f, GNorm2 g) => GNorm2 (f :*: g) where
    gnorm_2 (x1 :*: y1) =
      sqrt ((x2 ^ (2 :: Int)) + (y2 ^ (2 :: Int)))
      where
        !x2 = gnorm_2 x1
        !y2 = gnorm_2 y1
    {-# INLINE gnorm_2 #-}

instance GNorm2 V1 where
    gnorm_2 = \case {}
    {-# INLINE gnorm_2 #-}

instance GNorm2 U1 where
    gnorm_2 _ = 0
    {-# INLINE gnorm_2 #-}

instance GNorm2 f => GNorm2 (M1 i c f) where
    gnorm_2 (M1 x) = gnorm_2 x
    {-# INLINE gnorm_2 #-}

-- -- TODO: Make sure this makes sense
-- instance GNorm2 f => GNorm2 (f :.: g) where
--     gnorm_2 (Comp1 x) = gnorm_2 x
--     {-# INLINE gnorm_2 #-}


-- | Helper class for automatically deriving 'pmap' using GHC Generics.
class GPMap f where
    gpmap :: (Double -> Double) -> f t -> f t

instance ParamSet a => GPMap (K1 i a) where
    gpmap f (K1 x) = K1 (pmap f x)
    {-# INLINE gpmap #-}

instance (GPMap f, GPMap g) => GPMap (f :*: g) where
    gpmap f (x1 :*: y1) = x2 :*: y2
      where
        !x2 = gpmap f x1
        !y2 = gpmap f y1
    {-# INLINE gpmap #-}

instance GPMap V1 where
    gpmap _ = \case {}
    {-# INLINE gpmap #-}

instance GPMap U1 where
    gpmap _ _ = U1
    {-# INLINE gpmap #-}

instance GPMap f => GPMap (M1 i c f) where
    gpmap f (M1 x) = M1 (gpmap f x)
    {-# INLINE gpmap #-}

-- instance GPMap f => GPMap (f :.: g) where
--     gpmap f (Comp1 x) = Comp1 (gpmap f x)
--     {-# INLINE gpmap #-}


--------------------------------------------------
-- Basic instances
--------------------------------------------------


instance ParamSet Double where
  zero = const 0
  pmap = id
  add = (+)
  sub = (-)
  mul = (*)
  div = (/)
  norm_2 = abs


instance (ParamSet a, ParamSet b) => ParamSet (a, b) where
  pmap f (x, y) = (pmap f x, pmap f y)
  add (x1, y1) (x2, y2) = (x1 `add` x2, y1 `add` y2)
  sub (x1, y1) (x2, y2) = (x1 `sub` x2, y1 `sub` y2)
  mul (x1, y1) (x2, y2) = (x1 `mul` x2, y1 `mul` y2)
  div (x1, y1) (x2, y2) = (x1 `div` x2, y1 `div` y2)
  norm_2 (x, y)
    = sqrt . sum . map ((^(2::Int)))
    $ [norm_2 x, norm_2 y]


instance (KnownNat n) => ParamSet (LA.R n) where
  zero = const 0
  pmap = LA.dvmap
  add = (+)
  sub = (-)
  mul = (*)
  div = (/)
  norm_2 = LA.norm_2


instance (KnownNat n, KnownNat m) => ParamSet (LA.L n m) where
  zero = const 0
  pmap = LA.dmmap
  add = (+)
  sub = (-)
  mul = (*)
  div = (/)
  norm_2 = LA.norm_2


-- | `Nothing` represents a deactivated parameter set component. If `Nothing`
-- is given as an argument to one of the `ParamSet` operations, the result is
-- `Nothing` as well.
--
-- This differs from the corresponding instance in the backprop library, where
-- `Nothing` is equivalent to `Just 0`.  However, the implementation below
-- seems to correspond adequately enough to the notion that a particular
-- component is either active or not in both the parameter set and the
-- gradient, hence it doesn't make sense to combine `Just` with `Nothing`.
instance (ParamSet a) => ParamSet (Maybe a) where
  zero = fmap zero
  pmap = fmap . pmap

  add (Just x) (Just y) = Just (add x y)
  add _ _ = Nothing

  sub (Just x) (Just y) = Just (sub x y)
  sub _ _ = Nothing

  mul (Just x) (Just y) = Just (mul x y)
  mul _ _ = Nothing

  div (Just x) (Just y) = Just (div x y)
  div _ _ = Nothing

  norm_2 = maybe 0 norm_2


-- | A map with different parameter sets (of the same type) assigned to the
-- individual keys.
--
-- When combining two maps with different sets of keys, only their intersection
-- is preserved.
instance (Ord k, ParamSet a) => ParamSet (M.Map k a) where
  zero = fmap zero
  pmap f = fmap (pmap f)
  add = M.intersectionWith add
  sub = M.intersectionWith sub
  mul= M.intersectionWith mul
  div= M.intersectionWith div
  norm_2 = sqrt . sum . map ((^(2::Int)) . norm_2)  . M.elems