{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Layout.Private where

import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Extent.Private (Extent)

import qualified Type.Data.Num.Unary.Literal as TypeNum
import qualified Type.Data.Num.Unary.Proof as Proof
import qualified Type.Data.Num.Unary as Unary
import Type.Data.Num.Unary (unary, (:+:))
import Type.Data.Num (integralFromProxy)
import Type.Base.Proxy (Proxy(Proxy))

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape (triangleSize, triangleRoot)

import Control.DeepSeq (NFData, rnf)
import Control.Applicative ((<$>))

import Data.List (tails)
import Data.Tuple.HT (mapSnd, swap)
import Data.Bool.HT (if')


data Order = RowMajor | ColumnMajor
   deriving (Eq, Show)

instance NFData Order where
   rnf RowMajor = ()
   rnf ColumnMajor = ()

flipOrder :: Order -> Order
flipOrder RowMajor = ColumnMajor
flipOrder ColumnMajor = RowMajor

transposeFromOrder :: Order -> Char
transposeFromOrder RowMajor = 'T'
transposeFromOrder ColumnMajor = 'N'

swapOnRowMajor :: Order -> (a,a) -> (a,a)
swapOnRowMajor order =
   case order of
      RowMajor -> swap
      ColumnMajor -> id

sideSwapFromOrder :: Order -> (a,a) -> (Char, (a,a))
sideSwapFromOrder order (m0,n0) =
   let ((side,m), (_,n)) = swapOnRowMajor order (('L', m0), ('R', n0))
   in (side,(m,n))


mapChecked ::
   (Shape.C sha, Shape.C shb) =>
   String -> (sha -> shb) -> sha -> shb
mapChecked name f sizeA =
   let sizeB = f sizeA
   in if Shape.size sizeA == Shape.size sizeB
         then sizeB
         else error $ name ++ ": sizes mismatch"


data Full meas vert horiz height width =
   Full {
      fullOrder :: Order,
      fullExtent :: Extent meas vert horiz height width
   } deriving (Eq, Show)

instance
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    NFData height, NFData width) =>
      NFData (Full meas vert horiz height width) where
   rnf (Full order extent) = rnf (order, extent)

instance
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
      Shape.C (Full meas vert horiz height width) where

   size (Full _ extent) = Shape.size (Extent.dimensions extent)

instance
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.Indexed height, Shape.Indexed width) =>
      Shape.Indexed (Full meas vert horiz height width) where

   type Index (Full meas vert horiz height width) =
            (Shape.Index height, Shape.Index width)
   indices (Full order extent) = fullIndices order extent

   unifiedOffset (Full RowMajor extent) =
      Shape.unifiedOffset (Extent.dimensions extent)
   unifiedOffset (Full ColumnMajor extent) =
      Shape.unifiedOffset (swap $ Extent.dimensions extent) . swap

   unifiedSizeOffset (Full RowMajor extent) =
      Shape.unifiedSizeOffset (Extent.dimensions extent)
   unifiedSizeOffset (Full ColumnMajor extent) =
      mapSnd (.swap) $
      Shape.unifiedSizeOffset (swap $ Extent.dimensions extent)

   inBounds (Full _ extent) = Shape.inBounds (Extent.dimensions extent)

instance
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.InvIndexed height, Shape.InvIndexed width) =>
      Shape.InvIndexed (Full meas vert horiz height width) where

   unifiedIndexFromOffset (Full order extent) =
      fullIndexFromOffset order extent


transpose ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Full meas vert horiz height width -> Full meas horiz vert width height
transpose (Full order extent) = Full (flipOrder order) (Extent.transpose extent)

inverse ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Full meas vert horiz height width -> Full meas horiz vert width height
inverse (Full order extent) = Full order (Extent.transpose extent)

dimensions ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
   Full meas vert horiz height width -> (Int, Int)
dimensions (Full order extent) =
   swapOnRowMajor order
      (Shape.size $ Extent.height extent,
       Shape.size $ Extent.width extent)

fullHeight ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Full meas vert horiz height width -> height
fullHeight = Extent.height . fullExtent

fullWidth ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Full meas vert horiz height width -> width
fullWidth = Extent.width . fullExtent


fullIndices ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.Indexed a, Shape.Indexed b) =>
   Order -> Extent meas vert horiz a b -> [(Shape.Index a, Shape.Index b)]
fullIndices order extent =
   case order of
      RowMajor -> Shape.indices $ Extent.dimensions extent
      ColumnMajor -> map swap $ Shape.indices $ swap $ Extent.dimensions extent

