{-# LANGUAGE CPP #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# OPTIONS_HADDOCK not-home #-}

-----------------------------------------------------------------------------
-- |
-- Copyright   : (c) Edward Kmett 2010-2021
-- License     : BSD3
-- Maintainer  : ekmett@gmail.com
-- Stability   : experimental
-- Portability : GHC only
--
-----------------------------------------------------------------------------

module Numeric.AD.Internal.Tower
  ( Tower(..)
  , zeroPad
  , zeroPadF
  , transposePadF
  , d
  , d'
  , withD
  , tangents
  , bundle
  , apply
  , getADTower
  , tower
  ) where

import Prelude hiding (all, sum)
import Control.Monad (join)
import Data.Foldable
import Data.Data (Data)
import Data.Number.Erf
import Data.Typeable (Typeable)
import Numeric.AD.Internal.Combinators
import Numeric.AD.Jacobian
import Numeric.AD.Mode

-- | @Tower@ is an AD 'Mode' that calculates a tangent tower by forward AD, and provides fast 'diffsUU', 'diffsUF'
newtype Tower a = Tower { Tower a -> [a]
getTower :: [a] } deriving (Typeable (Tower a)
DataType
Constr
Typeable (Tower a)
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> Tower a -> c (Tower a))
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c (Tower a))
-> (Tower a -> Constr)
-> (Tower a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c (Tower a)))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tower a)))
-> ((forall b. Data b => b -> b) -> Tower a -> Tower a)
-> (forall r r'.
    (r -> r' -> r)
    -> r -> (forall d. Data d => d -> r') -> Tower a -> r)
-> (forall r r'.
    (r' -> r -> r)
    -> r -> (forall d. Data d => d -> r') -> Tower a -> r)
-> (forall u. (forall d. Data d => d -> u) -> Tower a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> Tower a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> Tower a -> m (Tower a))
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> Tower a -> m (Tower a))
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> Tower a -> m (Tower a))
-> Data (Tower a)
Tower a -> DataType
Tower a -> Constr
(forall d. Data d => c (t d)) -> Maybe (c (Tower a))
(forall b. Data b => b -> b) -> Tower a -> Tower a
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tower a -> c (Tower a)
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tower a)
forall a. Data a => Typeable (Tower a)
forall a. Data a => Tower a -> DataType
forall a. Data a => Tower a -> Constr
forall a.
Data a =>
(forall b. Data b => b -> b) -> Tower a -> Tower a
forall a u.
Data a =>
Int -> (forall d. Data d => d -> u) -> Tower a -> u
forall a u.
Data a =>
(forall d. Data d => d -> u) -> Tower a -> [u]
forall a r r'.
Data a =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tower a -> r
forall a r r'.
Data a =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tower a -> r
forall a (m :: * -> *).
(Data a, Monad m) =>
(forall d. Data d => d -> m d) -> Tower a -> m (Tower a)
forall a (m :: * -> *).
(Data a, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Tower a -> m (Tower a)
forall a (c :: * -> *).
Data a =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tower a)
forall a (c :: * -> *).
Data a =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tower a -> c (Tower a)
forall a (t :: * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Tower a))
forall a (t :: * -> * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tower a))
forall a.
Typeable a
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall u. Int -> (forall d. Data d => d -> u) -> Tower a -> u
forall u. (forall d. Data d => d -> u) -> Tower a -> [u]
forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tower a -> r
forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tower a -> r
forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Tower a -> m (Tower a)
forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Tower a -> m (Tower a)
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tower a)
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tower a -> c (Tower a)
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Tower a))
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tower a))
$cTower :: Constr
$tTower :: DataType
gmapMo :: (forall d. Data d => d -> m d) -> Tower a -> m (Tower a)
$cgmapMo :: forall a (m :: * -> *).
(Data a, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Tower a -> m (Tower a)
gmapMp :: (forall d. Data d => d -> m d) -> Tower a -> m (Tower a)
$cgmapMp :: forall a (m :: * -> *).
(Data a, MonadPlus m) =>
(forall d. Data d => d -> m d) -> Tower a -> m (Tower a)
gmapM :: (forall d. Data d => d -> m d) -> Tower a -> m (Tower a)
$cgmapM :: forall a (m :: * -> *).
(Data a, Monad m) =>
(forall d. Data d => d -> m d) -> Tower a -> m (Tower a)
gmapQi :: Int -> (forall d. Data d => d -> u) -> Tower a -> u
$cgmapQi :: forall a u.
Data a =>
Int -> (forall d. Data d => d -> u) -> Tower a -> u
gmapQ :: (forall d. Data d => d -> u) -> Tower a -> [u]
$cgmapQ :: forall a u.
Data a =>
(forall d. Data d => d -> u) -> Tower a -> [u]
gmapQr :: (r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tower a -> r
$cgmapQr :: forall a r r'.
Data a =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Tower a -> r
gmapQl :: (r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tower a -> r
$cgmapQl :: forall a r r'.
Data a =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Tower a -> r
gmapT :: (forall b. Data b => b -> b) -> Tower a -> Tower a
$cgmapT :: forall a.
Data a =>
(forall b. Data b => b -> b) -> Tower a -> Tower a
dataCast2 :: (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tower a))
$cdataCast2 :: forall a (t :: * -> * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tower a))
dataCast1 :: (forall d. Data d => c (t d)) -> Maybe (c (Tower a))
$cdataCast1 :: forall a (t :: * -> *) (c :: * -> *).
(Data a, Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Tower a))
dataTypeOf :: Tower a -> DataType
$cdataTypeOf :: forall a. Data a => Tower a -> DataType
toConstr :: Tower a -> Constr
$ctoConstr :: forall a. Data a => Tower a -> Constr
gunfold :: (forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tower a)
$cgunfold :: forall a (c :: * -> *).
Data a =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Tower a)
gfoldl :: (forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tower a -> c (Tower a)
$cgfoldl :: forall a (c :: * -> *).
Data a =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Tower a -> c (Tower a)
$cp1Data :: forall a. Data a => Typeable (Tower a)
Data, Typeable)

