module Numeric.Jalla.Matrix
(
GMatrix(..),
CMatrix(..),
shapeTrans,
MatrixMatrix(..),
MatrixVector(..),
MatrixScalar(..),
module Numeric.Jalla.Indexable,
Matrix,
Order(..),
Transpose(..),
RefVector,
MMM,
createMatrix,
modifyMatrix,
setDiag,
setRow,
setColumn,
setBlock,
fillBlock,
scaleRow,
scaleColumn,
refRow,
refColumn,
matrixMap,
matrixBinMap,
matrixList,
matrixLists,
listMatrix,
matrixAssocs,
gmatrixAssocs,
row,
column,
rows,
columns,
module Numeric.Jalla.IMM,
prettyPrintMatrix,
prettyPrintMatrixIO,
solveLinearSystem,
invert,
pseudoInverse,
frobNorm,
idMatrix,
matrixMultDiag,
svd,
SVD(..),
SVDOpt(..),
SVDU(..),
SVDVT(..),
checkIndex,
inMatrixRange,
diagIndices,
matrixAlloc',
matrixElem,
matrixMult,
unsafeMatrixSetElem,
unsafeMatrixMult,
unsafeMatrixFill,
unsafeMatrixCopy,
unsafeSolveLinearSystem,
unsafeSVD,
unsafeMatrixMap,
unsafeMatrixBinMap,
withCMatrixRow,
withCMatrixColumn,
CFloat,
CDouble,
Complex
) where
import Numeric.Jalla.Foreign.BLAS
import Numeric.Jalla.Foreign.BlasOps
import Numeric.Jalla.Foreign.LAPACKE
import Numeric.Jalla.Foreign.LapackeOps
import Numeric.Jalla.Internal
import Numeric.Jalla.IMM
import Numeric.Jalla.Vector
import Numeric.Jalla.Indexable
import Numeric.Jalla.Types
import Foreign.C.Types
import Foreign.Marshal.Array
import Foreign hiding (unsafePerformIO)
import System.IO.Unsafe (unsafePerformIO)
import Data.Ix
import Data.Complex
import Data.List (partition)
import Data.Maybe (fromJust)
import Control.Applicative
import Control.Monad.State
import Data.Convertible
instance BLASEnum Order CblasOrder where
toBlas RowMajor = CblasRowMajor
toBlas ColumnMajor = CblasColMajor
fromBlas CblasRowMajor = RowMajor
fromBlas CblasColMajor = ColumnMajor
instance BLASEnum Transpose CblasTranspose where
toBlas Trans = CblasTrans
toBlas NoTrans = CblasNoTrans
fromBlas CblasTrans = Trans
fromBlas CblasNoTrans = NoTrans
instance BLASEnum UpLo CblasUplo where
toBlas Up = CblasUpper
toBlas Lo = CblasLower
fromBlas CblasUpper = Up
fromBlas CblasLower = Lo
instance LAPACKEEnum Order Int where
toLapacke e = fromEnum (toBlas e :: CblasOrder)
fromLapacke le = fromBlas (toEnum le :: CblasOrder)
class (Field1 e, Indexable (mat e) IndexPair e) => GMatrix mat e where
shape :: mat e -> Shape
rowCount :: mat e -> Index
colCount :: mat e -> Index
shape m = (rowCount m, colCount m)
rowCount = fst . shape
colCount = snd . shape
infixl 7 ##, ##!
infixl 6 ##-, ##+
class (Field1 e, BlasOps e, GMatrix mat e, CMatrix mat e) => MatrixMatrix mat e where
(##) :: mat e -> mat e -> mat e
(##!) :: (mat e, Transpose) -> (mat e, Transpose) -> mat e
(##+) :: mat e -> mat e -> mat e
(##-) :: mat e -> mat e -> mat e
m1 ## m2 | colCount m1 /= rowCount m2 = error "(##): shape mismatch!"
| otherwise = unsafePerformIO $ matrixMult 1 NoTrans m1 NoTrans m2
(m1,t1) ##! (m2,t2) | colCountTrans t1 s1 /= rowCountTrans t2 s2 = error "(##): shape mismatch!"
| otherwise = unsafePerformIO $ matrixMult 1 t1 m1 t2 m2
where s1 = shape m1
s2 = shape m2
m1 ##+ m2 = matrixBinMap (\a b -> a + b) m1 m2
m1 ##- m2 = matrixBinMap (\a b -> a b) m1 m2
infixl 7 #|,|#
class (CMatrix mat e, CVector vec e) => MatrixVector mat vec e where
(#|) :: mat e -> vec e -> vec e
(|#) :: vec e -> mat e -> vec e
infixl 7 #.*,#./
infixl 6 #.+,#.-
class (Storable e, CMatrix mat e) => MatrixScalar mat e where
(#.*) :: mat e -> e -> mat e
a #.* b = matrixMap (*b) a
(#./) :: mat e -> e -> mat e
a #./ b = matrixMap (/b) a
(#.+) :: mat e -> e -> mat e
a #.+ b = matrixMap (+b) a
(#.-) :: mat e -> e -> mat e
a #.- b = matrixMap (()b) a
class (Storable e, BlasOps e, GMatrix mat e) => CMatrix mat e where
type CMatrixVector mat e :: *
type CMatrixVectorS mat e :: *
matrixAlloc :: Shape -> IO (mat e)
withCMatrix :: mat e -> (Ptr e -> IO a) -> IO a
lda :: mat e -> Index
order :: mat e -> Order
matrixForeignPtr :: mat e -> ForeignPtr e
matrixMap :: (Storable e1, Storable e2, CMatrix mat1 e1, CMatrix mat2 e2) =>
(e1 -> e2)
-> mat1 e1
-> mat2 e2
matrixMap f mat = unsafePerformIO $ let s = shape mat in
matrixAlloc s >>= \m -> unsafeMatrixMap f mat m >> return m
matrixBinMap :: (Storable e1, Storable e2, Storable e3, CMatrix mat1 e1, CMatrix mat2 e2, CMatrix mat3 e3) =>
(e1 -> e2 -> e3)
-> mat1 e1
-> mat2 e2
-> mat3 e3
matrixBinMap f mat1 mat2 = unsafePerformIO $ do
let (m1,n1) = shape mat1
(m2,n2) = shape mat2
m <- matrixAlloc (min m1 m2, min n1 n2)
unsafeMatrixBinMap f mat1 mat2 m
return m
data CMatrixContainer = forall mat a. CMatrix mat a => CMatrixContainer (mat a)
lengthAndInc' :: [CMatrixContainer] -> [(Index, Index, Index, Index)]
lengthAndInc' mas = if nr > nc then as else bs
where
as = map lengthAndInc'' mas
lengthAndInc'' (CMatrixContainer a) = lengthAndInc a
bs = map flipit as
flipit (a,b,c,d) = (b,a,d,c)
(rm,cm) = partition (== RowMajor) os
(nr,nc) = (length rm, length cm)
os = map (\(CMatrixContainer m) -> order m) mas
lengthAndInc :: forall mat a. (CMatrix mat a) => mat a -> (Index, Index, Index, Index)
lengthAndInc ma = case o of
RowMajor -> (n,m,1,ldA)
_ -> (n,m,ldA,1)
where o = order ma
(m,n) = shape ma
ldA = lda ma
unsafeMatrixMap :: (Storable e1, Storable e2, CMatrix mat1 e1, CMatrix mat2 e2) => (e1 -> e2) -> mat1 e1 -> mat2 e2 -> IO ()
unsafeMatrixMap f mat mat' =
let
[(n1,m1,i11,i12),(n2,m2,i21,i22)] = lengthAndInc' [CMatrixContainer mat, CMatrixContainer mat']
in
withCMatrix mat $ \matp -> do
withCMatrix mat' $ \mat'p ->
unsafePtrMapInc2 (i11,i12) (i21,i22) f matp mat'p ((min n1 n2),(min m1 m2))
unsafeMatrixBinMap :: (Storable e1, Storable e2, Storable e3, CMatrix mat1 e1, CMatrix mat2 e2, CMatrix mat3 e3) => (e1 -> e2 -> e3) -> mat1 e1 -> mat2 e2 -> mat3 e3 -> IO ()
unsafeMatrixBinMap f mat mat' mat'' =
let
[(n1,m1,i11,i12),(n2,m2,i21,i22),(n3,m3,i31,i32)] = lengthAndInc' [CMatrixContainer mat, CMatrixContainer mat', CMatrixContainer mat'']
in
withCMatrix mat $ \matp ->
withCMatrix mat' $ \mat'p ->
withCMatrix mat'' $ \mat''p ->
unsafe2PtrMapInc2 (i11,i12) (i21,i22) (i31,i32) f matp mat'p mat''p ((minimum [n1,n2,n3]),(minimum [m1,m2,m3]))
data Matrix e = Matrix { matP :: !(ForeignPtr e),
matShape :: !Shape,
matLDA :: !Index,
matOrder :: !Order }
instance (Num e, Field1 e, BlasOps e) => GMatrix Matrix e where
shape = matShape
instance BlasOps e => MatrixMatrix Matrix e
instance BlasOps e => Indexable (Matrix e) IndexPair e where
m ! ij = unsafePerformIO $ matrixElem m ij
instance (Num e, Field1 e, BlasOps e) => CMatrix Matrix e where
type CMatrixVector Matrix e = Vector e
type CMatrixVectorS Matrix e = Vector (FieldScalar e)
matrixAlloc = matrixAlloc'
withCMatrix = withMatrix'
lda = matLDA
order = matOrder
matrixForeignPtr = matP
withMatrix' :: (BlasOps e) => Matrix e -> (Ptr e -> IO a) -> IO a
withMatrix' m = withForeignPtr (matP m)
instance (BlasOps e, Show e) => Show (Matrix e) where
show mat = "listMatrix (" ++ show m ++ "," ++ show n ++ ") " ++ show ml
where (m,n) = shape mat
ml = matrixList RowMajor mat
instance (BlasOps e, Eq e) => Eq (Matrix e) where
a == b = if (shape a == shape b)
then (and $ zipWith (==) (matrixList RowMajor a) (matrixList RowMajor b))
else False
instance (BlasOps e, Num e) => Num (Matrix e) where
a + b = a ##+ b
a b = a ##- b
a * b = matrixBinMap (*) a b
negate = matrixMap (* (1))
abs = matrixMap abs
signum = matrixMap signum
fromInteger i = createMatrix (1,1) $ setElem (0,0) (fromIntegral i)
instance (BlasOps e, Num e, Fractional e) => Fractional (Matrix e) where
a / b = matrixBinMap (/) a b
recip = matrixMap recip
fromRational r = createMatrix (1,1) $ setElem (0,0) (fromRational r)
instance (BlasOps e, Num e, Fractional e) => Floating (Matrix e) where
pi = createMatrix (1,1) $ setElem (0,0) pi
exp = matrixMap exp
sqrt = matrixMap sqrt
log = matrixMap log
a ** b = matrixBinMap (**) a b
logBase = matrixBinMap logBase
sin = matrixMap sin
tan = matrixMap tan
cos = matrixMap cos
asin = matrixMap asin
atan = matrixMap atan
acos = matrixMap acos
sinh = matrixMap sinh
tanh = matrixMap tanh
cosh = matrixMap cosh
asinh = matrixMap asinh
atanh = matrixMap atanh
acosh = matrixMap acosh
gmatrixAssocs :: (GMatrix mat e) => mat e -> [(IndexPair,e)]
gmatrixAssocs m = zip is $ map (m !) is
where
is = range ((0,0),s)
s = let (r,c) = shape m in (r1,c1)
matrixAssocs :: (BlasOps e, CMatrix mat e) => Order -> mat e -> [(IndexPair, e)]
matrixAssocs o mat = zip r es
where
r | o == RowMajor = [(i,j) | i <- [0..r'], j <- [0..c']]
| otherwise = [(i,j) | j <- [0..c'], i <- [0..r']]
es = matrixList o mat
(r',c') = let (a,b) = shape mat in (a1,b1)
matrixMult :: (BlasOps e, CMatrix mat e) =>
e
-> Transpose
-> mat e
-> Transpose
-> mat e
-> IO (mat e)
matrixMult alpha transA a transB b =
matrixAlloc s >>= \ret ->
unsafeMatrixMult alpha transA a transB b 0 ret >>
return ret
where s = (rowCountTrans transA (shape a), colCountTrans transB (shape b))
unsafeMatrixMult :: (BlasOps e, CMatrix mat e) =>
e
-> Transpose
-> mat e
-> Transpose
-> mat e
-> e
-> mat e
-> IO ()
unsafeMatrixMult alpha transA a transB b beta c = do
when (order a /= order b) $ error "unsafeMatrixMult: order of matrices must be equal."
withCMatrix a $ \pa ->
withCMatrix b $ \pb ->
withCMatrix c $ \pc ->
gemm (toBlas $ order a) transA' transB' m n k alpha pa ldA pb ldB beta pc ldC
where
(m,k) = shapeTrans transA $ shape a
n = colCountTrans transB $ shape b
ldA = lda a
ldB = lda b
ldC = lda c
transA' = toBlas transA
transB' = toBlas transB
unsafeSolveLinearSystem :: (BlasOps e, LapackeOps e se, CMatrix mat e) =>
mat e
-> mat e
-> IO ()
unsafeSolveLinearSystem a b | rowCount a == colCount a && rowCount a == rowCount b =
withCMatrix a $ \pa ->
withCMatrix b $ \pb ->
allocaArray n $ \pipiv ->
gesv (fromEnum ((toBlas $ order a) :: CblasOrder)) n nrhs pa (lda a) pipiv pb (lda b) >>= \ret ->
if ret /= 0 then error "unsafeSolveLinearSystem: ret /= 0" else return ()
where
n = colCount a
nrhs = colCount b
unsafeSolveLinearSystem a b | otherwise = error "unsafeSolveLinearSystem: The shapes of the arguments do not match."
solveLinearSystem :: (BlasOps e, LapackeOps e se, CMatrix mat e) =>
mat e
-> mat e
-> mat e
solveLinearSystem a b = unsafePerformIO $
matrixCopy b NoTrans >>= \x ->
matrixCopy a NoTrans >>= \a' ->
unsafeSolveLinearSystem a' x >> return x
idMatrix :: (BlasOps e, CMatrix mat e) => Index -> mat e
idMatrix n = createMatrix (n,n) $ fill 0 >> setDiag 0 (repeat 1)
invert' :: (BlasOps e, LapackeOps e se, CMatrix mat e) => mat e -> mat e
invert' a | colCount a == rowCount a = solveLinearSystem a (idMatrix $ colCount a)
| otherwise = error "Cannot invert non-square matrix."
invert :: (BlasOps e, LapackeOps e se, CMatrix mat e) => mat e -> Maybe (mat e)
invert a | colCount a == rowCount a = unsafePerformIO $ matrixCopy a NoTrans >>= \a' -> unsafeInvert a'
| otherwise = Nothing
pseudoInverse :: (BlasOps e, se ~ FieldScalar e, BlasOps se, Real se, LapackeOps e se, MatrixMatrix mat e, CMatrix mat e)
=> mat e -> mat e
pseudoInverse a = (matrixMultDiag (vt,Trans) s, NoTrans) ##! (u,Trans)
where svd' = (svd a (SVDU SVDThin, SVDVT SVDThin))
s = map (\x -> if x /= 0 then 1 / (realToFrac x) else 0) $ svdS svd'
u = fromJust $ svdU svd'
vt = fromJust $ svdVT svd'
data RefVector e = RefVector {
refRefP :: !(ForeignPtr e),
refVecP :: !(Ptr e),
refVecInc :: !Index,
refVecLength :: !Index}
instance (Show e, Field1 e, Storable e, BlasOps e) => Show (RefVector e) where
show v = "listVector " ++ show (vectorList v)
instance (BlasOps e, Storable e) => CVector RefVector e where
vectorAlloc = error "No vectorAlloc for RefVector."
withCVector v act = act $ refVecP v
inc = refVecInc
instance (BlasOps e, Storable e) => Indexable (RefVector e) Index e where
v ! i = if i >= 0 && i < refVecLength v
then unsafePerformIO $ withCVector v $ \p -> peek (advancePtr p (i * (refVecInc v)))
else error "RefVector range violation."
instance (Field1 e, Storable e, BlasOps e) => GVector RefVector e where
vectorLength = refVecLength
instance BlasOps e => VectorVector RefVector e
instance BlasOps e => VectorScalar RefVector e
withCMatrixRow :: Storable e => CMatrix mat e => mat e -> Index -> (RefVector e -> IO a) -> IO a
withCMatrixRow mat i act = withCMatrix mat $ \mp -> do
when (i >= m || i < 0) $ error "withCMatrixRow range violation."
let p = advancePtr mp (i * rinc)
act (RefVector { refRefP = (matrixForeignPtr mat), refVecP = p, refVecInc = cinc, refVecLength = n })
where
(m,n) = shape mat
o = order mat
(rinc,cinc) | o == RowMajor = (lda mat, 1)
| otherwise = (1, lda mat)
withCMatrixColumn :: Storable e => CMatrix mat e => mat e -> Index -> (RefVector e -> IO a) -> IO a
withCMatrixColumn mat i act = withCMatrix mat $ \mp -> do
when (i >= n || i < 0) $ error "withCMatrixColumn range violation."
let p = advancePtr mp (i * cinc)
act (RefVector { refRefP = (matrixForeignPtr mat), refVecP = p, refVecInc = rinc, refVecLength = m })
where
(m,n) = shape mat
o = order mat
(rinc,cinc) | o == RowMajor = (lda mat, 1)
| otherwise = (1, lda mat)
columnRef, rowRef :: (CMatrix mat e) => mat e -> Index -> RefVector e
columnRef m i = unsafePerformIO $ withCMatrixColumn m i return
rowRef m i = unsafePerformIO $ withCMatrixRow m i return
rowsRef, columnsRef :: (CMatrix mat e) => mat e -> [RefVector e]
rowsRef m = map (rowRef m) [0..(rowCount m)1]
columnsRef m = map (columnRef m) [0..(colCount m)1]
column, row :: (CMatrix mat e, CVector vec e) => mat e -> Index -> vec e
row m i = unsafePerformIO $ withCMatrixRow m i $ \ref -> copyVector ref
column m i = unsafePerformIO $ withCMatrixColumn m i $ \ref -> copyVector ref
rows, columns :: (CMatrix mat e, CVector vec e) => mat e -> [vec e]
rows m = map (row m) [0..(rowCount m) 1]
columns m = map (column m) [0..(colCount m) 1]
matrixMultDiag :: (BlasOps e) => CMatrix mat e =>
(mat e, Transpose)
-> [e]
-> mat e
matrixMultDiag (a,t) d = modifyMatrix a t $ zipWithM_ scaleColumn d [0..c1]
where sh@(_,c) = shapeTrans t (shape a)
data SVDU = SVDU SVDOpt deriving (Ord, Eq)
data SVDVT = SVDVT SVDOpt deriving (Ord, Eq)
data SVDOpt = SVDFull
| SVDThin
| SVDNone
deriving (Ord, Eq)
svdJob :: SVDOpt -> CChar
svdJob SVDFull = toEnum $ fromEnum 'A'
svdJob SVDThin = toEnum $ fromEnum 'S'
svdJob SVDNone = toEnum $ fromEnum 'N'
svdJobs :: (SVDU, SVDVT) -> (CChar,CChar)
svdJobs (SVDU u,SVDVT vt) = (svdJob u, svdJob vt)
data SVD mat e = SVD {
svdU :: Maybe (mat e)
, svdVT :: Maybe (mat e)
, svdS :: [FieldScalar e] }
unsafeSVD :: (BlasOps e, LapackeOps e se, CVector vec se, CMatrix mat e) =>
mat e
-> (SVDU, SVDVT)
-> vec se
-> mat e
-> mat e
-> IO Int
unsafeSVD a opts s u vt = do
when (inc s /= 1) $ error $ "unsafeSVD: s must have increment 1, but has " ++ show (inc s)
withCMatrix a $ \ap ->
withCVector s $ \sp ->
withCMatrix u $ \up ->
withCMatrix vt $ \vtp ->
mallocForeignPtrArray superb_size >>= \superb' -> withForeignPtr superb' $ \superbp -> do
gesvd mOrder jobu jobvt m n ap (lda a) sp up (lda u) vtp (lda vt) superbp
where (jobu, jobvt) = svdJobs opts
mOrder = toLapacke $ order a
(m,n) = shape a
superb_size = (min m n) 1
svd :: (BlasOps e, se ~ FieldScalar e, BlasOps se, LapackeOps e se, CMatrix mat e) =>
mat e
-> (SVDU, SVDVT)
-> SVD mat e
svd a opts@(SVDU optu, SVDVT optvt) =
unsafePerformIO $ do
matrixCopy a NoTrans >>= \acopy ->
matrixAlloc (shapeU optu) >>= \u ->
matrixAlloc (shapeVT optvt) >>= \vt ->
vectorAlloc len_s >>= \(s :: Vector se) -> do
unsafeSVD acopy opts s u vt
return $ SVD { svdU = if optu /= SVDNone then Just u else Nothing
, svdVT = if optvt /= SVDNone then Just vt else Nothing
, svdS = vectorList s }
where
(m,n) = shape a
len_s = min m n
shapeU SVDFull = (m,m)
shapeU SVDThin = (m, min m n)
shapeU _ = (0,0)
shapeVT SVDFull = (n,n)
shapeVT SVDThin = (min m n, n)
shapeVT _ = (0,0)
unsafeInvert :: (BlasOps e, LapackeOps e se, CMatrix mat e) => mat e -> IO (Maybe (mat e))
unsafeInvert mat = withCMatrix mat $ \mp ->
allocaArray (min m n) $ \ipiv ->
getrf o m n mp ldA ipiv >>= \ret ->
if ret /= 0
then return Nothing
else getri o n mp ldA ipiv >>= \ret ->
if ret /= 0
then return Nothing
else return $ Just mat
where
o = toLapacke $ order mat
ldA = lda mat
(m,n) = shape mat
frobNorm :: (BlasOps e, CMatrix mat e) => mat e -> e
frobNorm mat = sqrt $ sum $ map (\v -> v ||* v) vs
where vs = rowsRef mat
matrixAlloc' :: (BlasOps e) => Shape -> IO (Matrix e)
matrixAlloc' (r,c) = mallocForeignPtrArray (r * c) >>=
\p -> return $ Matrix p (r,c) r ColumnMajor
checkIndex :: Shape -> IndexPair -> Bool
checkIndex (r,c) (i,j) = inRange (0,r1) i && inRange (0,c1) j
inMatrixRange :: (BlasOps e, GMatrix mat e) => mat e -> IndexPair -> Bool
inMatrixRange m i = checkIndex (shape m) i
matrixElem :: (Num e, BlasOps e, CMatrix mat e) => mat e -> IndexPair -> IO e
matrixElem m (i,j) | not (checkIndex s (i,j)) = error $ "matrixElem out of bounds"
where s = shape m
matrixElem m (i,j) | otherwise = withCMatrix m $
\p -> peekElemOff p (i' * (lda m) + j')
where (i',j') | order m == RowMajor = (i,j)
| otherwise = (j,i)
unsafeMatrixSetElem :: (BlasOps e, CMatrix mat e) =>
mat e
-> IndexPair
-> e
-> IO ()
unsafeMatrixSetElem mat (i,j) he | not (checkIndex s (i,j)) = error $ "unsafeMatrixSetElem out of bounds"
where s = shape mat
unsafeMatrixSetElem mat (i,j) he | otherwise = withCMatrix mat $
\p -> pokeElemOff p (i' * (lda mat) + j') he
where (i',j') | order mat == RowMajor = (i,j)
| otherwise = (j,i)
unsafeMatrixSetElems :: (BlasOps e, CMatrix mat e) =>
mat e
-> [(IndexPair, e)]
-> IO ()
unsafeMatrixSetElems mat els = withCMatrix mat $
\p -> mapM_ (setter p) els
where
ld = lda mat
setter' p ((i,j),e) = pokeElemOff p (i * ld + j) e
setter p ((i,j),e) | order mat == RowMajor = setter' p ((i,j),e)
| otherwise = setter' p ((j,i),e)
unsafeMatrixFill :: (Num e, BlasOps e, CMatrix mat e) =>
mat e
-> e
-> IO ()
unsafeMatrixFill m e = let (r,c) = shape m
f p n | n > 0 = let p' = advancePtr p 1
n' = n 1
in poke p e >> f p' n'
| otherwise = return ()
in withCMatrix m (\p' -> f p' (r * c))
unsafeMatrixCopy :: (BlasOps e, CMatrix mat e) =>
mat e
-> Transpose
-> mat e
-> IO ()
unsafeMatrixCopy src t dst | shapeTrans t (shape src) == shape dst =
case t of
NoTrans -> zipWithM_ unsafeCopyVector src_rows dst_rows
Trans -> zipWithM_ unsafeCopyVector src_cols dst_rows
where
src_cols = columnsRef src
src_rows = rowsRef src
dst_rows = rowsRef dst
n = (rowCount src) * (colCount src)
unsafeMatrixCopy _ _ _ | otherwise = error "unsafeMatrixCopy: shape mismatch."
matrixCopy :: (BlasOps e, CMatrix mat e) => mat e -> Transpose -> IO (mat e)
matrixCopy a t = matrixAlloc (shapeTrans t (shape a)) >>= \ret -> unsafeMatrixCopy a t ret >> return ret
matrixMap' :: (BlasOps e1, BlasOps e2, CMatrix mat1 e1, CMatrix mat2 e2) => (e1 -> e2) -> mat1 e1 -> IO (mat2 e2)
matrixMap' f mat = matrixAlloc s >>= \mRet ->
withCMatrix mat $ \p1 ->
withCMatrix mRet $ \p2 ->
unsafePtrMap f p1 p2 n >> return mRet
where
s@(r,c) = shape mat
n = r * c
matrixList :: (GMatrix mat e) => Order -> mat e -> [e]
matrixList o mat | o == RowMajor = [mat ! (i,j) | i <- [0..(r1)], j <- [0..(c1)]]
| o == ColumnMajor = [mat ! (i,j) | j <- [0..(c1)], i <- [0..(r1)]]
where (r,c) = shape mat
matrixLists :: (GMatrix mat e) => mat e -> [[e]]
matrixLists mat = let (r,c) = shape mat
in [[mat ! (i,j) | j <- [0..(c1)]] | i <- [0..(r1)]]
listMatrix :: (BlasOps e, CMatrix mat e) =>
Shape
-> [e]
-> mat e
listMatrix (r,c) l = if c < 0 || r < 0
then error "Negative matrix shape??"
else createMatrix (r,c) $ setElems' $ zip [(i,j) | i <- [0..(r1)], j <- [0..(c1)]] l
prettyPrintMatrix :: (GMatrix mat e) => mat e -> [String]
prettyPrintMatrix m = map ppl $ matrixLists m
where
pp a = show a ++ " "
ppl = concatMap pp
prettyPrintMatrixIO :: (GMatrix mat e) => mat e -> IO ()
prettyPrintMatrixIO m = mapM_ putStrLn $ prettyPrintMatrix m
type MMonad mat e = StateT (mat e) IO
newtype MMM s mat e a = MMM { unMMM :: MMonad mat e a } deriving (Monad, Applicative, Functor)
runMMM :: (BlasOps e, CMatrix mat e) => mat e -> MMM s mat e a -> IO (mat e)
runMMM mat m = matrixAlloc s >>= \ret -> unsafeMatrixCopy mat NoTrans ret >> execStateT (unMMM m) ret
where s = shape mat
instance (BlasOps e, CMatrix mat e) => IMM (MMM s mat e) IndexPair (mat e) e where
setElem = setElem'
setElems = setElems'
fill = fill'
getElem = getElem'
createMatrix :: (BlasOps e, CMatrix mat e) =>
Shape
-> MMM s mat e a
-> mat e
createMatrix s m = unsafePerformIO $ matrixAlloc s >>= execStateT (unMMM m)
unsafeScale :: (BlasOps e, CVector vec e) => e -> vec e -> IO ()
unsafeScale alpha x = withCVector x $ \xp -> scal n alpha xp incx
where n = vectorLength x
incx = inc x
unsafeAccum :: (BlasOps e, CVector vec e) => e -> vec e -> vec e -> IO ()
unsafeAccum = unsafeVectorAdd
modifyMatrix :: (BlasOps e, CMatrix mat e) => mat e -> Transpose -> MMM s mat e a -> mat e
modifyMatrix mat t m = unsafePerformIO $ matrixAlloc s >>= \ret ->
unsafeMatrixCopy mat t ret >> execStateT (unMMM m) ret
where s = shapeTrans t (shape mat)
getMatrix :: (BlasOps e, CMatrix mat e) => MMM s mat e (mat e)
getMatrix = MMM get
refRow :: CMatrix mat e => Index -> MMM s mat e (RefVector e)
refRow i = MMM $ get >>= \m -> liftIO (withCMatrixRow m i return)
refColumn :: CMatrix mat e => Index -> MMM s mat e (RefVector e)
refColumn i = MMM $ get >>= \m -> liftIO (withCMatrixColumn m i return)
scaleRow :: CMatrix mat e =>
e
-> Index
-> MMM s mat e ()
scaleRow alpha i = MMM $ get >>= \m -> do
liftIO $ withCMatrixRow m i (unsafeScale alpha)
scaleColumn :: CMatrix mat e => e -> Index -> MMM s mat e ()
scaleColumn alpha i = MMM $ get >>= \m -> do
liftIO $ withCMatrixColumn m i (unsafeScale alpha)
setElem' :: (BlasOps e, CMatrix mat e) => IndexPair -> e -> MMM s mat e ()
setElem' (i,j) a = MMM $ get >>= \m -> liftIO (unsafeMatrixSetElem m (i,j) a)
fill' :: (BlasOps e, CMatrix mat e) => e -> MMM s mat e ()
fill' a = MMM $ get >>= \m -> liftIO $ unsafeMatrixFill m a
setDiag :: (BlasOps e, CMatrix mat e) =>
Index
-> [e]
-> MMM s mat e ()
setDiag d as = MMM $ get >>= \m ->
let (r,c) = shape m
idxs = diagIndices (r,c) d
in
case idxs of
[] -> return ()
ijs -> setDiag' m ijs as
where
setDiag' :: (BlasOps e, CMatrix mat e) => mat e -> [IndexPair] -> [e] -> MMonad mat e ()
setDiag' m ijs as = liftIO $ unsafeMatrixSetElems m $ zip ijs as
setElems' :: (BlasOps e, CMatrix mat e) => [(IndexPair,e)] -> MMM s mat e ()
setElems' els = MMM $ get >>= \m -> liftIO $ unsafeMatrixSetElems m els
setRow :: (BlasOps e, CMatrix mat e) =>
Index
-> [e]
-> MMM s mat e ()
setRow i as = fmap shape getMatrix >>= \(_,c) -> setElems $ zip (range ((i,0),(i,c1))) as
setColumn :: (BlasOps e, CMatrix mat e) =>
Index
-> [e]
-> MMM s mat e ()
setColumn i as = fmap shape getMatrix >>= \(r,_) -> setElems $ zip (range ((0,i),(r1,i))) as
setBlock :: (BlasOps e, CMatrix mat e) =>
IndexPair
-> mat e
-> MMM s mat e ()
setBlock (i,j) mat = getMatrix >>= \m -> setElems (a m)
where
a m = as m
is'' = range ((0,0),(r,c))
is' = map (\(a,b) -> (a+i,b+j)) is''
es = matrixList RowMajor mat
as m = filter (\(ij,_) -> inRange ((0,0),s) ij) (zip is' es)
where
s = let (r,c) = shape m in (r1,c1)
(r,c) = let (a,b) = shape mat in (a1,b1)
fillBlock :: (BlasOps e, CMatrix mat e) =>
IndexPair
-> IndexPair
-> e
-> MMM s mat e ()
fillBlock start end = setElems . zip (range (start,end)) . repeat
getElem' :: (BlasOps e, CMatrix mat e) => IndexPair -> MMM s mat e e
getElem' ij = MMM $ get >>= \m -> liftIO $ matrixElem m ij