{-# LANGUAGE PatternGuards #-} -- | -- Module : Statistics.Matrix -- Copyright : 2011 Aleksey Khudyakov, 2014 Bryan O'Sullivan -- License : BSD3 -- -- Basic matrix operations. -- -- There isn't a widely used matrix package for Haskell yet, so -- we implement the necessary minimum here. module Statistics.Matrix ( -- * Data types Matrix(..) , Vector -- * Conversion from/to lists/vectors , fromVector , fromList , fromRowLists , fromRows , fromColumns , toVector , toList , toRows , toColumns , toRowLists -- * Other , generate , generateSym , ident , diag , dimension , center , multiply , multiplyV , transpose , power , norm , column , row , map , for , unsafeIndex , hasNaN , bounds , unsafeBounds ) where import Prelude hiding (exponent, map, sum) import Control.Applicative ((<$>)) import Control.Monad.ST import qualified Data.Vector.Unboxed as U import Data.Vector.Unboxed ((!)) import qualified Data.Vector.Unboxed.Mutable as UM import Statistics.Function (for, square) import Statistics.Matrix.Types import Statistics.Matrix.Mutable (unsafeNew,unsafeWrite,unsafeFreeze) import Statistics.Sample.Internal (sum) ---------------------------------------------------------------- -- Conversion to/from vectors/lists ---------------------------------------------------------------- -- | Convert from a row-major list. fromList :: Int -- ^ Number of rows. -> Int -- ^ Number of columns. -> [Double] -- ^ Flat list of values, in row-major order. -> Matrix fromList r c = fromVector r c . U.fromList -- | create a matrix from a list of lists, as rows fromRowLists :: [[Double]] -> Matrix fromRowLists = fromRows . fmap U.fromList -- | Convert from a row-major vector. fromVector :: Int -- ^ Number of rows. -> Int -- ^ Number of columns. -> U.Vector Double -- ^ Flat list of values, in row-major order. -> Matrix fromVector r c v | r*c /= len = error "input size mismatch" | otherwise = Matrix r c 0 v where len = U.length v -- | create a matrix from a list of vectors, as rows fromRows :: [Vector] -> Matrix fromRows xs | [] <- xs = error "Statistics.Matrix.fromRows: empty list of rows!" | any (/=nCol) ns = error "Statistics.Matrix.fromRows: row sizes do not match" | nCol == 0 = error "Statistics.Matrix.fromRows: zero columns in matrix" | otherwise = fromVector nRow nCol (U.concat xs) where nCol:ns = U.length <$> xs nRow = length xs -- | create a matrix from a list of vectors, as columns fromColumns :: [Vector] -> Matrix fromColumns = transpose . fromRows -- | Convert to a row-major flat vector. toVector :: Matrix -> U.Vector Double toVector (Matrix _ _ _ v) = v -- | Convert to a row-major flat list. toList :: Matrix -> [Double] toList = U.toList . toVector -- | Convert to a list of lists, as rows toRowLists :: Matrix -> [[Double]] toRowLists (Matrix _ nCol _ v) = chunks $ U.toList v where chunks [] = [] chunks xs = case splitAt nCol xs of (rowE,rest) -> rowE : chunks rest -- | Convert to a list of vectors, as rows toRows :: Matrix -> [Vector] toRows (Matrix _ nCol _ v) = chunks v where chunks xs | U.null xs = [] | otherwise = case U.splitAt nCol xs of (rowE,rest) -> rowE : chunks rest -- | Convert to a list of vectors, as columns toColumns :: Matrix -> [Vector] toColumns = toRows . transpose ---------------------------------------------------------------- -- Other ---------------------------------------------------------------- -- | Generate matrix using function generate :: Int -- ^ Number of rows -> Int -- ^ Number of columns -> (Int -> Int -> Double) -- ^ Function which takes /row/ and /column/ as argument. -> Matrix generate nRow nCol f = Matrix nRow nCol 0 $ U.generate (nRow*nCol) $ \i -> let (r,c) = i `quotRem` nCol in f r c -- | Generate symmetric square matrix using function generateSym :: Int -- ^ Number of rows and columns -> (Int -> Int -> Double) -- ^ Function which takes /row/ and /column/ as argument. It must -- be symmetric in arguments: @f i j == f j i@ -> Matrix generateSym n f = runST $ do m <- unsafeNew n n for 0 n $ \r -> do unsafeWrite m r r (f r r) for (r+1) n $ \c -> do let x = f r c unsafeWrite m r c x unsafeWrite m c r x unsafeFreeze m -- | Create the square identity matrix with given dimensions. ident :: Int -> Matrix ident n = diag $ U.replicate n 1.0 -- | Create a square matrix with given diagonal, other entries default to 0 diag :: Vector -> Matrix diag v = Matrix n n 0 $ U.create $ do arr <- UM.replicate (n*n) 0 for 0 n $ \i -> UM.unsafeWrite arr (i*n + i) (v ! i) return arr where n = U.length v -- | Return the dimensions of this matrix, as a (row,column) pair. dimension :: Matrix -> (Int, Int) dimension (Matrix r c _ _) = (r, c) -- | Avoid overflow in the matrix. avoidOverflow :: Matrix -> Matrix avoidOverflow m@(Matrix r c e v) | center m > 1e140 = Matrix r c (e + 140) (U.map (* 1e-140) v) | otherwise = m -- | Matrix-matrix multiplication. Matrices must be of compatible -- sizes (/note: not checked/). multiply :: Matrix -> Matrix -> Matrix multiply m1@(Matrix r1 _ e1 _) m2@(Matrix _ c2 e2 _) = Matrix r1 c2 (e1 + e2) $ U.generate (r1*c2) go where go t = sum $ U.zipWith (*) (row m1 i) (column m2 j) where (i,j) = t `quotRem` c2 -- | Matrix-vector multiplication. multiplyV :: Matrix -> Vector -> Vector multiplyV m v | cols m == c = U.generate (rows m) (sum . U.zipWith (*) v . row m) | otherwise = error $ "matrix/vector unconformable " ++ show (cols m,c) where c = U.length v -- | Raise matrix to /n/th power. Power must be positive -- (/note: not checked). power :: Matrix -> Int -> Matrix power mat 1 = mat power mat n = avoidOverflow res where mat2 = power mat (n `quot` 2) pow = multiply mat2 mat2 res | odd n = multiply pow mat | otherwise = pow -- | Element in the center of matrix (not corrected for exponent). center :: Matrix -> Double center mat@(Matrix r c _ _) = unsafeBounds U.unsafeIndex mat (r `quot` 2) (c `quot` 2) -- | Calculate the Euclidean norm of a vector. norm :: Vector -> Double norm = sqrt . sum . U.map square -- | Return the given column. column :: Matrix -> Int -> Vector column (Matrix r c _ v) i = U.backpermute v $ U.enumFromStepN i c r {-# INLINE column #-} -- | Return the given row. row :: Matrix -> Int -> Vector row (Matrix _ c _ v) i = U.slice (c*i) c v unsafeIndex :: Matrix -> Int -- ^ Row. -> Int -- ^ Column. -> Double unsafeIndex = unsafeBounds U.unsafeIndex -- | Apply function to every element of matrix map :: (Double -> Double) -> Matrix -> Matrix map f (Matrix r c e v) = Matrix r c e (U.map f v) -- | Indicate whether any element of the matrix is @NaN@. hasNaN :: Matrix -> Bool hasNaN = U.any isNaN . toVector -- | Given row and column numbers, calculate the offset into the flat -- row-major vector. bounds :: (Vector -> Int -> r) -> Matrix -> Int -> Int -> r bounds k (Matrix rs cs _ v) r c | r < 0 || r >= rs = error "row out of bounds" | c < 0 || c >= cs = error "column out of bounds" | otherwise = k v $! r * cs + c {-# INLINE bounds #-} -- | Given row and column numbers, calculate the offset into the flat -- row-major vector, without checking. unsafeBounds :: (Vector -> Int -> r) -> Matrix -> Int -> Int -> r unsafeBounds k (Matrix _ cs _ v) r c = k v $! r * cs + c {-# INLINE unsafeBounds #-} transpose :: Matrix -> Matrix transpose m@(Matrix r0 c0 e _) = Matrix c0 r0 e . U.generate (r0*c0) $ \i -> let (r,c) = i `quotRem` r0 in unsafeIndex m c r