module Language.Atom.Verification
  ( Model
  , model
  -- , verify
  -- , kInduction
  -- , boundedModelChecking
  , test
  ) where

import Data.Char ()
import Data.Function ()
import Data.Int ()
import Data.List ()
import Data.Ratio ()
import Data.Word ()
import Math.SMT.Yices.Pipe
import Math.SMT.Yices.Syntax
import System.Process ()

import Language.Atom.Elaboration ()
import Language.Atom.Expressions
import Language.Atom.Scheduling

-- | A model of a scheduled program for model checking.
data Model = Model
  --{ transition :: Int -> [Name] -> ([CmdY], [CmdY])
  --, names :: [(Name, Int)]
  --, 
  --}

-- | Create a model from a scheduled program.
model :: Schedule -> [UV] -> [UA] -> Model 
model _ _ _ = Model

-- | A list of assertions captured by a model.
--assertions :: Model -> [Name]

-- | Bounded model checking starting from the initial state.
--boundedModelChecking :: Model -> Int -> Name -> IO (Maybe Witness)

-- | K-induction model checking given a min and a max k-depth.
--kInduction :: Model -> Int -> Int -> Name -> IO ()


test :: IO ()
test = do
  print test
  result <- runY test
  print result
  where
  test =
    [ DEFINE ("a", VarT "int") Nothing
    , DEFINE ("b", VarT "int") Nothing
    , ASSERT (VarE "a" := VarE "b")
    -- , ASSERT (YNe (YVar "a") (YVar "b"))
    ]
  
-- | Runs a Yices program.  Returns a list of variable values if satisfiable.
runY :: [CmdY] -> IO ResY
runY a = do
  p <- createYicesPipe "yices" []
  runCmdsY p a
  r <- checkY p
  exitY p
  return r



{-
yType :: Type -> YT
yType t = case t of
  Bool   -> YBool
  Int8   -> YInt
  Int16  -> YInt
  Int32  -> YInt
  Int64  -> YInt
  Word8  -> YNat
  Word16 -> YNat
  Word32 -> YNat
  Word64 -> YNat
  Float  -> YReal
  Double -> YReal
-}

