{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE PatternSynonyms     #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE TypeOperators       #-}

-- |
-- Module      : Numeric.Backprop.Implicit
-- Copyright   : (c) Justin Le 2017
-- License     : BSD3
--
-- Maintainer  : justin@jle.im
-- Stability   : experimental
-- Portability : non-portable
--
-- Offers full functionality for implicit-graph back-propagation.  The
-- intended usage is to write a 'BPOp', which is a normal Haskell
-- function from 'BVar's to a result 'BVar'. These 'BVar's can be
-- manipulated using their 'Num' \/ 'Fractional' \/ 'Floating' instances.
--
-- The library can then perform back-propagation on the function (using
-- 'backprop' or 'grad') by using an implicitly built graph.
--
-- This should actually be powerful enough for most use cases, but falls
-- short for a couple of situations:
--
-- 1. If the result of a function on 'BVar's is used twice
-- (like @z@ in @let z = x * y in z + z@), this will allocate a new
-- redundant graph node for every usage site of @z@.  You can explicitly
-- /force/ @z@, but only using an explicit graph description using
-- "Numeric.Backprop".
--
-- 2. This can't handle sum types, like "Numeric.Backprop" can.  You can
-- never pattern match on the constructors of a value inside a 'BVar'.  I'm
-- not sure if this is a fundamental limitation (I suspect it might be) or
-- if I just can't figure out how to implement it.  Suggestions welcome!
--
-- As a comparison, this module offers functionality and an API very
-- similar to "Numeric.AD.Mode.Reverse" from the /ad/ library, except for
-- the fact that it can handle /heterogeneous/ values.
--
-- Note that every type involved has to be an instance of 'Num'.  This is
-- because gradients all need to be "summable" (which is implemented using
-- 'sum' and '+'), and we also need to able to generate gradients of '1'
-- and '0'.


module Numeric.Backprop.Implicit (
  -- * Types
  -- ** Backprop types
    BPOp, BVar, Op, OpB
  -- ** Tuple types
  -- | See "Numeric.Backprop#prod" for a mini-tutorial on 'Prod' and
  -- 'Tuple'
  , Prod(..), Tuple, I(..)
  -- * back-propagation
  , backprop, grad, eval
  -- * Var manipulation
  , BP.constVar, BP.liftB, (BP..$), BP.liftB1, BP.liftB2, BP.liftB3
  -- ** As Parts
  , partsVar, withParts
  , splitVars, gSplit, gTuple
  , partsVar', withParts'
  , splitVars', gSplit'
  -- * Op
  , BP.op1, BP.op2, BP.op3, BP.opN
  , BP.op1', BP.op2', BP.op3'
  -- * Utility
  , pattern (:>), only, head'
  , pattern (::<), only_
  -- ** Numeric Ops
  -- | Optimized ops for numeric functions.  See
  -- "Numeric.Backprop.Op#numops" for more information.
  , (+.), (-.), (*.), negateOp, absOp, signumOp
  , (/.), recipOp
  , expOp, logOp, sqrtOp, (**.), logBaseOp
  , sinOp, cosOp, tanOp, asinOp, acosOp, atanOp
  , sinhOp, coshOp, tanhOp, asinhOp, acoshOp, atanhOp
  ) where

import           Data.Type.Combinator
import           Data.Type.Index
import           Data.Type.Length
import           Data.Type.Product
import           Data.Type.Util
import           Lens.Micro hiding         (ix)
import           Lens.Micro.Extras
import           Numeric.Backprop.Internal
import           Numeric.Backprop.Iso
import           Numeric.Backprop.Op
import           Type.Class.Higher
import           Type.Class.Known
import           Type.Class.Witness
import qualified Generics.SOP              as SOP
import qualified Numeric.Backprop          as BP

-- | An operation on 'BVar's that can be backpropagated. A value of type:
--
-- @
-- 'BPOp' rs a
-- @
--
-- takes a bunch of 'BVar's containg @rs@ and uses them to (purely) produce
-- a 'BVar' containing an @a@.
--
-- @
-- foo :: 'BPOp' '[ Double, Double ] Double
-- foo (x ':<' y ':<' 'Ø') = x + sqrt y
-- @
--
-- 'BPOp' here is related to 'Numeric.Backprop.BPOpI' from the normal
-- explicit-graph backprop module "Numeric.Backprop".
type BPOp rs a = forall s. Prod (BVar s rs) rs -> BVar s rs a

-- | Run back-propagation on a 'BPOp' function, getting both the result and
-- the gradient of the result with respect to the inputs.
--
-- @
-- foo :: 'BPOp' '[Double, Double] Double
-- foo (x :< y :< Ø) =
--   let z = x * sqrt y
--   in  z + x ** y
-- @
--
-- >>> 'backprop' foo (2 ::< 3 ::< Ø)
-- (11.46, 13.73 ::< 6.12 ::< Ø)
backprop
    :: Every Num rs
    => BPOp rs a
    -> Tuple rs
    -> (a, Tuple rs)
backprop f xs = BP.backprop (BP.withInps' (prodLength xs) (return . f)) xs

-- | Run the 'BPOp' on an input tuple and return the gradient of the result
-- with respect to the input tuple.
--
-- @
-- foo :: 'BPOp' '[Double, Double] Double
-- foo (x :< y :< Ø) =
--   let z = x * sqrt y
--   in  z + x ** y
-- @
--
-- >>> grad foo (2 ::< 3 ::< Ø)
-- 13.73 ::< 6.12 ::< Ø
grad
    :: Every Num rs
    => BPOp rs a
    -> Tuple rs
    -> Tuple rs
grad f = snd . backprop f

-- | Simply run the 'BPOp' on an input tuple, getting the result without
-- bothering with the gradient or with back-propagation.
--
-- @
-- foo :: 'BPOp' '[Double, Double] Double
-- foo (x :< y :< Ø) =
--   let z = x * sqrt y
--   in  z + x ** y
-- @
--
-- >>> eval foo (2 ::< 3 ::< Ø)
-- 11.46
eval
    :: (Known Length rs, Num a)
    => BPOp rs a
    -> Tuple rs
    -> a
eval f = BP.evalBPOp $ BP.implicitly f

-- | A version of 'partsVar' taking explicit 'Length', indicating the
-- number of items in the input tuple and their types.
--
-- Requiring an explicit 'Length' is mostly useful for rare "extremely
-- polymorphic" situations, where GHC can't infer the type and length of
-- the internal tuple.  If you ever actually explicitly write down @bs@ as
-- a list of types, you should be able to just use 'partsVar'.
partsVar'
    :: forall s rs bs a. Every Num bs
    => Length bs
    -> Iso' a (Tuple bs)
    -> BVar s rs a
    -> Prod (BVar s rs) bs
partsVar' l i r = map1 (\ix -> every @_ @Num ix //
                                 BP.liftB1 (BP.op1' (f ix)) r
                       ) ixes
  where
    f :: Num b
      => Index bs b
      -> a
      -> (b, Maybe b -> a)
    f ix x = ( getI . index ix . view i $ x
             , review i
             . flip (set (indexP ix)) zeroes
             . maybe (I 1) I
             )
    zeroes :: Tuple bs
    zeroes = map1 (\ix -> I 0 \\ every @_ @Num ix) ixes
    ixes :: Prod (Index bs) bs
    ixes = indices' l

-- | Use an 'Iso' (or compatible 'Control.Lens.Iso.Iso' from the lens
-- library) to "pull out" the parts of a data type and work with each part
-- as a 'BVar'.
--
-- If there is an isomorphism between a @b@ and a @'Tuple' as@ (that is, if
-- an @a@ is just a container for a bunch of @as@), then it lets you break
-- out the @as@ inside and work with those.
--
-- @
-- data Foo = F Int Bool
--
-- fooIso :: 'Iso'' Foo (Tuple '[Int, Bool])
-- fooIso = 'iso' (\\(F i b)         -\> i ::\< b ::\< Ø)
--              (\\(i ::\< b ::\< Ø) -\> F i b        )
--
-- 'partsVar' fooIso :: 'BVar' rs Foo -> 'Prod' ('BVar' s rs) '[Int, Bool]
--
-- stuff :: 'BPOp' s '[Foo] a
-- stuff (foo :< Ø) =
--     case 'partsVar' fooIso foo of
--       i :< b :< Ø ->
--         -- now, i is a 'BVar' pointing to the 'Int' inside foo
--         -- and b is a 'BVar' pointing to the 'Bool' inside foo
--         -- you can do stuff with the i and b here
-- @
--
-- You can use this to pass in product types as the environment to a 'BP',
-- and then break out the type into its constituent products.
--
-- Note that for a type like @Foo@, @fooIso@ can be generated automatically
-- with 'GHC.Generics.Generic' from "GHC.Generics" and
-- 'Generics.SOP.Generic' from "Generics.SOP" and /generics-sop/, using the
-- 'gTuple' iso.  See 'gSplit' for more information.
--
-- Also, if you are literally passing a tuple (like
-- @'BP' s '[Tuple '[Int, Bool]@) then you can give in the identity
-- isomorphism ('id') or use 'splitVars'.
--
-- At the moment, this implicit 'partsVar' is less efficient than the
-- explicit 'Numeric.Backprop.partsVar', but this might change in the
-- future.
partsVar
    :: forall s rs bs a. (Every Num bs, Known Length bs)
    => Iso' a (Tuple bs)
    -> BVar s rs a
    -> Prod (BVar s rs) bs
partsVar = partsVar' known

-- | A version of 'withParts' taking explicit 'Length', indicating the
-- number of internal items and their types.
--
-- Requiring an explicit 'Length' is mostly useful for rare "extremely
-- polymorphic" situations, where GHC can't infer the type and length of
-- the internal tuple.  If you ever actually explicitly write down @bs@ as
-- a list of types, you should be able to just use 'withParts'.
withParts'
    :: forall s rs bs a r. Every Num bs
    => Length bs
    -> Iso' a (Tuple bs)
    -> BVar s rs a
    -> (Prod (BVar s rs) bs -> r)
    -> r
withParts' l i r f = f (partsVar' l i r)

