module Language.Atom.Verify (verify) where

import Data.Char
import Data.Int
import Data.List
import Data.Ratio
import Data.Word
import System.Process

import Language.Atom.Elaboration
import Language.Atom.Expressions

verify :: Int -> [[[Rule]]] -> [UV] -> (String, UE) -> IO Bool
verify depth rules uvs (name, assert) = do
  putStrLn $ "Checking assertion " ++ name ++ "..."
  verify
      
  where
  verify | assert == ubool True  = return True
         | assert == ubool False = putStrLn ("Assertion failed trivially.") >> return False
         | otherwise = do
             --putStrLn yices
             --putStrLn "***********************************************"
             out <- readProcess "yices" ["-e"] yices
             case parseYices out of
               S "unsat" : _ -> return True
               S "sat"   : vars' -> do
                 let vars = map (parseVar uvs) vars'
                 putStrLn ("Assertion failed.  See counter example: " ++ name ++ ".vcd")
                 writeFile (name ++ ".vcd") $ vcd uvs vars
                 return False
               _         -> error "Unexpected results from yices."
      
  assertUE :: Int -> String
  assertUE step = "(assert (not " ++ f ++ "))\n"
    where
    f = case assert of
      UVRef uv -> uvName step uv
      _        -> error "expressions not supported assertions yet"

  asserts = concatMap assertUE [0..depth]
  yices = initialize uvs ++ concatMap (transition rules uvs) [1..depth] ++ asserts ++ "(check)"

yicesType :: Type -> String
yicesType t = case t of
  Bool   -> "bool"
  Int8   -> "(bitvector  8)"
  Int16  -> "(bitvector 16)"
  Int32  -> "(bitvector 32)"
  Int64  -> "(bitvector 64)"
  Word8  -> "(bitvector  8)"
  Word16 -> "(bitvector 16)"
  Word32 -> "(bitvector 32)"
  Word64 -> "(bitvector 64)"
  Float  -> "real"
  Double -> "real"

vars :: [UV] -> Int -> String
vars uvs step = concatMap (var step) uvs
  where
  var :: Int -> UV -> String
  var step uv = "(define " ++ uvName step uv ++ "::" ++ yicesType (uvType uv) ++ ")\n"

uvName :: Int -> UV -> String
uvName step (UV i _ _) = "v" ++ show i ++ "_" ++ show step

initialize :: [UV] -> String
initialize uvs = vars uvs 0 ++ concatMap initialize uvs
  where
  initialize :: UV -> String
  initialize uv@(UV _ _ c) = "(assert (= " ++ uvName 0 uv ++ " " ++ const c ++ "))\n"
  const :: Const -> String
  const c = case c of
    CBool   c -> if c then "true" else "false" 
    CInt8   c -> "0b" ++ bits  8 (fromIntegral c)
    CInt16  c -> "0b" ++ bits 16 (fromIntegral c)
    CInt32  c -> "0b" ++ bits 32 (fromIntegral c)
    CInt64  c -> "0b" ++ bits 64 (fromIntegral c)
    CWord8  c -> "0b" ++ bits  8 (fromIntegral c)
    CWord16 c -> "0b" ++ bits 16 (fromIntegral c)
    CWord32 c -> "0b" ++ bits 32 (fromIntegral c)
    CWord64 c -> "0b" ++ bits 64 (fromIntegral c)
    CFloat  c -> "(/ " ++ show (numerator $ toRational c) ++ " " ++ show (denominator $ toRational c) ++ ")"
    CDouble c -> "(/ " ++ show (numerator $ toRational c) ++ " " ++ show (denominator $ toRational c) ++ ")"

  bits :: Int -> Word64 -> String
  bits 0 _ = ""
  bits n a = bits (n - 1) (div a 2) ++ show (mod a 2)

-- Time 0 to 1 is step 1.
transition :: [[[Rule]]] -> [UV] -> Int -> String
transition _ {-schedule-} uvs step = vars uvs step ++ transition
  where
  transition = "; transition " ++ show (step - 1) ++ " to " ++ show step ++ "\n"  --XXX

getUV :: [UV] -> Int -> UV
getUV [] _ = error "Verify.getUV"
getUV (uv@(UV i _ _):_) j | i == j = uv
getUV (_:a) i = getUV a i


parseVar :: [UV] -> Group -> (Int, UV, String)
parseVar uvs (G [S "=", S name, value]) = (t, getUV uvs i, parseValue value)
  where
  (i', t') = break (== '_') $ tail name
  i = read i'
  t = read $ tail t'
parseVar _ g = error $ "Verify.parseVar: " ++ show g

