module Numeric.AD.Internal.Tensors
( Tensors(..)
, headT
, tailT
, tensors
) where
import Control.Applicative
import Data.Foldable
import Data.Traversable
import Data.Monoid
import Data.Typeable (Typeable1(..), Typeable(..), TyCon, mkTyCon, mkTyConApp, typeOfDefault, gcast1)
import Numeric.AD.Internal.Comonad
import Numeric.AD.Internal.Stream
infixl 3 :-
data Tensors f a = a :- Tensors f (f a)
instance Functor f => Functor (Tensors f) where
fmap f (a :- as) = f a :- fmap (fmap f) as
instance Foldable f => Foldable (Tensors f) where
foldMap f (a :- as) = f a `mappend` foldMap (foldMap f) as
instance Traversable f => Traversable (Tensors f) where
traverse f (a :- as) = (:-) <$> f a <*> traverse (traverse f) as
instance Functor f => Copointed (Tensors f) where
extract (a :- _) = a
tailT :: Tensors f a -> Tensors f (f a)
tailT (_ :- as) = as
headT :: Tensors f a -> a
headT (a :- _) = a
tensors :: Functor f => Stream f a -> Tensors f a
tensors (a :< as) = a :- distribute (tensors <$> as)
where
distribute :: Functor f => f (Tensors f a) -> Tensors f (f a)
distribute x = (headT <$> x) :- distribute (tailT <$> x)
instance Typeable1 f => Typeable1 (Tensors f) where
typeOf1 tfa = mkTyConApp tensorsTyCon [typeOf1 (undefined `asArgsType` tfa)]
where asArgsType :: f a -> t f a -> f a
asArgsType = const
instance (Typeable1 f, Typeable a) => Typeable (Tensors f a) where
typeOf = typeOfDefault
tensorsTyCon :: TyCon
tensorsTyCon = mkTyCon "Numeric.AD.Internal.Tensors.Tensors"