-- | Checking for missing cases in a match expression.  Based on
-- "Warnings for pattern matching" by Luc Maranget.  We only detect
-- inexhaustiveness here - ideally, we would also like to check for
-- redundant cases.
module Language.Futhark.TypeChecker.Match
  ( unmatched,
    Match,
  )
where

import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Util (maybeHead, nubOrd)
import Futhark.Util.Pretty hiding (group, space)
import Language.Futhark hiding (ExpBase (Constr))

data Constr
  = Constr Name
  | ConstrTuple
  | ConstrRecord [Name]
  | -- | Treated as 0-ary.
    ConstrLit PatLit
  deriving (Constr -> Constr -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Constr -> Constr -> Bool
$c/= :: Constr -> Constr -> Bool
== :: Constr -> Constr -> Bool
$c== :: Constr -> Constr -> Bool
Eq, Eq Constr
Constr -> Constr -> Bool
Constr -> Constr -> Ordering
Constr -> Constr -> Constr
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Constr -> Constr -> Constr
$cmin :: Constr -> Constr -> Constr
max :: Constr -> Constr -> Constr
$cmax :: Constr -> Constr -> Constr
>= :: Constr -> Constr -> Bool
$c>= :: Constr -> Constr -> Bool
> :: Constr -> Constr -> Bool
$c> :: Constr -> Constr -> Bool
<= :: Constr -> Constr -> Bool
$c<= :: Constr -> Constr -> Bool
< :: Constr -> Constr -> Bool
$c< :: Constr -> Constr -> Bool
compare :: Constr -> Constr -> Ordering
$ccompare :: Constr -> Constr -> Ordering
Ord, Int -> Constr -> ShowS
[Constr] -> ShowS
Constr -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Constr] -> ShowS
$cshowList :: [Constr] -> ShowS
show :: Constr -> String
$cshow :: Constr -> String
showsPrec :: Int -> Constr -> ShowS
$cshowsPrec :: Int -> Constr -> ShowS
Show)

-- | A representation of the essentials of a pattern.
data Match t
  = MatchWild t
  | MatchConstr Constr [Match t] t
  deriving (Match t -> Match t -> Bool
forall t. Eq t => Match t -> Match t -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Match t -> Match t -> Bool
$c/= :: forall t. Eq t => Match t -> Match t -> Bool
== :: Match t -> Match t -> Bool
$c== :: forall t. Eq t => Match t -> Match t -> Bool
Eq, Match t -> Match t -> Bool
Match t -> Match t -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {t}. Ord t => Eq (Match t)
forall t. Ord t => Match t -> Match t -> Bool
forall t. Ord t => Match t -> Match t -> Ordering
forall t. Ord t => Match t -> Match t -> Match t
min :: Match t -> Match t -> Match t
$cmin :: forall t. Ord t => Match t -> Match t -> Match t
max :: Match t -> Match t -> Match t
$cmax :: forall t. Ord t => Match t -> Match t -> Match t
>= :: Match t -> Match t -> Bool
$c>= :: forall t. Ord t => Match t -> Match t -> Bool
> :: Match t -> Match t -> Bool
$c> :: forall t. Ord t => Match t -> Match t -> Bool
<= :: Match t -> Match t -> Bool
$c<= :: forall t. Ord t => Match t -> Match t -> Bool
< :: Match t -> Match t -> Bool
$c< :: forall t. Ord t => Match t -> Match t -> Bool
compare :: Match t -> Match t -> Ordering
$ccompare :: forall t. Ord t => Match t -> Match t -> Ordering
Ord, Int -> Match t -> ShowS
forall t. Show t => Int -> Match t -> ShowS
forall t. Show t => [Match t] -> ShowS
forall t. Show t => Match t -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Match t] -> ShowS
$cshowList :: forall t. Show t => [Match t] -> ShowS
show :: Match t -> String
$cshow :: forall t. Show t => Match t -> String
showsPrec :: Int -> Match t -> ShowS
$cshowsPrec :: forall t. Show t => Int -> Match t -> ShowS
Show)

matchType :: Match StructType -> StructType
matchType :: Match StructType -> StructType
matchType (MatchWild StructType
t) = StructType
t
matchType (MatchConstr Constr
_ [Match StructType]
_ StructType
t) = StructType
t

pprMatch :: Int -> Match t -> Doc a
pprMatch :: forall t a. Int -> Match t -> Doc a
pprMatch Int
_ MatchWild {} = Doc a
"_"
pprMatch Int
_ (MatchConstr (ConstrLit PatLit
l) [Match t]
_ t
_) = forall a ann. Pretty a => a -> Doc ann
pretty PatLit
l
pprMatch Int
p (MatchConstr (Constr Name
c) [Match t]
ps t
_) =
  forall a. Bool -> Doc a -> Doc a
