{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK not-home #-}

-----------------------------------------------------------------------------
-- |
-- Copyright   : (c) Edward Kmett 2010-2021
-- License     : BSD3
-- Maintainer  : ekmett@gmail.com
-- Stability   : experimental
-- Portability : GHC only
--
-- Dense Forward AD. Useful when the result involves the majority of the input
-- elements. Do not use for 'Numeric.AD.Mode.Mixed.hessian' and beyond, since
-- they only contain a small number of unique @n@th derivatives --
-- @(n + k - 1) `choose` k@ for functions of @k@ inputs rather than the
-- @k^n@ that would be generated by using 'Dense', not to mention the redundant
-- intermediate derivatives that would be
-- calculated over and over during that process!
--
-- Assumes all instances of 'f' have the same number of elements.
--
-- NB: We don't need the full power of 'Traversable' here, we could get
-- by with a notion of zippable that can plug in 0's for the missing
-- entries. This might allow for gradients where @f@ has exponentials like @((->) a)@
-----------------------------------------------------------------------------

module Numeric.AD.Internal.Dense
  ( Dense(..)
  , ds
  , ds'
  , vars
  , apply
  ) where

import Control.Monad (join)
import Data.Typeable ()
import Data.Traversable (mapAccumL)
import Data.Data ()
import Data.Number.Erf
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode

data Dense f a
  = Lift !a
  | Dense !a (f a)
  | Zero

instance Show a => Show (Dense f a) where
  showsPrec :: Int -> Dense f a -> ShowS
showsPrec Int
d (Lift a
a)    = Int -> a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
d a
a
  showsPrec Int
d (Dense a
a f a
_) = Int -> a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
d a
a
  showsPrec Int
_ Dense f a
Zero        = String -> ShowS
showString String
"0"

ds :: f a -> Dense f a -> f a
ds :: f a -> Dense f a -> f a
ds f a
_ (Dense a
_ f a
da) = f a
da
ds f a
z Dense f a
_ = f a
z
{-# INLINE ds #-}

ds' :: Num a => f a -> Dense f a -> (a, f a)
ds' :: f a -> Dense f a -> (a, f a)
ds' f a
_ (Dense a
a f a
da) = (a
a, f a
da)
ds' f a
z (Lift a
a) = (a
a, f a
z)
ds' f a
z Dense f a
Zero = (a
0, f a
z)
{-# INLINE ds' #-}

-- Bind variables and count inputs
vars :: (Traversable f, Num a) => f a -> f (Dense f a)
vars :: f a -> f (Dense f a)
vars f a
as = (Int, f (Dense f a)) -> f (Dense f a)
forall a b. (a, b) -> b
snd ((Int, f (Dense f a)) -> f (Dense f a))
-> (Int, f (Dense f a)) -> f (Dense f a)
forall a b. (a -> b) -> a -> b
$ (Int -> a -> (Int, Dense f a))
-> Int -> f a -> (Int, f (Dense f a))
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL Int -> a -> (Int, Dense f a)
outer (Int
0 :: Int) f a
as where
  outer :: Int -> a -> (Int, Dense f a)
outer !Int
i a
a = (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense a
a (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (Int, f a) -> f a
forall a b. (a, b) -> b
snd ((Int, f a) -> f a) -> (Int, f a) -> f a
forall a b. (a -> b) -> a -> b
$ (Int -> a -> (Int, a)) -> Int -> f a -> (Int, f a)
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL (Int -> Int -> a -> (Int, a)
forall a b p. (Eq a, Num a, Num b) => a -> a -> p -> (a, b)
inner Int
i) Int
0 f a
as)
  inner :: a -> a -> p -> (a, b)
inner !a
i !a
j p
_ = (a
j a -> a -> a
forall a. Num a => a -> a -> a
+ a
1, if a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
j then b
1 else b
0)
{-# INLINE vars #-}

apply :: (Traversable f, Num a) => (f (Dense f a) -> b) -> f a -> b
apply :: (f (Dense f a) -> b) -> f a -> b
apply f (Dense f a) -> b
f f a
as = f (Dense f a) -> b
f (f a -> f (Dense f a)
forall (f :: * -> *) a.
(Traversable f, Num a) =>
f a -> f (Dense f a)
vars f a
as)
{-# INLINE apply #-}

primal :: Num a => Dense f a -> a
primal :: Dense f a -> a
primal Dense f a
Zero = a
0
primal (Lift a
a) = a
a
primal (Dense a
a f a
_) = a
a

instance (Num a, Traversable f) => Mode (Dense f a) where
  type Scalar (Dense f a) = a
  asKnownConstant :: Dense f a -> Maybe (Scalar (Dense f a))
asKnownConstant (Lift a
a) = a -> Maybe a
forall a. a -> Maybe a
Just a
a
  asKnownConstant Dense f a
Zero = a -> Maybe a
forall a. a -> Maybe a
Just a
0
  asKnownConstant Dense f a
_ = Maybe (Scalar (Dense f a))
forall a. Maybe a
Nothing
  isKnownConstant :: Dense f a -> Bool
isKnownConstant Dense{} = Bool
False
  isKnownConstant Dense f a
_ = Bool
True
  isKnownZero :: Dense f a -> Bool
isKnownZero Dense f a
Zero = Bool
True
  isKnownZero Dense f a
_ = Bool
False
  auto :: Scalar (Dense f a) -> Dense f a
auto = Scalar (Dense f a) -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift
  zero :: Dense f a
zero = Dense f a
forall (f :: * -> *) a. Dense f a
Zero
  Scalar (Dense f a)
_ *^ :: Scalar (Dense f a) -> Dense f a -> Dense f a
*^ Dense f a
Zero       = Dense f a
forall (f :: * -> *) a. Dense f a
Zero
  Scalar (Dense f a)
a *^ Lift a
b     = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (a
Scalar (Dense f a)
a a -> a -> a
forall a. Num a => a -> a -> a
* a
b)
  Scalar (Dense f a)
a *^ Dense a
b f a
db = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (a
Scalar (Dense f a)
a a -> a -> a
forall a. Num a => a -> a -> a
* a
b) (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
Scalar (Dense f a)
aa -> a -> a
forall a. Num a => a -> a -> a
*) f a
db
  Dense f a
Zero       ^* :: Dense f a -> Scalar (Dense f a) -> Dense f a
^* Scalar (Dense f a)
_ = Dense f a
forall (f :: * -> *) a. Dense f a
Zero
  Lift a
a     ^* Scalar (Dense f a)
b = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
Scalar (Dense f a)
b)
  Dense a
a f a
da ^* Scalar (Dense f a)
b = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
Scalar (Dense f a)
b) (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
*a
Scalar (Dense f a)
b) f a
da
  Dense f a
Zero       ^/ :: Dense f a -> Scalar (Dense f a) -> Dense f a
^/ Scalar (Dense f a)
_ = Dense f a
forall (f :: * -> *) a. Dense f a
Zero
  Lift a
a     ^/ Scalar (Dense f a)
b = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (a
a a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
Scalar (Dense f a)
b)
  Dense a
a f a
da ^/ Scalar (Dense f a)
b = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (a
a a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
Scalar (Dense f a)
b) (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Fractional a => a -> a -> a
/a
Scalar (Dense f a)
b) f a
da

(<+>) :: (Traversable f, Num a) => Dense f a -> Dense f a -> Dense f a
Dense f a
Zero       <+> :: Dense f a -> Dense f a -> Dense f a
<+> Dense f a
a          = Dense f a
a
Dense f a
a          <+> Dense f a
Zero       = Dense f a
a
Lift a
a     <+> Lift a
b     = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b)
Lift a
a     <+> Dense a
b f a
db = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b) f a
db
Dense a
a f a
da <+> Lift a
b     = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b) f a
da
Dense a
a f a
da <+> Dense a
b f a
db = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b) (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> a -> a
forall a. Num a => a -> a -> a
(+) f a
da f a
db

instance (Traversable f, Num a) => Jacobian (Dense f a) where
  type D (Dense f a) = Id a
  unary :: (Scalar (Dense f a) -> Scalar (Dense f a))
-> D (Dense f a) -> Dense f a -> Dense f a
unary Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_         Dense f a
Zero        = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a)
f Scalar (Dense f a)
0)
  unary Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_         (Lift a
b)    = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b)
  unary Scalar (Dense f a) -> Scalar (Dense f a)
