{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
module Numeric.LAPACK.Matrix.Divide where

import qualified Numeric.LAPACK.Matrix.Multiply as Multiply
import qualified Numeric.LAPACK.Matrix.Array.Divide as ArrDivide
import qualified Numeric.LAPACK.Matrix.Array.Unpacked as Unpacked
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Permutation as PermMatrix
import qualified Numeric.LAPACK.Matrix.Type.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Array.Private (Full)
import Numeric.LAPACK.Matrix.Type.Private (scaleWithCheck)
import Numeric.LAPACK.Matrix.Modifier
         (Transposition(NonTransposed,Transposed),
          Inversion(Inverted))
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable as Array
import qualified Data.Array.Comfort.Shape as Shape

import Data.Semigroup ((<>))

import GHC.Exts (Constraint)


class (Matrix.Box typ) => Determinant typ where
   type DeterminantExtra typ extra :: Constraint
   determinant ::
      (DeterminantExtra typ xl, DeterminantExtra typ xu) =>
      (Omni.Strip lower, Omni.Strip upper) =>
      (Shape.C sh, Class.Floating a) =>
      Matrix.Quadratic typ xl xu lower upper sh a -> a

class (Matrix.Box typ) => Solve typ where
   type SolveExtra typ extra :: Constraint
   {-# MINIMAL solve | solveLeft,solveRight #-}
   solve ::
      (SolveExtra typ xl, SolveExtra typ xu) =>
      (Omni.Strip lower, Omni.Strip upper) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Eq height, Class.Floating a) =>
      Transposition ->
      Matrix.Quadratic typ xl xu lower upper height a ->
      Full meas vert horiz height width a ->
      Full meas vert horiz height width a
   solve NonTransposed a b = solveRight a b
   solve Transposed a b =
      Unpacked.transpose $ solveLeft (Unpacked.transpose b) a

   solveRight ::
      (SolveExtra typ xl, SolveExtra typ xu) =>
      (Omni.Strip lower, Omni.Strip upper) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Eq height, Class.Floating a) =>
      Matrix.Quadratic typ xl xu lower upper height a ->
      Full meas vert horiz height width a ->
      Full meas vert horiz height width a
   solveRight = solve NonTransposed

   solveLeft ::
      (SolveExtra typ xl, SolveExtra typ xu) =>
      (Omni.Strip lower, Omni.Strip upper) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
      Full meas vert horiz height width a ->
      Matrix.Quadratic typ xl xu lower upper width a ->
      Full meas vert horiz height width a
   solveLeft = Unpacked.swapMultiply $ solve Transposed

class (Solve typ) => Inverse typ where
   type InverseExtra typ extra :: Constraint
   inverse ::
      (InverseExtra typ xl, InverseExtra typ xu) =>
      (MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper,
       Extent.Measure meas, Shape.C height, Shape.C width, Class.Floating a) =>
      Matrix.QuadraticMeas typ xl xu lower upper meas height width a ->
      Matrix.QuadraticMeas typ xl xu lower upper meas width height a

infixl 7 ##/#
infixr 7 #\##

