Skip to content

Commit bf33d7f

Browse files
committed
[ruby/prism] Cache strpbrk lookup tables
ruby/prism@46656b2fd5
1 parent 169ba06 commit bf33d7f

2 files changed

Lines changed: 74 additions & 43 deletions

File tree

prism/parser.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,27 @@ struct pm_parser {
962962
* toggled with a magic comment.
963963
*/
964964
bool warn_mismatched_indentation;
965+
966+
#if defined(PRISM_HAS_NEON) || defined(PRISM_HAS_SSSE3) || defined(PRISM_HAS_SWAR)
967+
/**
968+
* Cached lookup tables for pm_strpbrk's SIMD fast path. Avoids rebuilding
969+
* the nibble-based tables on every call when the charset hasn't changed
970+
* (which is the common case during string/regex/list lexing).
971+
*/
972+
struct {
973+
/** The cached charset (null-terminated, max 11 chars + NUL). */
974+
uint8_t charset[12];
975+
976+
/** Nibble-based low lookup table for SIMD matching. */
977+
uint8_t low_lut[16];
978+
979+
/** Nibble-based high lookup table for SIMD matching. */
980+
uint8_t high_lut[16];
981+
982+
/** Scalar fallback table (4 x 64-bit bitmasks covering all ASCII). */
983+
uint64_t table[4];
984+
} strpbrk_cache;
985+
#endif
965986
};
966987

967988
#endif

prism/util/pm_strpbrk.c

Lines changed: 53 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -45,29 +45,52 @@ pm_strpbrk_explicit_encoding_set(pm_parser_t *parser, uint32_t start, uint32_t l
4545
* 3. SWAR — little-endian fallback, processes 8 bytes per iteration.
4646
*/
4747

48-
#if defined(PRISM_HAS_NEON)
49-
#include <arm_neon.h>
48+
#if defined(PRISM_HAS_NEON) || defined(PRISM_HAS_SSSE3) || defined(PRISM_HAS_SWAR)
5049

51-
static inline bool
52-
scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) {
53-
// Build nibble-based lookup tables from the charset. All breakpoint
54-
// characters are ASCII (< 0x80), so they fit within high nibbles 0-7.
55-
//
56-
// For each charset byte c, we set bit (1 << (c >> 4)) in low_lut[c & 0xF].
57-
// high_lut[h] = (1 << h) for each high nibble h present in the charset.
58-
// A source byte s matches iff (low_lut[s & 0xF] & high_lut[s >> 4]) != 0.
59-
uint8_t low_arr[16] = { 0 };
60-
uint8_t high_arr[16] = { 0 };
61-
uint64_t table[4] = { 0 };
50+
/**
51+
* Update the cached strpbrk lookup tables if the charset has changed. The
52+
* parser caches the last charset's precomputed tables so that repeated calls
53+
* with the same breakpoints (the common case during string/regex/list lexing)
54+
* skip table construction entirely.
55+
*
56+
* Builds three structures:
57+
* - low_lut/high_lut: nibble-based lookup tables for SIMD matching (NEON/SSSE3)
58+
* - table: 256-bit bitmap for scalar fallback matching (all platforms)
59+
*/
60+
static inline void
61+
pm_strpbrk_cache_update(pm_parser_t *parser, const uint8_t *charset) {
62+
// The cache key is the full 12-byte charset buffer. Since it is always
63+
// NUL-padded, a fixed-size comparison covers both content and length.
64+
if (memcmp(parser->strpbrk_cache.charset, charset, sizeof(parser->strpbrk_cache.charset)) == 0) return;
65+
66+
memset(parser->strpbrk_cache.low_lut, 0, sizeof(parser->strpbrk_cache.low_lut));
67+
memset(parser->strpbrk_cache.high_lut, 0, sizeof(parser->strpbrk_cache.high_lut));
68+
memset(parser->strpbrk_cache.table, 0, sizeof(parser->strpbrk_cache.table));
6269

70+
size_t charset_len = 0;
6371
for (const uint8_t *c = charset; *c != '\0'; c++) {
64-
low_arr[*c & 0x0F] |= (uint8_t) (1 << (*c >> 4));
65-
high_arr[*c >> 4] = (uint8_t) (1 << (*c >> 4));
66-
table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F);
72+
parser->strpbrk_cache.low_lut[*c & 0x0F] |= (uint8_t) (1 << (*c >> 4));
73+
parser->strpbrk_cache.high_lut[*c >> 4] = (uint8_t) (1 << (*c >> 4));
74+
parser->strpbrk_cache.table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F);
75+
charset_len++;
6776
}
6877

