{-# LANGUAGE FlexibleInstances #-}

module Feldspar.Compiler.Backend.C.CodeGeneration where

import Feldspar.Compiler.Imperative.Representation
import Feldspar.Compiler.Error
import Feldspar.Compiler.Backend.C.Options
import Feldspar.Compiler.Backend.C.Library

import qualified Data.List as List (last,find)

-- =======================
-- == C code generation ==
-- =======================

codeGenerationError = handleError "CodeGeneration"

defaultMemberName = "member"

data Place =
      Declaration_pl
      --value of var,           need type,          type array-style
      --declare variables
    | MainParameter_pl
      --value of var            need type,          type pointer-style
      --main fun parameters
    | ValueNeed_pl
      --value of var,           not need type       -
      --in Expressions
    | AddressNeed_pl
      --access of var,          not need type       -
      --output of fun
    | FunctionCallIn_pl
      --value of var,           not need type       - SPEC ARRAY FORMAT
      --input of fun
    deriving (Eq,Show)

class ToC a where
    toC :: Options -> Place -> a -> String

getStructTypeName :: Options -> Place -> Type -> String
getStructTypeName options place t@(StructType types) =
    "_" ++ concat (map ((++"_") . getStructTypeName options place . snd) types)
getStructTypeName options place t@(UnionType types) =
    "_" ++ concat (map ((++"_") . getStructTypeName options place . snd) types)
getStructTypeName options place t@(ArrayType len innerType) = 
    "arr_T" ++ getStructTypeName options place innerType ++ "_S" ++ len2str len
    where
        len2str :: Length -> String
        len2str UndefinedLen = "UD"
        len2str (LiteralLen i) = show i
        len2str (IndirectLen s) = s
getStructTypeName options place t = replace (toC options place t) " " "" -- float complex -> floatcomplex

instance ToC Type where
    toC options place t@(StructType types) = "struct s" ++ getStructTypeName options place t
    toC options place t@(UnionType types) = "union u" ++ getStructTypeName options place t
    toC options place (UserType u) = u
    toC options place VoidType = "void"
    -- arraytype handled in variable
    toC options place t = case (List.find (\(t',_,_) -> t == t') $ types $ platform options) of
        Just (_,s,_)  -> s
        Nothing       -> codeGenerationError InternalError $
                         "Unhandled type in platform " ++ (name $ platform options) ++ ": " ++ show t ++ " place: " ++ show place

instance ToC (Variable ()) where
    toC options place a@(Variable name typ role _) = show_variable options place role typ name

show_variable :: Options -> Place -> VariableRole -> Type -> String -> String
show_variable options place role typ name  = listprint id " " [variableType, show_name role place typ name] where
    variableType = show_type options place typ restr
    restr
        | place == MainParameter_pl = isRestrict $ platform options
        | otherwise = NoRestrict

show_type :: Options -> Place -> Type -> IsRestrict -> String
show_type options Declaration_pl (ArrayType s t) restr = codeGenerationError InternalError $ "Array allocation is not allowed."
show_type options MainParameter_pl (ArrayType s t) restr = "struct array"
show_type options Declaration_pl t _ = toC options Declaration_pl t
show_type options MainParameter_pl t _ = toC options MainParameter_pl t
show_type options _ _ _ = ""

show_name :: VariableRole -> Place -> Type -> String  -> String
show_name Value place t n
    | place == AddressNeed_pl = "&" ++ n
    | otherwise = n
show_name Pointer place t n
    | place == AddressNeed_pl && List.last n == ']' = "&" ++ n
    | place == AddressNeed_pl && List.last n /= ']' = n
    | place == Declaration_pl = codeGenerationError InternalError $ "Output variable of the function declared!"
    | place == MainParameter_pl = "* " ++ n
    | List.last n == ']' = n
    | otherwise = "(* " ++ n ++ ")"

-- show_array_in_fun :: (HasType a, ToC a) => Options -> Place -> a -> String
-- show_array_in_fun options place exp = case typeof exp of
    -- t@(ArrayType _ _) -> concat["&(", toC options place exp, genIndex t, ")"]
    -- _ -> toC options place exp

-- genIndex :: Type -> String
-- genIndex (ArrayType _ t) = "[0]" ++ genIndex t
-- genIndex _ = ""

----------------------
--   Type           --
----------------------

class HasType a where
    typeof :: a -> Type

instance HasType (Variable t) where
    typeof (Variable r t s _) = t    

instance (ShowLabel t) => HasType (Constant t) where
    typeof (IntConst _ _ _) = NumType Signed S32
    typeof (FloatConst _ _ _) = FloatType
    typeof (BoolConst _ _ _) = BoolType
    typeof (ComplexConst r i _ _) = ComplexType (typeof r)
    typeof arr@(ArrayConst l _ _) = ArrayType (LiteralLen $ length l) elemtype
        where
            elemtype = case l of
                []  -> codeGenerationError InternalError $ "Const array with 0 elements: " ++ show arr
                _   -> checktype (typeof $ head l) (map typeof l)
            checktype :: Type -> [Type] -> Type
            checktype t [] = t
            checktype t (x:xs)
                | t == x = checktype t xs
                | otherwise = codeGenerationError InternalError $ "Different element types in constant array: " ++ show arr

instance (ShowLabel t) => HasType (Expression t) where
    typeof (VarExpr v _) = typeof v
    typeof (ArrayElem n i _ _) = decrArrayDepth (typeof n)
    typeof (StructField s f _ _) = getStructFieldType f (typeof s)
    typeof (ConstExpr c _) = typeof c
    typeof (FunctionCall f t r p _ _) = t
    typeof (Cast t e _ _) = t
    typeof (SizeOf s _ _) = NumType Signed S32

instance (ShowLabel t) => HasType (ActualParameter t) where
    typeof (In e _) = typeof e
    typeof (Out l _) = typeof l

----------------------
-- Helper functions --
----------------------

ind :: (a-> String) -> a -> String
ind f x = unlines $ map (\a -> "    " ++ a) $ lines $ f x

listprint :: (a->String) -> String -> [a] -> String
listprint f s xs = listprint' s $ filter (\a -> a /= "")$ map f xs where
    listprint' _ [] = ""
    listprint' _ [x] = x
    listprint' s (x:xs) = x ++ s ++ listprint' s (xs)

decrArrayDepth :: Type -> Type
decrArrayDepth (ArrayType _ t) = t
decrArrayDepth _ = codeGenerationError InternalError "Non-array variable is indexed!"

getStructFieldType :: String -> Type -> Type
getStructFieldType f (StructType l) = case List.find (\(a,_) -> a == f) l of
    Just (_,t) -> t
    Nothing -> structFieldNotFound f
getStructFieldType f t = codeGenerationError InternalError $ "Trying to get a struct field from not a struct typed expression\n" ++ "Field: " ++ f ++ "\nType:  " ++ show t

structFieldNotFound f = codeGenerationError InternalError $ "Not found struct field with this name: " ++ f