module Contravariant.Extras.TH (
    opContrazipDecs,
    contrazipDecs,
    contrazipExp,
  ) where

import Contravariant.Extras.Prelude
import Data.Functor.Contravariant
import Data.Functor.Contravariant.Divisible
import Language.Haskell.TH.Syntax hiding (classP)
import qualified TemplateHaskell.Compat.V0208 as Compat


{-|
Generates declarations in the spirit of the following:

@
tuple3 :: Monoid a => Op a b1 -> Op a b2 -> Op a b3 -> Op a ( b1 , b2 , b3 )
tuple3 ( Op op1 ) ( Op op2 ) ( Op op3 ) =
  Op $ \( v1 , v2 , v3 ) -> mconcat [ op1 v1 , op2 v2 , op3 v3 ]
@
-}
opContrazipDecs :: String -> Int -> [ Dec ]
opContrazipDecs :: String -> Int -> [Dec]
opContrazipDecs String
baseName Int
arity =
  [ Dec
signature , Dec
value ]
  where
    name :: Name
name =
      String -> Name
mkName (String -> ShowS
showString String
baseName (Int -> String
forall a. Show a => a -> String
show Int
arity))
    signature :: Dec
signature =
      Name -> Type -> Dec
SigD Name
name Type
type_
      where
        type_ :: Type
type_ =
          [TyVarBndr] -> Cxt -> Type -> Type
ForallT [TyVarBndr]
vars Cxt
cxt Type
type_
          where
            vars :: [TyVarBndr]
vars =
              (String -> TyVarBndr) -> [String] -> [TyVarBndr]
forall a b. (a -> b) -> [a] -> [b]
map (Name -> TyVarBndr
Compat.specifiedPlainTV (Name -> TyVarBndr) -> (String -> Name) -> String -> TyVarBndr
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> Name
mkName) (String
"a" String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
bs)
              where
                bs :: [String]
bs =
                  (Int -> String) -> [Int] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Int -> String
forall a. Show a => a -> String
b (Int -> Int -> [Int]
forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
                  where
                    b :: a -> String
b a
index =
                      String -> ShowS
showString String
"b" (a -> String
forall a. Show a => a -> String
show a
index)
            cxt :: Cxt
cxt =
              [ Type
pred ]
              where
                pred :: Type
pred =
                  Name -> Cxt -> Type
Compat.classP ''Monoid [ Type
a ]
                  where
                    a :: Type
a =
                      Name -> Type
VarT (String -> Name
mkName String
"a") 
            type_ :: Type
type_ =
              (Type -> Type -> Type) -> Type -> Cxt -> Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
appArrowT Type
result Cxt
params
              where
                appArrowT :: Type -> Type -> Type
appArrowT Type
a Type
b =
                  Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
a) Type
b
                a :: Type
a =
                  Name -> Type
VarT (String -> Name
mkName String
"a")
                result :: Type
result =
                  Type -> Type -> Type
AppT (Type -> Type -> Type
AppT (Name -> Type
ConT ''Op) Type
a) Type
tuple
                  where
                    tuple :: Type
tuple =
                      (Type -> Type -> Type) -> Type -> Cxt -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Int -> Type
TupleT Int
arity) Cxt
params
                      where
                        params :: Cxt
params =
                          (Int -> Type) -> [Int] -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map Int -> Type
