{-# LANGUAGE CPP #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE Trustworthy #-} --------------------------------------------------------------------------- -- | -- Copyright : (C) 2012-2015 Edward Kmett -- License : BSD-style (see the file LICENSE) -- -- Maintainer : Edward Kmett -- Stability : experimental -- Portability : non-portable -- -- Simple matrix operation for low-dimensional primitives. --------------------------------------------------------------------------- module Linear.Trace ( Trace(..) , frobenius ) where import Control.Monad as Monad import Linear.V0 import Linear.V1 import Linear.V2 import Linear.V3 import Linear.V4 import Linear.Plucker import Linear.Quaternion import Linear.V import Linear.Vector import Data.Complex import Data.Distributive import Data.Foldable as Foldable import Data.Functor.Bind as Bind import Data.Functor.Compose import Data.Functor.Product import Data.Hashable import Data.HashMap.Lazy import Data.IntMap (IntMap) import Data.Map (Map) -- $setup -- >>> import Data.Complex -- >>> import Debug.SimpleReflect.Vars -- >>> import Linear.V2 class Functor m => Trace m where -- | Compute the trace of a matrix -- -- >>> trace (V2 (V2 a b) (V2 c d)) -- a + d trace :: Num a => m (m a) -> a #ifndef HLINT default trace :: (Foldable m, Num a) => m (m a) -> a trace = Foldable.sum . diagonal {-# INLINE trace #-} #endif -- | Compute the diagonal of a matrix -- -- >>> diagonal (V2 (V2 a b) (V2 c d)) -- V2 a d diagonal :: m (m a) -> m a #ifndef HLINT default diagonal :: Monad m => m (m a) -> m a diagonal = Monad.join {-# INLINE diagonal #-} #endif instance Trace IntMap where diagonal = Bind.join {-# INLINE diagonal #-} instance Ord k => Trace (Map k) where diagonal = Bind.join {-# INLINE diagonal #-} instance (Eq k, Hashable k) => Trace (HashMap k) where diagonal = Bind.join {-# INLINE diagonal #-} instance Dim n => Trace (V n) instance Trace V0 instance Trace V1 instance Trace V2 instance Trace V3 instance Trace V4 instance Trace Plucker instance Trace Quaternion instance Trace Complex where trace ((a :+ _) :+ (_ :+ b)) = a + b {-# INLINE trace #-} diagonal ((a :+ _) :+ (_ :+ b)) = a :+ b {-# INLINE diagonal #-} instance (Trace f, Trace g) => Trace (Product f g) where trace (Pair xx yy) = trace (pfst <$> xx) + trace (psnd <$> yy) where pfst (Pair x _) = x psnd (Pair _ y) = y {-# INLINE trace #-} diagonal (Pair xx yy) = diagonal (pfst <$> xx) `Pair` diagonal (psnd <$> yy) where pfst (Pair x _) = x psnd (Pair _ y) = y {-# INLINE diagonal #-} instance (Distributive g, Trace g, Trace f) => Trace (Compose g f) where trace = trace . fmap (fmap trace . distribute) . getCompose . fmap getCompose {-# INLINE trace #-} diagonal = Compose . fmap diagonal . diagonal . fmap distribute . getCompose . fmap getCompose {-# INLINE diagonal #-} -- | Compute the of a matrix. frobenius :: (Num a, Foldable f, Additive f, Additive g, Distributive g, Trace g) => f (g a) -> a frobenius m = trace $ fmap (\ f' -> Foldable.foldl' (^+^) zero $ liftI2 (*^) f' m) (distribute m)