{-# LANGUAGE GeneralizedNewtypeDeriving
           , ExistentialQuantification
  #-}

module Foreign.MathLink.ML ( runMathLink
                           , runMathLinkWithArgs
                           , evaluate
                           , evaluateString
                           , getLink
                           , throwOnError
                           , boolToError
                           , getType
                           , putFunctionHead
                           , getFunctionHead
                           , putScalarWith
                           , getScalarWith
                           , putStringWith
                           , getStringWith
                           , withLink0
                           , withLink1
                           , withLink2
                           , withLink3
                           , withLink4
                           ) where

import Foreign.MathLink.Types
import qualified Foreign.MathLink.IO as MLIO

import Foreign
import Foreign.C
import Foreign.Storable
import Control.Exception (bracket)
import Control.Monad
import Control.Monad.Trans
import qualified Control.Monad.Reader as Rd
import qualified Control.Monad.State as St
import qualified Control.Monad.Error as Er
import Data.IntMap (IntMap)
import qualified Data.IntMap as IM
import Data.Int
import System.IO
import System.Environment

-- | Returns the link associated with this thread of execution.
getLink :: ML Link
getLink = Rd.ask >>= (return . link)

updateState :: ML ()
updateState = do
  b <- checkMessage
  if not b then return () else do
    mMsg <- getMessage
    case mMsg of
      Nothing -> return ()
      Just (msg,_) ->
        case msg of
          TerminateMessage -> do st <- St.get
                                 St.put $ st { done = True }
                                 updateState
          InterruptMessage -> do st <- St.get
                                 St.put $ st { abort = True }
                                 updateState
          AbortMessage     -> do st <- St.get
                                 St.put $ st { abort = True }
                                 updateState
          UnknownMessage _ -> return ()
          _                -> updateState

-- | Runs /MathLink/, exposing the given list of functions.
runMathLink :: [Function] -> IO ()
runMathLink functions = do 
  args <- getArgs
  runMathLinkWithArgs args functions

-- | Like 'runMathLink', but explicitly specifies the command line
--   arguments to be passed to /MathLink/.
runMathLinkWithArgs :: [String] -> [Function] -> IO ()
runMathLinkWithArgs args functions =
  bracket MLIO.acquireEnvironment MLIO.releaseEnvironment $ \env ->
    bracket (MLIO.acquireLink env args) MLIO.releaseLink $ \lnk -> do
      let config = Config { environment = env
                          , link = lnk
                          , functionTable = 
                              IM.fromList $ zip [0..] functions
                          }
          state = State { abort = False
                        , done = False
                        }
      er <- runML runLoop state config 
      case er of
        Left err -> do hPutStrLn stderr err
                       return ()
        Right () -> return ()
            
runLoop :: ML ()
runLoop = do
  installFunctionTable
  processPackets

processPackets :: ML ()
processPackets = do
  pkt <- answer
  case pkt of
    ResumePacket -> do
      refuseToBeAFrontEnd
      processPackets
    _ -> return ()
          
answer :: ML Packet
answer = do
  pkt <- getPacket
  case pkt of
    CallPacket -> do
      processCallPacket
      endPacket
      newPacket
      answer
    _ -> return $ mkPacket 0
  return pkt

printString :: String -> ML ()
printString str = do 
  evaluate $ "Print[\"" ++ str ++ "\"]"
  return ()

processCallPacket :: ML ()
processCallPacket =
  do expr <- get
     case expr of
       ExInt n -> do
         config <- Rd.ask
         case n `IM.lookup` (functionTable config) of
           Just fn -> function fn
           _       -> Er.throwError "Function lookup failed."
       _ -> Er.throwError "Expected int."
  `Er.catchError` \err -> do
     clearError
     printString err
     put $ ExSymbol "$Failed"
  

refuseToBeAFrontEnd :: ML ()
refuseToBeAFrontEnd = do
  putFunctionHead "EvaluatePacket" 1
  putFunctionHead "Module" 2
  putFunctionHead "List" 1
  putFunctionHead "Set" 2
  put meSym
  put plSym
  putFunctionHead "CompoundExpression" 3
  putFunctionHead "Set" 2
  put plSym
  getLink >>= (liftIO . (\l -> MLIO.transferExpression l l))
  putFunctionHead "Message" 2
  putFunctionHead "MessageName" 2
  put plSym
  put $ ExString "notfe"
  put meSym
  put meSym
  endPacket
  waitForPacket (== SuspendPacket)
  where meSym = ExSymbol "me"
        plSym = ExSymbol "$ParentLink"

installFunctionTable :: ML ()
installFunctionTable = do
  activate
  functionPairs <- Rd.ask >>= (return . IM.toList . functionTable)
  mapM_ definePattern functionPairs
  put $ ExSymbol "End"
  flush

definePattern :: (Int,Function) -> ML ()
definePattern (ident,func) =
  put $ ExFunction "DefineExternal" [ ExString $ callPattern func
                                    , ExString $ argumentPattern func
                                    , ExInt ident
                                    ]

-- | Sends the given 'String' to /Mathematica/ for evaluation.
--
-- Does not block
evaluate :: String -> ML Bool
evaluate s =
    do put $ ExFunction "EvaluatePacket" 
               [ ExFunction "ToExpression"
                   [ ExString s ] ]
       endPacket
       return True
    `Er.catchError` do
       return $ return False

-- | Like 'evaluate', but blocks until the execution is complete.
evaluateString :: String -> ML Bool
evaluateString s = do
    result <- evaluate s
    waitForPacket (== ReturnPacket)
    return result

waitForPacket :: (Packet -> Bool) -> ML ()
waitForPacket q = do
  pkt <- getPacket
  newPacket
  if q pkt then return () else waitForPacket q
      

-- misc

activate :: ML ()
activate = withLink0 MLIO.activate >>= boolToError

flush :: ML ()
flush = withLink0 MLIO.flush >>= boolToError

checkReady :: ML Bool
checkReady = withLink0 MLIO.checkReady


-- errors

getError :: ML Error
getError = withLink0 MLIO.getError

clearError :: ML Bool
clearError = withLink0 MLIO.clearError

getErrorMessage :: ML String
getErrorMessage = withLink0 MLIO.getErrorMessage

throwOnError :: Integral a => a -> ML ()
throwOnError i =
  if i == 0 then
      getErrorMessage >>= Er.throwError
    else
      return ()

boolToError :: Bool -> ML ()
boolToError True = return ()
boolToError False = getErrorMessage >>= Er.throwError

-- packets

getPacket :: ML Packet
getPacket = withLink0 MLIO.getPacket

newPacket :: ML Bool
newPacket = withLink0 MLIO.newPacket

endPacket :: ML Bool
endPacket = withLink0 MLIO.endPacket


-- messages

getMessage :: ML (Maybe (Message,Int))
getMessage = withLink0 MLIO.getMessage

putMessage :: Message -> ML Bool
putMessage = withLink1 MLIO.putMessage

checkMessage :: ML Bool
checkMessage = withLink0 MLIO.checkMessage

-- | Helper for marshaling scalar values to /Mathematica/.
putScalarWith :: (Link -> b -> IO CInt)
              -> (a -> b)
              -> a
              -> ML ()
putScalarWith fn cnv i = withLink1 fn (cnv i) >>= throwOnError

-- | Helper for marshaling scalar values from /Mathematica/.
getScalarWith :: Storable a
              => (Link -> Ptr a -> IO CInt)
              -> (a -> b)
              -> ML b
getScalarWith fn cnv = do
  l <- getLink 
  eS <- liftIO $ bracket malloc free $ \xPtr -> do
                   bS <- fn l xPtr >>= MLIO.convToBool
                   if bS then
                       peek xPtr >>= (return . Right . cnv)
                     else
                       MLIO.getErrorMessage l >>= (return . Left)
  case eS of
    Left msg -> Er.throwError msg
    Right s  -> return s

-- | Helper for marshaling 'String's to /Mathematica/.
putStringWith :: (Link -> CString -> IO CInt)
              -> String
              -> ML ()
putStringWith fn str = do
  l <- getLink
  liftIO (withCString str $ \sPtr -> fn l sPtr) >>= throwOnError

-- | Helper for marshaling 'String's from /Mathematica/.
getStringWith :: (Link -> Ptr CString -> IO CInt)
              -> (Link -> CString -> IO ())
              -> ML String
getStringWith afn rfn = do
  l <- getLink
  eStr <- liftIO $ bracket malloc free $ \strPtrPtr -> do
            bStr <- afn l strPtrPtr >>= MLIO.convToBool
            if bStr then do
                strPtr <- peek strPtrPtr
                str <- peekCString strPtr
                rfn l strPtr
                return $ Right str
              else
                MLIO.getErrorMessage l >>= (return . Left)
  case eStr of
    Left err -> Er.throwError err
    Right s  -> return s

-- | Gets the type of the next expression to be read on the /MathLink/
--   connection.
getType :: ML Type
getType = getLink >>= (liftIO . MLIO.mlGetType) >>= (return . mkType)

putFunctionHead :: String -> Int -> ML ()
putFunctionHead hd n = 
  putStringWith (\l s -> MLIO.mlPutFunction l s (fromIntegral n)) hd

getFunctionHead :: ML (String,Int)
getFunctionHead = do
  l <- getLink
  eFn <- liftIO $ bracket malloc free $ \strPtrPtr ->
           bracket malloc free $ \nPtr -> do
             bFn <- MLIO.mlGetFunction l strPtrPtr nPtr >>= MLIO.convToBool
             if bFn then do
                 strPtr <- peek strPtrPtr
                 str <- peekCString strPtr
                 n <- peek nPtr
                 MLIO.mlReleaseSymbol l strPtr
                 return $ Right (str,fromIntegral n)
               else
                 MLIO.getErrorMessage l >>= (return . Left)
  case eFn of
    Left err -> Er.throwError err
    Right f  -> return f

instance Expressible Expression where
    put e =
      case e of
        ExInt i -> putScalarWith MLIO.mlPutInt fromIntegral i
        ExReal r -> putScalarWith MLIO.mlPutReal64 realToFrac r
        ExString s -> putStringWith MLIO.mlPutString s
        ExSymbol s -> putStringWith MLIO.mlPutSymbol s
        ExFunction hd args -> do 
          putFunctionHead hd (fromIntegral $ length args)
          mapM_ put args

    get = do
      typ <- getType
      case typ of
        ErrorType -> getErrorMessage >>= Er.throwError
        IntType -> getScalarWith MLIO.mlGetInt fromIntegral >>= 
          (return . ExInt)
        RealType -> getScalarWith MLIO.mlGetReal64 realToFrac >>= 
          (return . ExReal)
        StringType -> 
          getStringWith MLIO.mlGetString MLIO.mlReleaseString >>=
            (return . ExString)
        SymbolType -> 
          getStringWith MLIO.mlGetSymbol MLIO.mlReleaseSymbol >>= 
            (return . ExSymbol)
        FunctionType -> do
          (hd,nArgs) <- getFunctionHead
          args <- mapM id $ take nArgs $ repeat get
          return $ ExFunction hd args


-- lifting utilities

withLink0 
    :: (Link -> IO a) 
    -> ML a
withLink0 f = getLink >>= (liftIO . f)

withLink1 
    :: (Link -> a -> IO b) 
    -> (a -> ML b)
withLink1 f = \x -> getLink >>= (liftIO . (\l -> f l x))
       
withLink2 
    :: (Link -> a -> b -> IO c) 
    -> (a -> b -> ML c)
withLink2 f = \x y -> getLink >>= (liftIO . (\l -> f l x y))

withLink3 
    :: (Link -> a -> b -> c -> IO d) 
    -> (a -> b -> c -> ML d)
withLink3 f = \x y z -> getLink >>= (liftIO . (\l -> f l x y z))

withLink4 
    :: (Link -> a -> b -> c -> d -> IO e) 
    -> (a -> b -> c -> d -> ML e)
withLink4 f = \x y z w -> getLink >>= (liftIO . (\l -> f l x y z w))