forall a. Show a => a -> Type
param (Int -> Int -> [Int]
forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
                          where
                            param :: a -> Type
param a
index =
                              Name -> Type
VarT (String -> Name
mkName (String -> ShowS
showString String
"b" (a -> String
forall a. Show a => a -> String
show a
index)))
                params :: Cxt
params =
                  (Int -> Type) -> [Int] -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map Int -> Type
forall a. Show a => a -> Type
param (Int -> Int -> [Int]
forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
                  where
                    param :: a -> Type
param a
index =
                      Type -> Type -> Type
AppT (Type -> Type -> Type
AppT (Name -> Type
ConT ''Op) Type
a) Type
b
                      where
                        b :: Type
b =
                          Name -> Type
VarT (String -> Name
mkName (String -> ShowS
showString String
"b" (a -> String
forall a. Show a => a -> String
show a
index)))
    value :: Dec
value =
      Name -> [Clause] -> Dec
FunD Name
name [Clause]
clauses
      where
        clauses :: [Clause]
clauses =
          [ Clause
clause ]
          where
            clause :: Clause
clause =
              [Pat] -> Body -> [Dec] -> Clause
Clause [Pat]
pats Body
body []
              where
                pats :: [Pat]
pats =
                  (Int -> Pat) -> [Int] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Pat
forall a. Show a => a -> Pat
pat (Int -> Int -> [Int]
forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
                  where
                    pat :: a -> Pat
pat a
index =
                      Name -> [Pat] -> Pat
Compat.conp 'Op [Pat]
pats
                      where
                        pats :: [Pat]
pats =
                          [ Name -> Pat
VarP Name
name ]
                          where
                            name :: Name
name =
                              String -> Name
mkName (String -> ShowS
showString String
"op" (a -> String
forall a. Show a => a -> String
show a
index))
                body :: Body
body =
                  Exp -> Body
NormalB (Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Op) Exp
lambda)
                  where
                    lambda :: Exp
lambda =
                      [Pat] -> Exp -> Exp
LamE [Pat]
pats Exp
exp
                      where
                        pats :: [Pat]
pats =
                          [ [Pat] -> Pat
TupP [Pat]
pats ]
                          where
                            pats :: [Pat]
pats =
                              (Int -> Pat) -> [Int] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Pat
forall a. Show a => a -> Pat
pat (Int -> Int -> [Int]
forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
                              where
                                pat :: a -> Pat
pat a
index =
                                  Name -> Pat
VarP (String -> Name
mkName (String -> ShowS
showString String
"v" (a -> String
forall a. Show a => a -> String
show a
index)))
                        exp :: Exp
exp =
                          Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'mconcat) ([Exp] -> Exp
ListE [Exp]
applications)
                          where
                            applications :: [Exp]
applications =
                              (Int -> Exp) -> [Int] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Exp
forall a. Show a => a -> Exp
application (Int -> Int -> [Int]
forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
                              where
                                application :: a -> Exp
application a
index =
                                  Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
opName) (Name -> Exp
VarE Name
varName)
                                  where
                                    opName :: Name
opName =
                                      String -> Name
mkName (String -> ShowS
showString String
"op" (a -> String
forall a. Show a => a -> String
show a
index))
                                    varName :: Name
varName =
                                      String -> Name
mkName (String -> ShowS
showString String
"v" (a -> String
forall a. Show a => a -> String
show a
index))

{-|
Generates declarations in the spirit of the following:

@
contrazip4 :: Divisible f => f a1 -> f a2 -> f a3 -> f a4 -> f ( a1 , a2 , a3 , a4 )
contrazip4 f1 f2 f3 f4 =
  divide $(TupleTH.splitTupleAt 4 1) f1 $
  divide $(TupleTH.splitTupleAt 3 1) f2 $
  divide $(TupleTH.splitTupleAt 2 1) f3 $
  f4
@
-}
contrazipDecs :: String -> Int -> [Dec]
contrazipDecs :: String -> Int -> [Dec]
contrazipDecs String
baseName Int
arity = [Dec
signature, Dec
value] where
  name :: Name
name = String -> Name
mkName (String -> ShowS
showString String
baseName (Int -> String
forall a. Show a => a -> String
show Int
arity))
  signature :: Dec
signature = Name -> Type -> Dec
SigD Name
name (Int -> Type
contrazipType Int
arity)
  value :: Dec
value = Name -> [Clause] -> Dec
FunD Name
name [Clause]
clauses where
    clauses :: [Clause]
clauses = [Clause
clause] where
      clause :: Clause
clause = [Pat] -> Body -> [Dec] -> Clause
Clause [] Body
body [] where
        body :: Body
body = Exp -> Body
NormalB (Int -> Exp
contrazipExp Int
arity)

contrazipType :: Int -> Type
contrazipType :: Int -> Type
contrazipType Int
arity = [TyVarBndr] -> Cxt -> Type -> Type
ForallT [TyVarBndr]
vars Cxt
cxt Type
type_ where
  fName :: Name
fName = String -> Name
mkName String
"f"
  aNames :: [Name]
aNames = (Int -> Name) -> [Int] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Name
forall a. Show a => a -> Name
aName (Int -> Int -> [Int]
forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity) where
    aName :: a -> Name
aName a
index = String -> Name
mkName (String -> ShowS
showString String
"a" (a -> String
forall a. Show a => a -> String
show a
index))
  vars :: [TyVarBndr]
vars = (Name -> TyVarBndr) -> [Name] -> [TyVarBndr]
forall a b. (a -> b) -> [a] -> [b]
map Name -> TyVarBndr
Compat.specifiedPlainTV (Name
fName Name -> [Name] -> [Name]
forall a. a -> [a] -> [a]
: [Name]
aNames)
  cxt :: Cxt
cxt = [Type
pred] where
    pred :: Type
pred = Name -> Cxt -> Type
Compat.classP ''Divisible [Name -> Type
VarT Name
fName]
  type_ :: Type
type_ = (Type -> Type -> Type) -> Type -> Cxt -> Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
appArrowT Type
result Cxt
params where
    appArrowT :: Type -> Type -> Type
appArrowT Type
a Type
b = Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
a) Type
b
    result :: Type
result = Type -> Type -> Type
AppT (Name -> Type
VarT Name
fName) Type
tuple where
      tuple :: Type
tuple = (Type -> Type -> Type) -> Type -> Cxt -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Int -> Type
TupleT Int
arity) ((Name -> Type) -> [Name] -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map Name -> Type
VarT [Name]
aNames)
    params :: Cxt
