Skip to content
19 changes: 19 additions & 0 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,25 @@ def _impl_v13(cls, bb, inputs, attr, params):
shape_val = data[np_index]
return relax.PrimValue(shape_val)

data_shape = bb.normalize(relax.op.shape_of(data))
data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape))
axis_extent = bb.normalize(
relax.op.take(data_shape_tensor, relax.const(axis, "int64"), axis=0, mode="wrap")
)

indices_dtype = indices.struct_info.dtype
if not indices_dtype.startswith("uint"):
if indices_dtype !="int64":
axis_extent = bb.normalize(relax.op.astype(axis_extent, indices_dtype))

indices = bb.normalize(
relax.op.where(
relax.op.less(indices, relax.const(0, indices_dtype)),
relax.op.add(indices, axis_extent),
indices,
)
)
Comment thread
cchung100m marked this conversation as resolved.
Outdated

return relax.op.take(data, indices, axis)


Expand Down
Loading