{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Vector (
   Vector,
   RealOf,
   ComplexOf,
   toList,
   fromList,
   autoFromList,
   CheckedArray.append, (+++),
   CheckedArray.take, CheckedArray.drop,
   CheckedArray.takeLeft, CheckedArray.takeRight,
   swap,
   CheckedArray.singleton,
   constant,
   zero,
   one,
   unit,
   dot, inner, (-*|),
   sum,
   absSum,
   norm1,
   norm2,
   norm2Squared,
   normInf,
   normInf1,
   argAbsMaximum,
   argAbs1Maximum,
   product,
   scale, scaleReal, (.*|),
   add, sub, (|+|), (|-|),
   negate, raise,
   mac,
   mul,
   divide, recip,
   minimum, argMinimum,
   maximum, argMaximum,
   limits, argLimits,
   CheckedArray.foldl,
   CheckedArray.foldl1,
   CheckedArray.foldMap,

   conjugate,
   fromReal,
   toComplex,
   realPart,
   imaginaryPart,
   zipComplex,
   unzipComplex,

   random, RandomDistribution(..),
   ) where

import qualified Numeric.LAPACK.Matrix.RowMajor as RowMajor
import qualified Numeric.LAPACK.Vector.Private as Vector
import qualified Numeric.LAPACK.Scalar as Scalar
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Matrix.Hermitian.Private
         (Determinant(Determinant, getDeterminant))
import Numeric.LAPACK.Linear.Private (withInfo)
import Numeric.LAPACK.Scalar (ComplexOf, RealOf, minusOne, absolute)
import Numeric.LAPACK.Private
         (ComplexPart(RealPart, ImaginaryPart), fill, copyConjugate, realPtr)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.LAPACK.FFI.Complex as LapackComplex
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.BLAS.FFI.Complex as BlasComplex
import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import Foreign.Marshal.Array (advancePtr)
import Foreign.ForeignPtr (withForeignPtr, castForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable, peek, peekElemOff, pokeElemOff)
import Foreign.C.Types (CInt)

import System.IO.Unsafe (unsafePerformIO)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.ST (runST)
import Control.Monad (fmap, return, (=<<))
import Control.Applicative (liftA3, (<$>))

import qualified Data.Array.Comfort.Storable.Mutable.Unchecked as UMutArray
import qualified Data.Array.Comfort.Storable.Mutable as MutArray
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array), append, (!))
import Data.Array.Comfort.Shape ((:+:))

import Data.Function (id, flip, ($), (.))
import Data.Complex (Complex)
import Data.Maybe (Maybe(Nothing,Just), maybe)
import Data.Tuple.HT (mapFst, uncurry3)
import Data.Tuple (fst, snd)
import Data.Word (Word64)
import Data.Bits (shiftR, (.&.))
import Data.Ord (Ord, (>=))
import Data.Eq (Eq, (==))
import Data.Bool (Bool(False,True))

import Prelude (Int, fromIntegral, (+), (-), (*), Char, Show, Enum, error, IO)


type Vector = Array


toList :: (Shape.C sh, Storable a) => Vector sh a -> [a]
toList = Array.toList

fromList :: (Shape.C sh, Storable a) => sh -> [a] -> Vector sh a
fromList = CheckedArray.fromList

autoFromList :: (Storable a) => [a] -> Vector (Shape.ZeroBased Int) a
autoFromList = Array.vectorFromList


{- |
> constant () = singleton

However, singleton does not need 'Class.Floating' constraint.
-}
constant :: (Shape.C sh, Class.Floating a) => sh -> a -> Vector sh a
constant sh a = Array.unsafeCreateWithSize sh $ fill a

zero :: (Shape.C sh, Class.Floating a) => sh -> Vector sh a
zero = flip constant Scalar.zero

one :: (Shape.C sh, Class.Floating a) => sh -> Vector sh a
one = flip constant Scalar.one

unit ::
   (Shape.Indexed sh, Class.Floating a) =>
   sh -> Shape.Index sh -> Vector sh a
unit sh ix = Array.unsafeCreateWithSize sh $ \n xPtr -> do
   fill Scalar.zero n xPtr
   pokeElemOff xPtr (Shape.offset sh ix) Scalar.one


{- |
Precedence and associativity (right) of (List.++).
This also matches '(Shape.:+:)'.
-}
infixr 5 +++