parensIf (Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Match t]
ps) Bool -> Bool -> Bool
&& Int
p forall a. Ord a => a -> a -> Bool
>= Int
10) forall a b. (a -> b) -> a -> b
$
    Doc a
"#" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Name
c forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map ((Doc a
" " <>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t a. Int -> Match t -> Doc a
pprMatch Int
10) [Match t]
ps)
pprMatch Int
_ (MatchConstr Constr
ConstrTuple [Match t]
ps t
_) =
  forall ann. Doc ann -> Doc ann
parens forall a b. (a -> b) -> a -> b
$ forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall t a. Int -> Match t -> Doc a
pprMatch (-Int
1)) [Match t]
ps
pprMatch Int
_ (MatchConstr (ConstrRecord [Name]
fs) [Match t]
ps t
_) =
  forall ann. Doc ann -> Doc ann
braces forall a b. (a -> b) -> a -> b
$ forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {t} {ann}. Name -> Match t -> Doc ann
ppField [Name]
fs [Match t]
ps
  where
    ppField :: Name -> Match t -> Doc ann
ppField Name
name Match t
t = forall a ann. Pretty a => a -> Doc ann
pretty (Name -> String
nameToString Name
name) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
equals forall a. Semigroup a => a -> a -> a
<> forall t a. Int -> Match t -> Doc a
pprMatch (-Int
1) Match t
t

instance Pretty (Match t) where
  pretty :: forall ann. Match t -> Doc ann
pretty = forall t a. Int -> Match t -> Doc a
pprMatch (-Int
1)

patternToMatch :: Pat -> Match StructType
patternToMatch :: Pat -> Match StructType
patternToMatch (Id VName
_ (Info PatType
t) SrcLoc
_) = forall t. t -> Match t
MatchWild forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t
patternToMatch (Wildcard (Info PatType
t) SrcLoc
_) = forall t. t -> Match t
MatchWild forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t
patternToMatch (PatParens Pat
p SrcLoc
_) = Pat -> Match StructType
patternToMatch Pat
p
patternToMatch (PatAttr AttrInfo VName
_ Pat
p SrcLoc
_) = Pat -> Match StructType
patternToMatch Pat
p
patternToMatch (PatAscription Pat
p TypeExp VName
_ SrcLoc
_) = Pat -> Match StructType
patternToMatch Pat
p
patternToMatch (PatLit PatLit
l (Info PatType
t) SrcLoc
_) =
  forall t. Constr -> [Match t] -> t -> Match t
MatchConstr (PatLit -> Constr
ConstrLit PatLit
l) [] forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t
patternToMatch p :: Pat
p@(TuplePat [Pat]
ps SrcLoc
_) =
  forall t. Constr -> [Match t] -> t -> Match t
MatchConstr Constr
ConstrTuple (forall a b. (a -> b) -> [a] -> [b]
map Pat -> Match StructType
patternToMatch [Pat]
ps) forall a b. (a -> b) -> a -> b
$
    Pat -> StructType
patternStructType Pat
p
patternToMatch p :: Pat
p@(RecordPat [(Name, Pat)]
fs SrcLoc
_) =
  forall t. Constr -> [Match t] -> t -> Match t
MatchConstr ([Name] -> Constr
ConstrRecord [Name]
fnames) (forall a b. (a -> b) -> [a] -> [b]
map Pat -> Match StructType
patternToMatch [Pat]
ps) forall a b. (a -> b) -> a -> b
$
    Pat -> StructType
patternStructType Pat
p
  where
    ([Name]
fnames, [Pat]
ps) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. Map Name a -> [(Name, a)]
sortFields forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Pat)]
fs
patternToMatch (PatConstr Name
c (Info PatType
t) [Pat]
args SrcLoc
_) =
  forall t. Constr -> [Match t] -> t -> Match t
MatchConstr (Name -> Constr
Constr Name
c) (forall a b. (a -> b) -> [a] -> [b]
map Pat -> Match StructType
patternToMatch [Pat]
args) forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t

isConstr :: Match t -> Maybe Name
isConstr :: forall t. Match t -> Maybe Name
isConstr (MatchConstr (Constr Name
c) [Match t]
_ t
_) = forall a. a -> Maybe a
Just Name
c
isConstr Match t
_ = forall a. Maybe a
Nothing

isBool :: Match t -> Maybe Bool
isBool :: forall t. Match t -> Maybe Bool
isBool (MatchConstr (ConstrLit (PatLitPrim (BoolValue Bool
b))) [Match t]
_ t
_) = forall a. a -> Maybe a
Just Bool
b
isBool Match t
_ = forall a. Maybe a
Nothing

