```{-# LANGUAGE FlexibleInstances #-}
{-# OPTIONS_GHC -fno-excess-precision #-}
-----------------------------------------------------------------------------
-- |
-- Module     : BLAS.C.Level1
-- Maintainer : Patrick Perry <patperry@stanford.edu>
-- Stability  : experimental
--

module BLAS.C.Level1
where

import Foreign ( Ptr, Storable, advancePtr, castPtr, peek, poke, with )
import Foreign.Storable.Complex ()
import Data.Complex

import BLAS.Elem.Base
import BLAS.C.Double
import BLAS.C.Zomplex

class (Elem a) => BLAS1 a where
dotu  :: Int -> Ptr a -> Int -> Ptr a -> Int -> IO a
dotc  :: Int -> Ptr a -> Int -> Ptr a -> Int -> IO a
nrm2  :: Int -> Ptr a -> Int -> IO Double
asum  :: Int -> Ptr a -> Int -> IO Double
iamax :: Int -> Ptr a -> Int -> IO Int

scal  :: Int -> a -> Ptr a -> Int -> IO ()

swap  :: Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
copy  :: Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
axpy  :: Int -> a -> Ptr a -> Int -> Ptr a -> Int -> IO ()

rotg  :: Ptr a -> Ptr a -> Ptr a -> Ptr a -> IO ()
rot   :: Int -> Ptr a -> Int -> Ptr a -> Int -> Double -> Double -> IO ()

-- | conjugate all elements of a vector
conj  :: Int -> Ptr a -> Int -> IO ()

-- | Replaces @y@ with @alpha (conj x) + y@
acxpy :: Int -> a -> Ptr a -> Int -> Ptr a -> Int -> IO ()

instance BLAS1 Double where
dotu  = ddot
dotc  = ddot
nrm2  = dnrm2
asum  = dasum
iamax = idamax
swap  = dswap
copy  = dcopy
axpy  = daxpy
scal  = dscal
rotg  = drotg
rot   = drot
conj _ _ _ = return ()
acxpy = daxpy

instance BLAS1 (Complex Double) where
dotu n pX incX pY incY =
with 0 \$ \pDotu -> do
zdotu_sub n pX incX pY incY pDotu
peek pDotu

dotc n pX incX pY incY =
with 0 \$ \pDotc -> do
zdotc_sub n pX incX pY incY pDotc
peek pDotc

nrm2  = znrm2
asum  = zasum
iamax = izamax
swap  = zswap
copy  = zcopy

axpy n alpha pX incX pY incY =
with alpha \$ \pAlpha ->
zaxpy n pAlpha pX incX pY incY

scal n alpha pX incX =
with alpha \$ \pAlpha ->
zscal n pAlpha pX incX

rotg  = zrotg

rot = zdrot

conj n pX incX =
let pXI   = (castPtr pX) `advancePtr` 1
alpha = -1
incXI = 2 * incX
in dscal n alpha pXI incXI

acxpy n a pX incX pY incY =
let pXR   = castPtr pX
pYR   = castPtr pY
incX' = 2 * incX
incY' = 2 * incY
in case a of
(ra :+  0) -> do
daxpy n ( ra) pXR incX' pYR incY'
daxpy n (-ra) pXI incX' pYI incY'
(0  :+ ia) -> do
daxpy n ( ia) pXR incX' pYI incY'
daxpy n ( ia) pXI incX' pYR incY'
_ -> go n pX pY
where
go n' pX' pY'
| n' `seq` pX' `seq` pY' `seq` False = undefined
| n' <= 0 =
return ()
| otherwise = do
x <- peek pX'
y <- peek pY'
poke pY' (a * (conjugate x) + y)

let n''  = n' - 1