{-# LANGUAGE Rank2Types, TypeFamilies, MultiParamTypeClasses, FunctionalDependencies, FlexibleInstances, FlexibleContexts, TemplateHaskell, UndecidableInstances, TypeOperators #-} -- {-# OPTIONS_HADDOCK hide, prune #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.AD.Internal.Composition -- Copyright : (c) Edward Kmett 2010 -- License : BSD3 -- Maintainer : ekmett@gmail.com -- Stability : experimental -- Portability : GHC only -- ----------------------------------------------------------------------------- module Numeric.AD.Internal.Composition ( ComposeFunctor(..) , ComposeMode(..) , composeMode , decomposeMode ) where import Control.Applicative import Data.Data (Data(..), mkDataType, DataType, mkConstr, Constr, constrIndex, Fixity(..)) import Data.Typeable (Typeable1(..), Typeable(..), TyCon, mkTyCon, mkTyConApp, typeOfDefault, gcast1) import Data.Foldable (Foldable(foldMap)) import Data.Traversable (Traversable(traverse)) import Numeric.AD.Internal.Classes import Numeric.AD.Internal.Types -- | Functor composition, used to nest the use of jacobian and grad newtype ComposeFunctor f g a = ComposeFunctor { decomposeFunctor :: f (g a) } instance (Functor f, Functor g) => Functor (ComposeFunctor f g) where fmap f (ComposeFunctor a) = ComposeFunctor (fmap (fmap f) a) instance (Foldable f, Foldable g) => Foldable (ComposeFunctor f g) where foldMap f (ComposeFunctor a) = foldMap (foldMap f) a instance (Traversable f, Traversable g) => Traversable (ComposeFunctor f g) where traverse f (ComposeFunctor a) = ComposeFunctor <$> traverse (traverse f) a instance (Typeable1 f, Typeable1 g) => Typeable1 (ComposeFunctor f g) where typeOf1 tfga = mkTyConApp composeFunctorTyCon [typeOf1 (fa tfga), typeOf1 (ga tfga)] where fa :: t f (g :: * -> *) a -> f a fa = undefined ga :: t (f :: * -> *) g a -> g a ga = undefined composeFunctorTyCon :: TyCon composeFunctorTyCon = mkTyCon "Numeric.AD.Internal.Composition.ComposeFunctor" {-# NOINLINE composeFunctorTyCon #-} composeFunctorConstr :: Constr composeFunctorConstr = mkConstr composeFunctorDataType "ComposeFunctor" [] Prefix {-# NOINLINE composeFunctorConstr #-} composeFunctorDataType :: DataType composeFunctorDataType = mkDataType "Numeric.AD.Internal.Composition.ComposeFunctor" [composeFunctorConstr] {-# NOINLINE composeFunctorDataType #-} instance (Typeable1 f, Typeable1 g, Data (f (g a)), Data a) => Data (ComposeFunctor f g a) where gfoldl f z (ComposeFunctor a) = z ComposeFunctor `f` a toConstr _ = composeFunctorConstr gunfold k z c = case constrIndex c of 1 -> k (z ComposeFunctor) _ -> error "gunfold" dataTypeOf _ = composeFunctorDataType dataCast1 f = gcast1 f -- | The composition of two AD modes is an AD mode in its own right newtype ComposeMode f g a = ComposeMode { runComposeMode :: f (AD g a) } composeMode :: AD f (AD g a) -> AD (ComposeMode f g) a composeMode (AD a) = AD (ComposeMode a) decomposeMode :: AD (ComposeMode f g) a -> AD f (AD g a) decomposeMode (AD (ComposeMode a)) = AD a instance (Primal f, Mode g, Primal g) => Primal (ComposeMode f g) where primal = primal . primal . runComposeMode instance (Mode f, Mode g) => Mode (ComposeMode f g) where lift = ComposeMode . lift . lift ComposeMode a <+> ComposeMode b = ComposeMode (a <+> b) a *^ ComposeMode b = ComposeMode (lift a *^ b) ComposeMode a ^* b = ComposeMode (a ^* lift b) ComposeMode a ^/ b = ComposeMode (a ^/ lift b) instance (Mode f, Mode g) => Lifted (ComposeMode f g) where showsPrec1 n (ComposeMode a) = showsPrec1 n a ComposeMode a ==! ComposeMode b = a ==! b compare1 (ComposeMode a) (ComposeMode b) = compare1 a b fromInteger1 = ComposeMode . lift . fromInteger1 ComposeMode a +! ComposeMode b = ComposeMode (a +! b) ComposeMode a -! ComposeMode b = ComposeMode (a -! b) ComposeMode a *! ComposeMode b = ComposeMode (a *! b) negate1 (ComposeMode a) = ComposeMode (negate1 a) abs1 (ComposeMode a) = ComposeMode (abs1 a) signum1 (ComposeMode a) = ComposeMode (signum1 a) ComposeMode a /! ComposeMode b = ComposeMode (a /! b) recip1 (ComposeMode a) = ComposeMode (recip1 a) fromRational1 = ComposeMode . lift . fromRational1 toRational1 (ComposeMode a) = toRational1 a pi1 = ComposeMode pi1 exp1 (ComposeMode a) = ComposeMode (exp1 a) log1 (ComposeMode a) = ComposeMode (log1 a) sqrt1 (ComposeMode a) = ComposeMode (sqrt1 a) ComposeMode a **! ComposeMode b = ComposeMode (a **! b) logBase1 (ComposeMode a) (ComposeMode b) = ComposeMode (logBase1 a b) sin1 (ComposeMode a) = ComposeMode (sin1 a) cos1 (ComposeMode a) = ComposeMode (cos1 a) tan1 (ComposeMode a) = ComposeMode (tan1 a) asin1 (ComposeMode a) = ComposeMode (asin1 a) acos1 (ComposeMode a) = ComposeMode (acos1 a) atan1 (ComposeMode a) = ComposeMode (atan1 a) sinh1 (ComposeMode a) = ComposeMode (sinh1 a) cosh1 (ComposeMode a) = ComposeMode (cosh1 a) tanh1 (ComposeMode a) = ComposeMode (tanh1 a) asinh1 (ComposeMode a) = ComposeMode (asinh1 a) acosh1 (ComposeMode a) = ComposeMode (acosh1 a) atanh1 (ComposeMode a) = ComposeMode (atanh1 a) properFraction1 (ComposeMode a) = (b, ComposeMode c) where (b, c) = properFraction1 a truncate1 (ComposeMode a) = truncate1 a round1 (ComposeMode a) = round1 a ceiling1 (ComposeMode a) = ceiling1 a floor1 (ComposeMode a) = floor1 a floatRadix1 (ComposeMode a) = floatRadix1 a floatDigits1 (ComposeMode a) = floatDigits1 a floatRange1 (ComposeMode a) = floatRange1 a decodeFloat1 (ComposeMode a) = decodeFloat1 a encodeFloat1 m e = ComposeMode (encodeFloat1 m e) exponent1 (ComposeMode a) = exponent1 a significand1 (ComposeMode a) = ComposeMode (significand1 a) scaleFloat1 n (ComposeMode a) = ComposeMode (scaleFloat1 n a) isNaN1 (ComposeMode a) = isNaN1 a isInfinite1 (ComposeMode a) = isInfinite1 a isDenormalized1 (ComposeMode a) = isDenormalized1 a isNegativeZero1 (ComposeMode a) = isNegativeZero1 a isIEEE1 (ComposeMode a) = isIEEE1 a atan21 (ComposeMode a) (ComposeMode b) = ComposeMode (atan21 a b) succ1 (ComposeMode a) = ComposeMode (succ1 a) pred1 (ComposeMode a) = ComposeMode (pred1 a) toEnum1 n = ComposeMode (toEnum1 n) fromEnum1 (ComposeMode a) = fromEnum1 a enumFrom1 (ComposeMode a) = map ComposeMode $ enumFrom1 a enumFromThen1 (ComposeMode a) (ComposeMode b) = map ComposeMode $ enumFromThen1 a b enumFromTo1 (ComposeMode a) (ComposeMode b) = map ComposeMode $ enumFromTo1 a b enumFromThenTo1 (ComposeMode a) (ComposeMode b) (ComposeMode c) = map ComposeMode $ enumFromThenTo1 a b c minBound1 = ComposeMode minBound1 maxBound1 = ComposeMode maxBound1 instance (Typeable1 f, Typeable1 g) => Typeable1 (ComposeMode f g) where typeOf1 tfga = mkTyConApp composeModeTyCon [typeOf1 (fa tfga), typeOf1 (ga tfga)] where fa :: t f (g :: * -> *) a -> f a fa = undefined ga :: t (f :: * -> *) g a -> g a ga = undefined instance (Typeable1 f, Typeable1 g, Typeable a) => Typeable (ComposeMode f g a) where typeOf = typeOfDefault composeModeTyCon :: TyCon composeModeTyCon = mkTyCon "Numeric.AD.Internal.Composition.ComposeMode" {-# NOINLINE composeModeTyCon #-} composeModeConstr :: Constr composeModeConstr = mkConstr composeModeDataType "ComposeMode" [] Prefix {-# NOINLINE composeModeConstr #-} composeModeDataType :: DataType composeModeDataType = mkDataType "Numeric.AD.Internal.Composition.ComposeMode" [composeModeConstr] {-# NOINLINE composeModeDataType #-} instance (Typeable1 f, Typeable1 g, Data (f (AD g a)), Data a) => Data (ComposeMode f g a) where gfoldl f z (ComposeMode a) = z ComposeMode `f` a toConstr _ = composeModeConstr gunfold k z c = case constrIndex c of 1 -> k (z ComposeMode) _ -> error "gunfold" dataTypeOf _ = composeModeDataType dataCast1 f = gcast1 f