diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 268d91b7500a..7d85906cffdd 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1106,6 +1106,25 @@ def _impl_v13(cls, bb, inputs, attr, params): shape_val = data[np_index] return relax.PrimValue(shape_val) + indices_dtype = indices.struct_info.dtype + if not indices_dtype.startswith("uint"): + data_shape = bb.normalize(relax.op.shape_of(data)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) + axis_extent = bb.normalize( + relax.op.take(data_shape_tensor, relax.const(axis, "int64"), axis=0, mode="wrap") + ) + + if indices_dtype !="int64": + axis_extent = bb.normalize(relax.op.astype(axis_extent, indices_dtype)) + + indices = bb.normalize( + relax.op.where( + relax.op.less(indices, relax.const(0, indices_dtype)), + relax.op.add(indices, axis_extent), + indices, + ) + ) + return relax.op.take(data, indices, axis) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 5a8d84b0900c..52a4064cc8f5 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -874,6 +874,68 @@ def _verify_gather(data_shape, indices, out_shape, axis=0): _verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1) +@pytest.mark.parametrize( + "axis, indices, out_shape", + [ + (0, [-1, 0], [2, 4]), + (1, [-1, 0], [3, 2]), + ( + 1, + [[-1, 0], [1, -2]], + [3, 2, 2], + ), + ], +) +@pytest.mark.parametrize("indices_type", [TensorProto.INT64, TensorProto.INT32]) +def test_gather_negative_indices(axis, indices, out_shape, indices_type): + gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=axis) + indices_shape = np.asarray(indices).shape + + graph = helper.make_graph( + [gather_node], + "gather_negative_indices_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, [3, 4]), + helper.make_tensor_value_info("indices", indices_type, indices_shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, out_shape)], + ) + + model = helper.make_model(graph, producer_name="gather_negative_indices_test") + indices_np_dtype = { + TensorProto.INT64: np.int64, + TensorProto.INT32: np.int32, + }[indices_type] + input_values = { + "data": np.random.randn(3, 4).astype("float32"), + "indices": np.array(indices).astype(indices_np_dtype), + } + check_correctness(model, inputs=input_values) + + +@pytest.mark.parametrize("indices_type", [TensorProto.INT64, TensorProto.INT32]) +def test_gather_negative_indices_ir_normalization(indices_type): + gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=1) + graph = helper.make_graph( + [gather_node], + "gather_negative_indices_ir_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, [3, 4]), + helper.make_tensor_value_info("indices", indices_type, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 2])], + ) + + model = helper.make_model(graph, producer_name="gather_negative_indices_ir_test") + tvm_model = from_onnx(model, opset=13, keep_params_in_input=True) + call_ops = collect_relax_call_ops(tvm_model["main"]) + + assert "relax.where" in call_ops + assert "relax.less" in call_ops + assert "relax.add" in call_ops + assert "relax.take" in call_ops + + @pytest.mark.parametrize( "data_shape, indices_shape, axis", [