{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.CArray.Conversion where

import qualified Data.Array.Accelerate.IO as AIO
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate.Array.Sugar (EltRepr)
import Data.Array.Accelerate (Array, DIM1, DIM2, Z(Z), (:.)((:.)))

import qualified Data.Array.CArray as CArray
import Data.Array.CArray (CArray, createCArray, withCArray, rangeSize)
import Data.Ix (Ix)
import Foreign.Storable.Complex ()

import Data.Array.Accelerate.Data.Complex (Complex)
import qualified Data.Complex as Complex

import System.IO.Unsafe (unsafePerformIO)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable)

import Control.Monad ((<=<))
import Data.IORef (newIORef, writeIORef, readIORef)


accFromCArrayReal1D ::
   (A.Elt a, Storable a, AIO.BlockPtrs (EltRepr a) ~ Ptr a) =>
   CArray Int a -> Array DIM1 a
accFromCArrayReal1D arr =
   let (len, fptr) = CArray.toForeignPtr arr
   in  unsafePerformIO $
       withForeignPtr fptr $ \ptr -> AIO.fromPtr (Z :. len) ptr

accFromCArrayReal2D ::
   (A.Elt a, Storable a, AIO.BlockPtrs (EltRepr a) ~ Ptr a) =>
   CArray (Int, Int) a -> Array DIM2 a
accFromCArrayReal2D arr =
   unsafePerformIO $
   withCArray arr $ \ptr ->
      AIO.fromPtr (accDimsFromArrayBounds2D $ CArray.bounds arr) ptr


accFromCArrayComplex1D ::
   (A.Elt a, Storable a, RealFloat a, AIO.BlockPtrs (EltRepr a) ~ Ptr a) =>
   CArray Int (Complex a) -> Array DIM1 (Complex a)
accFromCArrayComplex1D arr =
   accFromCArrayComplex (accDimsFromArrayBounds1D $ CArray.bounds arr) arr

accFromCArrayComplex2D ::
   (A.Elt a, Storable a, RealFloat a, AIO.BlockPtrs (EltRepr a) ~ Ptr a) =>
   CArray (Int, Int) (Complex a) -> Array DIM2 (Complex a)
accFromCArrayComplex2D arr =
   accFromCArrayComplex (accDimsFromArrayBounds2D $ CArray.bounds arr) arr

accFromCArrayComplex ::
   (Ix i, A.Shape sh,
    A.Elt a, Storable a, RealFloat a, AIO.BlockPtrs (EltRepr a) ~ Ptr a) =>
   sh -> CArray i (Complex a) -> Array sh (Complex a)
accFromCArrayComplex sh arr =
   unsafePerformIO $
   withCArray (CArray.amap Complex.realPart arr) $ \ptrReal ->
   withCArray (CArray.amap Complex.imagPart arr) $ \ptrImag ->
      AIO.fromPtr sh (((), ptrReal), ptrImag)


accDimsFromArrayBounds1D :: (CArray.Ix i) => (i, i) -> DIM1
accDimsFromArrayBounds1D bnds = Z :. rangeSize bnds

accDimsFromArrayBounds2D ::
   (CArray.Ix i, CArray.Ix j) => ((i, j), (i, j)) -> DIM2
accDimsFromArrayBounds2D ((li,lj), (ri,rj)) =
   Z :. rangeSize (li,ri) :. rangeSize (lj,rj)


cArrayFromAccReal1D ::
   (A.Elt a, Storable a, AIO.BlockPtrs (EltRepr a) ~ Ptr a) =>
   Array DIM1 a -> CArray Int a
cArrayFromAccReal1D arr =
   let (Z:.len) = A.arrayShape arr
   in  unsafePerformIO $
       createCArray (0,len-1) $ \ptr -> AIO.toPtr arr ptr

cArrayFromAccReal2D ::
   (A.Elt a, Storable a, AIO.BlockPtrs (EltRepr a) ~ Ptr a) =>
   Array DIM2 a -> CArray (Int, Int) a
cArrayFromAccReal2D arr =
   unsafePerformIO $
   createCArray (arrayBounds2DFromAccDims $ A.arrayShape arr) $ \ptr ->
      AIO.toPtr arr ptr

cArrayFromAccComplex1D ::
   (A.Elt a, Storable a, RealFloat a, AIO.BlockPtrs (EltRepr a) ~ Ptr a) =>
   Array DIM1 (Complex a) -> CArray Int (Complex a)
cArrayFromAccComplex1D arr =
   cArrayFromAccComplex
      (arrayBounds1DFromAccDims $ A.arrayShape arr) arr

cArrayFromAccComplex2D ::
   (A.Elt a, Storable a, RealFloat a, AIO.BlockPtrs (EltRepr a) ~ Ptr a) =>
   Array DIM2 (Complex a) -> CArray (Int,Int) (Complex a)
cArrayFromAccComplex2D arr =
   cArrayFromAccComplex
      (arrayBounds2DFromAccDims $ A.arrayShape arr) arr

cArrayFromAccComplex ::
   (Ix i, A.Shape sh,
    A.Elt a, Storable a, RealFloat a, AIO.BlockPtrs (EltRepr a) ~ Ptr a) =>
   (i,i) -> Array sh (Complex a) -> CArray i (Complex a)
cArrayFromAccComplex bnds arr = unsafePerformIO $ do
   fmap (uncurry (CArray.liftArray2 (Complex.:+))) $
      createCArrayAdd bnds $ \ptrReal ->
      createCArray bnds $ \ptrImag ->
         AIO.toPtr arr (((), ptrReal), ptrImag)


withCArrayComplex ::
   (Ix i, A.Shape sh,
    A.Elt a, Storable a, RealFloat a, AIO.BlockPtrs (EltRepr a) ~ Ptr a) =>
   (sh -> (i,i)) ->
   (CArray i (Complex a) -> CArray i (Complex a)) ->
   Array sh (Complex a) -> Array sh (Complex a)
withCArrayComplex bndsFromSh f arr =
   let sh = A.arrayShape arr
   in  accFromCArrayComplex sh $ f $
       cArrayFromAccComplex (bndsFromSh sh) arr


createCArrayAdd ::
   (Ix i, Storable e) =>
   (i,i) -> (Ptr e -> IO a) -> IO (CArray i e, a)
createCArrayAdd bnds create = do
   ref <- newIORef (error "uninitialized inner value")
   arr <- createCArray bnds $ writeIORef ref <=< create
   fmap ((,) arr) $ readIORef ref


arrayBounds1DFromAccDims :: DIM1 -> (Int, Int)
arrayBounds1DFromAccDims (Z :. i) = (0, i-1)

arrayBounds2DFromAccDims :: DIM2 -> ((Int, Int), (Int, Int))
arrayBounds2DFromAccDims (Z :. i :. j) = ((0,0), (i-1,j-1))