{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
module Numeric.LAPACK.Matrix.Wrapper where

import qualified Numeric.LAPACK.Matrix.Type.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Strict as ExtentStrict
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Array.Unpacked as Unpacked
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Class as MatrixClass
import qualified Numeric.LAPACK.Matrix.Divide as Divide
import qualified Numeric.LAPACK.Matrix.Multiply as Multiply
import Numeric.LAPACK.Matrix.Type.Private (Matrix)
import Numeric.LAPACK.Matrix.Shape (Filled)

import qualified Type.Data.Num.Unary as Unary

import Data.Tuple.HT (mapPair)



data MapExtent typ meas
data instance
   Matrix (MapExtent typ meas)
      extraLower extraUpper lower upper meas1 vert1 horiz1 height width a
         where
      MapExtent ::
         (Extent.C vert0, Extent.C horiz0) =>
         Extent.Map meas0 vert0 horiz0 meas1 vert1 horiz1 height width ->
         Matrix typ xl xu lower upper meas0 vert0 horiz0 height width a ->
         Matrix (MapExtent typ meas0) (xl,vert0) (xu,horiz0)
            lower upper meas1 vert1 horiz1 height width a

type family MapExtentExtra xl
type instance MapExtentExtra (xl,ex) = xl
type family MapExtentExtent xl
type instance MapExtentExtent (xl,ex) = ex


instance
   (Matrix.Box typ, Extent.Measure meas) =>
      Matrix.Box (MapExtent typ meas) where
   type BoxExtra (MapExtent typ meas) extra =
         (Matrix.BoxExtra typ (MapExtentExtra extra))
   extent (MapExtent m a) = ExtentStrict.apply m $ Matrix.extent a

instance
   (Matrix.Transpose typ, Extent.Measure meas) =>
      Matrix.Transpose (MapExtent typ meas) where
   type TransposeExtra (MapExtent typ meas) extra =
         (Matrix.TransposeExtra typ (MapExtentExtra extra))
   transpose (MapExtent m a) =
      MapExtent (ExtentStrict.transpose m) (Matrix.transpose a)

instance
   (Matrix.Layout typ, Extent.Measure meas) =>
      Matrix.Layout (MapExtent typ meas) where
   type LayoutExtra (MapExtent typ meas) extra =
         (Matrix.LayoutExtra typ (MapExtentExtra extra))
   layout (MapExtent _ a) = Matrix.layout a

instance
   (Matrix.Format typ, Extent.Measure meas) =>
      Matrix.Format (MapExtent typ meas) where
   type FormatExtra (MapExtent typ meas) extra =
         (Matrix.FormatExtra typ (MapExtentExtra extra))
   format fmt (MapExtent _ a) = Matrix.format fmt a

instance
   (MatrixClass.Unpack typ, Extent.Measure meas) =>
      MatrixClass.Unpack (MapExtent typ meas) where
   type UnpackExtra (MapExtent typ meas) extra =
         (MatrixClass.UnpackExtra typ (MapExtentExtra extra))
   unpack (MapExtent m a) =
      ArrMatrix.liftUnpacked1 id $ ArrMatrix.mapExtent m $ MatrixClass.toFull a

instance
   (Matrix.MultiplySame typ, Extent.Measure meas) =>
      Matrix.MultiplySame (MapExtent typ meas) where
   type MultiplySameExtra (MapExtent typ meas) extra =
         (Matrix.MultiplySameExtra typ (MapExtentExtra extra))
   multiplySame (MapExtent m a) (MapExtent _ b) =
      MapExtent m $ Matrix.multiplySame a b

instance
   (MatrixClass.Complex typ, Extent.Measure meas) =>
      MatrixClass.Complex (MapExtent typ meas) where
   conjugate (MapExtent m a) = MapExtent m $ MatrixClass.conjugate a
   fromReal (MapExtent m a) = MapExtent m $ MatrixClass.fromReal a
   toComplex (MapExtent m a) = MapExtent m $ MatrixClass.toComplex a


instance
   (Multiply.MultiplyVector typ, Matrix.ToQuadratic typ, Extent.Measure meas) =>
      Multiply.MultiplyVector (MapExtent typ meas) where
   type MultiplyVectorExtra (MapExtent typ meas) extra =
         (Multiply.MultiplyVectorExtra typ (MapExtentExtra extra),
          Matrix.BoxExtra typ (MapExtentExtra extra))
   matrixVector (MapExtent _ a) x = Multiply.matrixVector a x
   vectorMatrix x (MapExtent _ a) = Multiply.vectorMatrix x a



data FillStrips typ
data instance
   Matrix (FillStrips typ)
      extraLower extraUpper lower upper meas vert horiz height width a
         where
      FillStrips ::
         (Omni.Strip lower, Omni.Strip upper) =>
         Matrix typ xl xu lower upper meas vert horiz height width a ->
         Matrix (FillStrips typ) (xl,lower) (xu,upper)
            Filled Filled meas vert horiz height width a

type family FillStripsExtra xl
type instance FillStripsExtra (xl,lower) = xl
type family FillStripsStrip xl
type instance FillStripsStrip (xl,lower) = lower


instance (Matrix.Box typ) => Matrix.Box (FillStrips typ) where
   type BoxExtra (FillStrips typ) extra =
         Matrix.BoxExtra typ (FillStripsExtra extra)
   extent (FillStrips m) = Matrix.extent m

instance (Matrix.Transpose typ) => Matrix.Transpose (FillStrips typ) where
   type TransposeExtra (FillStrips typ) extra =
         Matrix.TransposeExtra typ (FillStripsExtra extra)
   transpose (FillStrips a) = FillStrips $ Matrix.transpose a

instance (Matrix.Layout typ) => Matrix.Layout (FillStrips typ) where
   type LayoutExtra (FillStrips typ) extra =
         Matrix.LayoutExtra typ (FillStripsExtra extra)
   layout (FillStrips m) = Matrix.layout m

instance (Matrix.Format typ) => Matrix.Format (FillStrips typ) where
   type FormatExtra (FillStrips typ) extra =
         Matrix.FormatExtra typ (FillStripsExtra extra)
   format fmt (FillStrips a) = Matrix.format fmt a

instance (Matrix.ToQuadratic typ) => Matrix.ToQuadratic (FillStrips typ) where
   heightToQuadratic (FillStrips m) = FillStrips $ Matrix.heightToQuadratic m
   widthToQuadratic (FillStrips m) = FillStrips $ Matrix.widthToQuadratic m

instance
   (MatrixClass.Unpack typ) =>
      MatrixClass.Unpack (FillStrips typ) where
   type UnpackExtra (FillStrips typ) extra =
         MatrixClass.UnpackExtra typ (FillStripsExtra extra)
   unpack (FillStrips m) = Unpacked.fillBoth $ MatrixClass.unpack m

instance (MatrixClass.Complex typ) => MatrixClass.Complex (FillStrips typ) where
   conjugate (FillStrips m) = FillStrips $ MatrixClass.conjugate m
   fromReal (FillStrips m) = FillStrips $ MatrixClass.fromReal m
   toComplex (FillStrips m) = FillStrips $ MatrixClass.toComplex m


instance (Matrix.MultiplySame typ) => Matrix.MultiplySame (FillStrips typ) where
   type MultiplySameExtra (FillStrips typ) extra =
         (Matrix.MultiplySameExtra typ (FillStripsExtra extra),
          MatrixShape.PowerStrip (FillStripsStrip extra))
   multiplySame (FillStrips a) (FillStrips b) =
      FillStrips $ Matrix.multiplySame a b

instance
   (Multiply.MultiplyVector typ, Matrix.ToQuadratic typ) =>
      Multiply.MultiplyVector (FillStrips typ) where
   type MultiplyVectorExtra (FillStrips typ) extra =
         (Multiply.MultiplyVectorExtra typ (FillStripsExtra extra),
          Matrix.BoxExtra typ (FillStripsExtra extra),
          Omni.Strip (FillStripsStrip extra))
   matrixVector (FillStrips a) x = Multiply.matrixVector a x
   vectorMatrix x (FillStrips a) = Multiply.vectorMatrix x a

instance
   (Multiply.MultiplySquare typ, Matrix.ToQuadratic typ) =>
      Multiply.MultiplySquare (FillStrips typ) where
   type MultiplySquareExtra (FillStrips typ) extra =
         (Multiply.MultiplySquareExtra typ (FillStripsExtra extra),
          Matrix.BoxExtra typ (FillStripsExtra extra),
          Omni.Strip (FillStripsStrip extra))
   transposableSquare trans (FillStrips a) = Multiply.transposableSquare trans a
   squareFull (FillStrips a) b = Multiply.squareFull a b
   fullSquare b (FillStrips a) = Multiply.fullSquare b a

instance (Multiply.Power typ) => Multiply.Power (FillStrips typ) where
   type PowerExtra (FillStrips typ) extra =
         (Multiply.PowerExtra typ (FillStripsExtra extra),
          MatrixShape.PowerStrip (FillStripsStrip extra))
   square (FillStrips a) = FillStrips $ Multiply.square a
   power n (FillStrips a) = FillStrips $ Multiply.power n a
   powers1 (FillStrips a) = fmap FillStrips $ Multiply.powers1 a


instance (Divide.Determinant typ) => Divide.Determinant (FillStrips typ) where
   type DeterminantExtra (FillStrips typ) extra =
         (Divide.DeterminantExtra typ (FillStripsExtra extra))
   determinant (FillStrips a) = Divide.determinant a

instance
   (Divide.Solve typ, Matrix.ToQuadratic typ) =>
      Divide.Solve (FillStrips typ) where
   type SolveExtra (FillStrips typ) extra =
         (Divide.SolveExtra typ (FillStripsExtra extra))
   solve trans (FillStrips a) = Divide.solve trans a
   solveRight (FillStrips a) b = Divide.solveRight a b
   solveLeft b (FillStrips a) = Divide.solveLeft b a

instance
   (Divide.Inverse typ, Matrix.ToQuadratic typ) =>
      Divide.Inverse (FillStrips typ) where
   type InverseExtra (FillStrips typ) extra =
         (Divide.InverseExtra typ (FillStripsExtra extra),
          MatrixShape.PowerStrip (FillStripsStrip extra))
   inverse (FillStrips a) = FillStrips $ Divide.inverse a



{- |
I do not know, if you will ever need this.
For diagonal matrices you may not need a wrapper at all
and for other matrices you may use 'FillStrips'.
-}
data PowerStrips typ
data instance
   Matrix (PowerStrips typ)
      extraLower extraUpper lowerf upperf meas vert horiz height width a
         where
      PowerStrips ::
         (Omni.Strip lower, Fill lower ~ lowerf, Omni.PowerStrip lowerf,
          Omni.Strip upper, Fill upper ~ upperf, Omni.PowerStrip upperf) =>
         Matrix typ xl xu lower upper meas vert horiz height width a ->
         Matrix (PowerStrips typ) (xl,lower) (xu,upper)
            lowerf upperf meas vert horiz height width a

powerStrips ::
   (Omni.Strip lower, Omni.Strip upper) =>
   Matrix.QuadraticMeas typ xl xu lower upper meas height width a ->
   Matrix.QuadraticMeas (PowerStrips typ) (xl,lower) (xu,upper)
      (Fill lower) (Fill upper) meas height width a
powerStrips a =
   case mapPair (filledPowerStripFact, filledPowerStripFact) $
        Matrix.strips a of
      (PowerStripFact, PowerStripFact) -> PowerStrips a

type family Fill offDiag
type instance Fill (Layout.Bands Unary.Zero) = Layout.Bands Unary.Zero
type instance Fill (Layout.Bands (Unary.Succ k)) = Layout.Filled
type instance Fill Layout.Filled = Layout.Filled

type family PowerStripsExtra xl
type instance PowerStripsExtra (xl,lower) = xl
type family PowerStripsStrip xl
type instance PowerStripsStrip (xl,lower) = lower

data PowerStripFact c = (Omni.PowerStrip c) => PowerStripFact

filledPowerStripFact ::
   (Omni.Strip c) => Omni.StripSingleton c -> PowerStripFact (Fill c)
filledPowerStripFact w =
   case w of
      Omni.StripFilled -> PowerStripFact
      Omni.StripBands Unary.Zero -> PowerStripFact
      Omni.StripBands Unary.Succ -> PowerStripFact


instance (Matrix.Box typ) => Matrix.Box (PowerStrips typ) where
   type BoxExtra (PowerStrips typ) extra =
         Matrix.BoxExtra typ (PowerStripsExtra extra)
   extent (PowerStrips m) = Matrix.extent m

instance (Matrix.Transpose typ) => Matrix.Transpose (PowerStrips typ) where
   type TransposeExtra (PowerStrips typ) extra =
         Matrix.TransposeExtra typ (PowerStripsExtra extra)
   transpose (PowerStrips a) = PowerStrips $ Matrix.transpose a

instance (Matrix.Layout typ) => Matrix.Layout (PowerStrips typ) where
   type LayoutExtra (PowerStrips typ) extra =
         Matrix.LayoutExtra typ (PowerStripsExtra extra)
   layout (PowerStrips m) = Matrix.layout m

instance (Matrix.Format typ) => Matrix.Format (PowerStrips typ) where
   type FormatExtra (PowerStrips typ) extra =
         Matrix.FormatExtra typ (PowerStripsExtra extra)
   format fmt (PowerStrips a) = Matrix.format fmt a

instance
      (Matrix.MultiplySame typ) => Matrix.MultiplySame (PowerStrips typ) where
   type MultiplySameExtra (PowerStrips typ) extra =
         (Matrix.MultiplySameExtra typ (PowerStripsExtra extra),
          MatrixShape.PowerStrip (PowerStripsStrip extra))
   multiplySame (PowerStrips a) (PowerStrips b) =
      PowerStrips $ Matrix.multiplySame a b

instance (Matrix.ToQuadratic typ) => Matrix.ToQuadratic (PowerStrips typ) where
   heightToQuadratic (PowerStrips m) = PowerStrips $ Matrix.heightToQuadratic m
   widthToQuadratic (PowerStrips m) = PowerStrips $ Matrix.widthToQuadratic m

instance
   (MatrixClass.Complex typ) =>
      MatrixClass.Complex (PowerStrips typ) where
   conjugate (PowerStrips m) = PowerStrips $ MatrixClass.conjugate m
   fromReal (PowerStrips m) = PowerStrips $ MatrixClass.fromReal m
   toComplex (PowerStrips m) = PowerStrips $ MatrixClass.toComplex m


instance
   (Multiply.MultiplyVector typ, Matrix.ToQuadratic typ) =>
      Multiply.MultiplyVector (PowerStrips typ) where
   type MultiplyVectorExtra (PowerStrips typ) extra =
         (Multiply.MultiplyVectorExtra typ (PowerStripsExtra extra),
          Divide.SolveExtra typ (PowerStripsExtra extra),
          Matrix.BoxExtra typ (PowerStripsExtra extra),
          Omni.Strip (PowerStripsStrip extra))
   matrixVector (PowerStrips a) x = Multiply.matrixVector a x
   vectorMatrix x (PowerStrips a) = Multiply.vectorMatrix x a

instance
   (Multiply.MultiplySquare typ, Matrix.ToQuadratic typ) =>
      Multiply.MultiplySquare (PowerStrips typ) where
   type MultiplySquareExtra (PowerStrips typ) extra =
         (Multiply.MultiplySquareExtra typ (PowerStripsExtra extra),
          Divide.SolveExtra typ (PowerStripsExtra extra),
          Matrix.BoxExtra typ (PowerStripsExtra extra),
          Omni.Strip (PowerStripsStrip extra))
   transposableSquare trans (PowerStrips a) =
      Multiply.transposableSquare trans a
   squareFull (PowerStrips a) b = Multiply.squareFull a b
   fullSquare b (PowerStrips a) = Multiply.fullSquare b a

instance (Multiply.Power typ) => Multiply.Power (PowerStrips typ) where
   type PowerExtra (PowerStrips typ) extra =
         (Multiply.PowerExtra typ (PowerStripsExtra extra),
          MatrixShape.PowerStrip (PowerStripsStrip extra))
   square (PowerStrips a) = PowerStrips $ Multiply.square a
   power n (PowerStrips a) = PowerStrips $ Multiply.power n a
   powers1 (PowerStrips a) = fmap PowerStrips $ Multiply.powers1 a


instance (Divide.Determinant typ) => Divide.Determinant (PowerStrips typ) where
   type DeterminantExtra (PowerStrips typ) extra =
         (Divide.DeterminantExtra typ (PowerStripsExtra extra))
   determinant (PowerStrips a) = Divide.determinant a

instance
   (Divide.Solve typ, Matrix.ToQuadratic typ) =>
      Divide.Solve (PowerStrips typ) where
   type SolveExtra (PowerStrips typ) extra =
         (Divide.SolveExtra typ (PowerStripsExtra extra))
   solve trans (PowerStrips a) = Divide.solve trans a
   solveRight (PowerStrips a) b = Divide.solveRight a b
   solveLeft b (PowerStrips a) = Divide.solveLeft b a

instance
   (Divide.Inverse typ, Matrix.ToQuadratic typ) =>
      Divide.Inverse (PowerStrips typ) where
   type InverseExtra (PowerStrips typ) extra =
         (Divide.InverseExtra typ (PowerStripsExtra extra),
          MatrixShape.PowerStrip (PowerStripsStrip extra))
   inverse (PowerStrips a) = PowerStrips $ Divide.inverse a