module Data.Packed.Internal.Matrix(
Matrix(..), rows, cols,
MatrixOrder(..), orderOf,
createMatrix, mat,
cmat, fmat,
toLists, flatten, reshape,
Element(..),
trans,
fromRows, toRows, fromColumns, toColumns,
matrixFromVector,
subMatrix,
liftMatrix, liftMatrix2,
(@@>),
saveMatrix,
singleton,
size, shSize, conformVs, conformMs, conformVTo, conformMTo
) where
import Data.Packed.Internal.Common
import Data.Packed.Internal.Signatures
import Data.Packed.Internal.Vector
import Foreign.Marshal.Alloc(alloca, free)
import Foreign.Marshal.Array(newArray)
import Foreign.Ptr(Ptr, castPtr)
import Foreign.Storable(Storable, peekElemOff, pokeElemOff, poke, sizeOf)
import Data.Complex(Complex)
import Foreign.C.Types(CInt, CChar)
import Foreign.C.String(newCString)
import System.IO.Unsafe(unsafePerformIO)
data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
data Matrix t = MC { irows :: !Int
, icols :: !Int
, cdat :: !(Vector t) }
| MF { irows :: !Int
, icols :: !Int
, fdat :: !(Vector t) }
rows :: Matrix t -> Int
rows = irows
cols :: Matrix t -> Int
cols = icols
xdat MC {cdat = d } = d
xdat MF {fdat = d } = d
orderOf :: Matrix t -> MatrixOrder
orderOf MF{} = ColumnMajor
orderOf MC{} = RowMajor
trans :: Matrix t -> Matrix t
trans MC {irows = r, icols = c, cdat = d } = MF {irows = c, icols = r, fdat = d }
trans MF {irows = r, icols = c, fdat = d } = MC {irows = c, icols = r, cdat = d }
cmat :: (Element t) => Matrix t -> Matrix t
cmat m@MC{} = m
cmat MF {irows = r, icols = c, fdat = d } = MC {irows = r, icols = c, cdat = transdata r d c}
fmat :: (Element t) => Matrix t -> Matrix t
fmat m@MF{} = m
fmat MC {irows = r, icols = c, cdat = d } = MF {irows = r, icols = c, fdat = transdata c d r}
mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
mat a f =
unsafeWith (xdat a) $ \p -> do
let m g = do
g (fi (rows a)) (fi (cols a)) p
f m
flatten :: Element t => Matrix t -> Vector t
flatten = cdat . cmat
type Mt t s = Int -> Int -> Ptr t -> s
toLists :: (Element t) => Matrix t -> [[t]]
toLists m = splitEvery (cols m) . toList . flatten $ m
fromRows :: Element t => [Vector t] -> Matrix t
fromRows vs = case compatdim (map dim vs) of
Nothing -> error "fromRows applied to [] or to vectors with different sizes"
Just c -> reshape c . join . map (adapt c) $ vs
where
adapt c v | dim v == c = v
| otherwise = constantD (v@>0) c
toRows :: Element t => Matrix t -> [Vector t]
toRows m = toRows' 0 where
v = flatten m
r = rows m
c = cols m
toRows' k | k == r*c = []
| otherwise = subVector k c v : toRows' (k+c)
fromColumns :: Element t => [Vector t] -> Matrix t
fromColumns m = trans . fromRows $ m
toColumns :: Element t => Matrix t -> [Vector t]
toColumns m = toRows . trans $ m
(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
infixl 9 @@>
MC {irows = r, icols = c, cdat = v} @@> (i,j)
| safe = if i<0 || i>=r || j<0 || j>=c
then error "matrix indexing out of range"
else v `at` (i*c+j)
| otherwise = v `at` (i*c+j)
MF {irows = r, icols = c, fdat = v} @@> (i,j)
| safe = if i<0 || i>=r || j<0 || j>=c
then error "matrix indexing out of range"
else v `at` (j*r+i)
| otherwise = v `at` (j*r+i)
atM' MC {icols = c, cdat = v} i j = v `at'` (i*c+j)
atM' MF {irows = r, fdat = v} i j = v `at'` (j*r+i)
matrixFromVector RowMajor c v = MC { irows = r, icols = c, cdat = v }
where (d,m) = dim v `divMod` c
r | m==0 = d
| otherwise = error "matrixFromVector"
matrixFromVector ColumnMajor c v = MF { irows = r, icols = c, fdat = v }
where (d,m) = dim v `divMod` c
r | m==0 = d
| otherwise = error "matrixFromVector"
createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
createMatrix order r c = do
p <- createVector (r*c)
return (matrixFromVector order c p)
reshape :: Storable t => Int -> Vector t -> Matrix t
reshape c v = matrixFromVector RowMajor c v
singleton x = reshape 1 (fromList [x])
liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
liftMatrix f MC { icols = c, cdat = d } = matrixFromVector RowMajor c (f d)
liftMatrix f MF { icols = c, fdat = d } = matrixFromVector ColumnMajor c (f d)
liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
liftMatrix2 f m1 m2
| not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
| otherwise = case m1 of
MC {} -> matrixFromVector RowMajor (cols m1) (f (cdat m1) (flatten m2))
MF {} -> matrixFromVector ColumnMajor (cols m1) (f (fdat m1) ((fdat.fmat) m2))
compat :: Matrix a -> Matrix b -> Bool
compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
class (Storable a) => Element a where
subMatrixD :: (Int,Int)
-> (Int,Int)
-> Matrix a -> Matrix a
subMatrixD = subMatrix'
transdata :: Int -> Vector a -> Int -> Vector a
transdata = transdataP
constantD :: a -> Int -> Vector a
constantD = constantP
instance Element Float where
transdata = transdataAux ctransF
constantD = constantAux cconstantF
instance Element Double where
transdata = transdataAux ctransR
constantD = constantAux cconstantR
instance Element (Complex Float) where
transdata = transdataAux ctransQ
constantD = constantAux cconstantQ
instance Element (Complex Double) where
transdata = transdataAux ctransC
constantD = constantAux cconstantC
transdata' :: Storable a => Int -> Vector a -> Int -> Vector a
transdata' c1 v c2 =
if noneed
then v
else unsafePerformIO $ do
w <- createVector (r2*c2)
unsafeWith v $ \p ->
unsafeWith w $ \q -> do
let go (1) _ = return ()
go !i (1) = go (i1) (c11)
go !i !j = do x <- peekElemOff p (i*c1+j)
pokeElemOff q (j*c2+i) x
go i (j1)
go (r11) (c11)
return w
where r1 = dim v `div` c1
r2 = dim v `div` c2
noneed = r1 == 1 || c1 == 1
transdataAux fun c1 d c2 =
if noneed
then d
else unsafePerformIO $ do
v <- createVector (dim d)
unsafeWith d $ \pd ->
unsafeWith v $ \pv ->
fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux"
return v
where r1 = dim d `div` c1
r2 = dim d `div` c2
noneed = r1 == 1 || c1 == 1
transdataP :: Storable a => Int -> Vector a -> Int -> Vector a
transdataP c1 d c2 =
if noneed
then d
else unsafePerformIO $ do
v <- createVector (dim d)
unsafeWith d $ \pd ->
unsafeWith v $ \pv ->
ctransP (fi r1) (fi c1) (castPtr pd) (fi sz) (fi r2) (fi c2) (castPtr pv) (fi sz) // check "transdataP"
return v
where r1 = dim d `div` c1
r2 = dim d `div` c2
sz = sizeOf (d @> 0)
noneed = r1 == 1 || c1 == 1
foreign import ccall "transF" ctransF :: TFMFM
foreign import ccall "transR" ctransR :: TMM
foreign import ccall "transQ" ctransQ :: TQMQM
foreign import ccall "transC" ctransC :: TCMCM
foreign import ccall "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt
constant' v n = unsafePerformIO $ do
w <- createVector n
unsafeWith w $ \p -> do
let go (1) = return ()
go !k = pokeElemOff p k v >> go (k1)
go (n1)
return w
constantAux fun x n = unsafePerformIO $ do
v <- createVector n
px <- newArray [x]
app1 (fun px) vec v "constantAux"
free px
return v
constantF :: Float -> Int -> Vector Float
constantF = constantAux cconstantF
foreign import ccall "constantF" cconstantF :: Ptr Float -> TF
constantR :: Double -> Int -> Vector Double
constantR = constantAux cconstantR
foreign import ccall "constantR" cconstantR :: Ptr Double -> TV
constantQ :: Complex Float -> Int -> Vector (Complex Float)
constantQ = constantAux cconstantQ
foreign import ccall "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV
constantC :: Complex Double -> Int -> Vector (Complex Double)
constantC = constantAux cconstantC
foreign import ccall "constantC" cconstantC :: Ptr (Complex Double) -> TCV
constantP :: Storable a => a -> Int -> Vector a
constantP a n = unsafePerformIO $ do
let sz = sizeOf a
v <- createVector n
unsafeWith v $ \p -> do
alloca $ \k -> do
poke k a
cconstantP (castPtr k) (fi n) (castPtr p) (fi sz) // check "constantP"
return v
foreign import ccall "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -> CInt -> IO CInt
subMatrix :: Element a
=> (Int,Int)
-> (Int,Int)
-> Matrix a
-> Matrix a
subMatrix (r0,c0) (rt,ct) m
| 0 <= r0 && 0 < rt && r0+rt <= (rows m) &&
0 <= c0 && 0 < ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m
| otherwise = error $ "wrong subMatrix "++
show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do
w <- createVector (rt*ct)
unsafeWith v $ \p ->
unsafeWith w $ \q -> do
let go (1) _ = return ()
go !i (1) = go (i1) (ct1)
go !i !j = do x <- peekElemOff p ((i+r0)*c+j+c0)
pokeElemOff q (i*ct+j) x
go i (j1)
go (rt1) (ct1)
return w
subMatrix' (r0,c0) (rt,ct) (MC _r c v) = MC rt ct $ subMatrix'' (r0,c0) (rt,ct) c v
subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m)
saveMatrix :: FilePath
-> String
-> Matrix Double
-> IO ()
saveMatrix filename fmt m = do
charname <- newCString filename
charfmt <- newCString fmt
let o = if orderOf m == RowMajor then 1 else 0
app1 (matrix_fprintf charname charfmt o) mat m "matrix_fprintf"
free charname
free charfmt
foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM
conformMs ms = map (conformMTo (r,c)) ms
where
r = maximum (map rows ms)
c = maximum (map cols ms)
conformVs vs = map (conformVTo n) vs
where
n = maximum (map dim vs)
conformMTo (r,c) m
| size m == (r,c) = m
| size m == (1,1) = reshape c (constantD (m@@>(0,0)) (r*c))
| size m == (r,1) = repCols c m
| size m == (1,c) = repRows r m
| otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
conformVTo n v
| dim v == n = v
| dim v == 1 = constantD (v@>0) n
| otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
repRows n x = fromRows (replicate n (flatten x))
repCols n x = fromColumns (replicate n (flatten x))
size m = (rows m, cols m)
shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"