{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Symmetric.Basic (
   Symmetric,
   SymmetricP,
   sumRank1,
   congruenceDiagonal, congruenceDiagonalTransposed,
   ) where

import qualified Numeric.LAPACK.Matrix.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import Numeric.LAPACK.Matrix.Symmetric.Unified
         (skipCheckCongruence, spr, syr, complement,
          scaledAnticommutator, scaledAnticommutatorTransposed)
import Numeric.LAPACK.Matrix.Mosaic.Private
         (withPacking, noLabel, applyFuncPair, triArg)
import Numeric.LAPACK.Matrix.Layout.Private
         (MirrorSingleton(SimpleMirror), Order, uploFromOrder)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated))
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (zero)
import Numeric.LAPACK.Private (fill)

import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Storable (poke)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)

import Data.Foldable (forM_)


type Symmetric sh = SymmetricP Layout.Unpacked sh
type SymmetricP pack sh = Array (Layout.SymmetricP pack sh)


sumRank1 ::
   (Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> sh -> [(a, Vector sh a)] -> SymmetricP pack sh a
sumRank1 order sh xs =
   let pack = Layout.autoPacking
   in Array.unsafeCreateWithSize (Layout.symmetricP pack order sh) $
      \triSize aPtr -> do

   let n = Shape.size sh
   evalContT $ do
      uploPtr <- Call.char $ uploFromOrder order
      nPtr <- Call.cint n
      alphaPtr <- Call.alloca
      incxPtr <- Call.cint 1
      liftIO $ do
         fill zero triSize aPtr
         forM_ xs $ \(alpha, Array shX x) -> do
            Call.assert "Symmetric.sumRank1: non-matching vector size" (sh==shX)
            poke alphaPtr alpha
            evalContT $ do
               xPtr <- ContT $ withForeignPtr x
               withPacking pack $
                  applyFuncPair (noLabel spr) (noLabel syr)
                     uploPtr nPtr alphaPtr xPtr incxPtr (triArg aPtr n)
   complement pack NonConjugated order n aPtr


congruenceDiagonal ::
   (Layout.Packing pack,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a -> Matrix.General height width a -> SymmetricP pack width a
congruenceDiagonal d =
   skipCheckCongruence Basic.mapWidth $ \a ->
      scaledAnticommutator SimpleMirror 0.5 a $
         Basic.scaleRows d a

congruenceDiagonalTransposed ::
   (Layout.Packing pack,
    Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Matrix.General height width a -> Vector width a -> SymmetricP pack height a
congruenceDiagonalTransposed =
   flip $ \d -> skipCheckCongruence Basic.mapHeight $ \a ->
      scaledAnticommutatorTransposed SimpleMirror 0.5 a $
         Basic.scaleColumns d a