{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK not-home #-}

-----------------------------------------------------------------------------
-- |
-- Copyright   : (c) Edward Kmett 2010-2021
-- License     : BSD3
-- Maintainer  : ekmett@gmail.com
-- Stability   : experimental
-- Portability : GHC only
--
-- Dense Forward AD. Useful when the result involves the majority of the input
-- elements. Do not use for 'Numeric.AD.Mode.Mixed.hessian' and beyond, since
-- they only contain a small number of unique @n@th derivatives --
-- @(n + k - 1) `choose` k@ for functions of @k@ inputs rather than the
-- @k^n@ that would be generated by using 'Dense', not to mention the redundant
-- intermediate derivatives that would be
-- calculated over and over during that process!
--
-- Assumes all instances of 'f' have the same number of elements.
--
-- NB: We don't need the full power of 'Traversable' here, we could get
-- by with a notion of zippable that can plug in 0's for the missing
-- entries. This might allow for gradients where @f@ has exponentials like @((->) a)@
-----------------------------------------------------------------------------

module Numeric.AD.Internal.Dense
  ( Dense(..)
  , ds
  , ds'
  , vars
  , apply
  ) where

import Control.Monad (join)
import Data.Typeable ()
import Data.Traversable (mapAccumL)
import Data.Data ()
import Data.Number.Erf
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode

data Dense f a
  = Lift !a
  | Dense !a (f a)
  | Zero

instance Show a => Show (Dense f a) where
  showsPrec :: Int -> Dense f a -> ShowS
showsPrec Int
d (Lift a
a)    = forall a. Show a => Int -> a -> ShowS
showsPrec Int
d a
a
  showsPrec Int
d (Dense a
a f a
_) = forall a. Show a => Int -> a -> ShowS
showsPrec Int
d a
a
  showsPrec Int
_ Dense f a
Zero        = String -> ShowS
showString String
"0"

ds :: f a -> Dense f a -> f a
ds :: forall (f :: * -> *) a. f a -> Dense f a -> f a
ds f a
_ (Dense a
_ f a
da) = f a
da
ds f a
z Dense f a
_ = f a
z
{-# INLINE ds #-}

ds' :: Num a => f a -> Dense f a -> (a, f a)
ds' :: forall a (f :: * -> *). Num a => f a -> Dense f a -> (a, f a)
ds' f a
_ (Dense a
a f a
da) = (a
a, f a
da)
ds' f a
z (Lift a
a) = (a
a, f a
z)
ds' f a
z Dense f a
Zero = (a
0, f a
z)
{-# INLINE ds' #-}

-- Bind variables and count inputs
vars :: (Traversable f, Num a) => f a -> f (Dense f a)
vars :: forall (f :: * -> *) a.
(Traversable f, Num a) =>
f a -> f (Dense f a)
vars f a
as = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL Int -> a -> (Int, Dense f a)
outer (Int
0 :: Int) f a
as where
  outer :: Int -> a -> (Int, Dense f a)
outer !Int
i a
a = (Int
i forall a. Num a => a -> a -> a
+ Int
1, forall (f :: * -> *) a. a -> f a -> Dense f a
Dense a
a forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
mapAccumL (forall {a} {b} {p}. (Eq a, Num a, Num b) => a -> a -> p -> (a, b)
inner Int
i) Int
0 f a
as)
  inner :: a -> a -> p -> (a, b)
inner !a
i !a
j p
_ = (a
j forall a. Num a => a -> a -> a
+ a
1, if a
i forall a. Eq a => a -> a -> Bool
== a
j then b
1 else b
0)
{-# INLINE vars #-}

apply :: (Traversable f, Num a) => (f (Dense f a) -> b) -> f a -> b
apply :: forall (f :: * -> *) a b.
(Traversable f, Num a) =>
(f (Dense f a) -> b) -> f a -> b
apply f (Dense f a) -> b
f f a
as = f (Dense f a) -> b
f (forall (f :: * -> *) a.
(Traversable f, Num a) =>
f a -> f (Dense f a)
vars f a
as)
{-# INLINE apply #-}

primal :: Num a => Dense f a -> a
primal :: forall a (f :: * -> *). Num a => Dense f a -> a
primal Dense f a
Zero = a
0
primal (Lift a
a) = a
a
primal (Dense a
a f a
_) = a
a

instance (Num a, Traversable f) => Mode (Dense f a) where
  type Scalar (Dense f a) = a
  asKnownConstant :: Dense f a -> Maybe (Scalar (Dense f a))
asKnownConstant (Lift a
a) = forall a. a -> Maybe a
Just a
a
  asKnownConstant Dense f a
Zero = forall a. a -> Maybe a
Just a
0
  asKnownConstant Dense f a
_ = forall a. Maybe a
Nothing
  isKnownConstant :: Dense f a -> Bool
isKnownConstant Dense{} = Bool
False
  isKnownConstant Dense f a
_ = Bool
True
  isKnownZero :: Dense f a -> Bool
isKnownZero Dense f a
Zero = Bool
True
  isKnownZero Dense f a
_ = Bool
False
  auto :: Scalar (Dense f a) -> Dense f a
auto = forall (f :: * -> *) a. a -> Dense f a
Lift
  zero :: Dense f a
zero = forall (f :: * -> *) a. Dense f a
Zero
  Scalar (Dense f a)
_ *^ :: Scalar (Dense f a) -> Dense f a -> Dense f a
*^ Dense f a
Zero       = forall (f :: * -> *) a. Dense f a
Zero
  Scalar (Dense f a)
a *^ Lift a
b     = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a)
a forall a. Num a => a -> a -> a
* a
b)
  Scalar (Dense f a)
a *^ Dense a
b f a
db = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a)
a forall a. Num a => a -> a -> a
* a
b) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar (Dense f a)
aforall a. Num a => a -> a -> a
*) f a
db
  Dense f a
