{-# LANGUAGE TypeFamilies #-}
module Data.Array.Comfort.Storable.Internal (
   Array(Array, shape, buffer),
   reshape,
   mapShape,

   (!),
   unsafeCreate,
   unsafeCreateWithSize,
   unsafeCreateWithSizeAndResult,
   toList,
   fromList,
   vectorFromList,

   map,
   copyIO,
   (//),

   createIO,
   createWithSizeIO,
   createWithSizeAndResultIO,
   showIO,
   readIO,
   toListIO,
   fromListIO,
   vectorFromListIO,
   ) where

import qualified Data.Array.Comfort.Shape as Shape

import qualified Foreign.Marshal.Array.Guarded as Alloc
import Foreign.Marshal.Array (copyArray, pokeArray, peekArray, advancePtr, )
import Foreign.Storable (Storable, poke, pokeElemOff, peek, peekElemOff, )
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, )
import Foreign.Ptr (Ptr, )

import System.IO.Unsafe (unsafePerformIO, )

import Control.Applicative ((<$>))

import Data.Tuple.HT (mapFst)

import Prelude hiding (map, readIO, )


data Array sh a =
   Array {
      shape :: sh,
      buffer :: ForeignPtr a
   }

instance (Shape.C sh, Show sh, Storable a, Show a) => Show (Array sh a) where
   show = unsafePerformIO . showIO

reshape :: sh1 -> Array sh0 a -> Array sh1 a
reshape sh (Array _ fptr) = Array sh fptr

mapShape :: (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
mapShape f (Array sh fptr) = Array (f sh) fptr


infixl 9 !

unsafeCreate ::
   (Shape.C sh, Storable a) => sh -> (Ptr a -> IO ()) -> Array sh a
unsafeCreate sh = unsafePerformIO . createIO sh

unsafeCreateWithSize ::
   (Shape.C sh, Storable a) => sh -> (Int -> Ptr a -> IO ()) -> Array sh a
unsafeCreateWithSize sh = unsafePerformIO . createWithSizeIO sh

unsafeCreateWithSizeAndResult ::
   (Shape.C sh, Storable a) => sh -> (Int -> Ptr a -> IO b) -> (Array sh a, b)
unsafeCreateWithSizeAndResult sh =
   unsafePerformIO . createWithSizeAndResultIO sh

(!) :: (Shape.Indexed sh, Storable a) => Array sh a -> Shape.Index sh -> a
(!) arr = unsafePerformIO . readIO arr

toList :: (Shape.C sh, Storable a) => Array sh a -> [a]
toList = unsafePerformIO . toListIO

fromList :: (Shape.C sh, Storable a) => sh -> [a] -> Array sh a
fromList sh = unsafePerformIO . fromListIO sh

vectorFromList :: (Storable a) => [a] -> Array (Shape.ZeroBased Int) a
vectorFromList = unsafePerformIO . vectorFromListIO


map ::
   (Shape.C sh, Storable a, Storable b) =>
   (a -> b) -> Array sh a -> Array sh b
map f (Array sh a) =
   unsafeCreate sh $ \dstPtr ->
   withForeignPtr a $ \srcPtr ->
   sequence_ $ take (Shape.size sh) $
      zipWith
         (\src dst -> poke dst . f =<< peek src)
         (iterate (flip advancePtr 1) srcPtr)
         (iterate (flip advancePtr 1) dstPtr)

copyIO :: (Shape.C sh, Storable a) => Array sh a -> IO (Array sh a)
copyIO (Array sh srcFPtr) =
   withForeignPtr srcFPtr $ \srcPtr ->
   createWithSizeIO sh $ \n dstPtr ->
      copyArray dstPtr srcPtr n

(//) ::
   (Shape.Indexed sh, Storable a) =>
   Array sh a -> [(Shape.Index sh, a)] -> Array sh a
(//) (Array sh fptr) xs =
   unsafeCreateWithSize sh $ \n dstPtr ->
   withForeignPtr fptr $ \srcPtr -> do
      copyArray dstPtr srcPtr n
      mapM_ (\(ix,a) -> pokeElemOff dstPtr (Shape.offset sh ix) a) xs



createIO ::
   (Shape.C sh, Storable a) => sh -> (Ptr a -> IO ()) -> IO (Array sh a)
createIO sh f = createWithSizeIO sh $ const f

createWithSizeIO ::
   (Shape.C sh, Storable a) => sh -> (Int -> Ptr a -> IO ()) -> IO (Array sh a)
createWithSizeIO sh f =
   fst <$> createWithSizeAndResultIO sh f

createWithSizeAndResultIO ::
   (Shape.C sh, Storable a) =>
   sh -> (Int -> Ptr a -> IO b) -> IO (Array sh a, b)
createWithSizeAndResultIO sh f =
   let size = Shape.size sh
   in fmap (mapFst (Array sh)) $ Alloc.create size $ f size

showIO :: (Shape.C sh, Show sh, Storable a, Show a) => Array sh a -> IO String
showIO arr = do
   xs <- toListIO arr
   return $ "fromList " ++ showsPrec 11 (shape arr) (' ' : show xs)

readIO :: (Shape.Indexed sh, Storable a) => Array sh a -> Shape.Index sh -> IO a
readIO (Array sh fptr) ix =
   withForeignPtr fptr $ flip peekElemOff (Shape.offset sh ix)

toListIO :: (Shape.C sh, Storable a) => Array sh a -> IO [a]
toListIO (Array sh fptr) =
   withForeignPtr fptr $ peekArray (Shape.size sh)

fromListIO ::
   (Shape.C sh, Storable a) =>
   sh -> [a] -> IO (Array sh a)
fromListIO sh xs =
   createWithSizeIO sh $ \size ptr ->
      pokeArray ptr $ take size $
      xs ++
      repeat (error "Array.Comfort.Storable.fromList: list too short for shape")

vectorFromListIO :: (Storable a) => [a] -> IO (Array (Shape.ZeroBased Int) a)
vectorFromListIO xs =
   createIO (Shape.ZeroBased $ length xs) $ flip pokeArray xs