{-# LANGUAGE DataKinds #-}

-- | This module provides convenient TemplateHaskell functions for making type lens suitable for use with multivariate distributions.
-- 
-- Given a data type that looks like:
-- 
-- >data Character = Character
-- >    { _name      :: String
-- >    , _species   :: String
-- >    , _job       :: Job
-- >    , _isGood    :: Maybe Bool
-- >    , _age       :: Double -- in years
-- >    , _height    :: Double -- in feet
-- >    , _weight    :: Double -- in pounds
-- >    }
-- >    deriving (Read,Show,Eq,Ord)
-- > 
-- >data Job = Manager | Crew | Henchman | Other
-- >    deriving (Read,Show,Eq,Ord)
-- 
--
-- when we run the command:
--
-- >makeTypeLenses ''Character
--
-- We generate the following type lenses automatically:
--
-- >data TH_name    = TH_name
-- >data TH_species = TH_species
-- >data TH_job     = TH_job
-- >data TH_isGood  = TH_isGood
-- >data TH_age     = TH_age
-- >data TH_height  = TH_height
-- >data TH_weight  = TH_weight
-- >
-- >instance TypeLens TH_name where
-- >    type instance TypeLensIndex TH_name = Nat1Box Zero
-- >instance TypeLens TH_species where
-- >    type instance TypeLensIndex TH_species = Nat1Box (Succ Zero)
-- >instance TypeLens TH_job where
-- >    type instance TypeLensIndex TH_job = Nat1Box (Succ (Succ Zero))
-- >instance TypeLens TH_isGood where
-- >    type instance TypeLensIndex TH_isGood = Nat1Box (Succ (Succ (Succ Zero)))
-- >instance TypeLens TH_age where
-- >    type instance TypeLensIndex TH_age = Nat1Box (Succ (Succ (Succ (Succ Zero))))
-- >instance TypeLens TH_height where
-- >    type instance TypeLensIndex TH_height = Nat1Box (Succ (Succ (Succ (Succ (Succ Zero)))))
-- >instance TypeLens TH_weight where
-- >    type instance TypeLensIndex TH_weight = Nat1Box (Succ (Succ (Succ (Succ (Succ (Succ Zero))))))
-- >        
-- >instance Trainable Character where
-- >    type instance GetHList Character = HList '[String,String,Job,Maybe Bool, Double,Double,Double]
-- >    getHList var = name var:::species var:::job var:::isGood var:::age var:::height var:::weight var:::HNil
-- >
-- >instance MultivariateLabels Character where
-- >    getLabels dist = ["TH_name","TH_species","TH_job","TH_isGood","TH_age","TH_height","TH_weight"]
-- 
-- 
-- 

module HLearn.Models.Distributions.Multivariate.Internal.TypeLens
    ( 
    -- * Lens
    Trainable (..)
    , TypeLens (..)
    , TypeFunction (..)
    -- * TemplateHaskell
    , makeTypeLenses
    , nameTransform
    )
    where

import HLearn.Algebra
import Language.Haskell.TH hiding (Range)
import Language.Haskell.TH.Syntax hiding (Range)


-------------------------------------------------------------------------------
-- Trainable

-- | The Trainable class allows us to convert data types into an isomorphic "HList".  All of our multivariate distributions work on "HList"s, so they work on all instances of "Trainable" as well.
class Trainable t where
    type GetHList t
    getHList :: t -> GetHList t

instance Trainable (HList '[]) where
    type GetHList (HList '[]) = HList '[]
    getHList t = t
    
instance (Trainable (HList xs)) => Trainable (HList (x ': xs)) where
    type GetHList (HList (x ': xs)) = HList (x ': xs)
    getHList t = t

-- | This specifies a type level natural number (i.e. "Nat1") that indexes at the right location into our HList
class TypeLens i where
    type TypeLensIndex i

class TypeFunction f where
    type Domain f
    type Range f
    
    typefunc :: f -> Domain f -> Range f

-- | given the name of one of our records, transform it into the name of our type lens
nameTransform :: String -> String
nameTransform str = "TH"++str

nameTransform' :: Name -> Name
nameTransform' name = mkName $ "TH"++(nameBase name)

-- | constructs the type lens
makeTypeLenses :: Name -> Q [Dec]
makeTypeLenses name = do
    datatypes <- makeDatatypes name
    indexNames <- makeIndexNames name
    trainableInstance <- makeTrainable name
    multivariateLabels <- makeMultivariateLabels name
    typeFunctions <- makeTypeFunctions name
    return $ datatypes ++ indexNames ++ trainableInstance ++ multivariateLabels ++ typeFunctions

makeDatatypes :: Name -> Q [Dec]
makeDatatypes name = fmap (map makeEmptyData) $ extractContructorNames name
    where
        makeEmptyData str = DataD [] (mkName $ nameTransform str) [] [NormalC (mkName $ nameTransform str) []] []

makeIndexNames :: Name -> Q [Dec]
makeIndexNames name = fmap (map makeIndexName . zip [0..]) $ extractContructorNames name
    where
    makeIndexName (i,str) = InstanceD [] (AppT (ConT $ mkName "TypeLens") (ConT $ mkName $ nameTransform str)) 
        [ TySynInstD (mkName "TypeLensIndex") [ConT $ mkName $ nameTransform str] (AppT (ConT $ mkName "Nat1Box") (typeNat i))
        ]
        where
            typeNat 0 = ConT $ mkName "Zero"
            typeNat n = AppT (ConT $ mkName "Succ") $ typeNat (n-1)

makeTypeFunctions :: Name -> Q [Dec]
makeTypeFunctions constructorName = fmap (map makeTypeFunction) $ extractConstructorFields constructorName
    where
        makeTypeFunction (recordName,_,recordType) = InstanceD [] (AppT (ConT $ mkName "TypeFunction") (ConT $ nameTransform' recordName)) 
            [ TySynInstD (mkName "Domain") [ConT $ nameTransform' recordName] (ConT constructorName)
            , TySynInstD (mkName "Range") [ConT $ nameTransform' recordName] (SigT recordType StarT)
            , FunD (mkName "typefunc") [Clause [VarP $ mkName "_"{-, VarP $ mkName "domain"-}] (NormalB $ VarE recordName) []]
            ]

makeTrainable :: Name -> Q [Dec]
makeTrainable name = do
    hlistType <- extractHListType name
    hlistExp <- extractHListExp (mkName "var") name
    return $ [InstanceD [] (AppT (ConT (mkName "Trainable")) (ConT name)) 
        [ TySynInstD (mkName "GetHList") [ConT name] (AppT (ConT $ mkName "HList") hlistType)
        , FunD (mkName "getHList") [Clause [VarP $ mkName "var"] (NormalB hlistExp) []]
        ]]

makeMultivariateLabels :: Name -> Q [Dec]
makeMultivariateLabels name = do
    labelL <- extractContructorNames name
    return $ [ InstanceD [] (AppT (ConT (mkName "MultivariateLabels")) (ConT name)) 
        [ FunD (mkName "getLabels") [Clause [VarP $ mkName "dist"] (NormalB $ go labelL ) []]
        ]]
        where
            go [] = ConE $ mkName "[]"
            go (x:xs) = AppE (AppE (ConE $ mkName ":") (LitE $ StringL (nameTransform x))) $ go xs

---------------------------------------

extractHListType :: Name -> Q Type
extractHListType name = do
    typeL <- fmap (map getType) $ extractConstructorFields name
    return $ go typeL
    where
        go [] = ConT $ mkName "[]"
        go (x:xs) = AppT (AppT (ConT $ mkName ":") (x)) $ go xs
            
        getType (n,s,t) = t

extractHListExp :: Name -> Name -> Q Exp
extractHListExp var name = do
    typeL <- fmap (map getName) $ extractConstructorFields name
    return $ go typeL
    where
        go [] = ConE $ mkName "HNil"
        go (x:xs) = AppE (AppE (ConE $ mkName ":::") (AppE (VarE x) (VarE var))) $ go xs
            
        getName (n,s,t) = n

-------------------------------------------------------------------------------
-- below taken from Data.Lens 

type ConstructorFieldInfo = (Name, Strict, Type)

extractContructorNames :: Name -> Q [String]
extractContructorNames datatype = fmap (map name) $ extractConstructorFields datatype
    where
        name (n,s,t) = nameBase n

extractConstructorFields :: Name -> Q [ConstructorFieldInfo]
extractConstructorFields datatype = do
  let datatypeStr = nameBase datatype
  i <- reify datatype
  return $ case i of
    TyConI (DataD    _ _ _ [RecC _ fs] _) -> fs
    TyConI (NewtypeD _ _ _ (RecC _ fs) _) -> fs
    TyConI (DataD    _ _ _ [_]         _) -> error $ "Can't derive Lens without record selectors: " ++ datatypeStr
    TyConI NewtypeD{} -> error $ "Can't derive Lens without record selectors: " ++ datatypeStr
    TyConI TySynD{}   -> error $ "Can't derive Lens for type synonym: " ++ datatypeStr
    TyConI DataD{}    -> error $ "Can't derive Lens for tagged union: " ++ datatypeStr
    _                 -> error $ "Can't derive Lens for: "  ++ datatypeStr ++ ", type name required."

extractTypeInfo name = do
    i <- reify name
    return $ case i of
        TyConI (DataD    _ n ts _ _) -> (n, ts)
        TyConI (NewtypeD _ n ts _ _) -> (n, ts)