{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Layout.Private (
   module Numeric.LAPACK.Matrix.Layout.Private,
   module Numeric.BLAS.Matrix.Layout,
   ) where

import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Extent.Private (Extent)
import Numeric.BLAS.Matrix.Layout (Order(..), flipOrder, transposeFromOrder)

import qualified Type.Data.Num.Unary.Literal as TypeNum
import qualified Type.Data.Num.Unary.Proof as Proof
import qualified Type.Data.Num.Unary as Unary
import Type.Data.Num.Unary (unary, (:+:))
import Type.Data.Num (integralFromProxy)
import Type.Base.Proxy (Proxy(Proxy))

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape (triangleSize, triangleRoot)

import Control.DeepSeq (NFData, rnf)
import Control.Applicative ((<$>))

import Data.List (tails)
import Data.Tuple.HT (mapSnd, swap)
import Data.Bool.HT (if')


swapOnRowMajor :: Order -> (a,a) -> (a,a)
swapOnRowMajor order =
   case order of
      RowMajor -> swap
      ColumnMajor -> id

sideSwapFromOrder :: Order -> (a,a) -> (Char, (a,a))
sideSwapFromOrder order (m0,n0) =
   let ((side,m), (_,n)) = swapOnRowMajor order (('L', m0), ('R', n0))
   in (side,(m,n))


mapChecked ::
   (Shape.C sha, Shape.C shb) =>
   String -> (sha -> shb) -> sha -> shb
mapChecked name f sizeA =
   let sizeB = f sizeA
   in if Shape.size sizeA == Shape.size sizeB
         then sizeB
         else error $ name ++ ": sizes mismatch"


data Full meas vert horiz height width =
   Full {
      fullOrder :: Order,
      fullExtent :: Extent meas vert horiz height width
   } deriving (Eq, Show)

instance
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    NFData height, NFData width) =>
      NFData (Full meas vert horiz height width) where
   rnf (Full order extent) = rnf (order, extent)

instance
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
      Shape.C (Full meas vert horiz height width) where

   size (Full _ extent) = Shape.size (Extent.dimensions extent)

instance
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.Indexed height, Shape.Indexed width) =>
      Shape.Indexed (Full meas vert horiz height width) where

   type Index (Full meas vert horiz height width) =
            (Shape.Index height, Shape.Index width)
   indices (Full order extent) = fullIndices order extent

   unifiedOffset (Full RowMajor extent) =
      Shape.unifiedOffset (Extent.dimensions extent)
   unifiedOffset (Full ColumnMajor extent) =
      Shape.unifiedOffset (swap $ Extent.dimensions extent) . swap

   unifiedSizeOffset (Full RowMajor extent) =
      Shape.unifiedSizeOffset (Extent.dimensions extent)
   unifiedSizeOffset (Full ColumnMajor extent) =
      mapSnd (.swap) $
      Shape.unifiedSizeOffset (swap $ Extent.dimensions extent)

   inBounds (Full _ extent) = Shape.inBounds (Extent.dimensions extent)

instance
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.InvIndexed height, Shape.InvIndexed width) =>
      Shape.InvIndexed (Full meas vert horiz height width) where

   unifiedIndexFromOffset (Full order extent) =
      fullIndexFromOffset order extent


transpose ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Full meas vert horiz height width -> Full meas horiz vert width height
transpose (Full order extent) = Full (flipOrder order) (Extent.transpose extent)

inverse ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Full meas vert horiz height width -> Full meas horiz vert width height
inverse (Full order extent) = Full order (Extent.transpose extent)

dimensions ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
   Full meas vert horiz height width -> (Int, Int)
dimensions (Full order extent) =
   swapOnRowMajor order
      (Shape.size $ Extent.height extent,
       Shape.size $ Extent.width extent)

fullHeight ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Full meas vert horiz height width -> height
fullHeight = Extent.height . fullExtent

fullWidth ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Full meas vert horiz height width -> width
fullWidth = Extent.width . fullExtent


fullIndices ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.Indexed a, Shape.Indexed b) =>
   Order -> Extent meas vert horiz a b -> [(Shape.Index a, Shape.Index b)]
fullIndices order extent =
   case order of
      RowMajor -> Shape.indices $ Extent.dimensions extent
      ColumnMajor -> map swap $ Shape.indices $ swap $ Extent.dimensions extent

fullIndexFromOffset ::
   (Shape.Checking check,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.InvIndexed a, Shape.InvIndexed b) =>
   Order -> Extent meas vert horiz a b -> Int ->
   Shape.Result check (Shape.Index a, Shape.Index b)
