module Language.ImProve.Code.Modelica (codeModelica) where

import Data.Function
import Data.List
import Data.Maybe
import Text.Printf

import Language.ImProve.Code.Simulink
import Language.ImProve.Core

-- Modelica generation.
codeModelica :: Name -> Statement -> IO ()
codeModelica name stmt = do
  net <- netlist stmt >>= return . moRename
  writeFile (name ++ ".mo") $ unlines $
    [ "// Generated by ImProve."
    , ""
    , "block " ++ name
    , "\tparameter Real Period=0.001;"
    ] ++
    inputs net ++
    outputs net ++
    [ "protected" ] ++
    states net ++
    internals net ++
    [ "equation"
    , "\twhen sample(0, Period) then"
    ] ++
    equations net ++
    [ "\tend when;"
    , "end " ++ name ++ ";"
    ]
  where
  inputs    net = [ printf "\tinput %s %s;"                 (constType a) name               | (name, Inport a)    <- sortedBlocks net ]
  outputs   net = [ printf "\tdiscrete output %s %s;"       (constType a) name               | (name, Outport a)   <- sortedBlocks net ]
  states    net = [ printf "\tdiscrete %s %s (start = %s);" (constType a) name (showConst a) | (name, UnitDelay a) <- sortedBlocks net ]
  internals net = [ printf "\tdiscrete %s %s;"           (constType $ netType net name) name | (name, block) <- sortedBlocks net, isInternal block ]
  equations net = mapMaybe (equation net) $ blocks net
  sortedBlocks net = sortBy (compare `on` fst) $ blocks net

equation :: Netlist -> (Name, Block) -> Maybe String
equation net (name, block) = case block of
  Inport  _      -> Nothing
  Outport _      -> f $ arg 0
  UnitDelay _    -> f $ arg 0
  Cast _         -> f $ printf "pre(%s)" $ arg 0
  Assertion      -> Just $ printf "\tassert(%s, \"%s\");" (arg 0) name
  Const' c       -> f $ showConst c
  Add'           -> f $ printf "%s + %s" (arg 0) (arg 1)
  Sub'           -> f $ printf "%s - %s" (arg 0) (arg 1)
  Mul'           -> f $ printf "%s * %s" (arg 0) (arg 1)
  Div'           -> case netType net name of
		      Bool  _ -> error "Modelica.equation: invalid netlist (1)"
                      Int   _ -> f $ printf "div(%s, %s)" (arg 0) (arg 1)
		      Float _ -> f $ printf "%s / %s"     (arg 0) (arg 1)
  Mod'           -> f $ printf "mod(%s, %s)" (arg 0) (arg 1)
  Not'           -> f $ printf "not %s" $ arg 0
  And'           -> f $ printf "%s and %s" (arg 0) (arg 1)
  Or'            -> f $ printf "%s or %s" (arg 0) (arg 1)
  Eq'            -> f $ printf "%s == %s" (arg 0) (arg 1)
  Lt'            -> f $ printf "%s < %s" (arg 0) (arg 1)
  Gt'            -> f $ printf "%s > %s" (arg 0) (arg 1)
  Le'            -> f $ printf "%s <= %s" (arg 0) (arg 1)
  Ge'            -> f $ printf "%s >= %s" (arg 0) (arg 1)
  Mux'           -> f $ printf "if %s then %s else %s" (arg 0) (arg 1) (arg 2)
  where
  f :: String -> Maybe String
  f a = Just $ printf "\t%s = %s;" name a
  arg i = case [ n | (n, (n1, p1)) <- nets net, n1 == name, p1 == i ] of
    [n] -> n
    _ -> error "Modelica.equation: invalid netlist (2)"

isInternal :: Block -> Bool
isInternal a = case a of
  Inport    _ -> False
  Outport   _ -> False
  UnitDelay _ -> False
  Assertion   -> False
  _           -> True

moRename :: Netlist -> Netlist
moRename net = net { env = map f $ env net, blocks = [ (f n, b) | (n, b) <- blocks net ], nets = [ (f a, (f b, i)) | (a, (b, i)) <- nets net ] }
  where
  f n = "`" ++ n ++ "`"

showConst :: Const -> String
showConst a = case a of
  Bool  True  -> "true"
  Bool  False -> "false"
  Int   a     -> show a
  Float a     -> show a

constType :: Const -> String
constType a = case a of
  Bool  _ -> "Boolean"
  Int   _ -> "Integer"
  Float _ -> "Real   "

netType :: Netlist -> Name -> Const
netType net name = case fromJust $ lookup name $ blocks net of
  Inport    a -> a
  UnitDelay a -> a
  Const'    a -> a
  Outport _ -> error "Modelica.netType: not expecting Outport"
  Assertion -> error "Modelica.netType: not expecting Assertion"
  Mod' -> Int 0
  Not' -> Bool False
  And' -> Bool False
  Or'  -> Bool False
  Eq'  -> Bool False
  Lt'  -> Bool False
  Gt'  -> Bool False
  Le'  -> Bool False
  Ge'  -> Bool False
  Mux' -> follow 1
  _    -> follow 0
  where
  follow :: Int -> Const
  follow p = case [ n | (n, (n1, p1)) <- nets net, n1 == name, p1 == p ] of
    [n] -> netType net n
    _   -> error "Modelica.netType: invalid netlist"

{-
data Block
  = Inport  Const
  | Outport Const
  | UnitDelay Const
  | Cast String
  | Assertion
  | Const' Const
  | Add'
  | Sub'
  | Mul'
  | Div'
  | Mod'
  | Not'
  | And'
  | Or'
  | Eq'
  | Lt'
  | Gt'
  | Le'
  | Ge'
  | Mux'

data Netlist = Netlist
  { nextId :: Int
  , path   :: Path
  , vars   :: [Path]
  , env    :: [Name]
  , blocks :: [(Name, Block)]
  , nets   :: [(Name, (Name, Int))]
  }
-}