{-# LANGUAGE CPP #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE DefaultSignatures #-} #if __GLASGOW_HASKELL__ >= 706 {-# LANGUAGE PolyKinds #-} #endif #if __GLASGOW_HASKELL__ >= 702 && __GLASGOW_HASKELL__ < 710 {-# LANGUAGE Trustworthy #-} #endif --------------------------------------------------------------------------- -- | -- Copyright : (C) 2012-2015 Edward Kmett -- License : BSD-style (see the file LICENSE) -- -- Maintainer : Edward Kmett -- Stability : experimental -- Portability : non-portable -- -- Simple matrix operation for low-dimensional primitives. --------------------------------------------------------------------------- module Linear.Trace ( Trace(..) , frobenius ) where import Control.Monad as Monad import Linear.V0 import Linear.V1 import Linear.V2 import Linear.V3 import Linear.V4 import Linear.Plucker import Linear.Quaternion import Linear.V import Linear.Vector #if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ > 704 import Data.Complex #endif import Data.Distributive import Data.Foldable as Foldable import Data.Functor.Bind as Bind import Data.Functor.Compose import Data.Functor.Product import Data.Hashable import Data.HashMap.Lazy import Data.IntMap import Data.Map -- $setup -- >>> import Data.Complex -- >>> import Data.IntMap -- >>> import Debug.SimpleReflect.Vars -- >>> import Linear.V2 class Functor m => Trace m where -- | Compute the trace of a matrix -- -- >>> trace (V2 (V2 a b) (V2 c d)) -- a + d trace :: Num a => m (m a) -> a #ifndef HLINT default trace :: (Foldable m, Num a) => m (m a) -> a trace = Foldable.sum . diagonal {-# INLINE trace #-} #endif -- | Compute the diagonal of a matrix -- -- >>> diagonal (V2 (V2 a b) (V2 c d)) -- V2 a d diagonal :: m (m a) -> m a #ifndef HLINT default diagonal :: Monad m => m (m a) -> m a diagonal = Monad.join {-# INLINE diagonal #-} #endif instance Trace IntMap where diagonal = Bind.join {-# INLINE diagonal #-} instance Ord k => Trace (Map k) where diagonal = Bind.join {-# INLINE diagonal #-} instance (Eq k, Hashable k) => Trace (HashMap k) where diagonal = Bind.join {-# INLINE diagonal #-} instance Dim n => Trace (V n) instance Trace V0 instance Trace V1 instance Trace V2 instance Trace V3 instance Trace V4 instance Trace Plucker instance Trace Quaternion #if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ > 704 instance Trace Complex where trace ((a :+ _) :+ (_ :+ b)) = a + b {-# INLINE trace #-} diagonal ((a :+ _) :+ (_ :+ b)) = a :+ b {-# INLINE diagonal #-} #endif instance (Trace f, Trace g) => Trace (Product f g) where trace (Pair xx yy) = trace (pfst <$> xx) + trace (psnd <$> yy) where pfst (Pair x _) = x psnd (Pair _ y) = y {-# INLINE trace #-} diagonal (Pair xx yy) = diagonal (pfst <$> xx) `Pair` diagonal (psnd <$> yy) where pfst (Pair x _) = x psnd (Pair _ y) = y {-# INLINE diagonal #-} instance (Distributive g, Trace g, Trace f) => Trace (Compose g f) where trace = trace . fmap (fmap trace . distribute) . getCompose . fmap getCompose {-# INLINE trace #-} diagonal = Compose . fmap diagonal . diagonal . fmap distribute . getCompose . fmap getCompose {-# INLINE diagonal #-} -- | Compute the of a matrix. frobenius :: (Num a, Foldable f, Additive f, Additive g, Distributive g, Trace g) => f (g a) -> a frobenius m = trace $ fmap (\ f' -> Foldable.foldl' (^+^) zero $ liftI2 (*^) f' m) (distribute m)