{-
vars :: [UV] -> Int -> String
vars uvs step = concatMap (var step) uvs
  where
  var :: Int -> UV -> String
  var step uv = "(define " ++ uvName step uv ++ "::" ++ yicesType (uvType uv) ++ ")\n"

uvName :: Int -> UV -> String
uvName step (UV i _ _) = "v" ++ show i ++ "_" ++ show step

initialize :: [UV] -> String
initialize uvs = vars uvs 0 ++ concatMap initialize uvs
  where
  initialize :: UV -> String
  initialize uv@(UV _ _ (Local c)) = "(assert (= " ++ uvName 0 uv ++ " " ++ const c ++ "))\n"
  initialize (UV _ _ (External _)) = ""
  const :: Const -> String
  const c = case c of
    CBool   c -> if c then "true" else "false"
    CInt8   c -> "0b" ++ bits  8 (fromIntegral c)
    CInt16  c -> "0b" ++ bits 16 (fromIntegral c)
    CInt32  c -> "0b" ++ bits 32 (fromIntegral c)
    CInt64  c -> "0b" ++ bits 64 (fromIntegral c)
    CWord8  c -> "0b" ++ bits  8 (fromIntegral c)
    CWord16 c -> "0b" ++ bits 16 (fromIntegral c)
    CWord32 c -> "0b" ++ bits 32 (fromIntegral c)
    CWord64 c -> "0b" ++ bits 64 (fromIntegral c)
    CFloat  c -> "(/ " ++ show (numerator $ toRational c) ++ " " ++ show (denominator $ toRational c) ++ ")"
    CDouble c -> "(/ " ++ show (numerator $ toRational c) ++ " " ++ show (denominator $ toRational c) ++ ")"

  bits :: Int -> Word64 -> String
  bits 0 _ = ""
  bits n a = bits (n - 1) (div a 2) ++ show (mod a 2)

-- Time 0 to 1 is step 1.
transition :: [[[Rule]]] -> [UV] -> Int -> String
transition _ {-schedule-} uvs step = vars uvs step ++ transition
  where
  transition = "; transition " ++ show (step - 1) ++ " to " ++ show step ++ "\n"  --XXX

getUV :: [UV] -> Int -> UV
getUV [] _ = error "Verify.getUV"
getUV (uv@(UV i _ _):_) j | i == j = uv
getUV (_:a) i = getUV a i


parseVar :: [UV] -> Group -> (Int, UV, String)
parseVar uvs (G [S "=", S name, value]) = (t, getUV uvs i, parseValue value)
  where
  (i', t') = break (== '_') $ tail name
  i = read i'
  t = read $ tail t'
parseVar _ g = error $ "Verify.parseVar: " ++ show g

parseValue :: Group -> String
parseValue (S ('0':'b':a)) = "b" ++ a ++ " "
parseValue (S "true")      = "1"
parseValue (S "false")     = "0"
parseValue (S v)           = "r" ++ v ++ " "
parseValue (G [S "/", S n, S d]) = "r" ++ show (fromRational (read n % read d)) ++ " "
parseValue a               = error $ "Verify.parseValue: " ++ show a


data Heirarchy = Variable UV | Module String [Heirarchy]

vcd :: [UV] -> [(Int, UV, String)] -> String
vcd uvs signals' = header ++ samples ++ end
  where
  signals = sortBy (\ (a,_,_) (b,_,_) -> compare a b) signals'

  header = "$timescale\n  1 ms\n$end\n" ++ concatMap decl (heirarchy 0 uvs) ++ "$enddefinitions $end\n"
  (lastTime, _, _) = last signals
  end = "#" ++ show (lastTime + 1) ++ "\n"

  decl :: Heirarchy -> String
  decl (Module name subs) = "$scope module " ++ name ++ " $end\n" ++ concatMap decl subs ++ "$upscope $end\n"
  decl (Variable uv)      = declVar uv

  declVar :: UV -> String
  declVar uv@(UV i n _) = case uvType uv of
    Bool   -> "$var wire 1 "     ++ code ++ " " ++ name ++ " $end\n"
    Int8   -> "$var integer 8 "  ++ code ++ " " ++ name ++ " $end\n"
    Int16  -> "$var integer 16 " ++ code ++ " " ++ name ++ " $end\n"
    Int32  -> "$var integer 32 " ++ code ++ " " ++ name ++ " $end\n"
    Int64  -> "$var integer 64 " ++ code ++ " " ++ name ++ " $end\n"
    Word8  -> "$var wire 8 "     ++ code ++ " " ++ name ++ " $end\n"
    Word16 -> "$var wire 16 "    ++ code ++ " " ++ name ++ " $end\n"
    Word32 -> "$var wire 32 "    ++ code ++ " " ++ name ++ " $end\n"
    Word64 -> "$var wire 64 "    ++ code ++ " " ++ name ++ " $end\n"
    Float  -> "$var real 32 "    ++ code ++ " " ++ name ++ " $end\n"
    Double -> "$var real 64 "    ++ code ++ " " ++ name ++ " $end\n"
    where
    code = vcdCode i
    name = reverse $ takeWhile (/= '.') $ reverse n

  samples = concatMap sample signals

  sample (t, (UV i _ _), v) = "#" ++ show t ++ "\n" ++ v ++ vcdCode i ++ "\n"

heirarchy :: Int -> [UV] -> [Heirarchy]
heirarchy _ [] = []
heirarchy depth uvs = heirarchy' depth notvars ++ map Variable vars
  where
  isVar :: UV -> Bool
  isVar uv = length (path depth uv) == 1
  (vars, notvars) = partition isVar uvs

heirarchy' :: Int -> [UV] -> [Heirarchy]
heirarchy' _ [] = []
heirarchy' depth uvs@(a:_) = Module n (heirarchy (depth + 1) yes) : heirarchy' depth no
  where
  n = head $ path depth a
  isMod uv = n == head (path depth uv)
  (yes, no) = partition isMod uvs

path :: Int -> UV -> [String]
path depth (UV _ n _) = drop depth $ words $ map (\ c -> if c == '.' then ' ' else c) n

vcdCode :: Int -> String
vcdCode i | i < 94 =              [chr (33 + mod i 94)]
vcdCode i = vcdCode (div i 94) ++ [chr (33 + mod i 94)]

{-
bitString :: Int -> String
bitString n = if null bits then "0" else bits
  where
  bit :: Int -> Char
  bit i = if testBit n i then '1' else '0'
  bits = dropWhile (== '0') $ map bit $ reverse [0..31]
-}

-}