{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE DeriveDataTypeable #-}
module Data.Array.Accelerate.Fourier.Sign where

import Data.Array.Accelerate.Data.Complex (Complex((:+)), )

import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate (Lift(lift), Unlift(unlift), Plain, )
import Data.Array.Accelerate.Smart (Exp(Exp), PreExp(Tuple, Prj), )
import Data.Array.Accelerate.Tuple
          (IsTuple(TupleRepr, fromTuple, toTuple),
           Tuple(NilTup, SnocTup), TupleIdx(ZeroTupIdx), )
import Data.Array.Accelerate.Array.Sugar
          (Elt(eltType, toElt, fromElt, eltType', toElt', fromElt'),
           EltRepr, EltRepr', )

import Data.Typeable (Typeable, )

import qualified Test.QuickCheck as QC


newtype Sign a = Sign {getSign :: a}
   deriving (Eq, Show, Typeable)

type instance EltRepr  (Sign a) = EltRepr  a
type instance EltRepr' (Sign a) = EltRepr' a

instance Elt a => Elt (Sign a) where
   eltType = eltType . getSign
   toElt   = Sign . toElt
   fromElt = fromElt . getSign

   eltType' = eltType' . getSign
   toElt'   = Sign . toElt'
   fromElt' = fromElt' . getSign

instance IsTuple (Sign a) where
   type TupleRepr (Sign a) = ((), a)
   fromTuple (Sign a) = ((), a)
   toTuple ((), a)    = Sign a

instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Sign a) where
   type Plain (Sign a) = Sign (Plain a)
   lift (Sign a) = Exp $ Tuple (NilTup `SnocTup` lift a)

instance Elt a => Unlift Exp (Sign (Exp a)) where
   unlift e = Sign $ Exp $ ZeroTupIdx `Prj` e


forward, inverse :: Num a => Sign a
forward = Sign (-1)
inverse = Sign 1

forwardExp, inverseExp :: (Elt a, A.IsNum a) => Exp (Sign a)
forwardExp = lift $ Sign $ A.fromIntegral (-1 :: Exp Int)
inverseExp = lift $ Sign $ A.fromIntegral ( 1 :: Exp Int)

toSign :: (Elt a) => Exp (Sign a) -> Exp a
toSign = getSign . unlift

cis ::
   (Elt a, A.IsFloating a) =>
   Exp (Sign a) -> Exp a -> Exp (Complex a)
cis sign w  =  A.lift $ cos w :+ toSign sign * sin w

cisRat ::
   (Elt a, A.IsFloating a) =>
   Exp (Sign a) -> Exp Int -> Exp Int -> Exp (Complex a)
cisRat sign denom numer =
   cis sign $ 2*pi * A.fromIntegral numer / A.fromIntegral denom


instance (Num a) => QC.Arbitrary (Sign a) where
   arbitrary = QC.elements [forward, inverse]