f (Id dadb) (Dense a
b f a
db) = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b) ((a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb a -> a -> a
forall a. Num a => a -> a -> a
*) f a
db)

  lift1 :: (Scalar (Dense f a) -> Scalar (Dense f a))
-> (D (Dense f a) -> D (Dense f a)) -> Dense f a -> Dense f a
lift1 Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a)
_  Dense f a
Zero        = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a)
f Scalar (Dense f a)
0)
  lift1 Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a)
_  (Lift a
b)    = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b)
  lift1 Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a)
df (Dense a
b f a
db) = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b) ((a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb a -> a -> a
forall a. Num a => a -> a -> a
*) f a
db) where
    Id a
dadb = D (Dense f a) -> D (Dense f a)
df (a -> Id a
forall a. a -> Id a
Id a
b)

  lift1_ :: (Scalar (Dense f a) -> Scalar (Dense f a))
-> (D (Dense f a) -> D (Dense f a) -> D (Dense f a))
-> Dense f a
-> Dense f a
lift1_ Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> D (Dense f a)
_  Dense f a
Zero         = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a)
f Scalar (Dense f a)
0)
  lift1_ Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> D (Dense f a)
_  (Lift a
b)     = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b)
  lift1_ Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> D (Dense f a)
df (Dense a
b f a
db) = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense a
Scalar (Dense f a)
a ((a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb a -> a -> a
forall a. Num a => a -> a -> a
*) f a
db) where
    a :: Scalar (Dense f a)
a = Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b
    Id a
dadb = D (Dense f a) -> D (Dense f a) -> D (Dense f a)
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Dense f a)
a) (a -> Id a
forall a. a -> Id a
Id a
b)

  binary :: (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a))
-> D (Dense f a)
-> D (Dense f a)
-> Dense f a
-> Dense f a
-> Dense f a
binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_          D (Dense f a)
_        Dense f a
Zero         Dense f a
Zero         = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f Scalar (Dense f a)
0 Scalar (Dense f a)
0)
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_          D (Dense f a)
_        Dense f a
Zero         (Lift a
c)     = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f Scalar (Dense f a)
0 a
Scalar (Dense f a)
c)
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_          D (Dense f a)
_        (Lift a
b)     Dense f a
Zero         = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b Scalar (Dense f a)
0)
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_          D (Dense f a)
_        (Lift a
b)     (Lift a
c)     = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b a
Scalar (Dense f a)
c)
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_         (Id dadc) Dense f a
Zero         (Dense a
c f a
dc) = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f Scalar (Dense f a)
0 a
Scalar (Dense f a)
c) (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc) f a
dc
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
_         (Id dadc) (Lift a
b)     (Dense a
c f a
dc) = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b a
Scalar (Dense f a)
c) (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc) f a
dc
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f (Id dadb) D (Dense f a)
_         (Dense a
b f a
db) Dense f a
Zero         = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b Scalar (Dense f a)
0) (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb a -> a -> a
forall a. Num a => a -> a -> a
*) f a
db
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f (Id dadb) D (Dense f a)
_         (Dense a
b f a
db) (Lift a
c)     = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b a
Scalar (Dense f a)
c) (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadb a -> a -> a
forall a. Num a => a -> a -> a
*) f a
db
  binary Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f (Id dadb) (Id dadc) (Dense a
b f a
db) (Dense a
c f a
dc) = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b a
Scalar (Dense f a)
c) (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> a -> a
productRule f a
db f a
dc where
    productRule :: a -> a -> a
productRule a
dbi a
dci = a
dadb a -> a -> a
forall a. Num a => a -> a -> a
* a
dbi a -> a -> a
forall a. Num a => a -> a -> a
+ a
dci a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc

  lift2 :: (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a))
-> (D (Dense f a)
    -> D (Dense f a) -> (D (Dense f a), D (Dense f a)))
-> Dense f a
-> Dense f a
-> Dense f a
lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  Dense f a
Zero         Dense f a
Zero         = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f Scalar (Dense f a)
0 Scalar (Dense f a)
0)
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  Dense f a
Zero         (Lift a
c)     = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f Scalar (Dense f a)
0 a
Scalar (Dense f a)
c)
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  (Lift a
b)     Dense f a
Zero         = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b Scalar (Dense f a)
0)
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  (Lift a
b)     (Lift a
c)     = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b a
Scalar (Dense f a)
c)
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df Dense f a
Zero         (Dense a
c f a
dc) = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f Scalar (Dense f a)
0 a
Scalar (Dense f a)
c) (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
*a
dadc) f a
dc where dadc :: a
dadc = Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> b
snd (D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (a -> Id a
forall a. a -> Id a
Id a
0) (a -> Id a
forall a. a -> Id a
Id a
c)))
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Lift a
b)     (Dense a
c f a
dc) = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b a
Scalar (Dense f a)
c) (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
*a
dadc) f a
dc where dadc :: a
dadc = Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> b
snd (D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)))
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Dense a
b f a
db) Dense f a
Zero         = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b Scalar (Dense f a)
0) (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadba -> a -> a
forall a. Num a => a -> a -> a
*) f a
db where dadb :: a
dadb = Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> a
fst (D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
0)))
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Dense a
b f a
db) (Lift a
c)     = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b a
Scalar (Dense f a)
c) (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadba -> a -> a
forall a. Num a => a -> a -> a
*) f a
db where dadb :: a
dadb = Id a -> a
forall a. Id a -> a
runId ((Id a, Id a) -> Id a
forall a b. (a, b) -> a
fst (D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)))
  lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Dense a
b f a
db) (Dense a
c f a
dc) = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b a
Scalar (Dense f a)
c) f a
da where
    (Id a
dadb, Id a
dadc) = D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)
    da :: f a