fullIndexFromOffset ::
   (Shape.Checking check,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.InvIndexed a, Shape.InvIndexed b) =>
   Order -> Extent meas vert horiz a b -> Int ->
   Shape.Result check (Shape.Index a, Shape.Index b)
fullIndexFromOffset order extent =
   case order of
      RowMajor ->
         Shape.unifiedIndexFromOffset (Extent.dimensions extent)
      ColumnMajor ->
         fmap swap .
         Shape.unifiedIndexFromOffset (swap $ Extent.dimensions extent)


type General height width = Full Extent.Size Extent.Big Extent.Big height width
type Tall height width = Full Extent.Size Extent.Big Extent.Small height width
type Wide height width = Full Extent.Size Extent.Small Extent.Big height width
type LiberalSquare height width = SquareMeas Extent.Size height width
type Square size = SquareMeas Extent.Shape size size
type SquareMeas meas height width =
         Full meas Extent.Small Extent.Small height width


fullMapExtent ::
   Extent.Map measA vertA horizA measB vertB horizB height width ->
   Full measA vertA horizA height width ->
   Full measB vertB horizB height width
fullMapExtent f (Full order extent) = Full order $ f extent

general :: Order -> height -> width -> General height width
general order height width = Full order $ Extent.general height width

tall ::
   (Shape.C height, Shape.C width) =>
   Order -> height -> width -> Tall height width
tall order height width =
   if Shape.size height >= Shape.size width
      then Full order $ Extent.tall height width
      else error "Layout.tall: height smaller than width"

wide ::
   (Shape.C height, Shape.C width) =>
   Order -> height -> width -> Wide height width
wide order height width =
   if Shape.size height <= Shape.size width
      then Full order $ Extent.wide height width
      else error "Layout.wide: width smaller than height"

liberalSquare ::
   (Shape.C height, Shape.C width) =>
   Order -> height -> width -> LiberalSquare height width
liberalSquare order height width =
   if Shape.size height == Shape.size width
      then Full order $ Extent.liberalSquare height width
      else error "Layout.liberalSquare: height and width sizes differ"

square :: Order -> sh -> Square sh
square order sh = Full order $ Extent.square sh


caseTallWide ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
   Full meas vert horiz height width ->
   Either (Tall height width) (Wide height width)
caseTallWide (Full order extent) =
   either (Left . Full order) (Right . Full order) $
   Extent.caseTallWide (\h w -> Shape.size h >= Shape.size w) extent


data Split lower meas vert horiz height width =
   Split {
      splitLower :: lower,
      splitOrder :: Order,
      splitExtent :: Extent meas vert horiz height width
   } deriving (Eq, Show)

splitHeight ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Split lower meas vert horiz height width -> height
splitHeight = Extent.height . splitExtent

splitWidth ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Split lower meas vert horiz height width -> width
splitWidth = Extent.width . splitExtent

splitMapExtent ::
   Extent.Map measA vertA horizA measB vertB horizB height width ->
   Split lower measA vertA horizA height width ->
   Split lower measB vertB horizB height width
splitMapExtent f (Split lowerPart order extent) =
   Split lowerPart order $ f extent


caseTallWideSplit ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
   Split lower meas vert horiz height width ->
   Either
      (Split lower Extent.Size Extent.Big Extent.Small height width)
      (Split lower Extent.Size Extent.Small Extent.Big height width)
caseTallWideSplit (Split lowerPart order extent) =
   either (Left . Split lowerPart order) (Right . Split lowerPart order) $
   Extent.caseTallWide (\h w -> Shape.size h >= Shape.size w) extent

data Reflector = Reflector deriving (Eq, Show)
data Triangle = Triangle deriving (Eq, Show)

instance NFData Reflector where rnf Reflector = ()
instance NFData Triangle where rnf Triangle = ()

splitPart ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.Indexed height, Shape.Indexed width) =>
   Split lower meas vert horiz height width ->
   (Shape.Index height, Shape.Index width) -> Either lower Triangle
splitPart (Split lowerPart _ extent) (r,c) =
   if Shape.offset (Extent.height extent) r >
         Shape.offset (Extent.width extent) c
     then Left lowerPart
     else Right Triangle

instance
   (NFData lower, Extent.Measure meas, Extent.C vert, Extent.C horiz,
    NFData height, NFData width) =>
      NFData (Split lower meas vert horiz height width) where
   rnf (Split lowerPart order extent) = rnf (lowerPart, order, extent)

