module HLearn.Models.Distributions.Multivariate.Internal.TypeLens
(
Trainable (..)
, TypeLens (..)
, makeTypeLenses
, nameTransform
)
where
import HLearn.Algebra
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
class Trainable t where
type GetHList t
getHList :: t -> GetHList t
instance Trainable (HList '[]) where
type GetHList (HList '[]) = HList '[]
getHList t = t
instance (Trainable (HList xs)) => Trainable (HList (x ': xs)) where
type GetHList (HList (x ': xs)) = HList (x ': xs)
getHList t = t
class TypeLens i where
type TypeLensIndex i
nameTransform :: String -> String
nameTransform str = "TH"++str
makeTypeLenses :: Name -> Q [Dec]
makeTypeLenses name = do
datatypes <- makeDatatypes name
indexNames <- makeIndexNames name
trainableInstance <- makeTrainable name
multivariateLabels <- makeMultivariateLabels name
return $ datatypes ++ indexNames ++ trainableInstance ++ multivariateLabels
makeDatatypes :: Name -> Q [Dec]
makeDatatypes name = fmap (map makeEmptyData) $ extractContructorNames name
where
makeEmptyData str = DataD [] (mkName $ nameTransform str) [] [NormalC (mkName $ nameTransform str) []] []
makeIndexNames :: Name -> Q [Dec]
makeIndexNames name = fmap (map makeIndexName . zip [0..]) $ extractContructorNames name
where
makeIndexName (i,str) = InstanceD [] (AppT (ConT $ mkName "TypeLens") (ConT $ mkName $ nameTransform str))
[ TySynInstD (mkName "TypeLensIndex") [ConT $ mkName $ nameTransform str] (AppT (ConT $ mkName "Nat1Box") (typeNat i))
]
where
typeNat 0 = ConT $ mkName "Zero"
typeNat n = AppT (ConT $ mkName "Succ") $ typeNat (n1)
makeTrainable :: Name -> Q [Dec]
makeTrainable name = do
hlistType <- extractHListType name
hlistExp <- extractHListExp (mkName "var") name
return $ [InstanceD [] (AppT (ConT (mkName "Trainable")) (ConT name))
[ TySynInstD (mkName "GetHList") [ConT name] (AppT (ConT $ mkName "HList") hlistType)
, FunD (mkName "getHList") [Clause [VarP $ mkName "var"] (NormalB hlistExp) []]
]]
makeMultivariateLabels :: Name -> Q [Dec]
makeMultivariateLabels name = do
labelL <- extractContructorNames name
return $ [ InstanceD [] (AppT (ConT (mkName "MultivariateLabels")) (ConT name))
[ FunD (mkName "getLabels") [Clause [VarP $ mkName "dist"] (NormalB $ go labelL ) []]
]]
where
go [] = ConE $ mkName "[]"
go (x:xs) = AppE (AppE (ConE $ mkName ":") (LitE $ StringL (nameTransform x))) $ go xs
type ConstructorFieldInfo = (Name, Strict, Type)
extractHListType :: Name -> Q Type
extractHListType name = do
typeL <- fmap (map getType) $ extractConstructorFields name
return $ go typeL
where
go [] = ConT $ mkName "[]"
go (x:xs) = AppT (AppT (ConT $ mkName ":") (x)) $ go xs
getType (n,s,t) = t
extractHListExp :: Name -> Name -> Q Exp
extractHListExp var name = do
typeL <- fmap (map getName) $ extractConstructorFields name
return $ go typeL
where
go [] = ConE $ mkName "HNil"
go (x:xs) = AppE (AppE (ConE $ mkName ":::") (AppE (VarE x) (VarE var))) $ go xs
getName (n,s,t) = n
extractContructorNames :: Name -> Q [String]
extractContructorNames datatype = fmap (map name) $ extractConstructorFields datatype
where
name (n,s,t) = nameBase n
extractConstructorFields :: Name -> Q [ConstructorFieldInfo]
extractConstructorFields datatype = do
let datatypeStr = nameBase datatype
i <- reify datatype
return $ case i of
TyConI (DataD _ _ _ [RecC _ fs] _) -> fs
TyConI (NewtypeD _ _ _ (RecC _ fs) _) -> fs
TyConI (DataD _ _ _ [_] _) -> error $ "Can't derive Lens without record selectors: " ++ datatypeStr
TyConI NewtypeD{} -> error $ "Can't derive Lens without record selectors: " ++ datatypeStr
TyConI TySynD{} -> error $ "Can't derive Lens for type synonym: " ++ datatypeStr
TyConI DataD{} -> error $ "Can't derive Lens for tagged union: " ++ datatypeStr
_ -> error $ "Can't derive Lens for: " ++ datatypeStr ++ ", type name required."
extractTypeInfo name = do
i <- reify name
return $ case i of
TyConI (DataD _ n ts _ _) -> (n, ts)
TyConI (NewtypeD _ n ts _ _) -> (n, ts)