----------------------------------------------------------------------------- -- | -- Module : Data.Matrix.Banded.Class.Copying -- Copyright : Copyright (c) , Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module Data.Matrix.Banded.Class.Copying ( -- * Copying Banded matrices newCopyBanded, copyBanded, unsafeCopyBanded, ) where import BLAS.C.Level1( BLAS1 ) import qualified BLAS.C.Level1 as BLAS import BLAS.Matrix import BLAS.UnsafeIOToM import Control.Monad( zipWithM_ ) import Data.Ix( range ) import Foreign( advancePtr ) import Data.Matrix.Banded.Class.Internal import Data.Matrix.Banded.Class.Views import Data.Vector.Dense.Class( unsafeCopyVector ) newCopyBanded :: (BLAS1 e, ReadBanded a x m, WriteBanded b y m) => a mn e -> m (b mn e) newCopyBanded a | isHermBanded a = newCopyBanded ((herm . coerceBanded) a) >>= return . coerceBanded . herm | otherwise = do a' <- newBanded_ (shapeBanded a) (numLower a, numUpper a) unsafeCopyBanded a' a return a' copyBanded :: (BLAS1 e, WriteBanded b y m, ReadBanded a x m) => b mn e -> a mn e -> m () copyBanded dst src | shapeBanded dst /= shapeBanded src = error "Shape mismatch in copyBanded." | bandwidth dst /= bandwidth src = error "Bandwidth mismatch in copyBanded." | otherwise = unsafeCopyBanded dst src unsafeCopyBanded :: (BLAS1 e, WriteBanded b y m, ReadBanded a x m) => b mn e -> a mn e -> m () unsafeCopyBanded dst src | isHermBanded dst = unsafeCopyBanded ((herm . coerceBanded) dst) ((herm . coerceBanded) src) | (not . isHermBanded) src = unsafeIOToM $ withBandedPtr dst $ \pDst -> withBandedPtr src $ \pSrc -> if ldDst == m && ldSrc == m then copyBlock pDst pSrc else copyCols pDst pSrc n | otherwise = zipWithM_ unsafeCopyVector (diagViews dst) (diagViews src) where m = numLower dst + numUpper dst + 1 -- we can be sure dst is not herm n = numCols dst ldDst = ldaOfBanded dst ldSrc = ldaOfBanded src copyBlock pDst pSrc = BLAS.copy (m*n) pSrc 1 pDst 1 copyCols pDst pSrc nleft | nleft == 0 = return () | otherwise = do BLAS.copy m pSrc 1 pDst 1 copyCols (pDst `advancePtr` ldDst) (pSrc `advancePtr` ldSrc) (nleft-1) diagViews a = map (unsafeDiagViewBanded a) $ (range . bandwidth) a