fullIndexFromOffset order extent =
   case order of
      RowMajor ->
         Shape.unifiedIndexFromOffset (Extent.dimensions extent)
      ColumnMajor ->
         fmap swap .
         Shape.unifiedIndexFromOffset (swap $ Extent.dimensions extent)


type General height width = Full Extent.Size Extent.Big Extent.Big height width
type Tall height width = Full Extent.Size Extent.Big Extent.Small height width
type Wide height width = Full Extent.Size Extent.Small Extent.Big height width
type LiberalSquare height width = SquareMeas Extent.Size height width
type Square size = SquareMeas Extent.Shape size size
type SquareMeas meas height width =
         Full meas Extent.Small Extent.Small height width


fullMapExtent ::
   Extent.Map measA vertA horizA measB vertB horizB height width ->
   Full measA vertA horizA height width ->
   Full measB vertB horizB height width
fullMapExtent f (Full order extent) = Full order $ f extent

general :: Order -> height -> width -> General height width
general order height width = Full order $ Extent.general height width

tall ::
   (Shape.C height, Shape.C width) =>
   Order -> height -> width -> Tall height width
tall order height width =
   if Shape.size height >= Shape.size width
      then Full order $ Extent.tall height width
      else error "Layout.tall: height smaller than width"

wide ::
   (Shape.C height, Shape.C width) =>
   Order -> height -> width -> Wide height width
wide order height width =
   if Shape.size height <= Shape.size width
      then Full order $ Extent.wide height width
      else error "Layout.wide: width smaller than height"

liberalSquare ::
   (Shape.C height, Shape.C width) =>
   Order -> height -> width -> LiberalSquare height width
liberalSquare order height width =
   if Shape.size height == Shape.size width
      then Full order $ Extent.liberalSquare height width
      else error "Layout.liberalSquare: height and width sizes differ"

square :: Order -> sh -> Square sh
square order sh = Full order $ Extent.square sh


caseTallWide ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
   Full meas vert horiz height width ->
   Either (Tall height width) (Wide height width)
caseTallWide (Full order extent) =
   either (Left . Full order) (Right . Full order) $
   Extent.caseTallWide (\h w -> Shape.size h >= Shape.size w) extent


data Split lower meas vert horiz height width =
   Split {
      splitLower :: lower,
      splitOrder :: Order,
      splitExtent :: Extent meas vert horiz height width
   } deriving (Eq, Show)

splitHeight ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Split lower meas vert horiz height width -> height
splitHeight = Extent.height . splitExtent

splitWidth ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Split lower meas vert horiz height width -> width
splitWidth = Extent.width . splitExtent

splitMapExtent ::
   Extent.Map measA vertA horizA measB vertB horizB height width ->
   Split lower measA vertA horizA height width ->
   Split lower measB vertB horizB height width
splitMapExtent f (Split lowerPart order extent) =
   Split lowerPart order $ f extent


caseTallWideSplit ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
   Split lower meas vert horiz height width ->
   Either
      (Split lower Extent.Size Extent.Big Extent.Small height width)
      (Split lower Extent.Size Extent.Small Extent.Big height width)
caseTallWideSplit (Split lowerPart order extent) =
   either (Left . Split lowerPart order) (Right . Split lowerPart order) $
   Extent.caseTallWide (\h w -> Shape.size h >= Shape.size w) extent

data Reflector = Reflector deriving (Eq, Show)
data Triangle = Triangle deriving (Eq, Show)

instance NFData Reflector where rnf Reflector = ()
instance NFData Triangle where rnf Triangle = ()

splitPart ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.Indexed height, Shape.Indexed width) =>
   Split lower meas vert horiz height width ->
   (Shape.Index height, Shape.Index width) -> Either lower Triangle
splitPart (Split lowerPart _ extent) (r,c) =
   if Shape.offset (Extent.height extent) r >
         Shape.offset (Extent.width extent) c
     then Left lowerPart
     else Right Triangle

instance
   (NFData lower, Extent.Measure meas, Extent.C vert, Extent.C horiz,
    NFData height, NFData width) =>
      NFData (Split lower meas vert horiz height width) where
   rnf (Split lowerPart order extent) = rnf (lowerPart, order, extent)