(+++) ::
   (Shape.C shx, Shape.C shy, Storable a) =>
   Vector shx a -> Vector shy a -> Vector (shx:+:shy) a
(+++) = append


swap ::
   (Shape.Indexed sh, Storable a) =>
   Shape.Index sh -> Shape.Index sh -> Vector sh a -> Vector sh a
swap i j x =
   runST (do
      y <- MutArray.thaw x
      xi <- MutArray.read y i
      xj <- MutArray.read y j
      MutArray.write y i xj
      MutArray.write y j xi
      UMutArray.unsafeFreeze y)


infixl 7 -*|, .*|

newtype Dot f a = Dot {runDot :: f a -> f a -> a}

{- |
> dot x y = Matrix.toScalar (singleRow x <#> singleColumn y)
-}
dot, (-*|) ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Vector sh a -> Vector sh a -> a
(-*|) = dot
dot =
   runDot $
   Class.switchFloating
      (Dot dotReal)
      (Dot dotReal)
      (Dot $ dotComplex 'T')
      (Dot $ dotComplex 'T')

{- |
> inner x y = dot (conjugate x) y
-}
inner ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Vector sh a -> Vector sh a -> a
inner =
   runDot $
   Class.switchFloating
      (Dot dotReal)
      (Dot dotReal)
      (Dot $ dotComplex 'C')
      (Dot $ dotComplex 'C')

dotReal ::
   (Shape.C sh, Eq sh, Class.Real a) =>
   Vector sh a -> Vector sh a -> a
dotReal arrX@(Array shX _x) (Array shY y) = unsafePerformIO $ do
   Call.assert "dot: shapes mismatch" (shX == shY)
   evalContT $ do
      (nPtr, sxPtr, incxPtr) <- vectorArgs arrX
      syPtr <- ContT $ withForeignPtr y
      incyPtr <- Call.cint 1
      liftIO $ BlasReal.dot nPtr sxPtr incxPtr syPtr incyPtr

{-
We cannot use 'cdot' because Haskell's FFI
does not support Complex numbers as return values.
-}
dotComplex ::
   (Shape.C sh, Eq sh, Class.Real a) =>
   Char -> Vector sh (Complex a) -> Vector sh (Complex a) -> Complex a
dotComplex trans (Array shX x) (Array shY y) = unsafePerformIO $ do
   Call.assert "dot: shapes mismatch" (shX == shY)
   evalContT $ do
      let m = Shape.size shX
      transPtr <- Call.char trans
      mPtr <- Call.cint m
      nPtr <- Call.cint 1
      alphaPtr <- Call.number Scalar.one
      xPtr <- ContT $ withForeignPtr x
      ldxPtr <- Call.leadingDim m
      yPtr <- ContT $ withForeignPtr y
      incyPtr <- Call.cint 1
      betaPtr <- Call.number Scalar.zero
      zPtr <- Call.alloca
      inczPtr <- Call.cint 1
      liftIO $
         Private.gemv
            transPtr mPtr nPtr alphaPtr xPtr ldxPtr
            yPtr incyPtr betaPtr zPtr inczPtr
      liftIO $ peek zPtr

sum :: (Shape.C sh, Class.Floating a) => Vector sh a -> a
sum (Array sh x) = unsafePerformIO $
   withForeignPtr x $ \xPtr -> Private.sum (Shape.size sh) xPtr 1

norm1 :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a
norm1 arr = unsafePerformIO $
   evalContT $ liftIO . uncurry3 csum1 =<< vectorArgs arr

csum1 :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a)
csum1 =
   getNorm $
   Class.switchFloating
      (Norm BlasReal.asum)
      (Norm BlasReal.asum)
      (Norm LapackComplex.sum1)
      (Norm LapackComplex.sum1)


{- |
Sum of the absolute values of real numbers or components of complex numbers.
For real numbers it is equivalent to 'norm1'.
-}
absSum :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a
absSum arr = unsafePerformIO $
   evalContT $ liftIO . uncurry3 asum =<< vectorArgs arr

asum :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a)
asum =
   getNorm $
   Class.switchFloating
      (Norm BlasReal.asum) (Norm BlasReal.asum)
      (Norm BlasComplex.casum) (Norm BlasComplex.casum)


{- |
Euclidean norm of a vector or Frobenius norm of a matrix.
-}
norm2 :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a
norm2 arr = unsafePerformIO $
   evalContT $ liftIO . uncurry3 nrm2 =<< vectorArgs arr