instance
   (Eq lower, Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
      Shape.C (Split lower meas vert horiz height width) where

   size (Split _ _ extent) = Shape.size (Extent.dimensions extent)

instance
   (Eq lower, Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.Indexed height, Shape.Indexed width) =>
      Shape.Indexed (Split lower meas vert horiz height width) where

   type Index (Split lower meas vert horiz height width) =
            (Either lower Triangle,
             (Shape.Index height, Shape.Index width))

   indices sh@(Split _ order extent) =
      map (\ix -> (splitPart sh ix, ix)) $ fullIndices order extent

   unifiedOffset sh@(Split _ order extent) (part,ix) = do
      Shape.assert "Shape.Split.offset: wrong matrix part" $
         part == splitPart sh ix
      case order of
         RowMajor -> Shape.unifiedOffset (Extent.dimensions extent) ix
         ColumnMajor ->
            Shape.unifiedOffset (swap $ Extent.dimensions extent) (swap ix)

   unifiedSizeOffset sh@(Split _ order extent) =
      let check (part,ix) a = do
            Shape.assert "Shape.Split.sizeOffset: wrong matrix part" $
               part == splitPart sh ix
            return a
      in case order of
            RowMajor ->
               mapSnd
                  (\getOffset (part,ix) -> check (part,ix) =<< getOffset ix) $
               Shape.unifiedSizeOffset (Extent.dimensions extent)
            ColumnMajor ->
               mapSnd
                  (\getOffset (part,ix) ->
                     check (part,ix) =<< getOffset (swap ix)) $
               Shape.unifiedSizeOffset (swap $ Extent.dimensions extent)

   inBounds sh@(Split _ _ extent) (part,ix) =
      Shape.inBounds (Extent.dimensions extent) ix
      &&
      part == splitPart sh ix

instance
   (Eq lower, Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.InvIndexed height, Shape.InvIndexed width) =>
      Shape.InvIndexed (Split lower meas vert horiz height width) where

   unifiedIndexFromOffset sh@(Split _ order extent) k = do
      ix <- fullIndexFromOffset order extent k
      return (splitPart sh ix, ix)



data Mosaic pack mirror uplo size =
   Mosaic {
      mosaicPack :: PackingSingleton pack,
      mosaicMirror :: MirrorSingleton mirror,
      mosaicUplo :: UpLoSingleton uplo,
      mosaicOrder :: Order,
      mosaicSize :: size
   } deriving (Eq, Show)

data Packed
data Unpacked

data PackingSingleton pack where
   Packed :: PackingSingleton Packed
   Unpacked :: PackingSingleton Unpacked

deriving instance Eq (PackingSingleton pack)
deriving instance Show (PackingSingleton pack)

instance NFData (PackingSingleton pack) where
   rnf Packed = ()
   rnf Unpacked = ()

class Packing pack where autoPacking :: PackingSingleton pack
instance Packing Unpacked where autoPacking = Unpacked
instance Packing Packed where autoPacking = Packed

squareFromMosaic :: Mosaic Unpacked mirror uplo size -> Square size
squareFromMosaic (Mosaic {mosaicOrder = order, mosaicSize = size}) =
   square order size

mosaicFromSquare ::
   (Mirror mirror, UpLo uplo) => Square size -> Mosaic Unpacked mirror uplo size
mosaicFromSquare (Full {fullOrder = order, fullExtent = extent}) =
   Mosaic {
      mosaicPack = Unpacked,
      mosaicMirror = autoMirror,
      mosaicUplo = autoUplo,
      mosaicOrder = order,
      mosaicSize = Extent.squareSize extent
   }


data NoMirror
data SimpleMirror
data ConjugateMirror

data MirrorSingleton mirror where
   NoMirror :: MirrorSingleton NoMirror
   SimpleMirror :: MirrorSingleton SimpleMirror
   ConjugateMirror :: MirrorSingleton ConjugateMirror

deriving instance Eq (MirrorSingleton mirror)
deriving instance Show (MirrorSingleton mirror)

instance NFData (MirrorSingleton mirror) where
   rnf NoMirror = ()
   rnf SimpleMirror = ()
   rnf ConjugateMirror = ()

class Mirror mirror where autoMirror :: MirrorSingleton mirror
instance Mirror NoMirror where autoMirror = NoMirror
instance Mirror SimpleMirror where autoMirror = SimpleMirror
instance Mirror ConjugateMirror where autoMirror = ConjugateMirror


type TriangularP pack = Mosaic pack NoMirror
type Triangular = TriangularP Packed

type LowerTriangularP pack = TriangularP pack Shape.Lower
type LowerTriangular = Triangular Shape.Lower

type UpperTriangularP pack = TriangularP pack Shape.Upper
type UpperTriangular = Triangular Shape.Upper

triangular :: UpLoSingleton uplo -> Order -> size -> Triangular uplo size
triangular = Mosaic Packed NoMirror

upperTriangular :: Order -> size -> UpperTriangular size
upperTriangular = triangular Upper

lowerTriangular :: Order -> size -> LowerTriangular size
lowerTriangular = triangular Lower


triangularP ::
   PackingSingleton pack ->
   UpLoSingleton uplo -> Order -> size -> TriangularP pack uplo size
triangularP pack = Mosaic pack NoMirror

upperTriangularP ::
   PackingSingleton pack -> Order -> size -> UpperTriangularP pack size
upperTriangularP pack = triangularP pack Upper

lowerTriangularP ::
   PackingSingleton pack -> Order -> size -> LowerTriangularP pack size
lowerTriangularP pack = triangularP pack Lower


type SymmetricP pack = Mosaic pack SimpleMirror Shape.Upper
type Symmetric = SymmetricP Packed

symmetric :: Order -> size -> Symmetric size
symmetric = symmetricP Packed

symmetricP :: PackingSingleton pack -> Order -> size -> SymmetricP pack size
symmetricP pack = Mosaic pack SimpleMirror Upper

symmetricFromHermitian :: HermitianP pack size -> SymmetricP pack size
symmetricFromHermitian (Mosaic pack ConjugateMirror upper order size) =
   Mosaic pack SimpleMirror upper order size


type HermitianP pack = Mosaic pack ConjugateMirror Shape.Upper
type Hermitian = HermitianP Packed

hermitian :: Order -> size -> Hermitian size
hermitian = hermitianP Packed

hermitianP :: PackingSingleton pack -> Order -> size -> HermitianP pack size
hermitianP pack = Mosaic pack ConjugateMirror Upper

hermitianFromSymmetric :: SymmetricP pack size -> HermitianP pack size
hermitianFromSymmetric (Mosaic pack SimpleMirror upper order size) =
   Mosaic pack ConjugateMirror upper order size

uploFromOrder :: Order -> Char
uploFromOrder RowMajor = 'L'
uploFromOrder ColumnMajor = 'U'



newtype Bands offDiag = Bands (UnaryProxy offDiag) deriving (Eq, Show)

type family GetBands strip
type instance GetBands (Bands offDiag) = offDiag

type Empty = Bands TypeNum.U0
data Filled = Filled deriving (Eq, Show)

u0 :: UnaryProxy TypeNum.U0
u0 = unary TypeNum.u0

empty :: Empty
empty = Bands u0

type family TriTransposed uplo
type instance TriTransposed Shape.Lower = Shape.Upper
type instance TriTransposed Shape.Upper = Shape.Lower

triangularTranspose ::
   (UpLo uplo) =>
   Mosaic pack mirror uplo sh ->
   Mosaic pack mirror (TriTransposed uplo) sh
triangularTranspose (Mosaic pack mirror uplo order size) =
   Mosaic pack mirror
      (case uplo of
         Lower -> Upper
         Upper -> Lower)
      (flipOrder order)
      size


autoUplo :: (UpLo uplo) => UpLoSingleton uplo
autoUplo = switchUpLo Upper Lower

uploOrder :: UpLoSingleton uplo -> Order -> Order
uploOrder uplo = case uplo of Lower -> flipOrder; Upper -> id


class UpLo uplo where
   switchUpLo :: f Shape.Upper -> f Shape.Lower -> f uplo

instance UpLo Shape.Upper where
   switchUpLo f _ = f

instance UpLo Shape.Lower where
   switchUpLo _ f = f

data UpLoSingleton uplo where
   Lower :: UpLoSingleton Shape.Lower
   Upper :: UpLoSingleton Shape.Upper

instance Eq (UpLoSingleton uplo) where
   Lower == Lower  =  True
   Upper == Upper  =  True

instance Show (UpLoSingleton uplo) where
   show Lower = "Lower"
   show Upper = "Upper"

instance NFData (UpLoSingleton uplo) where
   rnf Lower = ()
   rnf Upper = ()

uploChar :: UpLoSingleton uplo -> Char
uploChar Lower = 'L'
uploChar Upper = 'U'


instance
   (UpLo uplo, NFData size) =>
      NFData (Mosaic pack mirror uplo size) where
   rnf (Mosaic pack mirror uplo order size) =
      rnf (pack, mirror, uplo, order, size)

instance
   (UpLo uplo, Shape.C size) =>
      Shape.C (Mosaic pack mirror uplo size) where

   size (Mosaic pack _mirror _uplo order size) =
      case pack of
         Packed -> triangleSize $ Shape.size size
         Unpacked -> Shape.size $ square order size

instance
   (UpLo uplo, Shape.Indexed size) =>
      Shape.Indexed (Mosaic pack mirror uplo size) where
   type Index (Mosaic pack mirror uplo size) =
         (Shape.Index size, Shape.Index size)

   indices (Mosaic pack _mirror uplo order size) =
      case (pack,uplo) of
         (Unpacked,_) -> Shape.indices $ square order size
         (Packed,Upper) -> triangleIndices order size
         (Packed,Lower) -> map swap $ triangleIndices (flipOrder order) size

   unifiedOffset (Mosaic pack _mirror uplo order size) =
      case (pack,uplo) of
         (Unpacked,_) -> Shape.unifiedOffset $ square order size
         (Packed,Upper) -> triangleOffset order size
         (Packed,Lower) -> triangleOffset (flipOrder order) size . swap

   unifiedSizeOffset (Mosaic pack _mirror uplo order size) =
      case (pack,uplo) of
         (Unpacked,_) -> Shape.unifiedSizeOffset $ square order size
         (Packed,Upper) -> triangleSizeOffset order size
         (Packed,Lower) ->
            mapSnd (.swap) $ triangleSizeOffset (flipOrder order) size

   inBounds (Mosaic pack _mirror uplo _ size) ix@(r,c) =
      Shape.inBounds (size,size) ix
      &&
      case (pack,uplo) of
         (Unpacked,_) -> True
         (Packed,Upper) -> Shape.offset size r <= Shape.offset size c
         (Packed,Lower) -> Shape.offset size r >= Shape.offset size c

instance
   (UpLo uplo, Shape.InvIndexed size) =>
      Shape.InvIndexed (Mosaic pack mirror uplo size) where

   unifiedIndexFromOffset (Mosaic pack _mirror uplo order size) k =
      case (pack,uplo) of
         (Unpacked,_) ->
            Shape.unifiedIndexFromOffset (square order size) k
         (Packed,Upper) -> triangleIndexFromOffset order size k
         (Packed,Lower) ->
            swap <$> triangleIndexFromOffset (flipOrder order) size k


squareRootDouble :: Int -> Double
squareRootDouble = sqrt . fromIntegral

squareExtent :: String -> Int -> Int
squareExtent name size =
   let n = round $ squareRootDouble size
   in if size == n*n
        then n
        else error (name ++ ": no square number of elements")


triangleRootDouble :: Int -> Double
triangleRootDouble = triangleRoot . fromIntegral

triangleExtent :: String -> Int -> Int
triangleExtent name size =
   let n = round $ triangleRootDouble size
   in if size == triangleSize n
        then n
        else error (name ++ ": no triangular number of elements")

triangleIndices ::
   (Shape.Indexed sh) => Order -> sh -> [(Shape.Index sh, Shape.Index sh)]
triangleIndices RowMajor = Shape.indices . Shape.upperTriangular
triangleIndices ColumnMajor = map swap . Shape.indices . Shape.lowerTriangular

triangleOffset ::
   (Shape.Checking check, Shape.Indexed sh) =>
   Order -> sh -> (Shape.Index sh, Shape.Index sh) -> Shape.Result check Int
triangleOffset order size =
   case order of
      RowMajor    -> Shape.unifiedOffset (Shape.upperTriangular size)
      ColumnMajor -> Shape.unifiedOffset (Shape.lowerTriangular size) . swap

triangleSizeOffset ::
   (Shape.Checking check, Shape.Indexed sh) =>
   Order -> sh ->
   (Int, (Shape.Index sh, Shape.Index sh) -> Shape.Result check Int)
triangleSizeOffset order size =
   case order of
      RowMajor -> Shape.unifiedSizeOffset (Shape.upperTriangular size)
      ColumnMajor ->
         mapSnd (.swap) $ Shape.unifiedSizeOffset (Shape.lowerTriangular size)

triangleIndexFromOffset ::
   (Shape.Checking check, Shape.InvIndexed sh) =>
   Order -> sh -> Int -> Shape.Result check (Shape.Index sh, Shape.Index sh)
triangleIndexFromOffset order size =
   case order of
      RowMajor -> Shape.unifiedIndexFromOffset (Shape.upperTriangular size)
      ColumnMajor ->
         fmap swap . Shape.unifiedIndexFromOffset (Shape.lowerTriangular size)


type UnaryProxy a = Proxy (Unary.Un a)

data Banded sub super meas vert horiz height width =
   Banded {
      bandedOffDiagonals :: (UnaryProxy sub, UnaryProxy super),
      bandedOrder :: Order,
      bandedExtent :: Extent meas vert horiz height width
   } deriving (Eq, Show)

type BandedGeneral sub super =
      Banded sub super Extent.Size Extent.Big Extent.Big
type BandedSquareMeas sub super meas height width =
      Banded sub super meas Extent.Small Extent.Small height width
type BandedSquare sub super size =
      BandedSquareMeas sub super Extent.Shape size size

type BandedLowerTriangular sub size = BandedSquare sub TypeNum.U0 size
type BandedUpperTriangular super size = BandedSquare TypeNum.U0 super size

type Diagonal size = BandedSquare TypeNum.U0 TypeNum.U0 size
type RectangularDiagonal = Banded TypeNum.U0 TypeNum.U0


bandedHeight ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Banded sub super meas vert horiz height width -> height
bandedHeight = Extent.height . bandedExtent

bandedWidth ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Banded sub super meas vert horiz height width -> width
bandedWidth = Extent.width . bandedExtent

bandedMapExtent ::
   Extent.Map measA vertA horizA measB vertB horizB height width ->
   Banded sub super measA vertA horizA height width ->
   Banded sub super measB vertB horizB height width
bandedMapExtent f (Banded offDiag order extent) =
   Banded offDiag order $ f extent

bandedBreadth ::
   (Unary.Natural sub, Unary.Natural super) =>
   (UnaryProxy sub, UnaryProxy super) -> Int
bandedBreadth (sub,super) =
   integralFromProxy sub + 1 + integralFromProxy super

numOffDiagonals ::
   (Unary.Natural sub, Unary.Natural super) =>
   Order -> (UnaryProxy sub, UnaryProxy super) -> (Int,Int)
numOffDiagonals order (sub,super) =
   swapOnRowMajor order (integralFromProxy sub, integralFromProxy super)

natFromProxy :: (Unary.Natural n) => UnaryProxy n -> Proof.Nat n
natFromProxy Proxy = Proof.Nat

addOffDiagonals ::
   (Unary.Natural subA, Unary.Natural superA,
    Unary.Natural subB, Unary.Natural superB,
    (subA :+: subB) ~ subC,
    (superA :+: superB) ~ superC) =>
   (UnaryProxy subA, UnaryProxy superA) ->
   (UnaryProxy subB, UnaryProxy superB) ->
   ((Proof.Nat subC, Proof.Nat superC),
    (UnaryProxy subC, UnaryProxy superC))
addOffDiagonals (subA,superA) (subB,superB) =
   ((Proof.addNat (natFromProxy subA) (natFromProxy subB),
     Proof.addNat (natFromProxy superA) (natFromProxy superB)),
    (Proxy,Proxy))

bandedTranspose ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Banded sub super meas vert horiz height width ->
   Banded super sub meas horiz vert width height
bandedTranspose (Banded (sub,super) order extent) =
   Banded (super,sub) (flipOrder order) (Extent.transpose extent)

diagonalInverse ::
   (Extent.Measure meas) =>
   BandedSquareMeas TypeNum.U0 TypeNum.U0 meas height width ->
   BandedSquareMeas TypeNum.U0 TypeNum.U0 meas width height
diagonalInverse (Banded (sub,super) order extent) =
   Banded (super,sub) order (Extent.transpose extent)


bandedGeneral ::
   (UnaryProxy sub, UnaryProxy super) -> Order -> height -> width ->
   BandedGeneral sub super height width
bandedGeneral offDiag order height width =
   Banded offDiag order (Extent.general height width)

bandedSquare ::
   (UnaryProxy sub, UnaryProxy super) -> Order -> size ->
   BandedSquare sub super size
bandedSquare offDiag order = Banded offDiag order . Extent.square

rectangularDiagonal ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width) =>
   Extent meas vert horiz height width ->
   (Int, RectangularDiagonal meas vert horiz height width)
