module Numeric.GSL.SimulatedAnnealing (
  
  simanSolve
  
  , SimulatedAnnealingParams(..)
  ) where
import Numeric.GSL.Internal
import Numeric.LinearAlgebra.HMatrix hiding(step)
import Data.Vector.Storable(generateM)
import Foreign.Storable(Storable(..))
import Foreign.Marshal.Utils(with)
import Foreign.Ptr(Ptr, FunPtr, nullFunPtr)
import Foreign.StablePtr(StablePtr, newStablePtr, deRefStablePtr, freeStablePtr)
import Foreign.C.Types
import System.IO.Unsafe(unsafePerformIO)
import System.IO (hFlush, stdout)
import Data.IORef (IORef, newIORef, writeIORef, readIORef, modifyIORef')
data SimulatedAnnealingParams = SimulatedAnnealingParams {
  n_tries :: CInt  
  , iters_fixed_T :: CInt  
  , step_size :: Double    
  , boltzmann_k :: Double  
  , cooling_t_initial :: Double 
  , cooling_mu_t :: Double      
  , cooling_t_min :: Double     
  } deriving (Eq, Show, Read)
instance Storable SimulatedAnnealingParams where
  sizeOf p = sizeOf (n_tries p) +
             sizeOf (iters_fixed_T p) +
             sizeOf (step_size p) +
             sizeOf (boltzmann_k p) +
             sizeOf (cooling_t_initial p) +
             sizeOf (cooling_mu_t p) +
             sizeOf (cooling_t_min p)
  
  alignment p = alignment (step_size p)
  
  peek ptr = SimulatedAnnealingParams <$>
             peekByteOff ptr 0 <*>
             peekByteOff ptr i <*>
             peekByteOff ptr (2*i) <*>
             peekByteOff ptr (2*i + d) <*>
             peekByteOff ptr (2*i + 2*d) <*>
             peekByteOff ptr (2*i + 3*d) <*>
             peekByteOff ptr (2*i + 4*d)
    where
      i = sizeOf (0 :: CInt)
      d = sizeOf (0 :: Double)
  poke ptr sap = do
    pokeByteOff ptr 0 (n_tries sap)
    pokeByteOff ptr i (iters_fixed_T sap)
    pokeByteOff ptr (2*i) (step_size sap)
    pokeByteOff ptr (2*i + d) (boltzmann_k sap)
    pokeByteOff ptr (2*i + 2*d) (cooling_t_initial sap)
    pokeByteOff ptr (2*i + 3*d) (cooling_mu_t sap)
    pokeByteOff ptr (2*i + 4*d) (cooling_t_min sap)
    where
      i = sizeOf (0 :: CInt)
      d = sizeOf (0 :: Double)
type P a = StablePtr (IORef a)
copyConfig :: P a -> P a -> IO ()
copyConfig src' dest' = do
  dest <- deRefStablePtr dest'
  src <- deRefStablePtr src'
  readIORef src >>= writeIORef dest
copyConstructConfig :: P a -> IO (P a)
copyConstructConfig x = do
  conf <- deRefRead x
  newconf <- newIORef conf
  newStablePtr newconf
destroyConfig :: P a -> IO ()
destroyConfig p = do
  freeStablePtr p
deRefRead :: P a -> IO a
deRefRead p = deRefStablePtr p >>= readIORef
wrapEnergy :: (a -> Double) -> P a -> Double
wrapEnergy f p = unsafePerformIO $ f <$> deRefRead p
wrapMetric :: (a -> a -> Double) -> P a -> P a -> Double
wrapMetric f x y = unsafePerformIO $ f <$> deRefRead x <*> deRefRead y
wrapStep :: Int
         -> (Vector Double -> Double -> a -> a)
         -> GSLRNG
         -> P a
         -> Double
         -> IO ()
wrapStep nrand f (GSLRNG rng) confptr stepSize = do
  v <- generateM nrand (\_ -> gslRngUniform rng)
  conf <- deRefStablePtr confptr
  modifyIORef' conf $ f v stepSize
wrapPrint :: (a -> String) -> P a -> IO ()
wrapPrint pf ptr = deRefRead ptr >>= putStr . pf >> hFlush stdout
foreign import ccall safe "wrapper"
  mkEnergyFun :: (P a -> Double) -> IO (FunPtr (P a -> Double))
foreign import ccall safe "wrapper"
  mkMetricFun :: (P a -> P a -> Double) -> IO (FunPtr (P a -> P a -> Double))
foreign import ccall safe "wrapper"
  mkStepFun :: (GSLRNG -> P a -> Double -> IO ())
            -> IO (FunPtr (GSLRNG -> P a -> Double -> IO ()))
foreign import ccall safe "wrapper"
  mkCopyFun :: (P a -> P a -> IO ()) -> IO (FunPtr (P a -> P a -> IO ()))
foreign import ccall safe "wrapper"
  mkCopyConstructorFun :: (P a -> IO (P a)) -> IO (FunPtr (P a -> IO (P a)))
foreign import ccall safe "wrapper"
  mkDestructFun :: (P a -> IO ()) -> IO (FunPtr (P a -> IO ()))
newtype GSLRNG = GSLRNG (Ptr GSLRNG)
foreign import ccall safe "gsl_rng.h gsl_rng_uniform"
  gslRngUniform :: Ptr GSLRNG -> IO Double
foreign import ccall safe "gsl-aux.h siman"
  siman :: CInt     
        -> Ptr SimulatedAnnealingParams    
        -> P a                             
        -> FunPtr (P a -> Double)          
        -> FunPtr (P a -> P a -> Double) 
        -> FunPtr (GSLRNG -> P a -> Double -> IO ())  
        -> FunPtr (P a -> P a -> IO ())  
        -> FunPtr (P a -> IO (P a))      
        -> FunPtr (P a -> IO ())           
        -> FunPtr (P a -> IO ())           
        -> IO CInt
simanSolve :: Int   
           -> Int   
                    
           -> SimulatedAnnealingParams  
           -> a                    
           -> (a -> Double)        
           -> (a -> a -> Double)   
           -> (Vector Double -> Double -> a -> a)  
           -> Maybe (a -> String)  
           -> a          
simanSolve seed nrand params conf e m step printfun =
  unsafePerformIO $ with params $ \paramptr -> do
    ewrap <- mkEnergyFun $ wrapEnergy e
    mwrap <- mkMetricFun $ wrapMetric m
    stepwrap <- mkStepFun $ wrapStep nrand step
    confptr <- newIORef conf >>= newStablePtr
    cpwrap <- mkCopyFun copyConfig
    ccwrap <- mkCopyConstructorFun copyConstructConfig
    dwrap <- mkDestructFun destroyConfig
    pwrap <- case printfun of
      Nothing -> return nullFunPtr
      Just pf -> mkDestructFun $ wrapPrint pf
    siman (fromIntegral seed)
      paramptr confptr
      ewrap mwrap stepwrap cpwrap ccwrap dwrap pwrap // check "siman"
    result <- deRefRead confptr
    freeStablePtr confptr
    return result