module Data.Matrix.Dense.Class.Views (
submatrixView,
splitRowsAt,
splitColsAt,
unsafeSubmatrixView,
unsafeSplitRowsAt,
unsafeSplitColsAt,
rowViews,
colViews,
rowView,
colView,
diagView,
unsafeRowView,
unsafeColView,
unsafeDiagView,
getDiag,
unsafeGetDiag,
) where
import BLAS.Elem( BLAS1 )
import BLAS.Internal( checkedSubmatrix, checkedRow, checkedCol, checkedDiag )
import BLAS.Tensor( shape )
import BLAS.Matrix.Base( herm )
import Data.Matrix.Dense.Class.Internal
import Data.Vector.Dense.Class.Internal( WriteVector, newCopyVector )
import Foreign
submatrixView :: (BaseMatrix a x, Storable e) => a mn e -> (Int,Int) -> (Int,Int) -> a mn' e
submatrixView a = checkedSubmatrix (shape a) (unsafeSubmatrixView a)
unsafeSubmatrixView :: (BaseMatrix a x, Storable e) =>
a mn e -> (Int,Int) -> (Int,Int) -> a mn' e
unsafeSubmatrixView a (i,j) (m,n)
| isHermMatrix a =
coerceMatrix $ herm $
unsafeSubmatrixView (herm $ coerceMatrix a) (j,i) (n,m)
| otherwise =
let (fp,p,_,_,ld,_) = arrayFromMatrix a
o = indexOfMatrix a (i,j)
p' = p `advancePtr` o
in matrixViewArray fp p' m n ld False
splitRowsAt :: (BaseMatrix a x, Storable e) =>
Int -> a (m,n) e -> (a (m1,n) e, a (m2,n) e)
splitRowsAt m1 a = ( submatrixView a (0,0) (m1,n)
, submatrixView a (m1,0) (m2,n)
)
where
(m,n) = shape a
m2 = m m1
unsafeSplitRowsAt :: (BaseMatrix a x, Storable e) =>
Int -> a (m,n) e -> (a (m1,n) e, a (m2,n) e)
unsafeSplitRowsAt m1 a = ( unsafeSubmatrixView a (0,0) (m1,n)
, unsafeSubmatrixView a (m1,0) (m2,n)
)
where
(m,n) = shape a
m2 = m m1
splitColsAt :: (BaseMatrix a x, Storable e) =>
Int -> a (m,n) e -> (a (m,n1) e, a (m,n2) e)
splitColsAt n1 a = ( submatrixView a (0,0) (m,n1)
, submatrixView a (0,n1) (m,n2)
)
where
(m,n) = shape a
n2 = n n1
unsafeSplitColsAt :: (BaseMatrix a x, Storable e) =>
Int -> a (m,n) e -> (a (m,n1) e, a (m,n2) e)
unsafeSplitColsAt n1 a = ( unsafeSubmatrixView a (0,0) (m,n1)
, unsafeSubmatrixView a (0,n1) (m,n2)
)
where
(m,n) = shape a
n2 = n n1
diagView :: (BaseMatrix a x, Storable e) => a mn e -> Int -> x k e
diagView a = checkedDiag (shape a) (unsafeDiagView a)
rowView :: (BaseMatrix a x, Storable e) => a (m,n) e -> Int -> x n e
rowView a = checkedRow (shape a) (unsafeRowView a)
colView :: (BaseMatrix a x, Storable e) => a (m,n) e -> Int -> x m e
colView a = checkedCol (shape a) (unsafeColView a)
getDiag :: (ReadMatrix a x m, WriteVector y m, BLAS1 e) =>
a mn e -> Int -> m (y k e)
getDiag a = checkedDiag (shape a) (unsafeGetDiag a)
unsafeGetDiag :: (ReadMatrix a x m, WriteVector y m, BLAS1 e) =>
a mn e -> Int -> m (y k e)
unsafeGetDiag a i = newCopyVector (unsafeDiagView a i)