rectangularDiagonal extent =
   let m = Shape.size $ Extent.height extent
       n = Shape.size $ Extent.width extent
       order = if m <= n then RowMajor else ColumnMajor
   in (min m n, Banded (u0,u0) order extent)


data BandedIndex row column =
     InsideBox row column
   | VertOutsideBox Int column
   | HorizOutsideBox row Int
   deriving (Eq, Show)

instance
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    NFData height, NFData width) =>
      NFData (Banded sub super meas vert horiz height width) where
   rnf (Banded (Proxy,Proxy) order extent) = rnf (order, extent)

instance
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
      Shape.C (Banded sub super meas vert horiz height width) where

   size (Banded offDiag order extent) =
      bandedBreadth offDiag *
      case order of
         RowMajor -> Shape.size (Extent.height extent)
         ColumnMajor -> Shape.size (Extent.width extent)

instance
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.Indexed height, Shape.Indexed width) =>
      Shape.Indexed (Banded sub super meas vert horiz height width) where

   type Index (Banded sub super meas vert horiz height width) =
            BandedIndex (Shape.Index height) (Shape.Index width)
   indices (Banded (sub,super) order extent) =
      let (height,width) = Extent.dimensions extent
      in case order of
            RowMajor ->
               map (\(r,c) -> either (HorizOutsideBox r) (InsideBox r) c) $
               bandedIndicesRowMajor (sub,super) (height,width)
            ColumnMajor ->
               map (\(c,r) ->
                     either (flip VertOutsideBox c) (flip InsideBox c) r) $
               bandedIndicesRowMajor (super,sub) (width,height)

   unifiedOffset shape@(Banded (sub,super) order extent) ix = do
      Shape.assert "Banded.offset: index outside band" $
         Shape.inBounds shape ix
      let (height,width) = Extent.dimensions extent
          kl = integralFromProxy sub
          ku = integralFromProxy super
      return $ bandedOffset (kl,ku) order (height,width) ix

   inBounds (Banded (sub,super) order extent) ix =
      let (height,width) = Extent.dimensions extent
          kl = integralFromProxy sub
          ku = integralFromProxy super
          insideBand r c = Shape.inBounds (Shape.Range (-kl) ku) (c-r)
      in case (order,ix) of
            (_, InsideBox r c) ->
               Shape.inBounds (height,width) (r,c)
               &&
               insideBand (Shape.offset height r) (Shape.offset width c)
            (RowMajor, HorizOutsideBox r c) ->
               Shape.inBounds height r
               &&
               insideBand (Shape.offset height r) (outsideOffset width c)
            (ColumnMajor, VertOutsideBox r c) ->
               Shape.inBounds width c
               &&
               insideBand (outsideOffset height r) (Shape.offset width c)
            _ -> False

