module Data.Matrix.Dense.IOBase
where
import Control.Monad
import Foreign
import System.IO.Unsafe
import BLAS.Internal( diagLen )
import Data.Elem.BLAS( Complex, Elem, BLAS1, conjugate )
import qualified Data.Elem.BLAS.Level1 as BLAS
import Data.Matrix.Class
import Data.Tensor.Class
import Data.Tensor.Class.MTensor
import Data.Vector.Dense.IOBase
data IOMatrix np e =
IOMatrix !TransEnum
!Int
!Int
!(ForeignPtr e)
!(Ptr e)
!Int
matrixViewArray :: (Elem e)
=> ForeignPtr e
-> Int
-> (Int,Int)
-> IOMatrix (n,p) e
matrixViewArray f o (m,n) = matrixViewArrayWithLda (max 1 m) f o (m,n)
matrixViewArrayWithLda :: (Elem e)
=> Int
-> ForeignPtr e
-> Int
-> (Int,Int)
-> IOMatrix (n,p) e
matrixViewArrayWithLda l f o (m,n) =
let p = unsafeForeignPtrToPtr f `advancePtr` o
in IOMatrix NoTrans m n f p l
numRowsIOMatrix :: IOMatrix np e -> Int
numRowsIOMatrix (IOMatrix _ m _ _ _ _) = m
numColsIOMatrix :: IOMatrix np e -> Int
numColsIOMatrix (IOMatrix _ _ n _ _ _) = n
ldaMatrixIOMatrix :: IOMatrix np e -> Int
ldaMatrixIOMatrix (IOMatrix _ _ _ _ _ l) = l
transEnumIOMatrix :: IOMatrix np e -> TransEnum
transEnumIOMatrix (IOMatrix h _ _ _ _ _) = h
isHermIOMatrix :: IOMatrix np e -> Bool
isHermIOMatrix = (ConjTrans ==) . transEnumIOMatrix
hermIOMatrix :: IOMatrix np e -> IOMatrix nm e
hermIOMatrix (IOMatrix h m n f p l) = (IOMatrix (flipTrans h) n m f p l)
unsafeSubmatrixViewIOMatrix :: (Elem e) =>
IOMatrix np e -> (Int,Int) -> (Int,Int) -> IOMatrix np' e
unsafeSubmatrixViewIOMatrix (IOMatrix h _ _ f p l) (i,j) (m',n') =
let o = if h == ConjTrans then i*l+j else i+j*l
p' = p `advancePtr` o
in IOMatrix h m' n' f p' l
unsafeRowViewIOMatrix :: (Elem e) => IOMatrix np e -> Int -> IOVector p e
unsafeRowViewIOMatrix (IOMatrix h _ n f p l) i =
let (c,o,s) = if h == ConjTrans then (Conj,i*l,1) else (NoConj,i,l)
p' = p `advancePtr` o
in IOVector c n f p' s
unsafeColViewIOMatrix :: (Elem e) => IOMatrix np e -> Int -> IOVector n e
unsafeColViewIOMatrix (IOMatrix h m _ f p l) j =
let (c,o,s) = if h == ConjTrans then (Conj,j,l) else (NoConj,j*l,1)
p' = p `advancePtr` o
in IOVector c m f p' s
unsafeDiagViewIOMatrix :: (Elem e) => IOMatrix np e -> Int -> IOVector k e
unsafeDiagViewIOMatrix (IOMatrix h m n f p l) i =
let o = if i >= 0
then if h == ConjTrans then i else i*l
else if h == ConjTrans then i*l else i
c = if h == ConjTrans then Conj else NoConj
p' = p `advancePtr` o
k = diagLen (m,n) i
s = l+1
in IOVector c k f p' s
maybeViewVectorAsRowIOMatrix :: (Elem e) => IOVector p e -> Maybe (IOMatrix p1 e)
maybeViewVectorAsRowIOMatrix (IOVector c n f p s)
| c == Conj && s == 1 =
Just $ IOMatrix ConjTrans 1 n f p (max 1 n)
| c == NoConj =
Just $ IOMatrix NoTrans 1 n f p s
| otherwise =
Nothing
maybeViewVectorAsColIOMatrix :: (Elem e) => IOVector n e -> Maybe (IOMatrix n1 e)
maybeViewVectorAsColIOMatrix (IOVector c n f p s)
| c == Conj =
Just $ IOMatrix ConjTrans n 1 f p s
| s == 1 =
Just $ IOMatrix NoTrans n 1 f p (max 1 n)
| otherwise =
Nothing
maybeViewIOMatrixAsVector :: (Elem e) => IOMatrix np e -> Maybe (IOVector k e)
maybeViewIOMatrixAsVector (IOMatrix h m n f p l)
| h == ConjTrans = Nothing
| l /= m = Nothing
| otherwise = Just $ IOVector NoConj (m*n) f p 1
maybeViewVectorAsIOMatrix :: (Elem e) => (Int,Int) -> IOVector k e -> Maybe (IOMatrix np e)
maybeViewVectorAsIOMatrix (m,n) (IOVector c k f p inc)
| m*n /= k =
error $ "maybeViewVectorAsMatrix " ++ show (m,n)
++ " <vector of dim " ++ show k ++ ">: vector dimension"
++ " must equal product of specified dimensions"
| c == Conj = Nothing
| inc /= 1 = Nothing
| otherwise = Just $ IOMatrix NoTrans m n f p m
liftIOMatrix :: (Elem e) => (forall n. IOVector n e -> IO ()) -> IOMatrix np e -> IO ()
liftIOMatrix g (IOMatrix h m n f p l)
| h == ConjTrans && l == n =
g (IOVector Conj (m*n) f p 1)
| h == NoTrans && l == m =
g (IOVector NoConj (m*n) f p 1)
| otherwise =
let (c,m',n') = if h == ConjTrans then (Conj,n,m) else (NoConj,m,n)
end = p `advancePtr` (n'*l)
go p' | p' == end = return ()
| otherwise = do
g (IOVector c m' f p' 1)
go (p' `advancePtr` l)
in go p
liftIOMatrix2 :: (Elem e, Elem f) =>
(forall k. IOVector k e -> IOVector k f -> IO ()) ->
IOMatrix np e -> IOMatrix np f -> IO ()
liftIOMatrix2 f a b =
if isHermIOMatrix a == isHermIOMatrix b
then case (maybeViewIOMatrixAsVector a, maybeViewIOMatrixAsVector b) of
((Just x), (Just y)) -> f x y
_ -> elementwise
else elementwise
where
elementwise =
let vecsA = if isHermIOMatrix a then rowViews
else colViews
vecsB = if isHermIOMatrix a then rowViews
else colViews
xs = vecsA a
ys = vecsB b
in zipWithM_ f xs ys
rowViews c = [ unsafeRowViewIOMatrix c i | i <- [ 0..numRows c 1 ] ]
colViews c = [ unsafeColViewIOMatrix c j | j <- [ 0..numCols c 1 ] ]
withIOMatrix :: IOMatrix (n,p) e -> (Ptr e -> IO a) -> IO a
withIOMatrix (IOMatrix _ _ _ f p _) g = do
a <- g p
touchForeignPtr f
return a
newIOMatrix_ :: (Elem e) => (Int,Int) -> IO (IOMatrix np e)
newIOMatrix_ (m,n)
| m < 0 || n < 0 =
fail $
"Tried to create a matrix with shape `" ++ show (m,n) ++ "'"
| otherwise = do
f <- mallocForeignPtrArray (m*n)
return $ IOMatrix NoTrans m n f (unsafeForeignPtrToPtr f) (max 1 m)
newCopyIOMatrix :: (BLAS1 e) => IOMatrix np e -> IO (IOMatrix np e)
newCopyIOMatrix (IOMatrix h m n f p l) =
let (m',n') = if h == ConjTrans then (n,m) else (m,n)
l' = max 1 m'
in do
(IOMatrix _ _ _ f' p' _) <- newIOMatrix_ (m',n')
if l == m'
then do
BLAS.copy (m*n) p 1 p' 1
else
let go src dst i | i == n' = return ()
| otherwise = do
BLAS.copy m' src 1 dst 1
go (src `advancePtr` l) (dst `advancePtr` l') (i+1)
in go p p' 0
touchForeignPtr f
touchForeignPtr f'
return (IOMatrix h m n f' p' l')
shapeIOMatrix :: IOMatrix np e -> (Int,Int)
shapeIOMatrix (IOMatrix _ m n _ _ _) = (m,n)
boundsIOMatrix :: IOMatrix np e -> ((Int,Int), (Int,Int))
boundsIOMatrix a = ((0,0), (m1,n1)) where (m,n) = shapeIOMatrix a
sizeIOMatrix :: IOMatrix np e -> Int
sizeIOMatrix (IOMatrix _ m n _ _ _) = m*n
getSizeIOMatrix :: IOMatrix np e -> IO Int
getSizeIOMatrix = return . sizeIOMatrix
getMaxSizeIOMatrix :: IOMatrix np e -> IO Int
getMaxSizeIOMatrix = getSizeIOMatrix
indicesIOMatrix :: IOMatrix np e -> [(Int,Int)]
indicesIOMatrix (IOMatrix h m n _ _ _)
| h == ConjTrans = [ (i,j) | i <- [ 0..m1 ], j <- [ 0..n1 ] ]
| otherwise = [ (i,j) | j <- [ 0..n1 ], i <- [ 0..m1 ] ]
getIndicesIOMatrix :: IOMatrix np e -> IO [(Int,Int)]
getIndicesIOMatrix = return . indicesIOMatrix
getIndicesIOMatrix' :: IOMatrix np e -> IO [(Int,Int)]
getIndicesIOMatrix' = getIndicesIOMatrix
getElemsIOMatrix :: (Elem e) => IOMatrix np e -> IO [e]
getElemsIOMatrix (IOMatrix h m n f p l)
| h == ConjTrans =
liftM (map conjugate) $
getElemsIOMatrix (IOMatrix NoTrans n m f p l)
| l == m =
getElemsIOVector (IOVector NoConj (m*n) f p 1)
| otherwise =
let end = p `advancePtr` (n*l)
go p' | p' == end = return []
| otherwise = unsafeInterleaveIO $ do
c <- getElemsIOVector (IOVector NoConj m f p' 1)
cs <- go (p' `advancePtr` l)
return (c ++ cs)
in go p
getElemsIOMatrix' :: (Elem e) => IOMatrix np e -> IO [e]
getElemsIOMatrix' (IOMatrix h m n f p l)
| h == ConjTrans =
liftM (map conjugate) $
getElemsIOMatrix' (IOMatrix NoTrans n m f p l)
| l == m =
getElemsIOVector' (IOVector NoConj (m*n) f p 1)
| otherwise =
let end = p `advancePtr` (n*l)
go p' | p' == end = return []
| otherwise = do
c <- getElemsIOVector' (IOVector NoConj m f p' 1)
cs <- go (p' `advancePtr` l)
return (c ++ cs)
in go p
getAssocsIOMatrix :: (Elem e) => IOMatrix np e -> IO [((Int,Int),e)]
getAssocsIOMatrix a = do
is <- getIndicesIOMatrix a
es <- getElemsIOMatrix a
return $ zip is es
getAssocsIOMatrix' :: (Elem e) => IOMatrix np e -> IO [((Int,Int),e)]
getAssocsIOMatrix' a = do
is <- getIndicesIOMatrix' a
es <- getElemsIOMatrix' a
return $ zip is es
unsafeReadElemIOMatrix :: (Elem e) => IOMatrix np e -> (Int,Int) -> IO e
unsafeReadElemIOMatrix (IOMatrix h _ _ f p l) (i,j)
| h == ConjTrans = do
e <- liftM conjugate $ peekElemOff p (i*l+j)
touchForeignPtr f
return e
| otherwise = do
e <- peekElemOff p (i+j*l)
touchForeignPtr f
return e
canModifyElemIOMatrix :: IOMatrix np e -> (Int,Int) -> IO Bool
canModifyElemIOMatrix _ _ = return True
unsafeWriteElemIOMatrix :: (Elem e) =>
IOMatrix np e -> (Int,Int) -> e -> IO ()
unsafeWriteElemIOMatrix (IOMatrix h _ _ f p l) (i,j) e
| h == ConjTrans = do
pokeElemOff p (i*l+j) (conjugate e)
touchForeignPtr f
| otherwise = do
pokeElemOff p (i+j*l) e
touchForeignPtr f
unsafeModifyElemIOMatrix :: (Elem e) =>
IOMatrix n e -> (Int,Int) -> (e -> e) -> IO ()
unsafeModifyElemIOMatrix (IOMatrix h _ _ f p l) (i,j) g =
let g' = if h == ConjTrans then conjugate . g . conjugate else g
p' = if h == ConjTrans then p `advancePtr` (i*l+j)
else p `advancePtr` (i+j*l)
in do
e <- peek p'
poke p' (g' e)
touchForeignPtr f
unsafeSwapElemsIOMatrix :: (Elem e) =>
IOMatrix n e -> (Int,Int) -> (Int,Int) -> IO ()
unsafeSwapElemsIOMatrix (IOMatrix h _ _ f p l) (i1,j1) (i2,j2) =
let (p1,p2) =
if h == ConjTrans
then (p `advancePtr` (i1*l+j1), p `advancePtr` (i2*l+j2))
else (p `advancePtr` (i1+j1*l), p `advancePtr` (i2+j2*l))
in do
e1 <- peek p1
e2 <- peek p2
poke p2 e1
poke p1 e2
touchForeignPtr f
modifyWithIOMatrix :: (Elem e) => (e -> e) -> IOMatrix np e -> IO ()
modifyWithIOMatrix g = liftIOMatrix (modifyWithIOVector g)
setZeroIOMatrix :: (Elem e) => IOMatrix np e -> IO ()
setZeroIOMatrix = liftIOMatrix setZeroIOVector
setConstantIOMatrix :: (Elem e) => e -> IOMatrix np e -> IO ()
setConstantIOMatrix k = liftIOMatrix (setConstantIOVector k)
doConjIOMatrix :: (BLAS1 e) => IOMatrix np e -> IO ()
doConjIOMatrix = liftIOMatrix doConjIOVector
scaleByIOMatrix :: (BLAS1 e) => e -> IOMatrix np e -> IO ()
scaleByIOMatrix k = liftIOMatrix (scaleByIOVector k)
shiftByIOMatrix :: (Elem e) => e -> IOMatrix np e -> IO ()
shiftByIOMatrix k = liftIOMatrix (shiftByIOVector k)
instance Shaped IOMatrix (Int,Int) where
shape = shapeIOMatrix
bounds = boundsIOMatrix
instance (Elem e) => ReadTensor IOMatrix (Int,Int) e IO where
getSize = getSizeIOMatrix
unsafeReadElem = unsafeReadElemIOMatrix
getIndices = getIndicesIOMatrix
getIndices' = getIndicesIOMatrix'
getElems = getElemsIOMatrix
getElems' = getElemsIOMatrix'
getAssocs = getAssocsIOMatrix
getAssocs' = getAssocsIOMatrix'
instance (BLAS1 e) => WriteTensor IOMatrix (Int,Int) e IO where
getMaxSize = getMaxSizeIOMatrix
setZero = setZeroIOMatrix
setConstant = setConstantIOMatrix
canModifyElem = canModifyElemIOMatrix
unsafeWriteElem = unsafeWriteElemIOMatrix
unsafeModifyElem = unsafeModifyElemIOMatrix
modifyWith = modifyWithIOMatrix
doConj = doConjIOMatrix
scaleBy = scaleByIOMatrix
shiftBy = shiftByIOMatrix
instance HasVectorView IOMatrix where
type VectorView IOMatrix = IOVector
instance MatrixShaped IOMatrix where
herm = hermIOMatrix