{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Hermitian.Eigen (
   values,
   decompose,
   ) where

import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Scalar as Scalar
import qualified Numeric.LAPACK.Shape as ExtShape
import Numeric.LAPACK.Matrix.Hermitian.Basic (Hermitian, HermitianP)
import Numeric.LAPACK.Matrix.Square.Basic (Square)
import Numeric.LAPACK.Matrix.Layout.Private (Order(ColumnMajor), uploFromOrder)
import Numeric.LAPACK.Matrix.Modifier (conjugatedOnRowMajor)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf)
import Numeric.LAPACK.Private
         (copyToTemp, copyCondConjugate, copyCondConjugateToTemp,
          withAutoWorkspaceInfo, withInfo, eigenMsg)

import qualified Numeric.LAPACK.FFI.Complex as LapackComplex
import qualified Numeric.LAPACK.FFI.Real as LapackReal
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape (triangleSize)

import Foreign.C.Types (CInt, CChar)
import Foreign.Ptr (Ptr, nullPtr)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Storable (Storable)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)


values ::
   (ExtShape.Permutable sh, Class.Floating a) =>
   HermitianP pack sh a -> Vector sh (RealOf a)
values a =
   case Layout.mosaicPack $ Array.shape a of
      Layout.Packed ->
         case Scalar.complexSingletonOfFunctor a of
            Scalar.Real -> valuesPacked a
            Scalar.Complex -> valuesPacked a
      Layout.Unpacked ->
         case Scalar.complexSingletonOfFunctor a of
            Scalar.Real -> valuesUnpacked a
            Scalar.Complex -> valuesUnpacked a

valuesPacked ::
   (ExtShape.Permutable sh, Class.Floating a, RealOf a ~ ar, Storable ar) =>
   Hermitian sh a -> Vector sh ar
valuesPacked (Array (Layout.Mosaic _pack _mirror _upper order size) a) =
   Array.unsafeCreateWithSize size $ \n wPtr ->
   evalContT $ do
      jobzPtr <- Call.char 'N'
      uploPtr <- Call.char $ uploFromOrder order
      aPtr <- copyToTemp (triangleSize n) a
      let zPtr = nullPtr
      ldzPtr <- Call.leadingDim n
      liftIO $ withInfo eigenMsg "hpev" $
         hpev jobzPtr uploPtr n aPtr wPtr zPtr ldzPtr

valuesUnpacked ::
   (ExtShape.Permutable sh, Class.Floating a, RealOf a ~ ar, Storable ar) =>
   HermitianP Layout.Unpacked sh a -> Vector sh ar
valuesUnpacked (Array (Layout.Mosaic _pack _mirror _upper order size) a) =
   Array.unsafeCreateWithSize size $ \n wPtr ->
   evalContT $ do
      jobzPtr <- Call.char 'N'
      uploPtr <- Call.char $ uploFromOrder order
      aPtr <- copyToTemp (n*n) a
      ldaPtr <- Call.leadingDim n
      liftIO $
         withAutoWorkspaceInfo eigenMsg "heev" $
            heev jobzPtr uploPtr n aPtr ldaPtr wPtr


decompose ::
   (ExtShape.Permutable sh, Class.Floating a) =>
   HermitianP pack sh a -> (Square sh a, Vector sh (RealOf a))
decompose a =
   case Layout.mosaicPack $ Array.shape a of
      Layout.Packed ->
         case Scalar.complexSingletonOfFunctor a of
            Scalar.Real -> decomposePacked a
            Scalar.Complex -> decomposePacked a
      Layout.Unpacked ->
         case Scalar.complexSingletonOfFunctor a of
            Scalar.Real -> decomposeUnpacked a
            Scalar.Complex -> decomposeUnpacked a

decomposePacked ::
   (ExtShape.Permutable sh, Class.Floating a, RealOf a ~ ar, Storable ar) =>
   Hermitian sh a -> (Square sh a, Vector sh (RealOf a))
decomposePacked (Array (Layout.Mosaic _pack _mirror _upper order size) a) =
   Array.unsafeCreateWithSizeAndResult (Layout.square ColumnMajor size) $
      \_ zPtr ->
   ArrayIO.unsafeCreateWithSize size $ \n wPtr ->
   evalContT $ do
      jobzPtr <- Call.char 'V'
      uploPtr <- Call.char $ uploFromOrder order
      aPtr <-
         copyCondConjugateToTemp (conjugatedOnRowMajor order) (triangleSize n) a
      ldzPtr <- Call.leadingDim n
      liftIO $ withInfo eigenMsg "hpev" $
         hpev jobzPtr uploPtr n aPtr wPtr zPtr ldzPtr

hpev ::
   (Class.Floating a) =>
   Ptr CChar -> Ptr CChar -> Int -> Ptr a -> Ptr (RealOf a) ->
   Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
hpev jobzPtr uploPtr n apPtr wPtr zPtr ldzPtr infoPtr = evalContT $ do
   nPtr <- Call.cint n
   case Scalar.complexSingletonOfFunctor apPtr of
      Scalar.Real -> do
         workPtr <- Call.allocaArray (3*n)
         liftIO $
            LapackReal.spev jobzPtr uploPtr
               nPtr apPtr wPtr zPtr ldzPtr workPtr infoPtr
      Scalar.Complex -> do
         workPtr <- Call.allocaArray (max 1 (2*n-1))
         rworkPtr <- Call.allocaArray (max 1 (3*n-2))
         liftIO $
            LapackComplex.hpev jobzPtr uploPtr
               nPtr apPtr wPtr zPtr ldzPtr workPtr rworkPtr infoPtr


decomposeUnpacked ::
   (ExtShape.Permutable sh, Class.Floating a, RealOf a ~ ar, Storable ar) =>
   HermitianP Layout.Unpacked sh a -> (Square sh a, Vector sh (RealOf a))
decomposeUnpacked
      (Array (Layout.Mosaic _pack _mirror _upper order size) a) =
   Array.unsafeCreateWithSizeAndResult (Layout.square ColumnMajor size) $
      \squareSize vPtr ->
   ArrayIO.unsafeCreateWithSize size $ \n wPtr ->
   evalContT $ do
      jobzPtr <- Call.char 'V'
      uploPtr <- Call.char $ uploFromOrder order
      sizePtr <- Call.cint squareSize
      aPtr <- ContT $ withForeignPtr a
      ldaPtr <- Call.leadingDim n
      incPtr <- Call.cint 1
      liftIO $ do
         copyCondConjugate (conjugatedOnRowMajor order)
            sizePtr aPtr incPtr vPtr incPtr
         withAutoWorkspaceInfo eigenMsg "heev" $
            heev jobzPtr uploPtr n vPtr ldaPtr wPtr

heev ::
   (Class.Floating a) =>
   Ptr CChar -> Ptr CChar -> Int -> Ptr a -> Ptr CInt ->
   Ptr (RealOf a) -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
heev jobzPtr uploPtr n aPtr ldaPtr wPtr workPtr lworkPtr infoPtr =
      evalContT $ do
   nPtr <- Call.cint n
   liftIO $ case Scalar.complexSingletonOfFunctor aPtr of
      Scalar.Real ->
         LapackReal.syev jobzPtr uploPtr nPtr aPtr ldaPtr wPtr
            workPtr lworkPtr infoPtr
      Scalar.Complex -> evalContT $ do
         rworkPtr <- Call.allocaArray (max 1 (3*n-2))
         liftIO $
            LapackComplex.heev jobzPtr uploPtr nPtr aPtr ldaPtr wPtr
               workPtr lworkPtr rworkPtr infoPtr