{-# LANGUAGE OverloadedStrings          #-}
{-# LANGUAGE RankNTypes                 #-}
{-# LANGUAGE TemplateHaskell            #-}

module Data.API.Tools.Traversal
    ( traversalTool
    , traversalsTool
    ) where

import           Data.API.NormalForm
import           Data.API.Tools.Combinators
import           Data.API.Tools.Datatypes
import           Data.API.TH
import           Data.API.Types

import           Control.Applicative
import qualified Data.Map                       as Map
import           Data.Maybe
import           Data.Monoid
import qualified Data.Set                       as Set
import qualified Data.Text                      as T
import           Data.Traversable
import           Language.Haskell.TH
import           Prelude


-- | Build a traversal of the root type (first argument) that updates
-- values of the second type, e.g. @traversalTool "Root" "Sub"@
-- produces
--
-- > traverseSubRoot :: Applicative f => (Sub -> f Sub) -> Root -> f Root
--
-- along with similar functions for all the types nested inside @Root@
-- that depend on @Sub@.
--
-- Note that types with custom representations will not have
-- traversals generated automatically: if required, these must be
-- defined manually in the same module as the call to 'traversalTool',
-- otherwise the generated code will lead to scope errors.
traversalTool :: TypeName -> TypeName -> APITool
traversalTool :: TypeName -> TypeName -> APITool
traversalTool TypeName
root = [TypeName] -> TypeName -> APITool
traversalsTool [TypeName
root]

-- | Like 'traversalTool', but it allows passing a list of \"roots\", to avoid conflicting
-- declarations.
traversalsTool :: [TypeName] -> TypeName -> APITool
traversalsTool :: [TypeName] -> TypeName -> APITool
traversalsTool [TypeName]
root TypeName
x = forall a. (a -> Tool a) -> Tool a
readTool (Tool APINode -> APITool
apiNodeTool forall b c a. (b -> c) -> (a -> b) -> a -> c
. API -> Tool APINode
s)
  where
    s :: API -> Tool APINode
s API
api = Tool (APINode, SpecNewtype)
-> Tool (APINode, SpecRecord)
-> Tool (APINode, SpecUnion)
-> Tool (APINode, SpecEnum)
-> Tool (APINode, APIType)
-> Tool APINode
apiSpecTool forall a. Monoid a => a
mempty (forall a. (a -> Q [Dec]) -> Tool a
simpleTool (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a b. (a -> b) -> a -> b
$ NormAPI
-> Set TypeName -> TypeName -> APINode -> SpecRecord -> Q [Dec]
traversalRecord NormAPI
napi Set TypeName
targets TypeName
x))
                               (forall a. (a -> Q [Dec]) -> Tool a
simpleTool (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a b. (a -> b) -> a -> b
$ NormAPI
-> Set TypeName -> TypeName -> APINode -> SpecUnion -> Q [Dec]
traversalUnion  NormAPI
napi Set TypeName
targets TypeName
x)) forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
      where
        napi :: NormAPI
napi    = API -> NormAPI
apiNormalForm API
api

        -- Calculate the types for which we must provide traversals:
        -- those that the root depends on and that depend on the
        -- traversed type
        targets :: Set TypeName
targets = (NormAPI -> Set TypeName -> Set TypeName
transitiveDeps        NormAPI
napi Set TypeName
rootSet forall a. Ord a => Set a -> Set a -> Set a
`Set.union` Set TypeName
rootSet) forall a. Ord a => Set a -> Set a -> Set a
`Set.intersection`
                  (NormAPI -> Set TypeName -> Set TypeName
transitiveReverseDeps NormAPI
napi Set TypeName
xSet    forall a. Ord a => Set a -> Set a -> Set a
`Set.union` Set TypeName
xSet)
        rootSet :: Set TypeName
rootSet = forall a. Ord a => [a] -> Set a
Set.fromList [TypeName]
root
        xSet :: Set TypeName
xSet    = forall a. a -> Set a
Set.singleton TypeName
x


-- | @traversalName x tn@ is the name of the function that traverses
-- @x@ values inside @tn@
traversalName :: TypeName -> TypeName -> Name
traversalName :: TypeName -> TypeName -> Name
traversalName TypeName
x TypeName
tn = Text -> Name
mkNameText forall a b. (a -> b) -> a -> b
$ Text
"traverse" forall a. Semigroup a => a -> a -> a
<> TypeName -> Text
_TypeName TypeName
x forall a. Semigroup a => a -> a -> a
<> TypeName -> Text
_TypeName TypeName
tn

-- | @traversalType x an@ is the type of the function that traverses
-- @x@ values inside @an@
traversalType :: TypeName -> APINode -> TypeQ
traversalType :: TypeName -> APINode -> TypeQ
traversalType TypeName
x APINode
an = [t| forall f . Applicative f => ($x' -> f $x') -> $ty -> f $ty |]
  where
    x' :: TypeQ
x' = forall (m :: * -> *). Quote m => Name -> m Type
conT forall a b. (a -> b) -> a -> b
$ Text -> Name
mkNameText forall a b. (a -> b) -> a -> b
$ TypeName -> Text
_TypeName TypeName
x
    ty :: TypeQ
ty = APINode -> TypeQ
nodeT APINode
an


-- | Construct a traversal of the X substructures of the given type
traverser :: NormAPI -> Set.Set TypeName -> TypeName -> APIType -> ExpQ
traverser :: NormAPI -> Set TypeName -> TypeName -> APIType -> ExpQ
traverser NormAPI
napi Set TypeName
targets TypeName
x APIType
ty = forall a. a -> Maybe a -> a
fromMaybe [| const pure |] forall a b. (a -> b) -> a -> b
$ NormAPI -> Set TypeName -> TypeName -> APIType -> Maybe ExpQ
traverser' NormAPI
napi Set TypeName
targets TypeName
x APIType
ty

-- | Construct a traversal of the X substructures of the given type,
-- or return 'Nothing' if there are no substructures to traverse
traverser' :: NormAPI -> Set.Set TypeName -> TypeName -> APIType -> Maybe ExpQ
traverser' :: NormAPI -> Set TypeName -> TypeName -> APIType -> Maybe ExpQ
traverser' NormAPI
napi Set TypeName
targets TypeName
x (TyList APIType
ty)  = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE [e|(.) traverse|]) forall a b. (a -> b) -> a -> b
$ NormAPI -> Set TypeName -> TypeName -> APIType -> Maybe ExpQ
traverser' NormAPI
napi Set TypeName
targets TypeName
x APIType
ty
traverser' NormAPI
napi Set TypeName
targets TypeName
x (TyMaybe APIType
ty) = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE [e|(.) traverse|]) forall a b. (a -> b) -> a -> b
$ NormAPI -> Set TypeName -> TypeName -> APIType -> Maybe ExpQ
traverser' NormAPI
napi Set TypeName
targets TypeName
x APIType
ty
traverser' NormAPI
napi Set TypeName
targets TypeName
x (TyName TypeName
tn)
  | TypeName
tn forall a. Eq a => a -> a -> Bool
== TypeName
x   = forall a. a -> Maybe a
Just [e| id |]
  | Bool -> Bool
not (TypeName
tn forall a. Ord a => a -> Set a -> Bool
`Set.member` Set TypeName
targets) = forall a. Maybe a
Nothing
  | Bool
otherwise = case forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup TypeName
tn NormAPI
napi of
                           Maybe NormTypeDecl
Nothing                -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"missing API type declaration: " forall a. [a] -> [a] -> [a]
++ Text -> [Char]
T.unpack (TypeName -> Text
_TypeName TypeName
tn)
                           Just (NTypeSynonym APIType
ty) -> NormAPI -> Set TypeName -> TypeName -> APIType -> Maybe ExpQ
traverser' NormAPI
napi Set TypeName
targets TypeName
x APIType
ty
                           Just (NRecordType  NormRecordType
_)  -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => Name -> m Exp
varE forall a b. (a -> b) -> a -> b
$ TypeName -> TypeName -> Name
traversalName TypeName
x TypeName
tn
                           Just (NUnionType   NormRecordType
_)  -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => Name -> m Exp
varE forall a b. (a -> b) -> a -> b
$ TypeName -> TypeName -> Name
traversalName TypeName
x TypeName
tn
                           Just (NEnumType    NormEnumType
_)  -> forall a. Maybe a
Nothing
                           Just (NNewtype     BasicType
_)  -> forall a. Maybe a
Nothing
traverser' NormAPI
_ Set TypeName
_ TypeName
_ (TyBasic BasicType
_)  = forall a. Maybe a
Nothing
traverser' NormAPI
_ Set TypeName
_ TypeName
_ APIType
TyJSON       = forall a. Maybe a
Nothing


-- | Build a traversal for a record type that applies f to any fields
-- of type X, and traverses nested structures.  For example:
--
-- > traverseXFoo :: Applicative f => (X -> f X) -> Foo -> f Foo
-- > traverseXFoo f x = Foo <$> f (foo_a x) <*> traverseXBar (traverse f) (foo_b x)
--
traversalRecord :: NormAPI -> Set.Set TypeName -> TypeName -> APINode -> SpecRecord -> Q [Dec]
traversalRecord :: NormAPI
-> Set TypeName -> TypeName -> APINode -> SpecRecord -> Q [Dec]
traversalRecord NormAPI
napi Set TypeName
targets TypeName
x APINode
an SpecRecord
sr
  | Bool -> Bool
not (APINode -> TypeName
anName APINode
an forall a. Ord a => a -> Set a -> Bool
`Set.member` Set TypeName
targets) = forall (m :: * -> *) a. Monad m => a -> m a
return []
  | APINode -> Conversion
anConvert APINode
an forall a. Eq a => a -> a -> Bool
/= forall a. Maybe a
Nothing              = forall (m :: * -> *) a. Monad m => a -> m a
return []
  | Bool
otherwise                            = Name -> TypeQ -> ExpQ -> Q [Dec]
simpleSigD Name
nom (TypeName -> APINode -> TypeQ
traversalType TypeName
x APINode
an) ExpQ
bdy
  where
    nom :: Name
nom = TypeName -> TypeName -> Name
traversalName TypeName
x (APINode -> TypeName
anName APINode
an)
    bdy :: ExpQ
bdy = do
      Name
f <- forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"f"
      Name
r <- forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"r"
      forall (m :: * -> *). Quote m => [m Pat] -> m Exp -> m Exp
lamE [forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
f, forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
r] forall a b. (a -> b) -> a -> b
$ ExpQ -> [ExpQ] -> ExpQ
applicativeE (APINode -> ExpQ
nodeConE APINode
an) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Name -> Name -> (FieldName, FieldType) -> ExpQ
traverseField Name
f Name
r) (SpecRecord -> [(FieldName, FieldType)]
srFields SpecRecord
sr)
    traverseField :: Name -> Name -> (FieldName, FieldType) -> ExpQ
