module Data.Elem.BLAS.Level1
where
import Prelude hiding ( div )
import Foreign ( Ptr, Storable, advancePtr, castPtr, peek, poke, with )
import Foreign.Storable.Complex ()
import Data.Complex ( Complex(..) )
import Data.Elem.BLAS.Base
import BLAS.CTypes
import Data.Elem.BLAS.Double
import Data.Elem.BLAS.Zomplex
class (Elem a) => BLAS1 a where
dotu :: Int -> Ptr a -> Int -> Ptr a -> Int -> IO a
dotc :: Int -> Ptr a -> Int -> Ptr a -> Int -> IO a
nrm2 :: Int -> Ptr a -> Int -> IO Double
asum :: Int -> Ptr a -> Int -> IO Double
iamax :: Int -> Ptr a -> Int -> IO Int
scal :: Int -> a -> Ptr a -> Int -> IO ()
swap :: Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
copy :: Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
axpy :: Int -> a -> Ptr a -> Int -> Ptr a -> Int -> IO ()
rotg :: Ptr a -> Ptr a -> Ptr a -> Ptr a -> IO ()
rot :: Int -> Ptr a -> Int -> Ptr a -> Int -> Double -> Double -> IO ()
vconj :: Int -> Ptr a -> Int -> IO ()
acxpy :: Int -> a -> Ptr a -> Int -> Ptr a -> Int -> IO ()
vmul :: Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
vcmul :: Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
vdiv :: Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
vcdiv :: Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
instance BLAS1 Double where
dotu = ddot
dotc = ddot
nrm2 = dnrm2
asum = dasum
iamax = idamax
swap = dswap
copy = dcopy
axpy = daxpy
scal = dscal
rotg = drotg
rot = drot
vconj _ _ _ = return ()
acxpy = daxpy
vmul n = dtbmv upper noTrans nonUnit n 0
vcmul = vmul
vdiv n = dtbsv upper noTrans nonUnit n 0
vcdiv = vdiv
instance BLAS1 (Complex Double) where
dotu n pX incX pY incY =
with 0 $ \pDotu -> do
zdotu_sub n pX incX pY incY pDotu
peek pDotu
dotc n pX incX pY incY =
with 0 $ \pDotc -> do
zdotc_sub n pX incX pY incY pDotc
peek pDotc
nrm2 = znrm2
asum = zasum
iamax = izamax
swap = zswap
copy = zcopy
axpy n alpha pX incX pY incY =
with alpha $ \pAlpha ->
zaxpy n pAlpha pX incX pY incY
scal n alpha pX incX =
with alpha $ \pAlpha ->
zscal n pAlpha pX incX
rotg = zrotg
rot = zdrot
vconj n pX incX =
let pXI = (castPtr pX) `advancePtr` 1
alpha = 1
incXI = 2 * incX
in dscal n alpha pXI incXI
acxpy n a pX incX pY incY =
let pXR = castPtr pX
pYR = castPtr pY
pXI = pXR `advancePtr` 1
pYI = pYR `advancePtr` 1
incX' = 2 * incX
incY' = 2 * incY
in case a of
(ra :+ 0) -> do
daxpy n ( ra) pXR incX' pYR incY'
daxpy n (ra) pXI incX' pYI incY'
(0 :+ ia) -> do
daxpy n ( ia) pXR incX' pYI incY'
daxpy n ( ia) pXI incX' pYR incY'
_ -> go n pX pY
where
go n' pX' pY'
| n' `seq` pX' `seq` pY' `seq` False = undefined
| n' <= 0 =
return ()
| otherwise = do
x <- peek pX'
y <- peek pY'
poke pY' (a * (conjugate x) + y)
let n'' = n' 1
pX'' = pX' `advancePtr` incX
pY'' = pY' `advancePtr` incY
go n'' pX'' pY''
vmul n = ztbmv upper noTrans nonUnit n 0
vcmul n = ztbmv upper conjTrans nonUnit n 0
vdiv n = ztbsv upper noTrans nonUnit n 0
vcdiv n = ztbsv upper conjTrans nonUnit n 0