{-# LANGUAGE Rank2Types, TypeFamilies, MultiParamTypeClasses, FunctionalDependencies, FlexibleInstances, FlexibleContexts, TemplateHaskell, UndecidableInstances, TypeOperators #-}
-- {-# OPTIONS_HADDOCK hide, prune #-}
-----------------------------------------------------------------------------
-- |
-- Module : Numeric.AD.Internal.Composition
-- Copyright : (c) Edward Kmett 2010
-- License : BSD3
-- Maintainer : ekmett@gmail.com
-- Stability : experimental
-- Portability : GHC only
--
-----------------------------------------------------------------------------
module Numeric.AD.Internal.Composition
( ComposeFunctor(..)
, ComposeMode(..)
, composeMode
, decomposeMode
) where
import Control.Applicative
import Data.Data (Data(..), mkDataType, DataType, mkConstr, Constr, constrIndex, Fixity(..))
import Data.Typeable (Typeable1(..), Typeable(..), TyCon, mkTyCon, mkTyConApp, typeOfDefault, gcast1)
import Data.Foldable (Foldable(foldMap))
import Data.Traversable (Traversable(traverse))
import Numeric.AD.Internal.Classes
import Numeric.AD.Internal.Types
-- | Functor composition, used to nest the use of jacobian and grad
newtype ComposeFunctor f g a = ComposeFunctor { decomposeFunctor :: f (g a) }
instance (Functor f, Functor g) => Functor (ComposeFunctor f g) where
fmap f (ComposeFunctor a) = ComposeFunctor (fmap (fmap f) a)
instance (Foldable f, Foldable g) => Foldable (ComposeFunctor f g) where
foldMap f (ComposeFunctor a) = foldMap (foldMap f) a
instance (Traversable f, Traversable g) => Traversable (ComposeFunctor f g) where
traverse f (ComposeFunctor a) = ComposeFunctor <$> traverse (traverse f) a
instance (Typeable1 f, Typeable1 g) => Typeable1 (ComposeFunctor f g) where
typeOf1 tfga = mkTyConApp composeFunctorTyCon [typeOf1 (fa tfga), typeOf1 (ga tfga)]
where fa :: t f (g :: * -> *) a -> f a
fa = undefined
ga :: t (f :: * -> *) g a -> g a
ga = undefined
composeFunctorTyCon :: TyCon
composeFunctorTyCon = mkTyCon "Numeric.AD.Internal.Composition.ComposeFunctor"
{-# NOINLINE composeFunctorTyCon #-}
composeFunctorConstr :: Constr
composeFunctorConstr = mkConstr composeFunctorDataType "ComposeFunctor" [] Prefix
{-# NOINLINE composeFunctorConstr #-}
composeFunctorDataType :: DataType
composeFunctorDataType = mkDataType "Numeric.AD.Internal.Composition.ComposeFunctor" [composeFunctorConstr]
{-# NOINLINE composeFunctorDataType #-}
instance (Typeable1 f, Typeable1 g, Data (f (g a)), Data a) => Data (ComposeFunctor f g a) where
gfoldl f z (ComposeFunctor a) = z ComposeFunctor `f` a
toConstr _ = composeFunctorConstr
gunfold k z c = case constrIndex c of
1 -> k (z ComposeFunctor)
_ -> error "gunfold"
dataTypeOf _ = composeFunctorDataType
dataCast1 f = gcast1 f
-- | The composition of two AD modes is an AD mode in its own right
newtype ComposeMode f g a = ComposeMode { runComposeMode :: f (AD g a) }
composeMode :: AD f (AD g a) -> AD (ComposeMode f g) a
composeMode (AD a) = AD (ComposeMode a)
decomposeMode :: AD (ComposeMode f g) a -> AD f (AD g a)
decomposeMode (AD (ComposeMode a)) = AD a
instance (Primal f, Mode g, Primal g) => Primal (ComposeMode f g) where
primal = primal . primal . runComposeMode
instance (Mode f, Mode g) => Mode (ComposeMode f g) where
lift = ComposeMode . lift . lift
ComposeMode a <+> ComposeMode b = ComposeMode (a <+> b)
a *^ ComposeMode b = ComposeMode (lift a *^ b)
ComposeMode a ^* b = ComposeMode (a ^* lift b)
ComposeMode a ^/ b = ComposeMode (a ^/ lift b)
instance (Mode f, Mode g) => Lifted (ComposeMode f g) where
showsPrec1 n (ComposeMode a) = showsPrec1 n a
ComposeMode a ==! ComposeMode b = a ==! b
compare1 (ComposeMode a) (ComposeMode b) = compare1 a b
fromInteger1 = ComposeMode . lift . fromInteger1
ComposeMode a +! ComposeMode b = ComposeMode (a +! b)
ComposeMode a -! ComposeMode b = ComposeMode (a -! b)
ComposeMode a *! ComposeMode b = ComposeMode (a *! b)
negate1 (ComposeMode a) = ComposeMode (negate1 a)
abs1 (ComposeMode a) = ComposeMode (abs1 a)
signum1 (ComposeMode a) = ComposeMode (signum1 a)
ComposeMode a /! ComposeMode b = ComposeMode (a /! b)
recip1 (ComposeMode a) = ComposeMode (recip1 a)
fromRational1 = ComposeMode . lift . fromRational1
toRational1 (ComposeMode a) = toRational1 a
pi1 = ComposeMode pi1
exp1 (ComposeMode a) = ComposeMode (exp1 a)
log1 (ComposeMode a) = ComposeMode (log1 a)
sqrt1 (ComposeMode a) = ComposeMode (sqrt1 a)
ComposeMode a **! ComposeMode b = ComposeMode (a **! b)
logBase1 (ComposeMode a) (ComposeMode b) = ComposeMode (logBase1 a b)
sin1 (ComposeMode a) = ComposeMode (sin1 a)
cos1 (ComposeMode a) = ComposeMode (cos1 a)
tan1 (ComposeMode a) = ComposeMode (tan1 a)
asin1 (ComposeMode a) = ComposeMode (asin1 a)
acos1 (ComposeMode a) = ComposeMode (acos1 a)
atan1 (ComposeMode a) = ComposeMode (atan1 a)
sinh1 (ComposeMode a) = ComposeMode (sinh1 a)
cosh1 (ComposeMode a) = ComposeMode (cosh1 a)
tanh1 (ComposeMode a) = ComposeMode (tanh1 a)
asinh1 (ComposeMode a) = ComposeMode (asinh1 a)
acosh1 (ComposeMode a) = ComposeMode (acosh1 a)
atanh1 (ComposeMode a) = ComposeMode (atanh1 a)
properFraction1 (ComposeMode a) = (b, ComposeMode c) where
(b, c) = properFraction1 a
truncate1 (ComposeMode a) = truncate1 a
round1 (ComposeMode a) = round1 a
ceiling1 (ComposeMode a) = ceiling1 a
floor1 (ComposeMode a) = floor1 a
floatRadix1 (ComposeMode a) = floatRadix1 a
floatDigits1 (ComposeMode a) = floatDigits1 a
floatRange1 (ComposeMode a) = floatRange1 a
decodeFloat1 (ComposeMode a) = decodeFloat1 a
encodeFloat1 m e = ComposeMode (encodeFloat1 m e)
exponent1 (ComposeMode a) = exponent1 a
significand1 (ComposeMode a) = ComposeMode (significand1 a)
scaleFloat1 n (ComposeMode a) = ComposeMode (scaleFloat1 n a)
isNaN1 (ComposeMode a) = isNaN1 a
isInfinite1 (ComposeMode a) = isInfinite1 a
isDenormalized1 (ComposeMode a) = isDenormalized1 a
isNegativeZero1 (ComposeMode a) = isNegativeZero1 a
isIEEE1 (ComposeMode a) = isIEEE1 a
atan21 (ComposeMode a) (ComposeMode b) = ComposeMode (atan21 a b)
succ1 (ComposeMode a) = ComposeMode (succ1 a)
pred1 (ComposeMode a) = ComposeMode (pred1 a)
toEnum1 n = ComposeMode (toEnum1 n)
fromEnum1 (ComposeMode a) = fromEnum1 a
enumFrom1 (ComposeMode a) = map ComposeMode $ enumFrom1 a
enumFromThen1 (ComposeMode a) (ComposeMode b) = map ComposeMode $ enumFromThen1 a b
enumFromTo1 (ComposeMode a) (ComposeMode b) = map ComposeMode $ enumFromTo1 a b
enumFromThenTo1 (ComposeMode a) (ComposeMode b) (ComposeMode c) = map ComposeMode $ enumFromThenTo1 a b c
minBound1 = ComposeMode minBound1
maxBound1 = ComposeMode maxBound1
instance (Typeable1 f, Typeable1 g) => Typeable1 (ComposeMode f g) where
typeOf1 tfga = mkTyConApp composeModeTyCon [typeOf1 (fa tfga), typeOf1 (ga tfga)]
where fa :: t f (g :: * -> *) a -> f a
fa = undefined
ga :: t (f :: * -> *) g a -> g a
ga = undefined
instance (Typeable1 f, Typeable1 g, Typeable a) => Typeable (ComposeMode f g a) where
typeOf = typeOfDefault
composeModeTyCon :: TyCon
composeModeTyCon = mkTyCon "Numeric.AD.Internal.Composition.ComposeMode"
{-# NOINLINE composeModeTyCon #-}
composeModeConstr :: Constr
composeModeConstr = mkConstr composeModeDataType "ComposeMode" [] Prefix
{-# NOINLINE composeModeConstr #-}
composeModeDataType :: DataType
composeModeDataType = mkDataType "Numeric.AD.Internal.Composition.ComposeMode" [composeModeConstr]
{-# NOINLINE composeModeDataType #-}
instance (Typeable1 f, Typeable1 g, Data (f (AD g a)), Data a) => Data (ComposeMode f g a) where
gfoldl f z (ComposeMode a) = z ComposeMode `f` a
toConstr _ = composeModeConstr
gunfold k z c = case constrIndex c of
1 -> k (z ComposeMode)
_ -> error "gunfold"
dataTypeOf _ = composeModeDataType
dataCast1 f = gcast1 f