Zero       ^* :: Dense f a -> Scalar (Dense f a) -> Dense f a
^* Scalar (Dense f a)
_ = forall (f :: * -> *) a. Dense f a
Zero
  Lift a
a     ^* Scalar (Dense f a)
b = forall (f :: * -> *) a. a -> Dense f a
Lift (a
a forall a. Num a => a -> a -> a
* Scalar (Dense f a)
b)
  Dense a
a f a
da ^* Scalar (Dense f a)
b = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (a
a forall a. Num a => a -> a -> a
* Scalar (Dense f a)
b) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Num a => a -> a -> a
*Scalar (Dense f a)
b) f a
da
  Dense f a
Zero       ^/ :: Fractional (Scalar (Dense f a)) =>
Dense f a -> Scalar (Dense f a) -> Dense f a
^/ Scalar (Dense f a)
_ = forall (f :: * -> *) a. Dense f a
Zero
  Lift a
a     ^/ Scalar (Dense f a)
b = forall (f :: * -> *) a. a -> Dense f a
Lift (a
a forall a. Fractional a => a -> a -> a
/ Scalar (Dense f a)
b)
  Dense a
a f a
da ^/ Scalar (Dense f a)
b = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (a
a forall a. Fractional a => a -> a -> a
/ Scalar (Dense f a)
b) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Fractional a => a -> a -> a
/Scalar (Dense f a)
b) f a
da

(<+>) :: (Traversable f, Num a) => Dense f a -> Dense f a -> Dense f a
Dense f a
Zero       <+> :: forall (f :: * -> *) a.
(Traversable f, Num a) =>
Dense f a -> Dense f a -> Dense f a
<+> Dense f a
a          = Dense f a
a
Dense f a
a          <+> Dense f a
Zero       = Dense f a
a
Lift a
a     <+> Lift a
b     = forall (f :: * -> *) a. a -> Dense f a
Lift (a
a forall a. Num a => a -> a -> a
+ a
b)
Lift a
a     <+> Dense a
b f a
db = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (a
a forall a. Num a => a -> a -> a
+ a
b) f a
db
Dense a
a f a
da <+> Lift a
b     = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (a
a forall a. Num a => a -> a -> a
+ a
b) f a
da
Dense a
a f a
da <+> Dense a
b f a
db = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (a
a forall a. Num a => a -> a -> a
+ a
b) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT forall a. Num a => a -> a -> a
(+) f a
da f a
db

instance (Traversable f, Num a) => Jacobian (Dense f a) where
  type D (Dense f a) = Id a
  unary :: (Scalar (Dense f a) -> Scalar (Dense f a))
-> D (Dense f a) -> Dense f a -> Dense f a
unary Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_         Dense f a
Zero        = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a)
f a
0)
  unary Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_         (Lift a
b)    = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a)
f a
b)
  unary Scalar (Dense f a) -> Scalar (Dense f a)
f (Id a
dadb) (Dense a
b f a
db) = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a)
f a
b) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb forall a. Num a => a -> a -> a
*) f a
db)

  lift1 :: (Scalar (Dense f a) -> Scalar (Dense f a))
-> (D (Dense f a) -> D (Dense f a)) -> Dense f a -> Dense f a
lift1 Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a)
_  Dense f a
Zero        = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a)
f a
0)
  lift1 Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a)
