{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Diagonal (
   Diagonal, FlexDiagonal,
   fromList, autoFromList,
   fromVector,
   lift,

   stack, (%%%),
   split,

   multiply,

   solve,
   inverse,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Quadratic as Quad
import qualified Numeric.LAPACK.Matrix.Banded as Banded

import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Basic as FullBasic
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Scalar as Scalar
import Numeric.LAPACK.Matrix.Array.Banded (Diagonal, FlexDiagonal)
import Numeric.LAPACK.Matrix.Layout.Private (Order)
import Numeric.LAPACK.Matrix.Private (ShapeInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked))

import qualified Numeric.Netlib.Class as Class

import Type.Base.Proxy (Proxy(Proxy))

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)
import Data.Array.Comfort.Shape ((::+))

import Foreign.Storable (Storable)



fromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Diagonal sh a
fromList order sh = Banded.squareFromList (Proxy,Proxy) order sh

autoFromList :: (Storable a) => Order -> [a] -> Diagonal ShapeInt a
autoFromList order = fromVector order . Vector.autoFromList

fromVector :: (Shape.C sh, Storable a) => Order -> Vector sh a -> Diagonal sh a
fromVector order = ArrMatrix.Array . Array.mapShape (Omni.quadratic order)


takeDiagonal ::
   (Omni.TriDiag diag) =>
   FlexDiagonal diag sh a -> Vector sh a
takeDiagonal = Array.mapShape Omni.squareSize . ArrMatrix.unwrap


lift ::
   (Layout.Packing pack,
    Shape.C sha, Shape.C shb, Class.Floating a, Class.Floating b) =>
   (Array sha a -> Array shb b) ->
   FlexDiagonalP pack Omni.Arbitrary sha a ->
   FlexDiagonalP pack Omni.Arbitrary shb b
lift f a =
   case ArrMatrix.packTag a of
      Layout.Packed ->
         Quad.diagonal (ArrMatrix.order a) $ f $ Quad.takeDiagonal a
      Layout.Unpacked ->
         Quad.diagonal (ArrMatrix.order a) $ f $ Quad.takeDiagonal a


type FlexDiagonalP pack diag sh =
         ArrMatrix.Quadratic pack diag Layout.Empty Layout.Empty sh

infixr 2 %%%

(%%%), stack ::
   (Layout.Packing pack) =>
   (Omni.TriDiag diag, Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   FlexDiagonalP pack diag sh0 a ->
   FlexDiagonalP pack diag sh1 a ->
   FlexDiagonalP pack diag (sh0::+sh1) a
(%%%) = stack
stack a b =
   let order = Omni.order $ ArrMatrix.shape b in
   case ArrMatrix.packTag a of
      Layout.Packed ->
         ArrMatrix.Array $
         Array.mapShape (Omni.uncheckedDiagonal order) $
         Vector.append (takeDiagonal a) (takeDiagonal b)
      Layout.Unpacked ->
         let shc =
               Layout.general order
                  (Unchecked $ Quad.size a) (Unchecked $ Quad.size b)
         in ArrMatrix.liftUnpacked2
               (\a_ b_ ->
                  FullBasic.mapExtent Extent.recheckAppend $
                  FullBasic.stack
                     (FullBasic.uncheck a_) (Vector.zero shc)
                     (Vector.zero $ Layout.inverse shc) (FullBasic.uncheck b_))
               a b

split ::
   (Omni.TriDiag diag, Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   FlexDiagonalP pack diag (sh0::+sh1) a ->
   (FlexDiagonalP pack diag sh0 a, FlexDiagonalP pack diag sh1 a)
split a = (Quad.takeTopLeft a, Quad.takeBottomRight a)


multiply ::
   (Omni.TriDiag diag, Shape.C sh, Eq sh, Class.Floating a) =>
   FlexDiagonal diag sh a -> FlexDiagonal diag sh a -> FlexDiagonal diag sh a
multiply = ArrMatrix.liftOmni2 Vector.mul


solve ::
   (Omni.TriDiag diag,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   FlexDiagonal diag height a ->
   ArrMatrix.Full meas vert horiz height width a ->
   ArrMatrix.Full meas vert horiz height width a
solve a =
   case ArrMatrix.diagTag a of
      Omni.Arbitrary -> Banded.solve a
      Omni.Unit -> \b ->
         if Omni.squareSize (ArrMatrix.shape a) ==
               Omni.height (ArrMatrix.shape b)
            then b
            else error ("Diagonal.solve: height shapes mismatch")

inverse ::
   (Omni.TriDiag diag, Shape.C sh, Class.Floating a) =>
   FlexDiagonal diag sh a -> FlexDiagonal diag sh a
inverse a =
   case ArrMatrix.diagTag a of
      Omni.Unit -> a
      Omni.Arbitrary -> ArrMatrix.liftOmni1 Vector.recip a

determinant ::
   (Omni.TriDiag diag, Shape.C sh, Class.Floating a) =>
   FlexDiagonal diag sh a -> a
determinant a =
   case ArrMatrix.diagTag a of
      Omni.Unit -> Scalar.one
      Omni.Arbitrary -> Vector.product $ ArrMatrix.unwrap a