Skip to content

Commit 2029eef

Browse files
authored
[AC]: support single channel foreground in background matting (#3683)
1 parent b9719a4 commit 2029eef

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/metrics/background_matting.py

+6
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ class MeanOfAbsoluteDifference(BaseBackgroundMattingMetrics):
9494
def update(self, annotation, prediction):
9595
pred = self.get_prediction(prediction)
9696
gt = self.get_annotation(annotation)
97+
if pred.shape[-1] == 1 and pred.shape[-1] != gt.shape[-1]:
98+
gt = cv2.cvtColor(gt, cv2.COLOR_RGB2GRAY)
9799
value = np.mean(abs(pred - gt)) * 1e3
98100
self.results.append(value)
99101
return value
@@ -105,6 +107,8 @@ class SpatialGradient(BaseBackgroundMattingMetrics):
105107
def update(self, annotation, prediction):
106108
pred = self.get_prediction(prediction)
107109
gt = self.get_annotation(annotation)
110+
if pred.shape[-1] == 1 and pred.shape[-1] != gt.shape[-1]:
111+
gt = cv2.cvtColor(gt, cv2.COLOR_RGB2GRAY)
108112
gt_grad = self.gauss_gradient(gt)
109113
pred_grad = self.gauss_gradient(pred)
110114
value = np.sum((gt_grad - pred_grad) ** 2) / 1000
@@ -152,6 +156,8 @@ class MeanSquaredErrorWithMask(BaseBackgroundMattingMetrics):
152156
def update(self, annotation, prediction):
153157
pred = self.get_prediction(prediction)
154158
gt = self.get_annotation(annotation)
159+
if pred.shape[-1] == 1 and pred.shape[-1] != gt.shape[-1]:
160+
gt = cv2.cvtColor(gt, cv2.COLOR_RGB2GRAY)
155161
if self.use_mask:
156162
mask = self.prepare_pha(annotation.value) > 0
157163
pred = pred[mask]

0 commit comments

Comments
 (0)