-- | Tensor product (that is, pairs) of monomials {-# LANGUAGE CPP, BangPatterns, TypeFamilies, UnicodeSyntax, KindSignatures, DataKinds #-} module Math.Algebra.Polynomial.Monomial.Tensor where -------------------------------------------------------------------------------- import Data.Typeable import Data.Either import Data.Proxy import GHC.TypeLits #if MIN_VERSION_base(4,11,0) import Data.Semigroup import Data.Monoid #else import Data.Monoid #endif import Math.Algebra.Polynomial.Class import Math.Algebra.Polynomial.Pretty -------------------------------------------------------------------------------- -- | Elementary tensors (basically pairs). The phantom type parameter -- @symbol@ is used to render an infix symbol when pretty-printing data Tensor (symbol :: Symbol) (a :: *) (b :: *) = Tensor !a !b deriving (Eq,Ord,Show,Typeable) instance (Semigroup a, Semigroup b) => Semigroup (Tensor sym a b) where (<>) (Tensor x1 y1) (Tensor x2 y2) = Tensor (x1<>x2) (y1<>y2) instance (Monoid a, Monoid b) => Monoid (Tensor sym a b) where mempty = Tensor mempty mempty mappend (Tensor x1 y1) (Tensor x2 y2) = Tensor (x1 `mappend` x2) (y1 `mappend` y2) instance (KnownSymbol sym, Pretty a, Pretty b) => Pretty (Tensor sym a b) where pretty t@(Tensor a b) = pretty a ++ tensorSymbol t ++ pretty b tensorSymbol :: KnownSymbol sym => Tensor sym a b -> String tensorSymbol = symbolVal . symProxy where symProxy :: Tensor sym a b -> Proxy sym symProxy _ = Proxy -------------------------------------------------------------------------------- flip :: Tensor sym a b -> Tensor sym b a flip (Tensor x y) = Tensor y x -------------------------------------------------------------------------------- -- * Injections injLeft :: Monoid b => a -> Tensor sym a b injLeft x = Tensor x mempty injRight :: Monoid a => b -> Tensor sym a b injRight x = Tensor mempty x -------------------------------------------------------------------------------- -- * Projections projLeft :: Tensor sym a b -> a projLeft (Tensor x _) = x projRight :: Tensor sym a b -> b projRight (Tensor _ y) = y -------------------------------------------------------------------------------- -- * differentiation diffTensor :: (Monomial a, Monomial b, Num c) => Either (VarM a) (VarM b) -> Int -> Tensor sym a b -> Maybe (Tensor sym a b, c) diffTensor ei k (Tensor left right) = case ei of Left v -> case diffM v k left of Just (left' ,c) -> Just (Tensor left' right , c) Nothing -> Nothing Right v -> case diffM v k right of Just (right',c) -> Just (Tensor left right', c) Nothing -> Nothing -------------------------------------------------------------------------------- instance (KnownSymbol sym, Monomial a, Monomial b) => Monomial (Tensor sym a b) where type VarM (Tensor sym a b) = Either (VarM a) (VarM b) -- checking the invariant normalizeM (Tensor x y) = Tensor (normalizeM x) (normalizeM y) isNormalM (Tensor x y) = isNormalM x && isNormalM y -- construction and deconstruction fromListM list = Tensor (fromListM list1) (fromListM list2) where (list1,list2) = partitionEithers $ map distEither list toListM (Tensor x y) = map f (toListM x) ++ map g (toListM y) where f (v,e) = (Left v, e) g (v,e) = (Right v, e) -- simple monomials emptyM = Tensor emptyM emptyM isEmptyM (Tensor x y) = isEmptyM x && isEmptyM y variableM ei = case ei of Left v -> Tensor (variableM v) emptyM Right v -> Tensor emptyM (variableM v) singletonM ei k = case ei of Left v -> Tensor (singletonM v k) emptyM Right v -> Tensor emptyM (singletonM v k) -- algebra mulM (Tensor x1 y1) (Tensor x2 y2) = Tensor (mulM x1 x2) (mulM y1 y2) productM tensors = Tensor (productM $ map projLeft tensors) (productM $ map projRight tensors) powM (Tensor x y) k = Tensor (powM x k) (powM y k) divM (Tensor x1 y1) (Tensor x2 y2) = case (divM x1 x2, divM y1 y2) of (Just z1 , Just z2) -> Just (Tensor z1 z2) (_ , _ ) -> Nothing -- calculus diffM = diffTensor -- degrees maxDegM (Tensor x y) = max (maxDegM x) (maxDegM y) totalDegM (Tensor x y) = totalDegM x + totalDegM y -- substitution and evaluation evalM f (Tensor x y) = evalM (f . Left) x * evalM (f . Right) y varSubsM f (Tensor x y) = Tensor x' y' where x' = varSubsM (unsafeFromLeft . f . Left ) x y' = varSubsM (unsafeFromRight . f . Right) y termSubsM f (Tensor x y, c) = (Tensor x' y', c*d*e) where (x',d) = termSubsM (f . Left ) (x,1) (y',e) = termSubsM (f . Right) (y,1) -------------------------------------------------------------------------------- -- * Helpers distEither :: (Either a b, c) -> Either (a,c) (b,c) distEither (ei, z) = case ei of Left x -> Left (x,z) Right y -> Right (y,z) unsafeFromLeft :: Either a b -> a unsafeFromLeft ei = case ei of Left x -> x Right _ -> error "unsafeFromLeft: Right" unsafeFromRight :: Either a b -> b unsafeFromRight ei = case ei of Left _ -> error "unsafeFromRight: Left" Right y -> y --------------------------------------------------------------------------------