complete :: [Match StructType] -> Bool
complete :: [Match StructType] -> Bool
complete [Match StructType]
xs
  | Just Match StructType
x <- forall a. [a] -> Maybe a
maybeHead [Match StructType]
xs,
    Scalar (Sum Map Name [StructType]
all_cs) <- Match StructType -> StructType
matchType Match StructType
x,
    Just [Name]
xs_cs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall t. Match t -> Maybe Name
isConstr [Match StructType]
xs =
      forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
xs_cs) (forall k a. Map k a -> [k]
M.keys Map Name [StructType]
all_cs)
  | Bool
otherwise =
      forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall a. a -> Maybe a -> a
fromMaybe [] (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall t. Match t -> Maybe Bool
isBool [Match StructType]
xs)) [Bool
True, Bool
False]
        Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall {t}. Match t -> Bool
isRecord [Match StructType]
xs
        Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall {t}. Match t -> Bool
isTuple [Match StructType]
xs
  where
    isRecord :: Match t -> Bool
isRecord (MatchConstr ConstrRecord {} [Match t]
_ t
_) = Bool
True
    isRecord Match t
_ = Bool
False
    isTuple :: Match t -> Bool
isTuple (MatchConstr Constr
ConstrTuple [Match t]
_ t
_) = Bool
True
    isTuple Match t
_ = Bool
False

specialise ::
  [StructType] ->
  Match StructType ->
  [[Match StructType]] ->
  [[Match StructType]]
specialise :: [StructType]
-> Match StructType -> [[Match StructType]] -> [[Match StructType]]
specialise [StructType]
ats Match StructType
c1 = [[Match StructType]] -> [[Match StructType]]
go
  where
    go :: [[Match StructType]] -> [[Match StructType]]
go ((Match StructType
c2 : [Match StructType]
row) : [[Match StructType]]
ps)
      | Just [Match StructType]
args <- forall {t}. Match t -> Match StructType -> Maybe [Match StructType]
match Match StructType
c1 Match StructType
c2 =
          ([Match StructType]
args forall a. [a] -> [a] -> [a]
++ [Match StructType]
row) forall a. a -> [a] -> [a]
: [[Match StructType]] -> [[Match StructType]]
go [[Match StructType]]
ps
      | Bool
otherwise =
          [[Match StructType]] -> [[Match StructType]]
go [[Match StructType]]
ps
    go [[Match StructType]]
_ = []

    match :: Match t -> Match StructType -> Maybe [Match StructType]
match (MatchConstr Constr
c1' [Match t]
_ t
_) (MatchConstr Constr
c2' [Match StructType]
args StructType
_)
      | Constr
c1' forall a. Eq a => a -> a -> Bool
== Constr
c2' =
          forall a. a -> Maybe a
Just [Match StructType]
args
      | Bool
otherwise =
          forall a. Maybe a
Nothing
    match Match t
_ MatchWild {} =
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall t. t -> Match t
MatchWild [StructType]
ats
    match Match t
_ Match StructType
_ =
      forall a. Maybe a
Nothing

defaultMat :: [[Match t]] -> [[Match t]]
defaultMat :: forall t. [[Match t]] -> [[Match t]]
defaultMat = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {t}. [Match t] -> Maybe [Match t]
onRow
  where
    onRow :: [Match t] -> Maybe [Match t]
onRow (MatchConstr {} : [Match t]
_) = forall a. Maybe a
Nothing
    onRow (MatchWild {} : [Match t]
ps) = forall a. a -> Maybe a
Just [Match t]
ps
    onRow [] = forall a. Maybe a
Nothing -- Should not happen.

findUnmatched :: [[Match StructType]] -> Int -> [[Match ()]]
findUnmatched :: [[Match StructType]] -> Int -> [[Match ()]]
findUnmatched [[Match StructType]]
pmat Int
n
  | ((Match StructType
p : [Match StructType]
_) : [[Match StructType]]
_) <- [[Match StructType]]
pmat,
    Just [Match StructType]
heads <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall a. [a] -> Maybe a
maybeHead [[Match StructType]]
pmat =
      if [Match StructType] -> Bool
complete [Match StructType]
heads
        then [Match StructType] -> [[Match ()]]
completeCase [Match StructType]
heads
        else forall {dim} {as} {t}. TypeBase dim as -> [Match t] -> [[Match ()]]
incompleteCase (Match StructType -> StructType
matchType Match StructType
p) [Match StructType]
heads
  where
    completeCase :: [Match StructType] -> [[Match ()]]
completeCase [Match StructType]
cs = do
      Match StructType
c <- [Match StructType]
cs
      let ats :: [StructType]
ats = case Match StructType
c of
            MatchConstr Constr
_ [Match StructType]
args StructType
_ -> forall a b. (a -> b) -> [a] -> [b]
map Match StructType -> StructType
matchType [Match StructType]
args
            MatchWild StructType
