module Lua.Bytecode5_1.Types where

import Debug.Trace

import Data.Word
import Data.Int
import Data.Bits
import Data.ByteString
import Data.Serialize
import Data.Serialize.Get
import Data.Serialize.IEEE754
import Data.Array
import Data.Map (Map)
import Data.Vector (Vector, fromList)
import Data.Hashable

class LuaGet a where
	luaGet :: Get a

runLuaGet :: (LuaGet a) => ByteString -> Either String a
runLuaGet = runGet luaGet

instance LuaGet Word8 where
	luaGet = getWord8

instance LuaGet Int32 where
	luaGet = getWord32le >>= return . fromIntegral

instance LuaGet Word32 where
	luaGet = getWord32le

instance LuaGet ByteString where
	luaGet = getByteString32

instance LuaGet Double where
	luaGet = getFloat64le

instance (LuaGet a) => LuaGet [a] where
	luaGet = getList32

instance (LuaGet a) => LuaGet (Vector a) where
	luaGet = luaGet >>= return . fromList

getList32 :: (LuaGet a) => Get [a]
getList32 = go [] =<< getWord32le
	where
		go as 0 = return (Prelude.reverse as)
		go as i = do
			x <- luaGet
			x `seq` go (x:as) (i - 1)

getByteString32 :: Get ByteString
getByteString32 = do
	--trace "getByteString32" (return ())
	length <- getWord32le
	--trace ("bytestring length is " ++ show length) (return ())
	let length' = fromIntegral length
	getByteString length'

type Table = Map Constant Constant

data Constant = NIL | BOOLEAN Bool | NUMBER Double | STRING ByteString | TABLE Table | CLOSURE Closure
	deriving (Eq, Ord, Show)

instance Hashable Constant where
	hashWithSalt salt c =
		salt +
			case c of
				BOOLEAN b -> 0 + hashWithSalt salt b
				NUMBER d -> 1 + hashWithSalt salt d
				STRING bs -> 2 + hashWithSalt salt bs
				CLOSURE _ -> error "Data.Hashable.hashWithSalt Lua.Bytecode.Types.CLOSURE"
				TABLE _ -> error "Data.Hashable.hashWithSalt Lua.Bytecode.Types.TABLE"
				NIL -> error "Data.Hashable.hashWithSalt Lua.Bytecode.Types.NIL"

instance LuaGet Constant where
	luaGet = getConstant

getConstant :: Get Constant
getConstant = do
	b <- getWord8
	case b of
		0 -> return NIL
		1 -> do
			b <- getFloat64le
			return $ BOOLEAN $ b == 1
		3 -> do
			n <- getFloat64le
			return $ NUMBER n
		4 -> do
			s <- getByteString32
			return $ STRING s
		n -> fail $ "Unexpected identifier for a Lua.Bytecode.Types.Constant: " ++ show n ++ "."

data Local = Local {
	  name :: ByteString
	, startPc :: Word32
	, endPc :: Word32
}
	deriving (Show)

instance LuaGet Local where
	luaGet = getLocal

getLocal :: Get Local
getLocal = do
	name <- luaGet
	startPc <- luaGet
	endPc <- luaGet
	return $ Local {..}

data LuaTypes = None | Nil | LBool | LightUser | Number | LString | Table | Function | UserData | Thread | NumTags

instance Enum LuaTypes where
	fromEnum x = case x of
		None -> -1
		Nil -> 0
		LBool -> 1
		LightUser -> 2
		Number -> 3
		LString -> 4
		Table -> 5
		Function -> 6
		UserData -> 7
		Thread -> 8
		NumTags -> 9
	toEnum x = case x of
		-1 -> None
		0 -> Nil
		1 -> LBool
		2 -> LightUser
		3 -> Number
		4 -> LString
		5 -> Table
		6 -> Function
		7 -> UserData
		8 -> Thread
		9 -> NumTags

instance Bounded LuaTypes where
	minBound = None
	maxBound = NumTags

luaIBitsInt = 32

sizeC :: Int
sizeC = 9
sizeB :: Int
sizeB = 9
sizeBx = sizeC + sizeB
sizeA :: Int
sizeA = 8
--sizeAx = sizeC + sizeB + sizeA

sizeOp = 6

posOp = 0
posA = posOp + sizeOp
posC = posA + sizeA
posB = posC + sizeC
posBx = posC
--posAx = posA

maxArgBx :: Word32
maxArgBx = (1 `shiftL` sizeBx) - 1