instance
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.InvIndexed height, Shape.InvIndexed width) =>
      Shape.InvIndexed (Banded sub super meas vert horiz height width) where

   unifiedIndexFromOffset (Banded (sub,super) order extent) j =
      bandedIndexFromOffset
         (integralFromProxy sub, integralFromProxy super) order
         (Extent.dimensions extent) j

outsideOffset :: Shape.C sh => sh -> Int -> Int
outsideOffset size k = if k<0 then k else Shape.size size + k

bandedOffset ::
   (Shape.Indexed height, Shape.Indexed width) =>
   (Int, Int) -> Order -> (height, width) ->
   BandedIndex (Shape.Index height) (Shape.Index width) -> Int
bandedOffset (kl,ku) order (height,width) ix =
   let k = kl+ku
   in case ix of
         InsideBox r c ->
            let i = Shape.uncheckedOffset height r
                j = Shape.uncheckedOffset width c
            in case order of
                  RowMajor -> k*i + kl+j
                  ColumnMajor -> k*j + ku+i
         VertOutsideBox r c ->
            let i = outsideOffset height r
                j = Shape.uncheckedOffset width c
            in  k*j + ku+i
         HorizOutsideBox r c ->
            let i = Shape.uncheckedOffset height r
                j = outsideOffset width c
            in  k*i + kl+j

