{-# LANGUAGE TemplateHaskell, CPP #-}
{- |
Module          : Foreign.C.Structs.Templates
Description     : Create C structs from Haskell
Copyright       : (c) Simon Plakolb, 2020
License         : MIT
Maintainer      : s.plakolb@gmail.com
Stability       : beta

This module exposes the template haskell framework to create Struct types.
-}
module Foreign.C.Structs.Templates
    (structT, acs)
where

import Language.Haskell.TH

import Foreign.Storable (Storable, peek, poke, sizeOf, alignment)
import Foreign.Ptr (castPtr)
import Foreign.C.Structs.Utils (next, sizeof, fmax)

-- | All @StructN@ types and their instances of 'Storable' are declared using 'structT'.
-- It can theoretically create C structs with an infinite number of fields.
-- The parameter of 'structT' is the number of fields the struct type should have.
-- Its constructor and type will both be named @StructN@ where N is equal to the argument to 'structT'.
structT :: Int -> DecsQ
structT :: Int -> DecsQ
structT = forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a b. (a -> b) -> a -> b
($) [Int -> Dec
structTypeT, Int -> Dec
storableInstanceT] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> [a]
repeat

-- | Access function for fields of a @StructN@ where @N@ is the number of fields in the struct.
-- N is the first argument passed to 'acs', while the second is the field number.
-- The first field has number 1, the second 2 and so on.
--
-- > s = Struct4 1 2 3 4
-- > $(acs 4 3) s
--
acs :: Int -> Int -> ExpQ
acs :: Int -> Int -> ExpQ
acs Int
big_n Int
small_n = [| \struct -> $(caseE [| struct |] [m]) |]
    where
          m :: MatchQ
          m :: MatchQ
m = forall (m :: * -> *).
Quote m =>
m Pat -> m Body -> [m Dec] -> m Match
match PatQ
pat (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => Name -> m Exp
varE forall a b. (a -> b) -> a -> b
$ [Name]
vrs forall a. [a] -> Int -> a
!! (Int
small_nforall a. Num a => a -> a -> a
-Int
1)) []

          pat :: PatQ
          pat :: PatQ
pat = forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
str forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). Quote m => Name -> m Pat
varP forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
big_n [Name]
vrs

          str :: Name
str = String -> Name
mkName forall a b. (a -> b) -> a -> b
$ String
"Struct" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
big_n

          vrs :: [Name]
vrs = String -> [Name]
fieldnames String
""

-- Templating functions

structTypeT :: Int -> Dec
#if __GLASGOW_HASKELL__ < 800
structTypeT nfields = DataD [] (structType nfields) tyVars [constructor] deriv''
#elif __GLASGOW_HASKELL__ < 802
structTypeT nfields = DataD [] (structType nfields) tyVars Nothing [constructor] deriv'
#else
structTypeT :: Int -> Dec
structTypeT Int
nfields = Cxt
-> Name
-> [TyVarBndr ()]
-> Maybe Kind
-> [Con]
-> [DerivClause]
-> Dec
DataD [] (forall {a}. Show a => a -> Name
structType Int
nfields) [TyVarBndr ()]
tyVars forall a. Maybe a
Nothing [Con
constructor] [DerivClause
deriv]
#endif
    where
#if __GLASGOW_HASKELL__ < 900
          tyVars    = map PlainTV $ take nfields $ fieldnames ""
#else
          tyVars :: [TyVarBndr ()]
tyVars    = forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall flag. Name -> flag -> TyVarBndr flag
PlainTV ()) forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
nfields forall a b. (a -> b) -> a -> b
$ String -> [Name]
fieldnames String
""
#endif

          constructor :: Con
constructor = Name -> [VarBangType] -> Con
RecC (forall {a}. Show a => a -> Name
structType Int
nfields) forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
nfields [VarBangType]
records

          records :: [VarBangType]
records     = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {a}. a -> Name -> (a, Bang, Kind)
defRec (Int -> [Name]
getters Int
nfields) (String -> [Name]
fieldnames String
"")
#if __GLASGOW_HASKELL__ < 800
          defRec n t  = (,,) n NotStrict (VarT t)
