{-# LANGUAGE FlexibleInstances, GADTs, RankNTypes, TemplateHaskell #-}
-- | Description : a EDSL for describing nonlinear programs
--
-- see usage in @examples/Test3.hs@
module Ipopt.NLP where

import Control.Applicative
import Control.Lens
import Control.Monad
import Control.Monad.Identity
import Control.Monad.State
import Data.Foldable (toList)
import Data.IntMap (IntMap)
import Data.List
import Data.Monoid
import Data.Monoid (First)
import Data.Sequence (Seq)
import Data.Vector (Vector)
import Foreign.C.Types (CDouble(..))
import qualified Data.Foldable as F
import qualified Data.IntMap as IM
import qualified Data.Map as M
import qualified Data.Sequence as Seq
import qualified Data.Set as S
import qualified Data.Vector as V
import qualified Data.Vector.Storable as VS

import Ipopt.Raw

-- * state
data NLPFun = NLPFun
    { _funF, _funG :: AnyRF Seq,
      _boundX, _boundG :: Seq (Double,Double) }
instance Show NLPFun where
    show (NLPFun f g a b) = "NLPFun <f> <g> {" ++ show a ++ "}{" ++ show b ++ "}"

data NLPState = NLPState
    { -- | current maximum index
      _nMax :: Ix,
      -- | what namespace (see 'inEnv')
      _currentEnv :: [String],
      _variables :: M.Map String Ix,
      -- | human-readable descriptions for the constraint, objective and
      -- variables
      _constraintLabels, _objLabels, _varLabels :: IntMap String,
      _varEnv :: IntMap (S.Set [String]),
      _constraintEnv, _objEnv :: IntMap [String],
      _nlpfun :: NLPFun,
      _defaultBounds :: (Double,Double),
      -- | inital state variable for the solver
      _initX :: Vector Double }
    deriving (Show)

-- | solver deals with arrays. This type is for indexes into the array
-- for the current variables that the solver is trying to find.
newtype Ix = Ix { _varIx :: Int } deriving Show

-- | the initial state to use when you actually have to get to IO
-- with the solution
nlpstate0 = NLPState (Ix (-1)) mempty mempty
    mempty mempty mempty -- labels
    mempty mempty mempty -- env
    (mempty :: NLPFun)
    (-1/0, 1/0)
    V.empty

type NLPT = StateT NLPState
type NLP = NLPT IO
   
-- ** representing functions

{- | this wrapper holds functions that can be used
for the objective (`f`) or for constraints (`g`). Many functions
in the instances provided are partial: this seems to be unavoidable
because the input variables haven't been decided yet, so you should
not be allowed to use 'compare' on these. But for now just use the
standard Prelude classes, and unimplementable functions (which
would not produce an 'AnyRF') are calls to 'error'

generate these using 'var', or perhaps by directly using the constructor:
@AnyRF $ Identity . V.sum@, would for example give the sum of all variables.
-}
data AnyRF cb = AnyRF (forall a. RealFloat a => Vector a -> cb a)

-- *** helpers for defining instances
liftOp0 :: (forall a. RealFloat a => a) -> AnyRF Identity
liftOp0 op = AnyRF $ \x -> Identity op

liftOp1 :: (forall a. RealFloat a => a -> a) -> AnyRF Identity -> AnyRF Identity
liftOp1 op (AnyRF a) = AnyRF $ \x -> Identity (op (runIdentity (a x)))

liftOp2 :: (forall a. RealFloat a => a -> a -> a) -> AnyRF Identity -> AnyRF Identity -> AnyRF Identity
liftOp2 op (AnyRF a) (AnyRF b) = AnyRF $ \x -> Identity (runIdentity (a x) `op` runIdentity (b x))

instance Num (AnyRF Identity) where
    (+) = liftOp2 (+)
    (*) = liftOp2 (*)
    (-) = liftOp2 (-)
    abs = liftOp1 abs
    signum = liftOp1 signum
    fromInteger n = liftOp0 (fromInteger n)

instance Fractional (AnyRF Identity) where
    (/) = liftOp2 (/)
    recip = liftOp1 recip
    fromRational n = liftOp0 (fromRational n)

instance Floating (AnyRF Identity) where
    pi = liftOp0 pi
    exp = liftOp1 exp
    sqrt = liftOp1 sqrt
    log = liftOp1 log
    sin = liftOp1 sin
    tan = liftOp1 tan
    cos = liftOp1 cos
    asin = liftOp1 asin
    atan = liftOp1 atan
    acos = liftOp1 acos
    sinh = liftOp1 sinh
    tanh = liftOp1 tanh
    cosh = liftOp1 cosh
    asinh = liftOp1 asinh
    atanh = liftOp1 atanh
    acosh = liftOp1 acosh
    (**) = liftOp2 (**)
    logBase = liftOp2 logBase

instance Real (AnyRF Identity) where
    toRational _ = error "Real AnyRF Identity"

instance Ord (AnyRF Identity)
instance Eq (AnyRF Identity)
instance RealFrac (AnyRF Identity) where
    properFraction = error "properFraction AnyRF"
instance RealFloat (AnyRF Identity) where
    isInfinite = error "isInfinite AnyRF"
    isNaN = error "isNaN AnyRF"
    decodeFloat = error "decodeFloat AnyRF"
    floatRange = error "floatRange AnyRF"
    isNegativeZero = error "isNegativeZero AnyRF"
    isIEEE = error "isIEEE AnyRF"
    isDenormalized = error "isDenormalized AnyRF"
    floatDigits _ = floatDigits (error "RealFrac AnyRF Identity floatDigits" :: Double)
    floatRadix _ = floatRadix   (error "RealFrac AnyRF Identity floatRadix" :: Double)
    atan2 = liftOp2 atan2
    significand = liftOp1 significand
    scaleFloat n = liftOp1 (scaleFloat n)
    encodeFloat a b = liftOp0 (encodeFloat a b)

instance Monoid (AnyRF Seq) where
    AnyRF f `mappend` AnyRF g = AnyRF (f `mappend` g)
    mempty = AnyRF mempty

instance Monoid NLPFun where
    NLPFun f g bx bg `mappend` NLPFun f' g' bx' bg' = NLPFun (f <> f') (g <> g') (bx <> bx') (bg <> bg')
    mempty = NLPFun mempty mempty mempty mempty

-- ** low level lenses to NLPState
makeLenses ''NLPFun
makeLenses ''Ix
makeLenses ''NLPState

-- | @to@ is one of 'varEnv', 'constraintEnv', 'objEnv'
copyEnv to n = do
    ev <- use currentEnv
    cloneLens to %= IM.insert n ev

-- | @to@ should be 'constraintLabels', 'objLabels', 'varLabels'
addDesc to Nothing n = return ()
addDesc to (Just x) n = do
    to %= IM.insert n x

-- * high-level functions

-- | calls 'createIpoptProblemAD' and 'ipoptSolve'. To be used at the
-- end of a do-block.
solveNLP' :: MonadIO m =>
    (IpProblem -> IO ()) -- ^ set ipopt options (using functions from "Ipopt.Raw")
    -> NLPT m IpOptSolved
solveNLP' setOpts = do
    (xl,xu) <- join $ uses (nlpfun . boundX) seqToVecs
    (gl,gu) <- join $ uses (nlpfun . boundG) seqToVecs
    AnyRF fs <- use (nlpfun . funF)
    AnyRF gs <- use (nlpfun . funG)

    p <- liftIO (createIpoptProblemAD xl xu gl gu (F.sum . fs) (V.fromList . toList . gs))
    liftIO (setOpts p)
    x0 <- uses initX (V.convert . V.map CDouble)
    r <- liftIO (ipoptSolve p =<< VS.thaw x0)
    liftIO $ freeIpoptProblem p
    return r

-- | add a constraint
addG :: Monad m
    => Maybe String -- ^ optional description
    -> (Double,Double) -- ^ bounds @(gl,gu)@ for the single inequality @gl_i <= g_i(x) <= gu_i@
    -> AnyRF Identity -- ^ @g_i(x)@
    -> NLPT m ()
addG d b (AnyRF f) = do
    nlpfun . boundG %= (Seq.|> b)
    nlpfun . funG %= \(AnyRF fs) -> AnyRF $ \x -> fs x Seq.|> runIdentity (f x)
    n <- use (nlpfun . boundG . to Seq.length)
    copyEnv constraintEnv n
    addDesc constraintLabels d n

{- | add a piece of the objective function, which is added in the form
`f_1 + f_2 + ...`, to make it easier to understand (at some point)
which components are responsible for the majority of the cost, and
which are irrelevant.
-}
addF :: Monad m
    => Maybe String -- ^ description
    -> AnyRF Identity -- ^ `f_i(x)`
    -> NLPT m ()
addF d (AnyRF f) = do
    nlpfun . funF %= \(AnyRF fs) -> AnyRF $ \x -> fs x Seq.|> runIdentity (f x)
    n <- use (objEnv . to ((+1) . IM.size))
    copyEnv objEnv n
    addDesc objLabels d n

-- | add a variable, or get a reference to the the same variable if it has
-- already been used
var' :: (Monad m, Functor m)
    => Maybe (Double,Double) -- ^ bounds @(xl,xu)@ to request that @xl <= x <= xu@.
                             -- if Nothing, you get whatever is in 'defaultBounds'
    -> String -- ^ variable name (namespace from the 'pushEnv' / 'popEnv' can
              -- make an @"x"@ you request here different from one you
              -- previously requested
    -> NLPT m (AnyRF Identity, Ix) -- ^ the value, and index (into the raw
                                   -- vector of variables that the solver
                                   -- sees)
var' bs s = do
    ev <- use currentEnv
    m <- use variables
    let s' = intercalate "." (reverse (s:ev))
    n <- case M.lookup s' m of
        Nothing -> do
            nMax %= over varIx (+1)
            db <- use defaultBounds
            nlpfun . boundX %= (Seq.|> db)
            n' <- use nMax
            variables %= M.insert s' n'
            return n'
        Just n -> return n
    varEnv %= IM.insert (view varIx n) (S.singleton ev)
    F.traverse_ (narrowBounds n) bs
    return (AnyRF $ \x -> Identity $ x V.! view varIx n, n)

-- | 'var'' without the usually unnecessary 'Ix'
var bs s = fmap fst (var' bs s)

{- | 'var', except this causes the solver to get a new variable,
so that you can use:

> [a,b,c,d,e] <- replicateM 5 (varFresh (Just (0, 10)) "x")

and the different letters can take different values (between 0 and 10)
in the optimal solution (depending on what you do with @a@ and similar
in the objective function).
-}
varFresh bs s = do
    n <- uses variables ((+1) . M.size)
    var bs (s ++ show n)

-- *** namespace
{- $namespace

When you build up an optimization problem, it may be composed of pieces.
Functions in this section help to ease the pain of making unique variables.
To illustrate:

> m <- inEnv "A" (var b "x")
> n <- var b "A.x" 

@m@ and @n@ above should refer to the same variable. In some sense this
is \"better\" that using 'varFresh' all the time, since perhaps you would
like to add dependencies between components (say the size of a header pipe,
refridgeration unit, foundation etc. has to satisfy sizes of individual
components).

-}

-- | combination of 'pushEnv' and 'popEnv'
inEnv :: Monad m => String -> NLPT m a -> NLPT m a
inEnv s action = do
    pushEnv s
    r <- action
    popEnv
    return r

pushEnv :: Monad m => String -> NLPT m ()
pushEnv s = currentEnv %= (s:)

popEnv :: Monad m => NLPT m String
popEnv = do
    n : ns <- use currentEnv
    currentEnv .= ns
    return n

-- *** bounds

-- | override bounds. Should be unnecessary given 'var' takes bounds.
setBounds :: Monad m => Ix -> (Double,Double) -> NLPT m ()
setBounds (Ix i) bs = nlpfun . boundX %= Seq.update i bs

-- | shrink the interval in which that variable is allowed.
narrowBounds :: Monad m => Ix -> (Double,Double) -> NLPT m ()
narrowBounds (Ix i) (a,b) = nlpfun . boundX . ix i %= \(a',b') -> (max a a', min b b')

-- * internal
seqToVecs :: MonadIO m => Seq (Double,Double) -> m (Vec,Vec)
seqToVecs x = let (a,b) = unzip (toList x) in liftIO $ do
    a' <- VS.thaw (VS.map CDouble (VS.fromList a))
    b' <- VS.thaw (VS.map CDouble (VS.fromList b))
    return (a',b')