da = (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> a -> a
productRule f a
db f a
dc
    productRule :: a -> a -> a
productRule a
dbi a
dci = a
dadb a -> a -> a
forall a. Num a => a -> a -> a
* a
dbi a -> a -> a
forall a. Num a => a -> a -> a
+ a
dci a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc

  lift2_ :: (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a))
-> (D (Dense f a)
    -> D (Dense f a)
    -> D (Dense f a)
    -> (D (Dense f a), D (Dense f a)))
-> Dense f a
-> Dense f a
-> Dense f a
lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  Dense f a
Zero     Dense f a
Zero     = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f Scalar (Dense f a)
0 Scalar (Dense f a)
0)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  Dense f a
Zero     (Lift a
c) = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f Scalar (Dense f a)
0 a
Scalar (Dense f a)
c)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  (Lift a
b) Dense f a
Zero     = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b Scalar (Dense f a)
0)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
_  (Lift a
b) (Lift a
c) = a -> Dense f a
forall (f :: * -> *) a. a -> Dense f a
Lift (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b a
Scalar (Dense f a)
c)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df Dense f a
Zero     (Dense a
c f a
dc) = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense a
Scalar (Dense f a)
a (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
*a
dadc) f a
dc where
    a :: Scalar (Dense f a)
a = Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f Scalar (Dense f a)
0 a
Scalar (Dense f a)
c
    (Id a
_, Id a
dadc) = D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Dense f a)
a) (a -> Id a
forall a. a -> Id a
Id a
0) (a -> Id a
forall a. a -> Id a
Id a
c)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Lift a
b) (Dense a
c f a
dc) = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense a
Scalar (Dense f a)
a (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
*a
dadc) f a
dc where
    a :: Scalar (Dense f a)
a = Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b a
Scalar (Dense f a)
c
    (Id a
_, Id a
dadc) = D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Dense f a)
a) (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Dense a
b f a
db) Dense f a
Zero = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense a
Scalar (Dense f a)
a (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadba -> a -> a
forall a. Num a => a -> a -> a
*) f a
db where
    a :: Scalar (Dense f a)
a = Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b Scalar (Dense f a)
0
    (Id a
dadb, Id a
_) = D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Dense f a)
a) (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
0)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Dense a
b f a
db) (Lift a
c) = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense a
Scalar (Dense f a)
a (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
dadba -> a -> a
forall a. Num a => a -> a -> a
*) f a
db where
    a :: Scalar (Dense f a)
a = Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b a
Scalar (Dense f a)
c
    (Id a
dadb, Id a
_) = D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Dense f a)
a) (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)
  lift2_ Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (Dense a
b f a
db) (Dense a
c f a
dc) = a -> f a -> Dense f a
forall (f :: * -> *) a. a -> f a -> Dense f a
Dense a
Scalar (Dense f a)
a (f a -> Dense f a) -> f a -> Dense f a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) (g :: * -> *) a b c.
(Foldable f, Traversable g) =>
(a -> b -> c) -> f a -> g b -> g c
zipWithT a -> a -> a
productRule f a
db f a
dc where
    a :: Scalar (Dense f a)
a = Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
f a
Scalar (Dense f a)
b a
Scalar (Dense f a)
c
    (Id a
dadb, Id a
dadc) = D (Dense f a)
-> D (Dense f a) -> D (Dense f a) -> (D (Dense f a), D (Dense f a))
df (a -> Id a
forall a. a -> Id a
Id a
Scalar (Dense f a)
a) (a -> Id a
forall a. a -> Id a
Id a
b) (a -> Id a
forall a. a -> Id a
Id a
c)
    productRule :: a -> a -> a
productRule a
dbi a
dci = a
dadb a -> a -> a
forall a. Num a => a -> a -> a
* a
dbi a -> a -> a
forall a. Num a => a -> a -> a
+ a
dci a -> a -> a
forall a. Num a => a -> a -> a
* a
dadc

mul :: (Traversable f, Num a) => Dense f a -> Dense f a -> Dense f a
mul :: Dense f a -> Dense f a -> Dense f a
mul = (Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a))
-> (D (Dense f a)
    -> D (Dense f a) -> (D (Dense f a), D (Dense f a)))
-> Dense f a
-> Dense f a
-> Dense f a
forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t)
-> (D t -> D t -> (D t, D t)) -> t -> t -> t
lift2 Scalar (Dense f a) -> Scalar (Dense f a) -> Scalar (Dense f a)
forall a. Num a => a -> a -> a
(*) (\D (Dense f a)
x D (Dense f a)
y -> (D (Dense f a)
y, D (Dense f a)
x))

#define BODY1(x)   (Traversable f, x) =>
#define BODY2(x,y) (Traversable f, x, y) =>
#define HEAD (Dense f a)
#include "instances.h"