{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Hermitian.Basic (
   Hermitian,
   Transposition(..),
   fromList,
   autoFromList,
   recheck,
   identity,
   diagonal,
   takeDiagonal,
   forceOrder,

   stack,
   takeTopLeft,
   takeTopRight,
   takeBottomRight,

   multiplyVector,
   multiplyFull,
   square, power,
   outer,
   sumRank1,
   sumRank2,

   toSquare,
   gramian,              gramianAdjoint,
   congruenceDiagonal,   congruenceDiagonalAdjoint,
   congruence,           congruenceAdjoint,
   scaledAnticommutator, scaledAnticommutatorAdjoint,
   addAdjoint,

   takeUpper,
   ) where

import qualified Numeric.LAPACK.Matrix.Symmetric.Private as Symmetric
import qualified Numeric.LAPACK.Matrix.Triangular.Private as Triangular
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Split as Split
import Numeric.LAPACK.Matrix.Hermitian.Private (Diagonal(..), TakeDiagonal(..))
import Numeric.LAPACK.Matrix.Triangular.Private
         (forPointers, pack, unpack, unpackToTemp,
          diagonalPointers, diagonalPointerPairs,
          rowMajorPointers, columnMajorPointers)
import Numeric.LAPACK.Matrix.Shape.Private
         (Order(RowMajor,ColumnMajor), flipOrder, sideSwapFromOrder,
          uploFromOrder)
import Numeric.LAPACK.Matrix.Modifier
         (Transposition(NonTransposed, Transposed), transposeOrder,
          Conjugation(Conjugated), conjugatedOnRowMajor)
import Numeric.LAPACK.Matrix.Private
         (Full, General, Square, argSquare, ShapeInt, shapeInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, zero, one)
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked))
import Numeric.LAPACK.Private
         (fill, lacgv, realPtr,
          copyConjugate, condConjugate, conjugateToTemp, condConjugateToTemp)

import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.BLAS.FFI.Complex as BlasComplex
import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape ((:+:)((:+:)))

import Foreign.C.Types (CInt, CChar)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable, poke, peek)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)

import Data.Foldable (forM_)

import Data.Function.HT (powerAssociative)


type Hermitian sh = Array (MatrixShape.Hermitian sh)


fromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Hermitian sh a
fromList order sh =
   CheckedArray.fromList (MatrixShape.Hermitian order sh)

autoFromList :: (Storable a) => Order -> [a] -> Hermitian ShapeInt a
autoFromList order xs =
   fromList order
      (shapeInt $ MatrixShape.triangleExtent "Hermitian.autoFromList" $
       length xs)
      xs

uncheck :: Hermitian sh a -> Hermitian (Unchecked sh) a
uncheck =
   Array.mapShape $
      \(MatrixShape.Hermitian order sh) ->
         MatrixShape.Hermitian order (Unchecked sh)

recheck :: Hermitian (Unchecked sh) a -> Hermitian sh a
recheck =
   Array.mapShape $
      \(MatrixShape.Hermitian order (Unchecked sh)) ->
         MatrixShape.Hermitian order sh


identity :: (Shape.C sh, Class.Floating a) => Order -> sh -> Hermitian sh a
identity order sh =
   Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $
      \triSize aPtr -> do
   fill zero triSize aPtr
   mapM_ (flip poke one) $ diagonalPointers order (Shape.size sh) aPtr

diagonal ::
   (Shape.C sh, Class.Floating a) =>
   Order -> Vector sh (RealOf a) -> Hermitian sh a
diagonal order =
   runDiagonal $
   Class.switchFloating
      (Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order)
      (Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order)

diagonalAux ::
   (Shape.C sh, Class.Floating a, RealOf a ~ ar, Storable ar) =>
   Order -> Vector sh ar -> Hermitian sh a