traverseField Name
f Name
r (FieldName
fn, FieldType
fty) = [e| $(traverser napi targets x (ftType fty)) $(varE f) ($(nodeFieldE an fn) $(varE r)) |]


-- | Build a traversal for a union type that traverses nested structures.
-- For example:
--
-- > traverseXBar :: Applicative f => (X -> f X) -> Bar -> f Bar
-- > traverseXBar f (BAR_one a) = BAR_one <$> f a
-- > traverseXBar f (Bar_two b) = BAR_two <$> traverseXBaz f b
--
traversalUnion :: NormAPI -> Set.Set TypeName -> TypeName -> APINode -> SpecUnion -> Q [Dec]
traversalUnion :: NormAPI
-> Set TypeName -> TypeName -> APINode -> SpecUnion -> Q [Dec]
traversalUnion NormAPI
napi Set TypeName
targets TypeName
x APINode
an SpecUnion
su
  | Bool -> Bool
not (APINode -> TypeName
anName APINode
an forall a. Ord a => a -> Set a -> Bool
`Set.member` Set TypeName
targets) = forall (m :: * -> *) a. Monad m => a -> m a
return []
  | APINode -> Conversion
anConvert APINode
an forall a. Eq a => a -> a -> Bool
/= forall a. Maybe a
Nothing              = forall (m :: * -> *) a. Monad m => a -> m a
return []
  | Bool
