module Numeric.AD.Internal.Dense
( Dense(..)
, ds
, ds'
, vars
, apply
) where
import Language.Haskell.TH
import Data.Typeable ()
import Data.Traversable (Traversable, mapAccumL)
import Data.Data ()
import Numeric.AD.Internal.Types
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Classes
import Numeric.AD.Internal.Identity
data Dense f a
= Lift !a
| Dense !a (f a)
instance Show a => Show (Dense f a) where
showsPrec n (Lift a) = showsPrec n a
showsPrec n (Dense a _) = showsPrec n a
ds :: f a -> AD (Dense f) a -> f a
ds _ (AD (Dense _ da)) = da
ds z _ = z
ds' :: f a -> AD (Dense f) a -> (a, f a)
ds' _ (AD (Dense a da)) = (a, da)
ds' z (AD (Lift a)) = (a, z)
vars :: (Traversable f, Num a) => f a -> f (AD (Dense f) a)
vars as = snd $ mapAccumL outer (0 :: Int) as
where
outer !i a = (i + 1, AD $ Dense a $ snd $ mapAccumL (inner i) 0 as)
inner !i !j _ = (j + 1, if i == j then 1 else 0)
apply :: (Traversable f, Num a) => (f (AD (Dense f) a) -> b) -> f a -> b
apply f as = f (vars as)
instance Primal (Dense f) where
primal (Lift a) = a
primal (Dense a _) = a
instance (Traversable f, Lifted (Dense f)) => Mode (Dense f) where
lift = Lift
Lift a <+> Lift b = Lift (a + b)
Lift a <+> Dense b db = Dense (a + b) db
Dense a da <+> Lift b = Dense (a + b) da
Dense a da <+> Dense b db = Dense (a + b) $ zipWithT (+) da db
a *^ Lift b = Lift (a * b)
a *^ Dense b db = Dense (a * b) $ fmap (a*) db
Lift a ^* b = Lift (a * b)
Dense a da ^* b = Dense (a * b) $ fmap (*b) da
Lift a ^/ b = Lift (a / b)
Dense a da ^/ b = Dense (a / b) $ fmap (/b) da
instance (Traversable f, Lifted (Dense f)) => Jacobian (Dense f) where
type D (Dense f) = Id
unary f _ (Lift b) = Lift (f b)
unary f (Id dadb) (Dense b db) = Dense (f b) (fmap (dadb *) db)
lift1 f _ (Lift b) = Lift (f b)
lift1 f df (Dense b db) = Dense (f b) (fmap (dadb *) db)
where
Id dadb = df (Id b)
lift1_ f _ (Lift b) = Lift (f b)
lift1_ f df (Dense b db) = Dense a (fmap (dadb *) db)
where
a = f b
Id dadb = df (Id a) (Id b)
binary f _ _ (Lift b) (Lift c)
= Lift (f b c)
binary f _ (Id dadc) (Lift b) (Dense c dc)
= Dense (f b c) $ fmap (* dadc) dc
binary f (Id dadb) _ (Dense b db) (Lift c)
= Dense (f b c) $ fmap (dadb *) db
binary f (Id dadb) (Id dadc) (Dense b db) (Dense c dc)
= Dense (f b c) $ zipWithT productRule db dc
where
productRule dbi dci = dadb * dbi + dci * dadc
lift2 f _ (Lift b) (Lift c)
= Lift (f b c)
lift2 f df (Lift b) (Dense c dc)
= Dense (f b c) $ fmap (*dadc) dc
where
(_, Id dadc) = df (Id b) (Id c)
lift2 f df (Dense b db) (Lift c)
= Dense (f b c) $ fmap (dadb*) db
where
(Id dadb, _) = df (Id b) (Id c)
lift2 f df (Dense b db) (Dense c dc) = Dense (f b c) da
where
(Id dadb, Id dadc) = df (Id b) (Id c)
da = zipWithT productRule db dc
productRule dbi dci = dadb * dbi + dci * dadc
lift2_ f _ (Lift b) (Lift c) = Lift (f b c)
lift2_ f df (Lift b) (Dense c dc)
= Dense a $ fmap (*dadc) dc
where
a = (f b c)
(_, Id dadc) = df (Id a) (Id b) (Id c)
lift2_ f df (Dense b db) (Lift c)
= Dense a $ fmap (dadb*) db
where
a = (f b c)
(Id dadb, _) = df (Id a) (Id b) (Id c)
lift2_ f df (Dense b db) (Dense c dc)
= Dense a $ zipWithT productRule db dc
where
a = f b c
(Id dadb, Id dadc) = df (Id a) (Id b) (Id c)
productRule dbi dci = dadb * dbi + dci * dadc
let f = varT (mkName "f") in
deriveLifted
(classP ''Traversable [f]:)
(conT ''Dense `appT` f)