{-# LANGUAGE MultiParamTypeClasses, ForeignFunctionInterface, ScopedTypeVariables, FunctionalDependencies, FlexibleInstances #-}
module Data.Eigen.Internal where

import Foreign.Ptr
import Foreign.Storable
import Foreign.C.Types
import Foreign.C.String
import Control.Monad
import Control.Applicative
import System.IO.Unsafe
import Data.Complex

class (Num a, Cast a b, Cast b a, Storable b, Code b) => Elem a b | a -> b where

instance Elem Float CFloat where
instance Elem Double CDouble where
instance Elem (Complex Float) (CComplex CFloat) where
instance Elem (Complex Double) (CComplex CDouble) where

class Cast a b where
    cast :: a -> b

data CComplex a = CComplex !a !a

instance Storable a => Storable (CComplex a) where
    sizeOf _ = sizeOf (undefined :: a) * 2
    alignment _ = alignment (undefined :: a)
    poke p (CComplex x y) = do
        pokeElemOff (castPtr p) 0 x
        pokeElemOff (castPtr p) 1 y
    peek p = CComplex
        <$> peekElemOff (castPtr p) 0
        <*> peekElemOff (castPtr p) 1

instance Cast CInt Int where; cast = fromIntegral
instance Cast Int CInt where; cast = fromIntegral
instance Cast CFloat Float where; cast (CFloat x) = x
instance Cast Float CFloat where; cast = CFloat
instance Cast CDouble Double where; cast (CDouble x) = x
instance Cast Double CDouble where; cast = CDouble
instance Cast (CComplex CFloat) (Complex Float) where; cast (CComplex x y) = cast x :+ cast y
instance Cast (Complex Float) (CComplex CFloat) where; cast (x :+ y) = CComplex (cast x) (cast y)
instance Cast (CComplex CDouble) (Complex Double) where; cast (CComplex x y) = cast x :+ cast y
instance Cast (Complex Double) (CComplex CDouble) where; cast (x :+ y) = CComplex (cast x) (cast y)


performIO :: IO a -> a
performIO = unsafeDupablePerformIO

foreign import ccall "eigen-proxy.h free" c_freeString :: CString -> IO ()

call :: IO CString -> IO ()
call func = func >>= \c_str -> when (c_str /= nullPtr) $
    peekCString c_str >>= \str -> c_freeString c_str >> fail str

foreign import ccall "eigen-proxy.h free" free :: Ptr a -> IO ()

foreign import ccall "eigen-proxy.h eigen_setNbThreads" c_setNbThreads :: CInt -> IO ()
foreign import ccall "eigen-proxy.h eigen_getNbThreads" c_getNbThreads :: IO CInt

foreign import ccall "eigen-proxy.h eigen_random"      c_random      :: CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_identity"    c_identity    :: CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_add"         c_add         :: CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_sub"         c_sub         :: CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_mul"         c_mul         :: CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_diagonal"    c_diagonal    :: CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_transpose"   c_transpose   :: CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_inverse"     c_inverse     :: CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_adjoint"     c_adjoint     :: CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_conjugate"   c_conjugate   :: CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_normalize"   c_normalize   :: CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_sum"         c_sum         :: CInt -> Ptr a -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_prod"        c_prod        :: CInt -> Ptr a -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_mean"        c_mean        :: CInt -> Ptr a -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_norm"        c_norm        :: CInt -> Ptr a -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_trace"       c_trace       :: CInt -> Ptr a -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_squaredNorm" c_squaredNorm :: CInt -> Ptr a -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_blueNorm"    c_blueNorm    :: CInt -> Ptr a -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_hypotNorm"   c_hypotNorm   :: CInt -> Ptr a -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_determinant" c_determinant :: CInt -> Ptr a -> Ptr a -> CInt -> CInt -> IO CString

foreign import ccall "eigen-proxy.h eigen_rank"         c_rank       :: CInt -> CInt -> Ptr CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_image"        c_image      :: CInt -> CInt -> Ptr (Ptr a) -> Ptr CInt -> Ptr CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_kernel"       c_kernel     :: CInt -> CInt -> Ptr (Ptr a) -> Ptr CInt -> Ptr CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_solve"        c_solve       :: CInt -> CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
foreign import ccall "eigen-proxy.h eigen_relativeError" c_relativeError :: CInt -> Ptr a -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString


class Code a where; code :: a -> CInt
instance Code CFloat where; code _ = 0
instance Code CDouble where; code _ = 1
instance Code (CComplex CFloat) where; code _ = 2
instance Code (CComplex CDouble) where; code _ = 3

random      :: forall a . Code a => Ptr a -> CInt -> CInt -> IO CString
identity    :: forall a . Code a => Ptr a -> CInt -> CInt -> IO CString
add         :: forall a . Code a => Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
sub         :: forall a . Code a => Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
mul         :: forall a . Code a => Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
diagonal    :: forall a . Code a => Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
transpose   :: forall a . Code a => Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
inverse     :: forall a . Code a => Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
adjoint     :: forall a . Code a => Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
conjugate   :: forall a . Code a => Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
normalize   :: forall a . Code a => Ptr a -> CInt -> CInt -> IO CString
sum         :: forall a . Code a => Ptr a -> Ptr a -> CInt -> CInt -> IO CString
prod        :: forall a . Code a => Ptr a -> Ptr a -> CInt -> CInt -> IO CString
mean        :: forall a . Code a => Ptr a -> Ptr a -> CInt -> CInt -> IO CString
norm        :: forall a . Code a => Ptr a -> Ptr a -> CInt -> CInt -> IO CString
trace       :: forall a . Code a => Ptr a -> Ptr a -> CInt -> CInt -> IO CString
squaredNorm :: forall a . Code a => Ptr a -> Ptr a -> CInt -> CInt -> IO CString
blueNorm    :: forall a . Code a => Ptr a -> Ptr a -> CInt -> CInt -> IO CString
hypotNorm   :: forall a . Code a => Ptr a -> Ptr a -> CInt -> CInt -> IO CString
determinant :: forall a . Code a => Ptr a -> Ptr a -> CInt -> CInt -> IO CString

rank       :: forall a . Code a => CInt -> Ptr CInt -> Ptr a -> CInt -> CInt -> IO CString
image      :: forall a . Code a => CInt -> Ptr (Ptr a) -> Ptr CInt -> Ptr CInt -> Ptr a -> CInt -> CInt -> IO CString
kernel     :: forall a . Code a => CInt -> Ptr (Ptr a) -> Ptr CInt -> Ptr CInt -> Ptr a -> CInt -> CInt -> IO CString
solve      :: forall a . Code a => CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString
relativeError :: forall a . Code a => Ptr a -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> Ptr a -> CInt -> CInt -> IO CString

random      = c_random      (code (undefined :: a))
identity    = c_identity    (code (undefined :: a))
add         = c_add         (code (undefined :: a))
sub         = c_sub         (code (undefined :: a))
mul         = c_mul         (code (undefined :: a))
diagonal    = c_diagonal    (code (undefined :: a))
transpose   = c_transpose   (code (undefined :: a))
inverse     = c_inverse     (code (undefined :: a))
adjoint     = c_adjoint     (code (undefined :: a))
conjugate   = c_conjugate   (code (undefined :: a))
normalize   = c_normalize   (code (undefined :: a))
sum         = c_sum         (code (undefined :: a))
prod        = c_prod        (code (undefined :: a))
mean        = c_mean        (code (undefined :: a))
norm        = c_norm        (code (undefined :: a))
trace       = c_trace       (code (undefined :: a))
squaredNorm = c_squaredNorm (code (undefined :: a))
blueNorm    = c_blueNorm    (code (undefined :: a))
hypotNorm   = c_hypotNorm   (code (undefined :: a))
determinant = c_determinant (code (undefined :: a))

rank        = c_rank (code (undefined :: a))
image       = c_image (code (undefined :: a))
kernel      = c_kernel (code (undefined :: a))
solve       = c_solve (code (undefined :: a))
relativeError = c_relativeError (code (undefined :: a))