{-# LANGUAGE FlexibleInstances #-}

module Foreign.MathLink.Expressible ( putString
                                    , getString
                                    , putInt16List
                                    , getInt16List
                                    , putInt32List
                                    , getInt32List
                                    , putIntList
                                    , getIntList
                                    , putReal32List
                                    , getReal32List
                                    , putReal64List
                                    , getReal64List
                                    ) where

import Foreign
import Foreign.Storable
import Foreign.C
import Data.Int
import Data.Ix
import Control.Monad.Error
import Control.Exception (bracket)
import qualified Foreign.MathLink.IO as MLIO
import Foreign.MathLink.Types
import Foreign.MathLink.ML
import qualified Data.Array.Unboxed as A

instance Expressible Char where
    -- Note: this instance really shouldn't be used. It exists only to ensure
    -- that 'String' is an instance of Expressible. The specialization rules
    -- for lists should obviate its use in a 'String' context.
    put c = putStringWith MLIO.mlPutString [c]
    get = getStringWith MLIO.mlGetString MLIO.mlReleaseString >>= (return . head)

instance Expressible Int16 where
    put = putScalarWith MLIO.mlPutInt16 fromIntegral
    get = getScalarWith MLIO.mlGetInt16 fromIntegral

instance Expressible Int32 where
    put = putScalarWith MLIO.mlPutInt32 fromIntegral
    get = getScalarWith MLIO.mlGetInt32 fromIntegral

instance Expressible Int where
    put = putScalarWith MLIO.mlPutInt fromIntegral
    get = getScalarWith MLIO.mlGetInt fromIntegral

instance Expressible Float where
    put = putScalarWith MLIO.mlPutReal32 realToFrac
    get = getScalarWith MLIO.mlGetReal32 realToFrac

instance Expressible Double where
    put = putScalarWith MLIO.mlPutReal64 realToFrac
    get = getScalarWith MLIO.mlGetReal64 realToFrac

checkFnHead :: (String -> Bool) 
            -> (Int -> Bool) 
            -> (String,Int) 
            -> ML (String,Int)
checkFnHead hdPred nArgPred pr@(hd,nArgs) =
  case (hdPred hd, nArgPred nArgs) of
    (False,_) ->
      throwError $ "Unexpected head '" ++ hd ++ "'."
    (_,False) -> 
      throwError $ "Unexpected number of arguments: " ++ show nArgs ++ "."
    _         -> return pr

instance ( Expressible e1
         , Expressible e2
         ) => Expressible (e1,e2) where
    put (ex1,ex2) = do 
      putFunctionHead "List" 2
      put ex1
      put ex2
    get = do
      (hd,nArgs) <- getFunctionHead >>= checkFnHead (== "List") (== 2)
      ex1 <- get
      ex2 <- get
      return (ex1,ex2)

instance ( Expressible e1
         , Expressible e2
         , Expressible e3
         ) => Expressible (e1,e2,e3) where
    put (ex1,ex2,ex3) = do
      putFunctionHead "List" 3
      put ex1
      put ex2
      put ex3
    get = do
      (hd,nArgs) <- getFunctionHead >>= checkFnHead (== "List") (== 3)
      ex1 <- get
      ex2 <- get
      ex3 <- get
      return (ex1,ex2,ex3)

instance ( Expressible e1
         , Expressible e2
         , Expressible e3
         , Expressible e4
         ) => Expressible (e1,e2,e3,e4) where
    put (ex1,ex2,ex3,ex4) = do
      putFunctionHead "List" 4
      put ex1
      put ex2
      put ex3
      put ex4
    get = do
      (hd,nArgs) <- getFunctionHead >>= checkFnHead (== "List") (== 4)
      ex1 <- get
      ex2 <- get
      ex3 <- get
      ex4 <- get
      return (ex1,ex2,ex3,ex4)

instance ( Expressible e1
         , Expressible e2
         , Expressible e3
         , Expressible e4
         , Expressible e5
         ) => Expressible (e1,e2,e3,e4,e5) where
    put (ex1,ex2,ex3,ex4,ex5) = do
      putFunctionHead "List" 5
      put ex1
      put ex2
      put ex3
      put ex4
      put ex5
    get = do
      (hd,nArgs) <- getFunctionHead >>= checkFnHead (== "List") (== 5)
      ex1 <- get
      ex2 <- get
      ex3 <- get
      ex4 <- get
      ex5 <- get
      return (ex1,ex2,ex3,ex4,ex5)

instance ( Expressible e1
         , Expressible e2
         , Expressible e3
         , Expressible e4
         , Expressible e5
         , Expressible e6
         ) => Expressible (e1,e2,e3,e4,e5,e6) where
    put (ex1,ex2,ex3,ex4,ex5,ex6) = do
      putFunctionHead "List" 6
      put ex1
      put ex2
      put ex3
      put ex4
      put ex5
      put ex6
    get = do
      (hd,nArgs) <- getFunctionHead >>= checkFnHead (== "List") (== 6)
      ex1 <- get
      ex2 <- get
      ex3 <- get
      ex4 <- get
      ex5 <- get
      ex6 <- get
      return (ex1,ex2,ex3,ex4,ex5,ex6)

instance ( Expressible e1
         , Expressible e2
         , Expressible e3
         , Expressible e4
         , Expressible e5
         , Expressible e6
         , Expressible e7
         ) => Expressible (e1,e2,e3,e4,e5,e6,e7) where
    put (ex1,ex2,ex3,ex4,ex5,ex6,ex7) = do
      putFunctionHead "List" 7
      put ex1
      put ex2
      put ex3
      put ex4
      put ex5
      put ex6
      put ex7
    get = do
      (hd,nArgs) <- getFunctionHead >>= checkFnHead (== "List") (== 7)
      ex1 <- get
      ex2 <- get
      ex3 <- get
      ex4 <- get
      ex5 <- get
      ex6 <- get
      ex7 <- get
      return (ex1,ex2,ex3,ex4,ex5,ex6,ex7)

instance Expressible e => Expressible [e] where
    put es = do
      putFunctionHead "List" (length es)
      mapM_ put es
    get = do
      (hd,nArgs) <- getFunctionHead >>= checkFnHead (== "List") (const True)
      mapM id $ take nArgs $ repeat get

{-# RULES "put/String"   put = putString #-}
{-# RULES "get/String"   get = getString #-}
{-# RULES "put/[Int]"    put = putIntList #-}
{-# RULES "get/[Int]"    get = getIntList #-}
{-# RULES "put/[Int16]"  put = putInt16List #-}
{-# RULES "get/[Int16]"  get = getInt16List #-}
{-# RULES "put/[Int32]"  put = putInt32List #-}
{-# RULES "get/[Int32]"  get = getInt32List #-}
{-# RULES "put/[Float]"  put = putReal32List #-}
{-# RULES "get/[Float]"  get = getReal32List #-}
{-# RULES "put/[Double]" put = putReal64List #-}
{-# RULES "get/[Double]" get = getReal64List #-}

putListWith :: Storable a
            => (Link -> Ptr a -> CInt -> IO CInt)
            -> (b -> a)
            -> [b]
            -> ML ()
putListWith fn cnv xs = do
  l <- getLink
  liftIO (withArray (map cnv xs) $ \xPtr -> fn l xPtr n) >>= throwOnError
  where n = fromIntegral $ length xs

getListWith :: Storable a
            => (Link -> Ptr (Ptr a) -> Ptr CInt -> IO CInt)
            -> (Link -> Ptr a -> CInt -> IO CInt)
            -> (a -> b)
            -> ML [b]
getListWith afn rfn cnv = do
  l <- getLink
  eXs <- liftIO $ bracket malloc free $ \xPtrPtr ->
           bracket malloc free $ \nPtr -> do
             bXs <- afn l xPtrPtr nPtr >>= MLIO.convToBool
             if bXs then do
                 xPtr <- peek xPtrPtr
                 n <- peek nPtr
                 xs <- peekArray (fromIntegral n) xPtr
                 rfn l xPtr n
                 return $ Right (map cnv xs)
               else
                 MLIO.getErrorMessage l >>= (return . Left)
  case eXs of
    Left err -> throwError err
    Right xs -> return xs

putString :: String -> ML ()
putString = putStringWith MLIO.mlPutString

getString :: ML String
getString = getStringWith MLIO.mlGetString MLIO.mlReleaseString

putIntList :: [Int] -> ML ()
putIntList = putListWith MLIO.mlPutIntList fromIntegral

getIntList :: ML [Int]
getIntList = getListWith MLIO.mlGetIntList MLIO.mlReleaseIntList fromIntegral

putInt16List :: [Int16] -> ML ()
putInt16List = putListWith MLIO.mlPutInt16List fromIntegral

getInt16List :: ML [Int16]
getInt16List = getListWith MLIO.mlGetInt16List MLIO.mlReleaseInt16List fromIntegral

putInt32List :: [Int32] -> ML ()
putInt32List = putListWith MLIO.mlPutInt32List fromIntegral

getInt32List :: ML [Int32]
getInt32List = getListWith MLIO.mlGetInt32List MLIO.mlReleaseInt32List fromIntegral

putReal32List :: [Float] -> ML ()
putReal32List = putListWith MLIO.mlPutReal32List realToFrac

getReal32List :: ML [Float]
getReal32List = getListWith MLIO.mlGetReal32List MLIO.mlReleaseReal32List realToFrac

putReal64List :: [Double] -> ML ()
putReal64List = putListWith MLIO.mlPutReal64List realToFrac

getReal64List :: ML [Double]
getReal64List = getListWith MLIO.mlGetReal64List MLIO.mlReleaseReal64List realToFrac

putArrayWith :: Storable a
             => (Link -> Ptr a -> Ptr CInt -> Ptr CString -> CInt -> IO CInt)
             -> (b -> a)
             -> [Int]
             -> [b]
             -> ML ()
putArrayWith fn cnv dims xs = do
  l <- getLink
  (liftIO $ withArray (take sz (map cnv xs)) $ \xPtr ->
    withArray (map fromIntegral dims) $ \dimPtr ->
      withCString "List" $ \strPtr ->
        withArray (take rank (repeat strPtr)) $ \hdsPtr ->
          fn l xPtr dimPtr hdsPtr (fromIntegral rank)) >>= throwOnError
  where rank = length dims
        sz = product dims

getArrayWith :: Storable a
             => (Link -> Ptr (Ptr a) -> Ptr (Ptr CInt) -> Ptr (Ptr CString) -> Ptr CInt -> IO CInt)
             -> (Link -> Ptr a -> Ptr CInt -> Ptr CString -> CInt -> IO ())
             -> (a -> b)
             -> ML ([Int],[b])
getArrayWith afn rfn cnv = do
  l <- getLink
  mArr <- liftIO $ bracket malloc free $ \xPtrPtr ->
    bracket malloc free $ \dimPtrPtr ->
      bracket malloc free $ \headPtrPtr ->
        bracket malloc free $ \rankPtr -> do
          bAry <- afn l xPtrPtr dimPtrPtr headPtrPtr rankPtr >>= MLIO.convToBool
          if bAry then do
              rank' <- peek rankPtr
              let rank = fromIntegral rank'
              dimPtr <- peek dimPtrPtr
              dims' <- peekArray rank dimPtr
              let dims = map fromIntegral dims'
                  sz = product dims
              xPtr <- peek xPtrPtr
              xs' <- peekArray sz xPtr
              headPtr <- peek headPtrPtr
              rfn l xPtr dimPtr headPtr rank'
              return $ Right (dims, map cnv xs')
            else
              MLIO.getErrorMessage l >>= (return . Left)
  case mArr of
    Left err -> throwError err
    Right arr -> return arr

instance Dimensional Int where
    rank _ = 1
    dimensions bnds = [rangeSize bnds]
    fromDimensions [n] = (1,n)
    fromDimensions _   = error "Unexpected number of dimensions."

instance Dimensional (Int,Int) where
    rank _ = 2
    dimensions ((l1,l2),(u1,u2)) = 
      dimensions (l1,u1) ++ dimensions (l2,u2)
    fromDimensions [n1,n2] = ((l1,l2),(u1,u2))
      where (l1,u1) = fromDimensions [n1]
            (l2,u2) = fromDimensions [n2]
    fromDimensions _ = error "Unexpected number of dimensions."

instance Dimensional (Int,Int,Int) where
    rank _ = 3
    dimensions ((l1,l2,l3),(u1,u2,u3)) = 
      dimensions (l1,u1) ++ 
      dimensions (l2,u2) ++
      dimensions (l3,u3)
    fromDimensions [n1,n2,n3] = ((l1,l2,l3),(u1,u2,u3))
      where (l1,u1) = fromDimensions [n1]
            (l2,u2) = fromDimensions [n2]
            (l3,u3) = fromDimensions [n3]
    fromDimensions _ = error "Unexpected number of dimensions."

instance Dimensional (Int,Int,Int,Int) where
    rank _ = 4
    dimensions ((l1,l2,l3,l4),(u1,u2,u3,u4)) =
      dimensions (l1,u1) ++ 
      dimensions (l2,u2) ++
      dimensions (l3,u3) ++
      dimensions (l4,u4)
    fromDimensions [n1,n2,n3,n4] = ((l1,l2,l3,l4),(u1,u2,u3,u4))
      where (l1,u1) = fromDimensions [n1]
            (l2,u2) = fromDimensions [n2]
            (l3,u3) = fromDimensions [n3]
            (l4,u4) = fromDimensions [n4]
    fromDimensions _ = error "Unexpected number of dimensions."
    
instance Dimensional (Int,Int,Int,Int,Int) where
    rank _ = 5
    dimensions ((l1,l2,l3,l4,l5),(u1,u2,u3,u4,u5)) =
      dimensions (l1,u1) ++ 
      dimensions (l2,u2) ++
      dimensions (l3,u3) ++
      dimensions (l4,u4) ++
      dimensions (l5,u5)
    fromDimensions [n1,n2,n3,n4,n5] = ((l1,l2,l3,l4,l5),(u1,u2,u3,u4,u5))
      where (l1,u1) = fromDimensions [n1]
            (l2,u2) = fromDimensions [n2]
            (l3,u3) = fromDimensions [n3]
            (l4,u4) = fromDimensions [n4]
            (l5,u5) = fromDimensions [n5]
    fromDimensions _ = error "Unexpected number of dimensions."

instance ( Dimensional ix
         ) => Expressible (A.UArray ix Int16) where
    put arr = putArrayWith MLIO.mlPutInt16Array fromIntegral dims xs
      where dims = dimensions $ A.bounds arr
            xs = A.elems arr
    get = do
      (dims,xs) <- getArrayWith MLIO.mlGetInt16Array MLIO.mlReleaseInt16Array fromIntegral
      return $ A.listArray (fromDimensions dims) xs

instance ( Dimensional ix
         ) => Expressible (A.UArray ix Int32) where
    put arr = putArrayWith MLIO.mlPutInt32Array fromIntegral dims xs
      where dims = dimensions $ A.bounds arr
            xs = A.elems arr
    get = do
      (dims,xs) <- getArrayWith MLIO.mlGetInt32Array MLIO.mlReleaseInt32Array fromIntegral
      return $ A.listArray (fromDimensions dims) xs

instance ( Dimensional ix
         ) => Expressible (A.UArray ix Int) where
    put arr = putArrayWith MLIO.mlPutIntArray fromIntegral dims xs
      where dims = dimensions $ A.bounds arr
            xs = A.elems arr
    get = do
      (dims,xs) <- getArrayWith MLIO.mlGetIntArray MLIO.mlReleaseIntArray fromIntegral
      return $ A.listArray (fromDimensions dims) xs

instance ( Dimensional ix
         ) => Expressible (A.UArray ix Float) where
    put arr = putArrayWith MLIO.mlPutReal32Array realToFrac dims xs
      where dims = dimensions $ A.bounds arr
            xs = A.elems arr
    get = do
      (dims,xs) <- getArrayWith MLIO.mlGetReal32Array MLIO.mlReleaseReal32Array realToFrac
      return $ A.listArray (fromDimensions dims) xs

instance ( Dimensional ix
         ) => Expressible (A.UArray ix Double) where
    put arr = putArrayWith MLIO.mlPutReal64Array realToFrac dims xs
      where dims = dimensions $ A.bounds arr
            xs = A.elems arr
    get = do
      (dims,xs) <- getArrayWith MLIO.mlGetReal64Array MLIO.mlReleaseReal64Array realToFrac
      return $ A.listArray (fromDimensions dims) xs