-- | A continuation-based version of 'partsVar'.  Instead of binding the
-- parts and using it in the rest of the block, provide a continuation to
-- handle do stuff with the parts inside.
--
-- Building on the example from 'partsVar':
--
-- @
-- data Foo = F Int Bool
--
-- fooIso :: 'Iso'' Foo (Tuple '[Int, Bool])
-- fooIso = 'iso' (\\(F i b)         -\> i ::\< b ::\< Ø)
--              (\\(i ::\< b ::\< Ø) -\> F i b        )
--
-- stuff :: 'BPOp' s '[Foo] a
-- stuff (foo :< Ø) = 'withParts' fooIso foo $ \\case
--     i :\< b :< Ø -\>
--       -- now, i is a 'BVar' pointing to the 'Int' inside foo
--       -- and b is a 'BVar' pointing to the 'Bool' inside foo
--       -- you can do stuff with the i and b here
-- @
--
-- Mostly just a stylistic alternative to 'partsVar'.
withParts
    :: forall s rs bs a r. (Every Num bs, Known Length bs)
    => Iso' a (Tuple bs)
    -> BVar s rs a
    -> (Prod (BVar s rs) bs -> r)
    -> r
withParts = withParts' known

-- | A version of 'splitVars' taking explicit 'Length', indicating the
-- number of internal items and their types.
--
-- Requiring an explicit 'Length' is mostly useful for rare "extremely
-- polymorphic" situations, where GHC can't infer the type and length of
-- the internal tuple.  If you ever actually explicitly write down @as@ as
-- a list of types, you should be able to just use 'splitVars'.
splitVars'
    :: forall s rs as. Every Num as
    => Length as
    -> BVar s rs (Tuple as)
    -> Prod (BVar s rs) as
