module Numeric.AD.Internal.Reverse
( Reverse(..)
, Tape(..)
, partials
, partialArray
, partialMap
, derivative
, derivative'
, vgrad, vgrad'
, Grad(..)
) where
import Prelude hiding (mapM)
import Control.Applicative (Applicative(..),(<$>))
import Control.Monad.ST
import Control.Monad (forM_)
import Data.List (foldl')
import Data.Array.ST
import Data.Array
import Data.IntMap (IntMap, fromListWith)
import Data.Graph (Vertex, transposeG, Graph)
import Data.Reify (reifyGraph, MuRef(..))
import qualified Data.Reify.Graph as Reified
import System.IO.Unsafe (unsafePerformIO)
import Language.Haskell.TH
import Data.Data (Data)
import Data.Typeable (Typeable)
import Numeric.AD.Internal.Types
import Numeric.AD.Internal.Classes
import Numeric.AD.Internal.Identity
import Numeric.AD.Internal.Var
data Tape a t
= Zero
| Lift !a
| Var !a !Int
| Binary !a a a t t
| Unary !a a t
deriving (Show, Data, Typeable)
newtype Reverse a = Reverse (Tape a (Reverse a)) deriving (Show, Typeable)
instance MuRef (Reverse a) where
type DeRef (Reverse a) = Tape a
mapDeRef _ (Reverse Zero) = pure Zero
mapDeRef _ (Reverse (Lift a)) = pure (Lift a)
mapDeRef _ (Reverse (Var a v)) = pure (Var a v)
mapDeRef f (Reverse (Binary a dadb dadc b c)) = Binary a dadb dadc <$> f b <*> f c
mapDeRef f (Reverse (Unary a dadb b)) = Unary a dadb <$> f b
instance Lifted Reverse => Mode Reverse where
isKnownZero (Reverse Zero) = True
isKnownZero _ = False
isKnownConstant (Reverse Zero) = True
isKnownConstant (Reverse (Lift _)) = True
isKnownConstant _ = False
auto a = Reverse (Lift a)
zero = Reverse Zero
(<+>) = binary (+) one one
a *^ b = lift1 (a *) (\_ -> auto a) b
a ^* b = lift1 (* b) (\_ -> auto b) a
a ^/ b = lift1 (/ b) (\_ -> auto (recip b)) a
Reverse Zero <**> y = auto (0 ** primal y)
_ <**> Reverse Zero = auto 1
x <**> Reverse (Lift y) = lift1 (**y) (\z -> (y *^ z ** Id (y1))) x
x <**> y = lift2_ (**) (\z xi yi -> (yi *! z /! xi, z *! log1 xi)) x y
instance Primal Reverse where
primal (Reverse Zero) = 0
primal (Reverse (Lift a)) = a
primal (Reverse (Var a _)) = a
primal (Reverse (Binary a _ _ _ _)) = a
primal (Reverse (Unary a _ _)) = a
instance Lifted Reverse => Jacobian Reverse where
type D Reverse = Id
unary f _ (Reverse Zero) = Reverse (Lift (f 0))
unary f _ (Reverse (Lift a)) = Reverse (Lift (f a))
unary f (Id dadb) b = Reverse (Unary (f (primal b)) dadb b)
lift1 f df b = unary f (df (Id pb)) b
where pb = primal b
lift1_ f df b = unary (const a) (df (Id a) (Id pb)) b
where pb = primal b
a = f pb
binary f _ _ (Reverse Zero) (Reverse Zero) = Reverse (Lift (f 0 0))
binary f _ _ (Reverse Zero) (Reverse (Lift c)) = Reverse (Lift (f 0 c))
binary f _ _ (Reverse (Lift b)) (Reverse Zero) = Reverse (Lift (f b 0))
binary f _ _ (Reverse (Lift b)) (Reverse (Lift c)) = Reverse (Lift (f b c))
binary f _ (Id dadc) (Reverse Zero) c = Reverse (Unary (f 0 (primal c)) dadc c)
binary f _ (Id dadc) (Reverse (Lift b)) c = Reverse (Unary (f b (primal c)) dadc c)
binary f (Id dadb) _ b (Reverse Zero) = Reverse (Unary (f (primal b) 0) dadb b)
binary f (Id dadb) _ b (Reverse (Lift c)) = Reverse (Unary (f (primal b) c) dadb b)
binary f (Id dadb) (Id dadc) b c = Reverse (Binary (f (primal b) (primal c)) dadb dadc b c)
lift2 f df b c = binary f dadb dadc b c
where (dadb, dadc) = df (Id (primal b)) (Id (primal c))
lift2_ f df b c = binary (\_ _ -> a) dadb dadc b c
where
pb = primal b
pc = primal c
a = f pb pc
(dadb, dadc) = df (Id a) (Id pb) (Id pc)
deriveLifted id (conT ''Reverse)
derivative :: Num a => AD Reverse a -> a
derivative = sum . map snd . partials
derivative' :: Num a => AD Reverse a -> (a, a)
derivative' r = (primal r, derivative r)
backPropagate :: Num a => (Vertex -> (Tape a Int, Int, [Int])) -> STArray s Int a -> Vertex -> ST s ()
backPropagate vmap ss v = do
case node of
Unary _ g b -> do
da <- readArray ss i
db <- readArray ss b
writeArray ss b (db + g*da)
Binary _ gb gc b c -> do
da <- readArray ss i
db <- readArray ss b
writeArray ss b (db + gb*da)
dc <- readArray ss c
writeArray ss c (dc + gc*da)
_ -> return ()
where
(node, i, _) = vmap v
topSortAcyclic :: Graph -> [Vertex]
topSortAcyclic g = reverse $ runST $ do
del <- newArray (bounds g) False :: ST s (STUArray s Int Bool)
let tg = transposeG g
starters = [ n | (n, []) <- assocs tg ]
loop [] rs = return rs
loop (n:ns) rs = do
writeArray del n True
let add [] = return ns
add (m:ms) = do
b <- ok (tg!m)
ms' <- add ms
if b then return (m:ms') else return ms'
ok [] = return True
ok (x:xs) = do b <- readArray del x; if b then ok xs else return False
ns' <- add (g!n)
loop ns' (n : rs)
loop starters []
partials :: forall a . Num a => AD Reverse a -> [(Int, a)]
partials (AD tape) = [ let v = sensitivities ! ix in seq v (ident, v) | (ix, Var _ ident) <- xs ]
where
Reified.Graph xs start = unsafePerformIO $ reifyGraph tape
g = array xsBounds [ (i, successors t) | (i, t) <- xs ]
vertexMap = array xsBounds xs
vmap i = (vertexMap ! i, i, [])
xsBounds = sbounds xs
sensitivities = runSTArray $ do
ss <- newArray xsBounds 0
writeArray ss start 1
forM_ (topSortAcyclic g) $
backPropagate vmap ss
return ss
sbounds ((a,_):as) = foldl' (\(lo,hi) (b,_) -> let lo' = min lo b; hi' = max hi b in lo' `seq` hi' `seq` (lo', hi')) (a,a) as
sbounds _ = undefined
successors :: Tape a t -> [t]
successors (Unary _ _ b) = [b]
successors (Binary _ _ _ b c) = [b,c]
successors _ = []
partialArray :: Num a => (Int, Int) -> AD Reverse a -> Array Int a
partialArray vbounds tape = accumArray (+) 0 vbounds (partials tape)
partialMap :: Num a => AD Reverse a -> IntMap a
partialMap = fromListWith (+) . partials
newtype S a = S { runS :: Int -> (a,Int) }
instance Monad S where
return a = S (\s -> (a,s))
S g >>= f = S (\s -> let (a,s') = g s in runS (f a) s')
instance Var Reverse where
var a v = Reverse (Var a v)
varId (Reverse (Var _ v)) = v
varId _ = error "varId: not a Var"
class Num a => Grad i o o' a | i -> a o o', o -> a i o', o' -> a i o where
pack :: i -> [AD Reverse a] -> AD Reverse a
unpack :: ([a] -> [a]) -> o
unpack' :: ([a] -> (a, [a])) -> o'
instance Num a => Grad (AD Reverse a) [a] (a, [a]) a where
pack i _ = i
unpack f = f []
unpack' f = f []
instance Grad i o o' a => Grad (AD Reverse a -> i) (a -> o) (a -> o') a where
pack f (a:as) = pack (f a) as
pack _ [] = error "Grad.pack: logic error"
unpack f a = unpack (f . (a:))
unpack' f a = unpack' (f . (a:))
vgrad :: Grad i o o' a => i -> o
vgrad i = unpack (unsafeGrad (pack i))
where
unsafeGrad f as = unbind vs (partialArray bds $ f vs)
where
(vs,bds) = bind as
vgrad' :: Grad i o o' a => i -> o'
vgrad' i = unpack' (unsafeGrad' (pack i))
where
unsafeGrad' f as = (primal r, unbind vs (partialArray bds r))
where
r = f vs
(vs,bds) = bind as