{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.BLAS.Vector.Slice (
T,
shape,
Vector,
RealOf,
ComplexOf,
slice,
fromVector,
toVector,
extract,
access,
dot, inner,
sum,
absSum,
norm2,
norm2Squared,
normInf,
normInf1,
argAbsMaximum,
argAbs1Maximum,
product,
scale, scaleReal,
add, sub,
negate, raise,
mac,
mul, mulConj,
minimum, argMinimum,
maximum, argMaximum,
limits, argLimits,
conjugate,
fromReal,
toComplex,
realFromComplexVector,
realPart,
imaginaryPart,
zipComplex,
unzipComplex,
) where
import qualified Numeric.BLAS.Slice as Slice
import qualified Numeric.BLAS.Scalar as Scalar
import qualified Numeric.BLAS.Private as Private
import Numeric.BLAS.Matrix.Modifier (Conjugation(NonConjugated, Conjugated))
import Numeric.BLAS.Scalar (ComplexOf, RealOf)
import Numeric.BLAS.Private (ComplexShape, copyConjugate, realPtr)
import qualified Numeric.BLAS.FFI.Generic as Blas
import qualified Numeric.BLAS.FFI.Complex as BlasComplex
import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import qualified Foreign.Marshal.Array.Guarded as ForeignArray
import Foreign.Marshal.Array (advancePtr)
import Foreign.ForeignPtr (withForeignPtr, castForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable, peek, peekElemOff)
import Foreign.C.Types (CInt)
import System.IO.Unsafe (unsafePerformIO)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (fmap, return, (=<<))
import Control.Applicative (liftA2, pure, (<$>), (<*>))
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Text.Show (Show)
import Data.Function (($), (.))
import Data.Complex (Complex)
import Data.Maybe (Maybe(Nothing,Just), maybe)
import Data.Tuple.HT (mapFst, uncurry3)
import Data.Tuple (fst, snd, uncurry)
import Data.Ord ((>=))
import Data.Eq (Eq, (==))
import Prelude (Int, fromIntegral, (-), (+), (*), error, IO)
type Vector = Array
shape :: T sh slice a -> slice
shape (Cons (Slice.Cons _s _k slc) _arr) = slc
mapShape :: (slice0 -> slice1) -> T sh slice0 a -> T sh slice1 a
mapShape f (Cons (Slice.Cons s k slc) arr) =
Cons (Slice.Cons s k (f slc)) arr
increment :: T sh slice a -> Int
increment (Cons (Slice.Cons _s k _slc) _arr) = k
startArg ::
(Storable a) =>
T sh slice a -> Call.FortranIO r (Ptr a)
startArg (Cons (Slice.Cons s _k _slice) (Array _sh x)) = do
sxPtr <- ContT $ withForeignPtr x
return (advancePtr sxPtr s)
sliceArg ::
(Storable a) =>
T sh slice a -> Call.FortranIO r (Ptr a, Ptr CInt)
sliceArg x =
liftA2 (,) (startArg x) (Call.cint $ increment x)
sizeSliceArg ::
(Shape.C sh, Storable a) =>
T shA sh a -> ContT r IO (Ptr CInt, Ptr a, Ptr CInt)
sizeSliceArg x =
liftA2
(\nPtr (xPtr,incxPtr) -> (nPtr, xPtr,incxPtr))
(Call.cint $ Shape.size $ shape x)
(sliceArg x)
infixl 4 <*|>
(<*|>) ::
(Storable a) =>
Call.FortranIO r (Ptr a -> Ptr CInt -> b) ->
T sh slice a ->
Call.FortranIO r b
f <*|> x = fmap uncurry f <*> sliceArg x
data T sh slice a = Cons (Slice.T slice) (Vector sh a)
deriving (Show)
toVector ::
(Shape.C slice, Class.Floating a) =>
T sh slice a -> Vector slice a
toVector x =
Array.unsafeCreateWithSize (shape x) $ \n syPtr ->
evalContT $ Call.run $
pure Blas.copy
<*> Call.cint n
<*|> x
<*> pure syPtr
<*> Call.cint 1
fromVector :: (Shape.C sh) => Vector sh a -> T sh sh a
fromVector xs = Cons (Slice.fromShape $ Array.shape xs) xs
slice :: (Slice.T shA -> Slice.T shB) -> T sh shA a -> T sh shB a
slice f (Cons slc xs) = Cons (f slc) xs
extract ::
(Shape.C slice, Shape.C sh, Class.Floating a) =>
(Slice.T sh -> Slice.T slice) -> Vector sh a -> Vector slice a
extract slc xs = toVector $ slice slc $ fromVector xs
access, (!) ::
(Shape.C shA, Shape.Indexed sh, Storable a) =>
T shA sh a -> Shape.Index sh -> a
access (Cons (Slice.Cons s k ssh) (Array sh x)) ix =
Array (Shape.Deferred sh) x
Array.!
Shape.DeferredIndex (s + k * Shape.offset ssh ix)
(!) = access
dot ::
(Shape.C sh, Eq sh, Class.Floating a) =>
T shA sh a -> T shB sh a -> a
dot =
runDot $
Class.switchFloating
(Dot dotReal)
(Dot dotReal)
(Dot dotComplex)
(Dot dotComplex)
inner ::
(Shape.C sh, Eq sh, Class.Floating a) =>
T shA sh a -> T shB sh a -> a
inner =
runDot $
Class.switchFloating
(Dot dotReal)
(Dot dotReal)
(Dot $ innerComplex . toVector)
(Dot $ innerComplex . toVector)
newtype Dot f g a = Dot {runDot :: f a -> g a -> a}
dotReal ::
(Shape.C sh, Eq sh, Class.Real a) =>
T shA sh a -> T shB sh a -> a
dotReal x y = unsafePerformIO $ do
let shX = shape x
let shY = shape y
Call.assert "dot: shapes mismatch" (shX == shY)
evalContT $ do
nPtr <- Call.cint $ Shape.size shX
(sxPtr, incxPtr) <- sliceArg x
(syPtr, incyPtr) <- sliceArg y
liftIO $ BlasReal.dot nPtr sxPtr incxPtr syPtr incyPtr
dotComplex ::
(Shape.C sh, Eq sh, Class.Real a) =>
T shA sh (Complex a) -> T shB sh (Complex a) -> Complex a
dotComplex x y = unsafePerformIO $ do
Call.assert "dot: shapes mismatch" (shape x == shape y)
evalContT $ do
transPtr <- Call.char 'N'
mPtr <- Call.cint 1
nPtr <- Call.cint $ Shape.size $ shape x
alphaPtr <- Call.number Scalar.one
(xPtr, ldxPtr) <- sliceArg x
(yPtr, incyPtr) <- sliceArg y
betaPtr <- Call.number Scalar.zero
zPtr <- Call.alloca
inczPtr <- Call.cint 1
liftIO $
Private.gemv
transPtr mPtr nPtr alphaPtr xPtr ldxPtr
yPtr incyPtr betaPtr zPtr inczPtr
liftIO $ peek zPtr
innerComplex ::
(Shape.C sh, Eq sh, Class.Real a) =>
Vector sh (Complex a) -> T shB sh (Complex a) -> Complex a
innerComplex (Array shX x) y = unsafePerformIO $ do
Call.assert "dot: shapes mismatch" (shX == shape y)
evalContT $ do
let m = Shape.size shX
transPtr <- Call.char 'C'
mPtr <- Call.cint m
nPtr <- Call.cint 1
alphaPtr <- Call.number Scalar.one
xPtr <- ContT $ withForeignPtr x
ldxPtr <- Call.leadingDim m
(yPtr, incyPtr) <- sliceArg y
betaPtr <- Call.number Scalar.zero
zPtr <- Call.alloca
inczPtr <- Call.cint 1
liftIO $
Private.gemv
transPtr mPtr nPtr alphaPtr xPtr ldxPtr
yPtr incyPtr betaPtr zPtr inczPtr
liftIO $ peek zPtr
sum :: (Shape.C sh, Class.Floating a) => T shA sh a -> a
sum x = unsafePerformIO $ evalContT $ do
xPtr <- startArg x
liftIO $ Private.sum (Shape.size $ shape x) xPtr (increment x)
absSum :: (Shape.C sh, Class.Floating a) => T shA sh a -> RealOf a
absSum arr = unsafePerformIO $
evalContT $ liftIO . uncurry3 asum =<< sizeSliceArg arr
asum :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a)
asum =
getNrm $
Class.switchFloating
(Nrm BlasReal.asum) (Nrm BlasReal.asum)
(Nrm BlasComplex.casum) (Nrm BlasComplex.casum)
norm2 :: (Shape.C sh, Class.Floating a) => T shA sh a -> RealOf a
norm2 arr = unsafePerformIO $
evalContT $ liftIO . uncurry3 nrm2 =<< sizeSliceArg arr
nrm2 :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a)
nrm2 =
getNrm $
Class.switchFloating
(Nrm BlasReal.nrm2) (Nrm BlasReal.nrm2)
(Nrm BlasComplex.cnrm2) (Nrm BlasComplex.cnrm2)
newtype Nrm a = Nrm {getNrm :: Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a)}
newtype Norm f a = Norm {getNorm :: f a -> RealOf a}
norm2Squared :: (Shape.C sh, Class.Floating a) => T shA sh a -> RealOf a
norm2Squared =
getNorm $
Class.switchFloating
(Norm norm2SquaredReal)
(Norm norm2SquaredReal)
(Norm norm2SquaredComplex)
(Norm norm2SquaredComplex)
norm2SquaredReal :: (Shape.C sh, Class.Real a) => T shA sh a -> a
norm2SquaredReal x =
unsafePerformIO $ evalContT $ do
(nPtr, sxPtr, incxPtr) <- sizeSliceArg x
liftIO $ BlasReal.dot nPtr sxPtr incxPtr sxPtr incxPtr
norm2SquaredComplex :: (Shape.C sh, Class.Real a) => T shA sh (Complex a) -> a
norm2SquaredComplex x =
unsafePerformIO $ evalContT $ do
nPtr <- Call.cint $ Shape.size $ shape x
xPtr <- startArg x
let xrPtr = realPtr xPtr
let xiPtr = advancePtr xrPtr 1
incxPtr <- Call.cint (increment x * 2)
liftIO $
liftA2 (+)
(BlasReal.dot nPtr xrPtr incxPtr xrPtr incxPtr)
(BlasReal.dot nPtr xiPtr incxPtr xiPtr incxPtr)
normInf :: (Shape.C sh, Class.Floating a) => T shA sh a -> RealOf a
normInf arr = unsafePerformIO $
fmap (Scalar.absolute . maybe Scalar.zero snd) $ absMax arr
normInf1 :: (Shape.C sh, Class.Floating a) => T shA sh a -> RealOf a
normInf1 x = unsafePerformIO $
evalContT $ do
(nPtr, sxPtr, incxPtr) <- sizeSliceArg x
liftIO $
fmap (Scalar.norm1 . maybe Scalar.zero snd) $
peekElemOff1 sxPtr (increment x) =<< Blas.iamax nPtr sxPtr incxPtr
argAbsMaximum ::
(Shape.InvIndexed sh, Class.Floating a) =>
T shA sh a -> (Shape.Index sh, a)
argAbsMaximum arr = unsafePerformIO $
fmap
(maybe
(error "Vector.argAbsMaximum: empty vector")
(mapFst (Shape.uncheckedIndexFromOffset $ shape arr))) $
absMax arr
absMax ::
(Shape.C sh, Class.Floating a) =>
T shA sh a -> IO (Maybe (Int, a))
absMax x =
case Scalar.complexSingletonOfFunctor x of
Scalar.Real -> evalContT $ do
(nPtr, sxPtr, incxPtr) <- sizeSliceArg x
liftIO $
peekElemOff1 sxPtr (increment x) =<< Blas.iamax nPtr sxPtr incxPtr
Scalar.Complex -> evalContT $ do
let n = Shape.size $ shape x
sxPtr <- startArg x
let incx = increment x
liftIO $ peekElemOff1 sxPtr incx =<< absMaxComplex n sxPtr incx
absMaxComplex :: (Class.Real a) => Int -> Ptr (Complex a) -> Int -> IO CInt
absMaxComplex n sxPtr incx =
ForeignArray.alloca n $ \syPtr -> do
let xrPtr = realPtr sxPtr
let incx2 = 2*incx
Private.mul NonConjugated n xrPtr incx2 xrPtr incx2 syPtr 1
let xiPtr = advancePtr xrPtr 1
Private.mulAdd NonConjugated n xiPtr incx2 xiPtr incx2 Scalar.one syPtr 1
evalContT $ do
nPtr <- Call.cint n
incyPtr <- Call.cint 1
liftIO $ Blas.iamax nPtr syPtr incyPtr
argAbs1Maximum ::
(Shape.InvIndexed sh, Class.Floating a) =>
T shA sh a -> (Shape.Index sh, a)
argAbs1Maximum x = unsafePerformIO $
evalContT $ do
(nPtr, sxPtr, incxPtr) <- sizeSliceArg x
liftIO $
fmap
(maybe
(error "Vector.argAbs1Maximum: empty vector")
(mapFst (Shape.uncheckedIndexFromOffset $ shape x))) $
peekElemOff1 sxPtr (increment x) =<< Blas.iamax nPtr sxPtr incxPtr
peekElemOff1 :: (Storable a) => Ptr a -> Int -> CInt -> IO (Maybe (Int, a))
peekElemOff1 ptr inc k1 =
let k1i = fromIntegral k1
ki = k1i-1
in if k1i == 0
then return Nothing
else Just . (,) ki <$> peekElemOff ptr (ki*inc)
product :: (Shape.C sh, Class.Floating a) => T shA sh a -> a
product x = unsafePerformIO $ evalContT $ do
xPtr <- startArg x
liftIO $ Private.product (Shape.size $ shape x) xPtr (increment x)
minimum, maximum :: (Shape.C shA, Shape.C sh, Class.Real a) => T shA sh a -> a
minimum = fst . limits
maximum = snd . limits
argMinimum, argMaximum ::
(Shape.C shA, Shape.InvIndexed sh, Shape.Index sh ~ ix, Class.Real a) =>
T shA sh a -> (ix,a)
argMinimum = fst . argLimits
argMaximum = snd . argLimits
limits :: (Shape.C shA, Shape.C sh, Class.Real a) => T shA sh a -> (a,a)
limits xs0 =
let xs = mapShape Shape.Deferred xs0
x0 = snd $ argAbs1Maximum xs
x1 = xs ! fst (argAbs1Maximum (fromVector (raise (-x0) xs)))
in if x0>=0 then (x1,x0) else (x0,x1)
argLimits ::
(Shape.C shA, Shape.InvIndexed sh, Shape.Index sh ~ ix, Class.Real a) =>
T shA sh a -> ((ix,a),(ix,a))
argLimits xs =
let p0@(_i0,x0) = argAbs1Maximum xs
p1 = (i1,xs!i1); i1 = fst $ argAbs1Maximum $ fromVector $ raise (-x0) xs
in if x0>=0 then (p1,p0) else (p0,p1)
scale, _scale ::
(Shape.C sh, Class.Floating a) =>
a -> T shA sh a -> Vector sh a
scale alpha x = Array.unsafeCreateWithSize (shape x) $ \n syPtr -> do
evalContT $ do
alphaPtr <- Call.number alpha
nPtr <- Call.cint n
(sxPtr, incxPtr) <- sliceArg x
incyPtr <- Call.cint 1
liftIO $ Blas.copy nPtr sxPtr incxPtr syPtr incyPtr
liftIO $ Blas.scal nPtr alphaPtr syPtr incyPtr
_scale a b = Array.unsafeCreateWithSize (shape b) $ \n cPtr -> do
let m = 1
let k = 1
evalContT $ do
transaPtr <- Call.char 'N'
transbPtr <- Call.char 'N'
mPtr <- Call.cint m
kPtr <- Call.cint k
nPtr <- Call.cint n
alphaPtr <- Call.number Scalar.one
aPtr <- Call.number a
ldaPtr <- Call.leadingDim m
(bPtr, ldbPtr) <- sliceArg b
betaPtr <- Call.number Scalar.zero
ldcPtr <- Call.leadingDim m
liftIO $
Blas.gemm
transaPtr transbPtr mPtr nPtr kPtr alphaPtr
aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr
scaleReal ::
(Shape.C sh, Class.Floating a) =>
RealOf a -> T shA sh a -> Vector sh a
scaleReal =
getScaleReal $
Class.switchFloating
(ScaleReal scale)
(ScaleReal scale)
(ScaleReal $ scale . Scalar.fromReal)
(ScaleReal $ scale . Scalar.fromReal)
newtype ScaleReal f g a = ScaleReal {getScaleReal :: RealOf a -> f a -> g a}
infixl 6 `add`, `sub`
add, sub ::
(Shape.C sh, Eq sh, Class.Floating a) =>
T shA sh a -> T shB sh a -> Vector sh a
add = mac Scalar.one
sub x y = mac Scalar.minusOne y x
mac ::
(Shape.C sh, Eq sh, Class.Floating a) =>
a -> T shA sh a -> T shB sh a -> Vector sh a
mac alpha x y =
Array.unsafeCreateWithSize (shape x) $ \n szPtr -> do
Call.assert "mac: shapes mismatch" (shape x == shape y)
evalContT $ do
nPtr <- Call.cint n
saPtr <- Call.number alpha
(sxPtr, incxPtr) <- sliceArg x
(syPtr, incyPtr) <- sliceArg y
inczPtr <- Call.cint 1
liftIO $ Blas.copy nPtr syPtr incyPtr szPtr inczPtr
liftIO $ Blas.axpy nPtr saPtr sxPtr incxPtr szPtr inczPtr
negate :: (Shape.C sh, Class.Floating a) => T shA sh a -> Vector sh a
negate =
getConjugate $
Class.switchFloating
(Conjugate $ scaleReal Scalar.minusOne)
(Conjugate $ scaleReal Scalar.minusOne)
(Conjugate $ scaleReal Scalar.minusOne)
(Conjugate $ scaleReal Scalar.minusOne)
raise :: (Shape.C sh, Class.Floating a) => a -> T shA sh a -> Vector sh a
raise c x =
Array.unsafeCreateWithSize (shape x) $ \n yPtr -> evalContT $ do
nPtr <- Call.cint n
cPtr <- Call.number c
onePtr <- Call.number Scalar.one
inccPtr <- Call.cint 0
(xPtr, incxPtr) <- sliceArg x
inc1Ptr <- Call.cint 1
liftIO $ do
Blas.copy nPtr xPtr incxPtr yPtr inc1Ptr
Blas.axpy nPtr onePtr cPtr inccPtr yPtr inc1Ptr
mul ::
(Shape.C sh, Eq sh, Class.Floating a) =>
T shA sh a -> T shB sh a -> Vector sh a
mul = mulConjugation NonConjugated
mulConj ::
(Shape.C sh, Eq sh, Class.Floating a) =>
T shA sh a -> T shB sh a -> Vector sh a
mulConj = mulConjugation Conjugated
mulConjugation ::
(Shape.C sh, Eq sh, Class.Floating a) =>
Conjugation -> T shA sh a -> T shB sh a -> Vector sh a
mulConjugation conj a x =
Array.unsafeCreateWithSize (shape x) $ \n yPtr -> do
Call.assert "mul: shapes mismatch" (shape a == shape x)
evalContT $ do
aPtr <- startArg a
xPtr <- startArg x
liftIO $ Private.mul conj n aPtr (increment a) xPtr (increment x) yPtr 1
newtype Conjugate f g a = Conjugate {getConjugate :: f a -> g a}
conjugate ::
(Shape.C sh, Class.Floating a) =>
T shA sh a -> Vector sh a
conjugate =
getConjugate $
Class.switchFloating
(Conjugate toVector)
(Conjugate toVector)
(Conjugate complexConjugate)
(Conjugate complexConjugate)
complexConjugate ::
(Shape.C sh, Class.Real a) =>
T shA sh (Complex a) -> Vector sh (Complex a)
complexConjugate x = Array.unsafeCreateWithSize (shape x) $ \n syPtr ->
evalContT $ do
nPtr <- Call.cint n
(sxPtr, incxPtr) <- sliceArg x
incyPtr <- Call.cint 1
liftIO $ copyConjugate nPtr sxPtr incxPtr syPtr incyPtr
fromReal ::
(Shape.C sh, Class.Floating a) => T shA sh (RealOf a) -> Vector sh a
fromReal =
getFromReal $
Class.switchFloating
(FromReal toVector)
(FromReal toVector)
(FromReal complexFromReal)
(FromReal complexFromReal)
newtype FromReal f g a = FromReal {getFromReal :: f (RealOf a) -> g a}
toComplex ::
(Shape.C sh, Class.Floating a) => T shA sh a -> Vector sh (ComplexOf a)
toComplex =
getToComplex $
Class.switchFloating
(ToComplex complexFromReal)
(ToComplex complexFromReal)
(ToComplex toVector)
(ToComplex toVector)
newtype ToComplex f g a = ToComplex {getToComplex :: f a -> g (ComplexOf a)}
complexFromReal ::
(Shape.C sh, Class.Real a) => T shA sh a -> Vector sh (Complex a)
complexFromReal x =
Array.unsafeCreateWithSize (shape x) $ \n yPtr ->
case realPtr yPtr of
yrPtr -> evalContT $ do
nPtr <- Call.cint n
(xPtr, incxPtr) <- sliceArg x
incyPtr <- Call.cint 2
inczPtr <- Call.cint 0
zPtr <- Call.number Scalar.zero
liftIO $ do
Blas.copy nPtr xPtr incxPtr yrPtr incyPtr
Blas.copy nPtr zPtr inczPtr (advancePtr yrPtr 1) incyPtr
realFromComplexVector ::
(Shape.C sh) =>
Vector sh (Complex a) -> T (sh, ComplexShape) (sh, ComplexShape) a
realFromComplexVector (Array sh a) =
let csh = (sh, Shape.static) in
Cons (Slice.fromShape csh) (Array csh (castForeignPtr a))
realPart ::
(Shape.C sh, Class.Real a) =>
T shA sh (Complex a) -> T (shA, ComplexShape) sh a
realPart (Cons (Slice.Cons s k slc) (Array sh a)) =
Cons
(Slice.Cons (2*s) (2*k) slc)
(Array (sh, Shape.static) (castForeignPtr a))
imaginaryPart ::
(Shape.C sh, Class.Real a) =>
T shA sh (Complex a) -> T (shA, ComplexShape) sh a
imaginaryPart (Cons (Slice.Cons s k slc) (Array sh a)) =
Cons
(Slice.Cons (2*s+1) (2*k) slc)
(Array (sh, Shape.static) (castForeignPtr a))
zipComplex ::
(Shape.C sh, Eq sh, Class.Real a) =>
T shA sh a -> T shB sh a -> Vector sh (Complex a)
zipComplex xr xi =
Array.unsafeCreateWithSize (shape xr) $ \n yPtr -> evalContT $ do
liftIO $ Call.assert "zipComplex: shapes mismatch" (shape xr == shape xi)
nPtr <- Call.cint n
(xrPtr, incxrPtr) <- sliceArg xr
(xiPtr, incxiPtr) <- sliceArg xi
let yrPtr = realPtr yPtr
incyPtr <- Call.cint 2
liftIO $ do
Blas.copy nPtr xrPtr incxrPtr yrPtr incyPtr
Blas.copy nPtr xiPtr incxiPtr (advancePtr yrPtr 1) incyPtr
unzipComplex ::
(Shape.C sh, Class.Real a) =>
T shA sh (Complex a) ->
(T (shA,ComplexShape) sh a, T (shA,ComplexShape) sh a)
unzipComplex x = (realPart x, imaginaryPart x)