{-# LANGUAGE PackageImports, RankNTypes, RecordWildCards,
             ScopedTypeVariables, FlexibleContexts #-}
{-# OPTIONS -Wall #-}
module Numeric.Optimization.Algorithms.CMAES (
       run, Config(..), defaultConfig,
       minimize, minimizeIO,
       minimizeT, minimizeTIO,
       minimizeG, minimizeGIO,
       getDoubles, putDoubles
)where
import           Control.Applicative ((<|>))
import           Control.Monad hiding (forM_, mapM)
import qualified "mtl" Control.Monad.State as State
import           Data.Data
import           Data.Generics
import           Data.List (isPrefixOf)
import           Data.Maybe
import           Data.Foldable
import           Data.Traversable
import           Safe (atDef, headDef)
import           System.IO
import qualified "strict" System.IO.Strict as Strict
import           System.IO.Unsafe(unsafePerformIO)
import           System.Process
import           Prelude hiding (concat, mapM, sum)
import Paths_cmaes
data Config tgt = Config
  { funcIO        :: tgt -> IO Double
    
  , projection    :: tgt -> [Double]
    
  , embedding     :: [Double] -> tgt
    
  , initXs        :: [Double]
    
  , sigma0        :: Double
    
  , scaling       :: Maybe [Double]
    
    
    
  , typicalXs     :: Maybe [Double]
    
    
    
  , noiseHandling :: Bool
    
  , noiseReEvals  :: Maybe Int
    
  , noiseEps      :: Maybe Double
    
    
  , tolFacUpX     :: Maybe Double
    
    
  , tolUpSigma    :: Maybe Double
    
  , tolFun        :: Maybe Double
    
    
  , tolStagnation :: Maybe Int
    
    
  , tolX          :: Maybe Double
    
    
  , verbose       :: Bool
    
  , otherArgs     :: [(String, String)]
    
    
  , pythonPath    :: Maybe FilePath
    
  , cmaesWrapperPath :: Maybe FilePath
    
  }
defaultConfig :: Config a
defaultConfig = Config
  { funcIO        = error "funcIO undefined"
  , projection    = error "projection undefined"
  , embedding     = error "embedding undefined"
  , initXs        = error "initXs undefined"
  , sigma0        = 0.25
  , scaling       = Nothing
  , typicalXs     = Nothing
  , noiseHandling = False
  , noiseReEvals  = Nothing
  , noiseEps      = Just 1e-7
  , tolFacUpX     = Just 1e10
  , tolUpSigma    = Just 1e20
  , tolFun        = Just 1e-11
  , tolStagnation = Nothing
  , tolX          = Just 1e-11
  , verbose       = False
  , otherArgs     = []
  , pythonPath    = Nothing
  , cmaesWrapperPath = Nothing
  }
minimize :: ([Double]-> Double) -> [Double] -> Config [Double]
minimize f xs = minimizeIO (return . f) xs
minimizeIO :: ([Double]-> IO Double) -> [Double] -> Config [Double]
minimizeIO fIO xs =
  defaultConfig
  { funcIO     = fIO
  , initXs     = xs
  , projection = id
  , embedding  = id
  }
minimizeT :: (Traversable t) => (t Double-> Double) -> t Double -> Config (t Double)
minimizeT f tx = minimizeTIO (return . f) tx
minimizeTIO :: (Traversable t) => (t Double-> IO Double) -> t Double -> Config (t Double)
minimizeTIO fIO tx =
  defaultConfig
  { funcIO     = fIO
  , initXs     = proj tx
  , projection = proj
  , embedding  = embd
  }
  where
    proj = toList
    embd = zipTWith (\_ y -> y) tx
minimizeG :: (Data a) => (a -> Double) -> a -> Config a
minimizeG f tx = minimizeGIO (return . f) tx
minimizeGIO :: (Data a) => (a -> IO Double) -> a -> Config a
minimizeGIO fIO initA =
  defaultConfig
  { funcIO     = fIO
  , initXs     = getDoubles initA
  , projection = getDoubles
  , embedding  = flip putDoubles initA
  }
wrapperFnFullPath :: FilePath
wrapperFnFullPath = unsafePerformIO $ do
  fullFn <- getDataFileName wrapperFn
  (_,hin,_,hproc) <- runInteractiveCommand "python --version"
  str <- hGetContents hin
  _ <- waitForProcess hproc
  let pythonVersion :: Int
      pythonVersion = read $ take 1 $ atDef "2" (words str) 1
      correctShebang
        | pythonVersion == 2 = "#!/usr/bin/env python"
        | otherwise          = "#!/usr/bin/env python2"
  wrapperLines <- lines <$> Strict.readFile fullFn
  when (headDef "" wrapperLines /= correctShebang) $ do
    writeFile fullFn $ unlines $ correctShebang : drop 1 wrapperLines
  return fullFn
  where
    wrapperFn = "cmaes_wrapper.py"
{-# NOINLINE wrapperFnFullPath #-}
run :: forall tgt. Config tgt -> IO tgt
run Config{..} = do
  let pythonPath0 = maybe "python2" id pythonPath
      wrapperPath0 = maybe wrapperFnFullPath id cmaesWrapperPath
  (Just hin, Just hout, _, hproc) <- createProcess (proc pythonPath0 [wrapperPath0])
    { std_in = CreatePipe, std_out = CreatePipe }
  sendLine hin $ unwords (map show initXs)
  sendLine hin $ show sigma0
  sendLine hin $ show $ length options
  forM_ options $ \(key, val) -> do
    sendLine hin key
    sendLine hin val
  let loop = do
        str <- recvLine hout
        let ws = words str
        case ws!!0 of
          "a" -> do
            return $ embedding $ map read $ drop 1 ws
          "q" -> do
            ans <- funcIO . embedding $ map read $ drop 1 ws
            sendLine hin $ show ans
            loop
          _ -> do
            fail "ohmy god"
  r <- loop
  _ <- waitForProcess hproc
  return r
    where
      probDim :: Int
      probDim = length initXs
      adjustDim :: [a] -> [a] -> [a]
      adjustDim supply orig =
        take probDim $
        catMaybes $
        zipWith (<|>)
          (map Just orig ++ repeat Nothing)
          (map Just supply)
      options :: [(String, String)]
      options = concat $ map maybeToList
        [ "scaling_of_variables" `is` (fmap$adjustDim [1..] ) scaling
        , "typical_x"            `is` (fmap$adjustDim initXs) typicalXs
        , "noise_handling"       `is` Just noiseHandling
        , "noise_reevals"        `is` noiseReEvals
        , "noise_eps"            `is` noiseEps
        , "tolfacupx"            `is` tolFacUpX
        , "tolupsigma"           `is` tolUpSigma
        , "tolfunhist"           `is` tolFun
        , "tolstagnation"        `is` tolStagnation
        , "tolx"                 `is` tolX
        ] ++ [otherArgs]
      is :: Show a => String -> Maybe a -> Maybe (String,String)
      is key = fmap (\val -> (key, show val))
      commHeader :: String
      commHeader = "<CMAES_WRAPPER_PY2HS>"
      recvLine :: Handle -> IO String
      recvLine h = do
        str <- hGetLine h
        when (verbose) $ hPutStrLn stderr str
        if commHeader `isPrefixOf` str
          then return $ unwords $ drop 1 $ words str
          else do
            recvLine h
      sendLine :: Handle -> String -> IO ()
      sendLine h str = do
        hPutStrLn h str
        hFlush h
zipTWith :: (Traversable t1, Traversable t2) => (a->b->c) -> (t1 a) -> (t2 b) -> (t1 c)
zipTWith op xs0 ys0 = State.evalState (mapM zipper xs0) (toList ys0)
  where
    zipper x = do
      zs <- State.get
      case zs of
        []   -> error "zipTWith: empty state"
        y:ys -> do State.put ys
                   return (op x y)
getDoubles :: Data a => a -> [Double]
getDoubles d = reverse $ State.execState (everywhereM getter d) []
  where
    getter :: GenericM (State.State [Double])
    getter a = do
      ys <- State.get
      let da = fmap (flip asTypeOf (head ys)) $ cast a
      case da of
        Nothing -> return a
        Just dd -> do
          State.put $ dd:ys
          return a
putDoubles :: Data a => [Double] -> a -> a
putDoubles ys0 d = State.evalState (everywhereM putter d) ys0
  where
    putter :: GenericM (State.State [Double])
    putter a0 = do
      ys <- State.get
      let ma1 = (cast =<<) $ fmap (asTypeOf (head ys)) $ cast a0
      case ma1 of
        Nothing -> return a0
        Just a1 -> do
          State.put $ tail ys
          return a1