{-# LANGUAGE DeriveFunctor              #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase                 #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE PolyKinds                  #-}
{-# LANGUAGE RankNTypes                 #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE TemplateHaskell            #-}
{-# LANGUAGE TypeApplications           #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE TypeInType                 #-}
{-# LANGUAGE TypeOperators              #-}
{-# LANGUAGE UndecidableInstances       #-}

-- |
-- Module      : Numeric.Backprop.Internal
-- Copyright   : (c) Justin Le 2017
-- License     : BSD3
--
-- Maintainer  : justin@jle.im
-- Stability   : experimental
-- Portability : non-portable
--
-- Provides the types and instances used for the graph
-- building/back-propagation for the library.

module Numeric.Backprop.Internal
  ( Summer(..), summers, summers'
  , Unity(..), unities, unities'
  , OpB
  , BPState(..), bpsSources
  , BP(..)
  , BPInpRef(..)
  , BPNode(..), bpnOut, bpnRes, bpnGradFunc, bpnGradCache, bpnSummer
  , BPPipe(..), bppOut, bppRes, bppGradFunc, bppGradCache
  , BVar(..)
  , ForwardRefs(..), _FRInternal
  ) where

import           Control.Monad.Reader
import           Control.Monad.ST
import           Control.Monad.State
import           Data.Kind
import           Data.STRef
import           Data.Type.Index
import           Data.Type.Product
import           Lens.Micro hiding                (ix)
import           Lens.Micro.TH
import           Numeric.Backprop.Internal.Helper
import           Numeric.Backprop.Op

-- | A subclass of 'OpM' (and superclass of 'Op'), representing 'Op's that
-- the /backprop/ library uses to perform backpropation.
--
-- An
--
-- @
-- 'OpB' s rs a
-- @
--
-- represents a differentiable function that takes a tuple of @rs@ and
-- produces an a @a@, which can be run on @'BVar' s@s and also inside @'BP'
-- s@s.  For example, an @'OpB' s '[ Int, Double ] Bool@ takes an 'Int' and
-- a 'Double' and produces a 'Bool', and does it in a differentiable way.
--
-- 'OpB' is a /superset/ of 'Op', so, if you see any function
-- that expects an 'OpB' (like 'Numeric.Backprop.opVar'' and
-- 'Numeric.Backprop.~$', for example), you can give them an 'Op', as well.
--
-- You can think of 'OpB' as a superclass/parent class of 'Op' in this
-- sense, and of 'Op' as a subclass of 'OpB'.
type OpB s as a = OpM (ST s) as a

-- | Reference to /usage sites/ for a given entity, used to get partial or
-- total derivatives.
data ForwardRefs s rs a
    -- | A list of 'BPInpRef's pointing to places that use the entity, to
    -- provide partial derivatives.
    = FRInternal ![BPInpRef s rs a]
    -- | The entity is the terminal result of a BP, so its total derivative
    -- is fixed.
    | FRTerminal !(Maybe a)

-- | Combines two 'FRInternal' lists.  If either input is an 'FRTerminal',
-- then throws away the other result and keeps the new terminal forced
-- total derivative.  (Biases to the left)
instance Monoid (ForwardRefs s rs a) where
    mempty  = FRInternal []
    mappend = \case
        FRInternal rs -> \case
          FRInternal rs'   -> FRInternal (rs ++ rs')
          t@(FRTerminal _) -> t
        FRTerminal _  -> id

-- | The "state" of a 'BP' action, which keeps track of what nodes, if any,
-- refer to any of the inputs.
data BPState :: Type -> [Type] -> Type where
    BPS :: { _bpsSources :: !(Prod (ForwardRefs s rs) rs)
           }
        -> BPState s rs

-- | A Monad allowing you to explicitly build hetereogeneous data
-- dependency graphs and that the library can perform back-propagation on.
--
-- A @'BP' s rs a@ is a 'BP' action that uses an environment of @rs@
-- returning a @a@.  When "run", it will compute a gradient that is a tuple
-- of @rs@.  (The phantom parameter @s@ is used to ensure that any 'BVar's
-- aren't leaked out of the monad)
--
-- Note that you can only "run" a @'BP' s rs@ that produces a 'BVar' --
-- that is, things of the form
--
-- @
-- 'BP' s rs ('BVar' s rs a)
-- @
--
-- The above is a 'BP' action that returns a 'BVar' containing an @a@.
-- When this is run, it'll produce a result of type @a@ and a gradient of
-- that is a tuple of @rs@.  (This form has a type synonym,
-- 'Numeric.Backprop.BPOp', for convenience)
--
-- For example, a @'BP' s '[ Int, Double, Double ]@ is a monad that
-- represents a computation with an 'Int', 'Double', and 'Double' as
-- inputs.   And, if you ran a
--
-- @
-- 'BP' s '[ Int, Double, Double ] ('BVar' s '[ Int, Double, Double ] Double)
-- @
--
-- Or, using the 'BPOp' type synonym:
--
-- @
-- 'Numeric.Backprop.BPOp' s '[ Int, Double, Double ] Double
-- @
--
-- with 'Numeric.Backprop.backprop' or 'Numeric.Backprop.gradBPOp', it'll
-- return a gradient on the inputs ('Int', 'Double', and 'Double') and
-- produce a value of type 'Double'.
--
-- Now, one powerful thing about this type is that a 'BP' is itself an
-- 'Op' (or more precisely, an 'Numeric.Backprop.OpB', which is a subtype of
-- 'OpM').  So, once you create your fancy 'BP' computation, you can
-- transform it into an 'OpM' using 'Numeric.Backprop.bpOp'.
newtype BP s rs a = BP { bpST :: ReaderT (Tuple rs) (StateT (BPState s rs) (ST s)) a }
      deriving ( Functor
               , Applicative
               , Monad
               )

-- | The basic unit of manipulation inside 'BP' (or inside an
-- implicit-graph backprop function).  Instead of directly working with
-- values, you work with 'BVar's contating those values.  When you work
-- with a 'BVar', the /backprop/ library can keep track of what values
-- refer to which other values, and so can perform back-propagation to
-- compute gradients.
--
-- A @'BVar' s rs a@ refers to a value of type @a@, with an environment
-- of values of the types @rs@.  The phantom parameter @s@ is used to
-- ensure that stray 'BVar's don't leak outside of the backprop process.
--
-- (That is, if you're using implicit backprop, it ensures that you interact
-- with 'BVar's in a polymorphic way.  And, if you're using explicit
-- backprop, it ensures that a @'BVar' s rs a@ never leaves the @'BP' s rs@
-- that it was created in.)
--
-- 'BVar's have 'Num', 'Fractional', 'Floating', etc. instances, so they
-- can be manipulated using polymorphic functions and numeric functions in
-- Haskell.  You can add them, subtract them, etc., in "implicit" backprop
-- style.
--
-- (However, note that if you directly manipulate 'BVar's using those
-- instances or using 'Numeric.Backprop.liftB', it delays evaluation, so every usage site
-- has to re-compute the result/create a new node.  If you want to re-use
-- a 'BVar' you created using '+' or '-' or 'Numeric.Backprop.liftB', use
-- 'Numeric.Backprop.bindVar' to force it first.  See documentation for
-- 'Numeric.Backprop.bindVar' for more details.)
data BVar :: Type -> [Type] -> Type -> Type where
    -- | A BVar referring to a 'BPNode'
    BVNode  :: !(Index bs a)
            -> !(STRef s (BPNode s rs as bs))
            -> BVar s rs a
    -- | A BVar referring to an environment input variable
    BVInp   :: !(Index rs a)
            -> BVar s rs a
    -- | A constant BVar that refers to a specific Haskell value
    BVConst :: !a
            -> BVar s rs a
    -- | A BVar that combines several other BVars using a function (an
    -- 'Op').  Essentially a branch of a tree.
    BVOp    :: !(Prod (BVar s rs) as)
            -> !(OpB s as a)
            -> BVar s rs a

-- | Used exclusively by 'ForwardRefs' to specify "where" and "how" to look
-- for partial derivatives at usage sites of a given entity.
data BPInpRef :: Type -> [Type] -> Type -> Type where
    -- | The entity is used in a 'BPNode', and as an Nth input
    IRNode  :: !(Index bs a)
            -> !(STRef s (BPNode s rs bs cs))
            -> BPInpRef s rs a
    -- | The entity is used in a 'BPPipe', and as an Nth input
    IRPipe  :: !(Index bs a)
            -> !(STRef s (BPPipe s rs bs cs))
            -> BPInpRef s rs a
    -- | The entity is used somehow in the terminal result of a 'BP', and
    -- so therefore has a fixed partial derivative contribution.
    IRConst :: !a
            -> BPInpRef s rs a

-- | A (stateful) node in the graph of operations/data dependencies in 'BP'
-- that the library uses.  'BVar's can refer to these to get results from
-- them, and 'BPInpRef's can refer to these to get partial derivatives from
-- them.
data BPNode :: Type -> [Type] -> [Type] -> [Type] -> Type where
    BPN :: { _bpnOut       :: !(Prod (ForwardRefs s rs) bs)
           , _bpnRes       :: !(Tuple bs)
           , _bpnGradFunc  :: !(Prod Maybe bs -> ST s (Tuple as))
           , _bpnGradCache :: !(Maybe (Tuple as))  -- nothing if is the "final output"
           , _bpnSummer    :: !(Prod Summer bs)
           }
        -> BPNode s rs as bs

-- | Essentially a "single-usage" 'BPNode'.  It's a stateful node, but only
-- ever has a single consumer (and so its total derivative comes from
-- a single partial derivative).  Used when keeping track of 'BVOp's.
data BPPipe :: Type -> [Type] -> [Type] -> [Type] -> Type where
    BPP :: { _bppOut       :: !(Prod (BPInpRef s rs) bs)
           , _bppRes       :: !(Tuple bs)
           , _bppGradFunc  :: !(Tuple bs -> ST s (Tuple as))
           , _bppGradCache :: !(Maybe (Tuple as))
           }
        -> BPPipe s rs as bs

makeLenses ''BPState
makeLenses ''BPNode
makeLenses ''BPPipe

-- | Traversal (fake prism) to refer to the list of internal refs if the
-- 'ForwardRef' isn't associated with a terminal entity.
_FRInternal
    :: Traversal (ForwardRefs s as a) (ForwardRefs t bs a)
                 [BPInpRef s as a]    [BPInpRef t bs a]
_FRInternal f = \case
    FRInternal xs -> FRInternal <$> f xs
    FRTerminal g  -> pure (FRTerminal g)




-- | Note that if you use the 'Num' instance to create 'BVar's, the
-- resulting 'BVar' is deferred/delayed.  At every location you use it, it
-- will be recomputed, and a separate graph node will be created.  If you
-- are using a 'BVar' you made with the 'Num' instance in multiple
-- locations, use 'Numeric.Backprop.bindVar' first to force it and prevent
-- recomputation.
instance Num a => Num (BVar s rs a) where
    r1 + r2       = BVOp (r1 :< r2 :< Ø) $ op2 (+)
    r1 - r2       = BVOp (r1 :< r2 :< Ø) $ op2 (-)
    r1 * r2       = BVOp (r1 :< r2 :< Ø) $ op2 (*)
    negate r      = BVOp (r  :< Ø)       $ op1 negate
    signum r      = BVOp (r  :< Ø)       $ op1 signum
    abs r         = BVOp (r  :< Ø)       $ op1 abs
    fromInteger x = BVConst (fromInteger x)

-- | See note for 'Num' instance.
instance Fractional a => Fractional (BVar s rs a) where
    r1 / r2        = BVOp (r1 :< r2 :< Ø) $ op2 (/)
    recip r        = BVOp (r  :< Ø)       $ op1 recip
    fromRational x = BVConst (fromRational x)

-- | See note for 'Num' instance.
instance Floating a => Floating (BVar s rs a) where
    pi            = BVConst pi
    exp   r       = BVOp (r :< Ø)        $ op1 exp
    log   r       = BVOp (r :< Ø)        $ op1 log
    sqrt  r       = BVOp (r :< Ø)        $ op1 sqrt
    r1 ** r2      = BVOp (r1 :< r2 :< Ø) $ op2 (**)
    logBase r1 r2 = BVOp (r1 :< r2 :< Ø) $ op2 logBase
    sin   r       = BVOp (r :< Ø)        $ op1 sin
    cos   r       = BVOp (r :< Ø)        $ op1 cos
    tan   r       = BVOp (r :< Ø)        $ op1 tan
    asin  r       = BVOp (r :< Ø)        $ op1 asin
    acos  r       = BVOp (r :< Ø)        $ op1 acos
    atan  r       = BVOp (r :< Ø)        $ op1 atan
    sinh  r       = BVOp (r :< Ø)        $ op1 sinh
    cosh  r       = BVOp (r :< Ø)        $ op1 cosh
    tanh  r       = BVOp (r :< Ø)        $ op1 tanh
    asinh r       = BVOp (r :< Ø)        $ op1 asinh
    acosh r       = BVOp (r :< Ø)        $ op1 acosh
    atanh r       = BVOp (r :< Ø)        $ op1 atanh