{-# LANGUAGE CPP #-}
module Record.TH where

import BasePrelude hiding (Proxy)
import GHC.TypeLits
import Language.Haskell.TH hiding (classP)


classP :: Name -> [Type] -> Pred
#if MIN_VERSION_template_haskell(2,10,0)
classP n tl = foldl AppT (ConT n) tl
#else
classP = ClassP
#endif

recordTypeDec :: Bool -> Int -> Dec
recordTypeDec strict arity =
  DataD [] name varBndrs [NormalC name conTypes] derivingNames
  where
    name =
      recordName strict arity
    varBndrs =
      do
        i <- [1 .. arity]
        let
          n = KindedTV (mkName ("n" <> show i)) (ConT ''Symbol)
          v = PlainTV (mkName ("v" <> show i))
          in [n, v]
    conTypes =
      do
        i <- [1 .. arity]
        return $ (,) strictness (VarT (mkName ("v" <> show i)))
      where
        strictness =
          if strict then IsStrict else NotStrict
    derivingNames =
#if MIN_VERSION_base(4,7,0)
      [''Show, ''Eq, ''Ord, ''Typeable, ''Generic]
#else
      [''Show, ''Eq, ''Ord, ''Generic]
#endif

recordName :: Bool -> Int -> Name
recordName strict arity =
  mkName $ recordNameString strict arity

recordNameString :: Bool -> Int -> String
recordNameString strict arity =
  prefix <> "Record" <> show arity
  where
    prefix =
      if strict then "Strict" else "Lazy"

recordFieldInstanceDec :: Bool -> Int -> Int -> Dec
recordFieldInstanceDec strict =
  fieldInstanceDec (FieldInstanceDecMode_Record strict)

tupleFieldInstanceDec :: Int -> Int -> Dec
tupleFieldInstanceDec arity fieldIndex =
  fieldInstanceDec FieldInstanceDecMode_Tuple arity fieldIndex

data FieldInstanceDecMode =
  FieldInstanceDecMode_Tuple |
  FieldInstanceDecMode_Record Bool

fieldInstanceDec :: FieldInstanceDecMode -> Int -> Int -> Dec
fieldInstanceDec mode arity fieldIndex =
  InstanceD [] headType decs
  where
    headType =
      foldl1 AppT 
      [classType, VarT selectedNVarName, recordType, recordPrimeType, VarT selectedVVarName, VarT selectedVPrimeVarName]
      where
        classType =
          ConT (mkName "Field")
    typeName =
      case mode of
        FieldInstanceDecMode_Tuple -> tupleTypeName arity
        FieldInstanceDecMode_Record strict -> recordName strict arity
    conName =
      case mode of
        FieldInstanceDecMode_Tuple -> tupleDataName arity
        FieldInstanceDecMode_Record strict -> recordName strict arity
    selectedNVarName =
      mkName $ "n" <> show fieldIndex
    selectedVVarName =
      mkName $ "v" <> show fieldIndex
    selectedVPrimeVarName =
      mkName $ "v" <> show fieldIndex <> "'"
    recordType =
      foldl (\a i -> AppT (addNVar a i)
                          (VarT (mkName ("v" <> show i))))
            (ConT typeName)
            [1 .. arity]
    recordPrimeType =
      foldl (\a i -> AppT (addNVar a i)
                          (VarT (if i == fieldIndex then selectedVPrimeVarName
                                                    else mkName ("v" <> show i))))
            (ConT typeName)
            [1 .. arity]
    addNVar =
      case mode of
        FieldInstanceDecMode_Tuple -> \a i -> a
        FieldInstanceDecMode_Record _ -> \a i -> AppT a (VarT (mkName ("n" <> show i)))
    decs =
      [fieldLensDec]
      where
        fieldLensDec =
          FunD (mkName "fieldLens") [Clause patterns (NormalB exp) []]
          where
            patterns =
              [WildP, VarP fVarName, ConP conName (fmap VarP indexedVVarNames)]
            fVarName =
              mkName "f"
            indexedVVarNames =
              fmap (\i -> mkName ("v" <> show i)) [1..arity]
            exp =
              AppE (AppE (VarE 'fmap) (consLambda))
                   (AppE (VarE fVarName) (VarE selectedVVarName))
              where
                consLambda =
                  LamE [VarP selectedVPrimeVarName] exp
                  where
                    exp =
                      foldl AppE (ConE conName) $
                      map VarE $
                      map (\(i, n) -> if i == fieldIndex then selectedVPrimeVarName
                                                         else mkName ("v" <> show i)) $
                      zip [1 .. arity] indexedVVarNames

recordStorableInstanceDec :: Bool -> Int -> Dec
recordStorableInstanceDec strict arity =
  InstanceD context (AppT (ConT (mkName "Storable")) recordType)
            [sizeOfFun, inlineFun "sizeOf", alignmentFun, inlineFun "alignment"
            , peekFun, inlineFun "peek", pokeFun, inlineFun "poke"]
  where
    name = recordName strict arity
    recordType =
      foldl (\a i -> AppT (AppT a (VarT (mkName ("n" <> show i))))
                          (VarT (mkName ("v" <> show i))))
            (ConT name)
            [1 .. arity]
    context = map (\i -> classP (mkName "Storable")  [VarT (mkName ("v" <> show i))])
              [1 .. arity]
    nameE = VarE . mkName
    -- The sum of the sizes of all types
    sizeOfFun' n = foldr (\a b -> AppE (AppE (nameE "+") a) b) (LitE (IntegerL 0)) $
                   map (\i -> AppE
                              (nameE "sizeOf")
                              (SigE (nameE "undefined")
                                    (VarT (mkName ("v" <> show i)))))
                   [1..n]
    sizeOfFun = FunD (mkName "sizeOf")
                [Clause [WildP]
                 (NormalB (sizeOfFun' arity)) []]
    -- Set the alignment to the maximum alignment of the types
    alignmentFun = FunD (mkName "alignment")
                   [(Clause [WildP]
                     (NormalB (AppE (nameE "maximum") $ ListE $
                               map (\i -> AppE
                                          (nameE "sizeOf")
                                          (SigE (nameE "undefined")
                                                (VarT (mkName ("v" <> show i)))))
                               [1..arity])) [])]
    -- Peek every variable, remember to add the size of the elements already seen to the ptr
    peekFun = FunD (mkName "peek")
              [(Clause [VarP (mkName "ptr")]
                  (NormalB (DoE $ map (\i -> BindS
                                             (BangP (VarP (mkName ("x" <> show i))))
                                                    (AppE (nameE "peek")
                                                          (AppE (AppE (nameE "plusPtr")
                                                                      (nameE "ptr"))
                                                                (sizeOfFun' (i - 1))))) [1..arity]
                                 ++ [NoBindS (AppE (nameE "return")
                                             (foldl (\a i -> AppE a (nameE ("x" <> show i)))
                                             (ConE name) [1 .. arity]))])) [])]
    typePattern = ConP name (map (\i -> VarP (mkName ("v" <> show i))) [1..arity])
    pokeFun = FunD (mkName "poke")
              [(Clause [VarP (mkName "ptr"), typePattern]
                 (NormalB (DoE $ map (\i -> NoBindS
                                            (AppE
                                             (AppE (VarE (mkName "poke"))
                                                   (AppE (AppE (nameE "plusPtr")
                                                                 (nameE "ptr"))
                                                          (sizeOfFun' (i - 1))))
                                             (nameE ("v" <> show i)))) [1..arity])) [])]
    inlineFun name = PragmaD $ InlineP (mkName name) Inline FunLike AllPhases

recordConFunDecs :: Bool -> Int -> [Dec]
recordConFunDecs strict arity =
  [inline, signature, fun]
  where
    inline =
      PragmaD (InlineP name Inline FunLike AllPhases)
    signature =
      SigD name type_
      where
        type_ =
          ForallT varBndrs [] $
          foldr AppT recordType $
          map (AppT ArrowT) $
          interleave nameProxyTypes valueVariableTypes
          where
            varBndrs =
              map PlainTV $
              interleave nameVariableNames valueVariableNames
            recordType =
              foldl' AppT (ConT (recordName strict arity)) $ 
              interleave nameVariableTypes valueVariableTypes
            valueVariableTypes =
              map VarT valueVariableNames
            valueVariableNames =
              map (\i -> mkName ("v" <> show i)) [1 .. arity]
            nameVariableTypes =
              map VarT nameVariableNames
            nameVariableNames =
              map (\i -> mkName ("n" <> show i)) [1 .. arity]
            nameProxyTypes =
              map (AppT (ConT (mkName "FieldName"))) nameVariableTypes
            interleave a b =
              join $ zipWith (\a b -> [a, b]) a b
    fun =
      FunD name [Clause [] (NormalB (recordConLambdaExp strict arity)) []]
    name =
      mkName string
      where
        string =
          onHead toLower (recordNameString strict arity)
          where
            onHead f =
              \case
                a : b -> f a : b
                [] -> []

-- |
-- Allows to specify field names at value-level.
-- Useful for type-inference.
-- 
-- E.g., in
-- 
-- >(\_ v1 _ v2 -> StrictRecord2 v1 v2) :: Types.FieldName n1 -> v1 -> Types.FieldName n2 -> v2 -> StrictRecord2 n1 v1 n2 v2
-- 
-- we can set the name signatures by passing
-- the name-proxies to this lambda.
recordConLambdaExp :: Bool -> Int -> Exp
recordConLambdaExp strict arity =
  SigE exp t
  where
    name =
      recordName strict arity
    exp =
      LamE pats exp
      where
        pats =
          concat $ flip map [1 .. arity] $ \i -> [WildP, VarP (mkName ("v" <> show i))]
        exp =
          foldl AppE (ConE name) (map (\i -> VarE (mkName ("v" <> show i))) [1 .. arity])
    t =
      fnType name
      where
        fnType conName =
          ForallT varBndrs [] $
          foldr1 (\l r -> AppT (AppT ArrowT l) r)
                 (argTypes <> pure (resultType conName))
        varBndrs =
          concat $ flip map [1 .. arity] $ \i ->
            PlainTV (mkName ("n" <> show i)) :
            PlainTV (mkName ("v" <> show i)) :
            []
        argTypes =
          concat $ flip map [1 .. arity] $ \i -> 
            AppT (ConT (mkName "FieldName")) (VarT (mkName ("n" <> show i))) :
            VarT (mkName ("v" <> show i)) :
            []
        resultType conName =
          foldl AppT (ConT conName) $ concat $ flip map [1 .. arity] $ \i ->
            VarT (mkName ("n" <> show i)) :
            VarT (mkName ("v" <> show i)) :
            []