{-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE DeriveGeneric #-} module Data.Matrix.Symmetric ( SymMatrix(..) , dim , rows , cols , unsafeIndex , (!) , flatten , unsafeFromVector , fromVector , takeRow , thaw , unsafeThaw , freeze , unsafeFreeze , create , Data.Matrix.Symmetric.map , imap , zip , zipWith ) where import Control.Monad (liftM) import Data.Bits (shiftR) import qualified Data.Vector.Generic as G import Prelude hiding (zip, zipWith) import GHC.Generics (Generic) import Data.Matrix.Generic import Data.Matrix.Symmetric.Mutable (SymMMatrix (..), new, unsafeWrite) type instance Mutable SymMatrix = SymMMatrix -- | Symmetric square matrix data SymMatrix v a = SymMatrix !Int !(v a) deriving (Show, Read, Generic, Eq) -------------------------------------------------------------------------------- -- Instances -------------------------------------------------------------------------------- instance G.Vector v a => Matrix SymMatrix v a where dim (SymMatrix n _) = (n,n) {-# INLINE dim #-} unsafeIndex (SymMatrix n vec) (i,j) = vec `G.unsafeIndex` idx n i j {-# INLINE unsafeIndex #-} unsafeFromVector (r,c) vec | r /= c = error "columns /= rows" | otherwise = SymMatrix r . G.concat . Prelude.map f $ [0..r-1] where f i = G.slice (i*(c+1)) (c-i) vec -- n = ((r+1)*r) `shiftR` 1 {-# INLINE unsafeFromVector #-} thaw (SymMatrix n v) = SymMMatrix n `liftM` G.thaw v {-# INLINE thaw #-} unsafeThaw (SymMatrix n v) = SymMMatrix n `liftM` G.thaw v {-# INLINE unsafeThaw #-} freeze (SymMMatrix n v) = SymMatrix n `liftM` G.freeze v {-# INLINE freeze #-} unsafeFreeze (SymMMatrix n v) = SymMatrix n `liftM` G.unsafeFreeze v {-# INLINE unsafeFreeze #-} -------------------------------------------------------------------------------- map :: (G.Vector v a, G.Vector v b) => (a -> b) -> SymMatrix v a -> SymMatrix v b map f (SymMatrix n vec) = SymMatrix n $ G.map f vec {-# INLINE map #-} -- | Upper triangular imap, i.e., i <= j imap :: (G.Vector v a, G.Vector v b) => ((Int, Int) -> a -> b) -> SymMatrix v a -> SymMatrix v b imap f mat = create $ do mat' <- new (n,n) let loop m !i !j | i >= n = return () | j >= n = loop m (i+1) (i+1) | otherwise = unsafeWrite m (i,j) (f (i,j) x) >> loop m i (j+1) where x = unsafeIndex mat (i,j) loop mat' 0 0 return mat' where n = rows mat {-# INLINE imap #-} zip :: (G.Vector v a, G.Vector v b, G.Vector v (a,b)) => SymMatrix v a -> SymMatrix v b -> SymMatrix v (a,b) zip (SymMatrix n1 v1) (SymMatrix n2 v2) | n1 /= n2 = error "imcompatible size" | otherwise = SymMatrix n1 $ G.zip v1 v2 {-# INLINE zip #-} zipWith :: (G.Vector v a, G.Vector v b, G.Vector v c) => (a -> b -> c) -> SymMatrix v a -> SymMatrix v b -> SymMatrix v c zipWith f (SymMatrix n1 v1) (SymMatrix n2 v2) | n1 /= n2 = error "imcompatible size" | otherwise = SymMatrix n1 . G.zipWith f v1 $ v2 {-# INLINE zipWith #-} -- helper -- row major upper triangular indexing idx :: Int -> Int -> Int -> Int idx n i j | i <= j = (i * (2 * n - i - 1)) `shiftR` 1 + j | otherwise = (j * (2 * n - j - 1)) `shiftR` 1 + i {-# INLINE idx #-}