instance
   (Eq lower, Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
      Shape.C (Split lower meas vert horiz height width) where

   size (Split _ _ extent) = Shape.size (Extent.dimensions extent)

instance
   (Eq lower, Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.Indexed height, Shape.Indexed width) =>
      Shape.Indexed (Split lower meas vert horiz height width) where

   type Index (Split lower meas vert horiz height width) =
            (Either lower Triangle,
             (Shape.Index height, Shape.Index width))

   indices sh@(Split _ order extent) =
      map (\ix -> (splitPart sh ix, ix)) $ fullIndices order extent

   unifiedOffset sh@(Split _ order extent) (part,ix) = do
      Shape.assert "Shape.Split.offset: wrong matrix part" $
         part == splitPart sh ix
      case order of
         RowMajor -> Shape.unifiedOffset (Extent.dimensions extent) ix
         ColumnMajor ->
            Shape.unifiedOffset (swap $ Extent.dimensions extent) (swap ix)

   unifiedSizeOffset sh@(Split _ order extent) =
      let check (part,ix) a = do
            Shape.assert "Shape.Split.sizeOffset: wrong matrix part" $
               part == splitPart sh ix
            return a
      in case order of
            RowMajor ->
               mapSnd
                  (\getOffset (part,ix) -> check (part,ix) =<< getOffset ix) $
               Shape.unifiedSizeOffset (Extent.dimensions extent)
            ColumnMajor ->
               mapSnd
                  (\getOffset (part,ix) ->
                     check (part,ix) =<< getOffset (swap ix)) $
               Shape.unifiedSizeOffset (swap $ Extent.dimensions extent)

   inBounds sh@(Split _ _ extent) (part,ix) =
      Shape.inBounds (Extent.dimensions extent) ix
      &&
      part == splitPart sh ix

instance
   (Eq lower, Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.InvIndexed height, Shape.InvIndexed width) =>
      Shape.InvIndexed (Split lower meas vert horiz height width) where

   unifiedIndexFromOffset sh@(Split _ order extent) k = do
      ix <- fullIndexFromOffset order extent k
      return (splitPart sh ix, ix)



data Mosaic pack mirror uplo size =
   Mosaic {
      mosaicPack :: PackingSingleton pack,
      mosaicMirror :: MirrorSingleton mirror,
      mosaicUplo :: UpLoSingleton uplo,
      mosaicOrder :: Order,
      mosaicSize :: size
   } deriving (Eq, Show)

data Packed
data Unpacked

data PackingSingleton pack where
   Packed :: PackingSingleton Packed
   Unpacked :: PackingSingleton Unpacked

deriving instance Eq (PackingSingleton pack)
deriving instance Show (PackingSingleton pack)

instance NFData (PackingSingleton pack) where
   rnf Packed = ()
   rnf Unpacked = ()

class Packing pack where autoPacking :: PackingSingleton pack
instance Packing Unpacked where autoPacking = Unpacked
instance Packing Packed where autoPacking = Packed

squareFromMosaic :: Mosaic Unpacked mirror uplo size -> Square size
squareFromMosaic (Mosaic {mosaicOrder = order, mosaicSize = size}) =
   square order size

mosaicFromSquare ::
   (Mirror mirror, UpLo uplo) => Square size -> Mosaic Unpacked mirror uplo size
mosaicFromSquare (Full {fullOrder = order, fullExtent = extent}) =
   Mosaic {
      mosaicPack = Unpacked,
      mosaicMirror = autoMirror,
      mosaicUplo = autoUplo,
      mosaicOrder = order,
      mosaicSize = Extent.squareSize extent
   }


data NoMirror
data SimpleMirror
data ConjugateMirror

data MirrorSingleton mirror where
   NoMirror :: MirrorSingleton NoMirror
   SimpleMirror :: MirrorSingleton SimpleMirror
   ConjugateMirror :: MirrorSingleton ConjugateMirror

deriving instance Eq (MirrorSingleton mirror)
deriving instance Show (MirrorSingleton mirror)

instance NFData (MirrorSingleton mirror) where
   rnf NoMirror = ()
   rnf SimpleMirror = ()
   rnf ConjugateMirror = ()

class Mirror mirror where autoMirror :: MirrorSingleton mirror
instance Mirror NoMirror where autoMirror = NoMirror
instance Mirror SimpleMirror where autoMirror = SimpleMirror
instance Mirror ConjugateMirror where autoMirror = ConjugateMirror


type TriangularP pack = Mosaic pack NoMirror
type Triangular = TriangularP Packed

type LowerTriangularP pack = TriangularP pack Shape.Lower
type LowerTriangular = Triangular Shape.Lower

type UpperTriangularP pack = TriangularP pack Shape.Upper
type UpperTriangular = Triangular Shape.Upper

triangular :: UpLoSingleton uplo -> Order -> size -> Triangular uplo size
triangular = Mosaic Packed NoMirror

upperTriangular :: Order -> size -> UpperTriangular size
upperTriangular = triangular Upper

lowerTriangular :: Order -> size -> LowerTriangular size
lowerTriangular = triangular Lower


triangularP ::
   PackingSingleton pack ->
   UpLoSingleton uplo -> Order -> size -> TriangularP pack uplo size
triangularP pack = Mosaic pack NoMirror

upperTriangularP ::
   PackingSingleton pack -> Order -> size -> UpperTriangularP pack size
upperTriangularP pack = triangularP pack Upper

lowerTriangularP ::
   PackingSingleton pack -> Order -> size -> LowerTriangularP pack size
lowerTriangularP pack = triangularP pack Lower


type SymmetricP pack = Mosaic pack SimpleMirror Shape.Upper
type Symmetric = SymmetricP Packed

symmetric :: Order -> size -> Symmetric size
symmetric = symmetricP Packed

symmetricP :: PackingSingleton pack -> Order -> size -> SymmetricP pack size
symmetricP pack = Mosaic pack SimpleMirror Upper

symmetricFromHermitian :: HermitianP pack size -> SymmetricP pack size
symmetricFromHermitian (Mosaic pack ConjugateMirror upper order size) =
   Mosaic pack SimpleMirror upper order size


type HermitianP pack = Mosaic pack ConjugateMirror Shape.Upper
type Hermitian = HermitianP Packed

hermitian :: Order -> size -> Hermitian size
hermitian = hermitianP Packed

hermitianP :: PackingSingleton pack -> Order -> size -> HermitianP pack size
hermitianP pack = Mosaic pack ConjugateMirror Upper

hermitianFromSymmetric :: SymmetricP pack size -> HermitianP pack size
hermitianFromSymmetric (Mosaic pack SimpleMirror upper order size) =
   Mosaic pack ConjugateMirror upper order size

uploFromOrder :: Order -> Char
uploFromOrder RowMajor = 'L'
uploFromOrder ColumnMajor = 'U'



newtype Bands offDiag = Bands (UnaryProxy offDiag) deriving (Eq, Show)

type family GetBands strip
type instance GetBands (Bands offDiag) = offDiag

type Empty = Bands TypeNum.U0
data Filled = Filled deriving (Eq, Show)

u0 :: UnaryProxy TypeNum.U0
u0 = unary TypeNum.u0

empty :: Empty
empty = Bands u0

type family TriTransposed uplo
type instance TriTransposed Shape.Lower = Shape.Upper
type instance TriTransposed Shape.Upper = Shape.Lower

triangularTranspose ::
   (UpLo uplo) =>
   Mosaic pack mirror uplo sh ->
   Mosaic pack mirror (TriTransposed uplo) sh
triangularTranspose (Mosaic pack mirror uplo order size) =
   Mosaic pack mirror
      (case uplo of
         Lower -> Upper
         Upper -> Lower)
      (flipOrder order)
      size


autoUplo :: (UpLo uplo) => UpLoSingleton uplo
autoUplo = switchUpLo Upper Lower

uploOrder :: UpLoSingleton uplo -> Order -> Order
uploOrder uplo = case uplo of Lower -> flipOrder; Upper -> id


class UpLo uplo where
   switchUpLo :: f Shape.Upper -> f Shape.Lower -> f uplo

instance UpLo Shape.Upper where
   switchUpLo f _ = f

instance UpLo Shape.Lower where
   switchUpLo _ f = f

data UpLoSingleton uplo where
   Lower :: UpLoSingleton Shape.Lower
   Upper :: UpLoSingleton Shape.Upper

instance Eq (UpLoSingleton uplo) where
   Lower == Lower  =  True
   Upper == Upper  =  True

instance Show (UpLoSingleton uplo) where
   show Lower = "Lower"
   show Upper = "Upper"

instance NFData (UpLoSingleton uplo) where
   rnf Lower = ()
   rnf Upper = ()

uploChar :: UpLoSingleton uplo -> Char
uploChar Lower = 'L'
uploChar Upper = 'U'


instance
   (UpLo uplo, NFData size) =>
      NFData (Mosaic pack mirror uplo size) where
   rnf (Mosaic pack mirror uplo order size) =
      rnf (pack, mirror, uplo, order, size)

instance
   (UpLo uplo, Shape.C size) =>
      Shape.C (Mosaic pack mirror uplo size) where

   size (Mosaic pack _mirror _uplo order size) =
      case pack of
         Packed -> triangleSize $ Shape.size size
         Unpacked -> Shape.size $ square order size

instance
   (UpLo uplo, Shape.Indexed size) =>
      Shape.Indexed (Mosaic pack mirror uplo size) where
   type Index (Mosaic pack mirror uplo size) =
         (Shape.Index size, Shape.Index size)

   indices (Mosaic pack _mirror uplo order size) =
      case (pack,uplo) of
         (Unpacked,_) -> Shape.indices $ square order size
         (Packed,Upper) -> triangleIndices order size
         (Packed,Lower) -> map swap $ triangleIndices (flipOrder order) size

   unifiedOffset (Mosaic pack _mirror uplo order size) =
      case (pack,uplo) of
         (Unpacked,_) -> Shape.unifiedOffset $ square order size
         (Packed,Upper) -> triangleOffset order size
         (Packed,Lower) -> triangleOffset (flipOrder order) size . swap

   unifiedSizeOffset (Mosaic pack _mirror uplo order size) =
      case (pack,uplo) of
         (Unpacked,_) -> Shape.unifiedSizeOffset $ square order size
         (Packed,Upper) -> triangleSizeOffset order size
         (Packed,Lower) ->
            mapSnd (.swap) $ triangleSizeOffset (flipOrder order) size

   inBounds (Mosaic pack _mirror uplo _ size) ix@(r,c) =
      Shape.inBounds (size,size) ix
      &&
      case (pack,uplo) of
         (Unpacked,_) -> True
         (Packed,Upper) -> Shape.offset size r <= Shape.offset size c
         (Packed,Lower) -> Shape.offset size r >= Shape.offset size c

instance
   (UpLo uplo, Shape.InvIndexed size) =>
      Shape.InvIndexed (Mosaic pack mirror uplo size) where

   unifiedIndexFromOffset (Mosaic pack _mirror uplo order size) k =
      case (pack,uplo) of
         (Unpacked,_) ->
            Shape.unifiedIndexFromOffset (square order size) k
         (Packed,Upper) -> triangleIndexFromOffset order size k
         (Packed,Lower) ->
            swap <$> triangleIndexFromOffset (flipOrder order) size k


squareRootDouble :: Int -> Double
squareRootDouble = sqrt . fromIntegral

squareExtent :: String -> Int -> Int
squareExtent name size =
   let n = round $ squareRootDouble size
   in if size == n*n
        then n
        else error (name ++ ": no square number of elements")


triangleRootDouble :: Int -> Double
triangleRootDouble = triangleRoot . fromIntegral

triangleExtent :: String -> Int -> Int
triangleExtent name size =
   let n = round $ triangleRootDouble size
   in if size == triangleSize n
        then n
        else error (name ++ ": no triangular number of elements")

triangleIndices ::
   (Shape.Indexed sh) => Order -> sh -> [(Shape.Index sh, Shape.Index sh)]
triangleIndices RowMajor = Shape.indices . Shape.upperTriangular
triangleIndices ColumnMajor = map swap . Shape.indices . Shape.lowerTriangular

triangleOffset ::
   (Shape.Checking check, Shape.Indexed sh) =>
   Order -> sh -> (Shape.Index sh, Shape.Index sh) -> Shape.Result check Int
triangleOffset order size =
   case order of
      RowMajor    -> Shape.unifiedOffset (Shape.upperTriangular size)
      ColumnMajor -> Shape.unifiedOffset (Shape.lowerTriangular size) . swap

triangleSizeOffset ::
   (Shape.Checking check, Shape.Indexed sh) =>
   Order -> sh ->
   (Int, (Shape.Index sh, Shape.Index sh) -> Shape.Result check Int)
triangleSizeOffset order size =
   case order of
      RowMajor -> Shape.unifiedSizeOffset (Shape.upperTriangular size)
      ColumnMajor ->
         mapSnd (.swap) $ Shape.unifiedSizeOffset (Shape.lowerTriangular size)

triangleIndexFromOffset ::
   (Shape.Checking check, Shape.InvIndexed sh) =>
   Order -> sh -> Int -> Shape.Result check (Shape.Index sh, Shape.Index sh)
triangleIndexFromOffset order size =
   case order of
      RowMajor -> Shape.unifiedIndexFromOffset (Shape.upperTriangular size)
      ColumnMajor ->
         fmap swap . Shape.unifiedIndexFromOffset (Shape.lowerTriangular size)


type UnaryProxy a = Proxy (Unary.Un a)

data Banded sub super meas vert horiz height width =
   Banded {
      bandedOffDiagonals :: (UnaryProxy sub, UnaryProxy super),
      bandedOrder :: Order,
      bandedExtent :: Extent meas vert horiz height width
   } deriving (Eq, Show)

type BandedGeneral sub super =
      Banded sub super Extent.Size Extent.Big Extent.Big
type BandedSquareMeas sub super meas height width =
      Banded sub super meas Extent.Small Extent.Small height width
type BandedSquare sub super size =
      BandedSquareMeas sub super Extent.Shape size size

type BandedLowerTriangular sub size = BandedSquare sub TypeNum.U0 size
type BandedUpperTriangular super size = BandedSquare TypeNum.U0 super size

type Diagonal size = BandedSquare TypeNum.U0 TypeNum.U0 size
type RectangularDiagonal = Banded TypeNum.U0 TypeNum.U0


bandedHeight ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Banded sub super meas vert horiz height width -> height
bandedHeight = Extent.height . bandedExtent

bandedWidth ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Banded sub super meas vert horiz height width -> width
bandedWidth = Extent.width . bandedExtent

bandedMapExtent ::
   Extent.Map measA vertA horizA measB vertB horizB height width ->
   Banded sub super measA vertA horizA height width ->
   Banded sub super measB vertB horizB height width
bandedMapExtent f (Banded offDiag order extent) =
   Banded offDiag order $ f extent

bandedBreadth ::
   (Unary.Natural sub, Unary.Natural super) =>
   (UnaryProxy sub, UnaryProxy super) -> Int
bandedBreadth (sub,super) =
   integralFromProxy sub + 1 + integralFromProxy super

numOffDiagonals ::
   (Unary.Natural sub, Unary.Natural super) =>
   Order -> (UnaryProxy sub, UnaryProxy super) -> (Int,Int)
numOffDiagonals order (sub,super) =
   swapOnRowMajor order (integralFromProxy sub, integralFromProxy super)

natFromProxy :: (Unary.Natural n) => UnaryProxy n -> Proof.Nat n
natFromProxy Proxy = Proof.Nat

addOffDiagonals ::
   (Unary.Natural subA, Unary.Natural superA,
    Unary.Natural subB, Unary.Natural superB,
    (subA :+: subB) ~ subC,
    (superA :+: superB) ~ superC) =>
   (UnaryProxy subA, UnaryProxy superA) ->
   (UnaryProxy subB, UnaryProxy superB) ->
   ((Proof.Nat subC, Proof.Nat superC),
    (UnaryProxy subC, UnaryProxy superC))
addOffDiagonals (subA,superA) (subB,superB) =
   ((Proof.addNat (natFromProxy subA) (natFromProxy subB),
     Proof.addNat (natFromProxy superA) (natFromProxy superB)),
    (Proxy,Proxy))

bandedTranspose ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Banded sub super meas vert horiz height width ->
   Banded super sub meas horiz vert width height
bandedTranspose (Banded (sub,super) order extent) =
   Banded (super,sub) (flipOrder order) (Extent.transpose extent)

diagonalInverse ::
   (Extent.Measure meas) =>
   BandedSquareMeas TypeNum.U0 TypeNum.U0 meas height width ->
   BandedSquareMeas TypeNum.U0 TypeNum.U0 meas width height
diagonalInverse (Banded (sub,super) order extent) =
   Banded (super,sub) order (Extent.transpose extent)


bandedGeneral ::
   (UnaryProxy sub, UnaryProxy super) -> Order -> height -> width ->
   BandedGeneral sub super height width
bandedGeneral offDiag order height width =
   Banded offDiag order (Extent.general height width)

bandedSquare ::
   (UnaryProxy sub, UnaryProxy super) -> Order -> size ->
   BandedSquare sub super size
bandedSquare offDiag order = Banded offDiag order . Extent.square

rectangularDiagonal ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width) =>
   Extent meas vert horiz height width ->
   (Int, RectangularDiagonal meas vert horiz height width)
