{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Vector (
   Vector,
   RealOf,
   ComplexOf,
   Vector.toList,
   Vector.fromList,
   Vector.autoFromList,
   CheckedArray.append, (Vector.+++),
   CheckedArray.take, CheckedArray.drop,
   CheckedArray.takeLeft, CheckedArray.takeRight,
   Vector.swap,
   CheckedArray.singleton,
   Vector.constant,
   Vector.zero,
   Vector.one,
   Vector.unit,
   Vector.dot, Vector.inner, (Vector.-*|),
   Vector.sum,
   Vector.absSum,
   norm1,
   Vector.norm2,
   Vector.norm2Squared,
   normInf,
   Vector.normInf1,
   argAbsMaximum,
   Vector.argAbs1Maximum,
   Vector.product,
   Vector.scale, Vector.scaleReal, (Vector..*|),
   Vector.add, Vector.sub, (Vector.|+|), (Vector.|-|),
   Vector.negate, Vector.raise,
   Vector.mac,
   Vector.mul, Vector.mulConj,
   divide, recip,
   Vector.minimum, Vector.argMinimum,
   Vector.maximum, Vector.argMaximum,
   Vector.limits, Vector.argLimits,
   CheckedArray.foldl,
   CheckedArray.foldl1,
   CheckedArray.foldMap,

   conjugate,
   Vector.fromReal,
   Vector.toComplex,
   Vector.realPart,
   Vector.imaginaryPart,
   Vector.zipComplex,
   Vector.unzipComplex,

   random, RandomDistribution(..),
   ) where

import qualified Numeric.LAPACK.Vector.Private as VectorPriv
import qualified Numeric.LAPACK.Scalar as Scalar
import qualified Numeric.LAPACK.Private as Private
import qualified Numeric.BLAS.Vector as Vector
import Numeric.LAPACK.Linear.Private (withInfo)
import Numeric.LAPACK.Scalar (ComplexOf, RealOf, absolute)
import Numeric.LAPACK.Private (copyConjugate)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.LAPACK.FFI.Complex as LapackComplex
import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable, peekElemOff, pokeElemOff)
import Foreign.C.Types (CInt)

import System.IO.Unsafe (unsafePerformIO)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (fmap, return, (=<<))
import Control.Applicative (liftA3, (<$>))

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Data.Function (id, ($), (.))
import Data.Complex (Complex)
import Data.Maybe (Maybe(Nothing,Just), maybe)
import Data.Tuple.HT (mapFst, uncurry3)
import Data.Tuple (snd)
import Data.Word (Word64)
import Data.Bits (shiftR, (.&.))
import Data.Ord (Ord)
import Data.Eq (Eq, (==))
import Data.Bool (Bool(False,True))

import Prelude (Int, fromIntegral, (+), (-), (*), Show, Enum, error, IO)


type Vector = Array


norm1 :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a
norm1 arr = unsafePerformIO $
   evalContT $ liftIO . uncurry3 csum1 =<< vectorArgs arr

csum1 :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a)
csum1 =
   getNorm $
   Class.switchFloating
      (Norm BlasReal.asum)
      (Norm BlasReal.asum)
      (Norm LapackComplex.sum1)
      (Norm LapackComplex.sum1)

newtype Norm a =
   Norm {getNorm :: Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a)}


normInf :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a
normInf arr = unsafePerformIO $
   evalContT $ do
      (nPtr, sxPtr, incxPtr) <- vectorArgs arr
      liftIO $
         fmap (absolute . maybe Scalar.zero snd) $
         peekElemOff1 sxPtr =<< VectorPriv.absMax nPtr sxPtr incxPtr

{- |
Returns the index and value of the element with the maximal absolute value.
Caution: It actually returns the value of the element, not its absolute value!
-}
argAbsMaximum ::
   (Shape.InvIndexed sh, Class.Floating a) =>
   Vector sh a -> (Shape.Index sh, a)
argAbsMaximum arr = unsafePerformIO $
   evalContT $ do
      (nPtr, sxPtr, incxPtr) <- vectorArgs arr
      liftIO $
         fmap
            (maybe
               (error "Vector.argAbsMaximum: empty vector")
               (mapFst (Shape.uncheckedIndexFromOffset $ Array.shape arr))) $
         peekElemOff1 sxPtr =<< VectorPriv.absMax nPtr sxPtr incxPtr


