module Data.Matrix (
Matrix , prettyMatrix
, nrows , ncols
, zero
, identity
, matrix
, getElem , (!)
, transpose , extendTo
, submatrix
, splitBlocks
, (<|>) , (<->)
, joinBlocks
) where
import Data.Monoid
import Control.DeepSeq
import qualified Data.Vector as V
data Matrix a = M {
nrows :: !Int
, ncols :: !Int
, mvect :: V.Vector a
} deriving Eq
sizeStr :: Int -> Int -> String
sizeStr n m = show n ++ "x" ++ show m
prettyMatrix :: Show a => Matrix a -> String
prettyMatrix m@(M _ _ v) = unlines
[ "( " <> unwords (fmap (\j -> fill mx $ show $ m ! (i,j)) [1..ncols m]) <> " )" | i <- [1..nrows m] ]
where
mx = V.maximum $ fmap (length . show) v
fill k str = replicate (k length str) ' ' ++ str
instance Show a => Show (Matrix a) where
show = prettyMatrix
instance NFData a => NFData (Matrix a) where
rnf (M _ _ v) = rnf v
encode :: Int
-> (Int,Int) -> Int
encode m (i,j) = (i1) * m + j 1
decode :: Int
-> Int -> (Int,Int)
decode m k = (q+1,r+1)
where
(q,r) = quotRem k m
zero :: Num a =>
Int
-> Int
-> Matrix a
zero n m = M n m $ V.replicate (n*m) 0
matrix :: Int
-> Int
-> ((Int,Int) -> a)
-> Matrix a
matrix n m f = M n m $ V.generate (n*m) (f . decode m)
identity :: Num a => Int -> Matrix a
identity n = matrix n n $ \(i,j) -> if i == j then 1 else 0
getElem :: Int
-> Int
-> Matrix a
-> a
getElem i j (M n m v)
| i > n || j > m = error $ "Trying to get the " ++ show (i,j) ++ " element from a "
++ sizeStr n m ++ " matrix."
| otherwise = v V.! encode m (i,j)
(!) :: Matrix a -> (Int,Int) -> a
m ! (i,j) = getElem i j m
transpose :: Matrix a -> Matrix a
transpose (M n m v) = M m n $ V.backpermute v $
fmap (\k -> let (q,r) = quotRem k n
in r*m + q
) $ V.enumFromN 0 (V.length v)
extendTo :: Num a
=> Int
-> Int
-> Matrix a -> Matrix a
extendTo n m a = a''
where
n' = n nrows a
a' = if n' <= 0 then a else a <-> zero n' (ncols a)
m' = m ncols a
a'' = if m' <= 0 then a' else a' <|> zero (nrows a') m'
submatrix :: Int
-> Int
-> Int
-> Int
-> Matrix a
-> Matrix a
submatrix r1 r2 c1 c2 (M _ m v) = M (r2r1+1) m' $
mconcat [ V.slice (encode m (r,c1)) m' v | r <- [r1 .. r2] ]
where
m' = c2c1+1
splitBlocks :: Int
-> Int
-> Matrix a
-> (Matrix a,Matrix a
,Matrix a,Matrix a)
splitBlocks i j a@(M n m _) = ( submatrix 1 i 1 j a , submatrix 1 i (j+1) m a
, submatrix (i+1) n 1 j a , submatrix (i+1) n (j+1) m a )
joinBlocks :: (Matrix a,Matrix a
,Matrix a,Matrix a)
-> Matrix a
joinBlocks (tl,tr,bl,br) = (tl <|> tr)
<->
(bl <|> br)
(<|>) :: Matrix a -> Matrix a -> Matrix a
(M n m v) <|> (M n' m' v')
| n /= n' = error $ "Horizontal join of " ++ sizeStr n m ++ " and "
++ sizeStr n' m' ++ " matrices."
| otherwise = let v'' = mconcat [ V.slice (encode m (r,1)) m v
<> V.slice (encode m' (r,1)) m' v'
| r <- [1..n] ]
in M n (m+m') v''
(<->) :: Matrix a -> Matrix a -> Matrix a
(M n m v) <-> (M n' m' v')
| m /= m' = error $ "Vertical join of " ++ sizeStr n m ++ " and "
++ sizeStr n' m' ++ " matrices."
| otherwise = M (n+n') m $ v <> v'
instance Functor Matrix where
fmap f (M n m v) = M n m $ fmap f v
strassen :: Num a => Matrix a -> Matrix a -> Matrix a
strassen (M 1 1 v) (M 1 1 v') = M 1 1 $ V.zipWith (*) v v'
strassen a b = joinBlocks (c11,c12,c21,c22)
where
n = div (nrows a) 2
(a11,a12,a21,a22) = splitBlocks n n a
(b11,b12,b21,b22) = splitBlocks n n b
p1 = strassen (a11 + a22) (b11 + b22)
p2 = strassen (a21 + a22) b11
p3 = strassen a11 (b12 b22)
p4 = strassen a22 (b21 b11)
p5 = strassen (a11 + a12) b22
p6 = strassen (a21 a11) (b11 + b12)
p7 = strassen (a12 a22) (b21 + b22)
c11 = p1 + p4 p5 + p7
c12 = p3 + p5
c21 = p2 + p4
c22 = p1 p2 + p3 + p6
first :: (a -> Bool) -> [a] -> a
first f = go
where
go (x:xs) = if f x then x else go xs
go [] = error "first: no element match the condition."
instance Num a => Num (Matrix a) where
fromInteger = M 1 1 . V.singleton . fromInteger
negate = fmap negate
abs = fmap abs
signum = fmap signum
(M n m v) + (M n' m' v')
| n /= n' || m /= m' = error $ "Addition of " ++ sizeStr n m ++ " and "
++ sizeStr n' m' ++ " matrices."
| otherwise = M n m $ V.zipWith (+) v v'
(M 1 1 v) * (M 1 1 v') = M 1 1 $ V.zipWith (*) v v'
a1@(M n m _) * a2@(M n' m' _)
| m /= n' = error $ "Multiplication of " ++ sizeStr n m ++ " and "
++ sizeStr n' m' ++ " matrices."
| otherwise =
let mx = maximum [n,m,n',m']
n2 = first (>= mx) $ fmap (2^) [(0 :: Int)..]
b1 = extendTo n2 n2 a1
b2 = extendTo n2 n2 a2
in submatrix 1 n 1 m' $ strassen b1 b2