module Data.Array.Repa.Algorithms.Convolve
( convolve
, GetOut
, outAs
, outClamp
, convolveOut )
where
import Data.Array.Repa as A
import qualified Data.Vector.Unboxed as V
import qualified Data.Array.Repa.Shape as S
import Prelude as P
convolve
:: (Elt a, Num a)
=> (DIM2 -> a)
-> Array DIM2 a
-> Array DIM2 a
-> Array DIM2 a
convolve makeOut
kernel@(Array (_ :. krnHeight :. krnWidth) [Region RangeAll (GenManifest krnVec)])
image@(Array imgSh@(_ :. imgHeight :. imgWidth) [Region RangeAll (GenManifest imgVec)])
= kernel `deepSeqArray` image `deepSeqArray`
force $ unsafeTraverse image id update
where
!krnHeight2 = krnHeight `div` 2
!krnWidth2 = krnWidth `div` 2
!borderLeft = krnWidth2
!borderRight = imgWidth krnWidth2 1
!borderUp = krnHeight2
!borderDown = imgHeight krnHeight2 1
update _ ix@(_ :. j :. i)
| i < borderLeft = makeOut ix
| i > borderRight = makeOut ix
| j < borderUp = makeOut ix
| j > borderDown = makeOut ix
| otherwise = stencil j i
stencil j i
= let imgStart = S.toIndex imgSh (Z :. j krnHeight2 :. i krnWidth2)
in integrate 0 0 0 imgStart 0
integrate !acc !x !y !imgCur !krnCur
| y >= krnHeight
= acc
| x >= krnWidth
= integrate acc 0 (y + 1) (imgCur + imgWidth krnWidth) krnCur
| otherwise
= let imgZ = imgVec `V.unsafeIndex` imgCur
krnZ = krnVec `V.unsafeIndex` krnCur
here = imgZ * krnZ
in integrate (acc + here) (x + 1) y (imgCur + 1) (krnCur + 1)
type GetOut a
= (DIM2 -> a)
-> DIM2
-> DIM2
-> a
outAs :: a -> GetOut a
outAs x _ _ _ = x
outClamp :: GetOut a
outClamp get (_ :. yLen :. xLen) (sh :. j :. i)
= clampX j i
where
clampX !y !x
| x < 0 = clampY y 0
| x >= xLen = clampY y (xLen 1)
| otherwise = clampY y x
clampY !y !x
| y < 0 = get (sh :. 0 :. x)
| y >= yLen = get (sh :. (yLen 1) :. x)
| otherwise = get (sh :. y :. x)
convolveOut
:: (Elt a, Num a)
=> GetOut a
-> Array DIM2 a
-> Array DIM2 a
-> Array DIM2 a
convolveOut getOut
kernel@(Array krnSh@(_ :. krnHeight :. krnWidth) _)
image@(Array imgSh@(_ :. imgHeight :. imgWidth) _)
= kernel `deepSeqArray` image `deepSeqArray`
force $ unsafeTraverse image id stencil
where
!krnHeight2 = krnHeight `div` 2
!krnWidth2 = krnWidth `div` 2
!krnSize = S.size krnSh
!borderLeft = krnWidth2
!borderRight = imgWidth krnWidth2 1
!borderUp = krnHeight2
!borderDown = imgHeight krnHeight2 1
stencil get (_ :. j :. i)
= let
get' ix@(_ :. y :. x)
| x < borderLeft = getOut get imgSh ix
| x > borderRight = getOut get imgSh ix
| y < borderUp = getOut get imgSh ix
| y > borderDown = getOut get imgSh ix
| otherwise = get ix
!ikrnWidth' = i krnWidth2
!jkrnHeight' = j krnHeight2
integrate !count !acc
| count == krnSize = acc
| otherwise
= let !ix@(sh :. y :. x) = S.fromIndex krnSh count
!ix' = sh :. y + jkrnHeight' :. x + ikrnWidth'
!here = kernel `unsafeIndex` ix * (get' ix')
in integrate (count + 1) (acc + here)
in integrate 0 0