parseValue :: Group -> String
parseValue (S ('0':'b':a)) = "b" ++ a ++ " "
parseValue (S "true")      = "1"
parseValue (S "false")     = "0"
parseValue (S v)           = "r" ++ v ++ " "
parseValue (G [S "/", S n, S d]) = "r" ++ show (fromRational (read n % read d)) ++ " "
parseValue a               = error $ "Verify.parseValue: " ++ show a


data Group = G [Group] | S String deriving Show

parseYices :: String -> [Group]
parseYices = groups . tokens

groups :: [String] -> [Group]
groups [] = []
groups ("(":a) = G (groups x) : groups y where (x, y) = split 0 [] a
groups (a:b) = S a : groups b

split :: Int -> [String] -> [String] -> ([String], [String])
split 0 a (")":b) = (reverse a, b)
split n a ("(":b) = split (n + 1) ("(" : a) b
split n a (")":b) = split (n - 1) (")" : a) b
split n a (b:c)   = split n (b:a) c
split _ _ []      = error "Verify.split"

tokens :: String -> [String]
tokens [] = []
tokens (a:b) | isSpace a = tokens b
tokens ('(':b) = "(" : tokens b
tokens (')':b) = ")" : tokens b
tokens a = tokens' "" a

tokens' :: String -> String -> [String]
tokens' a [] = [reverse a]
tokens' a c@(b:_) | isSpace b || elem b "()" = reverse a : tokens c
tokens' a (b:c) = tokens' (b:a) c

data Heirarchy = Variable UV | Module String [Heirarchy]

vcd :: [UV] -> [(Int, UV, String)] -> String
vcd uvs signals' = header ++ samples ++ end
  where
  signals = sortBy (\ (a,_,_) (b,_,_) -> compare a b) signals' 

  header = "$timescale\n  1 ms\n$end\n" ++ concatMap decl (heirarchy 0 uvs) ++ "$enddefinitions $end\n"
  (lastTime, _, _) = last signals
  end = "#" ++ show (lastTime + 1) ++ "\n"

  decl :: Heirarchy -> String
  decl (Module name subs) = "$scope module " ++ name ++ " $end\n" ++ concatMap decl subs ++ "$upscope $end\n"
  decl (Variable uv)      = declVar uv

  declVar :: UV -> String
  declVar uv@(UV i n _) = case uvType uv of
    Bool   -> "$var wire 1 "     ++ code ++ " " ++ name ++ " $end\n"
    Int8   -> "$var integer 8 "  ++ code ++ " " ++ name ++ " $end\n"
    Int16  -> "$var integer 16 " ++ code ++ " " ++ name ++ " $end\n"
    Int32  -> "$var integer 32 " ++ code ++ " " ++ name ++ " $end\n"
    Int64  -> "$var integer 64 " ++ code ++ " " ++ name ++ " $end\n"
    Word8  -> "$var wire 8 "     ++ code ++ " " ++ name ++ " $end\n"
    Word16 -> "$var wire 16 "    ++ code ++ " " ++ name ++ " $end\n"
    Word32 -> "$var wire 32 "    ++ code ++ " " ++ name ++ " $end\n"
    Word64 -> "$var wire 64 "    ++ code ++ " " ++ name ++ " $end\n"
    Float  -> "$var real 32 "    ++ code ++ " " ++ name ++ " $end\n"
    Double -> "$var real 64 "    ++ code ++ " " ++ name ++ " $end\n"
    where
    code = vcdCode i
    name = reverse $ takeWhile (/= '.') $ reverse n

  samples = concatMap sample signals

  sample (t, (UV i _ _), v) = "#" ++ show t ++ "\n" ++ v ++ vcdCode i ++ "\n"

heirarchy :: Int -> [UV] -> [Heirarchy]
heirarchy _ [] = []
heirarchy depth uvs = heirarchy' depth notvars ++ map Variable vars
  where
  isVar :: UV -> Bool
  isVar uv = length (path depth uv) == 1
  (vars, notvars) = partition isVar uvs

heirarchy' :: Int -> [UV] -> [Heirarchy]
heirarchy' _ [] = []
heirarchy' depth uvs@(a:_) = Module n (heirarchy (depth + 1) yes) : heirarchy' depth no
  where
  n = head $ path depth a
  isMod uv = n == head (path depth uv)
  (yes, no) = partition isMod uvs

path :: Int -> UV -> [String]
path depth (UV _ n _) = drop depth $ words $ map (\ c -> if c == '.' then ' ' else c) n
 
vcdCode :: Int -> String
vcdCode i | i < 94 =              [chr (33 + mod i 94)] 
vcdCode i = vcdCode (div i 94) ++ [chr (33 + mod i 94)] 

{-
bitString :: Int -> String
bitString n = if null bits then "0" else bits
  where
  bit :: Int -> Char
  bit i = if testBit n i then '1' else '0'
  bits = dropWhile (== '0') $ map bit $ reverse [0..31]
-}