{-# LANGUAGE Rank2Types, GeneralizedNewtypeDeriving, TemplateHaskell, FlexibleContexts, FlexibleInstances, MultiParamTypeClasses, DeriveFunctor, DeriveFoldable, DeriveTraversable #-} -- {-# OPTIONS_HADDOCK hide, prune #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.AD.Internal -- Copyright : (c) Edward Kmett 2010 -- License : BSD3 -- Maintainer : ekmett@gmail.com -- Stability : experimental -- Portability : GHC only -- ----------------------------------------------------------------------------- module Numeric.AD.Internal ( module Numeric.AD.Internal.Classes , UU, UF, FU, FF , zipWithT , zipWithDefaultT , on , AD(..) , Id(..) , probe , unprobe , probed , unprobed , Pair(..) ) where import Control.Applicative import Language.Haskell.TH import Numeric.AD.Internal.Classes import Data.Monoid import Data.Traversable (Traversable, mapAccumL) import Data.Foldable (Foldable, toList) -- | A scalar-to-scalar automatically-differentiable function. type UU a = forall s. Mode s => AD s a -> AD s a -- | A scalar-to-non-scalar automatically-differentiable function. type UF f a = forall s. Mode s => AD s a -> f (AD s a) -- | A non-scalar-to-scalar automatically-differentiable function. type FU f a = forall s. Mode s => f (AD s a) -> AD s a -- | A non-scalar-to-non-scalar automatically-differentiable function. type FF f g a = forall s. Mode s => f (AD s a) -> g (AD s a) on :: (a -> a -> b) -> (c -> a) -> c -> c -> b on f g a b = f (g a) (g b) data Pair a b = Pair a b deriving (Eq, Ord, Show, Read, Functor, Foldable, Traversable) zipWithT :: (Foldable f, Traversable g) => (a -> b -> c) -> f a -> g b -> g c zipWithT f as = snd . mapAccumL (\(a:as') b -> (as', f a b)) (toList as) zipWithDefaultT :: (Foldable f, Traversable g) => a -> (a -> b -> c) -> f a -> g b -> g c zipWithDefaultT z f as = zipWithT f (toList as ++ repeat z) class Iso a b where iso :: f a -> f b osi :: f b -> f a instance Iso a a where iso = id osi = id -- | 'AD' serves as a common wrapper for different 'Mode' instances, exposing a traditional -- numerical tower. Universal quantification is used to limit the actions in user code to -- machinery that will return the same answers under all AD modes, allowing us to use modes -- interchangeably as both the type level \"brand\" and dictionary, providing a common API. newtype AD f a = AD { runAD :: f a } deriving (Iso (f a), Lifted, Mode, Primal) -- > instance (Lifted f, Num a) => Num (AD f a) -- etc. let f = varT (mkName "f") in deriveNumeric (classP ''Lifted [f]:) (conT ''AD `appT` f) newtype Id a = Id a deriving (Iso a, Eq, Ord, Show, Enum, Bounded, Num, Real, Fractional, Floating, RealFrac, RealFloat, Monoid) probe :: a -> AD Id a probe a = AD (Id a) unprobe :: AD Id a -> a unprobe (AD (Id a)) = a pid :: f a -> f (Id a) pid = iso unpid :: f (Id a) -> f a unpid = osi probed :: f a -> f (AD Id a) probed = iso . pid unprobed :: f (AD Id a) -> f a unprobed = unpid . osi instance Functor Id where fmap f (Id a) = Id (f a) instance Applicative Id where pure = Id Id f <*> Id a = Id (f a) instance Monad Id where return = Id Id a >>= f = f a instance Lifted Id where (==!) = (==) compare1 = compare showsPrec1 = showsPrec fromInteger1 = fromInteger (+!) = (+) (-!) = (-) (*!) = (*) negate1 = negate abs1 = abs signum1 = signum (/!) = (/) recip1 = recip fromRational1 = fromRational toRational1 = toRational pi1 = pi exp1 = exp log1 = log sqrt1 = sqrt (**!) = (**) logBase1 = logBase sin1 = sin cos1 = cos tan1 = tan asin1 = asin acos1 = acos atan1 = atan sinh1 = sinh cosh1 = cosh tanh1 = tanh asinh1 = asinh acosh1 = acosh atanh1 = atanh properFraction1 = properFraction truncate1 = truncate round1 = round ceiling1 = ceiling floor1 = floor floatRadix1 = floatRadix floatDigits1 = floatDigits floatRange1 = floatRange decodeFloat1 = decodeFloat encodeFloat1 = encodeFloat exponent1 = exponent significand1 = significand scaleFloat1 = scaleFloat isNaN1 = isNaN isInfinite1 = isInfinite isDenormalized1 = isDenormalized isNegativeZero1 = isNegativeZero isIEEE1 = isIEEE atan21 = atan2 succ1 = succ pred1 = pred toEnum1 = toEnum fromEnum1 = fromEnum enumFrom1 = enumFrom enumFromThen1 = enumFromThen enumFromTo1 = enumFromTo enumFromThenTo1 = enumFromThenTo minBound1 = minBound maxBound1 = maxBound instance Mode Id where lift = Id Id a ^* b = Id (a * b) a *^ Id b = Id (a * b) Id a <+> Id b = Id (a + b) instance Primal Id where primal (Id a) = a