{-# LANGUAGE TemplateHaskell #-}
module Data.Profunctor.Product.Tuples.TH
  ( mkTs
  , pTns
  , mkFlattenNs
  , mkUnflattenNs
  , pNs
  , mkDefaultNs
  , maxTupleSize
  ) where

import Language.Haskell.TH

import Data.Profunctor (Profunctor (dimap))
import Data.Profunctor.Product.Class (ProductProfunctor, (***!), empty)
import Data.Profunctor.Product.Default.Class (Default (def))
import Control.Applicative (pure)

mkTs :: [Int] -> Q [Dec]
mkTs = mapM mkT

mkT :: Int -> Q Dec
mkT n = tySynD (tyName n) tyVars tyDef
  where
    tyName n' = mkName ('T':show n')
    tyVars = map PlainTV . take n $ allNames
    tyDef = case n of
      0 -> tupleT 0
      1 -> varT (head allNames)
      _ -> tupleT 2 `appT` varT (head allNames) `appT` applyT (n - 1)
    applyT n' = foldl (\t v -> t `appT` varT v) (conT (tyName n')) (take n' (tail allNames))
    allNames = [ mkName $ c:show i | i <- [0::Int ..], c <- ['a'..'z'] ]

chain :: ProductProfunctor p => (t -> p a2 b2) -> (p a1 b1, t)
      -> p (a1, a2) (b1, b2)
chain rest (a, as) = uncurry (***!) (a, rest as)

pTns :: [Int] -> Q [Dec]
pTns = fmap concat . mapM pTn

productProfunctor :: Name -> Q Pred
productProfunctor p = classP ''ProductProfunctor [pure (VarT p)]

default_ :: Name -> Name -> Name -> Q Pred
default_ p a b = classP ''Default (map (pure . VarT) [p, a, b])

pTn :: Int -> Q [Dec]
pTn n = sequence [sig, fun]
  where
    p = mkName "p"
    sig = sigD (pT n) (forallT (map PlainTV $ p : take n as ++ take n bs)
                               (sequence [productProfunctor p])
                               (arrowT `appT` mkLeftTy `appT` mkRightTy)
                      )
    mkLeftTy = foldl appT (conT tN)
             $ zipWith (\a b -> varT p `appT` varT a `appT` varT b) (take n as) (take n bs)
    mkRightTy = varT p `appT` foldl appT (conT tN) (map varT . take n $ as)
                       `appT` foldl appT (conT tN) (map varT . take n $ bs)
    fun = funD (pT n) [ clause [] (normalB bdy) [] ]
    bdy = case n of
      0 -> [| const empty |]
      1 -> [| id |]
      2 -> [| uncurry (***!) |]
      _ -> varE 'chain `appE` varE (pT (n - 1))
    pT n' = mkName ("pT" ++ show n')
    tN = mkName ('T':show n)
    as = [ mkName $ 'a':show i | i <- [0::Int ..] ]
    bs = [ mkName $ 'b':show i | i <- [0::Int ..] ]

mkFlattenNs :: [Int] -> Q [Dec]
mkFlattenNs = fmap concat . mapM mkFlattenN

mkFlattenN :: Int -> Q [Dec]
mkFlattenN n = sequence [sig, fun]
  where
    sig = sigD nm (forallT (map PlainTV names) (pure []) $ arrowT `appT` unflatT names `appT` flatT names)
    fun = funD nm [ clause [mkTupPat names] (normalB bdy) [] ]
    bdy = mkFlatExp names
    unflatT [] = tupleT 0
    unflatT [v] = varT v
    unflatT (v:vs) = tupleT 2 `appT` varT v `appT` unflatT vs
    flatT [] = tupleT 0
    flatT [v] = varT v
    flatT vs = foldl appT (tupleT (length vs)) (map varT vs)
    mkTupPat [] = tupP []
    mkTupPat [v] = varP v
    mkTupPat (v:vs) = tupP [varP v, mkTupPat vs]
    mkFlatExp [] = tupE []
    mkFlatExp [v] = varE v
    mkFlatExp vs = tupE (map varE vs)
    nm = mkName ("flatten" ++ show n)
    names = take n [ mkName $ c:show i | i <- [0::Int ..], c <- ['a'..'z'] ]

