Skip to content

Commit 0666cea

Browse files
committed
[ruby/prism] SIMD/SWAR for strpbrk
ruby/prism@c464b298aa
1 parent 120c9ed commit 0666cea

3 files changed

Lines changed: 237 additions & 69 deletions

File tree

prism/defines.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,18 @@
276276
#define PRISM_UNLIKELY(x) (x)
277277
#endif
278278

279+
/**
280+
* Platform detection for SIMD / fast-path implementations. At most one of
281+
* these macros is defined, selecting the best available vectorization strategy.
282+
*/
283+
#if (defined(__aarch64__) && defined(__ARM_NEON)) || defined(_M_ARM64)
284+
#define PRISM_HAS_NEON
285+
#elif (defined(__x86_64__) && defined(__SSSE3__)) || defined(_M_X64)
286+
#define PRISM_HAS_SSSE3
287+
#elif defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
288+
#define PRISM_HAS_SWAR
289+
#endif
290+
279291
/**
280292
* Count trailing zero bits in a 64-bit value. Used by SWAR identifier scanning
281293
* to find the first non-matching byte in a word.

prism/prism.c

Lines changed: 7 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,16 +1783,14 @@ char_is_identifier_utf8(const uint8_t *b, ptrdiff_t n) {
17831783
* Callers must handle any remaining bytes (short tail or non-ASCII/UTF-8)
17841784
* with a byte-at-a-time loop.
17851785
*
1786-
* Up to four optimized implementations are selected at compile time, with a
1786+
* Up to three optimized implementations are selected at compile time, with a
17871787
* no-op fallback for unsupported platforms:
17881788
* 1. NEON — processes 16 bytes per iteration on aarch64.
1789-
* 2. SSE2 — processes 16 bytes per iteration on x86-64.
1790-
* 3. WASM SIMD — processes 16 bytes per iteration on WebAssembly.
1791-
* 4. SWAR — little-endian fallback, processes 8 bytes per iteration.
1792-
* 5. No-op — returns 0; the caller's byte-at-a-time loop handles everything.
1789+
* 2. SSSE3 — processes 16 bytes per iteration on x86-64.
1790+
* 3. SWAR — little-endian fallback, processes 8 bytes per iteration.
17931791
*/
17941792

1795-
#if defined(__aarch64__) && defined(__ARM_NEON)
1793+
#if defined(PRISM_HAS_NEON)
17961794
#include <arm_neon.h>
17971795

17981796
static inline size_t
@@ -1844,8 +1842,8 @@ scan_identifier_ascii(const uint8_t *start, const uint8_t *end) {
18441842
return (size_t) (cursor - start);
18451843
}
18461844

1847-
#elif defined(__x86_64__) && defined(__SSE2__)
1848-
#include <emmintrin.h>
1845+
#elif defined(PRISM_HAS_SSSE3)
1846+
#include <tmmintrin.h>
18491847