vectorArgs ::
   (Shape.C sh) => Array sh a -> ContT r IO (Ptr CInt, Ptr a, Ptr CInt)
vectorArgs (Array sh x) =
   liftA3 (,,)
      (Call.cint $ Shape.size sh)
      (ContT $ withForeignPtr x)
      (Call.cint 1)

peekElemOff1 :: (Storable a) => Ptr a -> CInt -> IO (Maybe (Int, a))
peekElemOff1 ptr k1 =
   let k1i = fromIntegral k1
       ki = k1i-1
   in if k1i == 0
         then return Nothing
         else Just . (,) ki <$> peekElemOff ptr ki


divide ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Vector sh a -> Vector sh a -> Vector sh a
divide (Array shB b) (Array shA a) =
      Array.unsafeCreateWithSize shB $ \n xPtr -> do
   Call.assert "divide: shapes mismatch" (shA == shB)
   evalContT $ do
      nPtr <- Call.cint n
      klPtr <- Call.cint 0
      kuPtr <- Call.cint 0
      nrhsPtr <- Call.cint 1
      abPtr <- Private.copyToTemp n a
      ldabPtr <- Call.leadingDim 1
      ipivPtr <- Call.allocaArray n
      bPtr <- ContT $ withForeignPtr b
      ldxPtr <- Call.leadingDim n
      liftIO $ do
         Private.copyBlock n bPtr xPtr
         withInfo "gbsv" $
            LapackGen.gbsv nPtr klPtr kuPtr nrhsPtr
               abPtr ldabPtr ipivPtr xPtr ldxPtr

recip :: (Shape.C sh, Class.Floating a) => Vector sh a -> Vector sh a
recip x =
   VectorPriv.recheck $
   divide
      (VectorPriv.uncheck $ Vector.one $ Array.shape x)
      (VectorPriv.uncheck x)


newtype Conjugate f a = Conjugate {getConjugate :: f a -> f a}

conjugate ::
   (Shape.C sh, Class.Floating a) =>
   Vector sh a -> Vector sh a
conjugate =
   getConjugate $
   Class.switchFloating
      (Conjugate id)
      (Conjugate id)
      (Conjugate complexConjugate)
      (Conjugate complexConjugate)

complexConjugate ::
   (Shape.C sh, Class.Real a) =>
   Vector sh (Complex a) -> Vector sh (Complex a)
complexConjugate (Array sh x) = Array.unsafeCreateWithSize sh $ \n syPtr ->
   evalContT $ do
      nPtr <- Call.cint n
      sxPtr <- ContT $ withForeignPtr x
      incxPtr <- Call.cint 1
      incyPtr <- Call.cint 1
      liftIO $ copyConjugate nPtr sxPtr incxPtr syPtr incyPtr


data RandomDistribution =
     UniformBox01
   | UniformBoxPM1
   | Normal
   | UniformDisc
   | UniformCircle
   deriving (Eq, Ord, Show, Enum)

{-
@random distribution shape seed@

Only the least significant 47 bits of @seed@ are used.
-}
random ::
   (Shape.C sh, Class.Floating a) =>
   RandomDistribution -> sh -> Word64 -> Vector sh a
random dist sh seed = Array.unsafeCreateWithSize sh $ \n xPtr ->
   evalContT $ do
      nPtr <- Call.cint n
      distPtr <-
         Call.cint $
         case (Private.caseRealComplexFunc xPtr False True, dist) of
            (_, UniformBox01) -> 1
            (_, UniformBoxPM1) -> 2
            (_, Normal) -> 3
            (True, UniformDisc) -> 4
            (True, UniformCircle) -> 5
            (False, UniformDisc) -> 2
            (False, UniformCircle) ->
               error
                  "Vector.random: UniformCircle not supported for real numbers"
      iseedPtr <- Call.allocaArray 4
      liftIO $ do
         pokeElemOff iseedPtr 0 $ fromIntegral ((seed `shiftR` 35) .&. 0xFFF)
         pokeElemOff iseedPtr 1 $ fromIntegral ((seed `shiftR` 23) .&. 0xFFF)
         pokeElemOff iseedPtr 2 $ fromIntegral ((seed `shiftR` 11) .&. 0xFFF)
         pokeElemOff iseedPtr 3 $ fromIntegral ((seed.&.0x7FF)*2+1)
         LapackGen.larnv distPtr iseedPtr nPtr xPtr