{-# LANGUAGE BangPatterns
           , FlexibleContexts
           , GADTs #-}

module Vision.Image.Threshold (
    -- * Simple threshold
      ThresholdType (..), thresholdType
    , threshold
    -- * Adaptive threshold
    , AdaptiveThresholdKernel (..), AdaptiveThreshold
    , adaptiveThreshold, adaptiveThresholdFilter
    -- * Other methods
    , otsu, scw
    ) where

import Data.Int
import Foreign.Storable (Storable)

import qualified Data.Vector.Storable as V
import qualified Data.Vector as VU

import Vision.Image.Class (
      Image, ImagePixel, FromFunction (..), FunctorImage, (!), shape
    )
import Vision.Image.Filter.Internal (
      Filter (..), BoxFilter, Kernel (..), SeparableFilter, SeparatelyFiltrable
    , KernelAnchor (KernelAnchorCenter), FilterFold (..)
    , BorderInterpolate (BorderReplicate)
    , apply, blur, gaussianBlur, Mean, mean
    )
import Vision.Image.Type (Manifest, delayed, manifest)
import Vision.Histogram (
      HistogramShape, PixelValueSpace, ToHistogram, histogram
    )
import Vision.Primitive (Z (..), (:.) (..), Size, shapeLength)

import qualified Vision.Histogram as H
import qualified Vision.Image.Class as I

-- | Specifies what to do with pixels matching the threshold predicate.
--
-- @'BinaryThreshold' a b@ will replace matching pixels by @a@ and non-matchings
-- pixels by @b@.
--
-- @'Truncate' a@ will replace matching pixels by @a@.
--
-- @'TruncateInv' a@ will replace non-matching pixels by @a@.
data ThresholdType src res where
    BinaryThreshold :: res -> res -> ThresholdType src res
    Truncate        :: src        -> ThresholdType src src
    TruncateInv     :: src        -> ThresholdType src src

-- | Given the thresholding method, a boolean indicating if the pixel match the
-- thresholding condition and the pixel, returns the new pixel value.
thresholdType :: ThresholdType src res -> Bool -> src -> res
thresholdType :: forall src res. ThresholdType src res -> Bool -> src -> res
thresholdType (BinaryThreshold res
ifTrue res
ifFalse) Bool
match src
_   | Bool
match     = res
ifTrue
                                                         | Bool
otherwise = res
ifFalse
thresholdType (Truncate        src
ifTrue)         Bool
match src
pix | Bool
match     = src
ifTrue
                                                         | Bool
otherwise = src
pix
thresholdType (TruncateInv     src
ifFalse)        Bool
match src
pix | Bool
match     = src
pix
                                                         | Bool
otherwise = src
ifFalse
{-# INLINE thresholdType #-}

-- -----------------------------------------------------------------------------

-- | Applies the given predicate and threshold policy on the image.
threshold :: FunctorImage src res
          => (ImagePixel src -> Bool)
          -> ThresholdType (ImagePixel src) (ImagePixel res) -> src -> res
threshold :: forall src res.
FunctorImage src res =>
(ImagePixel src -> Bool)
-> ThresholdType (ImagePixel src) (ImagePixel res) -> src -> res
threshold !ImagePixel src -> Bool
cond !ThresholdType (ImagePixel src) (ImagePixel res)
thresType =
    forall src res.
FunctorImage src res =>
(ImagePixel src -> ImagePixel res) -> src -> res
I.map (\ImagePixel src
pix -> forall src res. ThresholdType src res -> Bool -> src -> res
thresholdType ThresholdType (ImagePixel src) (ImagePixel res)
thresType (ImagePixel src -> Bool
cond ImagePixel src
pix) ImagePixel src
pix)
{-# INLINE threshold #-}

-- -----------------------------------------------------------------------------

-- | Defines how pixels of the kernel of the adaptive threshold will be
-- weighted.
--
-- With 'MeanKernel', pixels of the kernel have the same weight.
--
-- With @'GaussianKernel' sigma@, pixels are weighted according to their distance
-- from the thresholded pixel using a Gaussian function parametred by @sigma@.
-- See 'gaussianBlur' for details.
data AdaptiveThresholdKernel acc where
    MeanKernel     :: Integral acc => AdaptiveThresholdKernel acc
    GaussianKernel :: (Floating acc, RealFrac acc)
                   => Maybe acc -> AdaptiveThresholdKernel acc

-- | Compares every pixel to its surrounding ones in the kernel of the given
-- radius.
adaptiveThreshold :: ( Image src, Integral (ImagePixel src)
                     , Ord (ImagePixel src)
                     , FromFunction res, Integral (FromFunctionPixel res)
                     , Storable acc
                     , SeparatelyFiltrable src res acc)
                  => AdaptiveThresholdKernel acc
                  -> Int
                  -- ^ Kernel radius.
                  -> ImagePixel src
                  -- ^ Minimum difference between the pixel and the kernel
                  -- average. The pixel is thresholded if
                  -- @pixel_value - kernel_mean > difference@ where difference
                  -- is this number. Can be negative.
                  -> ThresholdType (ImagePixel src) (FromFunctionPixel res)
                  -> src
                  -> res
adaptiveThreshold :: forall src res acc.
(Image src, Integral (ImagePixel src), Ord (ImagePixel src),
 FromFunction res, Integral (FromFunctionPixel res), Storable acc,
 SeparatelyFiltrable src res acc) =>
AdaptiveThresholdKernel acc
-> Int
-> ImagePixel src
-> ThresholdType (ImagePixel src) (FromFunctionPixel res)
-> src
-> res
adaptiveThreshold AdaptiveThresholdKernel acc
kernelType Int
radius ImagePixel src
thres ThresholdType (ImagePixel src) (FromFunctionPixel res)
thresType src
img =
    forall src acc res.
(Integral src, Ord src, Storable acc) =>
AdaptiveThresholdKernel acc
-> Int
-> src
-> ThresholdType src res
-> AdaptiveThreshold src acc res
adaptiveThresholdFilter AdaptiveThresholdKernel acc
kernelType Int
radius ImagePixel src
thres ThresholdType (ImagePixel src) (FromFunctionPixel res)
thresType forall src res f. Filterable src res f => f -> src -> res
`apply` src
img
{-# INLINABLE adaptiveThreshold #-}

type AdaptiveThreshold src acc res = SeparableFilter src () acc res

-- | Creates an adaptive thresholding filter to be used with 'apply'.
--
-- Use 'adaptiveThreshold' if you only want to apply the filter on the image.
--
-- Compares every pixel to its surrounding ones in the kernel of the given
-- radius.
adaptiveThresholdFilter :: (Integral src, Ord src, Storable acc)
                        => AdaptiveThresholdKernel acc
                        -> Int
                        -- ^ Kernel radius.
                        -> src
                        -- ^ Minimum difference between the pixel and the kernel
                        -- average. The pixel is thresholded if
                        -- @pixel_value - kernel_mean > difference@ where
                        -- difference is this number. Can be negative.
                        -> ThresholdType src res
                        -> AdaptiveThreshold src acc res
adaptiveThresholdFilter :: forall src acc res.
(Integral src, Ord src, Storable acc) =>
AdaptiveThresholdKernel acc
-> Int
-> src
-> ThresholdType src res
-> AdaptiveThreshold src acc res
adaptiveThresholdFilter !AdaptiveThresholdKernel acc
kernelType !Int
radius !src
thres !ThresholdType src res
thresType =
    Filter src (SeparableKernel src () acc) () (FilterFold acc) acc src
kernelFilter { fPost :: Point -> src -> () -> acc -> res
fPost = Point -> src -> () -> acc -> res
post }
  where
    !kernelFilter :: Filter src (SeparableKernel src () acc) () (FilterFold acc) acc src
kernelFilter =
        case AdaptiveThresholdKernel acc
kernelType of AdaptiveThresholdKernel acc
MeanKernel         -> forall src acc res.
(Integral src, Integral acc, Num res) =>
Int -> Blur src acc res
blur         Int
radius
                           GaussianKernel Maybe acc
sig -> forall src acc res.
(Integral src, Floating acc, RealFrac acc, Storable acc,
 Integral res) =>
Int -> Maybe acc -> Blur src acc res
gaussianBlur Int
radius Maybe acc
sig

    post :: Point -> src -> () -> acc -> res
post Point
ix src
pix ()
ini acc
acc =
        let !acc' :: src
acc' = (forall src kernel init fold acc res.
Filter src kernel init fold acc res
-> Point -> src -> init -> acc -> res
fPost Filter src (SeparableKernel src () acc) () (FilterFold acc) acc src
kernelFilter) Point
ix src
pix ()
ini acc
acc
            !cond :: Bool
cond = (src
pix forall a. Num a => a -> a -> a
- src
acc') forall a. Ord a => a -> a -> Bool
> src
thres
        in forall src res. ThresholdType src res -> Bool -> src -> res
thresholdType ThresholdType src res
thresType Bool
cond src
pix
{-# INLINE adaptiveThresholdFilter #-}

-- -----------------------------------------------------------------------------

-- | Applies a clustering-based image thresholding using the Otsu's method.
--
-- See <https://en.wikipedia.org/wiki/Otsu's_method>.
otsu :: ( HistogramShape (PixelValueSpace (ImagePixel src))
        , ToHistogram (ImagePixel src), FunctorImage src res
        , Ord (ImagePixel src), Num (ImagePixel src), Enum (ImagePixel src))
     => ThresholdType (ImagePixel src) (ImagePixel res) -> src -> res
otsu :: forall src res.
(HistogramShape (PixelValueSpace (ImagePixel src)),
 ToHistogram (ImagePixel src), FunctorImage src res,
 Ord (ImagePixel src), Num (ImagePixel src),
 Enum (ImagePixel src)) =>
ThresholdType (ImagePixel src) (ImagePixel res) -> src -> res
otsu !ThresholdType (ImagePixel src) (ImagePixel res)
thresType !src
img =
    forall src res.
FunctorImage src res =>
(ImagePixel src -> Bool)
-> ThresholdType (ImagePixel src) (ImagePixel res) -> src -> res
threshold (forall a. Ord a => a -> a -> Bool
<= ImagePixel src
thresh) ThresholdType (ImagePixel src) (ImagePixel res)
thresType src
img
 where
    !thresh :: ImagePixel src
thresh =
        let hist :: Histogram (PixelValueSpace (ImagePixel src)) Int
hist       = forall i a.
(MaskedImage i, ToHistogram (ImagePixel i), Storable a, Num a,
 HistogramShape (PixelValueSpace (ImagePixel i))) =>
Maybe (PixelValueSpace (ImagePixel i))
-> i -> Histogram (PixelValueSpace (ImagePixel i)) a
histogram forall a. Maybe a
Nothing src
img
            histV :: Vector Int
histV      = forall sh a. Histogram sh a -> Vector a
H.vector Histogram (PixelValueSpace (ImagePixel src)) Int
hist
            tot :: Int
tot        = forall sh. Shape sh => sh -> Int
shapeLength (forall i. MaskedImage i => i -> Point
I.shape src
img)
            runningMul :: Vector Int
runningMul = forall a b c.
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (\Int
v Int
i -> Int
v forall a. Num a => a -> a -> a
* Int
i) Vector Int
histV (forall a. Storable a => [a] -> Vector a
V.fromList [Int
0..Int
255])
            sm :: Double
sm         = forall a. Integral a => a -> Double
double (forall a. (Storable a, Num a) => Vector a -> a
V.sum forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Int -> Vector a -> Vector a
V.drop Int
1 Vector Int
runningMul)
            wB :: Vector Int
wB         = forall a. Storable a => (a -> a -> a) -> Vector a -> Vector a
V.scanl1 forall a. Num a => a -> a -> a
(+) Vector Int
histV
            wF :: Vector Int
wF         = forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
V.map (\Int
x -> Int
tot forall a. Num a => a -> a -> a
- Int
x) Vector Int
wB
            sumB :: Vector Int
sumB       = forall a. Storable a => (a -> a -> a) -> Vector a -> Vector a
V.scanl1 forall a. Num a => a -> a -> a
(+) Vector Int
runningMul
            mB :: Vector Double
mB         = forall a b c.
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (\Int
n Int
d -> if Int
d forall a. Eq a => a -> a -> Bool
== Int
0 then Double
1
                                                      else forall a. Integral a => a -> Double
double Int
n forall a. Fractional a => a -> a -> a
/ forall a. Integral a => a -> Double
double Int
d)
                                   Vector Int
sumB Vector Int
wB
            mF :: Vector Double
mF         = forall a b c.
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (\Int
b Int
f -> if Int
f forall a. Eq a => a -> a -> Bool
== Int
0 then Double
1
                                                      else   (Double
sm forall a. Num a => a -> a -> a
- forall a. Integral a => a -> Double
double Int
b)
                                                           forall a. Fractional a => a -> a -> a
/ forall a. Integral a => a -> Double
double Int
f)
                                   Vector Int
sumB Vector Int
wF
            between :: Vector Double
between    = forall a b c d e.
(Storable a, Storable b, Storable c, Storable d, Storable e) =>
(a -> b -> c -> d -> e)
-> Vector a -> Vector b -> Vector c -> Vector d -> Vector e
V.zipWith4 (\Int
x Int
y Double
b Double
f ->   forall a. Integral a => a -> Double
double Int
x forall a. Num a => a -> a -> a
* forall a. Integral a => a -> Double
double Int
y
                                                 forall a. Num a => a -> a -> a
* (Double
b forall a. Num a => a -> a -> a
- Double
f)forall a b. (Num a, Integral b) => a -> b -> a
^Int
two)
                                    Vector Int
wB Vector Int
wF Vector Double
mB Vector Double
mF
        in forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. Ord a => Vector a -> a
VU.maximum (forall a b. Vector a -> Vector b -> Vector (a, b)
VU.zip (forall a. [a] -> Vector a
VU.fromList forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Vector a -> [a]
V.toList Vector Double
between)
                                    (forall a. [a] -> Vector a
VU.fromList [ImagePixel src
0..ImagePixel src
255]))

    !two :: Int
two    = Int
2 :: Int
{-# INLINABLE otsu #-}

-- -----------------------------------------------------------------------------

-- | This is a sliding concentric window filter (SCW) that uses the ratio of the
-- standard deviations of two sliding windows centered on a same point to detect
-- regions of interest (ROI).
--
-- > scw sizeWindowA sizeWindowB beta thresType img
--
-- Let @σA@ be the standard deviation of a fist window around a pixel and @σB@
-- be the standard deviation of another window around the same pixel.
-- Then the pixel will match the threshold if @σB / σA >= beta@, and will be
-- thresholded according to the given 'ThresholdType'.
--
-- See <http://www.academypublisher.com/jcp/vol04/no08/jcp0408771777.pdf>.
scw :: ( Image src, Integral (ImagePixel src), FromFunction dst
       , Floating stdev, Fractional stdev, Ord stdev, Storable stdev)
    => Size -> Size -> stdev
    -> ThresholdType (ImagePixel src) (FromFunctionPixel dst) -> src -> dst
scw :: forall src dst stdev.
(Image src, Integral (ImagePixel src), FromFunction dst,
 Floating stdev, Fractional stdev, Ord stdev, Storable stdev) =>
Point
-> Point
-> stdev
-> ThresholdType (ImagePixel src) (FromFunctionPixel dst)
-> src
-> dst
scw !Point
sizeA !Point
sizeB !stdev
beta !ThresholdType (ImagePixel src) (FromFunctionPixel dst)
thresType !src
img =
    Delayed stdev -> Delayed stdev -> dst
betaThreshold (Point -> Delayed stdev
stdDev Point
sizeA) (Point -> Delayed stdev
stdDev Point
sizeB)
  where
    betaThreshold :: Delayed stdev -> Delayed stdev -> dst
betaThreshold Delayed stdev
a Delayed stdev
b =
        forall i.
FromFunction i =>
Point -> (Point -> FromFunctionPixel i) -> i
fromFunction (forall i. MaskedImage i => i -> Point
shape src
img) forall a b. (a -> b) -> a -> b
$ \Point
pt ->
            let !cond :: Bool
cond = (Delayed stdev
b forall i. Image i => i -> Point -> ImagePixel i
! Point
pt) forall a. Fractional a => a -> a -> a
/ (Delayed stdev
a forall i. Image i => i -> Point -> ImagePixel i
! Point
pt) forall a. Ord a => a -> a -> Bool
< stdev
beta
            in forall src res. ThresholdType src res -> Bool -> src -> res
thresholdType ThresholdType (ImagePixel src) (FromFunctionPixel dst)
thresType Bool
cond (src
img forall i. Image i => i -> Point -> ImagePixel i
! Point
pt)

    stdDev :: Point -> Delayed stdev
stdDev Point
size =
       let filt :: (Integral src, Fractional res) => Mean src Int16 res
           filt :: forall src res.
(Integral src, Fractional res) =>
Mean src Int16 res
filt     = forall src acc res.
(Integral src, Integral acc, Fractional res) =>
Point -> SeparableFilter src () acc res
mean Point
size
           !meanImg :: Manifest stdev
meanImg = forall p. Manifest p -> Manifest p
manifest forall a b. (a -> b) -> a -> b
$ forall src res f. Filterable src res f => f -> src -> res
apply forall src res.
(Integral src, Fractional res) =>
Mean src Int16 res
filt src
img
           !varImg :: Manifest stdev
varImg  = forall p. Manifest p -> Manifest p
manifest forall a b. (a -> b) -> a -> b
$ forall src res f. Filterable src res f => f -> src -> res
apply (forall src res.
(Integral src, Fractional res, Storable res) =>
Point -> Manifest res -> BoxFilter src res res res
variance Point
size Manifest stdev
meanImg) src
img
       in forall p. Delayed p -> Delayed p
delayed forall a b. (a -> b) -> a -> b
$ forall src res.
FunctorImage src res =>
(ImagePixel src -> ImagePixel res) -> src -> res
I.map forall a. Floating a => a -> a
sqrt Manifest stdev
varImg
{-# INLINABLE scw #-}

-- | Given a mean image and an original image, computes the variance of the
-- kernel of the given size.
--
-- @average [ (origPix - mean)^2 | origPix <- kernel pixels on original ]@.
variance :: (Integral src, Fractional res, Storable res)
         => Size -> Manifest res -> BoxFilter src res res res
variance :: forall src res.
(Integral src, Fractional res, Storable res) =>
Point -> Manifest res -> BoxFilter src res res res
variance !size :: Point
size@(DIM0
Z :. Int
h :. Int
w) !Manifest res
meanImg =
    forall src kernel init fold acc res.
Point
-> KernelAnchor
-> kernel
-> (Point -> src -> init)
-> fold
-> (Point -> src -> init -> acc -> res)
-> BorderInterpolate src
-> Filter src kernel init fold acc res
Filter Point
size KernelAnchor
KernelAnchorCenter (forall src init acc.
(init -> Point -> src -> acc -> acc) -> Kernel src init acc
Kernel forall {a} {a} {p}. (Integral a, Num a) => a -> p -> a -> a -> a
kernel) (\Point
pt src
_ -> Manifest res
meanImg forall i. Image i => i -> Point -> ImagePixel i
! Point
pt)
           (forall acc. (Point -> acc) -> FilterFold acc
FilterFold (forall a b. a -> b -> a
const res
0)) Point -> src -> res -> res -> res
post forall a. BorderInterpolate a
BorderReplicate
  where
    kernel :: a -> p -> a -> a -> a
kernel !a
kernelMean p
_ !a
val !a
acc =
        a
acc forall a. Num a => a -> a -> a
+ forall a. Num a => a -> a
square (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
val forall a. Num a => a -> a -> a
- a
kernelMean)

    !nPixsFactor :: res
nPixsFactor = res
1 forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$! Int
h forall a. Num a => a -> a -> a
* Int
w)
    post :: Point -> src -> res -> res -> res
post Point
_ src
_ res
_ !res
acc  = res
acc forall a. Num a => a -> a -> a
* res
nPixsFactor
{-# INLINABLE variance #-}

-- -----------------------------------------------------------------------------

square :: Num a => a -> a
square :: forall a. Num a => a -> a
square a
a = a
a forall a. Num a => a -> a -> a
* a
a

double :: Integral a => a -> Double
double :: forall a. Integral a => a -> Double
double = forall a b. (Integral a, Num b) => a -> b
fromIntegral