{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- |
-- Module      : Prelude.Backprop.Explicit
-- Copyright   : (c) Justin Le 2018
-- License     : BSD3
--
-- Maintainer  : justin@jle.im
-- Stability   : experimental
-- Portability : non-portable
--
-- Provides "explicit" versions of all of the functions in
-- "Prelude.Backprop".  Instead of relying on a 'Backprop' instance, allows
-- you to manually provide 'zero', 'add', and 'one' on a per-value basis.
--
-- @since 0.2.0.0

module Prelude.Backprop.Explicit (
  -- * Foldable and Traversable
    sum
  , product
  , length
  , minimum
  , maximum
  , traverse
  -- * Functor and Applicative
  , fmap
  , pure
  , liftA2
  , liftA3
  -- * Misc
  , coerce
  ) where

import           Numeric.Backprop.Explicit
import           Prelude             (Num(..), Fractional(..), Eq(..), Ord(..), Functor, Foldable, Traversable, Applicative, (.), ($))
import qualified Control.Applicative as P
import qualified Data.Coerce         as C
import qualified Data.Foldable       as P
import qualified Prelude             as P

-- | Lifted 'P.sum'
sum :: forall t a s. (Foldable t, Functor t, Num a, Reifies s W)
    => AddFunc (t a)
    -> ZeroFunc a
    -> BVar s (t a)
    -> BVar s a
sum af zf = liftOp1 af zf . op1 $ \xs ->
    ( P.sum xs
    , (P.<$ xs)
    )
{-# INLINE sum #-}

-- | Lifted 'P.pure'.
pure
    :: forall t a s. (Foldable t, Applicative t, Reifies s W)
    => AddFunc a
    -> ZeroFunc a
    -> ZeroFunc (t a)
    -> BVar s a
    -> BVar s (t a)
pure af zfa zf = liftOp1 af zf . op1 $ \x ->
    ( P.pure x
    , P.foldl' (runAF af) (runZF zfa x)
    )
{-# INLINE pure #-}

-- | Lifted 'P.product'
product
    :: forall t a s. (Foldable t, Functor t, Fractional a, Reifies s W)
    => AddFunc (t a)
    -> ZeroFunc a
    -> BVar s (t a)
    -> BVar s a
product af zf = liftOp1 af zf . op1 $ \xs ->
    let p = P.product xs
    in ( p
       , \d -> (\x -> p * d / x) P.<$> xs
       )
{-# INLINE product #-}

-- | Lifted 'P.length'.
length
    :: forall t a b s. (Foldable t, Num b, Reifies s W)
    => AddFunc (t a)
    -> ZeroFunc (t a)
    -> ZeroFunc b
    -> BVar s (t a)
    -> BVar s b
length af zfa zf = liftOp1 af zf . op1 $ \xs ->
    ( P.fromIntegral (P.length xs)
    , P.const (runZF zfa xs)
    )
{-# INLINE length #-}

-- | Lifted 'P.minimum'.  Undefined for situations where 'P.minimum' would
-- be undefined.
minimum
    :: forall t a s. (Foldable t, Functor t, Ord a, Reifies s W)
    => AddFunc (t a)
    -> ZeroFunc a
    -> BVar s (t a)
    -> BVar s a
minimum af zf = liftOp1 af zf . op1 $ \xs ->
    let m = P.minimum xs
    in  ( m
        , \d -> (\x -> if x == m then d else runZF zf x) P.<$> xs
        )
{-# INLINE minimum #-}

-- | Lifted 'P.maximum'.  Undefined for situations where 'P.maximum' would
-- be undefined.
maximum
    :: forall t a s. (Foldable t, Functor t, Ord a, Reifies s W)
    => AddFunc (t a)
    -> ZeroFunc a
    -> BVar s (t a)
    -> BVar s a
maximum af zf = liftOp1 af zf . op1 $ \xs ->
    let m = P.maximum xs
    in  ( m
        , \d -> (\x -> if x == m then d else runZF zf x) P.<$> xs
        )
{-# INLINE maximum #-}

-- | Lifted 'P.fmap'.  Lifts backpropagatable functions to be
-- backpropagatable functions on 'Traversable' 'Functor's.
fmap
    :: forall f a b s. (Traversable f, Reifies s W)
    => AddFunc a
    -> AddFunc b
    -> ZeroFunc a
    -> ZeroFunc b
    -> ZeroFunc (f b)
    -> (BVar s a -> BVar s b)
    -> BVar s (f a)
    -> BVar s (f b)
fmap afa afb zfa zfb zfbs f = collectVar afb zfb zfbs . P.fmap f . sequenceVar afa zfa
{-# INLINE fmap #-}

-- | Lifted 'P.traverse'.  Lifts backpropagatable functions to be
-- backpropagatable functions on 'Traversable' 'Functor's.
traverse
    :: forall t f a b s. (Traversable t, Applicative f, Foldable f, Reifies s W)
    => AddFunc a
    -> AddFunc b
    -> AddFunc (t b)
    -> ZeroFunc a
    -> ZeroFunc b
    -> ZeroFunc (t b)
    -> ZeroFunc (f (t b))
    -> (BVar s a -> f (BVar s b))
    -> BVar s (t a)
    -> BVar s (f (t b))
traverse afa afb aftb zfa zfb zftb zfftb f
        = collectVar aftb zftb zfftb
        . P.fmap (collectVar afb zfb zftb)
        . P.traverse f
        . sequenceVar afa zfa
{-# INLINE traverse #-}

-- | Lifted 'P.liftA2'.  Lifts backpropagatable functions to be
-- backpropagatable functions on 'Traversable' 'Applicative's.
liftA2
    :: forall f a b c s.
       ( Traversable f
       , Applicative f
       , Reifies s W
       )
    => AddFunc a
    -> AddFunc b
    -> AddFunc c
    -> ZeroFunc a
    -> ZeroFunc b
    -> ZeroFunc c
    -> ZeroFunc (f c)
    -> (BVar s a -> BVar s b -> BVar s c)
    -> BVar s (f a)
    -> BVar s (f b)
    -> BVar s (f c)
liftA2 afa afb afc zfa zfb zfc zffc f x y
    = collectVar afc zfc zffc
    $ f P.<$> sequenceVar afa zfa x
        P.<*> sequenceVar afb zfb y
{-# INLINE liftA2 #-}

-- | Lifted 'P.liftA3'.  Lifts backpropagatable functions to be
-- backpropagatable functions on 'Traversable' 'Applicative's.
liftA3
    :: forall f a b c d s.
       ( Traversable f
       , Applicative f
       , Reifies s W
       )
    => AddFunc a
    -> AddFunc b
    -> AddFunc c
    -> AddFunc d
    -> ZeroFunc a
    -> ZeroFunc b
    -> ZeroFunc c
    -> ZeroFunc d
    -> ZeroFunc (f d)
    -> (BVar s a -> BVar s b -> BVar s c -> BVar s d)
    -> BVar s (f a)
    -> BVar s (f b)
    -> BVar s (f c)
    -> BVar s (f d)
liftA3 afa afb afc afd zfa zfb zfc zfd zffd f x y z
    = collectVar afd zfd zffd
    $ f P.<$> sequenceVar afa zfa x
        P.<*> sequenceVar afb zfb y
        P.<*> sequenceVar afc zfc z
{-# INLINE liftA3 #-}

-- | Coerce items inside a 'BVar'.
coerce
    :: forall a b s. C.Coercible a b
    => BVar s a
    -> BVar s b
coerce = coerceVar
{-# INLINE coerce #-}