module Math.Nuha.Internal where
import Data.Vector.Unboxed (Unbox)
import qualified Data.Vector.Unboxed as V
cartProd :: [[Int]] -> [[Int]]
{-# INLINE cartProd #-}
cartProd :: [[Int]] -> [[Int]]
cartProd [[Int]]
mIdcs
| [[Int]]
mIdcs [[Int]] -> [[Int]] -> Bool
forall a. Eq a => a -> a -> Bool
== [] = [Char] -> [[Int]]
forall a. HasCallStack => [Char] -> a
error [Char]
"cartProd : Should not happen, empty mIdcs"
| ([[Int]] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Int]]
mIdcs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) = [ [Int
i] | Int
i <- [[Int]]
mIdcs[[Int]] -> Int -> [Int]
forall a. [a] -> Int -> a
!!Int
0]
| ([[Int]] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Int]]
mIdcs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2) = [ [Int
i1,Int
i2] | Int
i1 <- [[Int]]
mIdcs[[Int]] -> Int -> [Int]
forall a. [a] -> Int -> a
!!Int
0, Int
i2 <- [[Int]]
mIdcs[[Int]] -> Int -> [Int]
forall a. [a] -> Int -> a
!!Int
1]
| Bool
otherwise = [Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
indices | Int
i <- [[Int]]
mIdcs[[Int]] -> Int -> [Int]
forall a. [a] -> Int -> a
!!Int
0, [Int]
indices <- [[Int]] -> [[Int]]
cartProd ([[Int]] -> [[Int]]
forall a. [a] -> [a]
tail [[Int]]
mIdcs)]
fromIndexToMultiIndex :: [Int] -> Int -> [Int]
{-# INLINE fromIndexToMultiIndex #-}
fromIndexToMultiIndex :: [Int] -> Int -> [Int]
fromIndexToMultiIndex [Int]
strides Int
idx = Int -> Int -> [Int]
iterIndices Int
idx Int
0 where
iterIndices :: Int -> Int -> [Int]
iterIndices Int
p Int
i
| (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
strides = Int
div Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> Int -> [Int]
iterIndices Int
mod (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
| Bool
otherwise = [Int
div]
where
stride :: Int
stride = [Int]
strides[Int] -> Int -> Int
forall a. [a] -> Int -> a
!!Int
i
(Int
div, Int
mod) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
divMod Int
p Int
stride
fromMultiIndexToIndex
:: [Int]
-> [Int]
-> Int
{-# INLINE fromMultiIndexToIndex #-}
fromMultiIndexToIndex :: [Int] -> [Int] -> Int
fromMultiIndexToIndex [Int]
strides [Int]
mIdx = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
strides [Int]
mIdx)
fromShapeToStrides :: [Int] -> [Int]
{-# INLINE fromShapeToStrides #-}
fromShapeToStrides :: [Int] -> [Int]
fromShapeToStrides [Int]
shape = [(Int -> Int -> Int) -> Int -> [Int] -> Int
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
1 (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
i [Int]
shape) | Int
i <- [Int
1..[Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shape]]
fromShapeToMultiIndices :: [Int] -> [[Int]]
{-# INLINE fromShapeToMultiIndices #-}
fromShapeToMultiIndices :: [Int] -> [[Int]]
fromShapeToMultiIndices [Int]
shape = [[Int]]
mIdcs where
mIdcs :: [[Int]]
mIdcs = [[Int]] -> [[Int]]
cartProd [[Int]]
ranges
ranges :: [[Int]]
ranges = [[Int
0..Int
sInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] | Int
s <- [Int]
shape]
isValidIdx
:: Int
-> Int
-> Bool
{-# INLINE isValidIdx #-}
isValidIdx :: Int -> Int -> Bool
isValidIdx Int
len Int
idx = (Int
idx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len) Bool -> Bool -> Bool
&& (Int
idx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0)
isValidMIdx
:: [Int]
-> [Int]
-> Bool
{-# INLINE isValidMIdx #-}
isValidMIdx :: [Int] -> [Int] -> Bool
isValidMIdx [Int]
shp [Int]
mIdx =
([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
mIdx) Bool -> Bool -> Bool
&&
([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (Int -> Bool) -> [Int] -> [Bool]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0) ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (-) [Int]
shp [Int]
mIdx)) Bool -> Bool -> Bool
&&
([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (Int -> Bool) -> [Int] -> [Bool]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>=Int
0) [Int]
mIdx)
isValidMIdcs
:: [Int]
-> [[Int]]
-> Bool
{-# INLINE isValidMIdcs #-}
isValidMIdcs :: [Int] -> [[Int]] -> Bool
isValidMIdcs [Int]
shp [[Int]]
mIdcs =
[[Int]] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Int]]
mIdcs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shp Bool -> Bool -> Bool
&&
[Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [[Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (Int -> Bool) -> [Int] -> [Bool]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
k -> (Int
0Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<=Int
k Bool -> Bool -> Bool
&& Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< [Int]
shp[Int] -> Int -> Int
forall a. [a] -> Int -> a
!!Int
i) ) ([[Int]]
mIdcs[[Int]] -> Int -> [Int]
forall a. [a] -> Int -> a
!!Int
i) | Int
i<-[Int
0 .. [[Int]] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Int]]
mIdcs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]