nrm2 :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a)
nrm2 =
   getNorm $
   Class.switchFloating
      (Norm BlasReal.nrm2) (Norm BlasReal.nrm2)
      (Norm BlasComplex.cnrm2) (Norm BlasComplex.cnrm2)

newtype Norm a =
   Norm {getNorm :: Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a)}


norm2Squared :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a
norm2Squared =
   getDeterminant $
   Class.switchFloating
      (Determinant norm2SquaredReal)
      (Determinant norm2SquaredReal)
      (Determinant $ norm2SquaredReal . decomplex)
      (Determinant $ norm2SquaredReal . decomplex)

norm2SquaredReal :: (Shape.C sh, Class.Real a) => Vector sh a -> a
norm2SquaredReal arr =
   unsafePerformIO $ evalContT $ do
      (nPtr, sxPtr, incxPtr) <- vectorArgs arr
      liftIO $ BlasReal.dot nPtr sxPtr incxPtr sxPtr incxPtr


normInf :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a
normInf arr = unsafePerformIO $
   evalContT $ do
      (nPtr, sxPtr, incxPtr) <- vectorArgs arr
      liftIO $
         fmap (absolute . maybe Scalar.zero snd) $
         peekElemOff1 sxPtr =<< Vector.absMax nPtr sxPtr incxPtr

{- |
Computes (almost) the infinity norm of the vector.
For complex numbers every element is replaced
by the sum of the absolute component values first.
-}
normInf1 :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a
normInf1 arr = unsafePerformIO $
   evalContT $ do
      (nPtr, sxPtr, incxPtr) <- vectorArgs arr
      liftIO $
         fmap (Scalar.norm1 . maybe Scalar.zero snd) $
         peekElemOff1 sxPtr =<< BlasGen.iamax nPtr sxPtr incxPtr


{- |
Returns the index and value of the element with the maximal absolute value.
Caution: It actually returns the value of the element, not its absolute value!
-}
argAbsMaximum ::
   (Shape.InvIndexed sh, Class.Floating a) =>
   Vector sh a -> (Shape.Index sh, a)
argAbsMaximum arr = unsafePerformIO $
   evalContT $ do
      (nPtr, sxPtr, incxPtr) <- vectorArgs arr
      liftIO $
         fmap
            (maybe
               (error "Vector.argAbsMaximum: empty vector")
               (mapFst (Shape.uncheckedIndexFromOffset $ Array.shape arr))) $
         peekElemOff1 sxPtr =<< Vector.absMax nPtr sxPtr incxPtr


{- |
Returns the index and value of the element with the maximal absolute value.
The function does not strictly compare the absolute value of a complex number
but the sum of the absolute complex components.
Caution: It actually returns the value of the element, not its absolute value!
-}
argAbs1Maximum ::
   (Shape.InvIndexed sh, Class.Floating a) =>
   Vector sh a -> (Shape.Index sh, a)
argAbs1Maximum arr = unsafePerformIO $
   evalContT $ do
      (nPtr, sxPtr, incxPtr) <- vectorArgs arr
      liftIO $
         fmap
            (maybe
               (error "Vector.argAbs1Maximum: empty vector")
               (mapFst (Shape.uncheckedIndexFromOffset $ Array.shape arr))) $
         peekElemOff1 sxPtr =<< BlasGen.iamax nPtr sxPtr incxPtr

vectorArgs ::
   (Shape.C sh) => Array sh a -> ContT r IO (Ptr CInt, Ptr a, Ptr CInt)
vectorArgs (Array sh x) =
   liftA3 (,,)
      (Call.cint $ Shape.size sh)
      (ContT $ withForeignPtr x)
      (Call.cint 1)

peekElemOff1 :: (Storable a) => Ptr a -> CInt -> IO (Maybe (Int, a))
peekElemOff1 ptr k1 =
   let k1i = fromIntegral k1
       ki = k1i-1
   in if k1i == 0
         then return Nothing
         else Just . (,) ki <$> peekElemOff ptr ki


product :: (Shape.C sh, Class.Floating a) => Vector sh a -> a
product (Array sh x) = unsafePerformIO $
   withForeignPtr x $ \xPtr -> Private.product (Shape.size sh) xPtr 1


{- |
For restrictions see 'limits'.
-}
minimum, maximum :: (Shape.C sh, Class.Real a) => Vector sh a -> a
minimum = fst . limits
maximum = snd . limits

