{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Accelerate.Data.Complex (
Complex(..),
real,
imag,
mkPolar,
cis,
polar,
magnitude,
phase,
conjugate,
) where
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Classes
import Data.Array.Accelerate.Data.Functor
import Data.Array.Accelerate.Prelude
import Data.Array.Accelerate.Product
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Type
import Prelude ( ($), undefined )
import Data.Complex ( Complex(..) )
import qualified Data.Complex as C
import qualified Prelude as P
type instance EltRepr (Complex Half) = V2 Half
type instance EltRepr (Complex Float) = V2 Float
type instance EltRepr (Complex Double) = V2 Double
type instance EltRepr (Complex CFloat) = V2 CFloat
type instance EltRepr (Complex CDouble) = V2 CDouble
instance Elt (Complex Half) where
eltType _ = TypeRscalar scalarType
toElt (V2 r i) = r :+ i
fromElt (r :+ i) = V2 r i
instance Elt (Complex Float) where
eltType _ = TypeRscalar scalarType
toElt (V2 r i) = r :+ i
fromElt (r :+ i) = V2 r i
instance Elt (Complex Double) where
eltType _ = TypeRscalar scalarType
toElt (V2 r i) = r :+ i
fromElt (r :+ i) = V2 r i
instance Elt (Complex CFloat) where
eltType _ = TypeRscalar scalarType
toElt (V2 r i) = r :+ i
fromElt (r :+ i) = V2 r i
instance Elt (Complex CDouble) where
eltType _ = TypeRscalar scalarType
toElt (V2 r i) = r :+ i
fromElt (r :+ i) = V2 r i
instance cst a => IsProduct cst (Complex a) where
type ProdRepr (Complex a) = ProdRepr (V2 a)
fromProd cst (r :+ i) = fromProd cst (V2 r i)
toProd cst p = let (V2 r i) = toProd cst p in (r :+ i)
prod cst _ = prod cst (undefined :: (V2 a))
instance (Lift Exp a, Elt (Plain a), Elt (Complex (Plain a))) => Lift Exp (Complex a) where
type Plain (Complex a) = Complex (Plain a)
lift (r :+ i) = Exp $ Tuple (NilTup `SnocTup` lift r `SnocTup` lift i)
instance (Elt a, Elt (Complex a)) => Unlift Exp (Complex (Exp a)) where
unlift e
= let r = Exp $ SuccTupIdx ZeroTupIdx `Prj` e
i = Exp $ ZeroTupIdx `Prj` e
in
r :+ i
instance (Eq a, Elt (Complex a)) => Eq (Complex a) where
x == y = let r1 :+ c1 = unlift x
r2 :+ c2 = unlift y
in r1 == r2 && c1 == c2
x /= y = let r1 :+ c1 = unlift x
r2 :+ c2 = unlift y
in r1 /= r2 || c1 /= c2
instance (RealFloat a, Elt (Complex a)) => P.Num (Exp (Complex a)) where
(+) = lift2 ((+) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a))
(-) = lift2 ((-) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a))
(*) = lift2 ((*) :: Complex (Exp a) -> Complex (Exp a) -> Complex (Exp a))
negate = lift1 (negate :: Complex (Exp a) -> Complex (Exp a))
signum z = if z == 0
then z
else let x :+ y = unlift z
r = magnitude z
in
lift (x/r :+ y/r)
abs z = lift (magnitude z :+ 0)
fromInteger n = lift (fromInteger n :+ 0)
instance (RealFloat a, Elt (Complex a)) => P.Fractional (Exp (Complex a)) where
fromRational x = lift (fromRational x :+ 0)
z / z' = lift ((x*x''+y*y'') / d :+ (y*x''-x*y'') / d)
where
x :+ y = unlift z
x' :+ y' = unlift z'
x'' = scaleFloat k x'
y'' = scaleFloat k y'
k = - max (exponent x') (exponent y')
d = x'*x'' + y'*y''
instance (RealFloat a, Elt (Complex a)) => P.Floating (Exp (Complex a)) where
pi = lift $ pi :+ 0
exp (unlift -> x :+ y) = let expx = exp x
in complex $ expx * cos y :+ expx * sin y
log z = lift $ log (magnitude z) :+ phase z
sqrt z@(unlift -> x :+ y) =
if z == 0
then 0
else lift $ u :+ (y < 0 ? (-v, v))
where
(u,v) = unlift (x < 0 ? (lift (v',u'), lift (u',v')))
v' = abs y / (u'*2)
u' = sqrt ((magnitude z + abs x) / 2)
x ** y =
if y == 0 then 1 else
if x == 0 then if exp_r > 0 then 0 else
if exp_r < 0 then lift (inf :+ 0)
else lift (nan :+ nan)
else if isInfinite r || isInfinite i
then if exp_r > 0 then lift (inf :+ 0) else
if exp_r < 0 then 0
else lift (nan :+ nan)
else exp (log x * y)
where
r :+ i = unlift x
exp_r :+ _ = unlift y
inf = 1 / 0
nan = 0 / 0
sin (unlift -> x :+ y) = complex $ sin x * cosh y :+ cos x * sinh y
cos (unlift -> x :+ y) = complex $ cos x * cosh y :+ (- sin x * sinh y)
tan (unlift -> x :+ y) = (complex $ sinx*coshy :+ cosx*sinhy) / (complex $ cosx*coshy :+ (-sinx*sinhy))
where
sinx = sin x
cosx = cos x
sinhy = sinh y
coshy = cosh y
sinh (unlift -> x :+ y) = complex $ cos y * sinh x :+ sin y * cosh x
cosh (unlift -> x :+ y) = complex $ cos y * cosh x :+ sin y * sinh x
tanh (unlift -> x :+ y) = (complex $ cosy*sinhx :+ siny*coshx) / (complex $ cosy*coshx :+ siny*sinhx)
where
siny = sin y
cosy = cos y
sinhx = sinh x
coshx = cosh x
asin z@(unlift -> x :+ y) = complex $ y' :+ (-x')
where
x' :+ y' = unlift $ log ((complex ((-y):+x)) + sqrt (1 - z*z))
acos z = complex $ y'' :+ (-x'')
where
x'' :+ y'' = unlift $ log (z + (complex ((-y') :+ x')))
x' :+ y' = unlift $ sqrt (1 - z*z)
atan z@(unlift -> x :+ y) = complex $ y' :+ (-x')
where
x' :+ y' = unlift $ log ((complex ((1-y):+x)) / sqrt (1+z*z))
asinh z = log (z + sqrt (1+z*z))
acosh z = log (z + (z+1) * sqrt ((z-1)/(z+1)))
atanh z = 0.5 * log ((1.0+z) / (1.0-z))
instance (FromIntegral a b, Num b, Elt (Complex b)) => FromIntegral a (Complex b) where
fromIntegral x = lift (fromIntegral x :+ 0)
instance Functor Complex where
fmap f (unlift -> r :+ i) = lift (f r :+ f i)
complex :: (Elt a, Elt (Complex a)) => Complex (Exp a) -> Exp (Complex a)
complex = lift
magnitude :: (RealFloat a, Elt (Complex a)) => Exp (Complex a) -> Exp a
magnitude (unlift -> r :+ i) = scaleFloat k (sqrt (sqr (scaleFloat mk r) + sqr (scaleFloat mk i)))
where
k = max (exponent r) (exponent i)
mk = -k
sqr z = z * z
phase :: (RealFloat a, Elt (Complex a)) => Exp (Complex a) -> Exp a
phase z@(unlift -> r :+ i) =
if z == 0
then 0
else atan2 i r
polar :: (RealFloat a, Elt (Complex a)) => Exp (Complex a) -> Exp (a,a)
polar z = lift (magnitude z, phase z)
#if __GLASGOW_HASKELL__ <= 708
mkPolar :: forall a. (RealFloat a, Elt (Complex a)) => Exp a -> Exp a -> Exp (Complex a)
#else
mkPolar :: forall a. (Floating a, Elt (Complex a)) => Exp a -> Exp a -> Exp (Complex a)
#endif
mkPolar = lift2 (C.mkPolar :: Exp a -> Exp a -> Complex (Exp a))
#if __GLASGOW_HASKELL__ <= 708
cis :: forall a. (RealFloat a, Elt (Complex a)) => Exp a -> Exp (Complex a)
#else
cis :: forall a. (Floating a, Elt (Complex a)) => Exp a -> Exp (Complex a)
#endif
cis = lift1 (C.cis :: Exp a -> Complex (Exp a))
real :: (Elt a, Elt (Complex a)) => Exp (Complex a) -> Exp a
real (unlift -> r :+ _) = r
imag :: (Elt a, Elt (Complex a)) => Exp (Complex a) -> Exp a
imag (unlift -> _ :+ i) = i
conjugate :: (Num a, Elt (Complex a)) => Exp (Complex a) -> Exp (Complex a)
conjugate z = lift $ real z :+ (- imag z)