rectangularDiagonal extent =
   let m = Shape.size $ Extent.height extent
       n = Shape.size $ Extent.width extent
       order = if m <= n then RowMajor else ColumnMajor
   in (min m n, Banded (u0,u0) order extent)


data BandedIndex row column =
     InsideBox row column
   | VertOutsideBox Int column
   | HorizOutsideBox row Int
   deriving (Eq, Show)

instance
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    NFData height, NFData width) =>
      NFData (Banded sub super meas vert horiz height width) where
   rnf (Banded (Proxy,Proxy) order extent) = rnf (order, extent)

instance
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
      Shape.C (Banded sub super meas vert horiz height width) where

   size (Banded offDiag order extent) =
      bandedBreadth offDiag *
      case order of
         RowMajor -> Shape.size (Extent.height extent)
         ColumnMajor -> Shape.size (Extent.width extent)

instance
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.Indexed height, Shape.Indexed width) =>
      Shape.Indexed (Banded sub super meas vert horiz height width) where

   type Index (Banded sub super meas vert horiz height width) =
            BandedIndex (Shape.Index height) (Shape.Index width)
   indices (Banded (sub,super) order extent) =
      let (height,width) = Extent.dimensions extent
      in case order of
            RowMajor ->
               map (\(r,c) -> either (HorizOutsideBox r) (InsideBox r) c) $
               bandedIndicesRowMajor (sub,super) (height,width)
            ColumnMajor ->
               map (\(c,r) ->
                     either (flip VertOutsideBox c) (flip InsideBox c) r) $
               bandedIndicesRowMajor (super,sub) (width,height)

   unifiedOffset shape@(Banded (sub,super) order extent) ix = do
      Shape.assert "Banded.offset: index outside band" $
         Shape.inBounds shape ix
      let (height,width) = Extent.dimensions extent
          kl = integralFromProxy sub
          ku = integralFromProxy super
      return $ bandedOffset (kl,ku) order (height,width) ix

   inBounds (Banded (sub,super) order extent) ix =
      let (height,width) = Extent.dimensions extent
          kl = integralFromProxy sub
          ku = integralFromProxy super
          insideBand r c = Shape.inBounds (Shape.Range (-kl) ku) (c-r)
      in case (order,ix) of
            (_, InsideBox r c) ->
               Shape.inBounds (height,width) (r,c)
               &&
               insideBand (Shape.offset height r) (Shape.offset width c)
            (RowMajor, HorizOutsideBox r c) ->
               Shape.inBounds height r
               &&
               insideBand (Shape.offset height r) (outsideOffset width c)
            (ColumnMajor, VertOutsideBox r c) ->
               Shape.inBounds width c
               &&
               insideBand (outsideOffset height r) (Shape.offset width c)
            _ -> False

