diff --git a/crates/transpiler/src/passes/inverse_cancellation.rs b/crates/transpiler/src/passes/inverse_cancellation.rs index dd5be9068983..28110b33cb63 100644 --- a/crates/transpiler/src/passes/inverse_cancellation.rs +++ b/crates/transpiler/src/passes/inverse_cancellation.rs @@ -174,11 +174,31 @@ static SELF_INVERSE_GATES_FOR_CANCELLATION: [StandardGate; 15] = [ StandardGate::C3X, ]; -static INVERSE_PAIRS_FOR_CANCELLATION: [[StandardGate; 2]; 4] = [ - [StandardGate::T, StandardGate::Tdg], - [StandardGate::S, StandardGate::Sdg], - [StandardGate::SX, StandardGate::SXdg], - [StandardGate::CS, StandardGate::CSdg], +// Inverse cancellation pairs. We store pairs, plus additional info +// if the gates are symmetric and cancel irrespective of qubit order. +struct InversePair { + gates: [StandardGate; 2], + symmetric: bool, +} +static INVERSE_PAIRS_FOR_CANCELLATION: [InversePair; 4] = [ + // for 1-q gates, the symmetric flag does not matter -- it is slightly more efficient + // to set it to `false` in this case to avoid more involved qubit equality checks + InversePair { + gates: [StandardGate::T, StandardGate::Tdg], + symmetric: false, + }, + InversePair { + gates: [StandardGate::S, StandardGate::Sdg], + symmetric: false, + }, + InversePair { + gates: [StandardGate::SX, StandardGate::SXdg], + symmetric: false, + }, + InversePair { + gates: [StandardGate::CS, StandardGate::CSdg], + symmetric: true, + }, ]; fn std_self_inverse(dag: &mut DAGCircuit) { @@ -235,22 +255,22 @@ fn std_self_inverse(dag: &mut DAGCircuit) { } fn std_inverse_pairs(dag: &mut DAGCircuit) { - if !INVERSE_PAIRS_FOR_CANCELLATION.iter().any(|gate| { - dag.get_op_counts().contains_key(gate[0].name()) - && dag.get_op_counts().contains_key(gate[1].name()) + if !INVERSE_PAIRS_FOR_CANCELLATION.iter().any(|pair| { + dag.get_op_counts().contains_key(pair.gates[0].name()) + && dag.get_op_counts().contains_key(pair.gates[1].name()) }) { return; } // Handle inverse pairs - for [gate_0, gate_1] in INVERSE_PAIRS_FOR_CANCELLATION { - if !dag.get_op_counts().contains_key(gate_0.name()) - || !dag.get_op_counts().contains_key(gate_1.name()) + for pair in INVERSE_PAIRS_FOR_CANCELLATION.iter() { + if !dag.get_op_counts().contains_key(pair.gates[0].name()) + || !dag.get_op_counts().contains_key(pair.gates[1].name()) { continue; } let filter = |inst: &PackedInstruction| -> bool { match inst.op.view() { - OperationRef::StandardGate(gate) => gate == gate_0 || gate == gate_1, + OperationRef::StandardGate(gate) => gate == pair.gates[0] || gate == pair.gates[1], _ => false, } }; @@ -264,11 +284,25 @@ fn std_inverse_pairs(dag: &mut DAGCircuit) { let NodeType::Operation(next_inst) = &dag[nodes[i + 1]] else { unreachable!("Not an op node"); }; - if inst.qubits == next_inst.qubits - && (inst.op.try_standard_gate() == Some(gate_0) - && next_inst.op.try_standard_gate() == Some(gate_1)) - || (inst.op.try_standard_gate() == Some(gate_1) - && next_inst.op.try_standard_gate() == Some(gate_0)) + + // For symmetric gates, check if the qubits match irrespective of the order + let qubits_match = if pair.symmetric { + let inst_qubits = dag.get_qargs(inst.qubits); + let next_inst_qubits = dag.get_qargs(next_inst.qubits); + + // We know that we have only 2-qubit gates here, so we perform a brief + // check based on the slice, without converting to a set. We also do not + // have to check the size matches, since we later check that the gates match, + // hence we know the number of qubits matches. + inst_qubits.iter().all(|q| next_inst_qubits.contains(q)) + } else { + inst.qubits == next_inst.qubits + }; + if qubits_match + && ((inst.op.try_standard_gate() == Some(pair.gates[0]) + && next_inst.op.try_standard_gate() == Some(pair.gates[1])) + || (inst.op.try_standard_gate() == Some(pair.gates[1]) + && next_inst.op.try_standard_gate() == Some(pair.gates[0]))) { dag.remove_op_node(nodes[i]); dag.remove_op_node(nodes[i + 1]); diff --git a/releasenotes/notes/fix-inverse-cancellation-symmetric-pairs-6739a450e53e5c84.yaml b/releasenotes/notes/fix-inverse-cancellation-symmetric-pairs-6739a450e53e5c84.yaml new file mode 100644 index 000000000000..578ab3b78662 --- /dev/null +++ b/releasenotes/notes/fix-inverse-cancellation-symmetric-pairs-6739a450e53e5c84.yaml @@ -0,0 +1,9 @@ +--- +fixes: + - | + Fixed a bug in :class:`.InverseCancellation` where :class:`.CSGate` and :class:`.CSdgGate` pairs + with reversed qubit arguments were inconsistently cancelled. Previously, a sequence of + :math:`CS-CS^\dagger` (with reversed qubit args) was cancelled, but + :math:`CS^\dagger-CS` was not. Now both are consistently cancelled. + + Fixed `#15855 `__. diff --git a/test/python/transpiler/test_inverse_cancellation.py b/test/python/transpiler/test_inverse_cancellation.py index 5a33ea29ca4b..db057d833b04 100644 --- a/test/python/transpiler/test_inverse_cancellation.py +++ b/test/python/transpiler/test_inverse_cancellation.py @@ -15,6 +15,7 @@ """ import unittest +import itertools import numpy as np import ddt @@ -34,6 +35,8 @@ TdgGate, CZGate, RZGate, + CSGate, + CSdgGate, ) from test import QiskitTestCase @@ -740,6 +743,33 @@ def test_nested_control_flow(self, gates_to_cancel): self.assertEqual(pass_(test), expected) + def test_symmetries(self): + """Test cancellation of symmetric gates. + + CS/CSdg should cancel if they are on the same set of qubits, irrespective of the order. + """ + # We're testing CS and CSdg in every possible combination and qubit order. + # All should cancel. + for qargs in itertools.product([(0, 1), (1, 0)], repeat=2): + for gates in itertools.permutations([CSGate(), CSdgGate()]): + qc = QuantumCircuit(2) + qc.append(gates[0], qargs[0]) + qc.append(gates[1], qargs[1]) + + with self.subTest(gates=gates, qargs=qargs): + tqc = InverseCancellation()(qc) + self.assertEqual(tqc.count_ops(), {}) + + # sanity check: verify that CS-CSdg doesn't globally get cancelled + with self.subTest(msg="sanity check"): + qc = QuantumCircuit(3) + qc.cs(0, 1) + qc.csdg(2, 1) + + tqc = InverseCancellation()(qc) + self.assertEqual(tqc.count_ops().get("cs", 0), 1) + self.assertEqual(tqc.count_ops().get("csdg", 0), 1) + if __name__ == "__main__": unittest.main()