{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ExistentialQuantification #-}
module Math.FFT.MaskTF (
   Mask, Complex, Real,
   IODim(..), TSpec,
   keep, trans, realTrans,
   mask2, mask3, (<++>),
   example,
   ) where

import qualified Data.Array.IArray as IArray
import qualified Data.Ix as Ix
import Data.Array.IArray (IArray)
import Data.Ix (Ix)

import qualified Control.Monad.Trans.State as MS
import Control.Applicative (Applicative, liftA2)
import Data.Functor.Reverse (Reverse(Reverse), getReverse)

import qualified Data.Complex as Cpl
import Prelude hiding (Real)


data IODim = IODim {ioDimNum, ioDimStride :: Int}
type TSpec typ = (([IODim], Box typ), [IODim])
newtype Mask typ ix = Mask ((ix,ix) -> Reverse (MS.State Int) (TSpec typ))


data Real = Real IODim
data Complex = Complex
data Invalid = Invalid

type family Combine typ0 typ1
type instance Combine Real Real = Invalid
type instance Combine Real Complex = Real
type instance Combine Complex Real = Real
type instance Combine Complex Complex = Complex
type instance Combine Invalid Real = Invalid
type instance Combine Invalid Complex = Invalid
type instance Combine Real Invalid = Invalid
type instance Combine Complex Invalid = Invalid
type instance Combine Invalid Invalid = Invalid

class Type typ where
   switch :: f Complex -> f Real -> f Invalid -> f typ

instance Type Complex where switch x _ _ = x
instance Type Real    where switch _ x _ = x
instance Type Invalid where switch _ _ x = x

newtype
   CombineType0 typ1 typ0 =
      CombineType0 {
         runCombineType0 :: Box typ0 -> CombineType1 typ0 typ1
      }

newtype
   CombineType1 typ0 typ1 =
      CombineType1 {
         runCombineType1 :: Box typ1 -> Box (Combine typ0 typ1)
      }

data Box typ = (Type typ) => Box typ


keepFirst ::
   (Combine typ0 typ1 ~ typ0) =>
   Box typ0 -> CombineType1 typ0 typ1
keepFirst x = CombineType1 (const x)

keepSecond ::
   (Combine typ0 typ1 ~ typ1) =>
   Box typ0 -> CombineType1 typ0 typ1
keepSecond _x = CombineType1 id

resultInvalid ::
   (Combine typ0 typ1 ~ Invalid) =>
   Box typ0 -> CombineType1 typ0 typ1
resultInvalid _x = CombineType1 (const $ Box Invalid)

combineType :: Box typ0 -> Box typ1 -> Box (Combine typ0 typ1)
combineType typ0@(Box _) typ1@(Box _) =
   flip runCombineType1 typ1 $
   flip runCombineType0 typ0 $
   switch
      (CombineType0 $ \x ->
         switch (keepSecond x) (keepSecond x) (keepSecond x))
      (CombineType0 $ \x ->
         switch (keepFirst x) (resultInvalid x) (keepSecond x))
      (CombineType0 $ \x ->
         switch (keepFirst x) (keepFirst x) (keepFirst x))


dim :: (Ix ix) => (ix,ix) -> MS.State Int IODim
dim bnds = do
   stride <- MS.get
   let num = Ix.rangeSize bnds
   MS.put (num*stride)
   return $ IODim num stride

makeMask ::
   (Ix ix, Type typ) =>
   (IODim -> (([IODim], typ), [IODim])) -> Mask typ ix
makeMask f =
   Mask $ \bnds ->
      Reverse $
      fmap
         (\((transComplex, transReal), keepDim) ->
            ((transComplex, Box transReal), keepDim)) $
      fmap f $ dim bnds

keep :: (Ix ix) => Mask Complex ix
keep = makeMask $ \d -> (([], Complex), [d])

trans :: Mask Complex Int
trans = makeMask $ \d -> (([d], Complex), [])

realTrans :: Mask Real Int
realTrans = makeMask $ \d -> (([], Real d), [])


combineTSpec :: TSpec typ0 -> TSpec typ1 -> TSpec (Combine typ0 typ1)
combineTSpec ((dims0, halfDim0), hdims0) ((dims1, halfDim1), hdims1) =
   ((dims0++dims1, combineType halfDim0 halfDim1), hdims0++hdims1)


infixl 6 <++>

(<++>) ::
   (Applicative m) =>
   m (TSpec typ0) -> m (TSpec typ1) -> m (TSpec (Combine typ0 typ1))
(<++>) = liftA2 combineTSpec

mask2 :: Mask typ0 ix0 -> Mask typ1 ix1 -> Mask (Combine typ0 typ1) (ix0,ix1)
mask2 (Mask op0) (Mask op1) =
   Mask $ \((l0,l1), (r0,r1)) -> op0 (l0,r0) <++> op1 (l1,r1)

mask3 ::
   Mask typ0 ix0 -> Mask typ1 ix1 -> Mask typ2 ix2 ->
   Mask (Combine (Combine typ0 typ1) typ2) (ix0,ix1,ix2)
mask3 (Mask op0) (Mask op1) (Mask op2) =
   Mask $ \((l0,l1,l2), (r0,r1,r2)) ->
            op0 (l0,r0) <++> op1 (l1,r1) <++> op2 (l2,r2)

dftRCN ::
   (Ix i, IArray array a) =>
   Mask Real i -> array i a -> array i (Cpl.Complex a)
dftRCN (Mask op) arr =
   undefined $ MS.evalState (getReverse $ op (IArray.bounds arr)) 1


example ::
   (IArray array a) =>
   array (Char, Int, Integer) a ->
   array (Char, Int, Integer) (Cpl.Complex a)
example = dftRCN (mask3 keep realTrans keep)