instance
   (Unary.Natural sub, Unary.Natural super,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.InvIndexed height, Shape.InvIndexed width) =>
      Shape.InvIndexed (Banded sub super meas vert horiz height width) where

   unifiedIndexFromOffset (Banded (sub,super) order extent) j =
      bandedIndexFromOffset
         (integralFromProxy sub, integralFromProxy super) order
         (Extent.dimensions extent) j

outsideOffset :: Shape.C sh => sh -> Int -> Int
outsideOffset size k = if k<0 then k else Shape.size size + k

bandedOffset ::
   (Shape.Indexed height, Shape.Indexed width) =>
   (Int, Int) -> Order -> (height, width) ->
   BandedIndex (Shape.Index height) (Shape.Index width) -> Int
bandedOffset (kl,ku) order (height,width) ix =
   let k = kl+ku
   in case ix of
         InsideBox r c ->
            let i = Shape.uncheckedOffset height r
                j = Shape.uncheckedOffset width c
            in case order of
                  RowMajor -> k*i + kl+j
                  ColumnMajor -> k*j + ku+i
         VertOutsideBox r c ->
            let i = outsideOffset height r
                j = Shape.uncheckedOffset width c
            in  k*j + ku+i
         HorizOutsideBox r c ->
            let i = Shape.uncheckedOffset height r
                j = outsideOffset width c
            in  k*i + kl+j

