Coverage for pyhiperta/cleaning.py: 100%

18 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-07 20:49 +0000

1"""Implements the charge images cleaning algorithms: tail-cut cleaning.""" 

2 

3import numpy as np 

4 

5from pyhiperta.utils.convolve import convolve_view 

6from pyhiperta.waveform_indexing import neighbors_only_stencil 

7 

8 

9def tail_cuts_cleaning( 

10 waveforms_2D: np.ndarray, 

11 pixel_threshold: float, 

12 neighbors_threshold: float, 

13 min_number_neighbors: int, 

14) -> np.ndarray: 

15 """Compute the mask of "signal" pixel that pass the tail_cuts thresholds. 

16 

17 The implementation is in 2 steps: 

18 - find the group of pixels that pass the `pixel_threshold` 

19 - find the pixels that pass `neighbors_threshold` and have at least 1 neighbor passing the 1st step. 

20 

21 Parameters 

22 ---------- 

23 waveforms_2D : np.ndarray 

24 Batch or integrated waveform in 2D format. Shape: ([N_batch,] N_pixels_x, N_pixels_y) 

25 pixel_threshold : float 

26 A pixel with a value greater or equal than `pixel_threshold` and at least `min_number_neighbors` neighbors 

27 that have a value greater or equal than `pixel_threshold` are considered "signal". 

28 neighbors_threshold : float 

29 A pixel with a value greater or equal than `neighbors_threshold` and at least 1 neighbor that is considered 

30 signal according to `pixel_threshold` will be considered "signal" as well. 

31 min_number_neighbors : int 

32 Minimum number of neighboring pixels that must have a value above `pixel_threshold` to be considered "signal". 

33 

34 Returns 

35 ------- 

36 np.ndarray 

37 A boolean mask with value True for "signal" pixels and value False otherwise. 

38 

39 Raises 

40 ------ 

41 ValueError 

42 If the shape of waveforms_2D can not be interpreted as a (batch of) 2D waveforms 

43 """ 

44 if len(waveforms_2D.shape) < 2: 

45 raise ValueError( 

46 "waveforms must be an array with at least 2 dimensions " 

47 "(waveform 2D or batch of waveform 2D), but got {}".format(waveforms_2D.shape) 

48 ) 

49 

50 # kepp pixels that are 

51 # 1: above pixel threshold and have at least min_number_neighbors above pixel threshold as well 

52 # 2: pixels that are above neighbor's threshold and have at least 1 neighbor that checks condition 1 

53 

54 nb_batch_dimension = len(waveforms_2D.shape) - 2 

55 

56 neighbors_stencil = neighbors_only_stencil() 

57 # add as many dimension to the 2D neighbor stencil as required (to allow for batch dimension) 

58 neighbors_stencil = neighbors_stencil[*([np.newaxis] * nb_batch_dimension), ...] 

59 # get the axis dimension to reduce when reducing the view: 

60 # If waveform2D.shape = (3, 55, 55) then stencil will have shape (1, 3, 3) and the 

61 # view will have shape (3, 55, 55, 1, 3, 3) 

62 # To reduce the view (compute the convolved operation) we will reduce on axis -3, -2, -1 

63 convolution_reduction_axis = tuple([-i - 1 for i in range(len(neighbors_stencil.shape))]) 

64 

65 # we will pad with one 0 on both ends of waveform 2D, and not pad the remaining (batch) axis 

66 pad_values = [(0, 0)] * nb_batch_dimension + [(1, 1), (1, 1)] 

67 

68 # Pad the waveform with 0 on all edges to be able to convolve without reducing the shape 

69 waveforms_2D_padded = np.pad(waveforms_2D, pad_values, mode="constant", constant_values=0) 

70 neighbors_only_view = convolve_view(waveforms_2D_padded, neighbors_stencil.shape) * neighbors_stencil 

71 # condition 1: 

72 mask = (waveforms_2D >= pixel_threshold) & ( 

73 (neighbors_only_view >= pixel_threshold).sum(axis=convolution_reduction_axis) >= min_number_neighbors 

74 ) 

75 # pad the mask to compute condition 2: condition on neighbor's number of neighbors 

76 padded_mask = np.pad(mask, pad_values, mode="constant", constant_values=0) 

77 # get the convolution view for the neighbors passing condition 1 

78 neighbors_passing_condition_1 = convolve_view(padded_mask, neighbors_stencil.shape) * neighbors_stencil 

79 # condition 2: 

80 mask |= (waveforms_2D >= neighbors_threshold) & ( 

81 neighbors_passing_condition_1.any(axis=convolution_reduction_axis) 

82 ) 

83 return mask