{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
module Numeric.LAPACK.Matrix.Multiply where

import qualified Numeric.LAPACK.Matrix.Divide as Divide
import qualified Numeric.LAPACK.Matrix.Array.Multiply as ArrMultiply
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Permutation as PermMatrix
import qualified Numeric.LAPACK.Matrix.Type as Type
import qualified Numeric.LAPACK.Matrix.Shape.Box as Box
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Permutation as Perm
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Array (Full)
import Numeric.LAPACK.Matrix.Divide (transposeFull)
import Numeric.LAPACK.Matrix.Type (Matrix, scaleWithCheck)
import Numeric.LAPACK.Matrix.Modifier (Transposition(NonTransposed,Transposed))
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable as Array
import qualified Data.Array.Comfort.Shape as Shape



infixl 7 -*#
infixr 7 #*|

(#*|) ::
   (MultiplyRight typ, Class.Floating a) =>
   Matrix typ a ->
   Vector (Type.WidthOf typ) a ->
   Vector (Type.HeightOf typ) a
(#*|) = matrixVector

(-*#) ::
   (MultiplyLeft typ, Class.Floating a) =>
   Vector (Type.HeightOf typ) a ->
   Matrix typ a ->
   Vector (Type.WidthOf typ) a
(-*#) = vectorMatrix


class (Type.Box typ) => MultiplyRight typ where
   matrixVector ::
      (Class.Floating a) =>
      Matrix typ a ->
      Vector (Type.WidthOf typ) a ->
      Vector (Type.HeightOf typ) a

class (Type.Box typ) => MultiplyLeft typ where
   vectorMatrix ::
      (Class.Floating a) =>
      Vector (Type.HeightOf typ) a ->
      Matrix typ a ->
      Vector (Type.WidthOf typ) a

instance (Shape.C shape, Eq shape) => MultiplyRight (Type.Scale shape) where
   matrixVector =
      scaleWithCheck "Matrix.Multiply.matrixVector Scale"
         Array.shape Vector.scale

instance (Shape.C shape, Eq shape) => MultiplyLeft (Type.Scale shape) where
   vectorMatrix =
      flip $
      scaleWithCheck "Matrix.Multiply.vectorMatrix Scale"
         Array.shape Vector.scale

instance
   (Divide.Solve typ,
    Type.WidthOf typ ~ Type.HeightOf typ, Eq (Type.HeightOf typ)) =>
      MultiplyRight (Type.Inverse typ) where
   matrixVector (Type.Inverse a) x = Divide.solveVector NonTransposed a x

instance
   (Divide.Solve typ,
    Type.WidthOf typ ~ Type.HeightOf typ, Eq (Type.HeightOf typ)) =>
      MultiplyLeft (Type.Inverse typ) where
   vectorMatrix x (Type.Inverse a) = Divide.solveVector Transposed a x

instance
   (Shape.C shape, Eq shape) =>
      MultiplyRight (PermMatrix.Permutation shape) where
   matrixVector = PermMatrix.multiplyVector Perm.NonInverted

instance
   (Shape.C shape, Eq shape) =>
      MultiplyLeft (PermMatrix.Permutation shape) where
   vectorMatrix = flip (PermMatrix.multiplyVector Perm.Inverted)

instance
   (ArrMultiply.MultiplyRight shape) =>
      MultiplyRight (ArrMatrix.Array shape) where
   matrixVector (ArrMatrix.Array a) x = ArrMultiply.matrixVector a x

instance
   (ArrMultiply.MultiplyLeft shape) =>
      MultiplyLeft (ArrMatrix.Array shape) where
   vectorMatrix x (ArrMatrix.Array a) = ArrMultiply.vectorMatrix x a



class
   (Type.Box typ, Type.HeightOf typ ~ Type.WidthOf typ) =>
      MultiplySquare typ where
   {-# MINIMAL transposableSquare | fullSquare,squareFull #-}
   transposableSquare ::
      (Type.HeightOf typ ~ height, Eq height, Shape.C width,
       Extent.C horiz, Extent.C vert, Class.Floating a) =>
      Transposition -> Matrix typ a ->
      Full vert horiz height width a -> Full vert horiz height width a
   transposableSquare NonTransposed a b = squareFull a b
   transposableSquare Transposed a b =
      transposeFull $ fullSquare (transposeFull b) a

   squareFull ::
      (Type.HeightOf typ ~ height, Eq height, Shape.C width,
       Extent.C horiz, Extent.C vert, Class.Floating a) =>
      Matrix typ a ->
      Full vert horiz height width a -> Full vert horiz height width a
   squareFull = transposableSquare NonTransposed

   fullSquare ::
      (Type.WidthOf typ ~ width, Eq width, Shape.C height,
       Extent.C horiz, Extent.C vert, Class.Floating a) =>
      Full vert horiz height width a ->
      Matrix typ a -> Full vert horiz height width a
   fullSquare b a =
      transposeFull $ transposableSquare Transposed a $ transposeFull b

infixl 7 ##*#, #*#
infixr 7 #*##

(#*##) ::
   (MultiplySquare typ, Type.HeightOf typ ~ height, Eq height, Shape.C width,
    Extent.C horiz, Extent.C vert, Class.Floating a) =>
   Matrix typ a ->
   Full vert horiz height width a -> Full vert horiz height width a
(#*##) = squareFull

(##*#) ::
   (MultiplySquare typ, Type.WidthOf typ ~ width, Eq width, Shape.C height,
    Extent.C horiz, Extent.C vert, Class.Floating a) =>
   Full vert horiz height width a ->
   Matrix typ a -> Full vert horiz height width a
(##*#) = fullSquare

instance (Shape.C shape, Eq shape) => MultiplySquare (Type.Scale shape) where
   transposableSquare _trans =
      scaleWithCheck "Matrix.Multiply.transposableSquare" Type.height $
         ArrMatrix.lift1 . Vector.scale

instance (Divide.Solve typ) => MultiplySquare (Type.Inverse typ) where
   transposableSquare trans (Type.Inverse a) b = Divide.solve trans a b

instance (Shape.C shape) => MultiplySquare (PermMatrix.Permutation shape) where
   transposableSquare = PermMatrix.multiplyFull . Perm.inversionFromTransposition

instance
   (ArrMultiply.MultiplySquare shape) =>
      MultiplySquare (ArrMatrix.Array shape) where
   transposableSquare = ArrMatrix.lift2 . ArrMultiply.transposableSquare
   fullSquare = ArrMatrix.lift2 ArrMultiply.fullSquare
   squareFull = ArrMatrix.lift2 ArrMultiply.squareFull



(#*#) ::
   (Multiply typA typB, Class.Floating a) =>
   Matrix typA a -> Matrix typB a -> Matrix (Multiplied typA typB) a
(#*#) = matrixMatrix

class (Type.Box typA, Type.Box typB) => Multiply typA typB where
   type Multiplied typA typB
   matrixMatrix ::
      (Class.Floating a) =>
      Matrix typA a -> Matrix typB a -> Matrix (Multiplied typA typB) a

instance
   (Box.Box shapeA, Box.Box shapeB, ArrMultiply.Multiply shapeA shapeB) =>
      Multiply (ArrMatrix.Array shapeA) (ArrMatrix.Array shapeB) where
   type Multiplied (ArrMatrix.Array shapeA) (ArrMatrix.Array shapeB) =
         ArrMatrix.Array (ArrMultiply.Multiplied shapeA shapeB)
   matrixMatrix (ArrMatrix.Array a) (ArrMatrix.Array b) =
      ArrMatrix.Array $ ArrMultiply.matrixMatrix a b


instance
   (Shape.C shapeA, Eq shapeA, shapeA ~ shapeB, Shape.C shapeB) =>
      Multiply (Type.Scale shapeA) (Type.Scale shapeB) where
   type Multiplied (Type.Scale shapeA) (Type.Scale shapeB) = Type.Scale shapeB
   matrixMatrix = Type.multiplySame

instance
   (Shape.C shapeA, Eq shapeA, shapeA ~ Box.HeightOf shapeB,
    Box.Box shapeB, ArrMultiply.Scale shapeB) =>
      Multiply (Type.Scale shapeA) (ArrMatrix.Array shapeB) where
   type Multiplied (Type.Scale shapeA) (ArrMatrix.Array shapeB) =
         ArrMatrix.Array shapeB
   matrixMatrix =
      scaleWithCheck "Matrix.Multiply.multiply Scale" Type.height
         ArrMatrix.scale

instance
   (Box.Box shapeA, ArrMultiply.Scale shapeA, Box.WidthOf shapeA ~ shapeB,
    Shape.C shapeB, Eq shapeB) =>
      Multiply (ArrMatrix.Array shapeA) (Type.Scale shapeB) where
   type Multiplied (ArrMatrix.Array shapeA) (Type.Scale shapeB) =
         ArrMatrix.Array shapeA
   matrixMatrix = flip $
      scaleWithCheck "Matrix.Multiply.multiply Scale" Type.width
         ArrMatrix.scale


instance
   (Shape.C shapeA, Eq shapeA, shapeA ~ shapeB, Shape.C shapeB) =>
      Multiply (Perm.Permutation shapeA) (Perm.Permutation shapeB) where
   type Multiplied (Perm.Permutation shapeA) (Perm.Permutation shapeB) =
         Perm.Permutation shapeB
   matrixMatrix = Type.multiplySame