diagonalAux order (Array sh x) =
   Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $
      \triSize aPtr -> do
   fill zero triSize aPtr
   withForeignPtr x $ \xPtr ->
      forM_ (diagonalPointerPairs order (Shape.size sh) xPtr aPtr) $
         \(srcPtr,dstPtr) -> poke (realPtr dstPtr) =<< peek srcPtr


takeDiagonal ::
   (Shape.C sh, Class.Floating a) =>
   Hermitian sh a -> Vector sh (RealOf a)
takeDiagonal =
   runTakeDiagonal $
   Class.switchFloating
      (TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux)
      (TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux)

takeDiagonalAux ::
   (Shape.C sh, Storable a, RealOf a ~ ar, Storable ar) =>
   Hermitian sh a -> Vector sh ar
takeDiagonalAux (Array (MatrixShape.Hermitian order sh) a) =
   Array.unsafeCreateWithSize sh $ \n xPtr ->
   withForeignPtr a $ \aPtr ->
      forM_ (diagonalPointerPairs order n xPtr aPtr) $
         \(dstPtr,srcPtr) -> poke dstPtr =<< peek (realPtr srcPtr)


{-
This is not maximally efficient.
It fills up a whole square.
This wastes memory but enables more regular memory access patterns.
Additionally, it fills unused parts of the square with mirrored values.
-}
forceOrder ::
   (Shape.C sh, Class.Floating a) =>
   Order -> Hermitian sh a -> Hermitian sh a
forceOrder newOrder a =
   if MatrixShape.hermitianOrder (Array.shape a) == newOrder
      then a
      else fromUpperPart $ Basic.forceOrder newOrder $ toSquare a

fromUpperPart ::
   (Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) =>
   Full vert Extent.Small height width a -> Hermitian width a
fromUpperPart = Triangular.fromUpperPart MatrixShape.Hermitian

{-
Naming is inconsistent to Triangular.takeUpper,
because here Hermitian is the input
and in Triangular.takeUpper, Triangular is the output.
-}
takeUpper ::
   (Shape.C sh, Class.Floating a) =>
   Hermitian sh a ->
   Array (MatrixShape.UpperTriangular MatrixShape.NonUnit sh) a
takeUpper =
   Array.mapShape
      (\(MatrixShape.Hermitian order sh) ->
         MatrixShape.Triangular MatrixShape.NonUnit MatrixShape.upper order sh)


stack ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   Hermitian sh0 a -> General sh0 sh1 a -> Hermitian sh1 a ->
   Hermitian (sh0:+:sh1) a
stack a b c =
   let order = MatrixShape.fullOrder $ Array.shape b
   in Triangular.stack "Hermitian" (MatrixShape.Hermitian order)
         (forceOrder order a) b (forceOrder order c)

