{-# LANGUAGE Rank2Types, TypeFamilies, MultiParamTypeClasses, FunctionalDependencies, FlexibleInstances, FlexibleContexts, TemplateHaskell, UndecidableInstances, TypeOperators #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.AD.Internal.Composition -- Copyright : (c) Edward Kmett 2010 -- License : BSD3 -- Maintainer : ekmett@gmail.com -- Stability : experimental -- Portability : GHC only -- -- Defines the composition of two AD modes as an AD mode in its own right ----------------------------------------------------------------------------- module Numeric.AD.Internal.Composition ( (:.)(..) , On , compose , decompose ) where import Numeric.AD.Classes import Numeric.AD.Internal newtype (f :. g) a = O { runO :: f (AD g a) } type On f g = g :. f compose :: AD f (AD g a) -> AD (f :. g) a compose (AD a) = AD (O a) decompose :: AD (f :. g) a -> AD f (AD g a) decompose (AD (O a)) = AD a instance (Primal f, Mode g, Primal g) => Primal (f :. g) where primal = primal . primal . runO instance (Mode f, Mode g) => Mode (f :. g) where lift = O . lift . lift O a <+> O b = O (a <+> b) a *^ O b = O (lift a *^ b) O a ^* b = O (a ^* lift b) O a ^/ b = O (a ^/ lift b) instance (Mode f, Mode g) => Lifted (f :. g) where showsPrec1 n (O a) = showsPrec1 n a O a ==! O b = a ==! b compare1 (O a) (O b) = compare1 a b fromInteger1 = O . lift . fromInteger1 O a +! O b = O (a +! b) O a -! O b = O (a -! b) O a *! O b = O (a *! b) negate1 (O a) = O (negate1 a) abs1 (O a) = O (abs1 a) signum1 (O a) = O (signum1 a) O a /! O b = O (a /! b) recip1 (O a) = O (recip1 a) fromRational1 = O . lift . fromRational1 toRational1 (O a) = toRational1 a pi1 = O pi1 exp1 (O a) = O (exp1 a) log1 (O a) = O (log1 a) sqrt1 (O a) = O (sqrt1 a) O a **! O b = O (a **! b) logBase1 (O a) (O b) = O (logBase1 a b) sin1 (O a) = O (sin1 a) cos1 (O a) = O (cos1 a) tan1 (O a) = O (tan1 a) asin1 (O a) = O (asin1 a) acos1 (O a) = O (acos1 a) atan1 (O a) = O (atan1 a) sinh1 (O a) = O (sinh1 a) cosh1 (O a) = O (cosh1 a) tanh1 (O a) = O (tanh1 a) asinh1 (O a) = O (asinh1 a) acosh1 (O a) = O (acosh1 a) atanh1 (O a) = O (atanh1 a) properFraction1 (O a) = (b, O c) where (b, c) = properFraction1 a truncate1 (O a) = truncate1 a round1 (O a) = round1 a ceiling1 (O a) = ceiling1 a floor1 (O a) = floor1 a floatRadix1 (O a) = floatRadix1 a floatDigits1 (O a) = floatDigits1 a floatRange1 (O a) = floatRange1 a decodeFloat1 (O a) = decodeFloat1 a encodeFloat1 m e = O (encodeFloat1 m e) exponent1 (O a) = exponent1 a significand1 (O a) = O (significand1 a) scaleFloat1 n (O a) = O (scaleFloat1 n a) isNaN1 (O a) = isNaN1 a isInfinite1 (O a) = isInfinite1 a isDenormalized1 (O a) = isDenormalized1 a isNegativeZero1 (O a) = isNegativeZero1 a isIEEE1 (O a) = isIEEE1 a atan21 (O a) (O b) = O (atan21 a b) succ1 (O a) = O (succ1 a) pred1 (O a) = O (pred1 a) toEnum1 n = O (toEnum1 n) fromEnum1 (O a) = fromEnum1 a enumFrom1 (O a) = map O $ enumFrom1 a enumFromThen1 (O a) (O b) = map O $ enumFromThen1 a b enumFromTo1 (O a) (O b) = map O $ enumFromTo1 a b enumFromThenTo1 (O a) (O b) (O c) = map O $ enumFromThenTo1 a b c minBound1 = O minBound1 maxBound1 = O maxBound1 -- deriveNumeric (conT `appT` varT (mkName "f") `appT` varT (mkName "g"))