-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Retrie.Universe
  ( Universe
  , printU
  , Matchable(..)
  , UMap(..)
  ) where

import Control.Monad
import Data.Data

import Retrie.AlphaEnv
import Retrie.ExactPrint
import Retrie.GHC
import Retrie.PatternMap.Class
import Retrie.PatternMap.Instances
import Retrie.Quantifiers
import Retrie.Substitution

-- | A sum type to collect all possible top-level rewritable types.
data Universe
  = ULHsExpr (LHsExpr GhcPs)
  | ULStmt (LStmt GhcPs (LHsExpr GhcPs))
  | ULType (LHsType GhcPs)
  | ULPat (LPat GhcPs)
  deriving (Typeable Universe
Universe -> DataType
Universe -> Constr
(forall b. Data b => b -> b) -> Universe -> Universe
forall a.
Typeable a
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall u. Int -> (forall d. Data d => d -> u) -> Universe -> u
forall u. (forall d. Data d => d -> u) -> Universe -> [u]
forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Universe -> r
forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Universe -> r
forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Universe -> m Universe
forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Universe -> m Universe
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c Universe
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Universe -> c Universe
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c Universe)
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c Universe)
gmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Universe -> m Universe
$cgmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Universe -> m Universe
gmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Universe -> m Universe
$cgmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Universe -> m Universe
gmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Universe -> m Universe
$cgmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Universe -> m Universe
gmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> Universe -> u
$cgmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> Universe -> u
gmapQ :: forall u. (forall d. Data d => d -> u) -> Universe -> [u]
$cgmapQ :: forall u. (forall d. Data d => d -> u) -> Universe -> [u]
gmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Universe -> r
$cgmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Universe -> r
gmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Universe -> r
$cgmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Universe -> r
gmapT :: (forall b. Data b => b -> b) -> Universe -> Universe
$cgmapT :: (forall b. Data b => b -> b) -> Universe -> Universe
dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c Universe)
$cdataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c Universe)
dataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c Universe)
$cdataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c Universe)
dataTypeOf :: Universe -> DataType
$cdataTypeOf :: Universe -> DataType
toConstr :: Universe -> Constr
$ctoConstr :: Universe -> Constr
gunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c Universe
$cgunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c Universe
gfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Universe -> c Universe
$cgfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Universe -> c Universe
Data)

-- | Exactprint an annotated 'Universe'.
printU :: Annotated Universe -> String
printU :: Annotated Universe -> String
printU Annotated Universe
u = Universe -> String
exactPrintU (forall ast. Annotated ast -> ast
astA Annotated Universe
u)
    forall c. c -> String -> c
`debug` (String
"printU:" forall a. [a] -> [a] -> [a]
++ forall a. Data a => a -> String
showAst (forall ast. Annotated ast -> ast
astA Annotated Universe
u))

-- | Primitive exactprint for 'Universe'.
exactPrintU :: Universe -> String
exactPrintU :: Universe -> String
exactPrintU (ULHsExpr LHsExpr GhcPs
e) = forall ast. ExactPrint ast => ast -> String
exactPrint LHsExpr GhcPs
e
exactPrintU (ULStmt LStmt GhcPs (LHsExpr GhcPs)
s) = forall ast. ExactPrint ast => ast -> String
exactPrint LStmt GhcPs (LHsExpr GhcPs)
s
exactPrintU (ULType LHsType GhcPs
t) = forall ast. ExactPrint ast => ast -> String
exactPrint LHsType GhcPs
t
exactPrintU (ULPat LPat GhcPs
p) = forall ast. ExactPrint ast => ast -> String
exactPrint LPat GhcPs
p

-------------------------------------------------------------------------------

-- | Class of types which can be injected into the 'Universe' type.
class Matchable ast where
  -- | Inject an AST into 'Universe'
  inject :: ast -> Universe

  -- | Project an AST from a 'Universe'.
  -- Can fail if universe contains the wrong type.
  project :: Universe -> ast

  -- | Get the original location of the AST.
  getOrigin :: ast -> SrcSpan

instance Matchable Universe where
  inject :: Universe -> Universe
inject = forall a. a -> a
id
  project :: Universe -> Universe
project = forall a. a -> a
id
  getOrigin :: Universe -> SrcSpan
getOrigin (ULHsExpr LHsExpr GhcPs
e) = forall ast. Matchable ast => ast -> SrcSpan
getOrigin LHsExpr GhcPs
e
  getOrigin (ULStmt LStmt GhcPs (LHsExpr GhcPs)
s) = forall ast. Matchable ast => ast -> SrcSpan
getOrigin LStmt GhcPs (LHsExpr GhcPs)
s
  getOrigin (ULType LHsType GhcPs
t) = forall ast. Matchable ast => ast -> SrcSpan
getOrigin LHsType GhcPs
t
  getOrigin (ULPat LPat GhcPs
p) = forall ast. Matchable ast => ast -> SrcSpan
getOrigin LPat GhcPs
p

instance Matchable (LocatedA (HsExpr GhcPs)) where
  inject :: GenLocated SrcSpanAnnA (HsExpr GhcPs) -> Universe