_ -> []
          a_k :: Int
a_k = forall (t :: * -> *) a. Foldable t => t a -> Int
length [StructType]
ats
          pmat' :: [[Match StructType]]
pmat' = [StructType]
-> Match StructType -> [[Match StructType]] -> [[Match StructType]]
specialise [StructType]
ats Match StructType
c [[Match StructType]]
pmat
      [Match ()]
u <- [[Match StructType]] -> Int -> [[Match ()]]
findUnmatched [[Match StructType]]
pmat' (Int
a_k forall a. Num a => a -> a -> a
+ Int
n forall a. Num a => a -> a -> a
- Int
1)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case Match StructType
c of
        MatchConstr Constr
c' [Match StructType]
_ StructType
_ ->
          let ([Match ()]
r, [Match ()]
p) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
a_k [Match ()]
u
           in forall t. Constr -> [Match t] -> t -> Match t
MatchConstr Constr
c' [Match ()]
r () forall a. a -> [a] -> [a]
: [Match ()]
p
        MatchWild StructType
_ ->
          forall t. t -> Match t
MatchWild () forall a. a -> [a] -> [a]
: [Match ()]
u

    incompleteCase :: TypeBase dim as -> [Match t] -> [[Match ()]]
incompleteCase TypeBase dim as
pt [Match t]
cs = do
      [Match ()]
u <- [[Match StructType]] -> Int -> [[Match ()]]
findUnmatched (forall t. [[Match t]] -> [[Match t]]
defaultMat [[Match StructType]]
pmat) (Int
n forall a. Num a => a -> a -> a
- Int
1)
      if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Match t]
cs
        then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall t. t -> Match t
MatchWild () forall a. a -> [a] -> [a]
: [Match ()]
u
        else case TypeBase dim as
pt of
          Scalar (Sum Map Name [TypeBase dim as]
all_cs) -> do
            -- Figure out which constructors are missing.
            let sigma :: [Name]
sigma = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall t. Match t -> Maybe Name
isConstr [Match t]
cs
                notCovered :: (Name, b) -> Bool
notCovered (Name
k, b
_) = Name
k forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Name]
sigma
            (Name
cname, [TypeBase dim as]
ts) <- forall a. (a -> Bool) -> [a] -> [a]
filter forall {b}. (Name, b) -> Bool
notCovered forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Map Name [TypeBase dim as]
all_cs
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall t. Constr -> [Match t] -> t -> Match t
MatchConstr (Name -> Constr
Constr Name
cname) (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const (forall t. t -> Match t
MatchWild ())) [TypeBase dim as]
ts) () forall a. a -> [a] -> [a]
: [Match ()]
u
          Scalar (Prim PrimType
Bool) -> do
            -- Figure out which constants are missing.
            let sigma :: [Bool]
sigma = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall t. Match t -> Maybe Bool
isBool [Match t]
cs
            Bool
b <- forall a. (a -> Bool) -> [a] -> [a]
filter (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Bool]
sigma) [Bool
True, Bool
False]
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall t. Constr -> [Match t] -> t -> Match t
MatchConstr (PatLit -> Constr
ConstrLit (PrimValue -> PatLit
PatLitPrim (Bool -> PrimValue
BoolValue Bool
b))) [] () forall a. a -> [a] -> [a]
: [Match ()]
u
          TypeBase dim as
_ -> do
            -- FIXME: this is wrong in the unlikely case where someone
            -- is pattern-matching every single possible number for
            -- some numeric type.  It should be handled more like Bool
            -- above.
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall t. t -> Match t
MatchWild () forall a. a -> [a] -> [a]
: [Match ()]
u
findUnmatched [] Int
n = [forall a. Int -> a -> [a]
replicate Int
n forall a b. (a -> b) -> a -> b
$ forall t. t -> Match t
MatchWild ()]
findUnmatched [[Match StructType]]
_ Int
_ = []

{-# NOINLINE unmatched #-}

-- | Find the unmatched cases.
unmatched :: [Pat] -> [Match ()]
unmatched :: [Pat] -> [Match ()]
unmatched [Pat]
orig_ps =
  -- The algorithm may find duplicate example, which we filter away
  -- here.
  forall a. Ord a => [a] -> [a]
nubOrd forall a b. (a -> b) -> a -> b
$
    forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall a. [a] -> Maybe a
maybeHead forall a b. (a -> b) -> a -> b
$
      [[Match StructType]] -> Int -> [[Match ()]]
findUnmatched (forall a b. (a -> b) -> [a] -> [b]
map (forall a. a -> [a]
L.singleton forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat -> Match StructType
patternToMatch) [Pat]
orig_ps) Int
1