module Numeric.AD.Internal.Sparse
( Index(..)
, emptyIndex
, addToIndex
, indices
, Sparse(..)
, apply
, vars
, d, d', ds
, skeleton
, spartial
, partial
, vgrad
, vgrad'
, vgrads
, Grad(..)
, Grads(..)
) where
import Prelude hiding (lookup)
import Control.Applicative
import Numeric.AD.Internal.Classes
import Data.Stream.Branching
import Numeric.AD.Internal.Types
import Data.Data
import Data.Typeable ()
import qualified Data.IntMap as IntMap
import Data.IntMap (IntMap, mapWithKey, unionWith, findWithDefault, toAscList, singleton, insertWith, lookup)
import Data.Traversable
import Language.Haskell.TH
newtype Index = Index (IntMap Int)
emptyIndex :: Index
emptyIndex = Index IntMap.empty
addToIndex :: Int -> Index -> Index
addToIndex k (Index m) = Index (insertWith (+) k 1 m)
indices :: Index -> [Int]
indices (Index as) = uncurry (flip replicate) `concatMap` toAscList as
data Sparse a = Sparse a (IntMap (Sparse a)) deriving (Show, Data, Typeable)
dropMap :: Int -> IntMap a -> IntMap a
dropMap n = snd . IntMap.split (n 1)
times :: Num a => Sparse a -> Int -> Sparse a -> Sparse a
times (Sparse a as) n (Sparse b bs) = Sparse (a * b) $
unionWith (<+>)
(fmap (^* b) (dropMap n as))
(fmap (a *^) (dropMap n bs))
vars :: (Traversable f, Num a) => f a -> f (AD Sparse a)
vars = snd . mapAccumL var 0
where
var !n a = (n + 1, AD $ Sparse a $ singleton n $ lift 1)
apply :: (Traversable f, Num a) => (f (AD Sparse a) -> b) -> f a -> b
apply f = f . vars
skeleton :: Traversable f => f a -> f Int
skeleton = snd . mapAccumL (\ !n _ -> (n + 1, n)) 0
d :: (Traversable f, Num a) => f b -> AD Sparse a -> f a
d fs (AD (Sparse _ da)) = snd $ mapAccumL (\ !n _ -> (n + 1, maybe 0 primal $ lookup n da)) 0 fs
d' :: (Traversable f, Num a) => f a -> AD Sparse a -> (a, f a)
d' fs (AD (Sparse a da)) = (a , snd $ mapAccumL (\ !n _ -> (n + 1, maybe 0 primal $ lookup n da)) 0 fs)
ds :: (Traversable f, Num a) => f b -> AD Sparse a -> Stream f a
ds fs (AD as@(Sparse a _)) = a :< (go emptyIndex <$> fns)
where
fns = skeleton fs
go ix i = partial (indices ix') as :< (go ix' <$> fns)
where ix' = addToIndex i ix
partial :: Num a => [Int] -> Sparse a -> a
partial [] (Sparse a _) = a
partial (n:ns) (Sparse _ da) = partial ns $ findWithDefault (lift 0) n da
spartial :: Num a => [Int] -> Sparse a -> Maybe a
spartial [] (Sparse a _) = Just a
spartial (n:ns) (Sparse _ da) = do
a' <- lookup n da
spartial ns a'
instance Primal Sparse where
primal (Sparse a _) = a
instance Lifted Sparse => Mode Sparse where
lift a = Sparse a (IntMap.empty)
Sparse a as <+> Sparse b bs = Sparse (a + b) $ unionWith (<+>) as bs
Sparse a as ^* b = Sparse (a * b) $ fmap (^* b) as
a *^ Sparse b bs = Sparse (a * b) $ fmap (a *^) bs
Sparse a as ^/ b = Sparse (a / b) $ fmap (^/ b) as
instance Lifted Sparse => Jacobian Sparse where
type D Sparse = Sparse
unary f dadb (Sparse pb bs) = Sparse (f pb) $ mapWithKey (times dadb) bs
lift1 f df b@(Sparse pb bs) = Sparse (f pb) $ mapWithKey (times (df b)) bs
lift1_ f df b@(Sparse pb bs) = a where
a = Sparse (f pb) $ mapWithKey (times (df a b)) bs
binary f dadb dadc (Sparse pb db) (Sparse pc dc) = Sparse (f pb pc) $
unionWith (<+>)
(mapWithKey (times dadb) db)
(mapWithKey (times dadc) dc)
lift2 f df b@(Sparse pb db) c@(Sparse pc dc) = Sparse (f pb pc) da where
(dadb, dadc) = df b c
da = unionWith (<+>)
(mapWithKey (times dadb) db)
(mapWithKey (times dadc) dc)
lift2_ f df b@(Sparse pb db) c@(Sparse pc dc) = a where
(dadb, dadc) = df a b c
a = Sparse (f pb pc) da
da = unionWith (<+>)
(mapWithKey (times dadb) db)
(mapWithKey (times dadc) dc)
deriveLifted id $ conT ''Sparse
class Num a => Grad i o o' a | i -> a o o', o -> a i o', o' -> a i o where
pack :: i -> [AD Sparse a] -> AD Sparse a
unpack :: ([a] -> [a]) -> o
unpack' :: ([a] -> (a, [a])) -> o'
instance Num a => Grad (AD Sparse a) [a] (a, [a]) a where
pack i _ = i
unpack f = f []
unpack' f = f []
instance Grad i o o' a => Grad (AD Sparse a -> i) (a -> o) (a -> o') a where
pack f (a:as) = pack (f a) as
pack _ [] = error "Grad.pack: logic error"
unpack f a = unpack (f . (a:))
unpack' f a = unpack' (f . (a:))
vgrad :: Grad i o o' a => i -> o
vgrad i = unpack (unsafeGrad (pack i))
where
unsafeGrad f as = d as $ apply f as
vgrad' :: Grad i o o' a => i -> o'
vgrad' i = unpack' (unsafeGrad' (pack i))
where
unsafeGrad' f as = d' as $ apply f as
class Num a => Grads i o a | i -> a o, o -> a i where
packs :: i -> [AD Sparse a] -> AD Sparse a
unpacks :: ([a] -> Stream [] a) -> o
instance Num a => Grads (AD Sparse a) (Stream [] a) a where
packs i _ = i
unpacks f = f []
instance Grads i o a => Grads (AD Sparse a -> i) (a -> o) a where
packs f (a:as) = packs (f a) as
packs _ [] = error "Grad.pack: logic error"
unpacks f a = unpacks (f . (a:))
vgrads :: Grads i o a => i -> o
vgrads i = unpacks (unsafeGrads (packs i))
where
unsafeGrads f as = ds as $ apply f as