module Agda.Compiler.Treeless.Identity
( detectIdentityFunctions ) where
import Control.Applicative ( Alternative((<|>), empty) )
import Data.Foldable (foldMap)
import Data.Semigroup
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List as List
import Agda.Syntax.Treeless
import Agda.TypeChecking.Substitute
import Agda.TypeChecking.Monad
import Agda.Utils.Lens
detectIdentityFunctions :: QName -> TTerm -> TCM TTerm
detectIdentityFunctions q t =
case isIdentity q t of
Nothing -> return t
Just (n, k) -> do
markInline True q
def <- theDef <$> getConstInfo q
return $ mkTLam n $ TVar k
isIdentity :: QName -> TTerm -> Maybe (Int, Int)
isIdentity q t =
trivialIdentity q t <|> recursiveIdentity q t
recursiveIdentity :: QName -> TTerm -> Maybe (Int, Int)
recursiveIdentity q t =
case b of
TCase x _ (TError TUnreachable) bs
| all (identityBranch x) bs -> pure (n, x)
_ -> empty
where
(n, b) = tLamView t
identityBranch _ TALit{} = False
identityBranch _ TAGuard{} = False
identityBranch x (TACon c a b) =
case b of
TApp (TCon c') args -> c == c' && identityArgs a args
TVar y -> y == x + a
_ -> False
where
identityArgs a args =
length args == a && and (zipWith match (reverse args) [0..])
proj x args = reverse args !! x
match TErased _ = True
match (TVar z) y = z == y
match (TApp (TDef f) args) y = f == q && length args == n && match (proj x args) y
match _ _ = False
data IdentityIn = IdIn [Int]
notId :: IdentityIn
notId = IdIn []
instance Semigroup IdentityIn where
IdIn xs <> IdIn ys = IdIn $ List.intersect xs ys
trivialIdentity :: QName -> TTerm -> Maybe (Int, Int)
trivialIdentity q t =
case go 0 b of
IdIn [x] -> pure (n, x)
IdIn [] -> Nothing
IdIn (_:_:_) -> Nothing
where
(n, b) = tLamView t
go :: Int -> TTerm -> IdentityIn
go k t =
case t of
TVar x | x >= k -> IdIn [x - k]
| otherwise -> notId
TLet _ b -> go (k + 1) b
TCase _ _ d bs -> sconcat (go k d :| map (goAlt k) bs)
TApp (TDef f) args
| f == q -> IdIn [ y | (TVar x, y) <- zip (reverse args) [0..], y + k == x ]
TCoerce v -> go k v
TApp{} -> notId
TLam{} -> notId
TLit{} -> notId
TDef{} -> notId
TCon{} -> notId
TPrim{} -> notId
TUnit{} -> notId
TSort{} -> notId
TErased{} -> notId
TError{} -> notId
goAlt :: Int -> TAlt -> IdentityIn
goAlt k (TALit _ b) = go k b
goAlt k (TAGuard _ b) = go k b
goAlt k (TACon _ n b) = go (k + n) b