18501848
static inline size_t
18511849
scan_identifier_ascii(const uint8_t *start, const uint8_t *end) {
@@ -1886,54 +1884,11 @@ scan_identifier_ascii(const uint8_t *start, const uint8_t *end) {
18861884
return (size_t) (cursor - start);
18871885
}
18881886

1889-
#elif defined(__wasm_simd128__)
1890-
#include <wasm_simd128.h>
1891-
1892-
static inline size_t
1893-
scan_identifier_ascii(const uint8_t *start, const uint8_t *end) {
1894-
const uint8_t *cursor = start;
1895-
1896-
while (cursor + 16 <= end) {
1897-
v128_t v = wasm_v128_load(cursor);
1898-
1899-
// Range checks via subtract-and-unsigned-compare: (v - lo) < count
1900-
// is true iff v is in [lo, lo + count). One subtract + one compare
1901-
// per range instead of two comparisons + AND.
1902-
1903-
// Fold case: OR with 0x20 maps A-Z to a-z.
1904-
v128_t lowered = wasm_v128_or(v, wasm_u8x16_splat(0x20));
1905-
v128_t letter = wasm_u8x16_lt(
1906-
wasm_i8x16_sub(lowered, wasm_u8x16_splat(0x61)),
1907-
wasm_u8x16_splat(0x1A));
1908-
1909-
v128_t digit = wasm_u8x16_lt(
1910-
wasm_i8x16_sub(v, wasm_u8x16_splat(0x30)),
1911-
wasm_u8x16_splat(0x0A));
1912-
1913-
v128_t underscore = wasm_i8x16_eq(v, wasm_u8x16_splat(0x5F));
1914-
1915-
v128_t ident = wasm_v128_or(wasm_v128_or(letter, digit), underscore);
1916-
1917-
// Fast path: if all 16 bytes are identifier chars, advance.
1918-
if (wasm_i8x16_all_true(ident)) {
1919-
cursor += 16;
1920-
continue;
1921-
}
1922-
1923-
// Extract bitmask only on the exit path to find the first non-match.
1924-
uint32_t mask = wasm_i8x16_bitmask(ident);
1925-
cursor += pm_ctzll((uint64_t) (~mask & 0xFFFF));
1926-
return (size_t) (cursor - start);
1927-
}
1928-
1929-
return (size_t) (cursor - start);
1930-
}
1931-
19321887
// The SWAR path uses pm_ctzll to find the first non-matching byte within a
19331888
// word, which only yields the correct byte index on little-endian targets.
19341889
// We gate on a positive little-endian check so that unknown-endianness
19351890
// platforms safely fall through to the no-op fallback.
1936-
#elif defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
1891+
#elif defined(PRISM_HAS_SWAR)
19371892

19381893
/**
19391894
* Portable SWAR fallback — processes 8 bytes per iteration.

prism/util/pm_strpbrk.c

Lines changed: 218 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,214 @@ pm_strpbrk_explicit_encoding_set(pm_parser_t *parser, uint32_t start, uint32_t l
2929
parser->explicit_encoding = parser->encoding;
3030
}
3131

32+
/**
33+
* Scan forward through ASCII bytes looking for a byte that is in the given
34+
* charset. Returns true if a match was found, storing its offset in *index.
35+
* Returns false if no match was found, storing the number of ASCII bytes
36+
* consumed in *index (so the caller can skip past them).
37+
*
38+
* All charset characters must be ASCII (< 0x80). The scanner stops at non-ASCII
39+
* bytes, returning control to the caller's encoding-aware loop.
40+
*
41+
* Up to three optimized implementations are selected at compile time, with a
42+
* no-op fallback for unsupported platforms:
43+
* 1. NEON — processes 16 bytes per iteration on aarch64.
44+
* 2. SSSE3 — processes 16 bytes per iteration on x86-64.
45+
* 3. SWAR — little-endian fallback, processes 8 bytes per iteration.
46+
*/
47+
48+
#if defined(PRISM_HAS_NEON)
49+
#include <arm_neon.h>
50+
51+
static inline bool
52+
scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) {
53+
// Build nibble-based lookup tables from the charset. All breakpoint
54+
// characters are ASCII (< 0x80), so they fit within high nibbles 0-7.
55+
//
56+
// For each charset byte c, we set bit (1 << (c >> 4)) in low_lut[c & 0xF].
57+
// high_lut[h] = (1 << h) for each high nibble h present in the charset.
58+
// A source byte s matches iff (low_lut[s & 0xF] & high_lut[s >> 4]) != 0.
59+
uint8_t low_arr[16] = { 0 };
60+
uint8_t high_arr[16] = { 0 };
61+
uint64_t table[4] = { 0 };
62+
63+
for (const uint8_t *c = charset; *c != '\0'; c++) {
64+
low_arr[*c & 0x0F] |= (uint8_t) (1 << (*c >> 4));
65+
high_arr[*c >> 4] = (uint8_t) (1 << (*c >> 4));
66+
table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F);
67+
}
68+
69+
uint8x16_t low_lut = vld1q_u8(low_arr);
70+
uint8x16_t high_lut = vld1q_u8(high_arr);
71+
uint8x16_t mask_0f = vdupq_n_u8(0x0F);
72+
uint8x16_t mask_80 = vdupq_n_u8(0x80);
73+
74+
size_t idx = 0;
75+
76+
while (idx + 16 <= maximum) {
77+
uint8x16_t v = vld1q_u8(source + idx);
78+
79+
// If any byte has the high bit set, we have non-ASCII data.
80+
// Return to let the caller's encoding-aware loop handle it.
81+
if (vmaxvq_u8(vandq_u8(v, mask_80)) != 0) break;
82+
83+
uint8x16_t lo_class = vqtbl1q_u8(low_lut, vandq_u8(v, mask_0f));
84+
uint8x16_t hi_class = vqtbl1q_u8(high_lut, vshrq_n_u8(v, 4));
85+
uint8x16_t matched = vtstq_u8(lo_class, hi_class);
86+
87+
if (vmaxvq_u8(matched) == 0) {
88+
idx += 16;
89+
continue;
90+
}
91+
92+
// Find the position of the first matching byte.
93+
uint64_t lo64 = vgetq_lane_u64(vreinterpretq_u64_u8(matched), 0);
94+
if (lo64 != 0) {
95+
*index = idx + pm_ctzll(lo64) / 8;
96+
return true;
97+
}
98+
uint64_t hi64 = vgetq_lane_u64(vreinterpretq_u64_u8(matched), 1);
99+
*index = idx + 8 + pm_ctzll(hi64) / 8;
100+
return true;
101+
}
102+
103+
// Scalar tail for remaining < 16 ASCII bytes.
104+
while (idx < maximum && source[idx] < 0x80) {
105+
uint8_t byte = source[idx];
106+
if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
107+
*index = idx;
108+
return true;
109+
}
110+
idx++;
111+
}
112+
113+
*index = idx;
114+
return false;
115+
}
116+
117+
#elif defined(PRISM_HAS_SSSE3)
118+
#include <tmmintrin.h>
119+
120+
static inline bool
121+
scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) {
122+
// Build nibble-based lookup tables and bitmap table in a single pass.
123+
uint8_t low_arr[16] = { 0 };
124+
uint8_t high_arr[16] = { 0 };
125+
uint64_t table[4] = { 0 };
126+
127+
for (const uint8_t *c = charset; *c != '\0'; c++) {
128+
low_arr[*c & 0x0F] |= (uint8_t) (1 << (*c >> 4));
129+
high_arr[*c >> 4] = (uint8_t) (1 << (*c >> 4));
130+
table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F);
131+
}
132+
133+
__m128i low_lut = _mm_loadu_si128((const __m128i *) low_arr);
134+
__m128i high_lut = _mm_loadu_si128((const __m128i *) high_arr);
135+
__m128i mask_0f = _mm_set1_epi8(0x0F);
136+
137+
size_t idx = 0;
138+
139+
while (idx + 16 <= maximum) {
140+
__m128i v = _mm_loadu_si128((const __m128i *) (source + idx));
141+
142+
// If any byte has the high bit set, stop.
143+
if (_mm_movemask_epi8(v) != 0) break;
144+
145+
// Nibble-based classification using pshufb (SSSE3), same as NEON
146+
// vqtbl1q_u8. A byte matches iff (low_lut[lo_nib] & high_lut[hi_nib]) != 0.
147+
__m128i lo_class = _mm_shuffle_epi8(low_lut, _mm_and_si128(v, mask_0f));
148+
__m128i hi_class = _mm_shuffle_epi8(high_lut, _mm_and_si128(_mm_srli_epi16(v, 4), mask_0f));
149+
__m128i matched = _mm_and_si128(lo_class, hi_class);
150+
151+
// Check if any byte matched.
152+
int mask = _mm_movemask_epi8(_mm_cmpeq_epi8(matched, _mm_setzero_si128()));
153+
154+
if (mask == 0xFFFF) {
155+
// All bytes were zero — no match in this chunk.
156+
idx += 16;
157+
continue;
158+
}
159+
160+
// Find the first matching byte (first non-zero in matched).
161+
*index = idx + pm_ctzll((uint64_t) (~mask & 0xFFFF));
162+
return true;
163+
}
164+
165+
// Scalar tail.
166+
while (idx < maximum && source[idx] < 0x80) {
167+
uint8_t byte = source[idx];
168+
if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
169+
*index = idx;
170+
return true;
171+
}
172+
idx++;
173+
}
174+
175+
*index = idx;
176+
return false;
177+
}
178+
179+
#elif defined(PRISM_HAS_SWAR)
180+
181+
static inline bool
182+
scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) {
183+
// Build a 256-bit lookup table (one bit per ASCII value).
184+
uint64_t table[4] = { 0 };
185+
for (const uint8_t *c = charset; *c != '\0'; c++) {
186+
table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F);
187+
}
188+
189+
static const uint64_t highs = 0x8080808080808080ULL;
190+
size_t idx = 0;
191+
192+
while (idx + 8 <= maximum) {
193+
uint64_t word;
194+
memcpy(&word, source + idx, 8);
195+
196+
// Bail on any non-ASCII byte.
197+
if (word & highs) break;
198+
199+
// Check each byte against the charset table.
200+
for (size_t j = 0; j < 8; j++) {
201+
uint8_t byte = source[idx + j];
202+
if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
203+
*index = idx + j;
204+
return true;
205+
}
206+
}
207+
208+
idx += 8;
209+
}
210+
211+
// Scalar tail.
212+
while (idx < maximum && source[idx] < 0x80) {
213+
uint8_t byte = source[idx];
214+
if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) {
215+
*index = idx;
216+
return true;
217+
}
218+
idx++;
219+
}
220+
221+
*index = idx;
222+
return false;
223+
}
224+
225+
#else
226+
227+
static inline bool
228+
scan_strpbrk_ascii(PRISM_ATTRIBUTE_UNUSED const uint8_t *source, PRISM_ATTRIBUTE_UNUSED size_t maximum, PRISM_ATTRIBUTE_UNUSED const uint8_t *charset, size_t *index) {
229+
*index = 0;
230+
return false;
231+
}
232+
233+
#endif
234+
32235
/**
33236
* This is the default path.
34237
*/
35238
static inline const uint8_t *
36-
pm_strpbrk_utf8(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) {
37-
size_t index = 0;
38-
239+
pm_strpbrk_utf8(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) {
39240
while (index < maximum) {
40241
if (strchr((const char *) charset, source[index]) != NULL) {
41242
return source + index;
@@ -73,9 +274,7 @@ pm_strpbrk_utf8(pm_parser_t *parser, const uint8_t *source, const uint8_t *chars
73274
* This is the path when the encoding is ASCII-8BIT.
74275
*/
75276
static inline const uint8_t *
76-
pm_strpbrk_ascii_8bit(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) {
77-
size_t index = 0;
78-
277+
pm_strpbrk_ascii_8bit(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) {
79278
while (index < maximum) {
80279
if (strchr((const char *) charset, source[index]) != NULL) {
81280
return source + index;
@@ -92,8 +291,7 @@ pm_strpbrk_ascii_8bit(pm_parser_t *parser, const uint8_t *source, const uint8_t
92291
* This is the slow path that does care about the encoding.
93292
*/
94293
static inline const uint8_t *
95-
pm_strpbrk_multi_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) {
96-
size_t index = 0;
294+
pm_strpbrk_multi_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) {
97295
const pm_encoding_t *encoding = parser->encoding;
98296

99297
while (index < maximum) {
@@ -135,8 +333,7 @@ pm_strpbrk_multi_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t
135333
* the encoding only supports single-byte characters.
136334
*/
137335
static inline const uint8_t *
138-
pm_strpbrk_single_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) {
139-
size_t index = 0;
336+
pm_strpbrk_single_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) {
140337
const pm_encoding_t *encoding = parser->encoding;
141338

142339
while (index < maximum) {
@@ -192,15 +389,19 @@ pm_strpbrk_single_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t
192389
*/
193390
const uint8_t *
194391
pm_strpbrk(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, ptrdiff_t length, bool validate) {
195-
if (length <= 0) {
196-
return NULL;
197-
} else if (!parser->encoding_changed) {
198-
return pm_strpbrk_utf8(parser, source, charset, (size_t) length, validate);
392+
if (length <= 0) return NULL;
393+
394+
size_t maximum = (size_t) length;
395+
size_t index = 0;
396+
if (scan_strpbrk_ascii(source, maximum, charset, &index)) return source + index;
397+
398+
if (!parser->encoding_changed) {
399+
return pm_strpbrk_utf8(parser, source, charset, index, maximum, validate);
199400
} else if (parser->encoding == PM_ENCODING_ASCII_8BIT_ENTRY) {
200-
return pm_strpbrk_ascii_8bit(parser, source, charset, (size_t) length, validate);
401+
return pm_strpbrk_ascii_8bit(parser, source, charset, index, maximum, validate);
201402
} else if (parser->encoding->multibyte) {
202-
return pm_strpbrk_multi_byte(parser, source, charset, (size_t) length, validate);
403+
return pm_strpbrk_multi_byte(parser, source, charset, index, maximum, validate);
203404
} else {
204-
return pm_strpbrk_single_byte(parser, source, charset, (size_t) length, validate);
405+
return pm_strpbrk_single_byte(parser, source, charset, index, maximum, validate);
205406
}
206407
}

0 commit comments

Comments
 (0)