diff --git a/core/src/main/java/com/google/errorprone/bugpatterns/ReferenceEquality.java b/core/src/main/java/com/google/errorprone/bugpatterns/ReferenceEquality.java
index dee7ad853ff..fd180563827 100644
--- a/core/src/main/java/com/google/errorprone/bugpatterns/ReferenceEquality.java
+++ b/core/src/main/java/com/google/errorprone/bugpatterns/ReferenceEquality.java
@@ -18,19 +18,29 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.errorprone.BugPattern.SeverityLevel.WARNING;
+import static com.google.errorprone.VisitorState.memoize;
+import static com.google.errorprone.util.ASTHelpers.getUpperBound;
+import static com.google.errorprone.util.ASTHelpers.isSameType;
+import static com.google.errorprone.util.ASTHelpers.isSubtype;
+import static javax.lang.model.element.Modifier.FINAL;
+import static javax.lang.model.element.Modifier.SEALED;
import com.google.errorprone.BugPattern;
import com.google.errorprone.BugPattern.StandardTags;
import com.google.errorprone.VisitorState;
+import com.google.errorprone.suppliers.Supplier;
import com.google.errorprone.util.ASTHelpers;
import com.sun.source.tree.ClassTree;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.LambdaExpressionTree;
import com.sun.source.tree.MethodTree;
+import com.sun.tools.javac.code.Scope;
import com.sun.tools.javac.code.Symbol;
+import com.sun.tools.javac.code.Symbol.ClassSymbol;
import com.sun.tools.javac.code.Symbol.MethodSymbol;
import com.sun.tools.javac.code.Symtab;
import com.sun.tools.javac.code.Type;
+import com.sun.tools.javac.util.Name;
/** A {@link BugChecker}; see the associated {@link BugPattern} annotation for details. */
@BugPattern(
@@ -56,10 +66,7 @@ protected boolean matchArgument(ExpressionTree tree, VisitorState state) {
if (inComparisonMethod(classType, type, state)) {
return false;
}
- if (ASTHelpers.isSubtype(type, state.getSymtab().enumSym.type, state)) {
- return false;
- }
- if (ASTHelpers.isSubtype(type, state.getSymtab().classType, state)) {
+ if (definitelyUsesReferenceEquality(type, state)) {
return false;
}
return true;
@@ -107,4 +114,93 @@ private static boolean overridesMethodOnType(
private static Symbol getOnlyMember(VisitorState state, Type type, String name) {
return getOnlyElement(type.tsym.members().getSymbolsByName(state.getName(name)));
}
+
+ /**
+ * Returns {@code true} if an instance of {@code type} is guaranteed to have an {@code equals}
+ * implementation that is equivalent to {@code ==}.
+ *
+ *
We can guarantee this for:
+ *
+ *
+ * - enum classes
+ *
- {@code final} classes that inherit {@link Object#equals} instead of having a more
+ * specific implementation
+ *
- {@code sealed} classes whose permitted subclasses all definitely use reference equality
+ * according to this method
+ *
+ */
+ private static boolean definitelyUsesReferenceEquality(Type type, VisitorState state) {
+ return definitelyUsesReferenceEquality(type, state, 0);
+ }
+
+ private static boolean definitelyUsesReferenceEquality(Type type, VisitorState state, int depth) {
+ if (depth > 1000) {
+ /*
+ * javac should never generate classes that form a PermittedSubclasses cycle, but just in case
+ * some system does, we bail out when we have seen a chain that is implausibly long.
+ */
+ return false;
+ }
+
+ /*
+ * If a value has static type `Class`, for example, then it uses reference equality, since
+ * `Class` is a `final` class that does not override `equals`. But we also want to cover cases
+ * like those of a value whose static type is `T` if `T` is declared as `T extends Class`.
+ * To do so, we look at the upper bound of the static type, transitively resolving a chain of
+ * bounds (e.g., `, U extends T>`) until we reach a fixed point.
+ */
+ Type previous;
+ do {
+ previous = type;
+ type = getUpperBound(type, state.getTypes());
+ } while (!state.getTypes().isSameType(type, previous));
+ if (type.tsym == null) {
+ return false;
+ }
+ if (isSubtype(type, state.getSymtab().enumSym.type, state)) {
+ return true;
+ }
+ if (implementsEquals(type, state)) {
+ return false;
+ }
+ if (type.tsym.getModifiers().contains(FINAL)) {
+ return true;
+ }
+ if (type.tsym.getModifiers().contains(SEALED)) {
+ if (type.tsym instanceof ClassSymbol classSymbol) {
+ for (Type sub : classSymbol.getPermittedSubclasses()) {
+ if (!definitelyUsesReferenceEquality(sub, state, depth + 1)) {
+ return false;
+ }
+ }
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Returns {@code true} if {@code type} declares or inherits an override of {@link Object#equals}.
+ */
+ private static boolean implementsEquals(Type type, VisitorState state) {
+ Name equalsName = EQUALS.get(state);
+ Symbol objectEquals = getOnlyMember(state, state.getSymtab().objectType, "equals");
+ for (Type sup : state.getTypes().closure(type)) {
+ if (isSameType(sup, state.getSymtab().objectType, state)) {
+ continue;
+ }
+ Scope scope = sup.tsym.members();
+ if (scope == null) {
+ continue;
+ }
+ for (Symbol sym : scope.getSymbolsByName(equalsName)) {
+ if (sym.overrides(objectEquals, type.tsym, state.getTypes(), /* checkResult= */ false)) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ private static final Supplier EQUALS = memoize(state -> state.getName("equals"));
}
diff --git a/core/src/test/java/com/google/errorprone/bugpatterns/ReferenceEqualityTest.java b/core/src/test/java/com/google/errorprone/bugpatterns/ReferenceEqualityTest.java
index 746f00f85d7..6613e62fd5a 100644
--- a/core/src/test/java/com/google/errorprone/bugpatterns/ReferenceEqualityTest.java
+++ b/core/src/test/java/com/google/errorprone/bugpatterns/ReferenceEqualityTest.java
@@ -617,6 +617,271 @@ class Test {
.doTest();
}
+ @Test
+ public void arrayComparison() {
+ compilationHelper
+ .addSourceLines(
+ "Test.java",
+ """
+ class Test {
+ boolean f(int[] a, int[] b) {
+ // BUG: Diagnostic contains: ReferenceEquality
+ return a == b;
+ }
+ }
+ """)
+ .doTest();
+ }
+
+ @Test
+ public void finalClassWithoutEquals() {
+ compilationHelper
+ .addSourceLines(
+ "Test.java",
+ """
+ final class Test {
+ boolean f(Test a, Test b) {
+ return a == b;
+ }
+ }
+ """)
+ .doTest();
+ }
+
+ @Test
+ public void finalClassWithEquals() {
+ compilationHelper
+ .addSourceLines(
+ "Test.java",
+ """
+ final class Test {
+ public boolean equals(Object o) {
+ return true;
+ }
+
+ boolean f(Test a, Test b) {
+ // BUG: Diagnostic contains: a.equals(b)
+ return a == b;
+ }
+ }
+ """)
+ .doTest();
+ }
+
+ @Test
+ public void sealedClassWithoutEquals() {
+ compilationHelper
+ .addSourceLines(
+ "Sealed.java",
+ """
+ sealed interface Sealed permits Final1, Final2 {}
+ """)
+ .addSourceLines(
+ "Final1.java",
+ """
+ final class Final1 implements Sealed {}
+ """)
+ .addSourceLines(
+ "Final2.java",
+ """
+ final class Final2 implements Sealed {}
+ """)
+ .addSourceLines(
+ "Test.java",
+ """
+ class Test {
+ boolean f(Sealed a, Sealed b) {
+ return a == b;
+ }
+ }
+ """)
+ .doTest();
+ }
+
+ @Test
+ public void sealedClassWithEquals() {
+ compilationHelper
+ .addSourceLines(
+ "Sealed.java",
+ """
+ sealed interface Sealed permits Final1, Final2 {}
+ """)
+ .addSourceLines(
+ "Final1.java",
+ """
+ final class Final1 implements Sealed {
+ public boolean equals(Object o) {
+ return true;
+ }
+ }
+ """)
+ .addSourceLines(
+ "Final2.java",
+ """
+ final class Final2 implements Sealed {}
+ """)
+ .addSourceLines(
+ "Test.java",
+ """
+ class Test {
+ boolean f(Sealed a, Sealed b) {
+ // BUG: Diagnostic contains: a.equals(b)
+ return a == b;
+ }
+ }
+ """)
+ .doTest();
+ }
+
+ @Test
+ public void sealedClassWithNonSealedSubclass() {
+ compilationHelper
+ .addSourceLines(
+ "Sealed.java",
+ """
+ sealed interface Sealed permits NonSealedSub {}
+ """)
+ .addSourceLines(
+ "NonSealedSub.java",
+ """
+ non-sealed class NonSealedSub implements Sealed {}
+ """)
+ .addSourceLines(
+ "Test.java",
+ """
+ class Test {
+ boolean f(Sealed a, Sealed b) {
+ // BUG: Diagnostic contains: a.equals(b)
+ return a == b;
+ }
+ }
+ """)
+ .doTest();
+ }
+
+ @Test
+ public void sealedClassWithEnumSubclass() {
+ compilationHelper
+ .addSourceLines(
+ "Sealed.java",
+ """
+ sealed interface Sealed permits MyEnum {}
+ """)
+ .addSourceLines(
+ "MyEnum.java",
+ """
+ enum MyEnum implements Sealed {
+ INSTANCE;
+ }
+ """)
+ .addSourceLines(
+ "Test.java",
+ """
+ class Test {
+ boolean f(Sealed a, Sealed b) {
+ return a == b;
+ }
+ }
+ """)
+ .doTest();
+ }
+
+ @Test
+ public void typeVariableBoundedByClass() {
+ compilationHelper
+ .addSourceLines(
+ "Test.java",
+ """
+ class Test> {
+ boolean f(T a, T b) {
+ return a == b;
+ }
+ }
+ """)
+ .doTest();
+ }
+
+ @Test
+ public void typeVariableBoundedByFinalClassWithoutEquals() {
+ compilationHelper
+ .addSourceLines(
+ "Final.java",
+ """
+ final class Final {}
+ """)
+ .addSourceLines(
+ "Test.java",
+ """
+ class Test {
+ boolean f(T a, T b) {
+ return a == b;
+ }
+ }
+ """)
+ .doTest();
+ }
+
+ @Test
+ public void typeVariableBoundedByFinalClassWithEquals() {
+ compilationHelper
+ .addSourceLines(
+ "Final.java",
+ """
+ final class Final {
+ public boolean equals(Object o) {
+ return true;
+ }
+ }
+ """)
+ .addSourceLines(
+ "Test.java",
+ """
+ class Test {
+ boolean f(T a, T b) {
+ // BUG: Diagnostic contains: a.equals(b)
+ return a == b;
+ }
+ }
+ """)
+ .doTest();
+ }
+
+ @Test
+ public void typeVariableBoundedByNonFinalClass() {
+ compilationHelper
+ .addSourceLines(
+ "NonFinal.java",
+ """
+ class NonFinal {}
+ """)
+ .addSourceLines(
+ "Test.java",
+ """
+ class Test {
+ boolean f(T a, T b) {
+ // BUG: Diagnostic contains: a.equals(b)
+ return a == b;
+ }
+ }
+ """)
+ .doTest();
+ }
+
+ @Test
+ public void typeVariableWithTransitiveBound() {
+ compilationHelper
+ .addSourceLines(
+ "Test.java",
+ """
+ class Test, U extends T> {
+ boolean f(U a, U b) {
+ return a == b;
+ }
+ }
+ """)
+ .doTest();
+ }
+
@Test
public void memorySegment() {
assume().that(Runtime.version().feature()).isAtLeast(22);