module CodeGen.Render.Haskell
  ( render
  ) where

import CodeGen.Prelude
import CodeGen.Types
import CodeGen.Parse.Cases (type2hsreal, type2real, type2accreal)
import qualified Data.Text as T


render :: TypeCategory -> TemplateType -> Parsable -> Maybe Text
render tc tt = typeCatHelper tc . renderParsable tt


typeCatHelper :: TypeCategory -> Text -> Maybe Text
typeCatHelper tc s = case tc of
  FunctionParam -> Just s
  ReturnValue   ->
    case T.take 3 s of
     "()"  -> Just   "IO ()"
     "Ptr" -> Just $ "IO (" <> s <> ")"
     _     -> Just $ "IO " <> s <> ""


renderParsable :: TemplateType -> Parsable -> Text
renderParsable tt =
  \case
    -- special cases
    TenType desc@(Pair (DescBuff, _)) -> "Ptr " <> renderTenType tt desc

    Ptr (Ptr x) -> "Ptr (Ptr " <> renderParsable tt x <> ")"
    Ptr x -> "Ptr " <> renderParsable tt x
    -- Raw DescBuffs need to be wrapped in a pointer for marshalling
    TenType x -> renderTenType tt x
    -- NNType x -> renderNNType lt tt x
    CType x -> renderCType x


renderTenType :: TemplateType -> TenType -> Text
renderTenType tt = \case
  Pair (Tensor,  lt) -> c <> prefix lt True <> type2hsreal tt <> "Tensor"
  Pair (Storage, lt) -> c <> tshow lt       <> type2hsreal tt <> "Storage"
  Pair (Real,    lt) -> type2real lt tt
  Pair (AccReal, lt) -> type2accreal lt tt
  r@(Pair (rtt, lt)) -> c <> prefix lt (isConcreteCudaPrefixed r) <> tshow rtt
 where
  c = "C'"


renderCType :: CType -> Text
renderCType = \case
  CVoid -> "()"
  -- int/uint conversions, see
  -- https://www.haskell.org/onlinereport/haskell2010/haskellch8.html
  -- https://hackage.haskell.org/package/base-4.10.0.0/docs/Foreign-C-Types.html
  CUInt64 -> "CULong"
  CUInt32 -> "CUInt"
  CUInt16 -> "CUShort"
  CUInt8  -> "CUChar"
  CInt64  -> "CLLong"
  CInt32  -> "CInt"
  CInt16  -> "CShort"
  CInt8   -> "CSChar"
  rest    -> tshow rest

{-
-- FIXME: get back to this when THC is finished
renderNNType :: LibType -> TemplateType -> NNType -> Text
renderNNType _ _ = \case
  IndexTensor   -> undefined
  IntegerTensor -> undefined
-}