{-# LANGUAGE CPP #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE DeriveDataTypeable #-} {-# OPTIONS_GHC -fno-warn-name-shadowing #-} {-# OPTIONS_HADDOCK not-home #-} ----------------------------------------------------------------------------- -- | -- Copyright : (c) Edward Kmett 2010-2021 -- License : BSD3 -- Maintainer : ekmett@gmail.com -- Stability : experimental -- Portability : GHC only -- ----------------------------------------------------------------------------- module Numeric.AD.Internal.Tower ( Tower(..) , zeroPad , zeroPadF , transposePadF , d , d' , withD , tangents , bundle , apply , getADTower , tower ) where import Prelude hiding (all, sum) import Control.Monad (join) import Data.Foldable import Data.Data (Data) import Data.Number.Erf import Data.Typeable (Typeable) import Numeric.AD.Internal.Combinators import Numeric.AD.Jacobian import Numeric.AD.Mode -- | @Tower@ is an AD 'Mode' that calculates a tangent tower by forward AD, and provides fast 'diffsUU', 'diffsUF' newtype Tower a = Tower { forall a. Tower a -> [a] getTower :: [a] } deriving (Tower a -> DataType Tower a -> Constr forall {a}. Data a => Typeable (Tower a) forall a. Data a => Tower a -> DataType forall a. Data a => Tower a -> Constr forall a. Data a => (forall b. Data b => b -> b) -> Tower a -> Tower a forall a u. Data a => Int -> (forall d. Data d => d -> u) -> Tower a -> u forall a u. Data a => (forall d. Data d => d -> u) -> Tower a -> [u] forall a r r'. Data a => (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> Tower a -> r forall a r r'. Data a => (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> Tower a -> r forall a (m :: * -> *). (Data a, Monad m) => (forall d. Data d => d -> m d) -> Tower a -> m (Tower a) forall a (m :: * -> *). (Data a, MonadPlus m) => (forall d. Data d => d -> m d) -> Tower a -> m (Tower a) forall a (c :: * -> *). Data a => (forall b r. Data b => c (b -> r) -> c r) -> (forall r. r -> c r) -> Constr -> c (Tower a) forall a (c :: * -> *). Data a => (forall d b. Data d => c (d -> b) -> d -> c b) -> (forall g. g -> c g) -> Tower a -> c (Tower a) forall a (t :: * -> *) (c :: * -> *). (Data a, Typeable t) => (forall d. Data d => c (t d)) -> Maybe (c (Tower a)) forall a (t :: * -> * -> *) (c :: * -> *). (Data a, Typeable t) => (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tower a)) forall a. Typeable a -> (forall (c :: * -> *). (forall d b. Data d => c (d -> b) -> d -> c b) -> (forall g. g -> c g) -> a -> c a) -> (forall (c :: * -> *). (forall b r. Data b => c (b -> r) -> c r) -> (forall r. r -> c r) -> Constr -> c a) -> (a -> Constr) -> (a -> DataType) -> (forall (t :: * -> *) (c :: * -> *). Typeable t => (forall d. Data d => c (t d)) -> Maybe (c a)) -> (forall (t :: * -> * -> *) (c :: * -> *). Typeable t => (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a)) -> ((forall b. Data b => b -> b) -> a -> a) -> (forall r r'. (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r) -> (forall r r'. (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r) -> (forall u. (forall d. Data d => d -> u) -> a -> [u]) -> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u) -> (forall (m :: * -> *). Monad m => (forall d. Data d => d -> m d) -> a -> m a) -> (forall (m :: * -> *). MonadPlus m => (forall d. Data d => d -> m d) -> a -> m a) -> (forall (m :: * -> *). MonadPlus m => (forall d. Data d => d -> m d) -> a -> m a) -> Data a forall (c :: * -> *). (forall b r. Data b => c (b -> r) -> c r) -> (forall r. r -> c r) -> Constr -> c (Tower a) forall (c :: * -> *). (forall d b. Data d => c (d -> b) -> d -> c b) -> (forall g. g -> c g) -> Tower a -> c (Tower a) forall (t :: * -> *) (c :: * -> *). Typeable t => (forall d. Data d => c (t d)) -> Maybe (c (Tower a)) gmapMo :: forall (m :: * -> *). MonadPlus m => (forall d. Data d => d -> m d) -> Tower a -> m (Tower a) $cgmapMo :: forall a (m :: * -> *). (Data a, MonadPlus m) => (forall d. Data d => d -> m d) -> Tower a -> m (Tower a) gmapMp :: forall (m :: * -> *). MonadPlus m => (forall d. Data d => d -> m d) -> Tower a -> m (Tower a) $cgmapMp :: forall a (m :: * -> *). (Data a, MonadPlus m) => (forall d. Data d => d -> m d) -> Tower a -> m (Tower a) gmapM :: forall (m :: * -> *). Monad m => (forall d. Data d => d -> m d) -> Tower a -> m (Tower a) $cgmapM :: forall a (m :: * -> *). (Data a, Monad m) => (forall d. Data d => d -> m d) -> Tower a -> m (Tower a) gmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> Tower a -> u $cgmapQi :: forall a u. Data a => Int -> (forall d. Data d => d -> u) -> Tower a -> u gmapQ :: forall u. (forall d. Data d => d -> u) -> Tower a -> [u] $cgmapQ :: forall a u. Data a => (forall d. Data d => d -> u) -> Tower a -> [u] gmapQr :: forall r r'. (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> Tower a -> r $cgmapQr :: forall a r r'. Data a => (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> Tower a -> r gmapQl :: forall r r'. (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> Tower a -> r $cgmapQl :: forall a r r'. Data a => (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> Tower a -> r gmapT :: (forall b. Data b => b -> b) -> Tower a -> Tower a $cgmapT :: forall a. Data a => (forall b. Data b => b -> b) -> Tower a -> Tower a dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *). Typeable t => (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tower a)) $cdataCast2 :: forall a (t :: * -> * -> *) (c :: * -> *). (Data a, Typeable t) => (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (Tower a)) dataCast1 :: forall (t :: * -> *) (c :: * -> *). Typeable t => (forall d. Data d => c (t d)) -> Maybe (c (Tower a)) $cdataCast1 :: forall a (t :: * -> *) (c :: * -> *). (Data a, Typeable t) => (forall d. Data d => c (t d)) -> Maybe (c (Tower a)) dataTypeOf :: Tower a -> DataType $cdataTypeOf :: forall a. Data a => Tower a -> DataType toConstr :: Tower a -> Constr $ctoConstr :: forall a. Data a => Tower a -> Constr gunfold :: forall (c :: * -> *). (forall b r. Data b => c (b -> r) -> c r) -> (forall r. r -> c r) -> Constr -> c (Tower a) $cgunfold :: forall a (c :: * -> *). Data a => (forall b r. Data b => c (b -> r) -> c r) -> (forall r. r -> c r) -> Constr -> c (Tower a) gfoldl :: forall (c :: * -> *). (forall d b. Data d => c (d -> b) -> d -> c b) -> (forall g. g -> c g) -> Tower a -> c (Tower a) $cgfoldl :: forall a (c :: * -> *). Data a => (forall d b. Data d => c (d -> b) -> d -> c b) -> (forall g. g -> c g) -> Tower a -> c (Tower a) Data, Typeable) instance Show a => Show (Tower a) where showsPrec :: Int -> Tower a -> ShowS showsPrec Int n (Tower [a] as) = Bool -> ShowS -> ShowS showParen (Int n forall a. Ord a => a -> a -> Bool > Int 10) forall a b. (a -> b) -> a -> b $ String -> ShowS showString String "Tower " forall b c a. (b -> c) -> (a -> b) -> a -> c . forall a. Show a => [a] -> ShowS showList [a] as -- Local combinators zeroPad :: Num a => [a] -> [a] zeroPad :: forall a. Num a => [a] -> [a] zeroPad [a] xs = [a] xs forall a. [a] -> [a] -> [a] ++ forall a. a -> [a] repeat a 0 {-# INLINE zeroPad #-} zeroPadF :: (Functor f, Num a) => [f a] -> [f a] zeroPadF :: forall (f :: * -> *) a. (Functor f, Num a) => [f a] -> [f a] zeroPadF fxs :: [f a] fxs@(f a fx:[f a] _) = [f a] fxs forall a. [a] -> [a] -> [a] ++ forall a. a -> [a] repeat (a 0 forall (f :: * -> *) a b. Functor f => a -> f b -> f a <$ f a fx) zeroPadF [f a] _ = forall a. HasCallStack => String -> a error String "zeroPadF :: empty list" {-# INLINE zeroPadF #-} transposePadF :: (Foldable f, Functor f) => a -> f [a] -> [f a] transposePadF :: forall (f :: * -> *) a. (Foldable f, Functor f) => a -> f [a] -> [f a] transposePadF a pad f [a] fx | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool all forall (t :: * -> *) a. Foldable t => t a -> Bool null f [a] fx = [] | Bool otherwise = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b fmap [a] -> a headPad f [a] fx forall a. a -> [a] -> [a] : forall (f :: * -> *) a. (Foldable f, Functor f) => a -> f [a] -> [f a] transposePadF a pad (forall {a}. [a] -> [a] drop1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> f [a] fx) where headPad :: [a] -> a headPad [] = a pad headPad (a x:[a] _) = a x drop1 :: [a] -> [a] drop1 (a _:[a] xs) = [a] xs drop1 [a] xs = [a] xs d :: Num a => [a] -> a d :: forall a. Num a => [a] -> a d (a _:a da:[a] _) = a da d [a] _ = a 0 {-# INLINE d #-} d' :: Num a => [a] -> (a, a) d' :: forall a. Num a => [a] -> (a, a) d' (a a:a da:[a] _) = (a a, a da) d' (a a:[a] _) = (a a, a 0) d' [a] _ = (a 0, a 0) {-# INLINE d' #-} tangents :: Tower a -> Tower a tangents :: forall a. Tower a -> Tower a tangents (Tower []) = forall a. [a] -> Tower a Tower [] tangents (Tower (a _:[a] xs)) = forall a. [a] -> Tower a Tower [a] xs {-# INLINE tangents #-} truncated :: Tower a -> Bool truncated :: forall a. Tower a -> Bool truncated (Tower []) = Bool True truncated Tower a _ = Bool False {-# INLINE truncated #-} bundle :: a -> Tower a -> Tower a bundle :: forall a. a -> Tower a -> Tower a bundle a a (Tower [a] as) = forall a. [a] -> Tower a Tower (a aforall a. a -> [a] -> [a] :[a] as) {-# INLINE bundle #-} withD :: (a, a) -> Tower a withD :: forall a. (a, a) -> Tower a withD (a a, a da) = forall a. [a] -> Tower a Tower [a a,a da] {-# INLINE withD #-} apply :: Num a => (Tower a -> b) -> a -> b apply :: forall a b. Num a => (Tower a -> b) -> a -> b apply Tower a -> b f a a = Tower a -> b f (forall a. [a] -> Tower a Tower [a a,a 1]) {-# INLINE apply #-} getADTower :: Tower a -> [a] getADTower :: forall a. Tower a -> [a] getADTower = forall a. Tower a -> [a] getTower {-# INLINE getADTower #-} tower :: [a] -> Tower a tower :: forall a. [a] -> Tower a tower = forall a. [a] -> Tower a Tower primal :: Num a => Tower a -> a primal :: forall a. Num a => Tower a -> a primal (Tower (a x:[a] _)) = a x primal Tower a _ = a 0 instance Num a => Mode (Tower a) where type Scalar (Tower a) = a auto :: Scalar (Tower a) -> Tower a auto Scalar (Tower a) a = forall a. [a] -> Tower a Tower [Scalar (Tower a) a] isKnownZero :: Tower a -> Bool isKnownZero (Tower [a] xs) = forall (t :: * -> *) a. Foldable t => t a -> Bool null [a] xs asKnownConstant :: Tower a -> Maybe (Scalar (Tower a)) asKnownConstant (Tower []) = forall a. a -> Maybe a Just a 0 asKnownConstant (Tower [a a]) = forall a. a -> Maybe a Just a a asKnownConstant Tower {} = forall a. Maybe a Nothing isKnownConstant :: Tower a -> Bool isKnownConstant (Tower []) = Bool True isKnownConstant (Tower [a _]) = Bool True isKnownConstant Tower {} = Bool False zero :: Tower a zero = forall a. [a] -> Tower a Tower [] Scalar (Tower a) a *^ :: Scalar (Tower a) -> Tower a -> Tower a *^ Tower [a] bs = forall a. [a] -> Tower a Tower (forall a b. (a -> b) -> [a] -> [b] map (Scalar (Tower a) aforall a. Num a => a -> a -> a *) [a] bs) Tower [a] as ^* :: Tower a -> Scalar (Tower a) -> Tower a ^* Scalar (Tower a) b = forall a. [a] -> Tower a Tower (forall a b. (a -> b) -> [a] -> [b] map (forall a. Num a => a -> a -> a *Scalar (Tower a) b) [a] as) Tower [a] as ^/ :: Fractional (Scalar (Tower a)) => Tower a -> Scalar (Tower a) -> Tower a ^/ Scalar (Tower a) b = forall a. [a] -> Tower a Tower (forall a b. (a -> b) -> [a] -> [b] map (forall a. Fractional a => a -> a -> a /Scalar (Tower a) b) [a] as) infixr 6 <+> (<+>) :: Num a => Tower a -> Tower a -> Tower a Tower [] <+> :: forall a. Num a => Tower a -> Tower a -> Tower a <+> Tower a bs = Tower a bs Tower a as <+> Tower [] = Tower a as Tower (a a:[a] as) <+> Tower (a b:[a] bs) = forall a. [a] -> Tower a Tower (a cforall a. a -> [a] -> [a] :[a] cs) where c :: a c = a a forall a. Num a => a -> a -> a + a b Tower [a] cs = forall a. [a] -> Tower a Tower [a] as forall a. Num a => Tower a -> Tower a -> Tower a <+> forall a. [a] -> Tower a Tower [a] bs instance Num a => Jacobian (Tower a) where type D (Tower a) = Tower a unary :: (Scalar (Tower a) -> Scalar (Tower a)) -> D (Tower a) -> Tower a -> Tower a unary Scalar (Tower a) -> Scalar (Tower a) f D (Tower a) dadb Tower a b = forall a. a -> Tower a -> Tower a bundle (Scalar (Tower a) -> Scalar (Tower a) f (forall a. Num a => Tower a -> a primal Tower a b)) (forall a. Tower a -> Tower a tangents Tower a b forall a. Num a => a -> a -> a * D (Tower a) dadb) lift1 :: (Scalar (Tower a) -> Scalar (Tower a)) -> (D (Tower a) -> D (Tower a)) -> Tower a -> Tower a lift1 Scalar (Tower a) -> Scalar (Tower a) f D (Tower a) -> D (Tower a) df Tower a b = forall a. a -> Tower a -> Tower a bundle (Scalar (Tower a) -> Scalar (Tower a) f (forall a. Num a => Tower a -> a primal Tower a b)) (forall a. Tower a -> Tower a tangents Tower a b forall a. Num a => a -> a -> a * D (Tower a) -> D (Tower a) df Tower a b) lift1_ :: (Scalar (Tower a) -> Scalar (Tower a)) -> (D (Tower a) -> D (Tower a) -> D (Tower a)) -> Tower a -> Tower a lift1_ Scalar (Tower a) -> Scalar (Tower a) f D (Tower a) -> D (Tower a) -> D (Tower a) df Tower a b = Tower a a where a :: Tower a a = forall a. a -> Tower a -> Tower a bundle (Scalar (Tower a) -> Scalar (Tower a) f (forall a. Num a => Tower a -> a primal Tower a b)) (forall a. Tower a -> Tower a tangents Tower a b forall a. Num a => a -> a -> a * D (Tower a) -> D (Tower a) -> D (Tower a) df Tower a a Tower a b) binary :: (Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a)) -> D (Tower a) -> D (Tower a) -> Tower a -> Tower a -> Tower a binary Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a) f D (Tower a) dadb D (Tower a) dadc Tower a b Tower a c = forall a. a -> Tower a -> Tower a bundle (Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a) f (forall a. Num a => Tower a -> a primal Tower a b) (forall a. Num a => Tower a -> a primal Tower a c)) (forall a. Tower a -> Tower a tangents Tower a b forall a. Num a => a -> a -> a * D (Tower a) dadb forall a. Num a => a -> a -> a + forall a. Tower a -> Tower a tangents Tower a c forall a. Num a => a -> a -> a * D (Tower a) dadc) lift2 :: (Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a)) -> (D (Tower a) -> D (Tower a) -> (D (Tower a), D (Tower a))) -> Tower a -> Tower a -> Tower a lift2 Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a) f D (Tower a) -> D (Tower a) -> (D (Tower a), D (Tower a)) df Tower a b Tower a c = forall a. a -> Tower a -> Tower a bundle (Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a) f (forall a. Num a => Tower a -> a primal Tower a b) (forall a. Num a => Tower a -> a primal Tower a c)) Tower a tana where (D (Tower a) dadb, D (Tower a) dadc) = D (Tower a) -> D (Tower a) -> (D (Tower a), D (Tower a)) df Tower a b Tower a c tanb :: Tower a tanb = forall a. Tower a -> Tower a tangents Tower a b tanc :: Tower a tanc = forall a. Tower a -> Tower a tangents Tower a c tana :: Tower a tana = case (forall a. Tower a -> Bool truncated Tower a tanb, forall a. Tower a -> Bool truncated Tower a tanc) of (Bool False, Bool False) -> Tower a tanb forall a. Num a => a -> a -> a * D (Tower a) dadb forall a. Num a => a -> a -> a + Tower a tanc forall a. Num a => a -> a -> a * D (Tower a) dadc (Bool True, Bool False) -> Tower a tanc forall a. Num a => a -> a -> a * D (Tower a) dadc (Bool False, Bool True) -> Tower a tanb forall a. Num a => a -> a -> a * D (Tower a) dadb (Bool True, Bool True) -> forall t. Mode t => t zero lift2_ :: (Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a)) -> (D (Tower a) -> D (Tower a) -> D (Tower a) -> (D (Tower a), D (Tower a))) -> Tower a -> Tower a -> Tower a lift2_ Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a) f D (Tower a) -> D (Tower a) -> D (Tower a) -> (D (Tower a), D (Tower a)) df Tower a b Tower a c = Tower a a where a0 :: Scalar (Tower a) a0 = Scalar (Tower a) -> Scalar (Tower a) -> Scalar (Tower a) f (forall a. Num a => Tower a -> a primal Tower a b) (forall a. Num a => Tower a -> a primal Tower a c) da :: Tower a da = forall a. Tower a -> Tower a tangents Tower a b forall a. Num a => a -> a -> a * Tower a dadb forall a. Num a => a -> a -> a + forall a. Tower a -> Tower a tangents Tower a c forall a. Num a => a -> a -> a * Tower a dadc a :: Tower a a = forall a. a -> Tower a -> Tower a bundle Scalar (Tower a) a0 Tower a da (Tower a dadb, Tower a dadc) = D (Tower a) -> D (Tower a) -> D (Tower a) -> (D (Tower a), D (Tower a)) df Tower a a Tower a b Tower a c -- mul xs ys = [ sum [xs!!j * ys!!(k-j)*bin k j | j <- [0..k]] | k <- [0..] ] -- adapted for efficiency and to handle finite lists xs, ys mul:: Num a => Tower a -> Tower a -> Tower a mul :: forall a. Num a => Tower a -> Tower a -> Tower a mul (Tower []) Tower a _ = forall a. [a] -> Tower a Tower [] mul (Tower (a a:[a] as)) (Tower [a] bs) = forall a. [a] -> Tower a Tower (forall {a}. Num a => [a] -> [a] -> [a] -> [a] -> [a] convs' [a 1] [a a] [a] as [a] bs) where convs' :: [a] -> [a] -> [a] -> [a] -> [a] convs' [a] _ [a] _ [a] _ [] = [] convs' [a] ps [a] ars [a] as [a] bs = forall {a}. Num a => [a] -> [a] -> [a] -> a sumProd3 [a] ps [a] ars [a] bs forall a. a -> [a] -> [a] : case [a] as of [] -> forall {a}. Num a => [a] -> [a] -> [a] -> [a] convs'' (forall a. Num a => [a] -> [a] next' [a] ps) [a] ars [a] bs a a:[a] as -> [a] -> [a] -> [a] -> [a] -> [a] convs' (forall a. Num a => [a] -> [a] next [a] ps) (a aforall a. a -> [a] -> [a] :[a] ars) [a] as [a] bs convs'' :: [a] -> [a] -> [a] -> [a] convs'' [a] _ [a] _ [] = forall a. HasCallStack => a undefined -- convs'' never called with last argument empty convs'' [a] _ [a] _ [a _] = [] convs'' [a] ps [a] ars (a _:[a] bs) = forall {a}. Num a => [a] -> [a] -> [a] -> a sumProd3 [a] ps [a] ars [a] bs forall a. a -> [a] -> [a] : [a] -> [a] -> [a] -> [a] convs'' (forall a. Num a => [a] -> [a] next' [a] ps) [a] ars [a] bs next :: [a] -> [a] next [a] xs = a 1 forall a. a -> [a] -> [a] : forall a b c. (a -> b -> c) -> [a] -> [b] -> [c] zipWith forall a. Num a => a -> a -> a (+) [a] xs (forall {a}. [a] -> [a] tail [a] xs) forall a. [a] -> [a] -> [a] ++ [a 1] -- next row in Pascal's triangle next' :: [a] -> [a] next' [a] xs = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c] zipWith forall a. Num a => a -> a -> a (+) [a] xs (forall {a}. [a] -> [a] tail [a] xs) forall a. [a] -> [a] -> [a] ++ [a 1] -- end part of next row in Pascal's triangle sumProd3 :: [a] -> [a] -> [a] -> a sumProd3 [a] as [a] bs [a] cs = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a sum (forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d] zipWith3 (\a x a y a z -> a xforall a. Num a => a -> a -> a *a yforall a. Num a => a -> a -> a *a z) [a] as [a] bs [a] cs) #define HEAD (Tower a) #include <instances.h>