module Data.Matrix.Dense.IOBase
where
import Control.Monad
import Foreign
import System.IO.Unsafe
import BLAS.Internal( diagLen )
import Data.Elem.BLAS( Complex, Elem, BLAS1, conjugate )
import qualified Data.Elem.BLAS.Level1 as BLAS
import Data.Matrix.Class
import Data.Tensor.Class
import Data.Tensor.Class.MTensor
import Data.Vector.Dense.IOBase
data IOMatrix np e =
IOMatrix !(ForeignPtr e)
!(Ptr e)
!Int
!Int
!Int
!Bool
matrixViewArray :: (Elem e)
=> ForeignPtr e
-> Int
-> (Int,Int)
-> IOMatrix (n,p) e
matrixViewArray f o (m,n) = matrixViewArrayWithLda m f o (m,n)
matrixViewArrayWithLda :: (Elem e)
=> Int
-> ForeignPtr e
-> Int
-> (Int,Int)
-> IOMatrix (n,p) e
matrixViewArrayWithLda l f o (m,n) =
let p = unsafeForeignPtrToPtr f `advancePtr` o
in IOMatrix f p m n l False
numRowsIOMatrix :: IOMatrix np e -> Int
numRowsIOMatrix (IOMatrix _ _ m _ _ _) = m
numColsIOMatrix :: IOMatrix np e -> Int
numColsIOMatrix (IOMatrix _ _ _ n _ _) = n
ldaMatrixIOMatrix :: IOMatrix np e -> Int
ldaMatrixIOMatrix (IOMatrix _ _ _ _ l _) = l
isHermIOMatrix :: IOMatrix np e -> Bool
isHermIOMatrix (IOMatrix _ _ _ _ _ h) = h
hermIOMatrix :: IOMatrix np e -> IOMatrix nm e
hermIOMatrix (IOMatrix f p m n l h) = (IOMatrix f p n m l (not h))
unsafeSubmatrixViewIOMatrix :: (Elem e) =>
IOMatrix np e -> (Int,Int) -> (Int,Int) -> IOMatrix np' e
unsafeSubmatrixViewIOMatrix (IOMatrix f p _ _ l h) (i,j) (m',n') =
let o = if h then i*l+j else i+j*l
p' = p `advancePtr` o
in IOMatrix f p' m' n' l h
unsafeRowViewIOMatrix :: (Elem e) => IOMatrix np e -> Int -> IOVector p e
unsafeRowViewIOMatrix (IOMatrix f p _ n l h) i =
let (o,s) = if h then (i*l,1) else (i,l)
p' = p `advancePtr` o
in IOVector f p' n s h
unsafeColViewIOMatrix :: (Elem e) => IOMatrix np e -> Int -> IOVector n e
unsafeColViewIOMatrix (IOMatrix f p m _ l h) j =
let (o,s) = if h then (j,l) else (j*l,1)
p' = p `advancePtr` o
in IOVector f p' m s h
unsafeDiagViewIOMatrix :: (Elem e) => IOMatrix np e -> Int -> IOVector k e
unsafeDiagViewIOMatrix (IOMatrix f p m n l h) i =
let o = if i >= 0
then if h then i else i*l
else if h then i*l else i
p' = p `advancePtr` o
k = diagLen (m,n) i
s = l+1
in IOVector f p' k s h
maybeViewVectorAsRowIOMatrix :: (Elem e) => IOVector p e -> Maybe (IOMatrix p1 e)
maybeViewVectorAsRowIOMatrix (IOVector f p n s c)
| c && (s == 1) =
Just $ IOMatrix f p 1 n (max 1 n) True
| not c =
Just $ IOMatrix f p 1 n s False
| otherwise =
Nothing
maybeViewVectorAsColIOMatrix :: (Elem e) => IOVector n e -> Maybe (IOMatrix n1 e)
maybeViewVectorAsColIOMatrix (IOVector f p n s c)
| c =
Just $ IOMatrix f p n 1 s True
| s == 1 =
Just $ IOMatrix f p n 1 (max 1 n) False
| otherwise =
Nothing
maybeViewIOMatrixAsVector :: (Elem e) => IOMatrix np e -> Maybe (IOVector k e)
maybeViewIOMatrixAsVector (IOMatrix f p m n l h)
| h = Nothing
| l /= m = Nothing
| otherwise = Just $ IOVector f p (m*n) 1 False
maybeViewVectorAsIOMatrix :: (Elem e) => (Int,Int) -> IOVector k e -> Maybe (IOMatrix np e)
maybeViewVectorAsIOMatrix (m,n) (IOVector f p k inc c)
| m*n /= k =
error $ "maybeViewVectorAsMatrix " ++ show (m,n)
++ " <vector of dim " ++ show k ++ ">: vector dimension"
++ " must equal product of specified dimensions"
| c = Nothing
| inc /= 1 = Nothing
| otherwise = Just $ IOMatrix f p m n m False
liftIOMatrix :: (Elem e) => (forall n. IOVector n e -> IO ()) -> IOMatrix np e -> IO ()
liftIOMatrix g (IOMatrix f p m n l h)
| h && (l == n) =
g (IOVector f p (m*n) 1 True)
| (not h) && (l == m) =
g (IOVector f p (m*n) 1 False)
| otherwise =
let (m',n') = if h then (n,m) else (m,n)
end = p `advancePtr` (n'*l)
go p' | p' == end = return ()
| otherwise = do
g (IOVector f p' m' 1 h)
go (p' `advancePtr` l)
in go p
liftIOMatrix2 :: (Elem e, Elem f) =>
(forall k. IOVector k e -> IOVector k f -> IO ()) ->
IOMatrix np e -> IOMatrix np f -> IO ()
liftIOMatrix2 f a b =
if isHermIOMatrix a == isHermIOMatrix b
then case (maybeViewIOMatrixAsVector a, maybeViewIOMatrixAsVector b) of
((Just x), (Just y)) -> f x y
_ -> elementwise
else elementwise
where
elementwise =
let vecsA = if isHermIOMatrix a then rowViews
else colViews
vecsB = if isHermIOMatrix a then rowViews
else colViews
xs = vecsA a
ys = vecsB b
in zipWithM_ f xs ys
rowViews c = [ unsafeRowViewIOMatrix c i | i <- [ 0..numRows c ] ]
colViews c = [ unsafeColViewIOMatrix c j | j <- [ 0..numCols c ] ]
withIOMatrix :: IOMatrix (n,p) e -> (Ptr e -> IO a) -> IO a
withIOMatrix (IOMatrix f p _ _ _ _) g = do
a <- g p
touchForeignPtr f
return a
newIOMatrix_ :: (Elem e) => (Int,Int) -> IO (IOMatrix np e)
newIOMatrix_ (m,n)
| m < 0 || n < 0 =
fail $
"Tried to create a matrix with shape `" ++ show (m,n) ++ "'"
| otherwise = do
f <- mallocForeignPtrArray (m*n)
return (IOMatrix f (unsafeForeignPtrToPtr f) m n (max 1 m) False)
newCopyIOMatrix :: (BLAS1 e) => IOMatrix np e -> IO (IOMatrix np e)
newCopyIOMatrix (IOMatrix f p m n l h) =
let (m',n') = if h then (n,m) else (m,n)
l' = max 1 m'
in do
(IOMatrix f' p' _ _ _ _) <- newIOMatrix_ (m',n')
if l == m'
then do
BLAS.copy (m*n) p 1 p' 1
else
let go src dst i | i == n' = return ()
| otherwise = do
BLAS.copy m' src 1 dst 1
go (src `advancePtr` l) (dst `advancePtr` l') (i+1)
in go p p' 0
touchForeignPtr f
touchForeignPtr f'
return (IOMatrix f' p' m n l' h)
shapeIOMatrix :: IOMatrix np e -> (Int,Int)
shapeIOMatrix (IOMatrix _ _ m n _ _) = (m,n)
boundsIOMatrix :: IOMatrix np e -> ((Int,Int), (Int,Int))
boundsIOMatrix a = ((0,0), (m1,n1)) where (m,n) = shapeIOMatrix a
sizeIOMatrix :: IOMatrix np e -> Int
sizeIOMatrix (IOMatrix _ _ m n _ _) = m*n
getSizeIOMatrix :: IOMatrix np e -> IO Int
getSizeIOMatrix = return . sizeIOMatrix
getMaxSizeIOMatrix :: IOMatrix np e -> IO Int
getMaxSizeIOMatrix = getSizeIOMatrix
indicesIOMatrix :: IOMatrix np e -> [(Int,Int)]
indicesIOMatrix (IOMatrix _ _ m n _ h)
| h = [ (i,j) | i <- [ 0..m1 ], j <- [ 0..n1 ] ]
| otherwise = [ (i,j) | j <- [ 0..n1 ], i <- [ 0..m1 ] ]
getIndicesIOMatrix :: IOMatrix np e -> IO [(Int,Int)]
getIndicesIOMatrix = return . indicesIOMatrix
getIndicesIOMatrix' :: IOMatrix np e -> IO [(Int,Int)]
getIndicesIOMatrix' = getIndicesIOMatrix
getElemsIOMatrix :: (Elem e) => IOMatrix np e -> IO [e]
getElemsIOMatrix (IOMatrix f p m n l h)
| h = liftM (map conjugate) $
getElemsIOMatrix (IOMatrix f p n m l False)
| l == m = getElemsIOVector (IOVector f p (m*n) 1 False)
| otherwise =
let end = p `advancePtr` (n*l)
go p' | p' == end = return []
| otherwise = unsafeInterleaveIO $ do
c <- getElemsIOVector (IOVector f p' m 1 False)
cs <- go (p' `advancePtr` l)
return (c ++ cs)
in go p
getElemsIOMatrix' :: (Elem e) => IOMatrix np e -> IO [e]
getElemsIOMatrix' (IOMatrix f p m n l h)
| h = liftM (map conjugate) $
getElemsIOMatrix' (IOMatrix f p n m l False)
| l == m = getElemsIOVector' (IOVector f p (m*n) 1 False)
| otherwise =
let end = p `advancePtr` (n*l)
go p' | p' == end = return []
| otherwise = do
c <- getElemsIOVector' (IOVector f p' m 1 False)
cs <- go (p' `advancePtr` l)
return (c ++ cs)
in go p
getAssocsIOMatrix :: (Elem e) => IOMatrix np e -> IO [((Int,Int),e)]
getAssocsIOMatrix a = do
is <- getIndicesIOMatrix a
es <- getElemsIOMatrix a
return $ zip is es
getAssocsIOMatrix' :: (Elem e) => IOMatrix np e -> IO [((Int,Int),e)]
getAssocsIOMatrix' a = do
is <- getIndicesIOMatrix' a
es <- getElemsIOMatrix' a
return $ zip is es
unsafeReadElemIOMatrix :: (Elem e) => IOMatrix np e -> (Int,Int) -> IO e
unsafeReadElemIOMatrix (IOMatrix f p _ _ l h) (i,j)
| h = do
e <- liftM conjugate $ peekElemOff p (i*l+j)
touchForeignPtr f
return e
| otherwise = do
e <- peekElemOff p (i+j*l)
touchForeignPtr f
return e
canModifyElemIOMatrix :: IOMatrix np e -> (Int,Int) -> IO Bool
canModifyElemIOMatrix _ _ = return True
unsafeWriteElemIOMatrix :: (Elem e) =>
IOMatrix np e -> (Int,Int) -> e -> IO ()
unsafeWriteElemIOMatrix (IOMatrix f p _ _ l h) (i,j) e
| h = do
pokeElemOff p (i*l+j) (conjugate e)
touchForeignPtr f
| otherwise = do
pokeElemOff p (i+j*l) e
touchForeignPtr f
unsafeModifyElemIOMatrix :: (Elem e) =>
IOMatrix n e -> (Int,Int) -> (e -> e) -> IO ()
unsafeModifyElemIOMatrix (IOMatrix f p _ _ l h) (i,j) g =
let g' = if h then conjugate . g . conjugate else g
p' = if h then p `advancePtr` (i*l+j) else p `advancePtr` (i+j*l)
in do
e <- peek p'
poke p' (g' e)
touchForeignPtr f
unsafeSwapElemsIOMatrix :: (Elem e) =>
IOMatrix n e -> (Int,Int) -> (Int,Int) -> IO ()
unsafeSwapElemsIOMatrix (IOMatrix f p _ _ l h) (i1,j1) (i2,j2) =
let (p1,p2) = if h then (p `advancePtr` (i1*l+j1), p `advancePtr` (i2*l+j2))
else (p `advancePtr` (i1+j1*l), p `advancePtr` (i2+j2*l))
in do
e1 <- peek p1
e2 <- peek p2
poke p2 e1
poke p1 e2
touchForeignPtr f
modifyWithIOMatrix :: (Elem e) => (e -> e) -> IOMatrix np e -> IO ()
modifyWithIOMatrix g = liftIOMatrix (modifyWithIOVector g)
setZeroIOMatrix :: (Elem e) => IOMatrix np e -> IO ()
setZeroIOMatrix = liftIOMatrix setZeroIOVector
setConstantIOMatrix :: (Elem e) => e -> IOMatrix np e -> IO ()
setConstantIOMatrix k = liftIOMatrix (setConstantIOVector k)
doConjIOMatrix :: (BLAS1 e) => IOMatrix np e -> IO ()
doConjIOMatrix = liftIOMatrix doConjIOVector
scaleByIOMatrix :: (BLAS1 e) => e -> IOMatrix np e -> IO ()
scaleByIOMatrix k = liftIOMatrix (scaleByIOVector k)
shiftByIOMatrix :: (Elem e) => e -> IOMatrix np e -> IO ()
shiftByIOMatrix k = liftIOMatrix (shiftByIOVector k)
instance Shaped IOMatrix (Int,Int) where
shape = shapeIOMatrix
bounds = boundsIOMatrix
instance (Elem e) => ReadTensor IOMatrix (Int,Int) e IO where
getSize = getSizeIOMatrix
unsafeReadElem = unsafeReadElemIOMatrix
getIndices = getIndicesIOMatrix
getIndices' = getIndicesIOMatrix'
getElems = getElemsIOMatrix
getElems' = getElemsIOMatrix'
getAssocs = getAssocsIOMatrix
getAssocs' = getAssocsIOMatrix'
instance (BLAS1 e) => WriteTensor IOMatrix (Int,Int) e IO where
getMaxSize = getMaxSizeIOMatrix
setZero = setZeroIOMatrix
setConstant = setConstantIOMatrix
canModifyElem = canModifyElemIOMatrix
unsafeWriteElem = unsafeWriteElemIOMatrix
unsafeModifyElem = unsafeModifyElemIOMatrix
modifyWith = modifyWithIOMatrix
doConj = doConjIOMatrix
scaleBy = scaleByIOMatrix
shiftBy = shiftByIOMatrix
instance HasVectorView IOMatrix where
type VectorView IOMatrix = IOVector
instance MatrixShaped IOMatrix where
herm = hermIOMatrix