Coverage for pyhiperta/utils/convolve.py: 100%
10 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-07 20:49 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-07 20:49 +0000
1from typing import Tuple
3import numpy as np
6def convolve_view(a: np.ndarray, stencil_shape: Tuple[int, ...]) -> np.ndarray:
7 """Read-only view into `a` that has `stencil_shape` extra dimensions for each part of `a` that a stencil
8 of shape `stencil_shape` would operate on.
10 Directly taken from https://stackoverflow.com/questions/43086557/convolve2d-just-by-using-numpy
11 It allows to view the sub-parts of `a` of shape `stencil_shape` without copying or duplicating `a`'s data.
13 Parameters
14 ----------
15 a : np.ndarray
16 The array to get the view in.
17 stencil_shape : Tuple[int, ...]
18 The shape of the stencil that we would want to operate on `a`. It must have the same number of dimension
19 than `a`.
21 Returns
22 -------
23 np.ndarray
24 Read-only view in `a` with `stencil_shape` extra dimension for each stencil sub-parts.
25 Shape: [*(a.shape - (stencil_shape - 1)), *stencil_shape]. The shape first dimensions are the same than `a`'s,
26 subtracting the stencil elements that would fall outside of `a` in the boundaries. The shape last dimensions are
27 the same as `stencil_shape`. Example: a.shape=(55, 55), stencil_shape=(3, 3): result's shape: (53, 53, 3, 3)
29 Examples
30 --------
31 >>> a = np.arange(55*55).reshape((55, 55))
32 >>> stencil = np.array([[0, 1], [1, 0]])
33 >>> convolve_view(a, stencil.shape).shape
34 [54, 54, 2, 2]
35 >>> convolve_view(a, stencil.shape)[0,0,:,:]
36 array([[0, 1], [55, 56]]) # left/high-most 2x2 part of `a`.
37 >>> convolve_view(a, stencil.shape)[1,1,:,:]
38 array([[56, 57], [111, 112]])
40 Raises
41 ------
42 ValueError
43 If `stencil_shape` and the array's shape are not compatible, or if the stencil shape is invalid.
44 """
45 if len(a.shape) != len(stencil_shape):
46 raise ValueError(
47 "Stencil shape {} and array shape {} must have the same number of dimensions".format(
48 stencil_shape, a.shape
49 )
50 )
51 if not all([0 < s <= a.shape[i] for i, s in enumerate(stencil_shape)]):
52 raise ValueError(
53 "Stencil shape must be strictly positive and smaller or equal than a.shape in all dimensions. "
54 "Got stencil shape {} and a's shape: {}".format(stencil_shape, a.shape)
55 )
57 # The output shape is a.shape - (stencil_shape - 1)
58 # The minus 1 is because the stencil center element is applied on each pixel of a, so it
59 # doesn't reduce the shape.
60 # Stencil of shape [3, 3] reduces each axis shape by 2: 1 element on each end of each axis for instance
61 convolve_view_shape = tuple(np.subtract(a.shape, stencil_shape) + 1) + stencil_shape
62 # strides of the view's extra dimension are the same than of the input array: we index subparts of it!
63 convolve_view_strides = a.strides + a.strides
65 return np.lib.stride_tricks.as_strided(
66 a, shape=convolve_view_shape, strides=convolve_view_strides, writeable=False
67 )