@@ -94,6 +94,8 @@ class MeanOfAbsoluteDifference(BaseBackgroundMattingMetrics):
94
94
def update (self , annotation , prediction ):
95
95
pred = self .get_prediction (prediction )
96
96
gt = self .get_annotation (annotation )
97
+ if pred .shape [- 1 ] == 1 and pred .shape [- 1 ] != gt .shape [- 1 ]:
98
+ gt = cv2 .cvtColor (gt , cv2 .COLOR_RGB2GRAY )
97
99
value = np .mean (abs (pred - gt )) * 1e3
98
100
self .results .append (value )
99
101
return value
@@ -105,6 +107,8 @@ class SpatialGradient(BaseBackgroundMattingMetrics):
105
107
def update (self , annotation , prediction ):
106
108
pred = self .get_prediction (prediction )
107
109
gt = self .get_annotation (annotation )
110
+ if pred .shape [- 1 ] == 1 and pred .shape [- 1 ] != gt .shape [- 1 ]:
111
+ gt = cv2 .cvtColor (gt , cv2 .COLOR_RGB2GRAY )
108
112
gt_grad = self .gauss_gradient (gt )
109
113
pred_grad = self .gauss_gradient (pred )
110
114
value = np .sum ((gt_grad - pred_grad ) ** 2 ) / 1000
@@ -152,6 +156,8 @@ class MeanSquaredErrorWithMask(BaseBackgroundMattingMetrics):
152
156
def update (self , annotation , prediction ):
153
157
pred = self .get_prediction (prediction )
154
158
gt = self .get_annotation (annotation )
159
+ if pred .shape [- 1 ] == 1 and pred .shape [- 1 ] != gt .shape [- 1 ]:
160
+ gt = cv2 .cvtColor (gt , cv2 .COLOR_RGB2GRAY )
155
161
if self .use_mask :
156
162
mask = self .prepare_pha (annotation .value ) > 0
157
163
pred = pred [mask ]
0 commit comments