{-# LANGUAGE ScopedTypeVariables, ExplicitForAll, BangPatterns #-}

{-|
Module: Text.PhonotacticLearner.Util.ConjugateGradient
Description: Line search and Conjugate Gradient Search.
Copyright: © 2016-2017 George Steel and Peter Jurgec
License: GPL-2+
Maintainer: george.steel@gmail.com

Implementations of line search and conjugate gradient search for minimization. Line search uses Illinois False Position.
-}

module Text.PhonotacticLearner.Util.ConjugateGradient (
    traceInline, regulaFalsiSearch, conjugateGradientSearch,

    -- llpOptimizeWeights
) where

import qualified Data.Map as M
import Data.List
import Data.Ix
import Debug.Trace
import qualified Data.Vector.Unboxed as V
import System.IO
import System.IO.Unsafe
import Numeric
import Data.Array.IArray
--import Text.PhonotacticLearner.WeightedDFA
--import Text.PhonotacticLearner.Util.Probability
import Text.PhonotacticLearner.Util.Ring
--import Text.PhonotacticLearner.MaxentGrammar

-- using conjugate gradient method ad described by Shewchuk in
-- "An Introduction to the Conjugate Gradient Method Without the Agonizing Pain"


-- length of starting guess for line search
rfInitSigma :: Double
rfInitSigma = 0.05

-- | Version of 'trace' which does not output a trailing linebreak. Good for progress bars.
traceInline :: String -> a -> a
traceInline s x = unsafePerformIO $ do
    hPutStr stderr s
    hFlush stderr
    return x

-- | Line search minimization using a modified Illinois False Position method.
--
-- Adapted from description at https://en.wikipedia.org/wiki/False_position_method
regulaFalsiSearch :: Double -- ^ stoping threshold uncertainty
                  -> (Vec -> Vec -> Double) -- ^ derivative of function to minimize
                  -> Vec -- ^ starting point
                  -> Vec -- ^ direction to search in
                  -> Vec -- ^ minimum point
regulaFalsiSearch epsilon f' xinit sdir = if (dxinit > 0) then xinit else pos (rfs a1 a2 0)
    where
        dir = normalizeVec sdir
        dxinit = f' xinit dir
        pos :: Double -> Vec
        pos alpha = xinit  (alpha  dir)
        doublingSearch = [(a, f' (pos a) dir) | a <- iterate (*2) rfInitSigma]
        (a1,a2) = head (filter (\((_,dx),(_,dy)) -> (dx <= 0) && (dy >= 0)) (zip ((0, dxinit):doublingSearch) doublingSearch))
        secant (!x,!dx) (!y,!dy) = (x*dy - y*dx) / (dy - dx)
        rfs :: (Double, Double) -> (Double, Double) -> Int -> Double
        rfs (!x,!dx) (!y,!dy) !bal
            | (dx == 0) = x
            | (dy == 0) = y
            | ((y-x) < epsilon) = secant (x,dx) (y,dy)
            | (dz <= 0) = {-traceShow (x,y, dx, dy, bal) $-} rfs (z,dz) (y,dy) (min bal 0 - 1)
            | otherwise = {-traceShow (x,y, dx, dy, bal) $-} rfs (x,dx) (z,dz) (max bal 0 + 1)
            where
                sy = if bal <= (-2) then (0.707 ^ negate bal) else 1
                sx = if bal >= 2 then (0.707 ^ bal) else 1
                z = (secant (x, sx*dx) (y, sy*dy))
                dz = f' (pos z) dir

-- | Nonlinear conjugate gradient search using Polak-Ribière method.
-- Stopping condition is two steps both havong a delta below the threshold.
conjugateGradientSearch :: Bool -- ^ trace progress to 'stderr' if true
                        -> (Double, Double) -- ^ stopping thresholds for conjugate gradient step and line search
                        -> (Vec -> (Vec, Bool)) -- ^ function to project points back into area defined by inequality constraints
                                                -- (for unconstrained problems use @(\x->(x,False))@)
                        -> (Vec -> (Double, Vec)) -- ^ function to minimize, returns value and gradient
                        -> (Vec -> Vec -> Double) -- ^ partial derivative of function to minimize
                        -> Vec -- ^ starting point
                        -> Vec -- ^ minimum point
conjugateGradientSearch shouldtrace (e1, e2) conproj fstar f' start = cjs dims (start  vec [2*e1]) zero zero start
    where                                       -- fake last step triggers restart and aviods stopping condition
        opttrace = if shouldtrace then traceInline else const id
        dims = length (coords start)
        cjs :: Int -> Vec -> Vec -> Vec -> Vec -> Vec
        cjs !bal !oldx !olddir !oldgrad !x = if normVec (oldx  x) < e1 || normVec (x  newx) < e1 -- two steps small enough
                                             then newx
                                             else cjs nbal' x sdir grad newx'
            where
                (v,grad) = fstar x
                beta' = innerProd grad (grad  oldgrad) / innerProd oldgrad oldgrad --Polak-Ribière
                (beta, nbal) = if (bal >= dims || beta' <= 0) then (0,0) else (beta', bal + 1)
                sdir = (beta  olddir)  grad
                newx = opttrace (if beta <= 0 then "+" else "-") $ regulaFalsiSearch e2 f' x sdir
                (newx', iscorr) = conproj newx
                nbal' = if iscorr then dims else nbal