-- {-# OPTIONS_HADDOCK hide, prune #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.AD.Internal.Var
-- Copyright   :  (c) Edward Kmett 2012
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- Reverse-Mode Automatic Differentiation using a single tape.
--
-- This version uses @Data.Reflection@ to update a single tape.
--
-- This is asymptotically faster than using @Reverse@, which
-- is forced to reify and topologically sort the graph, but it is
-- less friendly to the use of sparks.
-----------------------------------------------------------------------------

module Numeric.AD.Internal.Var
    ( Var(..)
    , bind
    , unbind
    , unbindMap
    , unbindWith
    , unbindMapWithDefault
    , Variable(..)
    , vary
    ) where

import Prelude hiding (mapM)
import Data.Array
import Data.IntMap (IntMap, findWithDefault)
import Data.Traversable (Traversable, mapM)
import Numeric.AD.Internal.Types
import Numeric.AD.Internal.Classes

-- | Used to mark variables for inspection during the reverse pass
class Primal v => Var v where
    var   :: a -> Int -> v a
    varId :: v a -> Int

instance Var f => Var (AD f) where
    var a v = AD (var a v)
    varId (AD v) = varId v

-- A simple fresh variable supply monad
newtype S a = S { runS :: Int -> (a,Int) }
instance Monad S where
    return a = S (\s -> (a,s))
    S g >>= f = S (\s -> let (a,s') = g s in runS (f a) s')

bind :: (Traversable f, Var v) => f a -> (f (v a), (Int,Int))
bind xs = (r,(0,hi)) where
  (r,hi) = runS (mapM freshVar xs) 0
  freshVar a = S (\s -> let s' = s + 1 in s' `seq` (var a s, s'))

unbind :: (Functor f, Var v)  => f (v a) -> Array Int a -> f a
unbind xs ys = fmap (\v -> ys ! varId v) xs

unbindWith :: (Functor f, Var v, Num a) => (a -> b -> c) -> f (v a) -> Array Int b -> f c
unbindWith f xs ys = fmap (\v -> f (primal v) (ys ! varId v)) xs

unbindMap :: (Functor f, Var v, Num a) => f (v a) -> IntMap a -> f a
unbindMap xs ys = fmap (\v -> findWithDefault 0 (varId v) ys) xs

unbindMapWithDefault :: (Functor f, Var v, Num a) => b -> (a -> b -> c) -> f (v a) -> IntMap b -> f c
unbindMapWithDefault z f xs ys = fmap (\v -> f (primal v) $ findWithDefault z (varId v) ys) xs

data Variable a = Variable a {-# UNPACK #-} !Int

instance Var Variable where
  var = Variable
  varId (Variable _ i) = i

instance Primal Variable where
  primal (Variable a _) = a

vary :: Var f => Variable a -> f a
vary (Variable a i) = var a i