otherwise                            = Name -> TypeQ -> [ClauseQ] -> Q [Dec]
funSigD Name
nom (TypeName -> APINode -> TypeQ
traversalType TypeName
x APINode
an) [ClauseQ]
cls
  where
    nom :: Name
nom = TypeName -> TypeName -> Name
traversalName TypeName
x (APINode -> TypeName
anName APINode
an)
    cls :: [ClauseQ]
cls = forall a b. (a -> b) -> [a] -> [b]
map forall {b}. (FieldName, (APIType, b)) -> ClauseQ
cl forall a b. (a -> b) -> a -> b
$ SpecUnion -> [(FieldName, (APIType, [Char]))]
suFields SpecUnion
su
    cl :: (FieldName, (APIType, b)) -> ClauseQ
cl (FieldName
fn,(APIType
ty,b
_)) = do
      Name
f <- forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"f"
      Name
z <- forall (m :: * -> *). Quote m => [Char] -> m Name
newName [Char]
"z"
      forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
f, APINode -> FieldName -> [Q Pat] -> Q Pat
nodeAltConP APINode
an FieldName
fn [forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
z]] (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB (FieldName -> APIType -> Name -> Name -> ExpQ
bdy FieldName
fn APIType
ty Name
f Name
z)) []
    bdy :: FieldName -> APIType -> Name -> Name -> ExpQ
bdy FieldName
fn APIType
ty Name
f Name
z = [e| $(nodeAltConE an fn) <$> $(traverser napi targets x ty) $(varE f) $(varE z) |]