module ContextFilter (filterContext) where

-- | Like 'filter' except you can request the elements before
-- and after the matched elements.
--
-- >>> filterContext 2 1 (\x -> 4<=x&&x<=6) [0..10::Int]
-- [2,3,4,5,6,7]
--
-- >>> filterContext 2 1 even [0..10::Int]
-- [0,1,2,3,4,5,6,7,8,9,10]
--
-- >>> filterContext 2 1 (==10) [0..10::Int]
-- [8,9,10]
--
-- >>> filterContext 2 1 (==0) [0..10::Int]
-- [0,1]
--
-- >>> filterContext 0 0 (==0) [0..10::Int]
-- [0]
--
-- >>> filterContext 2 1 (==5) [0..10::Int]
-- [3,4,5,6]
--
-- >>> filterContext 1 2 (==5) [0..10::Int]
-- [4,5,6,7]
filterContext ::
  Int         {- ^ context before       -} ->
  Int         {- ^ context after        -} ->
  (a -> Bool) {- ^ predicate            -} ->
  [a]         {- ^ inputs               -} ->
  [a]         {- ^ matches with context -}
filterContext :: forall a. Int -> Int -> (a -> Bool) -> [a] -> [a]
filterContext Int
before Int
after a -> Bool
p [a]
xs0
  | Int
before forall a. Ord a => a -> a -> Bool
< Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"filterContext: bad before"
  | Int
after  forall a. Ord a => a -> a -> Bool
< Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"filterContext: bad after"
  | Bool
otherwise  = forall {a}. Int -> Int -> [a] -> [a] -> [a]
go Int
0 Int
0 [a]
xs0 [a]
xs0
  where
    width :: Int
width = Int
before forall a. Num a => a -> a -> a
+ Int
after

    -- i: index
    -- m: current match window
    -- xs: list to match
    -- ys: offset list to generate results from
    go :: Int -> Int -> [a] -> [a] -> [a]
go Int
i Int
m (a
x:[a]
xs) yys :: [a]
yys@(a
y:[a]
ys) =
      if (Int
m forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
|| Bool
px) Bool -> Bool -> Bool
&& Int
i forall a. Ord a => a -> a -> Bool
>= Int
before then a
y forall a. a -> [a] -> [a]
: [a]
rest else [a]
rest
      where
        rest :: [a]
rest = Int -> Int -> [a] -> [a] -> [a]
go (Int
iforall a. Num a => a -> a -> a
+Int
1) Int
m' [a]
xs [a]
ys'
        px :: Bool
px = a -> Bool
p a
x
        m' :: Int
m' = forall a. Ord a => a -> a -> a
max (Int
mforall a. Num a => a -> a -> a
-Int
1) (if Bool
px then Int
width else Int
0)
        ys' :: [a]
ys' = if Int
i forall a. Ord a => a -> a -> Bool
>= Int
before then [a]
ys else [a]
yys

    -- no more matches, so just return the remaining m window
    go Int
_ Int
m [a]
_ [a]
ys = forall a. Int -> [a] -> [a]
take Int
m [a]
ys