bandedIndicesRowMajor ::
   (Unary.Natural sub, Unary.Natural super,
    Shape.Indexed height, Shape.Indexed width) =>
   (UnaryProxy sub, UnaryProxy super) ->
   (height, width) ->
   [(Shape.Index height, Either Int (Shape.Index width))]
bandedIndicesRowMajor (sub,super) (height,width) =
   let kl = integralFromProxy sub
       ku = integralFromProxy super
   in concat $
      zipWith (\r -> map ((,) r)) (Shape.indices height) $
      map (take (kl+1+ku)) $ tails $
         (map Left $ take kl $ iterate (1+) (-kl)) ++
         (map Right $ Shape.indices width) ++
         (map Left $ iterate (1+) 0)

bandedIndexFromOffset ::
   (Shape.Checking check, Shape.InvIndexed height, Shape.InvIndexed width) =>
   (Int,Int) -> Order -> (height,width) -> Int ->
   Shape.Result check (BandedIndex (Shape.Index height) (Shape.Index width))
bandedIndexFromOffset (kl,ku) order (height,width) =
   case order of
      RowMajor -> let n = Shape.size width in \j -> do
         let (rb,cb) = divMod j (kl+1+ku)
         r <- Shape.unifiedIndexFromOffset height rb
         let ci = rb+cb-kl
         if' (ci<0) (return $ HorizOutsideBox r ci) $
            if' (ci>=n) (return $ HorizOutsideBox r (ci-n)) $
            (InsideBox r <$> Shape.unifiedIndexFromOffset width ci)
      ColumnMajor -> \j -> do
         let m = Shape.size height
         let (cb,rb) = divMod j (kl+1+ku)
         c <- Shape.unifiedIndexFromOffset width cb
         let ri = rb+cb-ku
         if' (ri<0) (return $ VertOutsideBox ri c) $
            if' (ri>=m) (return $ VertOutsideBox (ri-m) c) $
            (flip InsideBox c <$> Shape.unifiedIndexFromOffset height ri)