(#\##) ::
   (Solve typ, Matrix.ToQuadratic typ,
    SolveExtra typ xl, SolveExtra typ xu,
    Matrix.BoxExtra typ xl, Matrix.BoxExtra typ xu,
    Omni.Strip lower, Omni.Strip upper,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs,
    Extent.Measure measA, Extent.Measure measB, Extent.Measure measC,
    Extent.MultiplyMeasure measA measB ~ measC,
    Extent.C vert, Extent.C horiz, Class.Floating a) =>
   Matrix.QuadraticMeas typ xl xu lower upper measA height width a ->
   Full measB vert horiz height nrhs a -> Full measC vert horiz width nrhs a
a#\##b =
   case Multiply.factorIdentityRight a of
      (q, ident) ->
         Multiply.reshapeHeight (Matrix.transpose ident) (solveRight q b)

(##/#) ::
   (Solve typ, Matrix.ToQuadratic typ,
    SolveExtra typ xl, SolveExtra typ xu,
    Matrix.BoxExtra typ xl, Matrix.BoxExtra typ xu,
    Omni.Strip lower, Omni.Strip upper,
    Shape.C height, Shape.C width, Eq width, Shape.C nrhs,
    Extent.Measure measA, Extent.Measure measB, Extent.Measure measC,
    Extent.MultiplyMeasure measA measB ~ measC,
    Extent.C vert, Extent.C horiz, Class.Floating a) =>
   Full measB vert horiz nrhs width a ->
   Matrix.QuadraticMeas typ xl xu lower upper measA height width a ->
   Full measC vert horiz nrhs height a
b##/#a =
   case Multiply.factorIdentityLeft a of
      (ident, q) ->
         Multiply.reshapeWidth (solveLeft b q) (Matrix.transpose ident)


solveVector ::
   (Solve typ, SolveExtra typ xl, SolveExtra typ xu,
    Omni.Strip lower, Omni.Strip upper,
    Shape.C sh, Eq sh, Class.Floating a) =>
   Transposition ->
   Matrix.Quadratic typ xl xu lower upper sh a ->
   Vector sh a -> Vector sh a
solveVector trans =
   ArrMatrix.unliftColumn Layout.ColumnMajor . solve trans


infixl 7 -/#
infixr 7 #\|

(#\|) ::
   (Solve typ, Matrix.ToQuadratic typ,
    SolveExtra typ xl, SolveExtra typ xu,
    Matrix.BoxExtra typ xl, Matrix.BoxExtra typ xu,
    Omni.Strip lower, Omni.Strip upper, Extent.Measure meas,
    Shape.C height, Shape.C width, Eq height, Class.Floating a) =>
   Matrix.QuadraticMeas typ xl xu lower upper meas height width a ->
   Vector height a -> Vector width a
(#\|) a =
   case Multiply.factorIdentityRight a of
      (q, ident) ->
         reshapeVector (Matrix.transpose ident) . solveVector NonTransposed q

(-/#) ::
   (Solve typ, Matrix.ToQuadratic typ,
    SolveExtra typ xl, SolveExtra typ xu,
    Matrix.BoxExtra typ xl, Matrix.BoxExtra typ xu,
    Omni.Strip lower, Omni.Strip upper, Extent.Measure meas,
    Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Vector width a ->
   Matrix.QuadraticMeas typ xl xu lower upper meas height width a ->
   Vector height a
(-/#) = flip $ \a ->
   case Multiply.factorIdentityLeft a of
      (ident, q) -> reshapeVector ident . solveVector Transposed q


reshapeVector ::
   (Extent.Measure meas, Shape.C height, Shape.C width) =>
   Multiply.IdentityMaes meas height width a ->
   Vector width a -> Vector height a
reshapeVector (Matrix.Identity extent) = Array.reshape (Extent.height extent)


instance Determinant Matrix.Scale where
   type DeterminantExtra Matrix.Scale extra = extra ~ ()
   determinant (Matrix.Scale sh a) = a ^ Shape.size sh

instance Solve Matrix.Scale where
   type SolveExtra Matrix.Scale extra = extra ~ ()
   solve _trans =
      scaleWithCheck "Matrix.Scale.solve" Matrix.height $
         ArrMatrix.lift1 . Vector.scale . recip

instance Inverse Matrix.Scale where
   type InverseExtra Matrix.Scale extra = extra ~ ()
   inverse (Matrix.Scale shape a) = Matrix.Scale shape $ recip a


instance Determinant Matrix.Permutation where
   type DeterminantExtra Matrix.Permutation extra = extra ~ ()
   determinant = PermMatrix.determinant

instance Solve Matrix.Permutation where
   type SolveExtra Matrix.Permutation extra = extra ~ ()
   solve trans =
      PermMatrix.multiplyFull
         (Inverted <> PermMatrix.inversionFromTransposition trans)

instance Inverse Matrix.Permutation where
   type InverseExtra Matrix.Permutation extra = extra ~ ()
   inverse a@(Matrix.Permutation _) =
      case Matrix.powerStrips a of
         (MatrixShape.Filled, MatrixShape.Filled) -> PermMatrix.transpose a
         _ -> a -- identity matrix

instance
   (Layout.Packing pack, Omni.Property property) =>
      Determinant (ArrMatrix.Array pack property) where
   type DeterminantExtra (ArrMatrix.Array pack property) extra = extra ~ ()
   determinant = ArrDivide.determinant

instance
   (Layout.Packing pack, Omni.Property property) =>
      Solve (ArrMatrix.Array pack property) where
   type SolveExtra (ArrMatrix.Array pack property) extra = extra ~ ()
   solveRight = ArrDivide.solve
   solveLeft = Matrix.swapMultiply $  ArrDivide.solve . Matrix.transpose

instance
   (Layout.Packing pack, Omni.Property property) =>
      Inverse (ArrMatrix.Array pack property) where
   type InverseExtra (ArrMatrix.Array pack property) extra = extra ~ ()
   inverse = ArrDivide.inverse