{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Matrix.Divide where

import qualified Numeric.LAPACK.Matrix.Plain.Divide as ArrDivide
import qualified Numeric.LAPACK.Matrix.Array.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Permutation as PermMatrix
import qualified Numeric.LAPACK.Matrix.Multiply as Multiply
import qualified Numeric.LAPACK.Matrix.Type as Type
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Array (Full)
import Numeric.LAPACK.Matrix.Type (Matrix, scaleWithCheck)
import Numeric.LAPACK.Matrix.Modifier
         (Transposition(NonTransposed,Transposed),
          Inversion(Inverted))
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape

import Data.Semigroup ((<>))


class
   (Type.Box typ, Type.HeightOf typ ~ Type.WidthOf typ) =>
      Determinant typ where
   determinant :: (Class.Floating a) => Matrix typ a -> a

class (Type.Box typ, Type.HeightOf typ ~ Type.WidthOf typ) => Solve typ where
   {-# MINIMAL solve | solveLeft,solveRight #-}
   solve ::
      (Type.HeightOf typ ~ height, Eq height, Shape.C width,
       Extent.C horiz, Extent.C vert, Class.Floating a) =>
      Transposition -> Matrix typ a ->
      Full vert horiz height width a -> Full vert horiz height width a
   solve NonTransposed a b = solveRight a b
   solve Transposed a b = Basic.transpose $ solveLeft (Basic.transpose b) a

   solveRight ::
      (Type.HeightOf typ ~ height, Eq height, Shape.C width,
       Extent.C horiz, Extent.C vert, Class.Floating a) =>
      Matrix typ a ->
      Full vert horiz height width a -> Full vert horiz height width a
   solveRight = solve NonTransposed

   solveLeft ::
      (Type.WidthOf typ ~ width, Eq width, Shape.C height,
       Extent.C horiz, Extent.C vert, Class.Floating a) =>
      Full vert horiz height width a ->
      Matrix typ a -> Full vert horiz height width a
   solveLeft = Basic.swapMultiply $ solve Transposed

class (Solve typ, Multiply.Power typ) => Inverse typ where
   inverse :: (Class.Floating a) => Matrix typ a -> Matrix typ a

infixl 7 ##/#
infixr 7 #\##

(#\##) ::
   (Solve typ, Type.HeightOf typ ~ height, Eq height, Shape.C width,
    Extent.C horiz, Extent.C vert, Class.Floating a) =>
   Matrix typ a ->
   Full vert horiz height width a -> Full vert horiz height width a
(#\##) = solveRight

(##/#) ::
   (Solve typ, Type.WidthOf typ ~ width, Eq width, Shape.C height,
    Extent.C horiz, Extent.C vert, Class.Floating a) =>
   Full vert horiz height width a ->
   Matrix typ a -> Full vert horiz height width a
(##/#) = solveLeft


solveVector ::
   (Solve typ, Type.HeightOf typ ~ height, Eq height, Class.Floating a) =>
   Transposition -> Matrix typ a -> Vector height a -> Vector height a
solveVector trans =
   ArrMatrix.unliftColumn MatrixShape.ColumnMajor . solve trans

infixl 7 -/#
infixr 7 #\|

(#\|) ::
   (Solve typ, Type.HeightOf typ ~ height, Eq height, Class.Floating a) =>
   Matrix typ a -> Vector height a -> Vector height a
(#\|) = solveVector NonTransposed

(-/#) ::
   (Solve typ, Type.HeightOf typ ~ height, Eq height, Class.Floating a) =>
   Vector height a -> Matrix typ a -> Vector height a
(-/#) = flip $ solveVector Transposed


instance (Shape.C shape, Eq shape) => Determinant (Type.Scale shape) where
   determinant (Type.Scale sh a) = a ^ Shape.size sh

instance (Shape.C shape, Eq shape) => Solve (Type.Scale shape) where
   solve _trans =
      scaleWithCheck "Matrix.Scale.solve" Type.height $
         ArrMatrix.lift1 . Vector.scale . recip

instance (Shape.C shape, Eq shape) => Inverse (Type.Scale shape) where
   inverse (Type.Scale shape a) = Type.Scale shape $ recip a


instance (Shape.C shape) => Determinant (PermMatrix.Permutation shape) where
   determinant = PermMatrix.determinant

instance (Shape.C shape) => Solve (PermMatrix.Permutation shape) where
   solve trans =
      PermMatrix.multiplyFull
         (Inverted <> PermMatrix.inversionFromTransposition trans)

instance (Shape.C shape) => Inverse (PermMatrix.Permutation shape) where
   inverse = PermMatrix.transpose


instance
      (ArrDivide.Determinant shape) => Determinant (ArrMatrix.Array shape) where
   determinant = ArrDivide.determinant . ArrMatrix.toVector

instance (ArrDivide.Solve shape) => Solve (ArrMatrix.Array shape) where
   solve = ArrMatrix.lift2 . ArrDivide.solve
   solveLeft = ArrMatrix.lift2 ArrDivide.solveLeft
   solveRight = ArrMatrix.lift2 ArrDivide.solveRight

instance (ArrDivide.Inverse shape) => Inverse (ArrMatrix.Array shape) where
   inverse = ArrMatrix.lift1 ArrDivide.inverse