mkUnflattenNs :: [Int] -> Q [Dec]
mkUnflattenNs = fmap concat . mapM mkUnflattenN

mkUnflattenN :: Int -> Q [Dec]
mkUnflattenN n = sequence [sig, fun]
  where
    sig = sigD nm (forallT (map PlainTV names) (pure []) $ arrowT `appT` flatT names `appT` unflatT names)
    fun = funD nm [ clause [mkTupPat names] (normalB bdy) [] ]
    bdy = mkUnflatExp names
    unflatT [] = tupleT 0
    unflatT [v] = varT v
    unflatT (v:vs) = tupleT 2 `appT` varT v `appT` unflatT vs
    flatT [] = tupleT 0
    flatT [v] = varT v
    flatT vs = foldl appT (tupleT (length vs)) (map varT vs)
    mkTupPat [] = tupP []
    mkTupPat [v] = varP v
    mkTupPat vs = tupP (map varP vs)
    mkUnflatExp [] = tupE []
    mkUnflatExp [v] = varE v
    mkUnflatExp (v:vs) = tupE [varE v, mkUnflatExp vs]
    nm = mkName ("unflatten" ++ show n)
    names = take n [ mkName $ c:show i | i <- [0::Int ..], c <- ['a'..'z'] ]

pNs :: [Int] -> Q [Dec]
pNs = fmap concat . mapM pN

pN :: Int -> Q [Dec]
pN n = sequence [sig, fun]
  where
    sig = sigD nm (forallT (map PlainTV $ p : as ++ bs)
                           (sequence [productProfunctor p])
                           (arrowT `appT` mkLeftTy `appT` mkRightTy)
                   )
    mkLeftTy = case n of
      1 -> mkPT (head as) (head bs)
      _ -> foldl appT (tupleT n) (zipWith mkPT as bs)
    mkRightTy = varT p `appT` mkTupT as `appT` mkTupT bs
    mkTupT = foldl appT (tupleT n) . map varT
    mkPT a b = varT p `appT` varT a `appT` varT b
    fun = funD nm [ clause [] (normalB bdy) [] ]
    bdy = varE 'convert `appE` unflat `appE` unflat `appE` flat `appE` pT
    unflat = varE $ mkName unflatNm
    flat = varE $ mkName flatNm
    pT = varE $ mkName pTNm
    unflatNm = "unflatten" ++ show n
    flatNm = "flatten" ++ show n
    pTNm = "pT" ++ show n
    nm = mkName ('p':show n)
    p = mkName "p"
    as = take n [ mkName $ 'a':show i | i <- [0::Int ..] ]
    bs = take n [ mkName $ 'b':show i | i <- [0::Int ..] ]

convert :: Profunctor p => (a2 -> a1) -> (tp -> tTp) -> (b1 -> b2)
                           -> (tTp -> p a1 b1)
                           -> tp -> p a2 b2
convert u u' f c = dimap u f . c . u'

mkDefaultNs :: [Int] -> Q [Dec]
mkDefaultNs = mapM mkDefaultN

mkDefaultN :: Int -> Q Dec
mkDefaultN n = instanceD (sequence (productProfunctor p : mkDefs))
                         (conT ''Default `appT` varT p `appT` mkTupT as `appT` mkTupT bs)
                         [mkFun]
  where
    mkDefs = zipWith (\a b -> default_ p a b) as bs
    mkTupT = foldl appT (tupleT n) . map varT
    mkFun = funD 'def [clause [] bdy []]
    bdy = normalB $ case n of
      0 -> varE 'empty
      _ -> varE (mkName $ 'p':show n) `appE` tupE (replicate n (varE 'def))
    p = mkName "p"
    as = take n [ mkName $ 'a':show i | i <- [0::Int ..] ]
    bs = take n [ mkName $ 'b':show i | i <- [0::Int ..] ]

maxTupleSize :: Int
maxTupleSize = 62