_  (Lift a
b)    = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a)
f a
b)
  lift1 Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a)
df (Dense a
b f a
db) = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a)
f a
b) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb forall a. Num a => a -> a -> a
*) f a
db) where
    Id a
dadb = D (Dense f a) -> D (Dense f a)
df (forall a. a -> Id a
Id a
b)

  lift1_ :: (Scalar (Dense f a) -> Scalar (Dense f a))
-> (D (Dense f a) -> D (Dense f a) -> D (Dense f a))
-> Dense f a
-> Dense f a
lift1_ Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> D (Dense f a)
_  Dense f a
Zero         = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a)
f a
0)
  lift1_ Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> D (Dense f a)
_  (Lift a
b)     = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a)
f a
b)
  lift1_ Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> D (Dense f a)
df (Dense a
b f a
db) = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense Scalar (Dense f a)
a (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb forall a. Num a => a -> a -> a
*) f a
db) where
    a :: Scalar (Dense f a)
a = Scalar (Dense f a) -> Scalar (Dense f a)
f a
b
    Id a
dadb = D (Dense f a) -> D (Dense f a) -> D (Dense f a)
df (forall a. a -> Id a
Id Scalar (Dense f a)
a) (forall a. a -> Id a
Id a
b)

  binary :: (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a))
-> D (Dense f a)
-> D (Dense f a)
-> Dense f a
-> Dense f a
-> Dense f a
binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_          D (Dense f a)
_        Dense f a
Zero         Dense f a
Zero         = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
0 a
0)
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_          D (Dense f a)
_        Dense f a
Zero         (Lift a
c)     = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
0 a
c)
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_          D (Dense f a)
_        (Lift a
b)     Dense f a
Zero         = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
0)
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_          D (Dense f a)
_        (Lift a
b)     (Lift a
c)     = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
c)
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_         (Id a
dadc) Dense f a
Zero         (Dense a
c f a
dc) = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
0 a
c) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Num a => a -> a -> a
* a
dadc) f a
dc
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_         (Id a
dadc) (Lift a
b)     (Dense a
c f a
dc) = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
c) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Num a => a -> a -> a
* a
dadc) f a
dc
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f (Id a
dadb) D (Dense f a)
_         (Dense a
b f a
db) Dense f a
Zero         = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
0) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb forall a. Num a => a -> a -> a
*) f a
db
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f (Id a
dadb) D (Dense f a)
_         (Dense a
b f a
db) (Lift a
c)     = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
c) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb forall a. Num a => a -> a -> a
*) f a
db
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f (Id a
dadb) (Id a
dadc) (Dense a
b f a
db) (Dense a
c f a
dc) = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
c) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> a -> a
productRule f a
db f a
dc where
    productRule :: a -> a -> a
productRule a
dbi a
dci = a
dadb forall a. Num a => a -> a -> a
* a
dbi forall a. Num a => a -> a -> a
+ a
dci forall a. Num a => a -> a -> a
* a
dadc

  lift2 :: (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a))
-> (D (Dense f a)
    -> D (Dense f a) -> (D (Dense f a), D (Dense f a)))
-> Dense f a
-> Dense f a
-> Dense f a
lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  Dense f a
Zero         Dense f a
Zero         = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
0 a
0)
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  Dense f a
Zero         (Lift a
c)     = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
0 a
c)
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  (Lift a
b)     Dense f a
Zero         = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
0)
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  (Lift a
b)     (Lift a
c)     = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
c)
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df Dense f a
Zero         (Dense a
c f a
dc) = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
0 a
c) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Num a => a -> a -> a
*a
dadc) f a
dc where dadc :: a
dadc = forall a. Id a -> a
runId (forall a b. (a, b) -> b
snd (D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (forall a. a -> Id a
Id a
0) (forall a. a -> Id a
Id a
c)))
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Lift a
b)     (Dense a
c f a
dc) = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
c) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Num a => a -> a -> a
*a
dadc) f a
dc where dadc :: a
dadc = forall a. Id a -> a
runId (forall a b. (a, b) -> b
snd (D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (forall a. a -> Id a
Id a
b) (forall a. a -> Id a
Id a
c)))
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Dense a
b f a
db) Dense f a
Zero         = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
0) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadbforall a. Num a => a -> a -> a
*) f a
db where dadb :: a
dadb = forall a. Id a -> a
runId (forall a b. (a, b) -> a
fst (D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (forall a. a -> Id a
Id a
b) (forall a. a -> Id a
Id a
0)))
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Dense a
b f a
db) (Lift a
c)     = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
c) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadbforall a. Num a => a -> a -> a
*) f a
db where dadb :: a
dadb = forall a. Id a -> a
runId (forall a b. (a, b) -> a
fst (D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (forall a. a -> Id a
Id a
b) (forall a. a -> Id a
Id a
c)))
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Dense a
b f a
db) (Dense a
c f a
dc) = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
c) f a
da where
    (Id a
dadb, Id a
dadc) = D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (forall a. a -> Id a
Id a
b) (forall a. a -> Id a
Id a
c)
    da :: f a
da = forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> a -> a
productRule f a
db f a
dc
    productRule :: a -> a -> a
productRule a
dbi a
dci = a
dadb forall a. Num a => a -> a -> a
* a
dbi forall a. Num a => a -> a -> a
+ a
dci forall a. Num a => a -> a -> a
* a
dadc

  lift2_ :: (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a))
-> (D (Dense f a)
    -> D (Dense f a)
    -> D (Dense f a)
    -> (D (Dense f a), D (Dense f a)))
-> Dense f a
-> Dense f a
-> Dense f a
lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  Dense f a
Zero     Dense f a
Zero     = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
0 a
0)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  Dense f a
Zero     (Lift a
c) = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
0 a
c)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  (Lift a
b) Dense f a
Zero     = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
0)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  (Lift a
b) (Lift a
c) = forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
c)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df Dense f a
Zero     (Dense a
c f a
dc) = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense Scalar (Dense f a)
a forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Num a => a -> a -> a
*a
dadc) f a
dc where
    a :: Scalar (Dense f a)
a = Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
0 a
c
    (D (Dense f a)
_, Id a
dadc) = D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (forall a. a -> Id a
Id Scalar (Dense f a)
a) (forall a. a -> Id a
Id a
0) (forall a. a -> Id a
Id a
c)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Lift a
b) (Dense a
c f a
dc) = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense Scalar (Dense f a)
a forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Num a => a -> a -> a
*a
dadc) f a
dc where
    a :: Scalar (Dense f a)
a = Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
c
    (D (Dense f a)
_, Id a
dadc) = D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (forall a. a -> Id a
Id Scalar (Dense f a)
a) (forall a. a -> Id a
Id a
b) (forall a. a -> Id a
Id a
c)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Dense a
b f a
db) Dense f a
Zero = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense Scalar (Dense f a)
a forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadbforall a. Num a => a -> a -> a
*) f a
db where
    a :: Scalar (Dense f a)
a = Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
0
    (Id a
dadb, D (Dense f a)
_) = D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (forall a. a -> Id a
Id Scalar (Dense f a)
a) (forall a. a -> Id a
Id a
b) (forall a. a -> Id a
Id a
0)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Dense a
b f a
db) (Lift a
c) = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense Scalar (Dense f a)
a forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadbforall a. Num a => a -> a -> a
*) f a
db where
    a :: Scalar (Dense f a)
a = Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
c
    (Id a
dadb, D (Dense f a)
_) = D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (forall a. a -> Id a
Id Scalar (Dense f a)
a) (forall a. a -> Id a
Id a
b) (forall a. a -> Id a
Id a
c)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Dense a
b f a
db) (Dense a
c f a
dc) = forall (f :: * -> *) a. a -> f a -> Dense f a
Dense Scalar (Dense f a)
a forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> a -> a
productRule f a
db f a
dc where
    a :: Scalar (Dense f a)
a = Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
b a
c
    (Id a
dadb, Id a
dadc) = D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (forall a. a -> Id a
Id Scalar (Dense f a)
a) (forall a. a -> Id a
Id a
b) (forall a. a -> Id a
Id a
c)
    productRule :: a -> a -> a
productRule a
dbi a
dci = a
dadb forall a. Num a => a -> a -> a
* a
dbi forall a. Num a => a -> a -> a
+ a
dci forall a. Num a => a -> a -> a
* a
dadc

mul :: (Traversable f, Num a) => Dense f a -> Dense f a -> Dense f a
mul :: forall (f :: * -> *) a.
(Traversable f, Num a) =>
Dense f a -> Dense f a -> Dense f a
mul = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t)
-> (D t -> D t -> (D t, D t)) -> t -> t -> t
lift2 forall a. Num a => a -> a -> a
(*) (\D (Dense f a)
x D (Dense f a)
y -> (D (Dense f a)
y, D (Dense f a)
x))

#define BODY1(x)   (Traversable f, x) =>
#define BODY2(x,y) (Traversable f, x, y) =>
#define HEAD (Dense f a)
#include "instances.h"