{-# LANGUAGE CPP #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_HADDOCK not-home #-} ----------------------------------------------------------------------------- ---- | ---- Copyright : (c) Edward Kmett 2010-2015 ---- License : BSD3 ---- Maintainer : ekmett@gmail.com ---- Stability : experimental ---- Portability : GHC only ---- ---- Unsafe and often partial combinators intended for internal usage. ---- ---- Handle with care. ------------------------------------------------------------------------------- module Numeric.AD.Internal.Forward.Double ( ForwardDouble(..) , bundle , unbundle , apply , bind , bind' , bindWith , bindWith' , transposeWith ) where #if __GLASGOW_HASKELL__ < 710 import Control.Applicative hiding ((<**>)) import Data.Foldable (Foldable, toList) import Data.Traversable (Traversable, mapAccumL) #else import Data.Foldable (toList) import Data.Traversable (mapAccumL) #endif import Control.Monad (join) import Data.Function (on) import Data.Number.Erf import Numeric.AD.Internal.Combinators import Numeric.AD.Internal.Identity import Numeric.AD.Jacobian import Numeric.AD.Mode data ForwardDouble = ForwardDouble { primal, tangent :: {-# UNPACK #-} !Double } deriving (Read, Show) unbundle :: ForwardDouble -> (Double, Double) unbundle (ForwardDouble a da) = (a, da) {-# INLINE unbundle #-} bundle :: Double -> Double -> ForwardDouble bundle = ForwardDouble {-# INLINE bundle #-} apply :: (ForwardDouble -> b) -> Double -> b apply f a = f (bundle a 1) {-# INLINE apply #-} instance Mode ForwardDouble where type Scalar ForwardDouble = Double auto = flip ForwardDouble 0 zero = ForwardDouble 0 0 isKnownZero (ForwardDouble 0 0) = True isKnownZero _ = False isKnownConstant (ForwardDouble _ 0) = True isKnownConstant _ = False a *^ ForwardDouble b db = ForwardDouble (a * b) (a * db) ForwardDouble a da ^* b = ForwardDouble (a * b) (da * b) ForwardDouble a da ^/ b = ForwardDouble (a / b) (da / b) (<+>) :: ForwardDouble -> ForwardDouble -> ForwardDouble ForwardDouble a da <+> ForwardDouble b db = ForwardDouble (a + b) (da + db) instance Jacobian ForwardDouble where type D ForwardDouble = Id Double unary f (Id dadb) (ForwardDouble b db) = ForwardDouble (f b) (dadb * db) lift1 f df (ForwardDouble b db) = ForwardDouble (f b) (dadb * db) where Id dadb = df (Id b) lift1_ f df (ForwardDouble b db) = ForwardDouble a da where a = f b Id da = df (Id a) (Id b) ^* db binary f (Id dadb) (Id dadc) (ForwardDouble b db) (ForwardDouble c dc) = ForwardDouble (f b c) $ dadb * db + dc * dadc lift2 f df (ForwardDouble b db) (ForwardDouble c dc) = ForwardDouble a da where a = f b c (Id dadb, Id dadc) = df (Id b) (Id c) da = dadb * db + dc * dadc lift2_ f df (ForwardDouble b db) (ForwardDouble c dc) = ForwardDouble a da where a = f b c (Id dadb, Id dadc) = df (Id a) (Id b) (Id c) da = dadb * db + dc * dadc instance Eq ForwardDouble where (==) = on (==) primal instance Ord ForwardDouble where compare = on compare primal instance Num ForwardDouble where fromInteger 0 = zero fromInteger n = auto (fromInteger n) (+) = (<+>) -- binary (+) 1 1 (-) = binary (-) (auto 1) (auto (-1)) -- TODO: <-> ? as it is, this might be pretty bad for Tower (*) = lift2 (*) (\x y -> (y, x)) negate = lift1 negate (const (auto (-1))) abs = lift1 abs signum signum a = lift1 signum (const zero) a instance Fractional ForwardDouble where fromRational 0 = zero fromRational r = auto (fromRational r) x / y = x * recip y recip = lift1_ recip (const . negate . join (*)) instance Floating ForwardDouble where pi = auto pi exp = lift1_ exp const log = lift1 log recip logBase x y = log y / log x sqrt = lift1_ sqrt (\z _ -> recip (auto 2 * z)) ForwardDouble 0 0 ** ForwardDouble a _ = ForwardDouble (0 ** a) 0 _ ** ForwardDouble 0 0 = ForwardDouble 1 0 x ** ForwardDouble y 0 = lift1 (**y) (\z -> y *^ z ** Id (y - 1)) x x ** y = lift2_ (**) (\z xi yi -> (yi * z / xi, z * log xi)) x y sin = lift1 sin cos cos = lift1 cos $ negate . sin tan = lift1 tan $ recip . join (*) . cos asin = lift1 asin $ \x -> recip (sqrt (auto 1 - join (*) x)) acos = lift1 acos $ \x -> negate (recip (sqrt (1 - join (*) x))) atan = lift1 atan $ \x -> recip (1 + join (*) x) sinh = lift1 sinh cosh cosh = lift1 cosh sinh tanh = lift1 tanh $ recip . join (*) . cosh asinh = lift1 asinh $ \x -> recip (sqrt (1 + join (*) x)) acosh = lift1 acosh $ \x -> recip (sqrt (join (*) x - 1)) atanh = lift1 atanh $ \x -> recip (1 - join (*) x) instance Enum ForwardDouble where succ = lift1 succ (const 1) pred = lift1 pred (const 1) toEnum = auto . toEnum fromEnum = fromEnum . primal enumFrom a = withPrimal a <$> enumFrom (primal a) enumFromTo a b = withPrimal a <$> enumFromTo (primal a) (primal b) enumFromThen a b = zipWith (fromBy a delta) [0..] $ enumFromThen (primal a) (primal b) where delta = b - a enumFromThenTo a b c = zipWith (fromBy a delta) [0..] $ enumFromThenTo (primal a) (primal b) (primal c) where delta = b - a instance Real ForwardDouble where toRational = toRational . primal instance RealFloat ForwardDouble where floatRadix = floatRadix . primal floatDigits = floatDigits . primal floatRange = floatRange . primal decodeFloat = decodeFloat . primal encodeFloat m e = auto (encodeFloat m e) isNaN = isNaN . primal isInfinite = isInfinite . primal isDenormalized = isDenormalized . primal isNegativeZero = isNegativeZero . primal isIEEE = isIEEE . primal exponent = exponent scaleFloat n = unary (scaleFloat n) (scaleFloat n 1) significand x = unary significand (scaleFloat (- floatDigits x) 1) x atan2 = lift2 atan2 $ \vx vy -> let r = recip (join (*) vx + join (*) vy) in (vy * r, negate vx * r) instance RealFrac ForwardDouble where properFraction a = (w, a `withPrimal` pb) where pa = primal a (w, pb) = properFraction pa truncate = truncate . primal round = round . primal ceiling = ceiling . primal floor = floor . primal instance Erf ForwardDouble where erf = lift1 erf $ \x -> (2 / sqrt pi) * exp (negate x * x) erfc = lift1 erfc $ \x -> ((-2) / sqrt pi) * exp (negate x * x) normcdf = lift1 normcdf $ \x -> ((-1) / sqrt pi) * exp (x * x * fromRational (- recip 2) / sqrt 2) instance InvErf ForwardDouble where inverf = lift1 inverfc $ \x -> recip $ (2 / sqrt pi) * exp (negate x * x) inverfc = lift1 inverfc $ \x -> recip $ negate (2 / sqrt pi) * exp (negate x * x) invnormcdf = lift1 invnormcdf $ \x -> recip $ ((-1) / sqrt pi) * exp (x * x * fromRational (- recip 2) / sqrt 2) bind :: Traversable f => (f ForwardDouble -> b) -> f Double -> f b bind f as = snd $ mapAccumL outer (0 :: Int) as where outer !i _ = (i + 1, f $ snd $ mapAccumL (inner i) 0 as) inner !i !j a = (j + 1, if i == j then bundle a 1 else auto a) bind' :: Traversable f => (f ForwardDouble -> b) -> f Double -> (b, f b) bind' f as = dropIx $ mapAccumL outer (0 :: Int, b0) as where outer (!i, _) _ = let b = f $ snd $ mapAccumL (inner i) (0 :: Int) as in ((i + 1, b), b) inner !i !j a = (j + 1, if i == j then bundle a 1 else auto a) b0 = f (auto <$> as) dropIx ((_,b),bs) = (b,bs) bindWith :: Traversable f => (Double -> b -> c) -> (f ForwardDouble -> b) -> f Double -> f c bindWith g f as = snd $ mapAccumL outer (0 :: Int) as where outer !i a = (i + 1, g a $ f $ snd $ mapAccumL (inner i) 0 as) inner !i !j a = (j + 1, if i == j then bundle a 1 else auto a) bindWith' :: Traversable f => (Double -> b -> c) -> (f ForwardDouble -> b) -> f Double -> (b, f c) bindWith' g f as = dropIx $ mapAccumL outer (0 :: Int, b0) as where outer (!i, _) a = let b = f $ snd $ mapAccumL (inner i) (0 :: Int) as in ((i + 1, b), g a b) inner !i !j a = (j + 1, if i == j then bundle a 1 else auto a) b0 = f (auto <$> as) dropIx ((_,b),bs) = (b,bs) transposeWith :: (Functor f, Foldable f, Traversable g) => (b -> f a -> c) -> f (g a) -> g b -> g c transposeWith f as = snd . mapAccumL go xss0 where go xss b = (tail <$> xss, f b (head <$> xss)) xss0 = toList <$> as