diff --git a/Common/PandasMapper.py b/Common/PandasMapper.py index c18bd1812a83..cde0311c42df 100644 --- a/Common/PandasMapper.py +++ b/Common/PandasMapper.py @@ -37,7 +37,10 @@ def __new__(cls, key): def __eq__(self, other): # We need this since Lean created data frames might contain Symbol objects in the indexes - return super().__eq__(other) and type(other) is not Symbol + if type(other) is Symbol: + return False + # For non-strings str.__eq__ returns NotImplemented, which Python resolves to False + return super().__eq__(other) def __hash__(self): return super().__hash__() diff --git a/Tests/Python/PandasIndexingTests.cs b/Tests/Python/PandasIndexingTests.cs index c57999fbc8b2..0644ff1ec372 100644 --- a/Tests/Python/PandasIndexingTests.cs +++ b/Tests/Python/PandasIndexingTests.cs @@ -85,5 +85,17 @@ public void ExpectedException() Assert.IsTrue(exception.Contains("No key found for either mapped or original key.", StringComparison.InvariantCulture), exception); } } + + [Test] + public void ColumnEqualsOnlyMatchingString() + { + using (Py.GIL()) + { + PyObject result = _pandasDataFrameTests.test_column_equals_only_matching_string(); + var test = result.As(); + + Assert.IsTrue(test); + } + } } } diff --git a/Tests/Python/PandasTests/PandasIndexingTests.py b/Tests/Python/PandasTests/PandasIndexingTests.py index 1ebd7d931aae..77a9100420af 100644 --- a/Tests/Python/PandasTests/PandasIndexingTests.py +++ b/Tests/Python/PandasTests/PandasIndexingTests.py @@ -14,6 +14,7 @@ from AlgorithmImports import * from QuantConnect.Tests import * from QuantConnect.Tests.Python import * +from PandasMapper import PandasColumn # TODO: Rename to PandasResearchTests and keep this class for QB related tests; rename py module to PandasTests class PandasIndexingTests(): @@ -68,3 +69,8 @@ def test_contains_user_defined_columns_with_spaces(self, column_name): return True except: return False + + def test_column_equals_only_matching_string(self): + # A column label should only equal a matching string, never None/ints/floats + column = PandasColumn("shares") + return (not (column == None)) and (not (column == 0)) and (not (column == 123)) and (column == "shares")