splitVars' l = partsVar' l id

-- | Split out a 'BVar' of a tuple into a tuple ('Prod') of 'BVar's.
--
-- @
-- -- the environment is a single Int-Bool tuple, tup
-- stuff :: 'BPOp' s '[ Tuple '[Int, Bool] ] a
-- stuff (tup :< Ø) =
--   case 'splitVar' tup of
--     i :< b :< Ø <- 'splitVars' tup
--     -- now, i is a 'BVar' pointing to the 'Int' inside tup
--     -- and b is a 'BVar' pointing to the 'Bool' inside tup
--     -- you can do stuff with the i and b here
-- @
--
-- Note that
--
-- @
-- 'splitVars' = 'partsVar' 'id'
-- @
splitVars
    :: forall s rs as. (Every Num as, Known Length as)
    => BVar s rs (Tuple as)
    -> Prod (BVar s rs) as
splitVars = splitVars' known

-- | A version of 'gSplit' taking explicit 'Length', indicating the
-- number of internal items and their types.
--
-- Requiring an explicit 'Length' is mostly useful for rare "extremely
-- polymorphic" situations, where GHC can't infer the type and length of
-- the internal tuple.  If you ever actually explicitly write down @as@ as
-- a list of types, you should be able to just use 'gSplit'.
gSplit'
    :: forall s rs as a. (SOP.Generic a, SOP.Code a ~ '[as], Every Num as)
    => Length as
    -> BVar s rs a
    -> Prod (BVar s rs) as
gSplit' l = partsVar' l gTuple

-- | Using 'GHC.Generics.Generic' from "GHC.Generics" and
-- 'Generics.SOP.Generic' from "Generics.SOP", /split/ a 'BVar' containing
-- a product type into a tuple ('Prod') of 'BVar's pointing to each value
-- inside.
--
-- Building on the example from 'partsVar':
--
-- @
-- import qualified Generics.SOP as SOP
--
-- data Foo = F Int Bool
--   deriving Generic
--
-- instance SOP.Generic Foo
--
-- 'gSplit' :: 'BVar' rs Foo -> 'Prod' ('BVar' s rs) '[Int, Bool]
--
-- stuff :: 'BPOp' s '[Foo] a
-- stuff (foo :< Ø) =
--     case 'gSplit' foo of
--       i :< b :< Ø ->
--         -- now, i is a 'BVar' pointing to the 'Int' inside foo
--         -- and b is a 'BVar' pointing to the 'Bool' inside foo
--         -- you can do stuff with the i and b here
-- @
--
-- Because @Foo@ is a straight up product type, 'gSplit' can use
-- "GHC.Generics" and take out the items inside.
--
-- Note that
--
-- @
-- 'gSplit' = 'splitVars' 'gTuple'
-- @
gSplit
    :: forall s rs as a. (SOP.Generic a, SOP.Code a ~ '[as], Every Num as, Known Length as)
    => BVar s rs a
    -> Prod (BVar s rs) as
gSplit = gSplit' known

-- TODO: figure out how to split sums
-- TODO: refactor these out to not need Known Length