inject = LHsExpr GhcPs -> Universe
ULHsExpr
  project :: Universe -> GenLocated SrcSpanAnnA (HsExpr GhcPs)
project (ULHsExpr LHsExpr GhcPs
x) = LHsExpr GhcPs
x
  project Universe
_ = forall a. HasCallStack => String -> a
error String
"project LHsExpr"
  getOrigin :: GenLocated SrcSpanAnnA (HsExpr GhcPs) -> SrcSpan
getOrigin GenLocated SrcSpanAnnA (HsExpr GhcPs)
e = forall a e. GenLocated (SrcSpanAnn' a) e -> SrcSpan
getLocA GenLocated SrcSpanAnnA (HsExpr GhcPs)
e

instance Matchable (LocatedA (Stmt GhcPs (LocatedA (HsExpr GhcPs)))) where
  inject :: GenLocated
  SrcSpanAnnA
  (StmtLR GhcPs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
-> Universe
inject = LStmt GhcPs (LHsExpr GhcPs) -> Universe
ULStmt
  project :: Universe
-> GenLocated
     SrcSpanAnnA
     (StmtLR GhcPs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
project (ULStmt LStmt GhcPs (LHsExpr GhcPs)
x) = LStmt GhcPs (LHsExpr GhcPs)
x
  project Universe
_ = forall a. HasCallStack => String -> a
error String
"project LStmt"
  getOrigin :: GenLocated
  SrcSpanAnnA
  (StmtLR GhcPs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
-> SrcSpan
getOrigin GenLocated
  SrcSpanAnnA
  (StmtLR GhcPs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
e = forall a e. GenLocated (SrcSpanAnn' a) e -> SrcSpan
getLocA GenLocated
  SrcSpanAnnA
  (StmtLR GhcPs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
e

instance Matchable (LocatedA (HsType GhcPs)) where
  inject :: GenLocated SrcSpanAnnA (HsType GhcPs) -> Universe
inject = LHsType GhcPs -> Universe
ULType
  project :: Universe -> GenLocated SrcSpanAnnA (HsType GhcPs)
project (ULType LHsType GhcPs
t) = LHsType GhcPs
t
  project Universe
_ = forall a. HasCallStack => String -> a
error String
"project ULType"
  getOrigin :: GenLocated SrcSpanAnnA (HsType GhcPs) -> SrcSpan
getOrigin GenLocated SrcSpanAnnA (HsType GhcPs)
e = forall a e. GenLocated (SrcSpanAnn' a) e -> SrcSpan
getLocA GenLocated SrcSpanAnnA (HsType GhcPs)
e

instance Matchable (LocatedA (Pat GhcPs)) where
  inject :: GenLocated SrcSpanAnnA (Pat GhcPs) -> Universe
inject = LPat GhcPs -> Universe
ULPat
  project :: Universe -> GenLocated SrcSpanAnnA (Pat GhcPs)
project (ULPat LPat GhcPs
p) = LPat GhcPs
p
  project Universe
_ = forall a. HasCallStack => String -> a
error String
"project ULPat"
  getOrigin :: GenLocated SrcSpanAnnA (Pat GhcPs) -> SrcSpan
getOrigin = forall a e. GenLocated (SrcSpanAnn' a) e -> SrcSpan
getLocA

-------------------------------------------------------------------------------

-- | The pattern map for 'Universe'.
data UMap a = UMap
  { forall a. UMap a -> EMap a
umExpr :: EMap a
  , forall a. UMap a -> SMap a
umStmt :: SMap a
  , forall a. UMap a -> TyMap a
umType :: TyMap a
  , forall a. UMap a -> PatMap a
umPat  :: PatMap a
  }
  deriving (forall a b. a -> UMap b -> UMap a
forall a b. (a -> b) -> UMap a -> UMap b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> UMap b -> UMap a
$c<$ :: forall a b. a -> UMap b -> UMap a
fmap :: forall a b. (a -> b) -> UMap a -> UMap b
$cfmap :: forall a b. (a -> b) -> UMap a -> UMap b
Functor)

instance PatternMap UMap where
  type Key UMap = Universe

  mEmpty :: UMap a
  mEmpty :: forall a. UMap a
mEmpty = forall a. EMap a -> SMap a -> TyMap a -> PatMap a -> UMap a
UMap forall (m :: * -> *) a. PatternMap m => m a
mEmpty forall (m :: * -> *) a. PatternMap m => m a
mEmpty forall (m :: * -> *) a. PatternMap m => m a
mEmpty forall (m :: * -> *) a. PatternMap m => m a
mEmpty

  mUnion :: UMap a -> UMap a -> UMap a
  mUnion :: forall a. UMap a -> UMap a -> UMap a
mUnion UMap a
m1 UMap a
m2 = forall a. EMap a -> SMap a -> TyMap a -> PatMap a -> UMap a
UMap
    (forall (m :: * -> *) a b.
PatternMap m =>
(a -> m b) -> a -> a -> m b
unionOn forall a. UMap a -> EMap a
umExpr UMap a
m1 UMap a
m2)
    (forall (m :: * -> *) a b.
PatternMap m =>
(a -> m b) -> a -> a -> m b
unionOn forall a. UMap a -> SMap a
umStmt UMap a
m1 UMap a
m2)
    (forall (m :: * -> *) a b.
PatternMap m =>
(a -> m b) -> a -> a -> m b
unionOn forall a. UMap a -> TyMap a
umType UMap a
m1 UMap a
m2)
    (forall (m :: * -> *) a b.
PatternMap m =>
(a -> m b) -> a -> a -> m b
unionOn forall a. UMap a -> PatMap a
umPat UMap a
m1 UMap a
m2)

  mAlter :: AlphaEnv -> Quantifiers -> Universe -> A a -> UMap a -> UMap a
  mAlter :: forall a.
AlphaEnv -> Quantifiers -> Universe -> A a -> UMap a -> UMap a
mAlter AlphaEnv
env Quantifiers
vs Universe
u A a
f UMap a
m = Universe -> UMap a
go Universe
u
    where
      go :: Universe -> UMap a
go (ULHsExpr LHsExpr GhcPs
e) = UMap a
m { umExpr :: EMap a
umExpr = forall (m :: * -> *) a.
PatternMap m =>
AlphaEnv -> Quantifiers -> Key m -> A a -> m a -> m a
mAlter AlphaEnv
env Quantifiers
vs LHsExpr GhcPs
e A a
f (forall a. UMap a -> EMap a
umExpr UMap a
m) }
      go (ULStmt LStmt GhcPs (LHsExpr GhcPs)
s) = UMap a
m { umStmt :: SMap a
umStmt = forall (m :: * -> *) a.
PatternMap m =>
AlphaEnv -> Quantifiers -> Key m -> A a -> m a -> m a
mAlter AlphaEnv
env Quantifiers
vs LStmt GhcPs (LHsExpr GhcPs)
s A a
f (forall a. UMap a -> SMap a
umStmt UMap a
m) }
      go (ULType LHsType GhcPs
t) = UMap a
m { umType :: TyMap a
umType = forall (m :: * -> *) a.
PatternMap m =>
AlphaEnv -> Quantifiers -> Key m -> A a -> m a -> m a
mAlter AlphaEnv
env Quantifiers
vs LHsType GhcPs
t A a
f (forall a. UMap a -> TyMap a
umType UMap a
m) }
      go (ULPat LPat GhcPs
p) = UMap a
m { umPat :: PatMap a
umPat  = forall (m :: * -> *) a.
PatternMap m =>
AlphaEnv -> Quantifiers -> Key m -> A a -> m a -> m a
mAlter AlphaEnv
env Quantifiers
vs (forall (p :: Pass). LPat (GhcPass p) -> LPat (GhcPass p)
cLPat LPat GhcPs
p) A a
f (forall a. UMap a -> PatMap a
umPat UMap a
m) }

  mMatch :: MatchEnv -> Universe -> (Substitution, UMap a) -> [(Substitution, a)]
  mMatch :: forall a.
MatchEnv
-> Universe -> (Substitution, UMap a) -> [(Substitution, a)]
mMatch MatchEnv
env = Universe -> (Substitution, UMap a) -> [(Substitution, a)]
go
    where
      go :: Universe -> (Substitution, UMap a) -> [(Substitution, a)]
go (ULHsExpr LHsExpr GhcPs
e) = forall b c a. (b -> c) -> (a, b) -> [(a, c)]
mapFor forall a. UMap a -> EMap a
umExpr forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> forall (m :: * -> *) a.
PatternMap m =>
MatchEnv -> Key m -> (Substitution, m a) -> [(Substitution, a)]
mMatch MatchEnv
env LHsExpr GhcPs
e
      go (ULStmt LStmt GhcPs (LHsExpr GhcPs)
s) = forall b c a. (b -> c) -> (a, b) -> [(a, c)]
mapFor forall a. UMap a -> SMap a
umStmt forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> forall (m :: * -> *) a.
PatternMap m =>
MatchEnv -> Key m -> (Substitution, m a) -> [(Substitution, a)]
mMatch MatchEnv
env LStmt GhcPs (LHsExpr GhcPs)
s
      go (ULType LHsType GhcPs
t) = forall b c a. (b -> c) -> (a, b) -> [(a, c)]
mapFor forall a. UMap a -> TyMap a
umType forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> forall (m :: * -> *) a.
PatternMap m =>
MatchEnv -> Key m -> (Substitution, m a) -> [(Substitution, a)]
mMatch MatchEnv
env LHsType GhcPs
t
      go (ULPat LPat GhcPs
p) = forall b c a. (b -> c) -> (a, b) -> [(a, c)]
mapFor forall a. UMap a -> PatMap a
umPat forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> forall (m :: * -> *) a.
PatternMap m =>
MatchEnv -> Key m -> (Substitution, m a) -> [(Substitution, a)]
mMatch MatchEnv
env (forall (p :: Pass). LPat (GhcPass p) -> LPat (GhcPass p)
cLPat LPat GhcPs
p)