{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE DeriveTraversable #-}
module Data.Grid.Internal.Convolution
( autoConvolute
, convolute
, clampBounds
, wrapBounds
, omitBounds
, window
, Neighboring
) where
import Control.Comonad
import Control.Comonad.Representable.Store
import Data.Functor.Compose
import Data.Functor.Rep
import Data.Grid.Internal.Coord
import Data.Grid.Internal.Grid
import Data.Grid.Internal.Nest
import GHC.TypeNats
criticalError :: a
criticalError = error
"Something went wrong, please report this issue to the maintainer of grids"
autoConvolute
:: forall window dims f a b
. ( Dimensions dims
, Dimensions window
, Functor f
, Neighboring window
)
=> (Grid window (Coord dims) -> f (Coord dims))
-> (f a -> b)
-> Grid dims a
-> Grid dims b
autoConvolute restrict = convolute (restrict . window @window @dims)
convolute
:: forall dims f a b
. (Functor f, Dimensions dims)
=> (Coord dims -> f (Coord dims))
-> (f a -> b)
-> Grid dims a
-> Grid dims b
convolute selectWindow f g =
let s = store (index g) criticalError
convoluted :: Store (Grid dims) b
convoluted = extend (f . experiment (fmap roundTrip . selectWindow)) s
(tabulator, _) = runStore convoluted
in tabulate tabulator
where
roundTrip :: Coord dims -> Coord dims
roundTrip = toEnum . fromEnum
window
:: forall window dims
. (Neighboring window, Dimensions window)
=> Coord dims
-> Grid window (Coord dims)
window = fromWindow . neighboring . toWindow
where
toWindow :: Coord dims -> Coord window
toWindow = coerceCoordDims
fromWindow :: Grid window (Coord window) -> Grid window (Coord dims)
fromWindow = fmap coerceCoordDims
class Neighboring dims where
neighborCoords :: Grid dims (Coord dims)
instance {-# OVERLAPPING #-} (KnownNat n) => Neighboring '[n] where
neighborCoords = fromList' . fmap (Coord . pure . subtract (numVals `div` 2)) . take numVals $ [0 .. ]
where
numVals = gridSize @'[n]
instance (KnownNat n, Neighboring ns) => Neighboring (n:ns) where
neighborCoords = joinGrid (addCoord <$> currentLevelNeighbors)
where
addCoord :: Coord '[n] -> Grid ns (Coord (n : ns) )
addCoord c = appendC c <$> nestedNeighbors
nestedNeighbors :: Grid ns (Coord ns )
nestedNeighbors = neighborCoords
currentLevelNeighbors :: Grid '[n] (Coord '[n] )
currentLevelNeighbors = neighborCoords
neighboring :: (Dimensions dims, Neighboring dims) => Coord dims -> Grid dims (Coord dims)
neighboring c = (c +) <$> neighborCoords
clampBounds
:: (Dimensions dims, Functor f) => f (Coord dims) -> f (Coord dims)
clampBounds = fmap clampCoord
wrapBounds
:: (Dimensions dims, Functor f) => f (Coord dims) -> f (Coord dims)
wrapBounds = fmap wrapCoord
omitBounds
:: (Dimensions dims, Functor f) => f (Coord dims) -> Compose f Maybe (Coord dims)
omitBounds = Compose . fmap wrap
where
wrap c | coordInBounds c = Just c
| otherwise = Nothing