module Numeric.AD.Internal.Composition
( (:.)(..)
, On
, compose
, decompose
) where
import Numeric.AD.Classes
import Numeric.AD.Internal
newtype (f :. g) a = O { runO :: f (AD g a) }
type On f g = g :. f
compose :: AD f (AD g a) -> AD (f :. g) a
compose (AD a) = AD (O a)
decompose :: AD (f :. g) a -> AD f (AD g a)
decompose (AD (O a)) = AD a
instance (Primal f, Mode g, Primal g) => Primal (f :. g) where
primal = primal . primal . runO
instance (Mode f, Mode g) => Mode (f :. g) where
lift = O . lift . lift
O a <+> O b = O (a <+> b)
a *^ O b = O (lift a *^ b)
O a ^* b = O (a ^* lift b)
O a ^/ b = O (a ^/ lift b)
instance (Mode f, Mode g) => Lifted (f :. g) where
showsPrec1 n (O a) = showsPrec1 n a
O a ==! O b = a ==! b
compare1 (O a) (O b) = compare1 a b
fromInteger1 = O . lift . fromInteger1
O a +! O b = O (a +! b)
O a -! O b = O (a -! b)
O a *! O b = O (a *! b)
negate1 (O a) = O (negate1 a)
abs1 (O a) = O (abs1 a)
signum1 (O a) = O (signum1 a)
O a /! O b = O (a /! b)
recip1 (O a) = O (recip1 a)
fromRational1 = O . lift . fromRational1
toRational1 (O a) = toRational1 a
pi1 = O pi1
exp1 (O a) = O (exp1 a)
log1 (O a) = O (log1 a)
sqrt1 (O a) = O (sqrt1 a)
O a **! O b = O (a **! b)
logBase1 (O a) (O b) = O (logBase1 a b)
sin1 (O a) = O (sin1 a)
cos1 (O a) = O (cos1 a)
tan1 (O a) = O (tan1 a)
asin1 (O a) = O (asin1 a)
acos1 (O a) = O (acos1 a)
atan1 (O a) = O (atan1 a)
sinh1 (O a) = O (sinh1 a)
cosh1 (O a) = O (cosh1 a)
tanh1 (O a) = O (tanh1 a)
asinh1 (O a) = O (asinh1 a)
acosh1 (O a) = O (acosh1 a)
atanh1 (O a) = O (atanh1 a)
properFraction1 (O a) = (b, O c) where
(b, c) = properFraction1 a
truncate1 (O a) = truncate1 a
round1 (O a) = round1 a
ceiling1 (O a) = ceiling1 a
floor1 (O a) = floor1 a
floatRadix1 (O a) = floatRadix1 a
floatDigits1 (O a) = floatDigits1 a
floatRange1 (O a) = floatRange1 a
decodeFloat1 (O a) = decodeFloat1 a
encodeFloat1 m e = O (encodeFloat1 m e)
exponent1 (O a) = exponent1 a
significand1 (O a) = O (significand1 a)
scaleFloat1 n (O a) = O (scaleFloat1 n a)
isNaN1 (O a) = isNaN1 a
isInfinite1 (O a) = isInfinite1 a
isDenormalized1 (O a) = isDenormalized1 a
isNegativeZero1 (O a) = isNegativeZero1 a
isIEEE1 (O a) = isIEEE1 a
atan21 (O a) (O b) = O (atan21 a b)
succ1 (O a) = O (succ1 a)
pred1 (O a) = O (pred1 a)
toEnum1 n = O (toEnum1 n)
fromEnum1 (O a) = fromEnum1 a
enumFrom1 (O a) = map O $ enumFrom1 a
enumFromThen1 (O a) (O b) = map O $ enumFromThen1 a b
enumFromTo1 (O a) (O b) = map O $ enumFromTo1 a b
enumFromThenTo1 (O a) (O b) (O c) = map O $ enumFromThenTo1 a b c
minBound1 = O minBound1
maxBound1 = O maxBound1