Coverage for pyhiperta/utils/convolve.py: 100%

10 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-21 20:50 +0000

1from typing import Tuple 

2 

3import numpy as np 

4 

5 

6def convolve_view(a: np.ndarray, stencil_shape: Tuple[int, ...]) -> np.ndarray: 

7 """Read-only view into `a` that has `stencil_shape` extra dimensions for each part of `a` that a stencil 

8 of shape `stencil_shape` would operate on. 

9 

10 Directly taken from https://stackoverflow.com/questions/43086557/convolve2d-just-by-using-numpy 

11 It allows to view the sub-parts of `a` of shape `stencil_shape` without copying or duplicating `a`'s data. 

12 

13 Parameters 

14 ---------- 

15 a : np.ndarray 

16 The array to get the view in. 

17 stencil_shape : Tuple[int, ...] 

18 The shape of the stencil that we would want to operate on `a`. It must have the same number of dimension 

19 than `a`. 

20 

21 Returns 

22 ------- 

23 np.ndarray 

24 Read-only view in `a` with `stencil_shape` extra dimension for each stencil sub-parts. 

25 Shape: [*(a.shape - (stencil_shape - 1)), *stencil_shape]. The shape first dimensions are the same than `a`'s, 

26 subtracting the stencil elements that would fall outside of `a` in the boundaries. The shape last dimensions are 

27 the same as `stencil_shape`. Example: a.shape=(55, 55), stencil_shape=(3, 3): result's shape: (53, 53, 3, 3) 

28 

29 Examples 

30 -------- 

31 >>> a = np.arange(55*55).reshape((55, 55)) 

32 >>> stencil = np.array([[0, 1], [1, 0]]) 

33 >>> convolve_view(a, stencil.shape).shape 

34 [54, 54, 2, 2] 

35 >>> convolve_view(a, stencil.shape)[0,0,:,:] 

36 array([[0, 1], [55, 56]]) # left/high-most 2x2 part of `a`. 

37 >>> convolve_view(a, stencil.shape)[1,1,:,:] 

38 array([[56, 57], [111, 112]]) 

39 

40 Raises 

41 ------ 

42 ValueError 

43 If `stencil_shape` and the array's shape are not compatible, or if the stencil shape is invalid. 

44 """ 

45 if len(a.shape) != len(stencil_shape): 

46 raise ValueError( 

47 "Stencil shape {} and array shape {} must have the same number of dimensions".format( 

48 stencil_shape, a.shape 

49 ) 

50 ) 

51 if not all([0 < s <= a.shape[i] for i, s in enumerate(stencil_shape)]): 

52 raise ValueError( 

53 "Stencil shape must be strictly positive and smaller or equal than a.shape in all dimensions. " 

54 "Got stencil shape {} and a's shape: {}".format(stencil_shape, a.shape) 

55 ) 

56 

57 # The output shape is a.shape - (stencil_shape - 1) 

58 # The minus 1 is because the stencil center element is applied on each pixel of a, so it 

59 # doesn't reduce the shape. 

60 # Stencil of shape [3, 3] reduces each axis shape by 2: 1 element on each end of each axis for instance 

61 convolve_view_shape = tuple(np.subtract(a.shape, stencil_shape) + 1) + stencil_shape 

62 # strides of the view's extra dimension are the same than of the input array: we index subparts of it! 

63 convolve_view_strides = a.strides + a.strides 

64 

65 return np.lib.stride_tricks.as_strided( 

66 a, shape=convolve_view_shape, strides=convolve_view_strides, writeable=False 

67 )