params = (Name -> Type) -> [Name] -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map Name -> Type
param [Name]
aNames where
      param :: Name -> Type
param Name
aName = Type -> Type -> Type
AppT (Name -> Type
VarT Name
fName) (Name -> Type
VarT Name
aName)

{-|
Contrazip lambda expression of specified arity.

Allows to create contrazip expressions of any arity:

>>>:t $(return (contrazipExp 2))
$(return (contrazipExp 2))
  :: Data.Functor.Contravariant.Divisible.Divisible f =>
     f a1 -> f a2 -> f (a1, a2)
-}
contrazipExp :: Int -> Exp
contrazipExp :: Int -> Exp
contrazipExp Int
arity = Exp -> Type -> Exp
SigE ([Pat] -> Exp -> Exp
LamE [Pat]
pats Exp
body) (Int -> Type
contrazipType Int
arity) where
  pats :: [Pat]
pats = (Int -> Pat) -> [Int] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Pat
forall a. Show a => a -> Pat
pat (Int -> Int -> [Int]
forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity) where
    pat :: a -> Pat
pat a
index = Name -> Pat
VarP Name
name where
      name :: Name
name = String -> Name
mkName (String -> ShowS
showString String
"f" (a -> String
forall a. Show a => a -> String
show a
index))
  body :: Exp
body = Int -> Exp
exp Int
arity where
    exp :: Int -> Exp
exp Int
index = case Int
index of
      Int
1 -> Name -> Exp
VarE (String -> Name
mkName (String -> ShowS
showString String
"f" (Int -> String
forall a. Show a => a -> String
show Int
arity)))
      Int
_ -> (Exp -> Exp -> Exp) -> [Exp] -> Exp
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 Exp -> Exp -> Exp
AppE [
          Name -> Exp
VarE 'divide
          ,
          Int -> Int -> Exp
splitTupleAtExp Int
index Int
1
          ,
          Name -> Exp
VarE (String -> Name
mkName (String -> ShowS
showString String
"f" (Int -> String
forall a. Show a => a -> String
show (Int
arity Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
index Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))))
          ,
          Int -> Exp
exp (Int -> Int
forall a. Enum a => a -> a
pred Int
index)
        ]

splitTupleAtExp :: Int -> Int -> Exp
splitTupleAtExp :: Int -> Int -> Exp
splitTupleAtExp Int
arity Int
position =
  let
    nameByIndex :: a -> Name
nameByIndex a
index = OccName -> NameFlavour -> Name
Name (String -> OccName
OccName (Char
'_' Char -> ShowS
forall a. a -> [a] -> [a]
: a -> String
forall a. Show a => a -> String
show a
index)) NameFlavour
NameS
    names :: [Name]
names = Int -> Int -> [Int]
forall a. Enum a => a -> a -> [a]
enumFromTo Int
0 (Int -> Int
forall a. Enum a => a -> a
pred Int
arity) [Int] -> ([Int] -> [Name]) -> [Name]
forall a b. a -> (a -> b) -> b
& (Int -> Name) -> [Int] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Name
forall a. Show a => a -> Name
nameByIndex
    pats :: [Pat]
pats = [Name]
names [Name] -> ([Name] -> [Pat]) -> [Pat]
forall a b. a -> (a -> b) -> b
& (Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP
    pat :: Pat
pat = [Pat] -> Pat
TupP [Pat]
pats
    exps :: [Exp]
exps = [Name]
names [Name] -> ([Name] -> [Exp]) -> [Exp]
forall a b. a -> (a -> b) -> b
& (Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE
    body :: Exp
body = Int -> [Exp] -> ([Exp], [Exp])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
position [Exp]
exps ([Exp], [Exp]) -> (([Exp], [Exp]) -> Exp) -> Exp
forall a b. a -> (a -> b) -> b
& \ ([Exp]
a, [Exp]
b) -> [Exp] -> Exp
Compat.tupE [[Exp] -> Exp
Compat.tupE [Exp]
a, [Exp] -> Exp
Compat.tupE [Exp]
b]
    in [Pat] -> Exp -> Exp
LamE [Pat
pat] Exp
body