maxArgSBx :: Word32
maxArgSBx = maxArgBx `shiftR` 1

--maxArgAx :: Word32
--maxArgAx = (1 `shiftL` sizeAx) - 1

maxArgA :: Word32
maxArgA = (1 `shiftL` sizeA) - 1
maxArgB :: Word32
maxArgB = (1 `shiftL` sizeB) - 1
maxArgC :: Word32
maxArgC = (1 `shiftL` sizeC) - 1

mask1 n p = (complement (complement 0 `shiftL` n)) `shiftL` p

mask0 n p = complement $ mask1 n p

getOpcode :: Word32 -> Word32
getOpcode i = i `shiftR` posOp .&. mask1 sizeOp 0
setOpcode i o = (i .&. mask0 sizeOp posOp) .|.
	(o `shiftL` posOp .&. mask1 sizeOp posOp)

getArg i pos size = i `shiftR` pos .&. mask1 size 0
setArg i v pos size = (i .&. mask0 size pos) .|.
	(v `shiftL` pos .&. mask1 size pos)

getArgA i = getArg i posA sizeA
setArgA i v = setArg i v posA sizeA

getArgB i = getArg i posB sizeB
setArgB i v = setArg i v posB sizeB

getArgC i = getArg i posC sizeC
setArgC i v = setArg i v posC sizeC

getArgBx :: Word32 -> Word32
getArgBx i = getArg i posBx sizeBx
setArgBx i v = setArg i v posBx sizeBx

--getArgAx i = getArg i posAx sizeAx
--setArgAx i v = setArg i v posAx sizeAx

getArgSBx :: Word32 -> Word32
getArgSBx i = getArgBx i - maxArgSBx

setArgSBx :: Word32 -> Word32 -> Word32
setArgSBx i b = setArgBx i $ b + maxArgSBx

createABC o a b c = o `shiftL` posOp
	.|. a `shiftL` posA
	.|. b `shiftL` posB
	.|. c `shiftL` posC

createABx o a bc = o `shiftL` posOp
	.|. a `shiftL` posA
	.|. bc `shiftL` posBx

--createAx o a = o `shiftL` posOp .|. a `shiftL` posAx

bitRK :: Word32
bitRK = 1 `shiftL` (sizeB - 1)

isK = (>0) . (.&. bitRK)

-- Left means a constant, Right means a register.
k :: Word32 -> Either Word32 Word32
k x =
	let stripped = x .&. mask1 7 0 in
	(if isK x
	then Left 
	else Right) stripped

data Opcode =
	  OP_MOVE
	| OP_LOADK
	| OP_LOADBOOL
	| OP_LOADNIL
	| OP_GETUPVAL
	| OP_GETGLOBAL
	| OP_GETTABLE
	| OP_SETGLOBAL
	| OP_SETUPVAL
	| OP_SETTABLE
	| OP_NEWTABLE
	| OP_SELF
	| OP_ADD
	| OP_SUB
	| OP_MUL
	| OP_DIV
	| OP_MOD
	| OP_POW
	| OP_UNM
	| OP_NOT
	| OP_LEN
	| OP_CONCAT
	| OP_JMP
	| OP_EQ
	| OP_LT
	| OP_LE
	| OP_TEST
	| OP_TESTSET
	| OP_CALL
	| OP_TAILCALL
	| OP_RETURN
	| OP_FORLOOP
	| OP_FORPREP
	| OP_TFORLOOP
	| OP_SETLIST
	| OP_CLOSE
	| OP_CLOSURE
	| OP_VARARG
	deriving (Show,Enum,Bounded,Ord,Eq,Ix)

numOpcodes = fromEnum (maxBound :: Opcode) + 1

instance LuaGet Opcode where
	luaGet = do
		w <- getWord32le
		let opNum = getOpcode w
		return $ toEnum $ fromIntegral opNum

data OpMode = ABC | ABx | AsBx
	deriving (Show,Enum,Bounded,Ord,Eq,Ix)

data OpArgMask = N | U | R | K
	deriving (Show,Enum,Bounded,Ord,Eq,Ix)

data OpcodeFields = OpcodeFields {
	  a :: Word32
	, b :: Word32
	, bx :: Word32
	, sBx :: Word32
	, c :: Word32
  	, opMode :: Word8
  	, bMode :: Word8
  	, cMode :: Word8
  	, testAMode :: Word8
  	, testTMode :: Word8
}
	deriving (Eq, Ord, Show)

