-- {-# LANGUAGE BangPatterns #-}

-- |
-- Copyright   : (c) Johannes Kropp
-- License     : BSD 3-Clause
-- Maintainer  : Johannes Kropp <jodak932@gmail.com>

module Math.Nuha.Internal where

-- import Debug.Trace
import Data.Vector.Unboxed (Unbox)
import qualified Data.Vector.Unboxed as V



-- | cartesian product for a list of lists
cartProd :: [[Int]] -> [[Int]]
{-# INLINE cartProd #-}
cartProd :: [[Int]] -> [[Int]]
cartProd [[Int]]
mIdcs
    | [[Int]]
mIdcs [[Int]] -> [[Int]] -> Bool
forall a. Eq a => a -> a -> Bool
== [] = [Char] -> [[Int]]
forall a. HasCallStack => [Char] -> a
error [Char]
"cartProd : Should not happen, empty mIdcs"
    | ([[Int]] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Int]]
mIdcs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) = [ [Int
i] | Int
i <- [[Int]]
mIdcs[[Int]] -> Int -> [Int]
forall a. [a] -> Int -> a
!!Int
0]
    | ([[Int]] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Int]]
mIdcs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2) = [ [Int
i1,Int
i2] | Int
i1 <- [[Int]]
mIdcs[[Int]] -> Int -> [Int]
forall a. [a] -> Int -> a
!!Int
0, Int
i2 <- [[Int]]
mIdcs[[Int]] -> Int -> [Int]
forall a. [a] -> Int -> a
!!Int
1]
    | Bool
otherwise = [Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
indices | Int
i <- [[Int]]
mIdcs[[Int]] -> Int -> [Int]
forall a. [a] -> Int -> a
!!Int
0, [Int]
indices <- [[Int]] -> [[Int]]
cartProd ([[Int]] -> [[Int]]
forall a. [a] -> [a]
tail [[Int]]
mIdcs)]

-- | unsafe function for convert an 1d index (of the holor values) to a multiindex of the holor
fromIndexToMultiIndex :: [Int] -> Int -> [Int]
{-# INLINE fromIndexToMultiIndex #-}
fromIndexToMultiIndex :: [Int] -> Int -> [Int]
fromIndexToMultiIndex [Int]
strides Int
idx = Int -> Int -> [Int]
iterIndices Int
idx Int
0 where
    iterIndices :: Int -> Int -> [Int]
iterIndices Int
p Int
i
        | (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
strides = Int
div Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> Int -> [Int]
iterIndices Int
mod (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
        | Bool
otherwise = [Int
div]
        where
            stride :: Int
stride = [Int]
strides[Int] -> Int -> Int
forall a. [a] -> Int -> a
!!Int
i
            (Int
div, Int
mod) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
divMod Int
p Int
stride

-- | unsafe function for convert multiindex of the holor to a an 1d index (of the holor values)
fromMultiIndexToIndex
    :: [Int] -- ^ strides
    -> [Int] -- ^ multiindex
    -> Int -- ^ index
{-# INLINE fromMultiIndexToIndex #-}
fromMultiIndexToIndex :: [Int] -> [Int] -> Int
fromMultiIndexToIndex [Int]
strides [Int]
mIdx = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
strides [Int]
mIdx)

-- | calculates the holor strides from the holor shape
fromShapeToStrides :: [Int] -> [Int]
{-# INLINE fromShapeToStrides #-}
fromShapeToStrides :: [Int] -> [Int]
fromShapeToStrides [Int]
shape = [(Int -> Int -> Int) -> Int -> [Int] -> Int
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
1 (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
i [Int]
shape) | Int
i <- [Int
1..[Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shape]]

-- | calculates all possible multiindices for a holor from its shape
fromShapeToMultiIndices :: [Int] -> [[Int]]
{-# INLINE fromShapeToMultiIndices #-}
fromShapeToMultiIndices :: [Int] -> [[Int]]
fromShapeToMultiIndices [Int]
shape = [[Int]]
mIdcs where
    mIdcs :: [[Int]]
mIdcs = [[Int]] -> [[Int]]
cartProd [[Int]]
ranges
    ranges :: [[Int]]
ranges = [[Int
0..Int
sInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] | Int
s <- [Int]
shape]

-- | tests if an index is valid
isValidIdx
    :: Int -- ^ length
    -> Int -- ^ idx
    -> Bool
{-# INLINE isValidIdx #-}
isValidIdx :: Int -> Int -> Bool
isValidIdx Int
len Int
idx = (Int
idx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len) Bool -> Bool -> Bool
&& (Int
idx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0)

-- | tests if a multiindex is valid
isValidMIdx
    :: [Int] -- ^ shape
    -> [Int] -- ^ multiindex
    -> Bool
{-# INLINE isValidMIdx #-}
isValidMIdx :: [Int] -> [Int] -> Bool
isValidMIdx [Int]
shp [Int]
mIdx =
    ([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
mIdx) Bool -> Bool -> Bool
&&
    ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (Int -> Bool) -> [Int] -> [Bool]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0) ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (-) [Int]
shp [Int]
mIdx)) Bool -> Bool -> Bool
&&
    ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (Int -> Bool) -> [Int] -> [Bool]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>=Int
0) [Int]
mIdx)

-- | tests if multiindices are valid
isValidMIdcs
    :: [Int] -- ^ shape
    -> [[Int]] -- ^ multiindices
    -> Bool
{-# INLINE isValidMIdcs #-}
isValidMIdcs :: [Int] -> [[Int]] -> Bool
isValidMIdcs [Int]
shp [[Int]]
mIdcs =
    [[Int]] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Int]]
mIdcs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
shp Bool -> Bool -> Bool
&&
    [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [[Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (Int -> Bool) -> [Int] -> [Bool]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
k -> (Int
0Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<=Int
k Bool -> Bool -> Bool
&& Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< [Int]
shp[Int] -> Int -> Int
forall a. [a] -> Int -> a
!!Int
i) ) ([[Int]]
mIdcs[[Int]] -> Int -> [Int]
forall a. [a] -> Int -> a
!!Int
i) | Int
i<-[Int
0 .. [[Int]] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Int]]
mIdcs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]