厌倦了C++,CS&ML博士用Rust重写Python扩展,还总结了9条规则
效果好不好,试一试就知道了。
比 Python 快;
兼容 NumPy;
可以进行数据并行多线程处理;
与执行数据并行多线程的所有其他包兼容;
安全。
[build-system]
requires = ["maturin==0.12.5"]
build-backend = "maturin"
[dependencies]
thiserror = "1.0.30"
ndarray-npy = { version = "0.8.1", default-features = false }
rayon = "1.5.1"
numpy = "0.15.0"
ndarray = { version = "0.15.4", features = ["approx", "rayon"] }
pyo3 = { version = "0.15.1", features = ["extension-module"] }
[dev-dependencies]
temp_testdir = "0.2.3"
mod python_module;
mod tests;
#[pyfn(m)]
#[pyo3(name = "read_f64")]
fn read_f64_py(
_py: Python<'_>,
filename: &str,
iid_count: usize,
sid_count: usize,
count_a1: bool,
iid_index: &PyArray1<usize>,
sid_index: &PyArray1<usize>,
val: &PyArray2<f64>,
num_threads: usize,
) -> Result<(), PyErr> {
let iid_index = iid_index.readonly();
let sid_index = sid_index.readonly();
let mut val = unsafe { val.as_array_mut() };
let ii = &iid_index.as_slice()?;
let si = &sid_index.as_slice()?;
create_pool(num_threads)?.install(|| {
read_no_alloc(
filename,
iid_count,
sid_count,
count_a1,
ii,
si,
f64::NAN,
&mut val,
)
})?;
Ok(())
}
let iid_index = iid_index.readonly();
let ii = &iid_index.as_slice()?;
let mut val = unsafe { val.as_array_mut() };
from .bed_reader import [...] read_f64 [...]
def read([...]):
[...]
val = np.zeros((len(iid_index), len(sid_index)), order=order, dtype=dtype)
[...]
reader = read_f64
[...]
reader(
str(self.filepath),
iid_count=self.iid_count,
sid_count=self.sid_count,
count_a1=self.count_A1,
iid_index=iid_index,
sid_index=sid_index,
val=val,
num_threads=num_threads,
)
[...]
return val
let mut buf_reader = BufReader::new(File::open(filename)?);
use thiserror::Error;
...
/// BedErrorPlus enumerates all possible errors
/// returned by this library.
/// Based on https://nick.groenen.me/posts/rust-error-handling/#the-library-error-type
#[derive(Error, Debug)]
pub enum BedErrorPlus {
#[error(transparent)]
IOError(#[from] std::io::Error),
#[error(transparent)]
BedError(#[from] BedError),
#[error(transparent)]
ThreadPoolError(#[from] ThreadPoolBuildError),
}
impl std::convert::From<BedErrorPlus> for PyErr {
fn from(err: BedErrorPlus) -> PyErr {
match err {
BedErrorPlus::IOError(_) => PyIOError::new_err(err.to_string()),
BedErrorPlus::ThreadPoolError(_) => PyValueError::new_err(err.to_string()),
BedErrorPlus::BedError(BedError::IidIndexTooBig(_))
| BedErrorPlus::BedError(BedError::SidIndexTooBig(_))
| BedErrorPlus::BedError(BedError::IndexMismatch(_, _, _, _))
| BedErrorPlus::BedError(BedError::IndexesTooBigForFiles(_, _))
| BedErrorPlus::BedError(BedError::SubsetMismatch(_, _, _, _)) => {
PyIndexError::new_err(err.to_string())
}
_ => PyValueError::new_err(err.to_string()),
}
}
}
if (BED_FILE_MAGIC1 != bytes_vector[0]) || (BED_FILE_MAGIC2 != bytes_vector[1]) {
return Err(BedError::IllFormed(filename.to_string()).into());
}
use thiserror::Error;
[...]
// https://docs.rs/thiserror/1.0.23/thiserror/
#[derive(Error, Debug, Clone)]
pub enum BedError {
#[error("Ill-formed BED file. BED file header is incorrect or length is wrong.'{0}'")]
IllFormed(String),
[...]
}
DNA 位置的二进制数据;
输出数组的列。
[... not shown, read bytes for DNA location's data ...]
// Zip in the column of the output array
.zip(out_val.axis_iter_mut(nd::Axis(1)))
// In parallel, decompress the iid info and put it in its column
.par_bridge() // This seems faster that parallel zip
.try_for_each(|(bytes_vector_result, mut col)| {
match bytes_vector_result {
Err(e) => Err(e),
Ok(bytes_vector) => {
for out_iid_i in 0..out_iid_count {
let in_iid_i = iid_index[out_iid_i];
let i_div_4 = in_iid_i / 4;
let i_mod_4 = in_iid_i % 4;
let genotype_byte: u8 = (bytes_vector[i_div_4] >> (i_mod_4 * 2)) & 0x03;
col[out_iid_i] = from_two_bits_to_value[genotype_byte as usize];
}
Ok(())
}
}
})?;
[...]
let mut result_list: Vec<Result<(), BedError>> = vec![Ok(()); sid_count];
nd::par_azip!((mut stats_row in stats.axis_iter_mut(nd::Axis(0)),
&n_observed in &n_observed_array,
&sum_s in &sum_s_array,
&sum2_s in &sum2_s_array,
result_ptr in &mut result_list)
{
[...some code not shown...]
});
// Check the result list for errors
result_list.par_iter().try_for_each(|x| (*x).clone())?;
[...]
def get_num_threads(num_threads=None):
if num_threads is not None:
return num_threads
if "PST_NUM_THREADS" in os.environ:
return int(os.environ["PST_NUM_THREADS"])
if "NUM_THREADS" in os.environ:
return int(os.environ["NUM_THREADS"])
if "MKL_NUM_THREADS" in os.environ:
return int(os.environ["MKL_NUM_THREADS"])
return multiprocessing.cpu_count()
pub fn create_pool(num_threads: usize) -> Result<rayon::ThreadPool, BedErrorPlus> {
match rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
{
Err(e) => Err(e.into()),
Ok(pool) => Ok(pool),
}
}
[ ]
create_pool(num_threads)?.install(|| {
read_no_alloc(
filename,
[ ]
)
})?;
[ ]
def read(
[...]
dtype: Optional[Union[type, str]] = "float32",
[...]
)
[...]
if dtype == np.int8:
reader = read_i8
elif dtype == np.float64:
reader = read_f64
elif dtype == np.float32:
reader = read_f32
else:
raise ValueError(
known, only"
"'int8', 'float32', and 'float64' are allowed."
)
reader(
str(self.filepath),
[...]
)
#[pyfn(m)]
#[pyo3(name = "read_f64")]
fn read_f64_py(
[...]
val: &PyArray2<f64>,
num_threads: usize,
-> Result<(), PyErr> {
[...]
let mut val = unsafe { val.as_array_mut() };
[...]
read_no_alloc(
[...]
f64::NAN,
val,
)
[...]
}
fn read_no_alloc<TOut: Copy + Default + From<i8> + Debug + Sync + Send>(
filename: &str,
[...]
missing_value: TOut,
val: &mut nd::ArrayViewMut2<'_, TOut>,
-> Result<(), BedErrorPlus> {
[...]
}
© THE END
投稿或寻求报道:[email protected]