bandedIndicesRowMajor ::
   (Unary.Natural sub, Unary.Natural super,
    Shape.Indexed height, Shape.Indexed width) =>
   (UnaryProxy sub, UnaryProxy super) ->
   (height, width) ->
   [(Shape.Index height, Either Int (Shape.Index width))]
bandedIndicesRowMajor (sub,super) (height,width) =
   let kl = integralFromProxy sub
       ku = integralFromProxy super
   in concat $
      zipWith (\r -> map ((,) r)) (Shape.indices height) $
      map (take (kl+1+ku)) $ tails $
         (map Left $ take kl $ iterate (1+) (-kl)) ++
         (map Right $ Shape.indices width) ++
         (map Left $ iterate (1+) 0)

bandedIndexFromOffset ::
   (Shape.Checking check, Shape.InvIndexed height, Shape.InvIndexed width) =>
   (Int,Int) -> Order -> (height,width) -> Int ->
   Shape.Result check (BandedIndex (Shape.Index height) (Shape.Index width))
bandedIndexFromOffset (kl,ku) order (height,width) =
   case order of
      RowMajor -> let n = Shape.size width in \j -> do
         let (rb,cb) = divMod j (kl+1+ku)
         r <- Shape.unifiedIndexFromOffset height rb
         let ci = rb+cb-kl
         if' (ci<0) (return $ HorizOutsideBox r ci) $
            if' (ci>=n) (return $ HorizOutsideBox r (ci-n)) $
            (InsideBox r <$> Shape.unifiedIndexFromOffset width ci)
      ColumnMajor -> \j -> do
         let m = Shape.size height
         let (cb,rb) = divMod j (kl+1+ku)
         c <- Shape.unifiedIndexFromOffset width cb
         let ri = rb+cb-ku
         if' (ri<0) (return $ VertOutsideBox ri c) $
            if' (ri>=m) (return $ VertOutsideBox (ri-m) c) $
            (flip InsideBox c <$> Shape.unifiedIndexFromOffset height ri)