takeTopLeft ::
   (Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Hermitian (sh0:+:sh1) a -> Hermitian sh0 a
takeTopLeft =
   Triangular.takeTopLeft
      (\(MatrixShape.Hermitian order sh@(sh0:+:_sh1)) ->
         (MatrixShape.Hermitian order sh0, (order,sh)))

takeTopRight ::
   (Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Hermitian (sh0:+:sh1) a -> General sh0 sh1 a
takeTopRight =
   Triangular.takeTopRight (\(MatrixShape.Hermitian order sh) -> (order,sh))

takeBottomRight ::
   (Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Hermitian (sh0:+:sh1) a -> Hermitian sh1 a
takeBottomRight =
   Triangular.takeBottomRight
      (\(MatrixShape.Hermitian order sh@(_sh0:+:sh1)) ->
         (MatrixShape.Hermitian order sh1, (order,sh)))


multiplyVector ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Transposition -> Hermitian sh a -> Vector sh a -> Vector sh a
multiplyVector transposed
   (Array (MatrixShape.Hermitian order shA) a) (Array shX x) =
      Array.unsafeCreateWithSize shX $ \n yPtr -> do
   Call.assert "Hermitian.multiplyVector: width shapes mismatch" (shA == shX)
   evalContT $ do
      let conj = conjugatedOnRowMajor $ transposeOrder transposed order
      uploPtr <- Call.char $ uploFromOrder order
      nPtr <- Call.cint n
      alphaPtr <- Call.number one
      aPtr <- ContT $ withForeignPtr a
      xPtr <- condConjugateToTemp conj n x
      incxPtr <- Call.cint 1
      betaPtr <- Call.number zero
      incyPtr <- Call.cint 1
      liftIO $ do
         BlasGen.hpmv
            uploPtr nPtr alphaPtr aPtr xPtr incxPtr betaPtr yPtr incyPtr
         condConjugate conj nPtr yPtr incyPtr


square :: (Shape.C sh, Class.Floating a) => Hermitian sh a -> Hermitian sh a
square (Array shape@(MatrixShape.Hermitian order sh) a) =
   Array.unsafeCreate shape $
      Symmetric.square Conjugated order (Shape.size sh) a

{-
Requires frequent unpacking and packing of triangles.
-}
power ::
   (Shape.C sh, Class.Floating a) => Integer -> Hermitian sh a -> Hermitian sh a
power n a0@(Array (MatrixShape.Hermitian order sh) _) =
   recheck $
   powerAssociative
      (\a b -> fromUpperPart $ multiplyFull NonTransposed a $ toSquare b)
      (identity order $ Unchecked sh)
      (uncheck a0)
      n


multiplyFull ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width,
    Class.Floating a) =>
   Transposition -> Hermitian height a ->
   Full vert horiz height width a ->
   Full vert horiz height width a
multiplyFull transposed
   (Array        (MatrixShape.Hermitian orderA shA) a)
   (Array shapeB@(MatrixShape.Full orderB extentB) b) =
      Array.unsafeCreate shapeB $ \cPtr -> do
   let (height,width) = Extent.dimensions extentB
   Call.assert "Hermitian.multiplyFull: shapes mismatch" (shA == height)
   let m0 = Shape.size height
   let n0 = Shape.size width
   let size = m0*m0
   evalContT $ do
      let (side,(m,n)) = sideSwapFromOrder orderB (m0,n0)
      sidePtr <- Call.char side
      uploPtr <- Call.char $ uploFromOrder orderA
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      alphaPtr <- Call.number one
      aPtr <- unpackToTemp (unpack orderA) m0 a
      ldaPtr <- Call.leadingDim m0
      incaPtr <- Call.cint 1
      sizePtr <- Call.cint size
      bPtr <- ContT $ withForeignPtr b
      ldbPtr <- Call.leadingDim m
      betaPtr <- Call.number zero
      ldcPtr <- Call.leadingDim m
      liftIO $ do
         when (transposeOrder transposed orderA /= orderB) $
            lacgv sizePtr aPtr incaPtr
         BlasGen.hemm sidePtr uploPtr
            mPtr nPtr alphaPtr aPtr ldaPtr
            bPtr ldbPtr betaPtr cPtr ldcPtr



withConjBuffer ::
   (Shape.C sh, Class.Floating a) =>
   Order -> sh -> Int -> Ptr a ->
   (Ptr CChar -> Ptr CInt -> Ptr CInt -> IO ()) -> ContT r IO ()
withConjBuffer order sh triSize aPtr act = do
   uploPtr <- Call.char $ uploFromOrder order
   nPtr <- Call.cint $ Shape.size sh
   incxPtr <- Call.cint 1
   sizePtr <- Call.cint triSize
   liftIO $ do
      fill zero triSize aPtr
      act uploPtr nPtr incxPtr
      condConjugate (conjugatedOnRowMajor order) sizePtr aPtr incxPtr

outer ::
   (Shape.C sh, Class.Floating a) => Order -> Vector sh a -> Hermitian sh a
outer order (Array sh x) =
   Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $
      \triSize aPtr ->
   evalContT $ do
      alphaPtr <- realOneArg aPtr
      xPtr <- ContT $ withForeignPtr x
      withConjBuffer order sh triSize aPtr $ \uploPtr nPtr incxPtr ->
         hpr uploPtr nPtr alphaPtr xPtr incxPtr aPtr


sumRank1 ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> sh -> [(RealOf a, Vector sh a)] -> Hermitian sh a
sumRank1 =
   getSumRank1 $
   Class.switchFloating
      (SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux)
      (SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux)

type SumRank1_ sh ar a = Order -> sh -> [(ar, Vector sh a)] -> Hermitian sh a

newtype SumRank1 sh a = SumRank1 {getSumRank1 :: SumRank1_ sh (RealOf a) a}

sumRank1Aux ::
   (Shape.C sh, Eq sh, Class.Floating a, RealOf a ~ ar, Storable ar) =>
   SumRank1_ sh ar a
sumRank1Aux order sh xs =
   Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $
      \triSize aPtr ->
   evalContT $ do
      alphaPtr <- Call.alloca
      withConjBuffer order sh triSize aPtr $ \uploPtr nPtr incxPtr ->
         forM_ xs $ \(alpha, Array shX x) ->
            withForeignPtr x $ \xPtr -> do
               Call.assert
                  "Hermitian.sumRank1: non-matching vector size" (sh==shX)
               poke alphaPtr alpha
               hpr uploPtr nPtr alphaPtr xPtr incxPtr aPtr


type HPR_ a =
   Ptr CChar -> Ptr CInt ->
   Ptr (RealOf a) -> Ptr a -> Ptr CInt -> Ptr a -> IO ()

newtype HPR a = HPR {getHPR :: HPR_ a}

hpr :: Class.Floating a => HPR_ a
hpr =
   getHPR $
   Class.switchFloating
      (HPR BlasReal.spr) (HPR BlasReal.spr)
      (HPR BlasComplex.hpr) (HPR BlasComplex.hpr)


sumRank2 ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> sh -> [(a, (Vector sh a, Vector sh a))] -> Hermitian sh a
sumRank2 order sh xys =
   Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $
      \triSize aPtr ->
   evalContT $ do
      alphaPtr <- Call.alloca
      withConjBuffer order sh triSize aPtr $ \uploPtr nPtr incPtr ->
         forM_ xys $ \(alpha, (Array shX x, Array shY y)) ->
            withForeignPtr x $ \xPtr ->
            withForeignPtr y $ \yPtr -> do
               Call.assert
                  "Hermitian.sumRank2: non-matching x vector size" (sh==shX)
               Call.assert
                  "Hermitian.sumRank2: non-matching y vector size" (sh==shY)
               poke alphaPtr alpha
               BlasGen.hpr2 uploPtr nPtr alphaPtr xPtr incPtr yPtr incPtr aPtr


{-
It is not strictly necessary to keep the 'order'.
It would be neither more complicated nor less efficient
to change the order via the conversion.
-}
toSquare, _toSquare ::
   (Shape.C sh, Class.Floating a) => Hermitian sh a -> Square sh a
_toSquare (Array (MatrixShape.Hermitian order sh) a) =
      Array.unsafeCreate (MatrixShape.square order sh) $ \bPtr ->
   evalContT $ do
      let n = Shape.size sh
      aPtr <- ContT $ withForeignPtr a
      conjPtr <- conjugateToTemp (Shape.triangleSize n) a
      liftIO $ do
         unpack (flipOrder order) n conjPtr bPtr -- wrong
         unpack order n aPtr bPtr

toSquare (Array (MatrixShape.Hermitian order sh) a) =
      Array.unsafeCreate (MatrixShape.square order sh) $ \bPtr ->
   withForeignPtr a $ \aPtr ->
      Symmetric.unpack Conjugated order (Shape.size sh) aPtr bPtr


gramian ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> Hermitian width a
gramian (Array (MatrixShape.Full order extent) a) =
   Array.unsafeCreate (MatrixShape.Hermitian order $ Extent.width extent) $
   \bPtr -> gramianIO order a bPtr $ gramianParameters order extent

gramianParameters ::
   (Extent.C horiz, Extent.C vert, Shape.C height, Shape.C width) =>
   Order ->
   Extent.Extent vert horiz height width ->
   ((Int, Int), (Char, Char, Int))
gramianParameters order extent =
   let (height, width) = Extent.dimensions extent
       n = Shape.size width
       k = Shape.size height
    in ((n,k),
         case order of
            ColumnMajor -> ('U', 'C', k)
            RowMajor -> ('L', 'N', n))


gramianAdjoint ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> Hermitian height a
gramianAdjoint (Array (MatrixShape.Full order extent) a) =
   Array.unsafeCreate (MatrixShape.Hermitian order $ Extent.height extent) $
   \bPtr -> gramianIO order a bPtr $ gramianAdjointParameters order extent

gramianAdjointParameters ::
   (Extent.C horiz, Extent.C vert, Shape.C height, Shape.C width) =>
   Order ->
   Extent.Extent vert horiz height width ->
   ((Int, Int), (Char, Char, Int))
gramianAdjointParameters order extent =
   let (height, width) = Extent.dimensions extent
       n = Shape.size height
       k = Shape.size width
   in ((n,k),
         case order of
            ColumnMajor -> ('U', 'N', n)
            RowMajor -> ('L', 'C', k))

{-
Another way to unify 'gramian' and 'gramianAdjoint'
would have been this function:

> gramianConjugation ::
>    Conjugation -> General height width a -> Hermitian width a

with

> gramianAdjoint a = gramianConjugation (transpose a)

but I would like to have

> order (gramianAdjoint a) = order a
-}
gramianIO ::
   (Class.Floating a) =>
   Order ->
   ForeignPtr a -> Ptr a ->
   ((Int, Int), (Char, Char, Int)) -> IO ()
gramianIO order a bPtr ((n,k), (uplo,trans,lda)) =
   evalContT $ do
      uploPtr <- Call.char uplo
      transPtr <- Call.char trans
      nPtr <- Call.cint n
      kPtr <- Call.cint k
      alphaPtr <- realOneArg a
      aPtr <- ContT $ withForeignPtr a
      ldaPtr <- Call.leadingDim lda
      betaPtr <- realZeroArg a
      cPtr <- Call.allocaArray (n*n)
      ldcPtr <- Call.leadingDim n
      liftIO $ do
         herk uploPtr transPtr
            nPtr kPtr alphaPtr aPtr ldaPtr betaPtr cPtr ldcPtr
         pack order n cPtr bPtr


type HERK_ a =
   Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr (RealOf a) -> Ptr a ->
   Ptr CInt -> Ptr (RealOf a) -> Ptr a -> Ptr CInt -> IO ()

newtype HERK a = HERK {getHERK :: HERK_ a}

herk :: Class.Floating a => HERK_ a
herk =
   getHERK $
   Class.switchFloating
      (HERK BlasReal.syrk)
      (HERK BlasReal.syrk)
      (HERK BlasComplex.herk)
      (HERK BlasComplex.herk)


skipCheckCongruence ::
   ((sh -> Unchecked sh) -> matrix0 -> matrix1) ->
   (matrix1 -> Hermitian (Unchecked sh) a) -> matrix0 -> Hermitian sh a
skipCheckCongruence mapSize f a =
   recheck $ f $ mapSize Unchecked a


congruenceDiagonal ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height (RealOf a) -> General height width a -> Hermitian width a
congruenceDiagonal d =
   skipCheckCongruence Basic.mapWidth $ \a ->
      scaledAnticommutator 0.5 a $ Basic.scaleRowsReal d a

congruenceDiagonalAdjoint ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   General height width a -> Vector width (RealOf a) -> Hermitian height a
congruenceDiagonalAdjoint =
   flip $ \d -> skipCheckCongruence Basic.mapHeight $ \a ->
      scaledAnticommutatorAdjoint 0.5 a $ Basic.scaleColumnsReal d a


congruence ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Hermitian height a -> General height width a -> Hermitian width a
congruence b =
   skipCheckCongruence Basic.mapWidth $ \a ->
      scaledAnticommutator one a $
      Split.tallMultiplyR NonTransposed (takeHalf b) a

congruenceAdjoint ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   General height width a -> Hermitian width a -> Hermitian height a
congruenceAdjoint =
   flip $ \b -> skipCheckCongruence Basic.mapHeight $ \a ->
      scaledAnticommutatorAdjoint one a $
      Basic.swapMultiply (Split.tallMultiplyR Transposed) a (takeHalf b)


data Corrupt = Corrupt
   deriving (Eq)

{- |
> let b = takeHalf a
> ==>
> isTriangular b && a == addAdjoint b
-}
takeHalf ::
   (Shape.C sh, Class.Floating a) =>
   Hermitian sh a -> Split.Square Corrupt sh a
takeHalf (Array (MatrixShape.Hermitian order sh) a) =
   Array.unsafeCreate (MatrixShape.Split Corrupt order (Extent.square sh)) $
      \bPtr -> evalContT $ do
   let n = Shape.size sh
   aPtr <- ContT $ withForeignPtr a
   nPtr <- Call.cint n
   alphaPtr <- Call.number 0.5
   incxPtr <- Call.cint (n+1)
   liftIO $ do
      unpack order n aPtr bPtr
      BlasGen.scal nPtr alphaPtr bPtr incxPtr


scaledAnticommutator ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   a ->
   Full vert horiz height width a ->
   Full vert horiz height width a -> Hermitian width a
scaledAnticommutator alpha arr (Array (MatrixShape.Full order extentB) b) = do
   let (Array (MatrixShape.Full _ extentA) a) = Basic.forceOrder order arr
   Array.unsafeCreate (MatrixShape.Hermitian order $ Extent.width extentB) $
         \cpPtr -> do
      Call.assert "Hermitian.anticommutator: extents mismatch"
         (extentA==extentB)
      scaledAnticommutatorIO alpha order a b cpPtr $
         gramianParameters order extentB

scaledAnticommutatorAdjoint ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   a ->
   Full vert horiz height width a ->
   Full vert horiz height width a -> Hermitian height a
scaledAnticommutatorAdjoint
      alpha arr (Array (MatrixShape.Full order extentB) b) = do
   let (Array (MatrixShape.Full _ extentA) a) = Basic.forceOrder order arr
   Array.unsafeCreate (MatrixShape.Hermitian order $ Extent.height extentB) $
         \cpPtr -> do
      Call.assert "Hermitian.anticommutatorAdjoint: extents mismatch"
         (extentA==extentB)
      scaledAnticommutatorIO alpha order a b cpPtr $
         gramianAdjointParameters order extentB

scaledAnticommutatorIO ::
   (Class.Floating a) =>
   a ->
   Order -> ForeignPtr a -> ForeignPtr a -> Ptr a ->
   ((Int, Int), (Char, Char, Int)) -> IO ()
scaledAnticommutatorIO alpha order a b cpPtr ((n,k), (uplo,trans,lda)) =
   evalContT $ do
      uploPtr <- Call.char uplo
      transPtr <- Call.char trans
      nPtr <- Call.cint n
      kPtr <- Call.cint k
      alphaPtr <- Call.number alpha
      aPtr <- ContT $ withForeignPtr a
      ldaPtr <- Call.leadingDim lda
      bPtr <- ContT $ withForeignPtr b
      let ldbPtr = ldaPtr
      betaPtr <- realZeroArg aPtr
      cPtr <- Call.allocaArray (n*n)
      ldcPtr <- Call.leadingDim n
      liftIO $ do
         her2k uploPtr transPtr nPtr kPtr alphaPtr
            aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr
         pack order n cPtr cpPtr


type HER2K_ a =
   Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a ->
   Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt ->
   Ptr (RealOf a) -> Ptr a -> Ptr CInt -> IO ()

newtype HER2K a = HER2K {getHER2K :: HER2K_ a}

her2k :: Class.Floating a => HER2K_ a
her2k =
   getHER2K $
   Class.switchFloating
      (HER2K BlasReal.syr2k)
      (HER2K BlasReal.syr2k)
      (HER2K BlasComplex.her2k)
      (HER2K BlasComplex.her2k)


addAdjoint, _addAdjoint ::
   (Shape.C sh, Class.Floating a) => Square sh a -> Hermitian sh a
_addAdjoint =
   argSquare $ \order sh a ->
      Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $ \bSize bPtr -> do
   let n = Shape.size sh
   evalContT $ do
      alphaPtr <- Call.number one
      incxPtr <- Call.cint 1
      aPtr <- ContT $ withForeignPtr a
      sizePtr <- Call.cint bSize
      conjPtr <- Call.allocaArray bSize
      liftIO $ do
         pack order n aPtr bPtr
         pack (flipOrder order) n aPtr conjPtr -- wrong
         lacgv sizePtr conjPtr incxPtr
         BlasGen.axpy sizePtr alphaPtr conjPtr incxPtr bPtr incxPtr

addAdjoint =
   argSquare $ \order sh a ->
      Array.unsafeCreate (MatrixShape.Hermitian order sh) $ \bPtr -> do
   let n = Shape.size sh
   evalContT $ do
      alphaPtr <- Call.number one
      incxPtr <- Call.cint 1
      incnPtr <- Call.cint n
      aPtr <- ContT $ withForeignPtr a
      liftIO $ case order of
         RowMajor ->
            forPointers (rowMajorPointers n aPtr bPtr) $
               \nPtr (srcPtr,dstPtr) -> do
            copyConjugate nPtr srcPtr incnPtr dstPtr incxPtr
            BlasGen.axpy nPtr alphaPtr srcPtr incxPtr dstPtr incxPtr
         ColumnMajor ->
            forPointers (columnMajorPointers n aPtr bPtr) $
               \nPtr ((srcRowPtr,srcColumnPtr),dstPtr) -> do
            copyConjugate nPtr srcRowPtr incnPtr dstPtr incxPtr
            BlasGen.axpy nPtr alphaPtr srcColumnPtr incxPtr dstPtr incxPtr


_pack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
_pack order n fullPtr packedPtr =
   evalContT $ do
      incxPtr <- Call.cint 1
      liftIO $
         case order of
            ColumnMajor ->
               forPointers (columnMajorPointers n fullPtr packedPtr) $
                  \nPtr ((_,srcPtr),dstPtr) ->
                     BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr
            RowMajor ->
               forPointers (rowMajorPointers n fullPtr packedPtr) $
                  \nPtr (srcPtr,dstPtr) ->
                     BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr


realZeroArg, realOneArg ::
   (Class.Floating a) => f a -> ContT r IO (Ptr (RealOf a))
realZeroArg =
   runRealArg $
   Class.switchFloating
      (RealArg $ const $ Call.number zero)
      (RealArg $ const $ Call.number zero)
      (RealArg $ const $ Call.number zero)
      (RealArg $ const $ Call.number zero)
realOneArg =
   runRealArg $
   Class.switchFloating
      (RealArg $ const $ Call.number one)
      (RealArg $ const $ Call.number one)
      (RealArg $ const $ Call.number one)
      (RealArg $ const $ Call.number one)

newtype RealArg f g a = RealArg {runRealArg :: f a -> g (Ptr (RealOf a))}