#else
          defRec :: a -> Name -> (a, Bang, Kind)
defRec a
n Name
t  = (,,) a
n (SourceUnpackedness -> SourceStrictness -> Bang
Bang SourceUnpackedness
NoSourceUnpackedness SourceStrictness
NoSourceStrictness) (Name -> Kind
VarT Name
t)
#endif
          deriv'' :: [Name]
deriv'' = [''Show, ''Eq]

          deriv' :: Cxt
deriv' = forall a b. (a -> b) -> [a] -> [b]
map Name -> Kind
ConT [Name]
deriv''
#if __GLASGOW_HASKELL__ > 800
          deriv :: DerivClause
deriv = Maybe DerivStrategy -> Cxt -> DerivClause
DerivClause forall a. Maybe a
Nothing Cxt
deriv'
#endif

storableInstanceT :: Int -> Dec
#if __GLASGOW_HASKELL__ < 800
storableInstanceT nfields = InstanceD cxt tp decs
#else
storableInstanceT :: Int -> Dec
storableInstanceT Int
nfields = Maybe Overlap -> Cxt -> Kind -> [Dec] -> Dec
InstanceD forall a. Maybe a
Nothing Cxt
cxt Kind
tp [Dec]
decs
#endif
    where
          vars :: [Name]
vars = forall a. Int -> [a] -> [a]
take Int
nfields forall a b. (a -> b) -> a -> b
$ String -> [Name]
fieldnames String
""

          storable :: Kind -> Kind
storable = Kind -> Kind -> Kind
AppT forall a b. (a -> b) -> a -> b
$ Name -> Kind
ConT ''Storable
#if __GLASGOW_HASKELL__ < 710
          cxt  = map (\v -> ClassP ''Storable [VarT v]) vars
#else
          cxt :: Cxt
cxt  = forall a b. (a -> b) -> [a] -> [b]
map (Kind -> Kind
storable forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Kind
VarT) [Name]
vars
#endif
          tp :: Kind
tp   = Kind -> Kind
storable forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Kind -> Kind -> Kind
AppT (Name -> Kind
ConT forall a b. (a -> b) -> a -> b
$ forall {a}. Show a => a -> Name
structType Int
nfields) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Name -> Kind
VarT [Name]
vars

          decs :: [Dec]
decs = [ Int -> Dec
sizeOfT Int
nfields
                 , Int -> Dec
alignmentT Int
nfields
                 , Int -> Dec
peekT Int
nfields
                 , Int -> Dec
pokeT Int
nfields
                 ]

-- Storable instance function temaples

sizeOfT :: Int -> Dec
sizeOfT :: Int -> Dec
sizeOfT Int
nfields = Name -> [Clause] -> Dec
FunD 'sizeOf [Clause
clause]
    where
          clause :: Clause
clause = [Pat] -> Body -> [Dec] -> Clause
Clause [Name -> Pat
VarP Name
struct] (Exp -> Body
NormalB Exp
body) [Dec]
wheres

          body :: Exp
body = Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'sizeof) forall a b. (a -> b) -> a -> b
$ String -> Exp
alignments String
"a") (String -> Exp
sizes String
"s")

          alignments :: String -> Exp
alignments = [Exp] -> Exp
ListE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
take Int
nfields forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [Name]
fieldnames

          sizes :: String -> Exp
sizes = [Exp] -> Exp
ListE forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
take Int
nfields forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [Name]
fieldnames

          wheres :: [Dec]
wheres = Name -> Int -> String -> [Dec]
vals 'alignment Int
nfields String
"a" forall a. [a] -> [a] -> [a]
++ Name -> Int -> String -> [Dec]
vals 'sizeOf Int
nfields String
"s"

alignmentT :: Int -> Dec
alignmentT :: Int -> Dec
alignmentT Int
nfields = Name -> [Clause] -> Dec
FunD 'alignment [Clause
clause]
    where
           clause :: Clause
clause = [Pat] -> Body -> [Dec] -> Clause
Clause [Name -> Pat
VarP Name
struct] (Exp -> Body
NormalB Exp
body) [Dec]
wheres

           body :: Exp
body = Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'fmax) ([Exp] -> Exp
ListE forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
nfields forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE forall a b. (a -> b) -> a -> b
$ String -> [Name]
fieldnames String
"")

           wheres :: [Dec]
wheres = Name -> Int -> String -> [Dec]
vals 'alignment Int
nfields String
""

peekT :: Int -> Dec
peekT :: Int -> Dec
peekT Int
nfields = Name -> [Clause] -> Dec
FunD 'peek [Clause
clause]
    where
          vars :: [Name]
vars = forall a. Int -> [a] -> [a]
take Int
nfields forall a b. (a -> b) -> a -> b
$ String -> [Name]
fieldnames String
""

          ptrs :: [Name]
ptrs = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
nfields forall a b. (a -> b) -> a -> b
$ String -> [Name]
fieldnames String
"_ptr"

          clause :: Clause
clause = [Pat] -> Body -> [Dec] -> Clause
Clause [Name -> Pat
VarP Name
ptr] (Exp -> Body
NormalB Exp
body) []

#if __GLASGOW_HASKELL__ < 900
          body = DoE $ initial ++ concat gotos ++ final
#else
          body :: Exp
body = Maybe ModName -> [Stmt] -> Exp
DoE forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ [Stmt]
initial forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Stmt]]
gotos forall a. [a] -> [a] -> [a]
++ [Stmt]
final
#endif

          initial :: [Stmt]
initial = [ Pat -> Exp -> Stmt
BindS (Name -> Pat
VarP forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [Name]
vars) (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'peek) Exp
castPtr')
                    , Pat -> Exp -> Stmt
BindS (Name -> Pat
VarP forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [Name]
ptrs) (Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'next) forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE Name
ptr) forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [Name]
vars)
                    ]


          gotos :: [[Stmt]]
gotos = forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Name -> Name -> Name -> [Stmt]
goto (forall a. [a] -> [a]
tail [Name]
vars) [Name]
ptrs (forall a. [a] -> [a]
tail [Name]
ptrs)

          goto :: Name -> Name -> Name -> [Stmt]
goto Name
n Name
p Name
next_p = [Name -> Name -> Stmt
bindVar' Name
p Name
n, Name -> Name -> Exp -> Stmt
bindPtr' Name
next_p Name
p (Name -> Exp
VarE Name
n)]

          final :: [Stmt]
final = [ Name -> Name -> Stmt
bindVar' (forall a. [a] -> a
last [Name]
ptrs) (forall a. [a] -> a
last [Name]
vars)
                  , Exp -> Stmt
NoBindS forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'return) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE (Name -> Exp
ConE (forall {a}. Show a => a -> Name
structType Int
nfields)) (forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
vars)
                  ]

pokeT :: Int -> Dec
pokeT :: Int -> Dec
pokeT Int
nfields = Name -> [Clause] -> Dec
FunD 'poke [Clause
clause]
    where
          vars :: [Name]
vars = forall a. Int -> [a] -> [a]
take Int
nfields forall a b. (a -> b) -> a -> b
$ String -> [Name]
fieldnames String
""

          types :: [Name]
types = forall a. Int -> [a] -> [a]
take Int
nfields forall a b. (a -> b) -> a -> b
$ String -> [Name]
fieldnames String
"_ty"

          ptrs :: [Name]
ptrs = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
nfields forall a b. (a -> b) -> a -> b
$ String -> [Name]
fieldnames String
"_ptr"

          clause :: Clause
clause = [Pat] -> Body -> [Dec] -> Clause
Clause [Pat]
patterns (Exp -> Body
NormalB Exp
body) []

#if __GLASGOW_HASKELL__ < 902
          patterns = [VarP ptr, ConP (structType nfields) (map VarP vars)]
#else
          patterns :: [Pat]
patterns = [Name -> Pat
VarP Name
ptr, Name -> Cxt -> [Pat] -> Pat
ConP (forall {a}. Show a => a -> Name
structType Int
nfields) (forall a b. (a -> b) -> [a] -> [b]
map Name -> Kind
VarT [Name]
types) (forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
vars)]
#endif

#if __GLASGOW_HASKELL__ < 900
          body = DoE $ [init_poke, init_next] ++ concat gotos ++ [final]
#else
          body :: Exp
