{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Triangular.Eigen (
   values,
   decompose,
   ) where

import qualified Numeric.LAPACK.Matrix.Mosaic.Basic as Mosaic
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Mosaic.Private
         (unpackZero, unpackToTemp, fillTriangle,
          forPointers, rowMajorPointers)
import Numeric.LAPACK.Matrix.Triangular.Basic (TriangularP)
import Numeric.LAPACK.Matrix.Layout.Private
         (Order(ColumnMajor,RowMajor), uploOrder)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (zero)
import Numeric.LAPACK.Private (copyToColumnMajorTemp, withInfo, errorCodeMsg)

import qualified Numeric.LAPACK.FFI.Complex as LapackComplex
import qualified Numeric.LAPACK.FFI.Real as LapackReal
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Foreign.C.Types (CInt, CChar)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr, nullPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)

import Data.Complex (Complex)
import Data.Tuple.HT (swap, mapPair)


values ::
   (Layout.UpLo uplo, Shape.C sh, Class.Floating a) =>
   TriangularP pack uplo sh a -> Vector sh a
values = Mosaic.takeDiagonal


decompose ::
   (Layout.UpLo uplo, Layout.Packing vpack,
    Shape.C sh, Class.Floating a) =>
   TriangularP pack uplo sh a ->
   (TriangularP vpack uplo sh a, TriangularP vpack uplo sh a)
decompose (Array (Layout.Mosaic packing mirror uplo order sh) a) =
   let triShape ord =
         Layout.Mosaic
            Layout.Unpacked mirror uplo (uploOrder uplo ord) sh
       n = Shape.size sh

   in swapUpper uplo $
      mapPair (Vector.conjugate . Mosaic.repack, Mosaic.repack) $
      Array.unsafeCreateWithSizeAndResult (triShape RowMajor) $ \_ vlPtr ->
      ArrayIO.unsafeCreate (triShape ColumnMajor) $ \vrPtr ->

   evalContT $ do
      sidePtr <- Call.char 'B'
      howManyPtr <- Call.char 'A'
      let selectPtr = nullPtr
      aPtr <- toColumnMajorTemp packing (uploOrder uplo order) n a
      ldaPtr <- Call.leadingDim n
      mmPtr <- Call.cint n
      mPtr <- Call.alloca
      liftIO $ withInfo errorCodeMsg "trevc" $
         trevc sidePtr howManyPtr selectPtr n
            aPtr ldaPtr vlPtr ldaPtr vrPtr ldaPtr mmPtr mPtr

swapUpper :: Layout.UpLoSingleton uplo -> (a,a) -> (a,a)
swapUpper uplo =
   case uplo of
      Layout.Lower -> id
      Layout.Upper -> swap

toColumnMajorTemp ::
   (Class.Floating a) =>
   Layout.PackingSingleton pack -> Order ->
   Int -> ForeignPtr a -> ContT () IO (Ptr a)
toColumnMajorTemp packing order n a =
   case packing of
      Layout.Packed ->
         let unpk =
               case order of
                  ColumnMajor -> unpackZero ColumnMajor
                  RowMajor -> unpackZeroRowMajor
         in unpackToTemp unpk n a
      Layout.Unpacked ->
         case order of
            ColumnMajor -> ContT $ withForeignPtr a
            RowMajor -> copyToColumnMajorTemp order n n a

unpackZeroRowMajor :: Class.Floating a => Int -> Ptr a -> Ptr a -> IO ()
unpackZeroRowMajor n packedPtr fullPtr = do
   fillTriangle zero RowMajor n fullPtr
   unpackRowMajor n packedPtr fullPtr

unpackRowMajor :: Class.Floating a => Int -> Ptr a -> Ptr a -> IO ()
unpackRowMajor n packedPtr fullPtr = evalContT $ do
   incxPtr <- Call.cint 1
   incyPtr <- Call.cint n
   liftIO $
      forPointers (rowMajorPointers n fullPtr packedPtr) $
            \nPtr (dstPtr,srcPtr) ->
         BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr


type TREVC_ a =
   Ptr CChar -> Ptr CChar -> Ptr Bool ->
   Int -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt ->
   Ptr CInt -> Ptr CInt -> Ptr CInt -> IO ()

newtype TREVC a = TREVC {getTREVC :: TREVC_ a}

trevc :: Class.Floating a => TREVC_ a
trevc =
   getTREVC $
   Class.switchFloating
      (TREVC trevcReal) (TREVC trevcReal)
      (TREVC trevcComplex) (TREVC trevcComplex)

trevcReal :: Class.Real a => TREVC_ a
trevcReal sidePtr howmnyPtr selectPtr n
      tPtr ldtPtr vlPtr ldvlPtr vrPtr ldvrPtr mmPtr mPtr infoPtr =
   evalContT $ do
      nPtr <- Call.cint n
      workPtr <- Call.allocaArray (3*n)
      liftIO $
         LapackReal.trevc sidePtr howmnyPtr selectPtr nPtr
            tPtr ldtPtr vlPtr ldvlPtr vrPtr ldvrPtr mmPtr mPtr workPtr infoPtr

trevcComplex :: Class.Real a => TREVC_ (Complex a)
trevcComplex sidePtr howmnyPtr selectPtr n
      tPtr ldtPtr vlPtr ldvlPtr vrPtr ldvrPtr mmPtr mPtr infoPtr =
   evalContT $ do
      nPtr <- Call.cint n
      workPtr <- Call.allocaArray (2*n)
      rworkPtr <- Call.allocaArray n
      liftIO $
         LapackComplex.trevc sidePtr howmnyPtr selectPtr nPtr
            tPtr ldtPtr vlPtr ldvlPtr vrPtr ldvrPtr mmPtr mPtr
            workPtr rworkPtr infoPtr