@@ -29,13 +29,214 @@ pm_strpbrk_explicit_encoding_set(pm_parser_t *parser, uint32_t start, uint32_t l
2929 parser -> explicit_encoding = parser -> encoding ;
3030}
3131
32+ /**
33+ * Scan forward through ASCII bytes looking for a byte that is in the given
34+ * charset. Returns true if a match was found, storing its offset in *index.
35+ * Returns false if no match was found, storing the number of ASCII bytes
36+ * consumed in *index (so the caller can skip past them).
37+ *
38+ * All charset characters must be ASCII (< 0x80). The scanner stops at non-ASCII
39+ * bytes, returning control to the caller's encoding-aware loop.
40+ *
41+ * Up to three optimized implementations are selected at compile time, with a
42+ * no-op fallback for unsupported platforms:
43+ * 1. NEON — processes 16 bytes per iteration on aarch64.
44+ * 2. SSSE3 — processes 16 bytes per iteration on x86-64.
45+ * 3. SWAR — little-endian fallback, processes 8 bytes per iteration.
46+ */
47+
48+ #if defined(PRISM_HAS_NEON )
49+ #include <arm_neon.h>
50+
51+ static inline bool
52+ scan_strpbrk_ascii (const uint8_t * source , size_t maximum , const uint8_t * charset , size_t * index ) {
53+ // Build nibble-based lookup tables from the charset. All breakpoint
54+ // characters are ASCII (< 0x80), so they fit within high nibbles 0-7.
55+ //
56+ // For each charset byte c, we set bit (1 << (c >> 4)) in low_lut[c & 0xF].
57+ // high_lut[h] = (1 << h) for each high nibble h present in the charset.
58+ // A source byte s matches iff (low_lut[s & 0xF] & high_lut[s >> 4]) != 0.
59+ uint8_t low_arr [16 ] = { 0 };
60+ uint8_t high_arr [16 ] = { 0 };
61+ uint64_t table [4 ] = { 0 };
62+
63+ for (const uint8_t * c = charset ; * c != '\0' ; c ++ ) {
64+ low_arr [* c & 0x0F ] |= (uint8_t ) (1 << (* c >> 4 ));
65+ high_arr [* c >> 4 ] = (uint8_t ) (1 << (* c >> 4 ));
66+ table [* c >> 6 ] |= (uint64_t ) 1 << (* c & 0x3F );
67+ }
68+
69+ uint8x16_t low_lut = vld1q_u8 (low_arr );
70+ uint8x16_t high_lut = vld1q_u8 (high_arr );
71+ uint8x16_t mask_0f = vdupq_n_u8 (0x0F );
72+ uint8x16_t mask_80 = vdupq_n_u8 (0x80 );
73+
74+ size_t idx = 0 ;
75+
76+ while (idx + 16 <= maximum ) {
77+ uint8x16_t v = vld1q_u8 (source + idx );
78+
79+ // If any byte has the high bit set, we have non-ASCII data.
80+ // Return to let the caller's encoding-aware loop handle it.
81+ if (vmaxvq_u8 (vandq_u8 (v , mask_80 )) != 0 ) break ;
82+
83+ uint8x16_t lo_class = vqtbl1q_u8 (low_lut , vandq_u8 (v , mask_0f ));
84+ uint8x16_t hi_class = vqtbl1q_u8 (high_lut , vshrq_n_u8 (v , 4 ));
85+ uint8x16_t matched = vtstq_u8 (lo_class , hi_class );
86+
87+ if (vmaxvq_u8 (matched ) == 0 ) {
88+ idx += 16 ;
89+ continue ;
90+ }
91+
92+ // Find the position of the first matching byte.
93+ uint64_t lo64 = vgetq_lane_u64 (vreinterpretq_u64_u8 (matched ), 0 );
94+ if (lo64 != 0 ) {
95+ * index = idx + pm_ctzll (lo64 ) / 8 ;
96+ return true;
97+ }
98+ uint64_t hi64 = vgetq_lane_u64 (vreinterpretq_u64_u8 (matched ), 1 );
99+ * index = idx + 8 + pm_ctzll (hi64 ) / 8 ;
100+ return true;
101+ }
102+
103+ // Scalar tail for remaining < 16 ASCII bytes.
104+ while (idx < maximum && source [idx ] < 0x80 ) {
105+ uint8_t byte = source [idx ];
106+ if (table [byte >> 6 ] & ((uint64_t ) 1 << (byte & 0x3F ))) {
107+ * index = idx ;
108+ return true;
109+ }
110+ idx ++ ;
111+ }
112+
113+ * index = idx ;
114+ return false;
115+ }
116+
117+ #elif defined(PRISM_HAS_SSSE3 )
118+ #include <tmmintrin.h>
119+
120+ static inline bool
121+ scan_strpbrk_ascii (const uint8_t * source , size_t maximum , const uint8_t * charset , size_t * index ) {
122+ // Build nibble-based lookup tables and bitmap table in a single pass.
123+ uint8_t low_arr [16 ] = { 0 };
124+ uint8_t high_arr [16 ] = { 0 };
125+ uint64_t table [4 ] = { 0 };
126+
127+ for (const uint8_t * c = charset ; * c != '\0' ; c ++ ) {
128+ low_arr [* c & 0x0F ] |= (uint8_t ) (1 << (* c >> 4 ));
129+ high_arr [* c >> 4 ] = (uint8_t ) (1 << (* c >> 4 ));
130+ table [* c >> 6 ] |= (uint64_t ) 1 << (* c & 0x3F );
131+ }
132+
133+ __m128i low_lut = _mm_loadu_si128 ((const __m128i * ) low_arr );
134+ __m128i high_lut = _mm_loadu_si128 ((const __m128i * ) high_arr );
135+ __m128i mask_0f = _mm_set1_epi8 (0x0F );
136+
137+ size_t idx = 0 ;
138+
139+ while (idx + 16 <= maximum ) {
140+ __m128i v = _mm_loadu_si128 ((const __m128i * ) (source + idx ));
141+
142+ // If any byte has the high bit set, stop.
143+ if (_mm_movemask_epi8 (v ) != 0 ) break ;
144+
145+ // Nibble-based classification using pshufb (SSSE3), same as NEON
146+ // vqtbl1q_u8. A byte matches iff (low_lut[lo_nib] & high_lut[hi_nib]) != 0.
147+ __m128i lo_class = _mm_shuffle_epi8 (low_lut , _mm_and_si128 (v , mask_0f ));
148+ __m128i hi_class = _mm_shuffle_epi8 (high_lut , _mm_and_si128 (_mm_srli_epi16 (v , 4 ), mask_0f ));
149+ __m128i matched = _mm_and_si128 (lo_class , hi_class );
150+
151+ // Check if any byte matched.
152+ int mask = _mm_movemask_epi8 (_mm_cmpeq_epi8 (matched , _mm_setzero_si128 ()));
153+
154+ if (mask == 0xFFFF ) {
155+ // All bytes were zero — no match in this chunk.
156+ idx += 16 ;
157+ continue ;
158+ }
159+
160+ // Find the first matching byte (first non-zero in matched).
161+ * index = idx + pm_ctzll ((uint64_t ) (~mask & 0xFFFF ));
162+ return true;
163+ }
164+
165+ // Scalar tail.
166+ while (idx < maximum && source [idx ] < 0x80 ) {
167+ uint8_t byte = source [idx ];
168+ if (table [byte >> 6 ] & ((uint64_t ) 1 << (byte & 0x3F ))) {
169+ * index = idx ;
170+ return true;
171+ }
172+ idx ++ ;
173+ }
174+
175+ * index = idx ;
176+ return false;
177+ }
178+
179+ #elif defined(PRISM_HAS_SWAR )
180+
181+ static inline bool
182+ scan_strpbrk_ascii (const uint8_t * source , size_t maximum , const uint8_t * charset , size_t * index ) {
183+ // Build a 256-bit lookup table (one bit per ASCII value).
184+ uint64_t table [4 ] = { 0 };
185+ for (const uint8_t * c = charset ; * c != '\0' ; c ++ ) {
186+ table [* c >> 6 ] |= (uint64_t ) 1 << (* c & 0x3F );
187+ }
188+
189+ static const uint64_t highs = 0x8080808080808080ULL ;
190+ size_t idx = 0 ;
191+
192+ while (idx + 8 <= maximum ) {
193+ uint64_t word ;
194+ memcpy (& word , source + idx , 8 );
195+
196+ // Bail on any non-ASCII byte.
197+ if (word & highs ) break ;
198+
199+ // Check each byte against the charset table.
200+ for (size_t j = 0 ; j < 8 ; j ++ ) {
201+ uint8_t byte = source [idx + j ];
202+ if (table [byte >> 6 ] & ((uint64_t ) 1 << (byte & 0x3F ))) {
203+ * index = idx + j ;
204+ return true;
205+ }
206+ }
207+
208+ idx += 8 ;
209+ }
210+
211+ // Scalar tail.
212+ while (idx < maximum && source [idx ] < 0x80 ) {
213+ uint8_t byte = source [idx ];
214+ if (table [byte >> 6 ] & ((uint64_t ) 1 << (byte & 0x3F ))) {
215+ * index = idx ;
216+ return true;
217+ }
218+ idx ++ ;
219+ }
220+
221+ * index = idx ;
222+ return false;
223+ }
224+
225+ #else
226+
227+ static inline bool
228+ scan_strpbrk_ascii (PRISM_ATTRIBUTE_UNUSED const uint8_t * source , PRISM_ATTRIBUTE_UNUSED size_t maximum , PRISM_ATTRIBUTE_UNUSED const uint8_t * charset , size_t * index ) {
229+ * index = 0 ;
230+ return false;
231+ }
232+
233+ #endif
234+
32235/**
33236 * This is the default path.
34237 */
35238static inline const uint8_t *
36- pm_strpbrk_utf8 (pm_parser_t * parser , const uint8_t * source , const uint8_t * charset , size_t maximum , bool validate ) {
37- size_t index = 0 ;
38-
239+ pm_strpbrk_utf8 (pm_parser_t * parser , const uint8_t * source , const uint8_t * charset , size_t index , size_t maximum , bool validate ) {
39240 while (index < maximum ) {
40241 if (strchr ((const char * ) charset , source [index ]) != NULL ) {
41242 return source + index ;
@@ -73,9 +274,7 @@ pm_strpbrk_utf8(pm_parser_t *parser, const uint8_t *source, const uint8_t *chars
73274 * This is the path when the encoding is ASCII-8BIT.
74275 */
75276static inline const uint8_t *
76- pm_strpbrk_ascii_8bit (pm_parser_t * parser , const uint8_t * source , const uint8_t * charset , size_t maximum , bool validate ) {
77- size_t index = 0 ;
78-
277+ pm_strpbrk_ascii_8bit (pm_parser_t * parser , const uint8_t * source , const uint8_t * charset , size_t index , size_t maximum , bool validate ) {
79278 while (index < maximum ) {
80279 if (strchr ((const char * ) charset , source [index ]) != NULL ) {
81280 return source + index ;
@@ -92,8 +291,7 @@ pm_strpbrk_ascii_8bit(pm_parser_t *parser, const uint8_t *source, const uint8_t
92291 * This is the slow path that does care about the encoding.
93292 */
94293static inline const uint8_t *
95- pm_strpbrk_multi_byte (pm_parser_t * parser , const uint8_t * source , const uint8_t * charset , size_t maximum , bool validate ) {
96- size_t index = 0 ;
294+ pm_strpbrk_multi_byte (pm_parser_t * parser , const uint8_t * source , const uint8_t * charset , size_t index , size_t maximum , bool validate ) {
97295 const pm_encoding_t * encoding = parser -> encoding ;
98296
99297 while (index < maximum ) {
@@ -135,8 +333,7 @@ pm_strpbrk_multi_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t
135333 * the encoding only supports single-byte characters.
136334 */
137335static inline const uint8_t *
138- pm_strpbrk_single_byte (pm_parser_t * parser , const uint8_t * source , const uint8_t * charset , size_t maximum , bool validate ) {
139- size_t index = 0 ;
336+ pm_strpbrk_single_byte (pm_parser_t * parser , const uint8_t * source , const uint8_t * charset , size_t index , size_t maximum , bool validate ) {
140337 const pm_encoding_t * encoding = parser -> encoding ;
141338
142339 while (index < maximum ) {
@@ -192,15 +389,19 @@ pm_strpbrk_single_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t
192389 */
193390const uint8_t *
194391pm_strpbrk (pm_parser_t * parser , const uint8_t * source , const uint8_t * charset , ptrdiff_t length , bool validate ) {
195- if (length <= 0 ) {
196- return NULL ;
197- } else if (!parser -> encoding_changed ) {
198- return pm_strpbrk_utf8 (parser , source , charset , (size_t ) length , validate );
392+ if (length <= 0 ) return NULL ;
393+
394+ size_t maximum = (size_t ) length ;
395+ size_t index = 0 ;
396+ if (scan_strpbrk_ascii (source , maximum , charset , & index )) return source + index ;
397+
398+ if (!parser -> encoding_changed ) {
399+ return pm_strpbrk_utf8 (parser , source , charset , index , maximum , validate );
199400 } else if (parser -> encoding == PM_ENCODING_ASCII_8BIT_ENTRY ) {
200- return pm_strpbrk_ascii_8bit (parser , source , charset , ( size_t ) length , validate );
401+ return pm_strpbrk_ascii_8bit (parser , source , charset , index , maximum , validate );
201402 } else if (parser -> encoding -> multibyte ) {
202- return pm_strpbrk_multi_byte (parser , source , charset , ( size_t ) length , validate );
403+ return pm_strpbrk_multi_byte (parser , source , charset , index , maximum , validate );
203404 } else {
204- return pm_strpbrk_single_byte (parser , source , charset , ( size_t ) length , validate );
405+ return pm_strpbrk_single_byte (parser , source , charset , index , maximum , validate );
205406 }
206407}
0 commit comments