{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Extent.Private where

import Numeric.LAPACK.Shape.Private (Unchecked(deconsUnchecked))
import Numeric.LAPACK.Wrapper (Flip(Flip, getFlip))

import Data.Array.Comfort.Shape ((:+:)((:+:)))

import Control.DeepSeq (NFData, rnf)

import Data.Maybe.HT (toMaybe)
import Data.Eq.HT (equating)


data family Extent vertical horizontal :: * -> * -> *

instance
   (C vertical, C horizontal, NFData height, NFData width) =>
      NFData (Extent vertical horizontal height width) where
   rnf =
      getAccessor $
      switchTagPair
         (Accessor $ \(Square s) -> rnf s)
         (Accessor $ \(Wide h w) -> rnf (h,w))
         (Accessor $ \(Tall h w) -> rnf (h,w))
         (Accessor $ \(General h w) -> rnf (h,w))


data Big = Big deriving (Eq,Show)
data Small = Small deriving (Eq,Show)

instance NFData Big where rnf Big = ()
instance NFData Small where rnf Small = ()

type General = Extent Big Big
type Tall = Extent Big Small
type Wide = Extent Small Big
type Square sh = Extent Small Small sh sh


data instance Extent Big Big height width =
   General {
      generalHeight :: height,
      generalWidth :: width
   }

data instance Extent Big Small height width =
   Tall {
      tallHeight :: height,
      tallWidth :: width
   }

data instance Extent Small Big height width =
   Wide {
      wideHeight :: height,
      wideWidth :: width
   }

data instance Extent Small Small height width =
   (height ~ width) =>
   Square {
      squareSize :: height
   }


general :: height -> width -> General height width
general = General

tall :: height -> width -> Tall height width
tall = Tall

wide :: height -> width -> Wide height width
wide = Wide

square :: sh -> Square sh
square = Square


newtype Map vertA horizA vertB horizB height width =
   Map {
      apply ::
         Extent vertA horizA height width ->
         Extent vertB horizB height width
   }


class C tag where switchTag :: f Small -> f Big -> f tag
instance C Small where switchTag f _ = f
instance C Big where switchTag _ f = f


switchTagPair ::
   (C vert, C horiz) =>
   f Small Small -> f Small Big -> f Big Small -> f Big Big -> f vert horiz
switchTagPair fSquare fWide fTall fGeneral =
   getFlip $
   switchTag
      (Flip $ switchTag fSquare fWide)
      (Flip $ switchTag fTall fGeneral)


newtype CaseTallWide height width vert horiz =
   CaseTallWide {
      getCaseTallWide ::
         Extent vert horiz height width ->
         Either (Tall height width) (Wide height width)
   }

caseTallWide ::
   (C vert, C horiz) =>
   (height -> width -> Bool) ->
   Extent vert horiz height width ->
   Either (Tall height width) (Wide height width)
caseTallWide ge =
   getCaseTallWide $
   switchTagPair
      (CaseTallWide $ \(Square sh) -> Left $ tall sh sh)
      (CaseTallWide Right)
      (CaseTallWide Left)
      (CaseTallWide $ \(General h w) ->
         if ge h w
            then Left $ tall h w
            else Right $ wide h w)


newtype GenSquare sh vert horiz =
   GenSquare {getGenSquare :: sh -> Extent vert horiz sh sh}

genSquare :: (C vert, C horiz) => sh -> Extent vert horiz sh sh
genSquare =
   getGenSquare $
   switchTagPair
      (GenSquare square)
      (GenSquare (\sh -> wide sh sh))
      (GenSquare (\sh -> tall sh sh))
      (GenSquare (\sh -> general sh sh))

newtype GenTall height width vert horiz =
   GenTall {
      getGenTall ::
         Extent vert Small height width -> Extent vert horiz height width
   }

generalizeTall :: (C vert, C horiz) =>
   Extent vert Small height width -> Extent vert horiz height width
generalizeTall =
   getGenTall $
   switchTagPair
      (GenTall id) (GenTall $ \(Square s) -> wide s s)
      (GenTall id) (GenTall $ \(Tall h w) -> general h w)

newtype GenWide height width vert horiz =
   GenWide {
      getGenWide ::
         Extent Small horiz height width -> Extent vert horiz height width
   }

generalizeWide :: (C vert, C horiz) =>
   Extent Small horiz height width -> Extent vert horiz height width
generalizeWide =
   getGenWide $
   switchTagPair
      (GenWide id)
      (GenWide id)
      (GenWide $ \(Square s) -> tall s s)
      (GenWide $ \(Wide h w) -> general h w)


newtype GenToTall height width vert horiz =
   GenToTall {
      getGenToTall ::
         Extent vert horiz height width -> Extent Big horiz height width
   }

genToTall :: (C vert, C horiz) =>
   Extent vert horiz height width -> Extent Big horiz height width
genToTall =
   getGenToTall $
   switchTagPair
      (GenToTall $ \(Square s) -> tall s s)
      (GenToTall $ \(Wide h w) -> general h w)
      (GenToTall id)
      (GenToTall id)


newtype GenToWide height width vert horiz =
   GenToWide {
      getGenToWide ::
         Extent vert horiz height width -> Extent vert Big height width
   }

genToWide :: (C vert, C horiz) =>
   Extent vert horiz height width -> Extent vert Big height width
genToWide =
   getGenToWide $
   switchTagPair
      (GenToWide $ \(Square s) -> wide s s)
      (GenToWide id)
      (GenToWide $ \(Tall h w) -> general h w)
      (GenToWide id)


newtype Accessor a height width vert horiz =
   Accessor {getAccessor :: Extent vert horiz height width -> a}

height :: (C vert, C horiz) => Extent vert horiz height width -> height
height =
   getAccessor $
   switchTagPair
      (Accessor squareSize)
      (Accessor wideHeight)
      (Accessor tallHeight)
      (Accessor generalHeight)

width :: (C vert, C horiz) => Extent vert horiz height width -> width
width =
   getAccessor $
   switchTagPair
      (Accessor (\(Square s) -> s))
      (Accessor wideWidth)
      (Accessor tallWidth)
      (Accessor generalWidth)


dimensions ::
   (C vert, C horiz) => Extent vert horiz height width -> (height,width)
dimensions x = (height x, width x)


toGeneral ::
   (C vert, C horiz) => Extent vert horiz height width -> General height width
toGeneral x = general (height x) (width x)

fromSquare :: (C vert, C horiz) => Square size -> Extent vert horiz size size
fromSquare = genSquare . squareSize

fromSquareLiberal :: (C vert, C horiz) =>
   Extent Small Small height width -> Extent vert horiz height width
fromSquareLiberal (Square s) = genSquare s

squareFromGeneral ::
   (C vert, C horiz, Eq size) =>
   Extent vert horiz size size -> Square size
squareFromGeneral x =
   let size = height x
   in if size == width x
        then square size
        else error "Extent.squareFromGeneral: no square shape"


newtype Transpose height width vert horiz =
   Transpose {
      getTranspose ::
         Extent vert horiz height width ->
         Extent horiz vert width height
   }

transpose ::
   (C vert, C horiz) =>
   Extent vert horiz height width ->
   Extent horiz vert width height
transpose =
   getTranspose $
   switchTagPair
      (Transpose $ \(Square s) -> Square s)
      (Transpose $ \(Wide h w) -> Tall w h)
      (Transpose $ \(Tall h w) -> Wide w h)
      (Transpose $ \(General h w) -> General w h)


newtype Equal height width vert horiz =
   Equal {
      getEqual ::
         Extent vert horiz height width ->
         Extent vert horiz height width -> Bool
   }

instance
   (C vert, C horiz, Eq height, Eq width) =>
      Eq (Extent vert horiz height width) where
   (==) =
      getEqual $
      switchTagPair
         (Equal $ \(Square a) (Square b) -> a==b)
         (Equal $ \a b -> equating wideHeight a b && equating wideWidth a b)
         (Equal $ \a b -> equating tallHeight a b && equating tallWidth a b)
         (Equal $ \a b ->
            equating generalHeight a b && equating generalWidth a b)


instance
   (C vert, C horiz, Show height, Show width) =>
      Show (Extent vert horiz height width) where
   showsPrec prec =
      getAccessor $
      switchTagPair
         (Accessor $ showsPrecSquare prec)
         (Accessor $ showsPrecAny "Extent.wide" prec)
         (Accessor $ showsPrecAny "Extent.tall" prec)
         (Accessor $ showsPrecAny "Extent.general" prec)

showsPrecSquare ::
   (Show height) =>
   Int -> Extent Small Small height width -> ShowS
showsPrecSquare p x =
   showParen (p>10) $
   showString "Extent.square " . showsPrec 11 (height x)

showsPrecAny ::
   (C vert, C horiz, Show height, Show width) =>
   String -> Int -> Extent vert horiz height width -> ShowS
showsPrecAny name p x =
   showParen (p>10) $
   showString name .
   showString " " . showsPrec 11 (height x) .
   showString " " . showsPrec 11 (width x)


newtype Widen heightA widthA heightB widthB vert =
   Widen {
      getWiden ::
         Extent vert Big heightA widthA ->
         Extent vert Big heightB widthB
   }

widen ::
   (C vert) =>
   widthB -> Extent vert Big height widthA -> Extent vert Big height widthB
widen w =
   getWiden $
   switchTag
      (Widen (\x -> x{wideWidth = w}))
      (Widen (\x -> x{generalWidth = w}))

reduceWideHeight ::
   (C vert) =>
   heightB -> Extent vert Big heightA width -> Extent vert Big heightB width
reduceWideHeight h =
   getWiden $
   switchTag
      (Widen (\x -> x{wideHeight = h}))
      (Widen (\x -> x{generalHeight = h}))


newtype Adapt heightA widthA heightB widthB vert horiz =
   Adapt {
      getAdapt ::
         Extent vert horiz heightA widthA ->
         Extent vert horiz heightB widthB
   }

reduceConsistent ::
   (C vert, C horiz) =>
   height -> width ->
   Extent vert horiz height width -> Extent vert horiz height width
reduceConsistent h w =
   getAdapt $
   switchTagPair
      (Adapt $ \(Square _) -> Square h)
      (Adapt $ \(Wide _ _) -> Wide h w)
      (Adapt $ \(Tall _ _) -> Tall h w)
      (Adapt $ \(General _ _) -> General h w)


class (C vert, C horiz) => GeneralTallWide vert horiz where
   switchTagGTW :: f Small Big -> f Big Small -> f Big Big -> f vert horiz

instance GeneralTallWide Small Big where switchTagGTW f _ _ = f
instance GeneralTallWide Big Small where switchTagGTW _ f _ = f
instance GeneralTallWide Big Big where switchTagGTW _ _ f = f

mapHeight ::
   (GeneralTallWide vert horiz) =>
   (heightA -> heightB) ->
   Extent vert horiz heightA width -> Extent vert horiz heightB width
mapHeight f =
   getAdapt $
   switchTagGTW
      (Adapt $ \(Wide h w) -> Wide (f h) w)
      (Adapt $ \(Tall h w) -> Tall (f h) w)
      (Adapt $ \(General h w) -> General (f h) w)

mapWidth ::
   (GeneralTallWide vert horiz) =>
   (widthA -> widthB) ->
   Extent vert horiz height widthA -> Extent vert horiz height widthB
mapWidth f =
   getAdapt $
   switchTagGTW
      (Adapt $ \(Wide h w) -> Wide h (f w))
      (Adapt $ \(Tall h w) -> Tall h (f w))
      (Adapt $ \(General h w) -> General h (f w))

mapSquareSize :: (shA -> shB) -> Square shA -> Square shB
mapSquareSize f (Square s) = Square (f s)


mapWrap ::
   (C vert, C horiz) =>
   (height -> f height) ->
   (width -> f width) ->
   Extent vert horiz height width ->
   Extent vert horiz (f height) (f width)
mapWrap fh fw =
   getAdapt $
   switchTagPair
      (Adapt $ \(Square h) -> Square (fh h))
      (Adapt $ \(Wide h w) -> Wide (fh h) (fw w))
      (Adapt $ \(Tall h w) -> Tall (fh h) (fw w))
      (Adapt $ \(General h w) -> General (fh h) (fw w))

{- only admissible since GHC-7.8
mapUnwrap ::
   (C vert, C horiz) =>
   (f height -> height) ->
   (f width -> width) ->
   Extent vert horiz (f height) (f width) ->
   Extent vert horiz height width
mapUnwrap fh fw =
   getAdapt $
   switchTagPair
      (Adapt $ \(Square h) -> Square (fh h))
      (Adapt $ \(Wide h w) -> Wide (fh h) (fw w))
      (Adapt $ \(Tall h w) -> Tall (fh h) (fw w))
      (Adapt $ \(General h w) -> General (fh h) (fw w))
-}

recheck ::
   (C vert, C horiz) =>
   Extent vert horiz (Unchecked height) (Unchecked width) ->
   Extent vert horiz height width
recheck =
   getAdapt $
   switchTagPair
      (Adapt $ \(Square h) -> Square (deconsUnchecked h))
      (Adapt $ \(Wide h w) -> Wide (deconsUnchecked h) (deconsUnchecked w))
      (Adapt $ \(Tall h w) -> Tall (deconsUnchecked h) (deconsUnchecked w))
      (Adapt $ \(General h w) ->
                              General (deconsUnchecked h) (deconsUnchecked w))



newtype Fuse height fuse width vert horiz =
   Fuse {
      getFuse ::
         Extent vert horiz height fuse ->
         Extent vert horiz fuse width ->
         Maybe (Extent vert horiz height width)
   }

fuse ::
   (C vert, C horiz, Eq fuse) =>
   Extent vert horiz height fuse ->
   Extent vert horiz fuse width ->
   Maybe (Extent vert horiz height width)
fuse =
   getFuse $
   switchTagPair
      (Fuse $ \(Square s0) (Square s1) -> toMaybe (s0==s1) $ Square s0)
      (Fuse $ \(Wide h f0) (Wide f1 w) -> toMaybe (f0==f1) $ Wide h w)
      (Fuse $ \(Tall h f0) (Tall f1 w) -> toMaybe (f0==f1) $ Tall h w)
      (Fuse $ \(General h f0) (General f1 w) -> toMaybe (f0==f1) $ General h w)


kronecker ::
   (C vert, C horiz) =>
   Extent vert horiz heightA widthA ->
   Extent vert horiz heightB widthB ->
   Extent vert horiz (heightA,heightB) (widthA,widthB)
kronecker = stackGen (,) (,)



{-
Tag table for 'beside'.

Small Small  Small Small -> Small Big
Small Small  Small Big   -> Small Big
Small Small  Big   Small -> Small Big
Small Small  Big   Big   -> Small Big
Small Big    Small Small -> Small Big
Small Big    Small Big   -> Small Big
Small Big    Big   Small -> Small Big
Small Big    Big   Big   -> Small Big
Big   Small  Small Small -> Small Big
Big   Small  Small Big   -> Small Big
Big   Small  Big   Small -> Big   Big
Big   Small  Big   Big   -> Big   Big
Big   Big    Small Small -> Small Big
Big   Big    Small Big   -> Small Big
Big   Big    Big   Small -> Big   Big
Big   Big    Big   Big   -> Big   Big
-}
newtype AppendMode vertA vertB vertC height widthA widthB =
   AppendMode (
      Extent vertA Big height widthA ->
      Extent vertB Big height widthB ->
      Extent vertC Big height (widthA:+:widthB)
   )

appendLeftAux ::
   (C vertA, C vertB) => AppendMode vertA vertB vertA height widthA widthB
appendLeftAux =
   AppendMode $ \extentA extentB ->
      widen (width extentA :+: width extentB) extentA

appendSame :: (C vert) => AppendMode vert vert vert height widthA widthB
appendSame = appendLeftAux

appendLeft :: (C vert) => AppendMode vert Big vert height widthA widthB
appendLeft = appendLeftAux

appendRight :: (C vert) => AppendMode Big vert vert height widthA widthB
appendRight =
   AppendMode $ \extentA extentB ->
      widen (width extentA :+: width extentB) extentB

type family Append a b
type instance Append Small b = Small
type instance Append Big   b = b

newtype
   AppendAny vertB height widthA widthB vertA =
      AppendAny {
         getAppendAny ::
            AppendMode vertA vertB (Append vertA vertB) height widthA widthB
      }

appendAny ::
   (C vertA, C vertB) =>
   AppendMode vertA vertB (Append vertA vertB) height widthA widthB
appendAny =
   getAppendAny $ switchTag (AppendAny appendLeftAux) (AppendAny appendRight)


stack ::
   (C vert, C horiz) =>
   Extent vert horiz heightA widthA ->
   Extent vert horiz heightB widthB ->
   Extent vert horiz (heightA:+:heightB) (widthA:+:widthB)
stack = stackGen (:+:) (:+:)

newtype Stack f heightA widthA heightB widthB vert horiz =
   Stack {
      getStack ::
         Extent vert horiz heightA widthA ->
         Extent vert horiz heightB widthB ->
         Extent vert horiz (f heightA heightB) (f widthA widthB)
   }

stackGen ::
   (C vert, C horiz) =>
   (heightA -> heightB -> f heightA heightB) ->
   (widthA -> widthB -> f widthA widthB) ->
   Extent vert horiz heightA widthA ->
   Extent vert horiz heightB widthB ->
   Extent vert horiz (f heightA heightB) (f widthA widthB)
stackGen fh fw =
   getStack $
   switchTagPair
      (Stack $ \(Square sa) (Square sb) ->
         Square (fh sa sb))
      (Stack $ \(Wide ha wa) (Wide hb wb) ->
         Wide (fh ha hb) (fw wa wb))
      (Stack $ \(Tall ha wa) (Tall hb wb) ->
         Tall (fh ha hb) (fw wa wb))
      (Stack $ \(General ha wa) (General hb wb) ->
         General (fh ha hb) (fw wa wb))



type family Multiply a b
type instance Multiply Small b = b
type instance Multiply Big   b = Big


data TagFact a = C a => TagFact

newtype MultiplyTagLaw b a =
   MultiplyTagLaw {
      getMultiplyTagLaw :: TagFact a -> TagFact b -> TagFact (Multiply a b)
   }

multiplyTagLaw :: TagFact a -> TagFact b -> TagFact (Multiply a b)
multiplyTagLaw a@TagFact =
   ($a) $ getMultiplyTagLaw $
   switchTag
      (MultiplyTagLaw $ flip const)
      (MultiplyTagLaw const)

heightFact :: (C vert) => Extent vert horiz height width -> TagFact vert
heightFact _ = TagFact

widthFact :: (C horiz) => Extent vert horiz height width -> TagFact horiz
widthFact _ = TagFact


newtype Unify height fuse width heightC widthC vertB horizB vertA horizA =
   Unify {
      getUnify ::
         Extent vertA horizA height fuse ->
         Extent vertB horizB fuse width ->
         Extent (Multiply vertA vertB) (Multiply horizA horizB) heightC widthC
   }

unifyLeft ::
   (C vertA, C horizA, C vertB, C horizB) =>
   Extent vertA horizA height fuse ->
   Extent vertB horizB fuse width ->
   Extent (Multiply vertA vertB) (Multiply horizA horizB) height fuse
unifyLeft =
   getUnify $
   switchTagPair
      (Unify $ const . fromSquareLiberal)
      (Unify $ const . generalizeWide)
      (Unify $ const . generalizeTall)
      (Unify $ const . toGeneral)

unifyRight ::
   (C vertA, C horizA, C vertB, C horizB) =>
   Extent vertA horizA height fuse ->
   Extent vertB horizB fuse width ->
   Extent (Multiply vertA vertB) (Multiply horizA horizB) fuse width
unifyRight =
   getUnify $
   switchTagPair
      (Unify $ const id)
      (Unify $ const genToWide)
      (Unify $ const genToTall)
      (Unify $ const toGeneral)


{-
Square  Square  -> Square
Square  Wide    -> Wide
Square  Tall    -> Tall
Square  General -> General
Wide    Square  -> Wide
Wide    Wide    -> Wide
Wide    Tall    -> General
Wide    General -> General
Tall    Square  -> Tall
Tall    Wide    -> General
Tall    Tall    -> Tall
Tall    General -> General
General Square  -> General
General Wide    -> General
General Tall    -> General
General General -> General

Small Small  Small Small -> Small Small
Small Small  Small Big   -> Small Big
Small Small  Big   Small -> Big   Small
Small Small  Big   Big   -> Big   Big
Small Big    Small Small -> Small Big
Small Big    Small Big   -> Small Big
Small Big    Big   Small -> Big   Big
Small Big    Big   Big   -> Big   Big
Big   Small  Small Small -> Big   Small
Big   Small  Small Big   -> Big   Big
Big   Small  Big   Small -> Big   Small
Big   Small  Big   Big   -> Big   Big
Big   Big    Small Small -> Big   Big
Big   Big    Small Big   -> Big   Big
Big   Big    Big   Small -> Big   Big
Big   Big    Big   Big   -> Big   Big
-}