module Arithmetic.Matrix where
import Numeric.Algebra hiding ((+), (*))
import Data.List
import Data.Function (on)
newtype TMatrix a = TMatrix {toList :: [[a]]}
deriving (Eq, Show)
fmap' :: ([[a]] -> [[b]]) -> TMatrix a -> TMatrix b
fmap' f (TMatrix a) = TMatrix (f a)
transp :: TMatrix a -> TMatrix a
transp = fmap' transpose
msum :: (Num a) => TMatrix a-> TMatrix a -> TMatrix a
msum a b = TMatrix $ (zipWith (zipWith (+)) `on` toList) a b
mmult :: (Rig a, Num a) => TMatrix a -> TMatrix a -> TMatrix a
mmult a b = TMatrix [[ foldr (+) zero $ zipWith (*) ar bc | bc <- transpose (toList b)] | ar <- toList a ]
instance (Num a, Rig a) => Num (TMatrix a) where
(+) = msum
(*) = mmult