{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
module Internal.Conversion (
    Complexable(..), RealElement,
    module Data.Complex
) where
import Internal.Vector
import Internal.Matrix
import Internal.Vectorized
import Data.Complex
import Control.Arrow((***))
class (Element s, Element d) => Precision s d | s -> d, d -> s where
    double2FloatG :: Vector d -> Vector s
    float2DoubleG :: Vector s -> Vector d
instance Precision Float Double where
    double2FloatG = double2FloatV
    float2DoubleG = float2DoubleV
instance Precision (Complex Float) (Complex Double) where
    double2FloatG = asComplex . double2FloatV . asReal
    float2DoubleG = asComplex . float2DoubleV . asReal
instance Precision I Z where
    double2FloatG = long2intV
    float2DoubleG = int2longV
class (Element t, Element (Complex t), RealFloat t)
    => RealElement t
instance RealElement Double
instance RealElement Float
class Complexable c where
    toComplex'   :: (RealElement e) => (c e, c e) -> c (Complex e)
    fromComplex' :: (RealElement e) => c (Complex e) -> (c e, c e)
    comp'        :: (RealElement e) => c e -> c (Complex e)
    single'      :: Precision a b => c b -> c a
    double'      :: Precision a b => c a -> c b
instance Complexable Vector where
    toComplex' = toComplexV
    fromComplex' = fromComplexV
    comp' v = toComplex' (v,constantD 0 (dim v))
    single' = double2FloatG
    double' = float2DoubleG
toComplexV :: (RealElement a) => (Vector a, Vector a) ->  Vector (Complex a)
toComplexV (r,i) = asComplex $ flatten $ fromColumns [r,i]
fromComplexV :: (RealElement a) => Vector (Complex a) -> (Vector a, Vector a)
fromComplexV z = (r,i) where
    [r,i] = toColumns $ reshape 2 $ asReal z
instance Complexable Matrix where
    toComplex' = uncurry $ liftMatrix2 $ curry toComplex'
    fromComplex' z = (reshape c *** reshape c) . fromComplex' . flatten $ z
        where c = cols z
    comp' = liftMatrix comp'
    single' = liftMatrix single'
    double' = liftMatrix double'