```{-# LANGUAGE RankNTypes, RecordWildCards,ScopedTypeVariables #-}
{-# OPTIONS -Wall #-}

{-|

Usage:

(1) create an optimization problem of type @Config@ by one of
@minimize@, @minimizeIO@ etc.

(2) @run@ it.

Let's optimize the following function /f(xs)/. @xs@ is a vector and
@f@ has its minimum at @xs !! i = sqrt(i)@.

>>> import Test.DocTest.Prop
>>> let f = sum . zipWith (\i x -> (x*abs x - i)**2) [0..] :: [Double] -> Double
>>> let initXs = replicate 10 0                            :: [Double]
>>> bestXs <- run \$ minimize f initXs
>>> assert \$ f bestXs < 1e-10

If your optimization is not working well, try:

* Set @scaling@ in the @Config@ to the appropriate search
range of each parameter.

* Set @tolFun@ in the @Config@ to the appropriate scale of
the function values.

An example for scaling the function value:

>>> let f2 xs = (/1e100) \$ sum \$ zipWith (\i x -> (x*abs x - i)**2) [0..] xs
>>> bestXs <- run \$ (minimize f2 \$ replicate 10 0) {tolFun = Just 1e-111}
>>> assert \$ f2 bestXs < 1e-110

An example for scaling the input values:

>>> let f3 xs = sum \$ zipWith (\i x -> (x*abs x - i)**2) [0,1e100..] xs
>>> let xs30 = replicate 10 0 :: [Double]
>>> let m3 = (minimize f3 xs30) {scaling = Just (replicate 10 1e50)}
>>> xs31 <- run \$ m3
>>> assert \$ f3 xs31 / f3 xs30 < 1e-10

Use @minimizeT@ to optimize functions on traversable structures.

>>> import qualified Data.Vector as V
>>> let f4 = V.sum . V.imap (\i x -> (x*abs x - fromIntegral i)**2) :: V.Vector Double -> Double
>>> bestVx <- run \$ minimizeT f4 \$ V.replicate 10 0
>>> assert \$ f4 bestVx < 1e-10

Or use @minimizeG@ to optimize functions of almost any type. Let's create a triangle ABC
so that AB = 3, AC = 4, BC = 5.

>>> let dist (ax,ay) (bx,by) = ((ax-bx)**2 + (ay-by)**2)**0.5
>>> let f5 [a,b,c] = (dist a b - 3.0)**2 + (dist a c - 4.0)**2 + (dist b c - 5.0)**2
>>> bestTriangle <- run \$ (minimizeG f5 [(0,0),(0,0),(0,0)]){tolFun = Just 1e-20}
>>> assert \$ f5 bestTriangle < 1e-10

Then the angle BAC should be orthogonal.

>>> let [(ax,ay),(bx,by),(cx,cy)] = bestTriangle
>>> assert \$ abs ((bx-ax)*(cx-ax) + (by-ay)*(cy-ay)) < 1e-10

-}

module Numeric.Optimization.Algorithms.CMAES (
run, Config(..), defaultConfig,
minimize, minimizeIO,
minimizeT, minimizeTIO,
minimizeG, minimizeGIO,
)where

import           Data.Data
import           Data.Generics
import           Data.List (isPrefixOf)
import           Data.Maybe
import           Data.Foldable
import           Data.Traversable
import           System.IO
import           System.Process
import           Prelude hiding (concat, mapM, sum)

import Paths_cmaes

-- | Optimizer configuration. @tgt@ is the type of the value to be
-- optimized.

data Config tgt = Config
{ funcIO        :: tgt -> IO Double
-- ^ The Function to be optimized.
, projection    :: tgt -> [Double]
-- ^ Extract the parameters to be tuned from @tgt@.
, embedding     :: [Double] -> tgt
-- ^ Create a value of type @tgt@ from the parameters.
, initXs        :: [Double]
-- ^ An initial guess of the parameters.
, sigma0        :: Double
-- ^ The global scaling factor.
, scaling       :: Maybe [Double]
-- ^ Typical deviation of each input parameters.
, typicalXs     :: Maybe [Double]
-- ^ Typical mean of each input parameters.
, tolFacUpX     :: Maybe Double
-- ^ Terminate when one of the scaling grew too big
-- (initial scaling was too small.)
, tolUpSigma    :: Maybe Double
-- ^ Terminate when the global scaling grew too big.
, tolFun        :: Maybe Double
-- ^ Terminate when the function value diversity in the current
-- and last few generations is smaller than this value
, tolStagnation :: Maybe Int
-- ^ Terminate when the improvement is not seen for this number
-- of iterations.
, tolX          :: Maybe Double
-- ^ Terminate when the deviations in the solutions are smaller
-- than this value.
, verbose       :: Bool
-- ^ Repeat the CMA-ES output into stderr.
}

-- | The default @Config@ values.
defaultConfig :: Config a
defaultConfig = Config
{ funcIO        = error "funcIO undefined"
, projection    = error "projection undefined"
, embedding     = error "embedding undefined"
, initXs        = error "initXs undefined"
, sigma0        = 0.25
, scaling       = Nothing
, typicalXs     = Nothing
, tolFacUpX     = Just 1e10
, tolUpSigma    = Just 1e20
, tolFun        = Just 1e-11
, tolStagnation = Nothing
, tolX          = Just 1e-11
, verbose       = False
}

-- | Create a minimizing problem, given a pure function and an initial guess.
minimize :: ([Double]-> Double) -> [Double] -> Config [Double]
minimize f xs = minimizeIO (return . f) xs

-- | Create a minimizing problem, given an @IO@ function and an initial guess.
minimizeIO :: ([Double]-> IO Double) -> [Double] -> Config [Double]
minimizeIO fIO xs =
defaultConfig
{ funcIO     = fIO
, initXs     = xs
, projection = id
, embedding  = id
}

-- | Create a minimizing problem for a function on traversable structure @t@.
minimizeT :: (Traversable t) => (t Double-> Double) -> t Double -> Config (t Double)
minimizeT f tx = minimizeTIO (return . f) tx

-- | Create a minimizing problem for an effectful function on a traversable structure @t@.
minimizeTIO :: (Traversable t) => (t Double-> IO Double) -> t Double -> Config (t Double)
minimizeTIO fIO tx =
defaultConfig
{ funcIO     = fIO
, initXs     = proj tx
, projection = proj
, embedding  = embd
}
where
proj = toList
embd = zipTWith (\_ y -> y) tx

-- | Create a minimizing problem for a function on almost any type @a@ which contain Doubles.
minimizeG :: (Data a) => (a -> Double) -> a -> Config a
minimizeG f tx = minimizeGIO (return . f) tx

-- | Create a minimizing problem for an effectful function of almost any type.
minimizeGIO :: (Data a) => (a -> IO Double) -> a -> Config a
minimizeGIO fIO initA =
defaultConfig
{ funcIO     = fIO
, initXs     = getDoubles initA
, projection = getDoubles
, embedding  = flip putDoubles initA
}

-- | Execute the optimizer and get the solution.
run :: forall tgt. Config tgt -> IO tgt
run Config{..} = do
fn <- getDataFileName wrapperFn
(Just hin, Just hout, _, _) <- createProcess (proc "python2" [fn])
{ std_in = CreatePipe, std_out = CreatePipe }
sendLine hin \$ unwords (map show initXs)
sendLine hin \$ show sigma0
sendLine hin \$ show \$ length options
forM_ options \$ \(key, val) -> do
sendLine hin key
sendLine hin val
let loop = do
str <- recvLine hout
let ws = words str
case ws!!0 of
"a" -> do
return \$ embedding \$ map read \$ drop 1 ws
"q" -> do
ans <- funcIO . embedding \$ map read \$ drop 1 ws
sendLine hin \$ show ans
loop
_ -> do
fail "ohmy god"
loop
where
options :: [(String, String)]
options = concat \$ map maybeToList
[ "scaling_of_variables" `is` scaling
, "typical_x"            `is` typicalXs
, "tolfacupx"            `is` tolFacUpX
, "tolupsigma"           `is` tolUpSigma
, "tolfunhist"           `is` tolFun
, "tolstagnation"        `is` tolStagnation
, "tolx"                 `is` tolX
]

is :: Show a => String -> Maybe a -> Maybe (String,String)
is key = fmap (\val -> (key, show val))

wrapperFn = "cmaes_wrapper.py"

recvLine :: Handle -> IO String
recvLine h = do
str <- hGetLine h
when (verbose) \$ hPutStrLn stderr str
then return \$ unwords \$ drop 1 \$ words str
else do
recvLine h

sendLine :: Handle -> String -> IO ()
sendLine h str = do
hPutStrLn h str
hFlush h

zipTWith :: (Traversable t1, Traversable t2) => (a->b->c) -> (t1 a) -> (t2 b) -> (t1 c)
zipTWith op xs0 ys0 = State.evalState (mapM zipper xs0) (toList ys0)
where
zipper x = do
(y:ys) <- State.get
State.put ys
return (op x y)

{-|

getDoubles and putDoubles are generic functions used to put [Double] in and out
of generic data types. Let's test them.

>>> let d3 = (1,2,3) :: (Double,Int,Double)
>>> getDoubles d3
[1.0,3.0]
>>> putDoubles [4,5] d3
(4.0,2,5.0)

>>> let complicated = ([0,1],(2,[3,4])) :: ([Double],(Double,[Double]))
>>> getDoubles complicated
[0.0,1.0,2.0,3.0,4.0]
>>> putDoubles [5,6,7,8,9] complicated
([5.0,6.0],(7.0,[8.0,9.0]))

Putting back the obtained values should not change the data.

>>> import Test.DocTest.Prop
>>> type Complicated = ([[Double]],(),(([(Double,String)]),[Double]))
>>> prop ((\x -> putDoubles (getDoubles x) x == x) :: Complicated -> Bool)

You can get the original list back after putting it.

>>> let make3 xs = take 3 \$ xs ++ [0..]
>>> prop ((\xs' y -> let xs = make3 xs' in getDoubles (putDoubles xs y)==xs) :: [Double] -> (Double,Double,Double) -> Bool)

-}

getDoubles :: Data d => d -> [Double]
getDoubles d = reverse \$ State.execState (everywhereM getter d) []
where
getter :: GenericM (State.State [Double])
getter a = do
ys <- State.get
let da = fmap (flip asTypeOf (head ys)) \$ cast a
case da of
Nothing -> return a
Just d -> do
State.put \$ d:ys
return a

putDoubles :: Data d => [Double] -> d -> d
putDoubles ys0 d = State.evalState (everywhereM putter d) ys0
where
putter :: GenericM (State.State [Double])
putter a = do
ys <- State.get
let ma' = (cast =<<) \$ fmap (asTypeOf (head ys)) \$ cast a
case ma' of
Nothing -> return a
Just a' -> do
State.put \$ tail ys
return a'
```