module Numeric.AD.Internal.Stream
( (:>)(..)
, Comonad(..)
, unfold
, tails
) where
import Control.Applicative
import Data.Monoid
import Data.Foldable
import Data.Traversable
import Numeric.AD.Internal
import Language.Haskell.TH
infixl 3 :<, :>
class Functor f => Comonad f where
extract :: (f :> a) -> a
duplicate :: (f :> a) -> (f :> (f :> a))
extend :: ((f :> a) -> b) -> (f :> a) -> (f :> b)
data (f :> a) = a :< f (f :> a)
instance Functor f => Functor ((:>)f) where
fmap f (a :< as) = f a :< fmap f <$> as
instance Functor f => Comonad ((:>) f) where
extract (a :< _) = a
duplicate aas@(_ :< as) = aas :< duplicate <$> as
extend f aas@(_ :< as) = f aas :< extend f <$> as
instance Foldable f => Foldable ((:>) f) where
foldMap f (a :< as) = f a `mappend` foldMap (foldMap f) as
instance Traversable f => Traversable ((:>) f) where
traverse f (a :< as) = (:<) <$> f a <*> traverse (traverse f) as
tails :: (f :> a) -> f (f :> a)
tails (_ :< as) = as
unfold :: Functor f => (a -> (b, f a)) -> a -> (f :> b)
unfold f a = h :< unfold f <$> t
where
(h, t) = f a
instance Primal ((:>) f) where
primal (a :< _) = a
instance Mode f => Mode ((:>) f) where
lift a = as
where as = a :< lift as
(a :< as) <+> (b :< bs) = (a + b) :< (as <+> bs)
a *^ (b :< bs) = (a * b) :< (lift a *^ bs)
(a :< as) ^* b = (a * b) :< (as ^* lift b)
(a :< as) ^/ b = (a / b) :< (as ^/ lift b)
instance Mode f => Lifted ((:>) f) where
showsPrec1 n (a :< _) = showsPrec n a
(==!) = (==) `on` primal
compare1 = compare `on` primal
fromInteger1 a = fromInteger a :< fromInteger1 a
(a :< as) +! (b :< bs) = (a + b) :< (as +! bs)
(a :< as) -! (b :< bs) = (a b) :< (as -! bs)
(a :< as) *! (b :< bs) = (a * b) :< (as *! bs)
negate1 (a :< as) = negate a :< negate1 as
abs1 (a :< as) = abs a :< abs1 as
signum1 (a :< as) = signum a :< signum1 as
(a :< as) /! (b :< bs) = (a / b) :< (as /! bs)
recip1 (a :< as) = recip a :< recip1 as
fromRational1 n = fromRational n :< fromRational1 n
toRational1 = toRational . primal
pi1 = pi :< pi1
exp1 (a :< as) = exp a :< exp1 as
log1 (a :< as) = log a :< log1 as
sqrt1 (a :< as) = sqrt a :< sqrt1 as
(a :< as) **! (b :< bs) = (a ** b) :< (as **! bs)
logBase1 (a :< as) (b :< bs) = logBase a b :< logBase1 as bs
sin1 (a :< as) = sin a :< sin1 as
cos1 (a :< as) = cos a :< cos1 as
tan1 (a :< as) = tan a :< tan1 as
asin1 (a :< as) = asin a :< asin1 as
acos1 (a :< as) = acos a :< acos1 as
atan1 (a :< as) = atan a :< atan1 as
sinh1 (a :< as) = sinh a :< sinh1 as
cosh1 (a :< as) = cosh a :< cosh1 as
tanh1 (a :< as) = tanh a :< tanh1 as
asinh1 (a :< as) = asinh a :< asinh1 as
acosh1 (a :< as) = acosh a :< acosh1 as
atanh1 (a :< as) = atanh a :< atanh1 as
properFraction1 (a :< as) = (b, c :< cs)
where
(b, c) = properFraction a
(_ :: Int, cs) = properFraction1 as
truncate1 = truncate . primal
round1 = round . primal
ceiling1 = ceiling . primal
floor1 = floor . primal
floatRadix1 = floatRadix . primal
floatDigits1 = floatDigits . primal
floatRange1 = floatRange . primal
decodeFloat1 = decodeFloat . primal
encodeFloat1 m e = encodeFloat m e :< encodeFloat1 m e
exponent1 = exponent . primal
significand1 (a :< as) = significand a :< significand1 as
scaleFloat1 n (a :< as) = scaleFloat n a :< scaleFloat1 n as
isNaN1 = isNaN . primal
isInfinite1 = isInfinite . primal
isDenormalized1 = isDenormalized . primal
isNegativeZero1 = isNegativeZero . primal
isIEEE1 = isIEEE . primal
atan21 (a :< as) (b :< bs) = atan2 a b :< atan21 as bs
succ1 (a :< as) = succ a :< succ1 as
pred1 (a :< as) = pred a :< pred1 as
toEnum1 n = toEnum n :< toEnum1 n
fromEnum1 = fromEnum . primal
enumFrom1 = error "TODO"
enumFromThen1 = error "TODO"
enumFromTo1 = error "TODO"
enumFromThenTo1 = error "TODO"
minBound1 = minBound :< minBound1
maxBound1 = maxBound :< maxBound1
deriveNumeric
(classP (mkName "Mode") [varT $ mkName "f"]:)
(conT (mkName ":>") `appT` varT (mkName "f"))