{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Matrix.Banded.Operations -- Copyright : Copyright (c) , Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module Data.Matrix.Banded.Operations ( module BLAS.Matrix.Immutable, module BLAS.Matrix.ReadOnly, -- * Matrix Arithmetic -- ** Pure scale, invScale, -- ** Impure getScaled, getInvScaled, -- * In-place operations doConj, scaleBy, invScaleBy, ) where import System.IO.Unsafe import Unsafe.Coerce import Data.Matrix.Banded.Internal import Data.Matrix.Dense.Internal ( DMatrix, IOMatrix) import Data.Vector.Dense.Internal hiding ( unsafeWithElemPtr, unsafeThaw, unsafeFreeze ) import qualified Data.Vector.Dense.Operations as V import qualified Data.Vector.Dense.Internal as V import qualified Data.Matrix.Dense.Internal as M import qualified Data.Matrix.Dense.Operations as M import BLAS.Access import BLAS.C ( CBLASTrans, colMajor, noTrans, conjTrans ) import qualified BLAS.C as BLAS import BLAS.Elem ( BLAS1, BLAS2 ) import qualified BLAS.Elem as E import BLAS.Matrix.Immutable import BLAS.Matrix.ReadOnly infixl 7 `scale`, `invScale` -- | Form a new matrix by multiplying every element by a value. getScaled :: (BLAS1 e) => e -> BMatrix t (m,n) e -> IO (BMatrix r (m,n) e) getScaled k = unaryOp (scaleBy k) -- | Form a new matrix by dividing every element by a value. getInvScaled :: (BLAS1 e) => e -> BMatrix t (m,n) e -> IO (BMatrix r (m,n) e) getInvScaled k = unaryOp (invScaleBy k) -- | Conjugate every element in a matrix. doConj :: (BLAS1 e) => IOBanded (m,n) e -> IO () doConj a = let (_,_,a',_) = toRawMatrix a in M.doConj a' -- | Scale every element in a matrix by the given value. scaleBy :: (BLAS1 e) => e -> IOBanded (m,n) e -> IO () scaleBy k a = let (_,_,a',h) = toRawMatrix a k' = if h then E.conj k else k in M.scaleBy k' a' -- | Divide every element by the given value. invScaleBy :: (BLAS1 e) => e -> IOBanded (m,n) e -> IO () invScaleBy k a = let (_,_,a',h) = toRawMatrix a k' = if h then E.conj k else k in M.invScaleBy k' a' blasTransOf :: BMatrix t (m,n) e -> CBLASTrans blasTransOf a = case (isHerm a) of False -> noTrans True -> conjTrans flipShape :: (Int,Int) -> (Int,Int) flipShape (m,n) = (n,m) -- | @gbmv alpha a x beta y@ replaces @y := alpha a * x + beta y@ gbmv :: (BLAS2 e) => e -> BMatrix s (m,n) e -> DVector t n e -> e -> IOVector m e -> IO () gbmv alpha a x beta y | numRows a == 0 || numCols a == 0 = return () | isConj x = do x' <- V.getConj (conj x) gbmv alpha a x' beta y | isConj y = do V.doConj y gbmv alpha a x beta (conj y) V.doConj y | otherwise = let order = colMajor transA = blasTransOf a (m,n) = case (isHerm a) of False -> shape a True -> (flipShape . shape) a (kl,ku) = case (isHerm a) of False -> (numLower a, numUpper a) True -> (numUpper a, numLower a) ldA = ldaOf a incX = V.strideOf x incY = V.strideOf y in unsafeWithBasePtr a $ \pA -> V.unsafeWithElemPtr x 0 $ \pX -> V.unsafeWithElemPtr y 0 $ \pY -> do BLAS.gbmv order transA m n kl ku alpha pA ldA pX incX beta pY incY -- | @gbmm alpha a b beta c@ replaces @c := alpha a * b + beta c@. gbmm :: (BLAS2 e) => e -> BMatrix s (m,k) e -> DMatrix t (k,n) e -> e -> IOMatrix (m,n) e -> IO () gbmm alpha a b beta c = sequence_ $ zipWith (\x y -> gbmv alpha a x beta y) (M.cols b) (M.cols c) unaryOp :: (BLAS1 e) => (IOBanded (m,n) e -> IO ()) -> BMatrix t (m,n) e -> IO (BMatrix r (m,n) e) unaryOp f a = do a' <- newCopy a f (unsafeThaw a') return (unsafeCoerce a') -- | Create a new matrix by scaling another matrix by the given value. scale :: (BLAS1 e) => e -> Banded (m,n) e -> Banded (m,n) e scale k a = unsafePerformIO $ getScaled k a {-# NOINLINE scale #-} -- | Form a new matrix by dividing every element by a value. invScale :: (BLAS1 e) => e -> Banded (m,n) e -> Banded (m,n) e invScale k a = unsafePerformIO $ getInvScaled k a {-# NOINLINE invScale #-} instance (BLAS2 e) => RMatrix (BMatrix s) e where unsafeDoSApplyAdd = gbmv unsafeDoSApplyAddMat = gbmm instance (BLAS2 e) => IMatrix (BMatrix Imm) e