module Data.Eigen.Matrix (
Matrix(..),
MatrixXf,
MatrixXd,
MatrixXcf,
MatrixXcd,
I.Elem,
I.CComplex,
valid,
fromList,
toList,
generate,
empty,
null,
square,
zero,
ones,
identity,
constant,
random,
cols,
rows,
dims,
(!),
coeff,
unsafeCoeff,
col,
row,
block,
topRows,
bottomRows,
leftCols,
rightCols,
sum,
prod,
mean,
minCoeff,
maxCoeff,
trace,
norm,
squaredNorm,
blueNorm,
hypotNorm,
determinant,
fold,
fold',
ifold,
ifold',
fold1,
fold1',
all,
any,
count,
add,
sub,
mul,
map,
imap,
filter,
ifilter,
diagonal,
transpose,
inverse,
adjoint,
conjugate,
normalize,
modify,
convert,
TriangularMode(..),
triangularView,
lowerTriangle,
upperTriangle,
encode,
decode,
thaw,
freeze,
unsafeThaw,
unsafeFreeze,
unsafeWith,
) where
import qualified Prelude as P
import qualified Data.List as L
import Prelude hiding (null, sum, all, any, map, filter)
import Data.Tuple
import Data.Complex hiding (conjugate)
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.C.Types
import Foreign.C.String
import Foreign.Storable
import Foreign.Marshal.Alloc
import Text.Printf
import Control.Monad
import Control.Monad.ST
import Control.Monad.Primitive
#if __GLASGOW_HASKELL__ >= 710
#else
import Control.Applicative hiding (empty)
#endif
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import qualified Data.Eigen.Internal as I
import qualified Data.Eigen.Matrix.Mutable as M
import qualified Data.ByteString.Lazy as BSL
import qualified Data.ByteString.Internal as BSI
data Matrix a b where
Matrix :: I.Elem a b => !Int -> !Int -> !(VS.Vector b) -> Matrix a b
type MatrixXf = Matrix Float CFloat
type MatrixXd = Matrix Double CDouble
type MatrixXcf = Matrix (Complex Float) (I.CComplex CFloat)
type MatrixXcd = Matrix (Complex Double) (I.CComplex CDouble)
instance (I.Elem a b, Show a) => Show (Matrix a b) where
show m@(Matrix rows cols _) = concat [
"Matrix ", show rows, "x", show cols,
"\n", L.intercalate "\n" $ P.map (L.intercalate "\t" . P.map show) $ toList m, "\n"]
instance I.Elem a b => Num (Matrix a b) where
(*) = mul
(+) = add
() = sub
fromInteger = constant 1 1 . fromInteger
signum = map signum
abs = map abs
negate = map negate
empty :: I.Elem a b => Matrix a b
empty = Matrix 0 0 VS.empty
null :: I.Elem a b => Matrix a b -> Bool
null (Matrix rows cols _) = rows == 0 && cols == 0
square :: I.Elem a b => Matrix a b -> Bool
square (Matrix rows cols _) = rows == cols
constant :: I.Elem a b => Int -> Int -> a -> Matrix a b
constant rows cols val = Matrix rows cols $ VS.replicate (rows * cols) (I.cast val)
zero :: I.Elem a b => Int -> Int -> Matrix a b
zero rows cols = constant rows cols 0
ones :: I.Elem a b => Int -> Int -> Matrix a b
ones rows cols = constant rows cols 1
identity :: I.Elem a b => Int -> Int -> Matrix a b
identity rows cols = I.performIO $ do
m <- M.new rows cols
I.call $ M.unsafeWith m I.identity
unsafeFreeze m
random :: I.Elem a b => Int -> Int -> IO (Matrix a b)
random rows cols = do
m <- M.new rows cols
I.call $ M.unsafeWith m I.random
unsafeFreeze m
rows :: I.Elem a b => Matrix a b -> Int
rows (Matrix rows _ _) = rows
cols :: I.Elem a b => Matrix a b -> Int
cols (Matrix _ cols _) = cols
dims :: I.Elem a b => Matrix a b -> (Int, Int)
dims (Matrix rows cols _) = (rows, cols)
(!) :: I.Elem a b => Matrix a b -> (Int, Int) -> a
(!) m (row,col) = coeff row col m
coeff :: I.Elem a b => Int -> Int -> Matrix a b -> a
coeff row col m@(Matrix rows cols _)
| not (valid m) = error "matrix is not valid"
| row < 0 || row >= rows = error $ printf "Matrix.coeff: row %d is out of bounds [0..%d)" row rows
| col < 0 || col >= cols = error $ printf "Matrix.coeff: col %d is out of bounds [0..%d)" col cols
| otherwise = unsafeCoeff row col m
unsafeCoeff :: I.Elem a b => Int -> Int -> Matrix a b -> a
unsafeCoeff row col (Matrix rows _ vals) = I.cast $ VS.unsafeIndex vals $ col * rows + row
col :: I.Elem a b => Int -> Matrix a b -> [a]
col c m@(Matrix rows _ _) = [coeff r c m | r <- [0..pred rows]]
row :: I.Elem a b => Int -> Matrix a b -> [a]
row r m@(Matrix _ cols _) = [coeff r c m | c <- [0..pred cols]]
block :: I.Elem a b => Int -> Int -> Int -> Int -> Matrix a b -> Matrix a b
block startRow startCol blockRows blockCols m =
generate blockRows blockCols $ \row col ->
coeff (startRow + row) (startCol + col) m
valid :: I.Elem a b => Matrix a b -> Bool
valid (Matrix rows cols vals) = rows >= 0 && cols >= 0 && VS.length vals == rows * cols
maxCoeff :: (I.Elem a b, Ord a) => Matrix a b -> a
maxCoeff = fold1' max
minCoeff :: (I.Elem a b, Ord a) => Matrix a b -> a
minCoeff = fold1' min
topRows :: I.Elem a b => Int -> Matrix a b -> Matrix a b
topRows n m@(Matrix _ cols _) = block 0 0 n cols m
bottomRows :: I.Elem a b => Int -> Matrix a b -> Matrix a b
bottomRows n m@(Matrix rows cols _) = block (rows n) 0 n cols m
leftCols :: I.Elem a b => Int -> Matrix a b -> Matrix a b
leftCols n m@(Matrix rows _ _) = block 0 0 rows n m
rightCols :: I.Elem a b => Int -> Matrix a b -> Matrix a b
rightCols n m@(Matrix rows cols _) = block 0 (cols n) rows n m
fromList :: I.Elem a b => [[a]] -> Matrix a b
fromList list = Matrix rows cols vals where
rows = length list
cols = L.foldl' max 0 $ P.map length list
vals = VS.create $ do
vm <- VSM.replicate (rows * cols) (I.cast (0 `asTypeOf` (head (head list))))
forM_ (zip [0..] list) $ \(row, vals) ->
forM_ (zip [0..] vals) $ \(col, val) ->
VSM.write vm (col * rows + row) (I.cast val)
return vm
toList :: I.Elem a b => Matrix a b -> [[a]]
toList m@(Matrix rows cols vals)
| not (valid m) = error "matrix is not valid"
| otherwise = [[I.cast $ vals `VS.unsafeIndex` (col * rows + row) | col <- [0..pred cols]] | row <- [0..pred rows]]
generate :: I.Elem a b => Int -> Int -> (Int -> Int -> a) -> Matrix a b
generate rows cols f = Matrix rows cols $ VS.create $ do
vals <- VSM.new (rows * cols)
forM_ [0..pred rows] $ \row ->
forM_ [0..pred cols] $ \col ->
VSM.write vals (col * rows + row) (I.cast $ f row col)
return vals
sum :: I.Elem a b => Matrix a b -> a
sum = _prop I.sum
prod :: I.Elem a b => Matrix a b -> a
prod = _prop I.prod
mean :: I.Elem a b => Matrix a b -> a
mean = _prop I.mean
trace :: I.Elem a b => Matrix a b -> a
trace = _prop I.trace
all :: I.Elem a b => (a -> Bool) -> Matrix a b -> Bool
all f = VS.all (f . I.cast) . _vals
any :: I.Elem a b => (a -> Bool) -> Matrix a b -> Bool
any f = VS.any (f . I.cast) . _vals
count :: I.Elem a b => (a -> Bool) -> Matrix a b -> Int
count f = VS.foldl' (\n x -> if f (I.cast x) then succ n else n) 0 . _vals
norm :: I.Elem a b => Matrix a b -> a
norm = _prop I.norm
squaredNorm :: I.Elem a b => Matrix a b -> a
squaredNorm = _prop I.squaredNorm
blueNorm :: I.Elem a b => Matrix a b -> a
blueNorm = _prop I.blueNorm
hypotNorm :: I.Elem a b => Matrix a b -> a
hypotNorm = _prop I.hypotNorm
determinant :: I.Elem a b => Matrix a b -> a
determinant m
| square m = _prop I.determinant m
| otherwise = error "Matrix.determinant: non-square matrix"
add :: I.Elem a b => Matrix a b -> Matrix a b -> Matrix a b
add m1 m2
| dims m1 == dims m2 = _binop const I.add m1 m2
| otherwise = error "Matrix.add: matrices should have the same size"
sub :: I.Elem a b => Matrix a b -> Matrix a b -> Matrix a b
sub m1 m2
| dims m1 == dims m2 = _binop const I.sub m1 m2
| otherwise = error "Matrix.add: matrices should have the same size"
mul :: I.Elem a b => Matrix a b -> Matrix a b -> Matrix a b
mul m1 m2
| cols m1 == rows m2 = _binop (\(rows, _) (_, cols) -> (rows, cols)) I.mul m1 m2
| otherwise = error "Matrix.mul: number of columns for lhs matrix should be the same as number of rows for rhs matrix"
map :: I.Elem a b => (a -> a) -> Matrix a b -> Matrix a b
map f (Matrix rows cols vals) = Matrix rows cols (VS.map (I.cast . f . I.cast) vals)
imap :: I.Elem a b => (Int -> Int -> a -> a) -> Matrix a b -> Matrix a b
imap f (Matrix rows cols vals) = Matrix rows cols (VS.imap (\n -> let (c, r) = divMod n rows in I.cast . f r c . I.cast) vals)
data TriangularMode
= Lower
| Upper
| StrictlyLower
| StrictlyUpper
| UnitLower
| UnitUpper deriving (Eq, Enum, Show, Read)
triangularView :: I.Elem a b => TriangularMode -> Matrix a b -> Matrix a b
triangularView Lower = imap $ \row col val -> case compare row col of { LT -> 0; _ -> val }
triangularView Upper = imap $ \row col val -> case compare row col of { GT -> 0; _ -> val }
triangularView StrictlyLower = imap $ \row col val -> case compare row col of { GT -> val; _ -> 0 }
triangularView StrictlyUpper = imap $ \row col val -> case compare row col of { LT -> val; _ -> 0 }
triangularView UnitLower = imap $ \row col val -> case compare row col of { GT -> val; LT -> 0; EQ -> 1 }
triangularView UnitUpper = imap $ \row col val -> case compare row col of { LT -> val; GT -> 0; EQ -> 1 }
lowerTriangle :: I.Elem a b => Matrix a b -> Matrix a b
lowerTriangle = triangularView Lower
upperTriangle :: I.Elem a b => Matrix a b -> Matrix a b
upperTriangle = triangularView Upper
filter :: I.Elem a b => (a -> Bool) -> Matrix a b -> Matrix a b
filter f = map (\x -> if f x then x else 0)
ifilter :: I.Elem a b => (Int -> Int -> a -> Bool) -> Matrix a b -> Matrix a b
ifilter f = imap (\r c x -> if f r c x then x else 0)
fold :: I.Elem a b => (c -> a -> c) -> c -> Matrix a b -> c
fold f a (Matrix _ _ vals) = VS.foldl (\a x -> f a (I.cast x)) a vals
fold' :: I.Elem a b => (c -> a -> c) -> c -> Matrix a b -> c
fold' f a (Matrix _ _ vals) = VS.foldl' (\a x -> f a (I.cast x)) a vals
ifold :: I.Elem a b => (Int -> Int -> c -> a -> c) -> c -> Matrix a b -> c
ifold f a (Matrix rows _ vals) = VS.ifoldl (\a n x -> let (c,r) = divMod n rows in f r c a (I.cast x)) a vals
ifold' :: I.Elem a b => (Int -> Int -> c -> a -> c) -> c -> Matrix a b -> c
ifold' f a (Matrix rows _ vals) = VS.ifoldl' (\a n x -> let (c,r) = divMod n rows in f r c a (I.cast x)) a vals
fold1 :: I.Elem a b => (a -> a -> a) -> Matrix a b -> a
fold1 f = foldl1 f . P.map I.cast . VS.toList . _vals
fold1' :: I.Elem a b => (a -> a -> a) -> Matrix a b -> a
fold1' f = L.foldl1' f . P.map I.cast . VS.toList . _vals
diagonal :: I.Elem a b => Matrix a b -> Matrix a b
diagonal = _unop (\(rows, cols) -> (min rows cols, 1)) I.diagonal
inverse :: I.Elem a b => Matrix a b -> Matrix a b
inverse m
| square m = _unop id I.inverse m
| otherwise = error "Matrix.inverse: non-square matrix"
adjoint :: I.Elem a b => Matrix a b -> Matrix a b
adjoint = _unop swap I.adjoint
transpose :: I.Elem a b => Matrix a b -> Matrix a b
transpose = _unop swap I.transpose
conjugate :: I.Elem a b => Matrix a b -> Matrix a b
conjugate = _unop id I.conjugate
normalize :: I.Elem a b => Matrix a b -> Matrix a b
normalize (Matrix rows cols vals) = I.performIO $ do
vals <- VS.thaw vals
VSM.unsafeWith vals $ \p ->
I.call $ I.normalize p (I.cast rows) (I.cast cols)
Matrix rows cols <$> VS.unsafeFreeze vals
modify :: I.Elem a b => (forall s. M.MMatrix a b s -> ST s ()) -> Matrix a b -> Matrix a b
modify f (Matrix rows cols vals) = Matrix rows cols (VS.modify (f . M.MMatrix rows cols) vals)
convert :: (I.Elem a b, I.Elem c d) => (a -> c) -> Matrix a b -> Matrix c d
convert f (Matrix rows cols vals) = Matrix rows cols $ VS.map (I.cast . f . I.cast) vals
freeze :: I.Elem a b => PrimMonad m => M.MMatrix a b (PrimState m) -> m (Matrix a b)
freeze (M.MMatrix mrows mcols mvals) = VS.freeze mvals >>= return . Matrix mrows mcols
thaw :: I.Elem a b => PrimMonad m => Matrix a b -> m (M.MMatrix a b (PrimState m))
thaw (Matrix rows cols vals) = VS.thaw vals >>= return . M.MMatrix rows cols
unsafeFreeze :: I.Elem a b => PrimMonad m => M.MMatrix a b (PrimState m) -> m (Matrix a b)
unsafeFreeze (M.MMatrix mrows mcols mvals) = VS.unsafeFreeze mvals >>= return . Matrix mrows mcols
unsafeThaw :: I.Elem a b => PrimMonad m => Matrix a b -> m (M.MMatrix a b (PrimState m))
unsafeThaw (Matrix rows cols vals) = VS.unsafeThaw vals >>= return . M.MMatrix rows cols
unsafeWith :: I.Elem a b => Matrix a b -> (Ptr b -> CInt -> CInt -> IO c) -> IO c
unsafeWith m@(Matrix rows cols vals) f
| not (valid m) = fail "Matrix.unsafeWith: matrix layout is invalid"
| otherwise = VS.unsafeWith vals $ \p -> f p (I.cast rows) (I.cast cols)
encode :: I.Elem a b => Matrix a b -> BSL.ByteString
encode m@(Matrix rows cols vals)
| valid m = BSL.fromChunks [
I.encodeInt (I.magicCode $ VS.head vals),
I.encodeInt (I.cast rows),
I.encodeInt (I.cast cols),
let (fp, fs) = VS.unsafeToForeignPtr0 vals in BSI.PS (castForeignPtr fp) 0 (fs * sizeOf (VS.head vals))]
| otherwise = error "Matrix.encode: matrix layout is invalid"
decode :: I.Elem a b => BSL.ByteString -> Matrix a b
decode st = Matrix rows cols vals where
(rows, cols, vals) = I.performIO $ do
st <- I.openStream st
code <- I.readInt st
when (code /= I.magicCode (VS.head vals)) $
fail "Matrix.decode: wrong matrix type"
rows <- I.cast <$> I.readInt st
cols <- I.cast <$> I.readInt st
BSI.PS fp fo _ <- I.readStream st (rows * cols * sizeOf (VS.head vals))
I.closeStream st
return (rows, cols, VS.unsafeFromForeignPtr0 (I.plusForeignPtr fp fo) (rows * cols))
_prop :: I.Elem a b => (Ptr b -> Ptr b -> CInt -> CInt -> IO CString) -> Matrix a b -> a
_prop f m = I.cast $ I.performIO $ alloca $ \p -> do
I.call $ unsafeWith m (f p)
peek p
_binop :: I.Elem a b => ((Int, Int) -> (Int, Int) -> (Int, Int)) -> (Ptr b -> CInt -> CInt -> Ptr b -> CInt -> CInt -> Ptr b -> CInt -> CInt -> IO CString) -> Matrix a b -> Matrix a b -> Matrix a b
_binop f g m1 m2 = I.performIO $ do
m0 <- uncurry M.new $ f (dims m1) (dims m2)
M.unsafeWith m0 $ \vals0 rows0 cols0 ->
unsafeWith m1 $ \vals1 rows1 cols1 ->
unsafeWith m2 $ \vals2 rows2 cols2 ->
I.call $ g
vals0 rows0 cols0
vals1 rows1 cols1
vals2 rows2 cols2
unsafeFreeze m0
_unop :: I.Elem a b => ((Int,Int) -> (Int,Int)) -> (Ptr b -> CInt -> CInt -> Ptr b -> CInt -> CInt -> IO CString) -> Matrix a b -> Matrix a b
_unop f g m1 = I.performIO $ do
m0 <- uncurry M.new $ f (dims m1)
M.unsafeWith m0 $ \vals0 rows0 cols0 ->
unsafeWith m1 $ \vals1 rows1 cols1 ->
I.call $ g
vals0 rows0 cols0
vals1 rows1 cols1
unsafeFreeze m0
_vals :: I.Elem a b => Matrix a b -> VS.Vector b
_vals (Matrix _ _ vals) = vals