{-# LANGUAGE GADTs #-}
module Agda.Syntax.Concrete.Operators
    ( parseApplication
    , parseModuleApplication
    , parseLHS
    , parsePattern
    , parsePatternSyn
    ) where
import Control.Applicative ( Alternative((<|>)))
import Control.Arrow (second)
import Control.Monad
import Data.Either (partitionEithers)
import qualified Data.Foldable as Fold
import Data.Function
import qualified Data.List as List
import Data.Maybe
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Traversable (traverse)
import qualified Data.Traversable as Trav
import Agda.Syntax.Common
import Agda.Syntax.Concrete hiding (appView)
import Agda.Syntax.Concrete.Operators.Parser
import Agda.Syntax.Concrete.Operators.Parser.Monad hiding (parse)
import Agda.Syntax.Concrete.Pattern
import qualified Agda.Syntax.Abstract.Name as A
import Agda.Syntax.Position
import Agda.Syntax.Fixity
import Agda.Syntax.Notation
import Agda.Syntax.Scope.Base
import Agda.Syntax.Scope.Monad
import Agda.TypeChecking.Monad.Base (typeError, TypeError(..), LHSOrPatSyn(..))
import qualified Agda.TypeChecking.Monad.Benchmark as Bench
import Agda.TypeChecking.Monad.Debug
import Agda.TypeChecking.Monad.State (getScope)
import Agda.Utils.Either
import Agda.Utils.Pretty
import Agda.Utils.List
import Agda.Utils.Trie (Trie)
import qualified Agda.Utils.Trie as Trie
import Agda.Utils.Impossible
billToParser :: ExprKind -> ScopeM a -> ScopeM a
billToParser k = Bench.billTo
  [ Bench.Parsing
  , case k of
      IsExpr    -> Bench.OperatorsExpr
      IsPattern -> Bench.OperatorsPattern
  ]
type FlatScope = Map QName [AbstractName]
getDefinedNames :: KindsOfNames -> FlatScope -> [[NewNotation]]
getDefinedNames kinds names =
  [ mergeNotations $ map (namesToNotation x . A.qnameName . anameName) ds
  | (x, ds) <- Map.toList names
  , any ((`elemKindsOfNames` kinds) . anameKind) ds
  , not (null ds)
  ]
  
  
  
  
localNames :: FlatScope -> ScopeM ([QName], [NewNotation])
localNames flat = do
  let defs = getDefinedNames allKindsOfNames flat
  locals <- nubOn fst . notShadowedLocals <$> getLocalVars
  
  reportS "scope.operators" 50
    [ "flat  = " ++ show flat
    , "defs  = " ++ show defs
    , "locals= " ++ show locals
    ]
  let localNots  = map localOp locals
      localNames = Set.fromList $ map notaName localNots
      otherNots  = filter (\n -> not (Set.member (notaName n) localNames))
                          (concat defs)
  return $ second (map useDefaultFixity) $ split $ localNots ++ otherNots
  where
    localOp (x, y) = namesToNotation (QName x) y
    split ops      = partitionEithers $ concatMap opOrNot ops
    opOrNot n      = Left (notaName n) :
                     if null (notation n) then [] else [Right n]
data InternalParsers e = InternalParsers
  { pTop    :: Parser e e
  , pApp    :: Parser e e
  , pArgs   :: Parser e [NamedArg e]
  , pNonfix :: Parser e e
  , pAtom   :: Parser e e
  }
data ExprKind = IsExpr | IsPattern
  deriving (Eq, Show)
data Parsers e = Parsers
  { parser :: [e] -> [e]
    
    
  , argsParser :: [e] -> [[NamedArg e]]
    
  , operators :: [NotationSection]
    
    
  , flattenedScope :: FlatScope
    
    
    
  }
buildParsers
  :: forall e. IsExpr e
  => ExprKind
     
  -> [QName]
     
     
     
     
     
     
     
  -> ScopeM (Parsers e)
buildParsers kind exprNames = do
    flat         <- flattenScope (qualifierModules exprNames) <$>
                      getScope
    (names, ops) <- localNames flat
    let 
        namesInExpr :: Set QName
        namesInExpr = Set.fromList exprNames
        partListsInExpr' = map (nameParts . unqualify) $
                           Set.toList namesInExpr
        partListTrie f =
          foldr (\ps -> Trie.union (Trie.everyPrefix ps ()))
                Trie.empty
                (f partListsInExpr')
        
        partListsInExpr :: Trie NamePart ()
        partListsInExpr = partListTrie id
        
        reversedPartListsInExpr :: Trie NamePart ()
        reversedPartListsInExpr = partListTrie (map reverse)
        
        partsInExpr :: Set RawName
        partsInExpr =
          Set.fromList [ s | Id s <- concat partListsInExpr' ]
        
        partsPresent n =
          [ Set.member p partsInExpr
          | p <- stringParts (notation n)
          ]
        addHole True  p = [Hole, Id p]
        addHole False p = [Id p]
        
        
        
        firstPartPresent withHole n =
          Trie.member (addHole withHole p) partListsInExpr
          where
          p = case n of
            NormalHole{} : IdPart p : _ -> rangedThing p
            IdPart p : _                -> rangedThing p
            _                           -> __IMPOSSIBLE__
        
        
        
        lastPartPresent withHole n =
          Trie.member (addHole withHole p) reversedPartListsInExpr
          where
          p = case reverse n of
            NormalHole{} : IdPart p : _ -> rangedThing p
            IdPart p : _                -> rangedThing p
            _                           -> __IMPOSSIBLE__
        
        
        correctUnderscores :: Bool -> Bool -> Notation -> Bool
        correctUnderscores withInitialHole withFinalHole n =
          firstPartPresent withInitialHole n
            &&
          lastPartPresent  withFinalHole   n
        
        
        filterCorrectUnderscoresOp :: [NewNotation] -> [NotationSection]
        filterCorrectUnderscoresOp ns =
          [ noSection n
          | n <- ns
          , if notaIsOperator n
            then correctUnderscores False False (notation n)
            else all (\s -> Trie.member [Id s] partListsInExpr)
                     (stringParts $ notation n)
          ]
        
        correctUnderscoresSect :: NotationKind -> Notation -> Bool
        correctUnderscoresSect k n = case (k, notationKind n) of
          (PrefixNotation,  InfixNotation)   -> correctUnderscores True False n
          (PostfixNotation, InfixNotation)   -> correctUnderscores False True n
          (NonfixNotation,  InfixNotation)   -> correctUnderscores True True n
          (NonfixNotation,  PrefixNotation)  -> correctUnderscores False True n
          (NonfixNotation,  PostfixNotation) -> correctUnderscores True False n
          _                                  -> __IMPOSSIBLE__
        
        
        
        
        
        (non, fix) = List.partition nonfix (filter (and . partsPresent) ops)
        cons       = getDefinedNames
                       (someKindsOfNames [ConName, FldName, PatternSynName]) flat
        conNames   = Set.fromList $
                       filter (flip Set.member namesInExpr) $
                       map (notaName . head) cons
        conParts   = Set.fromList $
                       concatMap notationNames $
                       filter (or . partsPresent) $
                       concat cons
        allNames   = Set.fromList $
                       filter (flip Set.member namesInExpr) names
        allParts   = Set.union conParts
                       (Set.fromList $
                        concatMap notationNames $
                        filter (or . partsPresent) ops)
        isAtom x
          | kind == IsPattern && not (isQualified x) =
            not (Set.member x conParts) || Set.member x conNames
          | otherwise =
            not (Set.member x allParts) || Set.member x allNames
        
        
        parseSections = case kind of
          IsPattern -> DoNotParseSections
          IsExpr    -> ParseSections
    let nonClosedSections l ns =
          case parseSections of
            DoNotParseSections -> []
            ParseSections      ->
              [ NotationSection n k (Just l) True
              | n <- ns
              , isinfix n && notaIsOperator n
              , k <- [PrefixNotation, PostfixNotation]
              , correctUnderscoresSect k (notation n)
              ]
        unrelatedOperators :: [NotationSection]
        unrelatedOperators =
          filterCorrectUnderscoresOp unrelated
            ++
          nonClosedSections Unrelated unrelated
          where
          unrelated = filter ((== Unrelated) . level) fix
        nonWithSections :: [NotationSection]
        nonWithSections =
          map (\s -> s { sectLevel = Nothing })
              (filterCorrectUnderscoresOp non)
            ++
          case parseSections of
            DoNotParseSections -> []
            ParseSections      ->
              [ NotationSection n NonfixNotation Nothing True
              | n <- fix
              , notaIsOperator n
              , correctUnderscoresSect NonfixNotation (notation n)
              ]
        
        
        relatedOperators :: [(PrecedenceLevel, [NotationSection])]
        relatedOperators =
          map (\((l, ns) : rest) -> (l, ns ++ concat (map snd rest))) .
          List.groupBy ((==) `on` fst) .
          List.sortBy (compare `on` fst) .
          mapMaybe (\n -> case level n of
                            Unrelated     -> Nothing
                            r@(Related l) ->
                              Just (l, filterCorrectUnderscoresOp [n] ++
                                       nonClosedSections r [n])) $
          fix
        everything :: [NotationSection]
        everything =
          concatMap snd relatedOperators ++
          unrelatedOperators ++
          nonWithSections
    reportS "scope.operators" 50
      [ "unrelatedOperators = " ++ show unrelatedOperators
      , "nonWithSections    = " ++ show nonWithSections
      , "relatedOperators   = " ++ show relatedOperators
      ]
    let g = Data.Function.fix $ \p -> InternalParsers
              { pTop    = memoise TopK $
                          Fold.asum $
                            foldr ($) (pApp p)
                              (map (\(l, ns) higher ->
                                       mkP (Right l) parseSections
                                           (pTop p) ns higher True)
                                   relatedOperators) :
                            map (\(k, n) ->
                                    mkP (Left k) parseSections
                                        (pTop p) [n] (pApp p) False)
                                (zip [0..] unrelatedOperators)
              , pApp    = memoise AppK $ appP (pNonfix p) (pArgs p)
              , pArgs   = argsP (pNonfix p)
              , pNonfix = memoise NonfixK $
                          Fold.asum $
                            pAtom p :
                            flip map nonWithSections (\sect ->
                              let n = sectNotation sect
                                  inner :: forall k. NK k ->
                                           Parser e (OperatorType k e)
                                  inner = opP parseSections (pTop p) n
                              in
                              case notationKind (notation n) of
                                InfixNotation ->
                                  flip ($) <$> placeholder Beginning
                                           <*> inner In
                                           <*> placeholder End
                                PrefixNotation ->
                                  inner Pre <*> placeholder End
                                PostfixNotation ->
                                  flip ($) <$> placeholder Beginning
                                           <*> inner Post
                                NonfixNotation -> inner Non
                                NoNotation     -> __IMPOSSIBLE__)
              , pAtom   = atomP isAtom
              }
    reportSDoc "scope.grammar" 10 $ return $
      "Operator grammar:" $$ nest 2 (grammar (pTop g))
    return $ Parsers
      { parser         = parse (parseSections, pTop  g)
      , argsParser     = parse (parseSections, pArgs g)
      , operators      = everything
      , flattenedScope = flat
      }
    where
        level :: NewNotation -> FixityLevel
        level = fixityLevel . notaFixity
        nonfix, isinfix, isprefix, ispostfix :: NewNotation -> Bool
        nonfix    = (== NonfixNotation)  . notationKind . notation
        isinfix   = (== InfixNotation)   . notationKind . notation
        isprefix  = (== PrefixNotation)  . notationKind . notation
        ispostfix = (== PostfixNotation) . notationKind . notation
        isPrefix, isPostfix :: NotationSection -> Bool
        isPrefix  = (== PrefixNotation)  . sectKind
        isPostfix = (== PostfixNotation) . sectKind
        isInfix :: Associativity -> NotationSection -> Bool
        isInfix ass s =
          sectKind s == InfixNotation
             &&
          fixityAssoc (notaFixity (sectNotation s)) == ass
        mkP :: PrecedenceKey
               
            -> ParseSections
            -> Parser e e
            -> [NotationSection]
            -> Parser e e
               
            -> Bool
               
               
            -> Parser e e
        mkP key parseSections p0 ops higher includeHigher =
            memoise (NodeK key) $
              Fold.asum $
                (if includeHigher then (higher :) else id) $
                catMaybes [nonAssoc, preRights, postLefts]
            where
            choice :: forall k.
                      NK k -> [NotationSection] ->
                      Parser e (OperatorType k e)
            choice k =
              Fold.asum .
              map (\sect ->
                let n = sectNotation sect
                    inner :: forall k.
                             NK k -> Parser e (OperatorType k e)
                    inner = opP parseSections p0 n
                in
                case k of
                  In   -> inner In
                  Pre  -> if isinfix n || ispostfix n
                          then flip ($) <$> placeholder Beginning
                                        <*> inner In
                          else inner Pre
                  Post -> if isinfix n || isprefix n
                          then flip <$> inner In
                                    <*> placeholder End
                          else inner Post
                  Non  -> __IMPOSSIBLE__)
            nonAssoc :: Maybe (Parser e e)
            nonAssoc = case filter (isInfix NonAssoc) ops of
              []  -> Nothing
              ops -> Just $
                (\x f y -> f (noPlaceholder x) (noPlaceholder y))
                  <$> higher
                  <*> choice In ops
                  <*> higher
            or p1 []   p2 []   = Nothing
            or p1 []   p2 ops2 = Just (p2 ops2)
            or p1 ops1 p2 []   = Just (p1 ops1)
            or p1 ops1 p2 ops2 = Just (p1 ops1 <|> p2 ops2)
            preRight :: Maybe (Parser e (MaybePlaceholder e -> e))
            preRight =
              or (choice Pre)
                 (filter isPrefix ops)
                 (\ops -> flip ($) <$> (noPlaceholder <$> higher)
                                   <*> choice In ops)
                 (filter (isInfix RightAssoc) ops)
            preRights :: Maybe (Parser e e)
            preRights = do
              preRight <- preRight
              return $ Data.Function.fix $ \preRights ->
                memoiseIfPrinting (PreRightsK key) $
                  preRight <*> (noPlaceholder <$> (preRights <|> higher))
            postLeft :: Maybe (Parser e (MaybePlaceholder e -> e))
            postLeft =
              or (choice Post)
                 (filter isPostfix ops)
                 (\ops -> flip <$> choice In ops
                               <*> (noPlaceholder <$> higher))
                 (filter (isInfix LeftAssoc) ops)
            postLefts :: Maybe (Parser e e)
            postLefts = do
              postLeft <- postLeft
              return $ Data.Function.fix $ \postLefts ->
                memoise (PostLeftsK key) $
                  flip ($) <$> (noPlaceholder <$> (postLefts <|> higher))
                           <*> postLeft
parsePat :: ([Pattern] -> [Pattern]) -> Pattern -> [Pattern]
parsePat prs = \case
    AppP p (Arg info q) ->
        fullParen' <$> (AppP <$> parsePat prs p <*> (Arg info <$> traverse (parsePat prs) q))
    RawAppP _ ps     -> fullParen' <$> (parsePat prs =<< prs ps)
    OpAppP r d ns ps -> fullParen' . OpAppP r d ns <$> (mapM . traverse . traverse) (parsePat prs) ps
    HiddenP _ _      -> fail "bad hidden argument"
    InstanceP _ _    -> fail "bad instance argument"
    AsP r x p        -> AsP r x <$> parsePat prs p
    p@DotP{}         -> return p
    ParenP r p       -> fullParen' <$> parsePat prs p
    p@WildP{}        -> return p
    p@AbsurdP{}      -> return p
    p@LitP{}         -> return p
    p@QuoteP{}       -> return p
    p@IdentP{}       -> return p
    RecP r fs        -> RecP r <$> mapM (traverse (parsePat prs)) fs
    p@EqualP{}       -> return p 
    EllipsisP _      -> fail "bad ellipsis"
    WithP r p        -> WithP r <$> parsePat prs p
type ParseLHS = Either Pattern (QName, LHSCore)
parseLHS'
  :: LHSOrPatSyn
       
       
  -> Maybe QName
       
       
  -> Pattern
       
  -> ScopeM (ParseLHS, [NotationSection])
       
       
parseLHS' IsLHS (Just qn) (RawAppP _ [WildP{}]) =
    return (Right (qn, LHSHead qn []), [])
parseLHS' lhsOrPatSyn top p = do
    
    patP <- buildParsers IsPattern (patternQNames p)
    
    let ps   = let result = parsePat (parser patP) p
               in  foldr seq () result `seq` result
    
    let cons = getNames (someKindsOfNames [ConName, PatternSynName])
                        (flattenedScope patP)
    let flds = getNames (someKindsOfNames [FldName])
                        (flattenedScope patP)
    let conf = PatternCheckConfig top cons flds
    case mapMaybe (validPattern conf) ps of
        
        [(_,lhs)] -> do reportS "scope.operators" 50 $ "Parsed lhs:" <+> pretty lhs
                        return (lhs, operators patP)
        
        []        -> typeError $ OperatorInformation (operators patP)
                               $ NoParseForLHS lhsOrPatSyn p
        
        rs        -> typeError $ OperatorInformation (operators patP)
                               $ AmbiguousParseForLHS lhsOrPatSyn p $
                       map (fullParen . fst) rs
    where
        getNames kinds flat =
          map (notaName . head) $ getDefinedNames kinds flat
        
        validPattern :: PatternCheckConfig -> Pattern -> Maybe (Pattern, ParseLHS)
        validPattern conf p =
          case (classifyPattern conf p, top) of
            (Just res@Left{}, Nothing) -> Just (p, res)  
            (Just res@Right{}, Just{}) -> Just (p, res)  
            _ -> Nothing
data PatternCheckConfig = PatternCheckConfig
  { topName  :: Maybe QName 
  , conNames :: [QName]     
  , fldNames :: [QName]     
  }
classifyPattern :: PatternCheckConfig -> Pattern -> Maybe ParseLHS
classifyPattern conf p =
  case patternAppView p of
    
    Arg _ (Named _ (IdentP x)) : ps | Just x == topName conf -> do
      guard $ all validPat ps
      return $ Right (x, lhsCoreAddSpine (LHSHead x []) ps)
    
    Arg _ (Named _ (IdentP x)) : ps | x `elem` fldNames conf -> do
      
      ps0 <- mapM classPat ps
      let (ps1, rest) = span (isLeft . namedArg) ps0
      (p2, ps3) <- uncons rest 
      guard $ all (isLeft . namedArg) ps3
      let (f, lhs)      = fromR p2
          (ps', _:ps'') = splitAt (length ps1) ps
      return $ Right (f, lhsCoreAddSpine (LHSProj x ps' lhs []) ps'')
    
    _ -> do
      guard $ validConPattern (conNames conf) p
      return $ Left p
  where 
        validPat = validConPattern (conNames conf) . namedArg
        classPat :: NamedArg Pattern -> Maybe (NamedArg ParseLHS)
        classPat = Trav.mapM (Trav.mapM (classifyPattern conf))
        fromR :: NamedArg (Either a (b, c)) -> (b, NamedArg c)
        fromR (Arg info (Named n (Right (b, c)))) = (b, Arg info (Named n c))
        fromR (Arg info (Named n (Left  a     ))) = __IMPOSSIBLE__
parseLHS :: QName -> Pattern -> ScopeM LHSCore
parseLHS top p = billToParser IsPattern $ do
  (res, ops) <- parseLHS' IsLHS (Just top) p
  case res of
    Right (f, lhs) -> return lhs
    _ -> typeError $ OperatorInformation ops
                   $ NoParseForLHS IsLHS p
parsePattern :: Pattern -> ScopeM Pattern
parsePattern = parsePatternOrSyn IsLHS
parsePatternSyn :: Pattern -> ScopeM Pattern
parsePatternSyn = parsePatternOrSyn IsPatSyn
parsePatternOrSyn :: LHSOrPatSyn -> Pattern -> ScopeM Pattern
parsePatternOrSyn lhsOrPatSyn p = billToParser IsPattern $ do
  (res, ops) <- parseLHS' lhsOrPatSyn Nothing p
  case res of
    Left p -> return p
    _      -> typeError $ OperatorInformation ops
                        $ NoParseForLHS lhsOrPatSyn p
validConPattern :: [QName] -> Pattern -> Bool
validConPattern cons p = case appView p of
    [WithP _ p]   -> validConPattern cons p
    [_]           -> True
    IdentP x : ps -> elem x cons && all (validConPattern cons) ps
    [QuoteP _, _] -> True
    DotP _ e : ps -> all (validConPattern cons) ps
    _             -> False
appView :: Pattern -> [Pattern]
appView p = case p of
    AppP p a         -> appView p ++ [namedArg a]
    OpAppP _ op _ ps -> IdentP op : map namedArg ps
    ParenP _ p       -> appView p
    RawAppP _ _      -> __IMPOSSIBLE__
    HiddenP _ _      -> __IMPOSSIBLE__
    InstanceP _ _    -> __IMPOSSIBLE__
    _                -> [p]
qualifierModules :: [QName] -> [[Name]]
qualifierModules qs =
  nubOn id $ filter (not . null) $ map (init . qnameParts) qs
parseApplication :: [Expr] -> ScopeM Expr
parseApplication [e] = return e
parseApplication es  = billToParser IsExpr $ do
    
    p <- buildParsers IsExpr [ q | Ident q <- es ]
    
    let result = parser p es
    case foldr seq () result `seq` result of
        [e] -> do
          reportSDoc "scope.operators" 50 $ return $
            "Parsed an operator application:" <+> pretty e
          return e
        []  -> typeError $ OperatorInformation (operators p)
                         $ NoParseForApplication es
        es' -> typeError $ OperatorInformation (operators p)
                         $ AmbiguousParseForApplication es
                         $ map fullParen es'
parseModuleIdentifier :: Expr -> ScopeM QName
parseModuleIdentifier (Ident m) = return m
parseModuleIdentifier e = typeError $ NotAModuleExpr e
parseRawModuleApplication :: [Expr] -> ScopeM (QName, [NamedArg Expr])
parseRawModuleApplication es = billToParser IsExpr $ do
    let e : es_args = es
    m <- parseModuleIdentifier e
    
    p <- buildParsers IsExpr [ q | Ident q <- es_args ]
    
    
    case  argsParser p es_args of
        [as] -> return (m, as)
        []   -> typeError $ OperatorInformation (operators p)
                          $ NoParseForApplication es
        ass -> do
          let f = fullParen . foldl (App noRange) (Ident m)
          typeError $ OperatorInformation (operators p)
                    $ AmbiguousParseForApplication es
                    $ map f ass
parseModuleApplication :: Expr -> ScopeM (QName, [NamedArg Expr])
parseModuleApplication (RawApp _ es) = parseRawModuleApplication es
parseModuleApplication (App r e1 e2) = do 
    (m, args) <- parseModuleApplication e1
    return (m, args ++ [e2])
parseModuleApplication e = do
    m <- parseModuleIdentifier e
    return (m, [])
fullParen :: IsExpr e => e -> e
fullParen e = case exprView $ fullParen' e of
    ParenV e    -> e
    e'          -> unExprView e'
fullParen' :: IsExpr e => e -> e
fullParen' e = case exprView e of
    LocalV _     -> e
    WildV _      -> e
    OtherV _     -> e
    HiddenArgV _ -> e
    InstanceArgV _ -> e
    ParenV _     -> e
    AppV e1 (Arg info e2) -> par $ unExprView $ AppV (fullParen' e1) (Arg info e2')
        where
            e2' = case argInfoHiding info of
                Hidden     -> e2
                Instance{} -> e2
                NotHidden  -> fullParen' <$> e2
    OpAppV x ns es -> par $ unExprView $ OpAppV x ns $ (map . fmap . fmap . fmap . fmap) fullParen' es
    LamV bs e -> par $ unExprView $ LamV bs (fullParen e)
    where
        par = unExprView . ParenV