{- |
For restrictions see 'limits'.
-}
argMinimum, argMaximum ::
   (Shape.InvIndexed sh, Shape.Index sh ~ ix, Class.Real a) =>
   Vector sh a -> (ix,a)
argMinimum = fst . argLimits
argMaximum = snd . argLimits

{- |
It should hold @limits x = Array.limits x@.
The function is based on fast BLAS functions.
It should be faster than @Array.minimum@ and @Array.maximum@
although it is certainly not as fast as possible.
It is less precise if minimum and maximum differ considerably in magnitude
and there are several minimum or maximum candidates of similar value.
E.g. you cannot rely on the property
that @raise (- minimum x) x@ has only non-negative elements.
-}
limits :: (Shape.C sh, Class.Real a) => Vector sh a -> (a,a)
limits xs0 =
   let xs = Array.mapShape Shape.Deferred xs0
       x0 = snd $ argAbsMaximum xs
       x1 = xs ! fst (argAbsMaximum (raise (-x0) xs))
   in if x0>=0 then (x1,x0) else (x0,x1)

argLimits ::
   (Shape.InvIndexed sh, Shape.Index sh ~ ix, Class.Real a) =>
   Vector sh a -> ((ix,a),(ix,a))
argLimits xs =
   let p0@(_i0,x0) = argAbsMaximum xs
       p1 = (i1,xs!i1); i1 = fst $ argAbsMaximum $ raise (-x0) xs
   in if x0>=0 then (p1,p0) else (p0,p1)


scale, _scale, (.*|) ::
   (Shape.C sh, Class.Floating a) =>
   a -> Vector sh a -> Vector sh a
(.*|) = scale

scale alpha (Array sh x) = Array.unsafeCreateWithSize sh $ \n syPtr -> do
   evalContT $ do
      alphaPtr <- Call.number alpha
      nPtr <- Call.cint n
      sxPtr <- ContT $ withForeignPtr x
      incxPtr <- Call.cint 1
      incyPtr <- Call.cint 1
      liftIO $ BlasGen.copy nPtr sxPtr incxPtr syPtr incyPtr
      liftIO $ BlasGen.scal nPtr alphaPtr syPtr incyPtr

_scale a (Array sh b) = Array.unsafeCreateWithSize sh $ \n cPtr -> do
   let m = 1
   let k = 1
   evalContT $ do
      transaPtr <- Call.char 'N'
      transbPtr <- Call.char 'N'
      mPtr <- Call.cint m
      kPtr <- Call.cint k
      nPtr <- Call.cint n
      alphaPtr <- Call.number Scalar.one
      aPtr <- Call.number a
      ldaPtr <- Call.leadingDim m
      bPtr <- ContT $ withForeignPtr b
      ldbPtr <- Call.leadingDim k
      betaPtr <- Call.number Scalar.zero
      ldcPtr <- Call.leadingDim m
      liftIO $
         BlasGen.gemm
            transaPtr transbPtr mPtr nPtr kPtr alphaPtr
            aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr


scaleReal ::
   (Shape.C sh, Class.Floating a) =>
   RealOf a -> Vector sh a -> Vector sh a
scaleReal =
   getScaleReal $
   Class.switchFloating
      (ScaleReal scale)
      (ScaleReal scale)
      (ScaleReal $ \x -> recomplex . scale x . decomplex)
      (ScaleReal $ \x -> recomplex . scale x . decomplex)

newtype ScaleReal f a = ScaleReal {getScaleReal :: RealOf a -> f a -> f a}


decomplex ::
   (Class.Real a) =>
   Vector sh (Complex a) -> Vector (sh, Shape.Enumeration ComplexPart) a
decomplex (Array sh a) = Array (sh, Shape.Enumeration) (castForeignPtr a)

recomplex ::
   (Class.Real a) =>
   Vector (sh, Shape.Enumeration ComplexPart) a -> Vector sh (Complex a)
recomplex (Array (sh, Shape.Enumeration) a) = Array sh (castForeignPtr a)



infixl 6 |+|, |-|, `add`, `sub`


add, sub, (|+|), (|-|) ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Vector sh a -> Vector sh a -> Vector sh a
add = mac Scalar.one
sub x y = mac minusOne y x

(|+|) = add
(|-|) = sub