instance Show a => Show (Tower a) where
  showsPrec :: Int -> Tower a -> ShowS
showsPrec Int
n (Tower [a]
as) = Bool -> ShowS -> ShowS
showParen (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString String
"Tower " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> ShowS
forall a. Show a => [a] -> ShowS
showList [a]
as

-- Local combinators

zeroPad :: Num a => [a] -> [a]
zeroPad :: [a] -> [a]
zeroPad [a]
xs = [a]
xs [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ a -> [a]
forall a. a -> [a]
repeat a
0
{-# INLINE zeroPad #-}

zeroPadF :: (Functor f, Num a) => [f a] -> [f a]
zeroPadF :: [f a] -> [f a]
zeroPadF fxs :: [f a]
fxs@(f a
fx:[f a]
_) = [f a]
fxs [f a] -> [f a] -> [f a]
forall a. [a] -> [a] -> [a]
++ f a -> [f a]
forall a. a -> [a]
repeat (a
0 a -> f a -> f a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ f a
fx)
zeroPadF [f a]
_ = String -> [f a]
forall a. HasCallStack => String -> a
error String
"zeroPadF :: empty list"
{-# INLINE zeroPadF #-}

transposePadF :: (Foldable f, Functor f) => a -> f [a] -> [f a]
transposePadF :: a -> f [a] -> [f a]
transposePadF a
pad f [a]
fx
  | ([a] -> Bool) -> f [a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null f [a]
fx = []
  | Bool
otherwise = ([a] -> a) -> f [a] -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [a] -> a
headPad f [a]
fx f a -> [f a] -> [f a]
forall a. a -> [a] -> [a]
: a -> f [a] -> [f a]
forall (f :: * -> *) a.
(Foldable f, Functor f) =>
a -> f [a] -> [f a]
transposePadF a
pad ([a] -> [a]
forall a. [a] -> [a]
drop1 ([a] -> [a]) -> f [a] -> f [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f [a]
fx)
  where
    headPad :: [a] -> a
headPad [] = a
pad
    headPad (a
x:[a]
_) = a
x
    drop1 :: [a] -> [a]
drop1 (a
_:[a]
xs) = [a]
xs
    drop1 [a]
xs = [a]
xs

d :: Num a => [a] -> a
d :: [a] -> a
d (a
_:a
da:[a]
_) = a
da
d [a]
_ = a
0
{-# INLINE d #-}

d' :: Num a => [a] -> (a, a)
d' :: [a] -> (a, a)
d' (a
a:a
da:[a]
_) = (a
a, a
da)
d' (a
a:[a]
_)    = (a
a, a
0)
d' [a]
_        = (a
0, a
0)
{-# INLINE d' #-}

tangents :: Tower a -> Tower a
tangents :: Tower a -> Tower a
tangents (Tower []) = [a] -> Tower a
forall a. [a] -> Tower a
Tower []
tangents (Tower (a
_:[a]
xs)) = [a] -> Tower a
forall a. [a] -> Tower a
Tower [a]
xs
{-# INLINE tangents #-}

truncated :: Tower a -> Bool
truncated :: Tower a -> Bool
truncated (Tower []) = Bool
True
truncated Tower a
_ = Bool
False
{-# INLINE truncated #-}

bundle :: a -> Tower a -> Tower a
bundle :: a -> Tower a -> Tower a
bundle a
a (Tower [a]
as) = [a] -> Tower a
forall a. [a] -> Tower a
Tower (a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
as)
{-# INLINE bundle #-}

withD :: (a, a) -> Tower a
withD :: (a, a) -> Tower a
withD (a
a, a
da) = [a] -> Tower a
forall a. [a] -> Tower a
Tower [a
a,a
da]
{-# INLINE withD #-}

apply :: Num a => (Tower a -> b) -> a -> b
apply :: (Tower a -> b) -> a -> b
apply Tower a -> b
f a
a = Tower a -> b
f ([a] -> Tower a
forall a. [a] -> Tower a
Tower [a
a,a
1])
{-# INLINE apply #-}

getADTower :: Tower a -> [a]
getADTower :: Tower a -> [a]
getADTower = Tower a -> [a]
forall a. Tower a -> [a]
getTower
{-# INLINE getADTower #-}

tower :: [a] -> Tower a
tower :: [a] -> Tower a
tower = [a] -> Tower a
forall a. [a] -> Tower a
Tower

primal :: Num a => Tower a -> a
primal :: Tower a -> a
primal (Tower (a
x:[a]
_)) = a
x
primal Tower a
_ = a
0

instance Num a => Mode (Tower a) where
  type Scalar (Tower a) = a
  auto :: Scalar (Tower a) -> Tower a
auto Scalar (Tower a)
a = [a] -> Tower a
forall a. [a] -> Tower a
Tower [a
Scalar (Tower a)
a]
  isKnownZero :: Tower a -> Bool
isKnownZero (Tower [a]
xs) = [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [a]
xs
  asKnownConstant :: Tower a -> Maybe (Scalar (Tower a))
asKnownConstant (Tower []) = a -> Maybe a
forall a. a -> Maybe a
Just a
0
  asKnownConstant (Tower [a
a]) = a -> Maybe a
forall a. a -> Maybe a
Just a
a
  asKnownConstant Tower {} = Maybe (Scalar (Tower a))
forall a. Maybe a
Nothing
  isKnownConstant :: Tower a -> Bool
isKnownConstant (Tower []) = Bool
True
  isKnownConstant (Tower [a
_]) = Bool
True
  isKnownConstant Tower {} = Bool
False
  zero :: Tower a
zero = [a] -> Tower a
forall a. [a] -> Tower a
Tower []
  Scalar (Tower a)
a *^ :: Scalar (Tower a) -> Tower a -> Tower a
*^ Tower [a]
bs = [a] -> Tower a
forall a. [a] -> Tower a
Tower ((a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a
Scalar (Tower a)
aa -> a -> a
forall a. Num a => a -> a -> a
*) [a]
bs)
  Tower [a]
as ^* :: Tower a -> Scalar (Tower a) -> Tower a
^* Scalar (Tower a)
b = [a] -> Tower a
forall a. [a] -> Tower a
Tower ((a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a -> a -> a
forall a. Num a => a -> a -> a
*a
Scalar (Tower a)
b) [a]
as)
  Tower [a]
as ^/ :: Tower a -> Scalar (Tower a) -> Tower a
^/ Scalar (Tower a)
b = [a] -> Tower a
forall a. [a] -> Tower a
Tower ((a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a -> a -> a
forall a. Fractional a => a -> a -> a
/a
Scalar (Tower a)
b) [a]
as)

infixr 6 <+>

(<+>) :: Num a => Tower a -> Tower a -> Tower a
Tower [] <+> :: Tower a -> Tower a -> Tower a
<+> Tower a
bs = Tower a
bs
Tower a
as <+> Tower [] = Tower a
as
Tower (a
a:[a]
as) <+> Tower (a
b:[a]
bs) = [a] -> Tower a
forall a. [a] -> Tower a
Tower (a
ca -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
cs) where
  c :: a
c = a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b
  Tower [a]
cs = [a] -> Tower a
forall a. [a] -> Tower a
Tower [a]
as Tower a -> Tower a -> Tower a
forall a. Num a => Tower a -> Tower a -> Tower a
<+> [a] -> Tower a
forall a. [a] -> Tower a
Tower [a]
bs

instance Num a => Jacobian (Tower a) where
  type D (Tower a) = Tower a
  unary :: (Scalar (Tower a) -> Scalar (Tower a))
-> D (Tower a) -> Tower a -> Tower a
unary Scalar (Tower a) -> Scalar (Tower a)
f D (Tower a)
dadb Tower a
b = a -> Tower a -> Tower a
forall a. a -> Tower a -> Tower a
bundle (Scalar (Tower a) -> Scalar (Tower a)
f (Tower a -> a
forall a. Num a => Tower a -> a
primal Tower a
b)) (Tower a -> Tower a
forall a. Tower a -> Tower a
tangents Tower a
b Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
* D (Tower a)
Tower a
dadb)
  lift1 :: (Scalar (Tower a) -> Scalar (Tower a))
-> (D (Tower a) -> D (Tower a)) -> Tower a -> Tower a
lift1 Scalar (Tower a) -> Scalar (Tower a)
f D (Tower a) -> D (Tower a)
df Tower a
b   = a -> Tower a -> Tower a
forall a. a -> Tower a -> Tower a
bundle (Scalar (Tower a) -> Scalar (Tower a)
f (Tower a -> a
forall a. Num a => Tower a -> a
primal Tower a
b)) (Tower a -> Tower a
forall a. Tower a -> Tower a
tangents Tower a
b Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
* D (Tower a) -> D (Tower a)
df D (Tower a)
Tower a
b)
  lift1_ :: (Scalar (Tower a) -> Scalar (Tower a))
-> (D (Tower a) -> D (Tower a) -> D (Tower a))
-> Tower a
-> Tower a
lift1_ Scalar (Tower a) -> Scalar (Tower a)
f D (Tower a) -> D (Tower a) -> D (Tower a)
df Tower a
b = Tower a
a where
    a :: Tower a
a = a -> Tower a -> Tower a
forall a. a -> Tower a -> Tower a
bundle (Scalar (Tower a) -> Scalar (Tower a)
f (Tower a -> a
forall a. Num a => Tower a -> a
primal Tower a
b)) (Tower a -> Tower a
forall a. Tower a -> Tower a
tangents Tower a
b Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
* D (Tower a) -> D (Tower a) -> D (Tower a)
df D (Tower a)
Tower a
a D (Tower a)
Tower a
b)

  binary :: (Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a))
-> D (Tower a) -> D (Tower a) -> Tower a -> Tower a -> Tower a
binary Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a)
f D (Tower a)
dadb D (Tower a)
dadc Tower a
b Tower a
c = a -> Tower a -> Tower a
forall a. a -> Tower a -> Tower a
bundle (Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a)
f (Tower a -> a
forall a. Num a => Tower a -> a
primal Tower a
b) (Tower a -> a
forall a. Num a => Tower a -> a
primal Tower a
c)) (Tower a -> Tower a
forall a. Tower a -> Tower a
tangents Tower a
b Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
* D (Tower a)
Tower a
dadb Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
+ Tower a -> Tower a
forall a. Tower a -> Tower a
tangents Tower a
c Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
* D (Tower a)
Tower a
dadc)
  lift2 :: (Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a))
-> (D (Tower a) -> D (Tower a) -> (D (Tower a), D (Tower a)))
-> Tower a
-> Tower a
-> Tower a
lift2 Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a)
f D (Tower a) -> D (Tower a) -> (D (Tower a), D (Tower a))
df Tower a
b Tower a
c = a -> Tower a -> Tower a
forall a. a -> Tower a -> Tower a
bundle (Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a)
f (Tower a -> a
forall a. Num a => Tower a -> a
primal Tower a
b) (Tower a -> a
forall a. Num a => Tower a -> a
primal Tower a
c)) Tower a
tana where
     (Tower a
dadb, Tower a
dadc) = D (Tower a) -> D (Tower a) -> (D (Tower a), D (Tower a))
df D (Tower a)
Tower a
b D (Tower a)
Tower a
c
     tanb :: Tower a
tanb = Tower a -> Tower a
forall a. Tower a -> Tower a
tangents Tower a
b
     tanc :: Tower a
tanc = Tower a -> Tower a
forall a. Tower a -> Tower a
tangents Tower a
c
     tana :: Tower a
tana = case (Tower a -> Bool
forall a. Tower a -> Bool
truncated Tower a
tanb, Tower a -> Bool
forall a. Tower a -> Bool
truncated Tower a
tanc) of
       (Bool
False, Bool
False) -> Tower a
tanb Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
* Tower a
dadb Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
+ Tower a
tanc Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
* Tower a
dadc
       (Bool
True, Bool
False) -> Tower a
tanc Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
* Tower a
dadc
       (Bool
False, Bool
True) -> Tower a
tanb Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
* Tower a
dadb
       (Bool
True, Bool
True) -> Tower a
forall t. Mode t => t
zero
  lift2_ :: (Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a))
-> (D (Tower a)
    -> D (Tower a) -> D (Tower a) -> (D (Tower a), D (Tower a)))
-> Tower a
-> Tower a
-> Tower a
lift2_ Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a)
f D (Tower a)
-> D (Tower a) -> D (Tower a) -> (D (Tower a), D (Tower a))
df Tower a
b Tower a
c = Tower a
a where
    a0 :: Scalar (Tower a)
a0 = Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a)
f (Tower a -> a
forall a. Num a => Tower a -> a
primal Tower a
b) (Tower a -> a
forall a. Num a => Tower a -> a
primal Tower a
c)
    da :: Tower a
da = Tower a -> Tower a
forall a. Tower a -> Tower a
tangents Tower a
b Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
* Tower a
dadb Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
+ Tower a -> Tower a
forall a. Tower a -> Tower a
tangents Tower a
c Tower a -> Tower a -> Tower a
forall a. Num a => a -> a -> a
* Tower a
dadc
    a :: Tower a
a = a -> Tower a -> Tower a
forall a. a -> Tower a -> Tower a
bundle a
Scalar (Tower a)
a0 Tower a
da
    (Tower a
dadb, Tower a
dadc) = D (Tower a)
-> D (Tower a) -> D (Tower a) -> (D (Tower a), D (Tower a))
df D (Tower a)
Tower a
a D (Tower a)
Tower a
b D (Tower a)
Tower a
c

-- mul xs ys = [ sum [xs!!j * ys!!(k-j)*bin k j | j <- [0..k]] | k <- [0..] ]
-- adapted for efficiency and to handle finite lists xs, ys
mul:: Num a => Tower a -> Tower a -> Tower a
mul :: Tower a -> Tower a -> Tower a
mul (Tower []) Tower a
_ = [a] -> Tower a
forall a. [a] -> Tower a
Tower []
mul (Tower (a
a:[a]
as)) (Tower [a]
bs) = [a] -> Tower a
forall a. [a] -> Tower a
Tower ([a] -> [a] -> [a] -> [a] -> [a]
forall a. Num a => [a] -> [a] -> [a] -> [a] -> [a]
convs' [a
1] [a
a] [a]
as [a]
bs)
  where convs' :: [a] -> [a] -> [a] -> [a] -> [a]
convs' [a]
_ [a]
_ [a]
_ [] = []
        convs' [a]
ps [a]
ars [a]
as [a]
bs = [a] -> [a] -> [a] -> a
forall a. Num a => [a] -> [a] -> [a] -> a
sumProd3 [a]
ps [a]
ars [a]
bs a -> [a] -> [a]
forall a. a -> [a] -> [a]
:
              case [a]
as of
                 [] -> [a] -> [a] -> [a] -> [a]
forall a. Num a => [a] -> [a] -> [a] -> [a]
convs'' ([a] -> [a]
forall a. Num a => [a] -> [a]
next' [a]
ps) [a]
ars [a]
bs
                 a
a:[a]
as -> [a] -> [a] -> [a] -> [a] -> [a]
convs' ([a] -> [a]
forall a. Num a => [a] -> [a]
next [a]
ps) (a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
ars) [a]
as [a]
bs
        convs'' :: [a] -> [a] -> [a] -> [a]
convs'' [a]
_ [a]
_ [] = [a]
forall a. HasCallStack => a
undefined -- convs'' never called with last argument empty
        convs'' [a]
_ [a]
_ [a
_] = []
        convs'' [a]
ps [a]
ars (a
_:[a]
bs) = [a] -> [a] -> [a] -> a
forall a. Num a => [a] -> [a] -> [a] -> a
sumProd3 [a]
ps [a]
ars [a]
bs a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a] -> [a]
convs'' ([a] -> [a]
forall a. Num a => [a] -> [a]
next' [a]
ps) [a]
ars [a]
bs
        next :: [a] -> [a]
next [a]
xs = a
1 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (a -> a -> a) -> [a] -> [a] -> [a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith a -> a -> a
forall a. Num a => a -> a -> a
(+) [a]
xs ([a] -> [a]
forall a. [a] -> [a]
tail [a]
xs) [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
1] -- next row in Pascal's triangle
        next' :: [a] -> [a]
next' [a]
xs = (a -> a -> a) -> [a] -> [a] -> [a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith a -> a -> a
forall a. Num a => a -> a -> a
(+) [a]
xs ([a] -> [a]
forall a. [a] -> [a]
tail [a]
xs) [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
1] -- end part of next row in Pascal's triangle
        sumProd3 :: [a] -> [a] -> [a] -> a
sumProd3 [a]
as [a]
bs [a]
cs = [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((a -> a -> a -> a) -> [a] -> [a] -> [a] -> [a]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 (\a
x a
y a
z -> a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
ya -> a -> a
forall a. Num a => a -> a -> a
*a
z) [a]
as [a]
bs [a]
cs)

#define HEAD (Tower a)
#include <instances.h>