body = Maybe ModName -> [Stmt] -> Exp
DoE forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ [Stmt
init_poke, Stmt
init_next] forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Stmt]]
gotos forall a. [a] -> [a] -> [a]
++ [Stmt
final]
#endif

          init_poke :: Stmt
init_poke = Exp -> Stmt
NoBindS
                    forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE Exp
cast_poke_ptr (Name -> Exp
VarE forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [Name]
vars)
                    where  cast_poke_ptr :: Exp
cast_poke_ptr = Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'poke) Exp
castPtr'

          init_next :: Stmt
init_next = Name -> Name -> Exp -> Stmt
bindPtr' (forall a. [a] -> a
head [Name]
ptrs) Name
ptr (Name -> Exp
VarE forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [Name]
vars)

          gotos :: [[Stmt]]
gotos = forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Name -> Name -> Name -> [Stmt]
goto (forall a. [a] -> [a]
tail [Name]
vars) [Name]
ptrs forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail [Name]
ptrs

          goto :: Name -> Name -> Name -> [Stmt]
goto Name
n Name
p Name
next_p = [Name -> Exp -> Stmt
pokeVar' Name
p Exp
var, Name -> Name -> Exp -> Stmt
bindPtr' Name
next_p Name
p Exp
var]
                where var :: Exp
var = Name -> Exp
VarE Name
n

          final :: Stmt
final = Name -> Exp -> Stmt
pokeVar' (forall a. [a] -> a
last [Name]
ptrs) (Name -> Exp
VarE forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [Name]
vars)

-- Helpers and Constants

structType :: a -> Name
structType a
n = String -> Name
mkName (String
"Struct" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show a
n)

struct :: Name
struct   = String -> Name
mkName String
"struct"

ptr :: Name
ptr      = String -> Name
mkName String
"ptr"

castPtr' :: Exp
castPtr' = Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'castPtr) (Name -> Exp
VarE Name
ptr)

fieldnames :: String -> [Name]
fieldnames :: String -> [Name]
fieldnames String
s = forall a b. (a -> b) -> [a] -> [b]
map (String -> Name
mkName forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. a -> [a] -> [a]
:String
s)) [Char
'a'..Char
'z']

getters    :: Int -> [Name]
getters :: Int -> [Name]
getters Int
n = forall a b. (a -> b) -> [a] -> [b]
map (String -> Name
mkName forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((String
"s" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
n) forall a. [a] -> [a] -> [a]
++))
          forall a b. (a -> b) -> a -> b
$  [String
"1st",String
"2nd",String
"3rd"]
          forall a. [a] -> [a] -> [a]
++ [forall a. Show a => a -> String
show Integer
n forall a. [a] -> [a] -> [a]
++ String
"th" | Integer
n <- [Integer
4..]]

vals :: Name -> Int -> String -> [Dec]
vals Name
f Int
n String
s = forall a. Int -> [a] -> [a]
take Int
n forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Name -> Name -> Dec
val (String -> [Name]
fieldnames String
s) (Int -> [Name]
getters Int
n)
    where
          val :: Name -> Name -> Dec
val Name
v Name
getter = Pat -> Body -> [Dec] -> Dec
ValD (Name -> Pat
VarP Name
v) (Exp -> Body
NormalB forall a b. (a -> b) -> a -> b
$ Name -> Exp
body Name
getter) []

          body :: Name -> Exp
body Name
getter  = Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
f) forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
getter) forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE Name
struct

bindVar' :: Name -> Name -> Stmt
bindVar' Name
ptr Name
var = Pat -> Exp -> Stmt
BindS (Name -> Pat
VarP Name
var) (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'peek) forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE Name
ptr)

pokeVar' :: Name -> Exp -> Stmt
pokeVar' Name
ptr Exp
var = Exp -> Stmt
NoBindS
       forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'poke) forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE Name
ptr) Exp
var

bindPtr' :: Name -> Name -> Exp -> Stmt
bindPtr' Name
np Name
pp Exp
var = Pat -> Exp -> Stmt
BindS (Name -> Pat
VarP Name
np)
       forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE Exp
next_ptr Exp
var
       where next_ptr :: Exp
next_ptr = Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'next) forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE Name
pp