{-# OPTIONS_GHC -Wall #-}

module Dvda.MultipleShooting.MSCoctave ( msCoctave
                                       , run
                                       ) where

import qualified Control.Monad.State as State
import Data.Hashable ( Hashable )
import qualified Data.HashSet as HS
import Data.List ( zipWith6 )
import Data.Maybe ( fromMaybe )

import Dvda.AD ( rad )
import Dvda.CGen  ( showMex )
import Dvda.CSE ( cse )
import Dvda.Codegen.WriteFile ( writeSourceFile )
import Dvda.Expr ( Expr(..), sym, substitute )
import Dvda.FunGraph ( (:*)(..), toFunGraph, countNodes )
import Dvda.HashMap ( HashMap )
import qualified Dvda.HashMap as HM
import Dvda.MultipleShooting.CoctaveTemplates
import Dvda.MultipleShooting.MSMonad
import Dvda.MultipleShooting.Types

{-
    min f(x) st:
    
    c(x) <= 0
    ceq(x) == 0
    A*x <= b
    Aeq*x == beq
    lb <= x <= ub
-}
type Integrator a = [Expr Double]
                   -> [Expr Double]
                   -> [Expr Double]
                   -> [Expr Double]
                   -> ([Expr Double]
                       -> [Expr Double] -> [Expr Double])
                   -> Expr Double
                   -> [Expr Double]

-- take user provided bounds and make sure they're complete
-- return functions which will lookup bounds on given state/action @ timestep, and given param
setupBounds :: (Eq a, Hashable a, Show a)
               => [(Expr a, (a,a, BCTime))]
               -> Int
               -> (Expr a -> Int -> (a,a), Expr a -> (a,a))
setupBounds userBounds nSteps = (lookupAll, lookupParam)
  where
    lookupAll x k
      | k >= nSteps = error "don't ask for bounds at timestep >= number of total timesteps"
      | otherwise = case HM.lookup (x,k) specificTimestepBounds of
        Just bnd -> bnd
        Nothing -> case HM.lookup x everyTimestepBounds of
          Just bnd -> bnd
          Nothing -> error $ "need to set bounds for \"" ++ show x ++ "\" at timestep " ++ show k

    lookupParam x = case HM.lookup x everyTimestepBounds of
        Just bnd -> bnd
        Nothing -> error $ "need to set bounds for \"" ++ show x ++ "\""

    -- bounds set at only one timestep
--    everyTimestepBounds :: HashMap (Expr a) (a,a)
    everyTimestepBounds = let
      everyTS (e,(lb,ub,ALWAYS)) = [(e,(lb,ub))]
      everyTS _ = []
      f (e,lbub) hm =
        if HM.member e hm
        then error $ "you set bounds twice for \"" ++ show e ++ "\""
        else HM.insert e lbub hm
      in foldr f HM.empty $ concatMap everyTS userBounds

    -- bounds set at specific timestep
--    specificTimestepBounds :: HashMap (Expr a, Int) (a,a)
    specificTimestepBounds = let
      specificTS (e,(lb,ub,TIMESTEP k)) = [((e,k),(lb,ub))]
      specificTS _ = []
      f (e,lbub) hm =
        if HM.member e hm
        then error $ "you set bounds twice for \"" ++ show e ++ "\""
        else HM.insert e lbub hm
      in foldr f HM.empty $ concatMap specificTS userBounds

vectorizeDvs :: [[a]] -> [[a]] -> [a] -> [a]
vectorizeDvs allStates allActions params = concat allStates ++ concat allActions ++ params

msCoctave ::
  State (Step Double) b
  -> Integrator Double
  -> Int
  -> String
  -> FilePath
  -> IO ()
msCoctave userStep' odeError n funDir name = do
  let step = State.execState userStep' $
             Step { stepStates  = Nothing
                  , stepActions = Nothing
                  , stepDxdt = Nothing
                  , stepDt = Nothing
                  , stepLagrangeTerm = Nothing
                  , stepMayerTerm = Nothing
                  , stepBounds = []
                  , stepConstraints = []
                  , stepParams = HS.empty
                  , stepConstants = HS.empty
                  , stepOutputs = HM.empty
                  , stepPeriodic = HS.empty
                  }
      getWithErr :: String -> (Step Double -> Maybe c) -> c
      getWithErr fieldName f = case f step of
        Nothing -> error $ "need to set " ++ fieldName
        Just ret -> ret

      actions = getWithErr "actions" stepActions
      dt      = getWithErr "dt"      stepDt
      (states,outputs,dxdt,lagrangeState) = let
        states'  = getWithErr "states" stepStates
        dxdt'    = getWithErr "dxdt"   stepDxdt
        outputs' = stepOutputs step
        in
         case stepLagrangeTerm step of
           Nothing -> (states',outputs',dxdt',Nothing)
           Just (lagrangeTerm,(lb,ub)) ->
             ( states' ++ [lagrangeState']
             , HM.union outputs' $ HM.fromList
               [(lagrangeStateName, lagrangeState'), (lagrangeTermName, lagrangeTerm)]
             , dxdt'++[lagrangeTerm]
             , Just (lagrangeState',(lb,ub)) )
              where
                lagrangeState' = sym lagrangeStateName
        
      params    = HS.toList (stepParams    step)
      constants = HS.toList (stepConstants step)

      allStates   = [[sym $ show x ++ "__" ++ show k | x <-  states] | k <- [0..(n-1)]]
      allActions  = [[sym $ show u ++ "__" ++ show k | u <- actions] | k <- [0..(n-1)]]
      dvs = vectorizeDvs allStates allActions params

      outputMap :: HashMap String [Expr Double]
      outputMap = HM.map f outputs
        where
          f output = zipWith (subStatesActions output) allStates allActions

      subStatesActions f x u = substitute f (zip states x ++ zip actions u)

      subAllTimesteps :: Expr Double -> [Expr Double]
      subAllTimesteps something = zipWith (subStatesActions something) allStates allActions

      (lbs,ubs) = unzip $ vectorizeDvs stateBounds actionBounds paramBounds
        where
          (getAllBounds,getParamBounds) = setupBounds bounds n
          stateBounds  = [[getAllBounds x k | x <- states ] | k <- [0..(n-1)]]
          actionBounds = [[getAllBounds u k | u <- actions] | k <- [0..(n-1)]]
          paramBounds  = [getParamBounds p | p <- params]

          bounds = stepBounds step ++ lagrangeBound
            where
              lagrangeBound = case lagrangeState of
                Nothing -> []
                Just (ls,(lb,ub)) -> [(ls,(0,0,TIMESTEP 0)),(ls, (lb, ub, ALWAYS))]

      cost = subStatesActions finalCost (last allStates) (last allActions)
        where
          finalCost = case (stepMayerTerm step, lagrangeState) of
            (Just mc, Nothing) -> mc
            (Nothing, Just (ls,_)) -> ls
            (Just mc, Just (ls,_)) -> mc + ls
            (Nothing,Nothing) -> error "need to set cost function"

      (ceq, cineq) = foldl f ([],[]) allConstraints
        where
          f (eqs,ineqs) (Constraint x EQ y) = (eqs ++ [x - y], ineqs)
          f (eqs,ineqs) (Constraint x LT y) = (eqs, ineqs ++ [x - y])
          f (eqs,ineqs) (Constraint x GT y) = (eqs, ineqs ++ [y - x])
      
          execDxdt x u = map (flip substitute (zip states x ++ zip actions u)) dxdt

          dodeConstraints = map (Constraint 0 EQ) $ concat $
                            zipWith6 odeError (init allStates) (init allActions) (tail allStates) (tail allActions)
                            (repeat execDxdt) (repeat dt)

          allConstraints = dodeConstraints ++ (concatMap (g . (fmap subAllTimesteps)) (stepConstraints step)) ++ periodicConstraints
            where
              g (Constraint [] _ _) = []
              g (Constraint _ _ []) = []
              g (Constraint (x:xs) ord (y:ys)) = Constraint x ord y : g (Constraint xs ord ys)
            
              periodicConstraints = map lookup' $ HS.toList (stepPeriodic step)
                where
                  lookup' x = fromMaybe (error $ "couldn't find periodic thing \"" ++ show x ++ "\" in hashmap")
                              $ HM.lookup x xuMap
                  xuMap = HM.fromList $ zip states  (zipWith setEqual (head  allStates) (last allStates )) ++
                                        zip actions (zipWith setEqual (head allActions) (last allActions))
                    where
                      setEqual x y = Constraint x EQ y

  (costSource,costFg0,costFg) <- do
    let costGrad = rad cost dvs
    fg0 <- toFunGraph (dvs :* constants) (cost :* costGrad)
    let fg = cse fg0
    return (showMex (name ++ "_cost") fg, fg0, fg)
  
  (constraintsSource,constraintsFg0,constraintsFg) <- do
    let cineqJacob = map (flip rad dvs) cineq
        ceqJacob   = map (flip rad dvs) ceq
    fg0 <- toFunGraph (dvs :* constants) (cineq :* ceq :* cineqJacob :* ceqJacob)
    let fg = cse fg0
    return (showMex (name ++ "_constraints") fg, fg0, fg)

  (timeSource,timeFg) <- do
    fg <- toFunGraph (dvs :* constants) (take n $ scanl (+) 0 (repeat dt))
    return (showMex (name ++ "_time") fg, fg)

  (outputSource,outputFg) <- do
    fg <- toFunGraph (dvs :* constants) (HM.elems outputMap)
    return (showMex (name ++ "_outputs") fg, fg)

  (simSource,simFg) <- do
    fg <- toFunGraph (states :* actions :* params :* constants) dxdt
    return (showMex (name ++ "_sim") fg, fg)
      
  let setupSource = writeSetupSource name dvs lbs ubs
      mexAllSource = writeMexAll name
      unstructConstsSource = writeUnstructConsts name constants
      structSource = writeToStruct name dvs params constants outputMap
      unstructSource = writeUnstruct name dvs params states allStates actions allActions
      plotSource = writePlot name outputMap

  _ <- writeSourceFile         mexAllSource funDir $ name ++ "_mex_all.m"
  _ <- writeSourceFile          setupSource funDir $ name ++ "_setup.m"
  _ <- writeSourceFile         structSource funDir $ name ++ "_struct.m"
  _ <- writeSourceFile unstructConstsSource funDir $ name ++ "_unstructConstants.m"
  _ <- writeSourceFile       unstructSource funDir $ name ++ "_unstruct.m"
  _ <- writeSourceFile           plotSource funDir $ name ++ "_plot.m"

  _ <- writeSourceFile           timeSource funDir $ name ++ "_time.c"
  _ <- writeSourceFile         outputSource funDir $ name ++ "_outputs.c"
  _ <- writeSourceFile            simSource funDir $ name ++ "_sim.c"
  _ <- writeSourceFile           costSource funDir $ name ++ "_cost.c"
  _ <- writeSourceFile    constraintsSource funDir $ name ++ "_constraints.c"

  putStrLn $ "nodes in time:        " ++ show (countNodes timeFg)
  putStrLn $ "nodes in output:      " ++ show (countNodes outputFg)
  putStrLn $ "nodes in sim:         " ++ show (countNodes simFg)
  putStrLn $ "nodes in cost:        " ++ show (countNodes costFg) ++
    " (" ++ show (countNodes costFg0) ++ " before CSE)"
  putStrLn $ "nodes in constraints: " ++ show (countNodes constraintsFg) ++
    " (" ++ show (countNodes constraintsFg0) ++ " before CSE)"
  

spring :: State (Step Double) ()
spring = do
  [x, v] <- setStates ["x","v"]
  [u]    <- setActions ["u"]
  [k, b] <- addConstants ["k", "b"]
  let cost = 2*x*x + 3*v*v + 10*u*u
  setDxdt [v, -k*x - b*v + u]
  setDt (tEnd/((fromIntegral n')-1))

  setLagrangeTerm cost (-1,2000)

  setBound x (5,5) (TIMESTEP 0)
  setBound v (0,0) (TIMESTEP 0)
  
  setBound x (-5,5) ALWAYS
  setBound v (-10,10) ALWAYS
  setBound u (-200, 200) ALWAYS

  setBound v (0,0) (TIMESTEP (n'-1))

  setPeriodic x
  setPeriodic u

tEnd :: Expr Double
tEnd = 1.5

n' :: Int
n' = 18

run :: IO ()
run = msCoctave spring simpsonsRuleError' n' "../Documents/MATLAB/" "spring"
--run = msCoctave spring eulerError' n' "../Documents/MATLAB/" "spring"