{-# LANGUAGE BangPatterns, TemplateHaskell, ScopedTypeVariables, DeriveDataTypeable, FlexibleContexts, UndecidableInstances #-} -- {-# OPTIONS_HADDOCK hide #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.AD.Internal.Iterated -- Copyright : (c) Edward Kmett 2010 -- License : BSD3 -- Maintainer : ekmett@gmail.com -- Stability : experimental -- Portability : GHC only -- ----------------------------------------------------------------------------- module Numeric.AD.Internal.Iterated ( Iterated(..) , tailI , unfoldI , bundle , bind ) where import Control.Applicative import Data.Monoid import Data.Foldable import Data.Traversable import Data.Data (Data(..), mkDataType, DataType, mkConstr, Constr, constrIndex, Fixity(Infix)) import Data.Typeable (Typeable1(..), Typeable(..), TyCon, mkTyCon, mkTyConApp, typeOfDefault, gcast1) import Numeric.AD.Internal.Types import Numeric.AD.Internal.Classes import Numeric.AD.Internal.Comonad import Numeric.AD.Internal.Combinators (on) -- import qualified Numeric.AD.Internal.Forward import Numeric.AD.Internal.Forward (Forward(..)) import Language.Haskell.TH infixl 3 :| data Iterated f a = a :| f (Iterated f a) bundle :: Num a => a -> a -> AD (Iterated Forward) a bundle a b = AD (a :| Forward (lift a) (lift b)) bind :: (Traversable f, Num a) => (f (AD (Iterated Forward) a) -> b) -> f a -> f b bind f as = snd $ mapAccumL outer (0 :: Int) as where outer !i _ = (i + 1, f $ snd $ mapAccumL (inner i) 0 as) inner !i !j a = (j + 1, bundle a $ if i == j then 1 else 0) instance Functor f => Functor (Iterated f) where fmap f (a :| as) = f a :| fmap f <$> as instance Functor f => Copointed (Iterated f) where extract (a :| _) = a instance Functor f => Comonad (Iterated f) where duplicate aas@(_ :| as) = aas :| duplicate <$> as extend f aas@(_ :| as) = f aas :| extend f <$> as instance Foldable f => Foldable (Iterated f) where foldMap f (a :| as) = f a `mappend` foldMap (foldMap f) as instance Traversable f => Traversable (Iterated f) where traverse f (a :| as) = (:|) <$> f a <*> traverse (traverse f) as -- tails of the f-branching stream comonad/cofree comonad tailI :: (Iterated f a) -> f (Iterated f a) tailI (_ :| as) = as unfoldI :: Functor f => (a -> (b, f a)) -> a -> Iterated f b unfoldI f a = h :| unfoldI f <$> t where (h, t) = f a instance Primal (Iterated f) where primal (a :| _) = a instance Mode f => Mode (Iterated f) where lift a = as where as = a :| lift as (a :| as) <+> (b :| bs) = (a + b) :| (as <+> bs) a *^ (b :| bs) = (a * b) :| (lift a *^ bs) (a :| as) ^* b = (a * b) :| (as ^* lift b) (a :| as) ^/ b = (a / b) :| (as ^/ lift b) instance Mode f => Lifted (Iterated f) where showsPrec1 n (a :| _) = showsPrec n a (==!) = (==) `on` primal compare1 = compare `on` primal fromInteger1 a = fromInteger a :| fromInteger1 a (a :| as) +! (b :| bs) = (a + b) :| (as +! bs) (a :| as) -! (b :| bs) = (a - b) :| (as -! bs) (a :| as) *! (b :| bs) = (a * b) :| (as *! bs) negate1 (a :| as) = negate a :| negate1 as abs1 (a :| as) = abs a :| abs1 as signum1 (a :| as) = signum a :| signum1 as (a :| as) /! (b :| bs) = (a / b) :| (as /! bs) recip1 (a :| as) = recip a :| recip1 as fromRational1 n = fromRational n :| fromRational1 n toRational1 = toRational . primal pi1 = pi :| pi1 exp1 (a :| as) = exp a :| exp1 as log1 (a :| as) = log a :| log1 as sqrt1 (a :| as) = sqrt a :| sqrt1 as (a :| as) **! (b :| bs) = (a ** b) :| (as **! bs) logBase1 (a :| as) (b :| bs) = logBase a b :| logBase1 as bs sin1 (a :| as) = sin a :| sin1 as cos1 (a :| as) = cos a :| cos1 as tan1 (a :| as) = tan a :| tan1 as asin1 (a :| as) = asin a :| asin1 as acos1 (a :| as) = acos a :| acos1 as atan1 (a :| as) = atan a :| atan1 as sinh1 (a :| as) = sinh a :| sinh1 as cosh1 (a :| as) = cosh a :| cosh1 as tanh1 (a :| as) = tanh a :| tanh1 as asinh1 (a :| as) = asinh a :| asinh1 as acosh1 (a :| as) = acosh a :| acosh1 as atanh1 (a :| as) = atanh a :| atanh1 as properFraction1 (a :| as) = (b, c :| cs) where (b, c) = properFraction a (_ :: Int, cs) = properFraction1 as truncate1 = truncate . primal round1 = round . primal ceiling1 = ceiling . primal floor1 = floor . primal floatRadix1 = floatRadix . primal floatDigits1 = floatDigits . primal floatRange1 = floatRange . primal decodeFloat1 = decodeFloat . primal encodeFloat1 m e = encodeFloat m e :| encodeFloat1 m e exponent1 = exponent . primal significand1 (a :| as) = significand a :| significand1 as scaleFloat1 n (a :| as) = scaleFloat n a :| scaleFloat1 n as isNaN1 = isNaN . primal isInfinite1 = isInfinite . primal isDenormalized1 = isDenormalized . primal isNegativeZero1 = isNegativeZero . primal isIEEE1 = isIEEE . primal atan21 (a :| as) (b :| bs) = atan2 a b :| atan21 as bs succ1 (a :| as) = succ a :| succ1 as pred1 (a :| as) = pred a :| pred1 as toEnum1 n = toEnum n :| toEnum1 n fromEnum1 = fromEnum . primal enumFrom1 = error "TODO" enumFromThen1 = error "TODO" enumFromTo1 = error "TODO" enumFromThenTo1 = error "TODO" minBound1 = minBound :| minBound1 maxBound1 = maxBound :| maxBound1 -- TODO: -- instance (Mode f, Foo a) => Foo (Iterated f) ... deriveNumeric (classP (mkName "Mode") [varT $ mkName "f"]:) (conT (mkName "Iterated") `appT` varT (mkName "f")) instance Typeable1 f => Typeable1 (Iterated f) where typeOf1 tfa = mkTyConApp iteratedTyCon [typeOf1 (undefined `asArgsType` tfa)] where asArgsType :: f a -> t f a -> f a asArgsType = const instance (Typeable1 f, Typeable a) => Typeable (Iterated f a) where typeOf = typeOfDefault iteratedTyCon :: TyCon iteratedTyCon = mkTyCon "Numeric.AD.Internal.Iterated.Iterated" {-# NOINLINE iteratedTyCon #-} consConstr :: Constr consConstr = mkConstr iteratedDataType "(:|)" [] Infix {-# NOINLINE consConstr #-} iteratedDataType :: DataType iteratedDataType = mkDataType "Numeric.AD.Internal.Iterated.Iterated" [consConstr] {-# NOINLINE iteratedDataType #-} instance (Typeable1 f, Data (f (Iterated f a)), Data a) => Data (Iterated f a) where gfoldl f z (a :| as) = z (:|) `f` a `f` as toConstr _ = consConstr gunfold k z c = case constrIndex c of 1 -> k (k (z (:|))) _ -> error "gunfold" dataTypeOf _ = iteratedDataType dataCast1 f = gcast1 f