module Data.Matrix.Symmetric
( SymMatrix(..)
, dim
, rows
, cols
, unsafeIndex
, (!)
, flatten
, unsafeFromVector
, fromVector
, takeRow
, thaw
, unsafeThaw
, freeze
, unsafeFreeze
, zip
, zipWith
) where
import Prelude hiding (zip, zipWith)
import Control.Monad (liftM)
import Data.Bits (shiftR)
import Data.Binary
import qualified Data.Vector.Generic as G
import Data.Matrix.Generic
import Data.Matrix.Symmetric.Mutable (SymMMatrix(..))
type instance Mutable SymMatrix = SymMMatrix
data SymMatrix v a = SymMatrix !Int !(v a)
deriving (Show)
instance (G.Vector v a, Binary a) => Binary (SymMatrix v a) where
put (SymMatrix n vec) = do put (52 :: Int)
put n
G.mapM_ put vec
get = do
magic <- get :: Get Int
if magic == 52
then do
n <- get
vec <- G.replicateM ((n+1)*n `shiftR` 1) get
return $ SymMatrix n vec
else error "not a valid file"
instance G.Vector v a => Matrix SymMatrix v a where
dim (SymMatrix n _) = (n,n)
unsafeIndex (SymMatrix n vec) (i,j) = vec `G.unsafeIndex` idx n i j
unsafeFromVector (r,c) vec | r /= c = error "columns /= rows"
| otherwise = SymMatrix r . G.concat . map f $ [0..r1]
where
f i = G.slice (i*(c+1)) (ci) vec
thaw (SymMatrix n v) = SymMMatrix n `liftM` G.thaw v
unsafeThaw (SymMatrix n v) = SymMMatrix n `liftM` G.thaw v
freeze (SymMMatrix n v) = SymMatrix n `liftM` G.freeze v
unsafeFreeze (SymMMatrix n v) = SymMatrix n `liftM` G.unsafeFreeze v
idx :: Int -> Int -> Int -> Int
idx n i j | i <= j = (i * (2 * n i 1)) `shiftR` 1 + j
| otherwise = (j * (2 * n j 1)) `shiftR` 1 + i
zip :: (G.Vector v a, G.Vector v b, G.Vector v (a,b))
=> SymMatrix v a -> SymMatrix v b -> SymMatrix v (a,b)
zip (SymMatrix n1 v1) (SymMatrix n2 v2)
| n1 /= n2 = error "imcompatible size"
| otherwise = SymMatrix n1 $ G.zip v1 v2
zipWith :: (G.Vector v a, G.Vector v b, G.Vector v c)
=> (a -> b -> c) -> SymMatrix v a -> SymMatrix v b -> SymMatrix v c
zipWith f (SymMatrix n1 v1) (SymMatrix n2 v2)
| n1 /= n2 = error "imcompatible size"
| otherwise = SymMatrix n1 . G.zipWith f v1 $ v2