Skip to content

Converting Faster-RCNN from PyTorch to CoreML #2479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
gizzleon opened this issue Apr 4, 2025 · 5 comments
Open

Converting Faster-RCNN from PyTorch to CoreML #2479

gizzleon opened this issue Apr 4, 2025 · 5 comments
Labels
bug Unexpected behaviour that should be corrected (type)

Comments

@gizzleon
Copy link

gizzleon commented Apr 4, 2025

🐞Describing the bug

Hi, I am converting a PyTorch Faster R-CNN model to CoreML and encountered data type mismatching issue, which may be related to #2440

The model I'm converting is torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn_v2.

The first issue was the unsupported torchvision::roi_align operator. With the implementation from this PR, I was able to convert a single RoIAlign layer.

However, when converting the whole Faster R-CNN model, the second input variable rois has unexpected shape (0,1) and dtype int32, where it is supposed to be a (N,5) float tensor.

Stack Trace

ERROR - converting 'torchvision::roi_align' op (located at: 'network/roi_heads/box_roi_pool/result_idx_in_level.1'):

Converting PyTorch Frontend ==> MIL Ops:  81%|████████▏ | 1374/1686 [00:00<00:00, 6381.85 ops/s]
Traceback (most recent call last):
  File "./bug_report.py", line 104, in <module>
    convert_faster_rcnn_model()
  File "./bug_report.py", line 101, in convert_faster_rcnn_model
    ct.convert(traced_model, inputs=[ct.TensorType(name="Input", shape=input_.shape)])
  File "./venv/lib/python3.12/site-packages/coremltools/converters/_converters_entry.py", line 635, in convert
    mlmodel = mil_convert(
              ^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
                         ^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 288, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 88, in load
    return _perform_torch_convert(converter, debug)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 151, in _perform_torch_convert
    prog = converter.convert()
           ^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1387, in convert
    convert_nodes(self.context, self.graph, early_exit=not has_states)
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 116, in convert_nodes
    raise e     # re-raise exception
    ^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 111, in convert_nodes
    convert_single_node(context, node)
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 175, in convert_single_node
    add_op(context, node)
  File "./bug_report.py", line 46, in roi_align
    x = mb.crop_resize(
        ^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/ops/registry.py", line 183, in add_op
    return cls._add_op(op_cls_to_add, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/builder.py", line 217, in _add_op
    new_op = op_cls(**kwargs)
             ^^^^^^^^^^^^^^^^
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 195, in __init__
    self._validate_and_set_inputs(input_kv)
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 511, in _validate_and_set_inputs
    self.input_spec.validate_inputs(self.name, self.op_type, input_kvs)
  File "./venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/input_type.py", line 138, in validate_inputs
    raise ValueError(msg)
ValueError: In op, of type crop_resize, named crop_resize_0, the named input `roi` must have the same data type as the named input `x`. However, roi has dtype int32 whereas x has dtype fp32.

To Reproduce

import coremltools as ct
import torch
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
    register_torch_op,
)
from coremltools.converters.mil.mil import Builder as mb
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2

from torchvision.ops.roi_align import RoIAlign


@register_torch_op(torch_alias=["torchvision::roi_align"])
def roi_align(context, node):
    inputs = _get_inputs(context, node)

    x = context[node.inputs[0]]
    input_shape = x.shape  # (B, h_in, w_in, C)
    if len(input_shape) != 4:
        raise ValueError(
            '"CropResize" op: expected input rank 4, got {}'.format(x.rank)
        )

    const_box_info = True
    if context[node.inputs[1]].val is None or context[node.inputs[2]].val is None:
        const_box_info = False

    extrapolation_value = context[node.inputs[2]].val

    # CoreML index information along with boxes
    if const_box_info:
        boxes = context[node.inputs[1]].val
        # CoreML expects boxes/ROI in
        # [N, 1, 5, 1, 1] format
        boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1)
    else:
        boxes = inputs[1]
        boxes = mb.reshape(x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1])
    # Get Height and Width of crop
    h_out = inputs[3]
    w_out = inputs[4]

    # Torch input format: [B, C, h_in, w_in]
    # CoreML input format: [B, C, h_in, w_in]

    # Crop Resize
    x = mb.crop_resize(
        x=x,
        roi=boxes,
        target_height=h_out.val,
        target_width=w_out.val,
        normalized_coordinates=True,
        spatial_scale=extrapolation_value,
        box_coordinate_mode="CORNERS_HEIGHT_FIRST",
        sampling_mode="OFFSET_CORNERS",
    )

    # CoreML output format: [N, 1, C, h_out, w_out]
    # Torch output format: [N, C, h_out, w_out]
    x = mb.squeeze(x=x, axes=[1])

    context.add(x, torch_name=node.outputs[0])


