{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Numeric.LAPACK.Matrix.Multiply where

import qualified Numeric.LAPACK.Matrix.Plain.Multiply as ArrMultiply
import qualified Numeric.LAPACK.Matrix.Array.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Permutation as PermMatrix
import qualified Numeric.LAPACK.Matrix.Type as Type
import qualified Numeric.LAPACK.Matrix.Modifier as Mod
import qualified Numeric.LAPACK.Matrix.Shape.Box as Box
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Array (Full)
import Numeric.LAPACK.Matrix.Type (Matrix, scaleWithCheck)
import Numeric.LAPACK.Matrix.Modifier (Transposition(NonTransposed,Transposed))
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable as Array
import qualified Data.Array.Comfort.Shape as Shape



infixl 7 -*#
infixr 7 #*|

(#*|) ::
   (MultiplyVector typ, Type.WidthOf typ ~ width, Eq width,
    Class.Floating a) =>
   Matrix typ a -> Vector width a -> Vector (Type.HeightOf typ) a
(#*|) = matrixVector

(-*#) ::
   (MultiplyVector typ, Type.HeightOf typ ~ height, Eq height,
    Class.Floating a) =>
   Vector height a -> Matrix typ a -> Vector (Type.WidthOf typ) a
(-*#) = vectorMatrix


class (Type.Box typ) => MultiplyVector typ where
   matrixVector ::
      (Type.WidthOf typ ~ width, Eq width, Class.Floating a) =>
      Matrix typ a -> Vector width a -> Vector (Type.HeightOf typ) a
   vectorMatrix ::
      (Type.HeightOf typ ~ height, Eq height, Class.Floating a) =>
      Vector height a -> Matrix typ a -> Vector (Type.WidthOf typ) a

instance (Shape.C shape) => MultiplyVector (Type.Scale shape) where
   matrixVector =
      scaleWithCheck "Matrix.Multiply.matrixVector Scale"
         Array.shape Vector.scale
   vectorMatrix =
      flip $
      scaleWithCheck "Matrix.Multiply.vectorMatrix Scale"
         Array.shape Vector.scale

instance (Shape.C shape) => MultiplyVector (PermMatrix.Permutation shape) where
   matrixVector = PermMatrix.multiplyVector Mod.NonInverted
   vectorMatrix = flip $ PermMatrix.multiplyVector Mod.Inverted

instance
   (ArrMultiply.MultiplyVector shape) =>
      MultiplyVector (ArrMatrix.Array shape) where
   matrixVector (ArrMatrix.Array a) x = ArrMultiply.matrixVector a x
   vectorMatrix x (ArrMatrix.Array a) = ArrMultiply.vectorMatrix x a



class
   (Type.Box typ, Type.HeightOf typ ~ Type.WidthOf typ) =>
      MultiplySquare typ where
   {-# MINIMAL transposableSquare | fullSquare,squareFull #-}
   transposableSquare ::
      (Type.HeightOf typ ~ height, Eq height, Shape.C width,
       Extent.C horiz, Extent.C vert, Class.Floating a) =>
      Transposition -> Matrix typ a ->
      Full vert horiz height width a -> Full vert horiz height width a
   transposableSquare NonTransposed a b = squareFull a b
   transposableSquare Transposed a b =
      Basic.transpose $ fullSquare (Basic.transpose b) a

   squareFull ::
      (Type.HeightOf typ ~ height, Eq height, Shape.C width,
       Extent.C horiz, Extent.C vert, Class.Floating a) =>
      Matrix typ a ->
      Full vert horiz height width a -> Full vert horiz height width a
   squareFull = transposableSquare NonTransposed

   fullSquare ::
      (Type.WidthOf typ ~ width, Eq width, Shape.C height,
       Extent.C horiz, Extent.C vert, Class.Floating a) =>
      Full vert horiz height width a ->
      Matrix typ a -> Full vert horiz height width a
   fullSquare b a =
      Basic.transpose $ transposableSquare Transposed a $ Basic.transpose b

infixl 7 ##*#, #*#
infixr 7 #*##

(#*##) ::
   (MultiplySquare typ, Type.HeightOf typ ~ height, Eq height, Shape.C width,
    Extent.C horiz, Extent.C vert, Class.Floating a) =>
   Matrix typ a ->
   Full vert horiz height width a -> Full vert horiz height width a
(#*##) = squareFull

(##*#) ::
   (MultiplySquare typ, Type.WidthOf typ ~ width, Eq width, Shape.C height,
    Extent.C horiz, Extent.C vert, Class.Floating a) =>
   Full vert horiz height width a ->
   Matrix typ a -> Full vert horiz height width a
(##*#) = fullSquare

instance (Shape.C shape) => MultiplySquare (Type.Scale shape) where
   transposableSquare _trans =
      scaleWithCheck "Matrix.Multiply.transposableSquare" Type.height $
         ArrMatrix.lift1 . Vector.scale

instance (Shape.C shape) => MultiplySquare (PermMatrix.Permutation shape) where
   transposableSquare =
      PermMatrix.multiplyFull . Perm.inversionFromTransposition

instance
   (ArrMultiply.MultiplySquare shape) =>
      MultiplySquare (ArrMatrix.Array shape) where
   transposableSquare = ArrMatrix.lift2 . ArrMultiply.transposableSquare
   fullSquare = ArrMatrix.lift2 ArrMultiply.fullSquare
   squareFull = ArrMatrix.lift2 ArrMultiply.squareFull



class (Type.Box typ, Type.HeightOf typ ~ Type.WidthOf typ) => Power typ where
   square :: (Class.Floating a) => Matrix typ a -> Matrix typ a
   power :: (Class.Floating a) => Int -> Matrix typ a -> Matrix typ a

instance (Shape.C shape) => Power (Type.Scale shape) where
   square (Type.Scale sh a) = Type.Scale sh (a*a)
   power n (Type.Scale sh a) = Type.Scale sh (a^n)

instance (Shape.C shape) => Power (PermMatrix.Permutation shape) where
   square (Type.Permutation p) = Type.Permutation $ Perm.square p
   power n (Type.Permutation p) =
      Type.Permutation $ Perm.power (fromIntegral n) p

instance (ArrMatrix.Power shape) => Power (ArrMatrix.Array shape) where
   square = ArrMatrix.lift1 ArrMultiply.square
   power = ArrMatrix.lift1 . ArrMultiply.power



(#*#) ::
   (Multiply typA typB, Class.Floating a) =>
   Matrix typA a -> Matrix typB a -> Matrix (Multiplied typA typB) a
(#*#) = matrixMatrix

class (Type.Box typA, Type.Box typB) => Multiply typA typB where
   type Multiplied typA typB
   matrixMatrix ::
      (Class.Floating a) =>
      Matrix typA a -> Matrix typB a -> Matrix (Multiplied typA typB) a

instance
   (Box.Box shapeA, Box.Box shapeB, ArrMultiply.Multiply shapeA shapeB) =>
      Multiply (ArrMatrix.Array shapeA) (ArrMatrix.Array shapeB) where
   type Multiplied (ArrMatrix.Array shapeA) (ArrMatrix.Array shapeB) =
         ArrMatrix.Array (ArrMultiply.Multiplied shapeA shapeB)
   matrixMatrix (ArrMatrix.Array a) (ArrMatrix.Array b) =
      ArrMatrix.Array $ ArrMultiply.matrixMatrix a b


instance
   (Shape.C shapeA, Eq shapeA, shapeA ~ shapeB, Shape.C shapeB) =>
      Multiply (Type.Scale shapeA) (Type.Scale shapeB) where
   type Multiplied (Type.Scale shapeA) (Type.Scale shapeB) = Type.Scale shapeB
   matrixMatrix = Type.multiplySame

instance
   (Shape.C shapeA, Eq shapeA, shapeA ~ Box.HeightOf shapeB,
    Box.Box shapeB, ArrMultiply.Scale shapeB) =>
      Multiply (Type.Scale shapeA) (ArrMatrix.Array shapeB) where
   type Multiplied (Type.Scale shapeA) (ArrMatrix.Array shapeB) =
         ArrMatrix.Array shapeB
   matrixMatrix =
      scaleWithCheck "Matrix.Multiply.multiply Scale" Type.height
         ArrMatrix.scale

instance
   (Box.Box shapeA, ArrMultiply.Scale shapeA, Box.WidthOf shapeA ~ shapeB,
    Shape.C shapeB, Eq shapeB) =>
      Multiply (ArrMatrix.Array shapeA) (Type.Scale shapeB) where
   type Multiplied (ArrMatrix.Array shapeA) (Type.Scale shapeB) =
         ArrMatrix.Array shapeA
   matrixMatrix = flip $
      scaleWithCheck "Matrix.Multiply.multiply Scale" Type.width
         ArrMatrix.scale


instance
   (Shape.C shapeA, Eq shapeA, shapeA ~ shapeB, Shape.C shapeB) =>
      Multiply (Perm.Permutation shapeA) (Perm.Permutation shapeB) where
   type Multiplied (Perm.Permutation shapeA) (Perm.Permutation shapeB) =
         Perm.Permutation shapeB
   matrixMatrix = Type.multiplySame