{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Banded (
   Banded,
   FlexBanded,
   Banded.General,
   Banded.Square,
   Banded.Upper, Banded.UnitUpper,
   Banded.Lower, Banded.UnitLower,
   Diagonal,
   Banded.FlexDiagonal,
   Banded.RectangularDiagonal,
   Banded.Hermitian,
   Banded.HermitianPosSemidef,
   Banded.HermitianPosDef,
   Banded.FlexHermitian,
   height, width,
   fromList,
   squareFromList,
   lowerFromList,
   upperFromList,
   mapExtent,
   diagonal,
   takeDiagonal,
   forceOrder,
   noUnit,
   toFull,
   toLowerTriangular,
   toUpperTriangular,
   takeTopLeftSquare,
   takeBottomRightSquare,
   transpose,
   adjoint,
   multiplyVector,
   multiply,
   multiplyFull,

   solve,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Banded.Linear as Linear
import qualified Numeric.LAPACK.Matrix.Banded.Basic as Basic

import qualified Numeric.LAPACK.Matrix.Array.Banded as Banded
import qualified Numeric.LAPACK.Matrix.Array.Mosaic as Triangular
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Strict as ExtentStrict
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Scalar as Scalar
import Numeric.LAPACK.Matrix.Array.Banded (FlexBanded, Banded, Diagonal)
import Numeric.LAPACK.Matrix.Array.Private (Full, diagTag)
import Numeric.LAPACK.Matrix.Layout.Private (Order, UnaryProxy)
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Type.Data.Num.Unary.Proof as Proof
import qualified Type.Data.Num.Unary as Unary
import Type.Data.Num.Unary ((:+:))

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((::+))

import Foreign.Storable (Storable)


height ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Banded sub super meas vert horiz height width a -> height
height = Omni.height . ArrMatrix.shape

width ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Banded sub super meas vert horiz height width a -> width
width = Omni.width . ArrMatrix.shape



fromList ::
   (Unary.Natural sub, Unary.Natural super,
    Shape.C height, Shape.C width, Storable a) =>
   (UnaryProxy sub, UnaryProxy super) -> Order -> height -> width -> [a] ->
   Banded.General sub super height width a
fromList offDiag order height_ width_ =
   ArrMatrix.lift0 . Basic.fromList offDiag order height_ width_

squareFromList ::
   (Unary.Natural sub, Unary.Natural super, Shape.C size, Storable a) =>
   (UnaryProxy sub, UnaryProxy super) -> Order -> size -> [a] ->
   Banded.Square sub super size a
squareFromList offDiag order size =
   ArrMatrix.lift0 . Basic.squareFromList offDiag order size

lowerFromList ::
   (Unary.Natural sub, Shape.C size, Storable a) =>
   UnaryProxy sub -> Order -> size -> [a] -> Banded.Lower sub size a
lowerFromList numOff order size =
   ArrMatrix.lift0 . Basic.lowerFromList numOff order size

upperFromList ::
   (Unary.Natural super, Shape.C size, Storable a) =>
   UnaryProxy super -> Order -> size -> [a] -> Banded.Upper super size a
upperFromList numOff order size =
   ArrMatrix.lift0 . Basic.upperFromList numOff order size

mapExtent ::
   (Extent.C vertA, Extent.C horizA) =>
   (Extent.C vertB, Extent.C horizB) =>
   (Unary.Natural sub, Unary.Natural super) =>
   Extent.Map measA vertA horizA measB vertB horizB height width ->
   Banded sub super measA vertA horizA height width a ->
   Banded sub super measB vertB horizB height width a
mapExtent = ArrMatrix.lift1 . Basic.mapExtent . ExtentStrict.apply

transpose ::
   (Omni.TriDiag diag, Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   FlexBanded diag sub super meas vert horiz height width a ->
   FlexBanded diag super sub meas horiz vert width height a
transpose a =
   case diagTag a of
      Omni.Arbitrary -> ArrMatrix.lift1 Basic.transpose a
      Omni.Unit ->
         case ArrMatrix.shape a of
            Omni.UnitBandedTriangular _ -> ArrMatrix.lift1 Basic.transpose a

adjoint ::
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   FlexBanded diag sub super meas vert horiz height width a ->
   FlexBanded diag super sub meas horiz vert width height a
adjoint a =
   case ArrMatrix.shape a of
      Omni.Banded _ -> ArrMatrix.lift1 Basic.adjoint a
      Omni.UnitBandedTriangular _ -> ArrMatrix.lift1 Basic.adjoint a
      Omni.BandedHermitian _ -> a

diagonal ::
   (Shape.C sh, Class.Floating a) => Order -> Vector sh a -> Diagonal sh a
diagonal order = ArrMatrix.lift0 . Basic.diagonal order

takeDiagonal ::
   (Unary.Natural sub, Unary.Natural super, Shape.C sh, Class.Floating a) =>
   Banded.Square sub super sh a -> Vector sh a
takeDiagonal = Basic.takeDiagonal . ArrMatrix.toVector

multiplyVector ::
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width,
    Class.Floating a) =>
   Banded sub super meas vert horiz height width a ->
   Vector width a -> Vector height a
multiplyVector = Basic.multiplyVector . ArrMatrix.toVector

multiply ::
   (Unary.Natural subA, Unary.Natural superA,
    Unary.Natural subB, Unary.Natural superB,
    (subA :+: subB) ~ subC,
    (superA :+: superB) ~ superC,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Shape.C fuse, Eq fuse,
    Class.Floating a) =>
   Banded subA superA meas vert horiz height fuse a ->
   Banded subB superB meas vert horiz fuse width a ->
   Banded subC superC meas vert horiz height width a
multiply a b =
   case Layout.addOffDiagonals
         (MatrixShape.bandedOffDiagonals $ ArrMatrix.shape a)
         (MatrixShape.bandedOffDiagonals $ ArrMatrix.shape b) of
      ((Proof.Nat, Proof.Nat), _numOffC) -> ArrMatrix.lift2 Basic.multiply a b

multiplyFull ::
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Shape.C fuse, Eq fuse,
    Class.Floating a) =>
   Banded sub super meas vert horiz height fuse a ->
   Full meas vert horiz fuse width a -> Full meas vert horiz height width a
multiplyFull = ArrMatrix.lift2 Basic.multiplyFull


forceOrder ::
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Order ->
   Banded sub super meas vert horiz height width a ->
   Banded sub super meas vert horiz height width a
forceOrder = ArrMatrix.lift1 . Basic.forceOrder

toLowerTriangular ::
   (Unary.Natural sub, Shape.C sh, Class.Floating a) =>
   Banded.Lower sub sh a -> Triangular.Lower sh a
toLowerTriangular = ArrMatrix.lift1 Basic.toLowerTriangular

toUpperTriangular ::
   (Unary.Natural super, Shape.C sh, Class.Floating a) =>
   Banded.Upper super sh a -> Triangular.Upper sh a
toUpperTriangular = ArrMatrix.lift1 Basic.toUpperTriangular

toFull ::
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Banded sub super meas vert horiz height width a ->
   Full meas vert horiz height width a
toFull = ArrMatrix.lift1 Basic.toFull


takeTopLeftSquare ::
   (Unary.Natural sub, Unary.Natural super,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Banded.Square sub super (sh0 ::+ sh1) a ->
   Banded.Square sub super sh0 a
takeTopLeftSquare = ArrMatrix.lift1 Basic.takeTopLeftSquare

takeBottomRightSquare ::
   (Unary.Natural sub, Unary.Natural super,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Banded.Square sub super (sh0 ::+ sh1) a ->
   Banded.Square sub super sh1 a
takeBottomRightSquare = ArrMatrix.lift1 Basic.takeBottomRightSquare


noUnit :: Banded.UnitTriangular sub super sh a -> Banded.Square sub super sh a
noUnit a =
   case ArrMatrix.shape a of
      Omni.UnitBandedTriangular sh ->
         ArrMatrix.Array $ Array.reshape (Omni.Banded sh) (ArrMatrix.unwrap a)


offDiagonals ::
   (Unary.Natural sub, Unary.Natural super) =>
   Banded.Quadratic diag sub super sh a ->
   (Unary.HeadSingleton sub, Unary.HeadSingleton super)
offDiagonals _ = (Unary.headSingleton, Unary.headSingleton)

solve ::
   (Omni.TriDiag diag, Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Banded.Quadratic diag sub super sh a ->
   Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a
solve a =
   case diagTag a of
      diag@Omni.Unit ->
         case ArrMatrix.shape a of
            Omni.UnitBandedTriangular _ ->
               ArrMatrix.lift2 (Linear.solveTriangular diag) a
      diag@Omni.Arbitrary ->
         case offDiagonals a of
            (Unary.Zero, Unary.Zero) ->
               ArrMatrix.lift2 (Linear.solveTriangular diag) a
            (Unary.Zero, Unary.Succ) ->
               ArrMatrix.lift2 (Linear.solveTriangular diag) a
            (Unary.Succ, Unary.Zero) ->
               ArrMatrix.lift2 (Linear.solveTriangular diag) a
            (Unary.Succ, Unary.Succ) ->
               ArrMatrix.lift2 Linear.solve a

determinant ::
   (Omni.TriDiag diag, Unary.Natural sub, Unary.Natural super,
    Shape.C sh, Class.Floating a) =>
   Banded.Quadratic diag sub super sh a -> a
determinant a =
   case ArrMatrix.diagTag a of
      MatrixShape.Unit -> Scalar.one
      MatrixShape.Arbitrary -> Linear.determinant $ ArrMatrix.toVector a