module Data.Matrix.Banded.Internal (
Banded(..),
module BLAS.Tensor.Base,
module BLAS.Matrix.Base,
bandwidth,
numLower,
numUpper,
coerceBanded,
banded,
listsBanded,
unsafeBanded,
module BLAS.Tensor.Immutable,
zeroBanded,
constantBanded,
diagBanded,
unsafeDiagBanded,
listsFromBanded,
ldaOfBanded,
isHermBanded,
module BLAS.Matrix.Immutable,
) where
import Data.AEq
import System.IO.Unsafe
import BLAS.Internal ( diagLen, checkedDiag, inlinePerformIO )
import BLAS.Elem( BLAS1, BLAS2 )
import BLAS.Tensor.Base
import BLAS.Tensor.Immutable
import BLAS.Tensor.Read
import BLAS.UnsafeIOToM
import BLAS.Matrix.Base hiding ( BaseMatrix )
import BLAS.Matrix.Immutable
import BLAS.Matrix.Mutable
import qualified BLAS.Matrix.Base as BLAS
import Data.Ix( inRange, range )
import Data.Matrix.Banded.Class.Internal( BaseBanded(..), ReadBanded,
IOBanded, coerceBanded, numLower, numUpper, bandwidth, isHermBanded,
shapeBanded, boundsBanded, ldaOfBanded, gbmv, gbmm, unsafeGetRowBanded,
unsafeGetColBanded )
import Data.Matrix.Banded.Class.Creating( newListsBanded, unsafeNewBanded,
newBanded )
import Data.Matrix.Banded.Class.Elements( writeElem, unsafeWriteElem )
import Data.Matrix.Banded.Class.Special( newZeroBanded, newConstantBanded )
import Data.Matrix.Banded.Class.Views( unsafeDiagViewBanded )
import Data.Matrix.Banded.Class.Copying( newCopyBanded )
import Data.Vector.Dense( Vector, zeroVector )
import Data.Vector.Dense.ST( runSTVector )
import Data.Matrix.Dense.ST( runSTMatrix )
newtype Banded mn e = B (IOBanded mn e)
unsafeFreezeIOBanded :: IOBanded mn e -> Banded mn e
unsafeFreezeIOBanded = B
unsafeThawIOBanded :: Banded mn e -> IOBanded mn e
unsafeThawIOBanded (B a) = a
liftBanded :: (IOBanded mn e -> a) -> Banded mn e -> a
liftBanded f (B x) = f x
inlineLiftBanded :: (IOBanded n e -> IO a) -> Banded n e -> a
inlineLiftBanded f = inlinePerformIO . liftBanded f
banded :: (BLAS1 e) => (Int,Int) -> (Int,Int) -> [((Int,Int), e)] -> Banded (m,n) e
banded mn kl ijes =
unsafeFreezeIOBanded $ unsafePerformIO $ newBanded mn kl ijes
unsafeBanded :: (BLAS1 e) => (Int,Int) -> (Int,Int) -> [((Int,Int), e)] -> Banded (m,n) e
unsafeBanded mn kl ijes =
unsafeFreezeIOBanded $ unsafePerformIO $ unsafeNewBanded mn kl ijes
listsBanded :: (BLAS1 e) => (Int,Int) -> (Int,Int) -> [[e]] -> Banded (m,n) e
listsBanded mn kl xs =
unsafeFreezeIOBanded $ unsafePerformIO $ newListsBanded mn kl xs
zeroBanded :: (BLAS1 e) => (Int,Int) -> (Int,Int) -> Banded (m,n) e
zeroBanded mn kl =
unsafeFreezeIOBanded $ unsafePerformIO $ newZeroBanded mn kl
constantBanded :: (BLAS1 e) => (Int,Int) -> (Int,Int) -> e -> Banded (m,n) e
constantBanded mn kl e =
unsafeFreezeIOBanded $ unsafePerformIO $ newConstantBanded mn kl e
diagBanded :: (BLAS1 e) => Banded mn e -> Int -> Vector k e
diagBanded a = checkedDiag (shape a) (unsafeDiagBanded a)
unsafeDiagBanded :: (BLAS1 e) => Banded mn e -> Int -> Vector k e
unsafeDiagBanded a i
| inRange (bandwidth a) i = unsafeDiagViewBanded a i
| otherwise = zeroVector $ diagLen (shape a) i
instance BaseTensor Banded (Int,Int) where
shape = shapeBanded . unsafeThawIOBanded
bounds = boundsBanded . unsafeThawIOBanded
instance ITensor Banded (Int,Int) where
(//) = replaceHelp writeElem
unsafeReplace = replaceHelp unsafeWriteElem
unsafeAt x i = inlineLiftBanded (flip unsafeReadElem i) x
size = inlineLiftBanded getSize
elems = inlineLiftBanded getElems
indices = inlineLiftBanded getIndices
assocs = inlineLiftBanded getAssocs
tmap f a = coerceBanded $ listsBanded mn bw (map (map f) es)
where (mn,bw,es) = listsFromBanded a
listsFromBanded :: (BLAS1 e) => Banded mn e -> ((Int,Int), (Int,Int),[[e]])
listsFromBanded a = ( (m,n)
, (kl,ku)
, map paddedDiag [(kl)..ku]
)
where
(m,n) = shape a
(kl,ku) = (numLower a, numUpper a)
padBegin i = replicate (max (i) 0) 0
padEnd i = replicate (max (mn+i) 0) 0
paddedDiag i = ( padBegin i
++ elems (unsafeDiagViewBanded a i)
++ padEnd i
)
replaceHelp :: (BLAS1 e) =>
(IOBanded mn e -> (Int,Int) -> e -> IO ())
-> Banded mn e -> [((Int,Int), e)] -> Banded mn e
replaceHelp set x ies =
unsafeFreezeIOBanded $ unsafePerformIO $ do
y <- newCopyBanded (unsafeThawIOBanded x)
mapM_ (uncurry $ set y) ies
return y
instance (Monad m) => ReadTensor Banded (Int,Int) m where
getSize = return . size
getAssocs = return . assocs
getIndices = return . indices
getElems = return . elems
getAssocs' = getAssocs
getIndices' = getIndices
getElems' = getElems
unsafeReadElem x i = return (unsafeAt x i)
instance BLAS.BaseMatrix Banded where
herm (B a) = B (herm a)
instance BaseBanded Banded Vector where
bandedViewArray f p m n kl ku l h = B $ bandedViewArray f p m n kl ku l h
arrayFromBanded (B a ) = arrayFromBanded a
instance (UnsafeIOToM m) => ReadBanded Banded Vector m where
instance (BLAS2 e) => IMatrix Banded e where
unsafeSApply alpha a x = runSTVector $ unsafeGetSApply alpha a x
unsafeSApplyMat alpha a b = runSTMatrix $ unsafeGetSApplyMat alpha a b
unsafeRow a i = runSTVector $ unsafeGetRow a i
unsafeCol a j = runSTVector $ unsafeGetCol a j
instance (BLAS2 e, UnsafeIOToM m) => MMatrix Banded e m where
unsafeDoSApplyAdd = gbmv
unsafeDoSApplyAddMat = gbmm
unsafeGetRow = unsafeGetRowBanded
unsafeGetCol = unsafeGetColBanded
instance (BLAS1 e) => Show (Banded mn e) where
show a
| isHermBanded a =
"herm (" ++ show (herm $ coerceBanded a) ++ ")"
| otherwise =
let (mn,kl,es) = listsFromBanded a
in "listsBanded " ++ show mn ++ " " ++ show kl ++ " " ++ show es
compareHelp :: (BLAS1 e) =>
(e -> e -> Bool) -> Banded mn e -> Banded mn e -> Bool
compareHelp cmp a b
| shape a /= shape b =
False
| isHermBanded a == isHermBanded b && bandwidth a == bandwidth b =
let elems' = if isHermBanded a then elems . herm .coerceBanded
else elems
in
and $ zipWith cmp (elems' a) (elems' b)
| otherwise =
let l = max (numLower a) (numLower b)
u = max (numUpper a) (numUpper b)
in
and $ zipWith cmp (diagElems (l,u) a) (diagElems (l,u) b)
where
diagElems bw c = concatMap elems [ diagBanded c i | i <- range bw ]
instance (BLAS1 e, Eq e) => Eq (Banded mn e) where
(==) = compareHelp (==)
instance (BLAS1 e, AEq e) => AEq (Banded mn e) where
(===) = compareHelp (===)
(~==) = compareHelp (~==)