data BandedHermitian off size =
   BandedHermitian {
      bandedHermitianOffDiagonals :: UnaryProxy off,
      bandedHermitianOrder :: Order,
      bandedHermitianSize :: size
   } deriving (Eq, Show)

instance (Unary.Natural off, NFData size) =>
      NFData (BandedHermitian off size) where
   rnf (BandedHermitian Proxy order size) = rnf (order, size)

instance (Unary.Natural off, Shape.C size) =>
      Shape.C (BandedHermitian off size) where

   size (BandedHermitian offDiag _order size) =
      (1 + integralFromProxy offDiag) * Shape.size size

instance (Unary.Natural off, Shape.Indexed size) =>
      Shape.Indexed (BandedHermitian off size) where
   type Index (BandedHermitian off size) =
            BandedIndex (Shape.Index size) (Shape.Index size)
   indices (BandedHermitian offDiag order size) =
      case order of
         RowMajor ->
            map (\(r,c) -> either (HorizOutsideBox r) (InsideBox r) c) $
            bandedIndicesRowMajor (u0, offDiag) (size,size)
         ColumnMajor ->
            map (\(c,r) ->
                  either (flip VertOutsideBox c) (flip InsideBox c) r) $
            bandedIndicesRowMajor (offDiag, u0) (size,size)

   unifiedOffset shape@(BandedHermitian offDiag order size) ix = do
      Shape.assert "BandedHermitian.offset: index outside band" $
         Shape.inBounds shape ix
      let k = integralFromProxy offDiag
      return $ bandedOffset (0,k) order (size,size) ix

   inBounds (BandedHermitian offDiag order size) ix =
      let ku = integralFromProxy offDiag
          insideBand r c = Shape.inBounds (Shape.Range 0 ku) (c-r)
      in case (order,ix) of
            (_, InsideBox r c) ->
               Shape.inBounds (size,size) (r,c)
               &&
               insideBand (Shape.offset size r) (Shape.offset size c)
            (RowMajor, HorizOutsideBox r c) ->
               Shape.inBounds size r
               &&
               insideBand (Shape.offset size r) (outsideOffset size c)
            (ColumnMajor, VertOutsideBox r c) ->
               Shape.inBounds size c
               &&
               insideBand (outsideOffset size r) (Shape.offset size c)
            _ -> False

instance (Unary.Natural off, Shape.InvIndexed size) =>
      Shape.InvIndexed (BandedHermitian off size) where

   unifiedIndexFromOffset (BandedHermitian offDiag order size) j =
      bandedHermitianIndexFromOffset
         (integralFromProxy offDiag) order size j

bandedHermitianIndexFromOffset ::
   (Shape.Checking check, Shape.InvIndexed sh, Shape.Index sh ~ ix) =>
   Int -> Order -> sh -> Int -> Shape.Result check (BandedIndex ix ix)
bandedHermitianIndexFromOffset k order size =
   case order of
      RowMajor -> let n = Shape.size size in \j -> do
         let (rb,cb) = divMod j (k+1)
         r <- Shape.unifiedIndexFromOffset size rb
         let ci = rb+cb
         if ci<n
            then InsideBox r <$> Shape.unifiedIndexFromOffset size ci
            else return $ HorizOutsideBox r (ci-n)
      ColumnMajor -> \j -> do
         let (cb,rb) = divMod j (k+1)
         c <- Shape.unifiedIndexFromOffset size cb
         let ri = rb+cb-k
         if ri>=0
            then flip InsideBox c <$> Shape.unifiedIndexFromOffset size ri
            else return $ VertOutsideBox ri c