{-# OPTIONS_GHC -Wall #-}

module Dvda.MultipleShooting.MSMonad ( State
                                     , setStates
                                     , setActions
                                     , addParam
                                     , addParams
                                     , addConstant
                                     , addConstants
                                     , setDxdt
                                     , setLagrangeTerm
                                     , setMayerTerm
                                     , setDt
                                     , addOutput
                                     , setPeriodic
                                     , addConstraint
                                     , setBound
                                     , lagrangeStateName
                                     , lagrangeTermName
                                     ) where

import Data.Hashable ( Hashable )
import qualified Data.HashSet as HS
import Data.List ( nub, sort )
import Data.Maybe ( isJust, fromMaybe )
import Data.Monoid ( mappend )
import Control.Monad ( when, zipWithM_ )
import Control.Monad.State ( State )
import qualified Control.Monad.State as State

import qualified Dvda.HashMap as HM

import Dvda.Expr ( Expr(..), sym )
import Dvda.MultipleShooting.Types

lagrangeStateName,lagrangeTermName :: String
lagrangeStateName = "lagrangeState"
lagrangeTermName = "lagrangeTerm"

failDuplicates :: [String] -> [String]
failDuplicates names
  | length names == length (nub names) = names
  | otherwise = error $ "ERROR: saw duplicate names in: " ++ show (sort names)

checkOctaveName :: String -> String
checkOctaveName name
  | any (`elem` badChars) name =
    error $ "ERROR: saw illegal octave variable character in string: \"" ++ name ++
    "\", illegal characters: " ++ badChars
  | name == lagrangeStateName = error "don't call your variable \"" ++ lagrangeStateName ++ "\", it's reserved"
  | name == lagrangeTermName = error "don't call your variable \"" ++ lagrangeTermName ++ "\", it's reserved"
  | otherwise = name
  where
    badChars = "\"'~!@#$%^&*()+`-=[]{}\\|;:,.<>/?"

setStates :: [String] -> State (Step a) [Expr a]
setStates names' = do
  step <- State.get
  case stepStates step of Just _ -> error "states already set, don't call setStates twice"
                          Nothing -> do
                            let names = failDuplicates (map checkOctaveName names')
                                syms = map sym (failDuplicates names)
                            State.put $ step {stepStates = Just syms}
                            zipWithM_ addOutput syms names
                            return syms

setActions :: [String] -> State (Step a) [Expr a]
setActions names' = do
  step <- State.get
  case stepActions step of Just _ -> error "actions already set, don't call setActions twice"
                           Nothing -> do
                             let names = failDuplicates (map checkOctaveName names')
                                 syms = map sym (failDuplicates names)
                             State.put $ step {stepActions = Just syms}
                             zipWithM_ addOutput syms names
                             return syms

addParam :: (Eq a, Hashable a) => String -> State (Step a) (Expr a)
addParam name = do
  [blah] <- addParams [name]
  return blah

addConstant :: (Eq a, Hashable a) => String -> State (Step a) (Expr a)
addConstant name = do
  [blah] <- addConstants [name]
  return blah

addParams :: (Eq a, Hashable a) => [String] -> State (Step a) [Expr a]
addParams names = do
  step  <- State.get
  let syms = map (sym . checkOctaveName) names
      params0 = stepParams step
  State.put $ step {stepParams = HS.union params0 (HS.fromList syms)}
  return syms

addConstants :: (Eq a, Hashable a) => [String] -> State (Step a) [Expr a]
addConstants names = do
  step  <- State.get
  let syms = map (sym . checkOctaveName) names
      constants0 = stepConstants step
  State.put $ step {stepConstants = HS.union constants0 (HS.fromList syms)}
  return syms

addOutput :: Expr a -> String -> State (Step a) ()
addOutput var name = do
  step <- State.get
  let hm = stepOutputs step
      err = error $ "ERROR: already have an output with name: \"" ++ name ++ "\""
  State.put $ step {stepOutputs = HM.insertWith err (checkOctaveName name) var hm}

setDt :: Expr a -> State (Step a) ()
setDt expr = do
  step  <- State.get
  when (isJust (stepDt step)) $ error "dt already set, don't call setDt twice"
  State.put $ step {stepDt = Just expr}

setPeriodic :: (Eq a, Hashable a, Show a) => Expr a -> State (Step a) ()
setPeriodic var = do
  step <- State.get
  let newPeriodic
        | var `HS.member` (stepPeriodic step) = error $ "you called setPeriodic twice on \"" ++ show var ++ "\""
        | not (var `elem` (fromMaybe [] (mappend (stepStates step) (stepActions step)))) =
          error $ "you can only make states or actions periodic, you can't make \"" ++ show var ++ "\" periodic"
        | otherwise = HS.insert var (stepPeriodic step)
  State.put $ step {stepPeriodic = newPeriodic}
-------------------------------------------

setDxdt :: [Expr a] -> State (Step a) ()
setDxdt vars = do
  step  <- State.get
  when (isJust (stepDxdt step)) $ error "dxdt already set, don't call setDxdt twice"
  State.put $ step {stepDxdt = Just vars}

setLagrangeTerm :: Expr a -> (a,a) -> State (Step a) ()
setLagrangeTerm var (lb,ub) = do
  step  <- State.get
  when (isJust (stepLagrangeTerm step)) $ error "Lagrange term already set, don't call setLagrangeTerm twice"
  State.put $ step {stepLagrangeTerm = Just (var,(lb,ub))}

setMayerTerm :: Expr a -> State (Step a) ()
setMayerTerm var = do
  step  <- State.get
  when (isJust (stepMayerTerm step)) $ error "Mayer term already set, don't call setMayerTerm twice"
  State.put $ step {stepMayerTerm = Just var}

setBound :: (Show a, Eq a, Hashable a)
            => Expr a -> (a, a) -> BCTime -> State (Step a) ()
setBound var@(ESym _) (lb, ub) bctime = do
  step <- State.get
  State.put $ step {stepBounds = (var, (lb,ub,bctime)):(stepBounds step)}
setBound _ _ _ = error "WARNING - setBound called on non-design variable, try addConstraint instead"

addConstraint :: Expr a -> Ordering -> Expr a -> State (Step a) ()
addConstraint x ordering y =
  State.state (\step -> ((), step {stepConstraints = (stepConstraints step) ++ [Constraint x ordering y]}))