def convert_roi_align_layer():
    roi_align_layer = RoIAlign(
        output_size=(7, 7), spatial_scale=1.0, sampling_ratio=1, aligned=False
    )

    input_tensor = torch.randn((1, 3, 400, 800))
    rois_stacked = torch.FloatTensor([[0, 0, 0, 10, 10], [0, 5, 5, 20, 20]])

    roi_align_layer.eval()

    traced_model = torch.jit.trace(roi_align_layer, (input_tensor, rois_stacked))

    ct.convert(
        traced_model,
        inputs=[
            ct.TensorType(name="Input", shape=input_tensor.shape),
            ct.TensorType(name="Rois", shape=rois_stacked.shape),
        ],
    )


def convert_faster_rcnn_model():
    model = fasterrcnn_resnet50_fpn_v2(pretrained=False)

    class ModelWrapper(torch.nn.Module):
        def __init__(self, network: torch.nn.Module):
            super().__init__()
            self.network = network

        def forward(self, x):
            output = self.network(x)[0]
            return output["boxes"], output["labels"], output["scores"]

    wrapped_model = ModelWrapper(model)

    input_ = torch.randn((1, 3, 400, 800))
    wrapped_model.eval()

    traced_model = torch.jit.trace(wrapped_model, input_)

    ct.convert(traced_model, inputs=[ct.TensorType(name="Input", shape=input_.shape)])


convert_roi_align_layer()
convert_faster_rcnn_model()

System environment:

  • coremltools version: 8.1
  • OS (e.g. MacOS version or Linux type): MacOS 15.3.2
  • Any other relevant version information (e.g. PyTorch or TensorFlow version):
    • torch==2.5.1
    • torchvision==0.20.1
    • numpy==1.26.4
@gizzleon gizzleon added the bug Unexpected behaviour that should be corrected (type) label Apr 4, 2025
@reneleonhardt
Copy link
Contributor

@gizzleon just out of curiosity: coremltools, torch, and torchvision had newer releases in January.
Is this report older than 2 weeks, could you try with current dependencies again?

And when I install coremltools, numpy 2.2 is being used, wasn't that the case for you?

@gizzleon
Copy link
Author

@reneleonhardt The issue persists on newer versions of coremltools 8.2, torch 2.6.0 and torchvision 0.21.0.

I am not able to use numpy 2.x as coremltools/converters/mil/mil/ops/defs/iOS15/elementwise_unary.py uses a copy operation deprecated in 2.x.

@reneleonhardt
Copy link
Contributor

I can't find any deprecations in 2.0, 2.1 or 2.2 regarding copy or these two function calls.
https://numpy.org/doc/stable/release/2.0.0-notes.html#deprecations

Release notes say numpy 2 is supported:
https://github.com/apple/coremltools/releases/tag/8.0

If you have time maybe you can open another issue for your environment 🙂

@gizzleon
Copy link
Author

gizzleon commented Apr 17, 2025

I can't find any deprecations in 2.0, 2.1 or 2.2 regarding copy or these two function calls. https://numpy.org/doc/stable/release/2.0.0-notes.html#deprecations

Release notes say numpy 2 is supported: https://github.com/apple/coremltools/releases/tag/8.0

If you have time maybe you can open another issue for your environment 🙂

It is a behavior change on the copy keyword rather than a deprecation. Sorry for the confusion.

The log I got with numpy 2.2.4:

  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 112, in convert_nodes
    convert_single_node(context, node)
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 173, in convert_single_node
    add_op(context, node)
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 6992, in reciprocal
    context.add(mb.inverse(x=inputs[0], name=node.name))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/ops/registry.py", line 183, in add_op
    return cls._add_op(op_cls_to_add, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/builder.py", line 237, in _add_op
    new_op.type_value_inference()
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 265, in type_value_inference
    output_vals = self._auto_val(output_types)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 382, in _auto_val
    vals = self.value_inference()
           ^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/operation.py", line 111, in wrapper
    return func(self)
           ^^^^^^^^^^
  File "venv/lib/python3.12/site-packages/coremltools/converters/mil/mil/ops/defs/iOS15/elementwise_unary.py", line 449, in value_inference
    return np.array(np.reciprocal(self.x.val + self.epsilon.val), copy=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Unable to avoid copy while creating an array as requested.
If using `np.array(obj, copy=False)` replace it with `np.asarray(obj)` to allow a copy when needed (no behavior change in NumPy 1.x).
For more details, see https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword.

Unfortunately I couldn't encapsulate and create a smaller sample for reproduction. The issue goes away when I'm converting a minimal network with torch.reciprocal

@reneleonhardt
Copy link
Contributor

Interesting, now I can see what you mean, thank you!
This was the only copy=False so I created #2488 to migrate to np.asarray(), maybe it will be merged 🤞

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type)
Projects
None yet
Development

No branches or pull requests

2 participants