data BandedHermitian off size =
   BandedHermitian {
      bandedHermitianOffDiagonals :: UnaryProxy off,
      bandedHermitianOrder :: Order,
      bandedHermitianSize :: size
   } deriving (Eq, Show)

instance (Unary.Natural off, NFData size) =>
      NFData (BandedHermitian off size) where
   rnf (BandedHermitian Proxy order size) = rnf (order, size)

instance (Unary.Natural off, Shape.C size) =>
      Shape.C (BandedHermitian off size) where

   size (BandedHermitian offDiag _order size) =
      (1 + integralFromProxy offDiag) * Shape.size size

instance (Unary.Natural off, Shape.Indexed size) =>
      Shape.Indexed (BandedHermitian off size) where
   type Index (BandedHermitian off size) =
            BandedIndex (Shape.Index size) (Shape.Index size)
   indices (BandedHermitian offDiag order size) =
      case order of
         RowMajor ->
            map (\(r,c) -> either (HorizOutsideBox r) (InsideBox r) c) $
            bandedIndicesRowMajor (u0, offDiag) (size,size)
         ColumnMajor ->
            map (\(c,r) ->
                  either (flip VertOutsideBox c) (flip InsideBox c) r) $
            bandedIndicesRowMajor (offDiag, u0) (size,size)

   unifiedOffset shape@(BandedHermitian offDiag order size) ix = do
      Shape.assert "BandedHermitian.offset: index outside band" $
         Shape.inBounds shape ix
      let k = integralFromProxy offDiag
      return $ bandedOffset (0,k) order (size,size) ix

   inBounds (BandedHermitian offDiag order size) ix =
      let ku = integralFromProxy offDiag
          insideBand r c = Shape.inBounds (Shape.Range 0 ku) (c-r)
      in case (order,ix) of
            (_, InsideBox r c) ->
               Shape.inBounds (size,size) (r,c)
               &&
               insideBand (Shape.offset size r) (Shape.offset size c)
            (RowMajor, HorizOutsideBox r c) ->
               Shape.inBounds size r
               &&
               insideBand (Shape.offset size r) (outsideOffset size c)
            (ColumnMajor, VertOutsideBox r c) ->
               Shape.inBounds size c
               &&
               insideBand (outsideOffset size r) (Shape.offset size c)
            _ -> False

instance (Unary.Natural off, Shape.InvIndexed size) =>
      Shape.InvIndexed (BandedHermitian off size) where

   unifiedIndexFromOffset (BandedHermitian offDiag order size) j =
      bandedHermitianIndexFromOffset
         (integralFromProxy offDiag) order size j

bandedHermitianIndexFromOffset ::
   (Shape.Checking check, Shape.InvIndexed sh, Shape.Index sh ~ ix) =>
   Int -> Order -> sh -> Int -> Shape.Result check (BandedIndex ix ix)
bandedHermitianIndexFromOffset k order size =
   case order of
      RowMajor -> let n = Shape.size size in \j -> do
         let (rb,cb) = divMod j (k+1)
         r <- Shape.unifiedIndexFromOffset size rb
         let ci = rb+cb
         if ci<n
            then InsideBox r <$> Shape.unifiedIndexFromOffset size ci
            else return $ HorizOutsideBox r (ci-n)
      ColumnMajor -> \j -> do
         let (cb,rb) = divMod j (k+1)
         c <- Shape.unifiedIndexFromOffset size cb
         let ri = rb+cb-k
         if ri>=0
            then flip InsideBox c <$> Shape.unifiedIndexFromOffset size ri
            else return $ VertOutsideBox ri c