{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE NoStarIsType #-}
module Data.Array.Shaped.Convolve(convolve) where
import Data.Array.Shaped
import Data.Array.Shaped.MatMul
import Data.Array.Internal.Shape
import GHC.TypeLits
import qualified Numeric.LinearAlgebra as N
convolve :: forall (n :: Nat) ish ksh osh wsh a ksc ksf i ws isp iwc .
( i ~ Rank ish
, ws ~ Take n ksh
, Window ws ish wsh, KnownNat (Rank ws)
, ksc ~ Size (Take i ksh)
, ksf ~ Size (Drop i ksh)
, isp ~ Size (Take n wsh)
, iwc ~ Size (Drop n wsh)
, iwc ~ ksc
, osh ~ (Take n wsh ++ Drop i ksh)
, Size wsh ~ (isp * iwc)
, Size ksh ~ (ksc * ksf)
, Size osh ~ (isp * ksf)
, Shape wsh, Shape ksh, Shape osh
, KnownNat ksc, KnownNat isp, KnownNat ksf
, N.Numeric a
) =>
Array ish a -> Array ksh a -> Array osh a
convolve :: Array ish a -> Array ksh a -> Array osh a
convolve Array ish a
i Array ksh a
k =
let iw :: Array wsh a
iw :: Array wsh a
iw = Array ish a -> Array wsh a
forall (ws :: [Nat]) (sh' :: [Nat]) (sh :: [Nat]) a.
(Window ws sh sh', KnownNat (Rank ws)) =>
Array sh a -> Array sh' a
window @ws Array ish a
i
ir :: Array [isp, iwc] a
ir :: Array '[isp, iwc] a
ir = Array wsh a -> Array '[isp, iwc] a
forall (sh' :: [Nat]) (sh :: [Nat]) a.
(Shape sh, Shape sh', Size sh ~ Size sh') =>
Array sh a -> Array sh' a
reshape Array wsh a
iw
kr :: Array [ksc, ksf] a
kr :: Array '[ksc, ksf] a
kr = Array ksh a -> Array '[ksc, ksf] a
forall (sh' :: [Nat]) (sh :: [Nat]) a.
(Shape sh, Shape sh', Size sh ~ Size sh') =>
Array sh a -> Array sh' a
reshape Array ksh a
k
m :: Array [isp, ksf] a
m :: Array '[isp, ksf] a
m = Array '[isp, iwc] a -> Array '[iwc, ksf] a -> Array '[isp, ksf] a
forall (m :: Nat) (n :: Nat) (o :: Nat) a.
(Numeric a, KnownNat m, KnownNat n, KnownNat o) =>
Array '[m, n] a -> Array '[n, o] a -> Array '[m, o] a
matMul Array '[isp, iwc] a
ir Array '[ksc, ksf] a
Array '[iwc, ksf] a
kr
r :: Array osh a
r :: Array osh a
r = Array '[isp, ksf] a -> Array osh a
forall (sh' :: [Nat]) (sh :: [Nat]) a.
(Shape sh, Shape sh', Size sh ~ Size sh') =>
Array sh a -> Array sh' a
reshape Array '[isp, ksf] a
m
in Array osh a
r
_example :: Array [20,30,3] Float -> Array [5,5,3,8] Float -> Array [16,26,8] Float
_example :: Array '[20, 30, 3] Float
-> Array '[5, 5, 3, 8] Float -> Array '[16, 26, 8] Float
_example = forall (ish :: [Nat]) (ksh :: [Nat]) (osh :: [Nat]) (wsh :: [Nat])
a (ksc :: Nat) (ksf :: Nat) (i :: Nat) (ws :: [Nat]) (isp :: Nat)
(iwc :: Nat).
(i ~ Rank ish, ws ~ Take 2 ksh, Window ws ish wsh,
KnownNat (Rank ws), ksc ~ Size (Take i ksh),
ksf ~ Size (Drop i ksh), isp ~ Size (Take 2 wsh),
iwc ~ Size (Drop 2 wsh), iwc ~ ksc,
osh ~ (Take 2 wsh ++ Drop i ksh), Size wsh ~ (isp * iwc),
Size ksh ~ (ksc * ksf), Size osh ~ (isp * ksf), Shape wsh,
Shape ksh, Shape osh, KnownNat ksc, KnownNat isp, KnownNat ksf,
Numeric a) =>
Array ish a -> Array ksh a -> Array osh a
forall (n :: Nat) (ish :: [Nat]) (ksh :: [Nat]) (osh :: [Nat])
(wsh :: [Nat]) a (ksc :: Nat) (ksf :: Nat) (i :: Nat) (ws :: [Nat])
(isp :: Nat) (iwc :: Nat).
(i ~ Rank ish, ws ~ Take n ksh, Window ws ish wsh,
KnownNat (Rank ws), ksc ~ Size (Take i ksh),
ksf ~ Size (Drop i ksh), isp ~ Size (Take n wsh),
iwc ~ Size (Drop n wsh), iwc ~ ksc,
osh ~ (Take n wsh ++ Drop i ksh), Size wsh ~ (isp * iwc),
Size ksh ~ (ksc * ksf), Size osh ~ (isp * ksf), Shape wsh,
Shape ksh, Shape osh, KnownNat ksc, KnownNat isp, KnownNat ksf,
Numeric a) =>
Array ish a -> Array ksh a -> Array osh a
convolve @2