mac ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   a -> Vector sh a -> Vector sh a -> Vector sh a
mac alpha x y =
   if Array.shape x == Array.shape y
      then Vector.mac alpha x y
      else error "mac: shapes mismatch"

negate :: (Shape.C sh, Class.Floating a) => Vector sh a -> Vector sh a
negate =
   getConjugate $
   Class.switchFloating
      (Conjugate $ scaleReal Scalar.minusOne)
      (Conjugate $ scaleReal Scalar.minusOne)
      (Conjugate $ scaleReal Scalar.minusOne)
      (Conjugate $ scaleReal Scalar.minusOne)


raise :: (Shape.C sh, Class.Floating a) => a -> Array sh a -> Array sh a
raise c (Array shX x) =
   Array.unsafeCreateWithSize shX $ \n yPtr -> evalContT $ do
      nPtr <- Call.cint n
      cPtr <- Call.number c
      onePtr <- Call.number Scalar.one
      inccPtr <- Call.cint 0
      xPtr <- ContT $ withForeignPtr x
      inc1Ptr <- Call.cint 1
      liftIO $ do
         BlasGen.copy nPtr xPtr inc1Ptr yPtr inc1Ptr
         BlasGen.axpy nPtr onePtr cPtr inccPtr yPtr inc1Ptr


mul ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Vector sh a -> Vector sh a -> Vector sh a
mul (Array shA a) (Array shX x) =
      Array.unsafeCreateWithSize shX $ \n yPtr -> do
   Call.assert "mul: shapes mismatch" (shA == shX)
   evalContT $ do
      aPtr <- ContT $ withForeignPtr a
      xPtr <- ContT $ withForeignPtr x
      liftIO $ Private.mul n aPtr 1 xPtr 1 yPtr 1

divide ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Vector sh a -> Vector sh a -> Vector sh a
divide (Array shB b) (Array shA a) =
      Array.unsafeCreateWithSize shB $ \n xPtr -> do
   Call.assert "divide: shapes mismatch" (shA == shB)
   evalContT $ do
      nPtr <- Call.cint n
      klPtr <- Call.cint 0
      kuPtr <- Call.cint 0
      nrhsPtr <- Call.cint 1
      abPtr <- Private.copyToTemp n a
      ldabPtr <- Call.leadingDim 1
      ipivPtr <- Call.allocaArray n
      bPtr <- ContT $ withForeignPtr b
      ldxPtr <- Call.leadingDim n
      liftIO $ do
         Private.copyBlock n bPtr xPtr
         withInfo "gbsv" $
            LapackGen.gbsv nPtr klPtr kuPtr nrhsPtr
               abPtr ldabPtr ipivPtr xPtr ldxPtr

recip :: (Shape.C sh, Class.Floating a) => Vector sh a -> Vector sh a
recip x =
   Vector.recheck $
   divide (Vector.uncheck $ one $ Array.shape x) (Vector.uncheck x)


newtype Conjugate f a = Conjugate {getConjugate :: f a -> f a}

conjugate ::
   (Shape.C sh, Class.Floating a) =>
   Vector sh a -> Vector sh a
conjugate =
   getConjugate $
   Class.switchFloating
      (Conjugate id)
      (Conjugate id)
      (Conjugate complexConjugate)
      (Conjugate complexConjugate)

complexConjugate ::
   (Shape.C sh, Class.Real a) =>
   Vector sh (Complex a) -> Vector sh (Complex a)
complexConjugate (Array sh x) = Array.unsafeCreateWithSize sh $ \n syPtr ->
   evalContT $ do
      nPtr <- Call.cint n
      sxPtr <- ContT $ withForeignPtr x
      incxPtr <- Call.cint 1
      incyPtr <- Call.cint 1
      liftIO $ copyConjugate nPtr sxPtr incxPtr syPtr incyPtr


fromReal ::
   (Shape.C sh, Class.Floating a) => Vector sh (RealOf a) -> Vector sh a
fromReal =
   getFromReal $
   Class.switchFloating
      (FromReal id)
      (FromReal id)
      (FromReal complexFromReal)
      (FromReal complexFromReal)

newtype FromReal f a = FromReal {getFromReal :: f (RealOf a) -> f a}

toComplex ::
   (Shape.C sh, Class.Floating a) => Vector sh a -> Vector sh (ComplexOf a)
