{-# LANGUAGE Rank2Types, TypeFamilies #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.AD.Mode.Mixed
-- Copyright   :  (c) Edward Kmett 2010
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- Mixed-Mode Automatic Differentiation.
--
-- Each combinator exported from this module chooses an appropriate AD mode.
-----------------------------------------------------------------------------

module Numeric.AD.Mode.Mixed
    (
    -- * Gradients (Reverse Mode)
      grad
    , grad'
    , gradWith
    , gradWith'

    -- * Jacobians (Mixed Mode)
    , jacobian
    , jacobian'
    , jacobianWith
    , jacobianWith'

    -- * Monadic Gradient/Jacobian (Reverse Mode)
    , gradM
    , gradM'
    , gradWithM
    , gradWithM'

    -- * Functorial Gradient/Jacobian (Reverse Mode)
    , gradF
    , gradF'
    , gradWithF
    , gradWithF'

    -- * Transposed Jacobians (Forward Mode)
    , jacobianT
    , jacobianWithT

    -- * Hessian (Forward-On-Reverse)
    , hessian

    -- * Hessian Tensors (Forward-On-Mixed)
    , hessianTensor

    -- * Hessian Vector Products (Forward-On-Reverse)
    , hessianProduct
    , hessianProduct'

    -- * Higher Order Gradients/Hessians (Sparse Forward)
    , gradients

    -- * Derivatives (Forward Mode)
    , diff
    , diffF

    , diff'
    , diffF'

    -- * Derivatives (Tower)
    , diffs
    , diffsF

    , diffs0
    , diffs0F

    -- * Directional Derivatives (Forward Mode)
    , du
    , du'
    , duF
    , duF'

    -- * Directional Derivatives (Tower)
    , dus
    , dus0
    , dusF
    , dus0F

    -- * Taylor Series (Tower)
    , taylor
    , taylor0

    -- * Maclaurin Series (Tower)
    , maclaurin
    , maclaurin0

    -- * Monadic Combinators (Forward Mode)
    , diffM
    , diffM'

    -- * Exposed Types
    , module Numeric.AD.Types
    , Mode(..)
    ) where

import Data.Traversable (Traversable)
import Data.Foldable (Foldable, foldr')
import Control.Applicative

import Numeric.AD.Types
import Numeric.AD.Internal.Identity (probed, unprobe)
import Numeric.AD.Internal.Composition
import Numeric.AD.Classes (Mode(..))

import qualified Numeric.AD.Mode.Forward as Forward
import Numeric.AD.Mode.Forward 
    ( diff, diff', diffF, diffF'
    , du, du', duF, duF'
    , diffM, diffM'
    , jacobianT, jacobianWithT
    ) 

import Numeric.AD.Mode.Tower 
    ( diffsF, diffs0F, diffs, diffs0
    , taylor, taylor0, maclaurin, maclaurin0
    , dus, dus0, dusF, dus0F
    )

import qualified Numeric.AD.Mode.Reverse as Reverse
import Numeric.AD.Mode.Reverse 
    ( grad, grad', gradWith, gradWith'
    , gradM, gradM', gradWithM, gradWithM'
    , gradF, gradF', gradWithF, gradWithF'
    )

-- temporary until we make a full sparse mode
import qualified Numeric.AD.Internal.Sparse as Sparse

-- | Calculate the Jacobian of a non-scalar-to-non-scalar function, automatically choosing between forward and reverse mode AD based on the number of inputs and outputs.
--
-- If you need to support functions where the output is only a 'Functor' or 'Monad', consider 'Numeric.AD.Reverse.jacobian' or 'Numeric.AD.Reverse.gradM' from "Numeric.AD.Reverse".
jacobian :: (Traversable f, Traversable g, Num a) => FF f g a -> f a -> g (f a)
jacobian f bs = snd <$> jacobian' f bs
{-# INLINE jacobian #-}

-- | Calculate both the answer and Jacobian of a non-scalar-to-non-scalar function, automatically choosing between forward- and reverse- mode AD based on the relative, number of inputs and outputs.
--
-- If you need to support functions where the output is only a 'Functor' or 'Monad', consider 'Numeric.AD.Reverse.jacobian'' or 'Numeric.AD.Reverse.gradM'' from "Numeric.AD.Reverse".
jacobian' :: (Traversable f, Traversable g, Num a) => FF f g a -> f a -> g (a, f a)
jacobian' f bs | n == 0    = fmap (\x -> (unprobe x, bs)) as
               | n > m     = Reverse.jacobian' f bs
               | otherwise = Forward.jacobian' f bs
    where
        as = f (probed bs)
        n = size bs
        m = size as
        size :: Foldable f => f a -> Int
        size = foldr' (\_ b -> 1 + b) 0
{-# INLINE jacobian' #-}

-- | @'jacobianWith' g f@ calculates the Jacobian of a non-scalar-to-non-scalar function, automatically choosing between forward and reverse mode AD based on the number of inputs and outputs.
--
-- The resulting Jacobian matrix is then recombined element-wise with the input using @g@.
--
-- If you need to support functions where the output is only a 'Functor' or 'Monad', consider 'Numeric.AD.Reverse.jacobianWith' or 'Numeric.AD.Reverse.gradWithM' from "Numeric.AD.Reverse".
jacobianWith :: (Traversable f, Traversable g, Num a) => (a -> a -> b) -> FF f g a -> f a -> g (f b)
jacobianWith g f bs = snd <$> jacobianWith' g f bs
{-# INLINE jacobianWith #-}

-- | @'jacobianWith'' g f@ calculates the answer and Jacobian of a non-scalar-to-non-scalar function, automatically choosing between forward and reverse mode AD based on the number of inputs and outputs.
--
-- The resulting Jacobian matrix is then recombined element-wise with the input using @g@.
--
-- If you need to support functions where the output is only a 'Functor' or 'Monad', consider 'Numeric.AD.Reverse.jacobianWith'' or 'Numeric.AD.Reverse.gradWithM'' from "Numeric.AD.Reverse".
jacobianWith' :: (Traversable f, Traversable g, Num a) => (a -> a -> b) -> FF f g a -> f a -> g (a, f b)
jacobianWith' g f bs
    | n == 0    = fmap (\x -> (unprobe x, undefined <$> bs)) as
    | n > m     = Reverse.jacobianWith' g f bs
    | otherwise = Forward.jacobianWith' g f bs
    where
        as = f (probed bs)
        n = size bs
        m = size as
        size :: Foldable f => f a -> Int
        size = foldr' (\_ b -> 1 + b) 0
{-# INLINE jacobianWith' #-}

-- | @'hessianProduct' f wv@ computes the product of the hessian @H@ of a non-scalar-to-scalar function @f@ at @w = 'fst' <$> wv@ with a vector @v = snd <$> wv@ using \"Pearlmutter\'s method\" from <http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.29.6143>, which states:
--
-- > H v = (d/dr) grad_w (w + r v) | r = 0
-- 
-- Or in other words, we take the directional derivative of the gradient.
hessianProduct :: (Traversable f, Num a) => FU f a -> f (a, a) -> f a
hessianProduct f = duF (grad (decomposeMode . f . fmap composeMode))

-- | @'hessianProduct'' f wv@ computes both the gradient of a non-scalar-to-scalar @f@ at @w = 'fst' <$> wv@ and the product of the hessian @H@ at @w@ with a vector @v = snd <$> wv@ using \"Pearlmutter's method\". The outputs are returned wrapped in the same functor.
--
-- > H v = (d/dr) grad_w (w + r v) | r = 0
-- 
-- Or in other words, we take the directional derivative of the gradient.
-- 
hessianProduct' :: (Traversable f, Num a) => FU f a -> f (a, a) -> f (a, a)
hessianProduct' f = duF' (grad (decomposeMode . f . fmap composeMode))

-- hessianProductWith' :: (Traversable f, Num a) => (a -> a -> a -> a -> b) -> (forall s. Mode s. f (AD s a) -> AD s a) -> f (a, a) -> f b

-- | Compute the hessian via the jacobian of the gradient. gradient is computed in reverse mode and then the jacobian is computed in forward mode.
hessian :: (Traversable f, Num a) => FU f a -> f a -> f (f a)
hessian f = Forward.jacobian (grad (decomposeMode . f . fmap composeMode))

-- | Compute the order 3 Hessian tensor on a non-scalar-to-non-scalar function via the forward-mode Jacobian of the mixed-mode Jacobian of the function.
hessianTensor :: (Traversable f, Traversable g, Num a) => FF f g a -> f a -> g (f (f a))
hessianTensor f = decomposeFunctor . Forward.jacobian (ComposeFunctor . jacobian (fmap decomposeMode . f . fmap composeMode))

-- data f :> a = a :< f (f :> a)
-- data f :- a = a :- (f :- f a) | Zero
{-
flatten :: (f :> a) -> (f :- a)
grads :: (Traversable f, Num a) => (forall s. Mode s => f (AD s a) -> AD s a) -> f a -> (f :- a) 
grads f b = a :- da :- d2a :- Zero
    (a, da) = grad2 f a
    dda = Forward.jacobian (grad (decomposeMode . f . fmap composeMode)
    ddda = Forward
-}

gradients :: (Traversable f, Num a) => FU f a -> f a -> Stream f a
gradients f as = Sparse.ds as $ f $ Sparse.vars as