module Data.Matrix.Banded.Internal (
BMatrix(..),
Banded,
IOBanded,
module BLAS.Matrix.Base,
module BLAS.Tensor,
toForeignPtr,
fromForeignPtr,
ldaOf,
isHerm,
toMatrix,
fromMatrix,
bandwidth,
numLower,
numUpper,
banded,
listsBanded,
newBanded_,
newBanded,
newListsBanded,
row,
col,
getRow,
getCol,
diag,
rowView,
colView,
coerceMatrix,
unsafeBanded,
unsafeNewBanded,
unsafeFreeze,
unsafeThaw,
unsafeWithElemPtr,
unsafeDiag,
unsafeGetRow,
unsafeGetCol,
unsafeRow,
unsafeCol,
unsafeRowView,
unsafeColView,
) where
import Control.Arrow ( second )
import Control.Monad ( zipWithM_ )
import Data.Ix ( inRange, range )
import Data.List ( foldl' )
import Foreign
import System.IO.Unsafe
import Unsafe.Coerce
import BLAS.Access
import BLAS.Elem ( Elem, BLAS1 )
import qualified BLAS.Elem as E
import BLAS.Internal ( checkedRow, checkedCol, checkedDiag, diagStart,
diagLen, clearArray, inlinePerformIO )
import BLAS.Matrix.Base hiding ( Matrix )
import qualified BLAS.Matrix.Base as C
import BLAS.Tensor
import Data.Matrix.Dense.Internal ( DMatrix )
import qualified Data.Matrix.Dense.Internal as M
import Data.Vector.Dense.Internal ( DVector, Vector, conj, dim, newListVector )
import qualified Data.Vector.Dense.Internal as V
data BMatrix t mn e
= BM { fptr :: !(ForeignPtr e)
, offset :: !Int
, size1 :: !Int
, size2 :: !Int
, lowBW :: !Int
, upBW :: !Int
, lda :: !Int
}
| H !(BMatrix t mn e)
type Banded = BMatrix Imm
type IOBanded = BMatrix Mut
fromForeignPtr :: ForeignPtr e -> Int -> (Int,Int) -> (Int,Int) -> Int
-> BMatrix t (m,n) e
fromForeignPtr f o (m,n) (kl,ku) l = BM f o m n kl ku l
toForeignPtr :: BMatrix t (m,n) e -> (ForeignPtr e, Int, (Int,Int), (Int,Int), Int)
toForeignPtr (H a) = toForeignPtr a
toForeignPtr (BM f o m n kl ku l) = (f, o, (m,n), (kl,ku), l)
ldaOf :: BMatrix t (m,n) e -> Int
ldaOf (H a) = ldaOf a
ldaOf a = lda a
isHerm :: BMatrix t (m,n) e -> Bool
isHerm (H a) = not (isHerm a)
isHerm _ = False
unsafeFreeze :: BMatrix t mn e -> Banded mn e
unsafeFreeze = unsafeCoerce
unsafeThaw :: BMatrix t mn e -> IOBanded mn e
unsafeThaw = unsafeCoerce
coerceMatrix :: BMatrix t mn e -> BMatrix t kl e
coerceMatrix = unsafeCoerce
toMatrix :: (Elem e) => BMatrix t (m,n) e -> (DMatrix t (m',n') e, (Int,Int), (Int,Int))
toMatrix (H a) =
case toMatrix (herm a) of (b, (m,n), (kl,ku)) -> (herm b, (n,m), (ku,kl))
toMatrix (BM f o m n kl ku ld) =
(M.fromForeignPtr f o (kl+1+ku,n) ld, (m,n), (kl,ku))
fromMatrix :: (Elem e) => DMatrix t (m,n) e -> (Int,Int) -> (Int,Int) -> BMatrix t (m',n') e
fromMatrix a (m,n) (kl,ku) = case a of
(M.H a') -> herm (fromMatrix a' (n,m) (ku,kl))
_ ->
let (f,o,(m',n'),ld) = M.toForeignPtr a
in case undefined of
_ | m' /= kl+1+ku ->
error $ "fromMatrix: number of rows must be equal to number of diagonals"
_ | n' /= n ->
error $ "fromMatrix: numbers of columns must be equal"
_ ->
BM f o m n kl ku ld
bandwidth :: BMatrix t (m,n) e -> (Int,Int)
bandwidth a =
let (kl,ku) = (numLower a, numUpper a)
in (negate kl, ku)
numLower :: BMatrix t (m,n) e -> Int
numLower (H a) = numUpper a
numLower a = lowBW a
numUpper :: BMatrix t (m,n) e -> Int
numUpper (H a) = numLower a
numUpper a = upBW a
newBanded_ :: (Elem e) => (Int,Int) -> (Int,Int) -> IO (BMatrix t (m,n) e)
newBanded_ (m,n) (kl,ku)
| m < 0 || n < 0 =
err "dimensions must be non-negative."
| kl < 0 =
err "lower bandwdth must be non-negative."
| m /= 0 && kl >= m =
err "lower bandwidth must be less than m."
| ku < 0 =
err "upper bandwidth must be non-negative."
| n /= 0 && ku >= n =
err "upper bandwidth must be less than n."
| otherwise =
let off = 0
m' = kl + 1 + ku
l = m'
in do
ptr <- mallocForeignPtrArray (m' * n)
return $ fromForeignPtr ptr off (m,n) (kl,ku) l
where
err s = ioError $ userError $
"newBanded_ " ++ show (m,n) ++ " " ++ show (kl,ku) ++ ": " ++ s
banded :: (BLAS1 e) => (Int,Int) -> (Int,Int) -> [((Int,Int), e)] -> Banded (m,n) e
banded mn kl ijes = unsafePerformIO $ newBanded mn kl ijes
unsafeBanded :: (BLAS1 e) => (Int,Int) -> (Int,Int) -> [((Int,Int), e)] -> Banded (m,n) e
unsafeBanded mn kl ijes = unsafePerformIO $ unsafeNewBanded mn kl ijes
newBanded :: (BLAS1 e) => (Int,Int) -> (Int,Int) -> [((Int,Int), e)] -> IO (BMatrix t (m,n) e)
newBanded = newBandedHelp writeElem
unsafeNewBanded :: (BLAS1 e) => (Int,Int) -> (Int,Int) -> [((Int,Int), e)] -> IO (BMatrix t (m,n) e)
unsafeNewBanded = newBandedHelp unsafeWriteElem
newBandedHelp :: (BLAS1 e) =>
(IOBanded (m,n) e -> (Int,Int) -> e -> IO ())
-> (Int,Int) -> (Int,Int) -> [((Int,Int),e)] -> IO (BMatrix t (m,n) e)
newBandedHelp set (m,n) (kl,ku) ijes = do
x <- newBanded_ (m,n) (kl,ku)
withForeignPtr (fptr x) $ flip clearArray ((kl+1+ku)*n)
mapM_ (uncurry $ set $ unsafeThaw x) ijes
return x
listsBanded :: (BLAS1 e) => (Int,Int) -> (Int,Int) -> [[e]] -> Banded (m,n) e
listsBanded mn kl xs = unsafePerformIO $ newListsBanded mn kl xs
newListsBanded :: (BLAS1 e) => (Int,Int) -> (Int,Int) -> [[e]] -> IO (BMatrix t (m,n) e)
newListsBanded (m,n) (kl,ku) xs = do
a <- newBanded_ (m,n) (kl,ku)
zipWithM_ (writeDiagElems (unsafeThaw a)) [(negate kl)..ku] xs
return a
where
writeDiagElems a i es =
let d = diag a i
nb = max 0 (negate i)
es' = drop nb es
in zipWithM_ (unsafeWriteElem d) [0..(dim d 1)] es'
unsafeDiag :: (Elem e) => BMatrix t (m,n) e -> Int -> DVector t k e
unsafeDiag (H a) d =
conj $ unsafeDiag a (negate d)
unsafeDiag a d =
let f = fptr a
off = indexOf a (diagStart d)
len = diagLen (shape a) d
stride = lda a
in V.fromForeignPtr f off len stride
diag :: (Elem e) => BMatrix t (m,n) e -> Int -> DVector t k e
diag a = checkedDiag (shape a) (unsafeDiag a)
indexOf :: BMatrix t mn e -> (Int,Int) -> Int
indexOf (H a) (i,j) = indexOf a (j,i)
indexOf (BM _ off _ _ _ ku ld) (i,j) =
off + ku + (i j) + j * ld
unsafeWithElemPtr :: (Elem e) => BMatrix t (m,n) e -> (Int,Int) -> (Ptr e -> IO a) -> IO a
unsafeWithElemPtr a (i,j) f = case a of
(H a') -> unsafeWithElemPtr a' (j,i) f
_ -> withForeignPtr (fptr a) $ \ptr ->
f $ ptr `advancePtr` (indexOf a (i,j))
row :: (BLAS1 e) => Banded (m,n) e -> Int -> Vector n e
row a = checkedRow (shape a) (unsafeRow a)
unsafeRow :: (BLAS1 e) => Banded (m,n) e -> Int -> Vector n e
unsafeRow a i = unsafePerformIO $ getRow a i
getRow :: (BLAS1 e) => BMatrix t (m,n) e -> Int -> IO (DVector r n e)
getRow a = checkedRow (shape a) (unsafeGetRow a)
unsafeGetRow :: (BLAS1 e) => BMatrix t (m,n) e -> Int -> IO (DVector r n e)
unsafeGetRow a i =
let (nb,x,na) = unsafeRowView a i
n = numCols a
in do
es <- getElems x
newListVector n $ (replicate nb 0) ++ es ++ (replicate na 0)
col :: (BLAS1 e) => Banded (m,n) e -> Int -> Vector m e
col a = checkedCol (shape a) (unsafeCol a)
unsafeCol :: (BLAS1 e) => Banded (m,n) e -> Int -> Vector m e
unsafeCol a i = unsafePerformIO $ getCol a i
getCol :: (BLAS1 e) => BMatrix t (m,n) e -> Int -> IO (DVector r m e)
getCol a = checkedCol (shape a) (unsafeGetCol a)
unsafeGetCol :: (BLAS1 e) => BMatrix t (m,n) e -> Int -> IO (DVector r m e)
unsafeGetCol a j = unsafeGetRow (herm a) j >>= return . conj
unsafeColView :: (Elem e) => BMatrix t (m,n) e -> Int -> (Int, DVector t k e, Int)
unsafeColView (BM f off m _ kl ku ld) j =
let nb = max (j ku) 0
na = max (m 1 j kl) 0
r = max (ku j) 0
c = j
off' = off + r + c * ld
stride = 1
len = m (nb + na)
in if len >= 0
then (nb, V.fromForeignPtr f off' len stride, na)
else (m , V.fromForeignPtr f off' 0 stride, 0)
unsafeColView a j =
case unsafeRowView (herm a) j of (nb, v, na) -> (nb, conj v, na)
unsafeRowView :: (Elem e) => BMatrix t (m,n) e -> Int -> (Int, DVector t k e, Int)
unsafeRowView (BM f off _ n kl ku ld) i =
let nb = max (i kl) 0
na = max (n 1 i ku) 0
r = min (ku + i) (kl + ku)
c = max (i kl) 0
off' = off + r + c * ld
stride = ld 1
len = n (nb + na)
in if len >= 0
then (nb, V.fromForeignPtr f off' len stride, na)
else (n , V.fromForeignPtr f off' 0 stride, 0)
unsafeRowView a i =
case unsafeColView (herm a) i of (nb, v, na) -> (nb, conj v, na)
rowView :: (Elem e) => BMatrix t (m,n) e -> Int -> (Int, DVector t k e, Int)
rowView a = checkedRow (shape a) (unsafeRowView a)
colView :: (Elem e) => BMatrix t (m,n) e -> Int -> (Int, DVector t k e, Int)
colView a = checkedCol (shape a) (unsafeColView a)
instance C.Matrix (BMatrix t) where
numRows = fst . shape
numCols = snd . shape
herm a = case a of
(H a') -> coerceMatrix a'
_ -> H (coerceMatrix a)
instance Tensor (BMatrix t (m,n)) (Int,Int) e where
shape a = case a of
(H a') -> case shape a' of (m,n) -> (n,m)
_ -> (size1 a, size2 a)
bounds a = let (m,n) = shape a in ((0,0), (m1,n1))
instance (BLAS1 e) => ITensor (BMatrix Imm (m,n)) (Int,Int) e where
size = inlinePerformIO . getSize
unsafeAt a = inlinePerformIO . (unsafeReadElem a)
indices = inlinePerformIO . getIndices
elems = inlinePerformIO . getElems
assocs = inlinePerformIO . getAssocs
(//) = replaceHelp writeElem
unsafeReplace = replaceHelp unsafeWriteElem
amap f a = banded (shape a) (bandwidth a) ies
where
ies = map (second f) (assocs a)
replaceHelp :: (BLAS1 e) =>
(IOBanded (m,n) e -> (Int,Int) -> e -> IO ())
-> Banded (m,n) e -> [((Int,Int), e)] -> Banded (m,n) e
replaceHelp set x ies =
unsafeFreeze $ unsafePerformIO $ do
y <- newCopy (unsafeThaw x)
mapM_ (uncurry $ set y) ies
return y
instance (BLAS1 e) => RTensor (BMatrix t (m,n)) (Int,Int) e IO where
newCopy b =
let (a,mn,kl) = toMatrix b
in do
a' <- newCopy a
return $ fromMatrix a' mn kl
getSize a = case a of
(H a') -> getSize a'
(BM _ _ m n kl ku _) -> return $ foldl' (+) 0 $
map (diagLen (m,n)) [(kl)..ku]
unsafeReadElem a (i,j) = case a of
(H a') -> unsafeReadElem a' (j,i) >>= return . E.conj
_ -> withForeignPtr (fptr a) $ \ptr ->
peekElemOff ptr (indexOf a (i,j))
getIndices a =
return $ filter (\ij -> inlinePerformIO $ canModifyElem (unsafeThaw a) ij)
(range $ bounds a)
getElems a = getAssocs a >>= return . (map snd)
getAssocs a = do
is <- unsafeInterleaveIO $ getIndices a
mapM (\i -> unsafeReadElem a i >>= \e -> return (i,e)) is
instance (BLAS1 e) => MTensor (BMatrix Mut (m,n)) (Int,Int) e IO where
setZero a = case toMatrix a of (a',_,_) -> setZero a'
setConstant e a = case toMatrix a of (a',_,_) -> setConstant e a'
canModifyElem a (i,j) = case a of
(H a') -> canModifyElem a' (j,i)
(BM _ _ m n kl ku _) -> return $ inRange (0,m1) i &&
inRange (0,n1) j &&
inRange (max 0 (jku), min (m1) (j+kl)) i
unsafeWriteElem a (i,j) e = case a of
(H a') -> unsafeWriteElem a' (j,i) (E.conj e)
_ -> withForeignPtr (fptr a) $ \ptr ->
pokeElemOff ptr (indexOf a (i,j)) e
modifyWith f a = case toMatrix a of (a',_,_) -> modifyWith f a'