69-
uint8x16_t low_lut = vld1q_u8(low_arr);
70-
uint8x16_t high_lut = vld1q_u8(high_arr);
78+
// Store the new charset key, NUL-padded to the full buffer size.
79+
memcpy(parser->strpbrk_cache.charset, charset, charset_len + 1);
80+
memset(parser->strpbrk_cache.charset + charset_len + 1, 0, sizeof(parser->strpbrk_cache.charset) - charset_len - 1);
81+
}
82+
83+
#endif
84+
85+
#if defined(PRISM_HAS_NEON)
86+
#include <arm_neon.h>
87+
88+
static inline bool
89+
scan_strpbrk_ascii(pm_parser_t *parser, const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) {
90+
pm_strpbrk_cache_update(parser, charset);
91+
92+
uint8x16_t low_lut = vld1q_u8(parser->strpbrk_cache.low_lut);
93+
uint8x16_t high_lut = vld1q_u8(parser->strpbrk_cache.high_lut);
7194
uint8x16_t mask_0f = vdupq_n_u8(0x0F);
7295
uint8x16_t mask_80 = vdupq_n_u8(0x80);
7396

@@ -103,7 +126,7 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset
103126
// Scalar tail for remaining < 16 ASCII bytes.
104127
while (idx < maximum && source[idx] < 0x80) {
105128
uint8_t byte = source[idx];
106-
if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
129+
if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
107130
*index = idx;
108131
return true;
109132
}
@@ -118,20 +141,11 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset
118141
#include <tmmintrin.h>
119142

120143
static inline bool
121-
scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) {
122-
// Build nibble-based lookup tables and bitmap table in a single pass.
123-
uint8_t low_arr[16] = { 0 };
124-
uint8_t high_arr[16] = { 0 };
125-
uint64_t table[4] = { 0 };
144+
scan_strpbrk_ascii(pm_parser_t *parser, const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) {
145+
pm_strpbrk_cache_update(parser, charset);
126146

127-
for (const uint8_t *c = charset; *c != '\0'; c++) {
128-
low_arr[*c & 0x0F] |= (uint8_t) (1 << (*c >> 4));
129-
high_arr[*c >> 4] = (uint8_t) (1 << (*c >> 4));
130-
table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F);
131-
}
132-
133-
__m128i low_lut = _mm_loadu_si128((const __m128i *) low_arr);
134-
__m128i high_lut = _mm_loadu_si128((const __m128i *) high_arr);
147+
__m128i low_lut = _mm_loadu_si128((const __m128i *) parser->strpbrk_cache.low_lut);
148+
__m128i high_lut = _mm_loadu_si128((const __m128i *) parser->strpbrk_cache.high_lut);
135149
__m128i mask_0f = _mm_set1_epi8(0x0F);
136150

137151
size_t idx = 0;
@@ -165,7 +179,7 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset
165179
// Scalar tail.
166180
while (idx < maximum && source[idx] < 0x80) {
167181
uint8_t byte = source[idx];
168-
if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
182+
if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
169183
*index = idx;
170184
return true;
171185
}
@@ -179,12 +193,8 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset
179193
#elif defined(PRISM_HAS_SWAR)
180194

181195
static inline bool
182-
scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) {
183-
// Build a 256-bit lookup table (one bit per ASCII value).
184-
uint64_t table[4] = { 0 };
185-
for (const uint8_t *c = charset; *c != '\0'; c++) {
186-
table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F);
187-
}
196+
scan_strpbrk_ascii(pm_parser_t *parser, const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) {
197+
pm_strpbrk_cache_update(parser, charset);
188198

189199
static const uint64_t highs = 0x8080808080808080ULL;
190200
size_t idx = 0;
@@ -199,7 +209,7 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset
199209
// Check each byte against the charset table.
200210
for (size_t j = 0; j < 8; j++) {
201211
uint8_t byte = source[idx + j];
202-
if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
212+
if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
203213
*index = idx + j;
204214
return true;
205215
}
@@ -211,7 +221,7 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset
211221
// Scalar tail.
212222
while (idx < maximum && source[idx] < 0x80) {
213223
uint8_t byte = source[idx];
214-
if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
224+
if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
215225
*index = idx;
216226
return true;
217227
}
@@ -225,7 +235,7 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset
225235
#else
226236

227237
static inline bool
228-
scan_strpbrk_ascii(PRISM_ATTRIBUTE_UNUSED const uint8_t *source, PRISM_ATTRIBUTE_UNUSED size_t maximum, PRISM_ATTRIBUTE_UNUSED const uint8_t *charset, size_t *index) {
238+
scan_strpbrk_ascii(PRISM_ATTRIBUTE_UNUSED pm_parser_t *parser, PRISM_ATTRIBUTE_UNUSED const uint8_t *source, PRISM_ATTRIBUTE_UNUSED size_t maximum, PRISM_ATTRIBUTE_UNUSED const uint8_t *charset, size_t *index) {
229239
*index = 0;
230240
return false;
231241
}
@@ -393,7 +403,7 @@ pm_strpbrk(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, p
393403

394404
size_t maximum = (size_t) length;
395405
size_t index = 0;
396-
if (scan_strpbrk_ascii(source, maximum, charset, &index)) return source + index;
406+
if (scan_strpbrk_ascii(parser, source, maximum, charset, &index)) return source + index;
397407

398408
if (!parser->encoding_changed) {
399409
return pm_strpbrk_utf8(parser, source, charset, index, maximum, validate);

0 commit comments

Comments
 (0)