{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.HMatrix (
   toVector,
   fromVector,
   toGeneral,
   fromGeneral,
   toHermitian,
   fromHermitian,
   fromOrder,
   toOrder,
   ) where

import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix.Hermitian as Hermitian
import qualified Numeric.LAPACK.Matrix.Square as Square
import qualified Numeric.LAPACK.Matrix.Layout as Layout
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Scalar as Scalar
import qualified Numeric.Netlib.Class as Class
import Numeric.LAPACK.Matrix (ShapeInt)
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.LinearAlgebra.Devel as HMatrixDevel
import qualified Numeric.LinearAlgebra.Data as HMatrixData
import qualified Numeric.LinearAlgebra as HMatrix

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import qualified Data.Vector.Storable as StVector

import Data.Functor.Compose (Compose (Compose))
import Data.Tuple.HT (mapPair)


toVector ::
   (StVector.Storable a) =>
   StVector.Vector a -> Vector ShapeInt a
toVector v =
   let (fptr,n) = StVector.unsafeToForeignPtr0 v
   in Array (Matrix.shapeInt n) fptr

fromVector ::
   (Shape.C shape, StVector.Storable a) =>
   Vector shape a -> StVector.Vector a
fromVector (Array shape fptr) =
   StVector.unsafeFromForeignPtr0 fptr (Shape.size shape)


toGeneral ::
   (Class.Floating a) =>
   HMatrix.Matrix a -> Matrix.General ShapeInt ShapeInt a
toGeneral a =
   case HMatrixDevel.orderOf a of
      HMatrixDevel.RowMajor -> toRowMajor a
      HMatrixDevel.ColumnMajor -> Matrix.transpose $ toRowMajor $ transpose a

transpose :: (Class.Floating a) => HMatrix.Matrix a -> HMatrix.Matrix a
transpose a =
   case Scalar.complexSingletonOfFunctor a of
      Scalar.Real ->
         case Scalar.precisionOfFunctor a of
            Scalar.Float -> HMatrix.tr' a
            Scalar.Double -> HMatrix.tr' a
      Scalar.Complex ->
         case Scalar.precisionOfFunctor $ Compose a of
            Scalar.Float -> HMatrix.tr' a
            Scalar.Double -> HMatrix.tr' a

toRowMajor ::
   (Class.Floating a) =>
   HMatrix.Matrix a -> Matrix.General ShapeInt ShapeInt a
toRowMajor a =
   let (dims, b) = flatten a
   in Matrix.fromRowMajor $
      Array.reshape (mapPair (Matrix.shapeInt, Matrix.shapeInt) dims) $
      toVector b

flatten ::
   (Class.Floating a) => HMatrix.Matrix a -> ((Int,Int), StVector.Vector a)
flatten a =
   case Scalar.complexSingletonOfFunctor a of
      Scalar.Real ->
         case Scalar.precisionOfFunctor a of
            Scalar.Float -> (HMatrixData.size a, HMatrix.flatten a)
            Scalar.Double -> (HMatrixData.size a, HMatrix.flatten a)
      Scalar.Complex ->
         case Scalar.precisionOfFunctor $ Compose a of
            Scalar.Float -> (HMatrixData.size a, HMatrix.flatten a)
            Scalar.Double -> (HMatrixData.size a, HMatrix.flatten a)


fromGeneral ::
   (Shape.C height, Shape.C width, StVector.Storable a) =>
   Matrix.General height width a -> HMatrix.Matrix a
fromGeneral a =
   HMatrixDevel.matrixFromVector
      (fromOrder $ ArrMatrix.order a)
      (Shape.size $ Matrix.height a)
      (Shape.size $ Matrix.width a)
      (fromVector $ ArrMatrix.unwrap a)


toHermitian ::
   (Class.Floating a) =>
   HMatrix.Herm a -> Matrix.Hermitian ShapeInt a
toHermitian =
   ArrMatrix.fromVector . hermitianFromUpper . ArrMatrix.toVector .
   Triangular.takeUpper . Square.fromFull . toGeneral . HMatrix.unSym

hermitianFromUpper ::
   Array (Layout.UpperTriangular sh) a -> Array (Layout.Hermitian sh) a
hermitianFromUpper =
   Array.mapShape (\sh -> sh{Layout.mosaicMirror = Layout.ConjugateMirror})

fromHermitian ::
   (Shape.C sh, Class.Floating a) =>
   Matrix.Hermitian sh a -> HMatrix.Herm a
fromHermitian =
   HMatrix.trustSym . fromGeneral . Square.toFull . Hermitian.toSquare


fromOrder :: Layout.Order -> HMatrixDevel.MatrixOrder
fromOrder order =
   case order of
      Layout.RowMajor -> HMatrixDevel.RowMajor
      Layout.ColumnMajor -> HMatrixDevel.ColumnMajor

toOrder :: HMatrixDevel.MatrixOrder -> Layout.Order
toOrder order =
   case order of
      HMatrixDevel.RowMajor -> Layout.RowMajor
      HMatrixDevel.ColumnMajor -> Layout.ColumnMajor