toComplex =
   getToComplex $
   Class.switchFloating
      (ToComplex complexFromReal)
      (ToComplex complexFromReal)
      (ToComplex id)
      (ToComplex id)

newtype ToComplex f a = ToComplex {getToComplex :: f a -> f (ComplexOf a)}

complexFromReal ::
   (Shape.C sh, Class.Real a) => Vector sh a -> Vector sh (Complex a)
complexFromReal (Array sh x) =
   Array.unsafeCreateWithSize sh $ \n yPtr ->
   case realPtr yPtr of
      yrPtr -> evalContT $ do
         nPtr <- Call.cint n
         xPtr <- ContT $ withForeignPtr x
         incxPtr <- Call.cint 1
         incyPtr <- Call.cint 2
         inczPtr <- Call.cint 0
         zPtr <- Call.number Scalar.zero
         liftIO $ do
            BlasGen.copy nPtr xPtr incxPtr yrPtr incyPtr
            BlasGen.copy nPtr zPtr inczPtr (advancePtr yrPtr 1) incyPtr


realPart ::
   (Shape.C sh, Class.Floating a) => Vector sh a -> Vector sh (RealOf a)
realPart =
   getToReal $
   Class.switchFloating
      (ToReal id)
      (ToReal id)
      (ToReal $ RowMajor.takeColumn RealPart . decomplex)
      (ToReal $ RowMajor.takeColumn RealPart . decomplex)

newtype ToReal f a = ToReal {getToReal :: f a -> f (RealOf a)}

imaginaryPart ::
   (Shape.C sh, Class.Real a) => Vector sh (Complex a) -> Vector sh a
imaginaryPart = RowMajor.takeColumn ImaginaryPart . decomplex


zipComplex ::
   (Shape.C sh, Eq sh, Class.Real a) =>
   Vector sh a -> Vector sh a -> Vector sh (Complex a)
zipComplex (Array shr xr) (Array shi xi) =
   Array.unsafeCreateWithSize shr $ \n yPtr -> evalContT $ do
      liftIO $ Call.assert "zipComplex: shapes mismatch" (shr==shi)
      nPtr <- Call.cint n
      xrPtr <- ContT $ withForeignPtr xr
      xiPtr <- ContT $ withForeignPtr xi
      let yrPtr = realPtr yPtr
      incxPtr <- Call.cint 1
      incyPtr <- Call.cint 2
      liftIO $ do
         BlasGen.copy nPtr xrPtr incxPtr yrPtr incyPtr
         BlasGen.copy nPtr xiPtr incxPtr (advancePtr yrPtr 1) incyPtr

unzipComplex ::
   (Shape.C sh, Class.Real a) =>
   Vector sh (Complex a) -> (Vector sh a, Vector sh a)
unzipComplex x = (realPart x, imaginaryPart x)


data RandomDistribution =
     UniformBox01
   | UniformBoxPM1
   | Normal
   | UniformDisc
   | UniformCircle
   deriving (Eq, Ord, Show, Enum)

{-
@random distribution shape seed@

Only the least significant 47 bits of @seed@ are used.
-}
random ::
   (Shape.C sh, Class.Floating a) =>
   RandomDistribution -> sh -> Word64 -> Vector sh a
random dist sh seed = Array.unsafeCreateWithSize sh $ \n xPtr ->
   evalContT $ do
      nPtr <- Call.cint n
      distPtr <-
         Call.cint $
         case (Private.caseRealComplexFunc xPtr False True, dist) of
            (_, UniformBox01) -> 1
            (_, UniformBoxPM1) -> 2
            (_, Normal) -> 3
            (True, UniformDisc) -> 4
            (True, UniformCircle) -> 5
            (False, UniformDisc) -> 2
            (False, UniformCircle) ->
               error
                  "Vector.random: UniformCircle not supported for real numbers"
      iseedPtr <- Call.allocaArray 4
      liftIO $ do
         pokeElemOff iseedPtr 0 $ fromIntegral ((seed `shiftR` 35) .&. 0xFFF)
         pokeElemOff iseedPtr 1 $ fromIntegral ((seed `shiftR` 23) .&. 0xFFF)
         pokeElemOff iseedPtr 2 $ fromIntegral ((seed `shiftR` 11) .&. 0xFFF)
         pokeElemOff iseedPtr 3 $ fromIntegral ((seed.&.0x7FF)*2+1)
         LapackGen.larnv distPtr iseedPtr nPtr xPtr