Skip to content

Commit

Permalink
rustfmt
Browse files Browse the repository at this point in the history
  • Loading branch information
llogiq committed Sep 26, 2023
1 parent 5c13966 commit 5dc85ad
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 51 deletions.
20 changes: 15 additions & 5 deletions src/integer_simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ unsafe fn usize_load_unchecked(bytes: &[u8], offset: usize) -> usize {
ptr::copy_nonoverlapping(
bytes.as_ptr().add(offset),
&mut output as *mut usize as *mut u8,
mem::size_of::<usize>()
mem::size_of::<usize>(),
);
output
}
Expand Down Expand Up @@ -65,11 +65,17 @@ pub fn chunk_count(haystack: &[u8], needle: u8) -> usize {
// 8
let mut counts = 0;
for i in 0..(haystack.len() - offset) / chunksize {
counts += bytewise_equal(usize_load_unchecked(haystack, offset + i * chunksize), needles);
counts += bytewise_equal(
usize_load_unchecked(haystack, offset + i * chunksize),
needles,
);
}
if haystack.len() % 8 != 0 {
let mask = usize::from_le(!(!0 >> ((haystack.len() % chunksize) * 8)));
counts += bytewise_equal(usize_load_unchecked(haystack, haystack.len() - chunksize), needles) & mask;
counts += bytewise_equal(
usize_load_unchecked(haystack, haystack.len() - chunksize),
needles,
) & mask;
}
count += sum_usize(counts);

Expand Down Expand Up @@ -98,11 +104,15 @@ pub fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
// 8
let mut counts = 0;
for i in 0..(utf8_chars.len() - offset) / chunksize {
counts += is_leading_utf8_byte(usize_load_unchecked(utf8_chars, offset + i * chunksize));
counts +=
is_leading_utf8_byte(usize_load_unchecked(utf8_chars, offset + i * chunksize));
}
if utf8_chars.len() % 8 != 0 {
let mask = usize::from_le(!(!0 >> ((utf8_chars.len() % chunksize) * 8)));
counts += is_leading_utf8_byte(usize_load_unchecked(utf8_chars, utf8_chars.len() - chunksize)) & mask;
counts += is_leading_utf8_byte(usize_load_unchecked(
utf8_chars,
utf8_chars.len() - chunksize,
)) & mask;
}
count += sum_usize(counts);

Expand Down
9 changes: 7 additions & 2 deletions src/naive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ pub fn naive_count_32(haystack: &[u8], needle: u8) -> usize {
/// assert_eq!(number_of_spaces, 6);
/// ```
pub fn naive_count(utf8_chars: &[u8], needle: u8) -> usize {
utf8_chars.iter().fold(0, |n, c| n + (*c == needle) as usize)
utf8_chars
.iter()
.fold(0, |n, c| n + (*c == needle) as usize)
}

/// Count the number of UTF-8 encoded Unicode codepoints in a slice of bytes, simple
Expand All @@ -38,5 +40,8 @@ pub fn naive_count(utf8_chars: &[u8], needle: u8) -> usize {
/// assert_eq!(char_count, 4);
/// ```
pub fn naive_num_chars(utf8_chars: &[u8]) -> usize {
utf8_chars.iter().filter(|&&byte| (byte >> 6) != 0b10).count()
utf8_chars
.iter()
.filter(|&&byte| (byte >> 6) != 0b10)
.count()
}
20 changes: 11 additions & 9 deletions src/simd/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ use std::mem;
use self::packed_simd::{u8x32, u8x64, FromCast};

const MASK: [u8; 64] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
];

unsafe fn u8x64_from_offset(slice: &[u8], offset: usize) -> u8x64 {
Expand Down Expand Up @@ -66,15 +65,17 @@ pub fn chunk_count(haystack: &[u8], needle: u8) -> usize {
// 32
let mut counts = u8x32::splat(0);
for i in 0..(haystack.len() - offset) / 32 {
counts -= u8x32::from_cast(u8x32_from_offset(haystack, offset + i * 32).eq(needles_x32));
counts -=
u8x32::from_cast(u8x32_from_offset(haystack, offset + i * 32).eq(needles_x32));
}
count += sum_x32(&counts);

// Straggler; need to reset counts because prior loop can run 255 times
counts = u8x32::splat(0);
if haystack.len() % 32 != 0 {
counts -= u8x32::from_cast(u8x32_from_offset(haystack, haystack.len() - 32).eq(needles_x32)) &
u8x32_from_offset(&MASK, haystack.len() % 32);
counts -=
u8x32::from_cast(u8x32_from_offset(haystack, haystack.len() - 32).eq(needles_x32))
& u8x32_from_offset(&MASK, haystack.len() % 32);
}
count += sum_x32(&counts);

Expand Down Expand Up @@ -127,8 +128,9 @@ pub fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
// Straggler; need to reset counts because prior loop can run 255 times
counts = u8x32::splat(0);
if utf8_chars.len() % 32 != 0 {
counts -= is_leading_utf8_byte_x32(u8x32_from_offset(utf8_chars, utf8_chars.len() - 32)) &
u8x32_from_offset(&MASK, utf8_chars.len() % 32);
counts -=
is_leading_utf8_byte_x32(u8x32_from_offset(utf8_chars, utf8_chars.len() - 32))
& u8x32_from_offset(&MASK, utf8_chars.len() % 32);
}
count += sum_x32(&counts);

Expand Down
52 changes: 23 additions & 29 deletions src/simd/x86_avx2.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
use std::arch::x86_64::{
__m256i,
_mm256_and_si256,
_mm256_cmpeq_epi8,
_mm256_extract_epi64,
_mm256_loadu_si256,
_mm256_sad_epu8,
_mm256_set1_epi8,
_mm256_setzero_si256,
_mm256_sub_epi8,
_mm256_xor_si256,
__m256i, _mm256_and_si256, _mm256_cmpeq_epi8, _mm256_extract_epi64, _mm256_loadu_si256,
_mm256_sad_epu8, _mm256_set1_epi8, _mm256_setzero_si256, _mm256_sub_epi8, _mm256_xor_si256,
};

#[target_feature(enable = "avx2")]
Expand All @@ -22,10 +14,9 @@ pub unsafe fn mm256_cmpneq_epi8(a: __m256i, b: __m256i) -> __m256i {
}

const MASK: [u8; 64] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
];

#[target_feature(enable = "avx2")]
Expand All @@ -36,10 +27,10 @@ unsafe fn mm256_from_offset(slice: &[u8], offset: usize) -> __m256i {
#[target_feature(enable = "avx2")]
unsafe fn sum(u8s: &__m256i) -> usize {
let sums = _mm256_sad_epu8(*u8s, _mm256_setzero_si256());
(
_mm256_extract_epi64(sums, 0) + _mm256_extract_epi64(sums, 1) +
_mm256_extract_epi64(sums, 2) + _mm256_extract_epi64(sums, 3)
) as usize
(_mm256_extract_epi64(sums, 0)
+ _mm256_extract_epi64(sums, 1)
+ _mm256_extract_epi64(sums, 2)
+ _mm256_extract_epi64(sums, 3)) as usize
}

#[target_feature(enable = "avx2")]
Expand All @@ -57,7 +48,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
for _ in 0..255 {
counts = _mm256_sub_epi8(
counts,
_mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles)
_mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles),
);
offset += 32;
}
Expand All @@ -70,7 +61,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
for _ in 0..128 {
counts = _mm256_sub_epi8(
counts,
_mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles)
_mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles),
);
offset += 32;
}
Expand All @@ -82,16 +73,16 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
for i in 0..(haystack.len() - offset) / 32 {
counts = _mm256_sub_epi8(
counts,
_mm256_cmpeq_epi8(mm256_from_offset(haystack, offset + i * 32), needles)
_mm256_cmpeq_epi8(mm256_from_offset(haystack, offset + i * 32), needles),
);
}
if haystack.len() % 32 != 0 {
counts = _mm256_sub_epi8(
counts,
_mm256_and_si256(
_mm256_cmpeq_epi8(mm256_from_offset(haystack, haystack.len() - 32), needles),
mm256_from_offset(&MASK, haystack.len() % 32)
)
mm256_from_offset(&MASK, haystack.len() % 32),
),
);
}
count += sum(&counts);
Expand All @@ -101,7 +92,10 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {

#[target_feature(enable = "avx2")]
unsafe fn is_leading_utf8_byte(u8s: __m256i) -> __m256i {
mm256_cmpneq_epi8(_mm256_and_si256(u8s, _mm256_set1_epu8(0b1100_0000)), _mm256_set1_epu8(0b1000_0000))
mm256_cmpneq_epi8(
_mm256_and_si256(u8s, _mm256_set1_epu8(0b1100_0000)),
_mm256_set1_epu8(0b1000_0000),
)
}

#[target_feature(enable = "avx2")]
Expand All @@ -118,7 +112,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
for _ in 0..255 {
counts = _mm256_sub_epi8(
counts,
is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset))
is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)),
);
offset += 32;
}
Expand All @@ -131,7 +125,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
for _ in 0..128 {
counts = _mm256_sub_epi8(
counts,
is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset))
is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)),
);
offset += 32;
}
Expand All @@ -143,16 +137,16 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
for i in 0..(utf8_chars.len() - offset) / 32 {
counts = _mm256_sub_epi8(
counts,
is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset + i * 32))
is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset + i * 32)),
);
}
if utf8_chars.len() % 32 != 0 {
counts = _mm256_sub_epi8(
counts,
_mm256_and_si256(
is_leading_utf8_byte(mm256_from_offset(utf8_chars, utf8_chars.len() - 32)),
mm256_from_offset(&MASK, utf8_chars.len() % 32)
)
mm256_from_offset(&MASK, utf8_chars.len() % 32),
),
);
}
count += sum(&counts);
Expand Down
7 changes: 1 addition & 6 deletions tests/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ extern crate bytecount;
extern crate quickcheck;
extern crate rand;

use bytecount::{
count, naive_count,
num_chars, naive_num_chars,
};
use bytecount::{count, naive_count, naive_num_chars, num_chars};
use rand::RngCore;

fn random_bytes(len: usize) -> Vec<u8> {
Expand Down Expand Up @@ -59,8 +56,6 @@ fn check_count_overflow_many() {
}
}



quickcheck! {
fn check_num_chars_correct(haystack: Vec<u8>) -> bool {
num_chars(&haystack) == naive_num_chars(&haystack)
Expand Down

0 comments on commit 5dc85ad

Please sign in to comment.