module Data.Matrix.Banded.Class.Copying (
newCopyBanded,
copyBanded,
unsafeCopyBanded,
) where
import BLAS.C.Level1( BLAS1 )
import qualified BLAS.C.Level1 as BLAS
import BLAS.Matrix
import BLAS.UnsafeIOToM
import Control.Monad( zipWithM_ )
import Data.Ix( range )
import Foreign( advancePtr )
import Data.Matrix.Banded.Class.Internal
import Data.Matrix.Banded.Class.Views
import Data.Vector.Dense.Class( unsafeCopyVector )
newCopyBanded :: (BLAS1 e, ReadBanded a x m, WriteBanded b y m) =>
a mn e -> m (b mn e)
newCopyBanded a
| isHermBanded a =
newCopyBanded ((herm . coerceBanded) a) >>=
return . coerceBanded . herm
| otherwise = do
a' <- newBanded_ (shapeBanded a) (numLower a, numUpper a)
unsafeCopyBanded a' a
return a'
copyBanded :: (BLAS1 e, WriteBanded b y m, ReadBanded a x m) =>
b mn e -> a mn e -> m ()
copyBanded dst src
| shapeBanded dst /= shapeBanded src =
error "Shape mismatch in copyBanded."
| bandwidth dst /= bandwidth src =
error "Bandwidth mismatch in copyBanded."
| otherwise =
unsafeCopyBanded dst src
unsafeCopyBanded :: (BLAS1 e, WriteBanded b y m, ReadBanded a x m) =>
b mn e -> a mn e -> m ()
unsafeCopyBanded dst src
| isHermBanded dst =
unsafeCopyBanded ((herm . coerceBanded) dst)
((herm . coerceBanded) src)
| (not . isHermBanded) src =
unsafeIOToM $
withBandedPtr dst $ \pDst ->
withBandedPtr src $ \pSrc ->
if ldDst == m && ldSrc == m
then copyBlock pDst pSrc
else copyCols pDst pSrc n
| otherwise =
zipWithM_ unsafeCopyVector (diagViews dst) (diagViews src)
where
m = numLower dst + numUpper dst + 1
n = numCols dst
ldDst = ldaOfBanded dst
ldSrc = ldaOfBanded src
copyBlock pDst pSrc =
BLAS.copy (m*n) pSrc 1 pDst 1
copyCols pDst pSrc nleft
| nleft == 0 = return ()
| otherwise = do
BLAS.copy m pSrc 1 pDst 1
copyCols (pDst `advancePtr` ldDst) (pSrc `advancePtr` ldSrc)
(nleft1)
diagViews a = map (unsafeDiagViewBanded a) $ (range . bandwidth) a