toOpcodeFields :: Word32 -> OpcodeFields
toOpcodeFields w =
	let o = toEnum . fromIntegral $ getOpcode w in
	OpcodeFields {
		  a = getArgA w
	  	, b = getArgB w
	  	, bx = getArgBx w
	  	, sBx = getArgSBx w
	  	, c = getArgC w
	  	, opMode = (opModes ! o) .&. 3
	  	, bMode = ((opModes ! o) `shiftR` 4) .&. 3
	  	, cMode = ((opModes ! o) `shiftR` 2) .&. 3
	  	, testAMode = (opModes ! o) .&. (1 `shiftL` 6)
	  	, testTMode = (opModes ! o) .&. (1 `shiftL` 7)
	}

data Operation = Operation {
	  operator :: Opcode
	, fields :: OpcodeFields
}
	deriving (Eq, Ord, Show)

type Operations = Vector Operation

instance LuaGet Operation where
	luaGet = do
		w <- getWord32le
		let opNum = getOpcode w
		let operator = toEnum $ fromIntegral opNum
		let fields = toOpcodeFields w
		return Operation {..}

opModes :: Array Opcode Word8
opModes = listArray (minBound,maxBound) [
	  opFromModes 0 1 R N ABC
	, opFromModes 0 1 K N ABx
	, opFromModes 0 1 U U ABC
	, opFromModes 0 1 R N ABC
	, opFromModes 0 1 U N ABC
	, opFromModes 0 1 K N ABx
	, opFromModes 0 1 R K ABC
	, opFromModes 0 0 K N ABx
	, opFromModes 0 0 U N ABC
	, opFromModes 0 0 K K ABC
	, opFromModes 0 1 U U ABC
	, opFromModes 0 1 R K ABC
	, opFromModes 0 1 K K ABC
	, opFromModes 0 1 K K ABC
	, opFromModes 0 1 K K ABC
	, opFromModes 0 1 K K ABC
	, opFromModes 0 1 K K ABC
	, opFromModes 0 1 K K ABC
	, opFromModes 0 1 R N ABC
	, opFromModes 0 1 R N ABC
	, opFromModes 0 1 R N ABC
	, opFromModes 0 1 R R ABC
	, opFromModes 0 0 R N AsBx
	, opFromModes 1 0 K K ABC
	, opFromModes 1 0 K K ABC
	, opFromModes 1 0 K K ABC
	, opFromModes 1 1 R U ABC
	, opFromModes 1 1 R U ABC
	, opFromModes 0 1 U U ABC
	, opFromModes 0 1 U U ABC
	, opFromModes 0 0 U N ABC
	, opFromModes 0 1 R N AsBx
	, opFromModes 0 1 R N AsBx
	, opFromModes 1 0 N U ABC
	, opFromModes 0 0 U U ABC
	, opFromModes 0 0 N N ABC
	, opFromModes 0 1 U N ABx
	, opFromModes 0 1 U N ABC
	]
	where
		opFromModes :: Word8 -> Word8 -> OpArgMask -> OpArgMask -> OpMode -> Word8
		opFromModes t a b c m =
			(t `shiftL` 7) .|.
			(a `shiftL` 6) .|.
			(word8FromEnum' b `shiftL` 4) .|.
			(word8FromEnum' c `shiftL` 2) .|.
			word8FromEnum' m
		word8FromEnum' :: Enum a => a -> Word8
		word8FromEnum' = fromIntegral . fromEnum

type Stack = Vector Constant

data GlobalState = GlobalState { }

type StackId = Int

data State = State {
	  top :: StackId                -- first free slot in the stack
	, baseS :: StackId               -- base of the current function
	, global :: GlobalState
	, ci :: CallInfo
	, instructions :: Operations    --operations in the current function
	, stackLast :: StackId          --last free slot in the stack
	, stack :: StackId              --stack base
	, stackSize :: Int
	, globalTable :: Table
}

g = globalTable

data Closure = Closure {
	  upVals :: Stack
	, isC :: Bool
	, proto :: Prototype
}
	deriving (Ord, Eq, Show)

data Prototype = Prototype {
	  constants :: Stack
	, code :: Operations
	, prototypes :: Prototypes
	, localVariables :: Stack
	, upValues :: Vector String
	, numParams :: Word8
	, isVarArg :: Bool
	, maxStackSize :: Word8
}
	deriving (Ord, Eq, Show)

type Prototypes = Vector Prototype

data CallInfo = CallInfo {
	  baseCI :: StackId
	, function :: StackId
	, savedPc :: Int
	, nResults :: Word8
}