diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a14c58a20..5c2175fe4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -68,6 +68,22 @@ jobs: - name: Build and Test hf_xet run: | cd hf_xet && cargo test --verbose --no-fail-fast + - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 + with: + python-version: '3.10' + - name: Create venv + run: python3 -m venv .venv + - name: Build wheel + uses: PyO3/maturin-action@04ac600d27cdf7a9a280dadf7147097c42b757ad # v1 + with: + command: develop + sccache: 'true' + working-directory: hf_xet + - name: Python integration tests (hf_xet) + run: | + source .venv/bin/activate + pip install pytest + pytest hf_xet/tests/ -v - name: Check Cargo.lock has no uncommitted changes run: | # the build and test steps would update Cargo.lock if it is out of date @@ -88,6 +104,23 @@ jobs: - name: Build and Test hf_xet run: | cd hf_xet && cargo test --verbose --no-fail-fast + - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 + with: + python-version: '3.10' + - name: Create venv + run: python -m venv .venv + - name: Build wheel + uses: PyO3/maturin-action@04ac600d27cdf7a9a280dadf7147097c42b757ad # v1 + with: + command: develop + sccache: 'true' + working-directory: hf_xet + - name: Python integration tests (hf_xet) + shell: bash + run: | + source .venv/Scripts/activate + pip install pytest + pytest hf_xet/tests/ -v build_and_test-macos: runs-on: macos-latest steps: @@ -108,6 +141,22 @@ jobs: - name: Build and Test hf_xet run: | cd hf_xet && cargo test --verbose --no-fail-fast + - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 + with: + python-version: '3.10' + - name: Create venv + run: python3 -m venv .venv + - name: Build wheel + uses: PyO3/maturin-action@04ac600d27cdf7a9a280dadf7147097c42b757ad # v1 + with: + command: develop + sccache: 'true' + working-directory: hf_xet + - name: Python integration tests (hf_xet) + run: | + source .venv/bin/activate + pip install pytest + pytest hf_xet/tests/ -v build_and_test-wasm: name: Build WASM runs-on: ubuntu-latest diff --git a/hf_xet/Cargo.lock b/hf_xet/Cargo.lock index 40ebdbdce..73885fcf9 100644 --- a/hf_xet/Cargo.lock +++ b/hf_xet/Cargo.lock @@ -1110,10 +1110,9 @@ dependencies = [ [[package]] name = "hf_xet" -version = "1.4.2" +version = "1.5.0" dependencies = [ "async-trait", - "chrono", "hf-xet", "http", "itertools 0.14.0", @@ -1122,6 +1121,7 @@ dependencies = [ "pyo3", "rand 0.10.1", "signal-hook", + "tempfile", "tracing", "winapi", "xet-client", @@ -4171,6 +4171,7 @@ dependencies = [ "itertools 0.14.0", "lazy_static", "more-asserts", + "pyo3", "rand 0.10.1", "serde", "serde_json", diff --git a/hf_xet/Cargo.toml b/hf_xet/Cargo.toml index f2c25ca34..08bb041a9 100644 --- a/hf_xet/Cargo.toml +++ b/hf_xet/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hf_xet" -version = "1.4.2" +version = "1.5.0" edition = "2024" license = "Apache-2.0" @@ -10,14 +10,10 @@ name = "hf_xet" crate-type = ["cdylib", "lib"] [dependencies] +xet-pkg = { package = "hf-xet", path = "../xet_pkg", features = ["python"] } xet-runtime = { path = "../xet_runtime" } xet-client = { path = "../xet_client" } -xet-pkg = { package = "hf-xet", path = "../xet_pkg", features = ["python"] } -async-trait = "0.1" -chrono = "0.4" -itertools = "0.14" -lazy_static = "1.5" pprof = { version = "0.14", features = [ "flamegraph", "prost", @@ -27,9 +23,12 @@ pyo3 = { version = "0.26", features = [ "abi3-py37", "auto-initialize", ] } +async-trait = "0.1" +http = "1" +itertools = "0.14" +lazy_static = "1.5" rand = "0.10" tracing = "0.1" -http = "1" # Unix-specific dependencies [target.'cfg(unix)'.dependencies] @@ -42,12 +41,12 @@ winapi = { version = "0.3", features = ["consoleapi", "wincon", "errhandlingapi" [features] default = ["no-default-cache", "elevated_information_level"] # By default, hf_xet disables the disk cache and uses aggressive logging level extension-module = ["pyo3/extension-module"] # Only enabled when building with maturin -native-tls = ["xet-client/native-tls-vendored"] -native-tls-vendored = ["xet-client/native-tls-vendored"] -no-default-cache = ["xet-runtime/no-default-cache"] +native-tls = ["xet-pkg/native-tls"] +native-tls-vendored = ["xet-pkg/native-tls-vendored"] +no-default-cache = ["xet-pkg/no-default-cache"] profiling = ["pprof"] -tokio-console = ["xet-runtime/tokio-console"] -elevated_information_level = ["xet-client/elevated_information_level", "xet-runtime/elevated_information_level"] +tokio-console = ["xet-pkg/tokio-console"] +elevated_information_level = ["xet-pkg/elevated_information_level"] [profile.release] split-debuginfo = "packed" @@ -64,6 +63,9 @@ debug = true split-debuginfo = "none" +[dev-dependencies] +tempfile = "3" + [profile.opt-test] inherits = "dev" debug = true diff --git a/hf_xet/src/config.rs b/hf_xet/src/config.rs new file mode 100644 index 000000000..07f6823bc --- /dev/null +++ b/hf_xet/src/config.rs @@ -0,0 +1,241 @@ +use pyo3::prelude::*; +use pyo3::types::PyDict; +use xet_runtime::config::XetConfig; + +#[pyclass(name = "XetConfig")] +#[derive(Clone)] +pub struct PyXetConfig { + inner: XetConfig, +} + +impl From for PyXetConfig { + fn from(inner: XetConfig) -> Self { + Self { inner } + } +} + +impl PyXetConfig { + pub fn inner(&self) -> &XetConfig { + &self.inner + } + + pub fn into_inner(self) -> XetConfig { + self.inner + } +} + +#[pymethods] +impl PyXetConfig { + #[new] + fn py_new() -> Self { + Self { + inner: XetConfig::new(), + } + } + + /// Return a new XetConfig with one or more values updated. + /// + /// Can be called in two ways: + /// config.with_config("group.field", value) -- single update + /// config.with_config({"group.field": value, ...}) -- batch update + #[pyo3(name = "with_config")] + #[pyo3(signature = (name_or_dict, value=None))] + fn py_with_config(&self, name_or_dict: &Bound<'_, PyAny>, value: Option<&Bound<'_, PyAny>>) -> PyResult { + let mut new_inner = self.inner.clone(); + + if let Ok(dict) = name_or_dict.downcast::() { + if value.is_some() { + return Err(pyo3::exceptions::PyTypeError::new_err( + "with_config(dict) does not accept a second argument", + )); + } + for (key, val) in dict.iter() { + let key_str: String = key.extract()?; + new_inner.update_field_from_python(&key_str, &val)?; + } + } else { + let name: String = name_or_dict.extract()?; + let val = value.ok_or_else(|| { + pyo3::exceptions::PyTypeError::new_err("with_config(name, value) requires a value argument") + })?; + new_inner.update_field_from_python(&name, val)?; + } + + Ok(Self { inner: new_inner }) + } + + /// Get a configuration value as its native Python type by dotted path + /// (e.g. "data.max_concurrent_file_ingestion"). + #[pyo3(name = "get")] + fn py_get(&self, py: Python<'_>, path: &str) -> PyResult> { + self.inner.get_field_to_python(path, py) + } + + fn __getitem__(&self, py: Python<'_>, key: &str) -> PyResult> { + self.inner + .get_field_to_python(key, py) + .map_err(|_| pyo3::exceptions::PyKeyError::new_err(key.to_owned())) + } + + /// Return all (key, value) pairs as a list of tuples. + /// Keys are dotted paths like "data.max_concurrent_file_ingestion". + fn items(&self, py: Python<'_>) -> PyResult)>> { + self.inner.all_items_to_python(py) + } + + /// Return all dotted-path keys. + fn keys(&self) -> Vec { + self.inner.all_keys() + } + + fn __len__(&self) -> usize { + self.inner.all_keys().len() + } + + fn __iter__(slf: PyRef<'_, Self>, py: Python<'_>) -> PyResult> { + let items = slf.inner.all_items_to_python(py)?; + Py::new(py, PyXetConfigIter { items, index: 0 }) + } + + fn __repr__(&self) -> String { + format!("XetConfig({:?})", self.inner) + } + + fn __str__(&self) -> String { + format!("{:?}", self.inner) + } +} + +#[pyclass] +pub(crate) struct PyXetConfigIter { + items: Vec<(String, Py)>, + index: usize, +} + +#[pymethods] +impl PyXetConfigIter { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(&mut self, py: Python<'_>) -> Option<(String, Py)> { + if self.index < self.items.len() { + let (key, value) = &self.items[self.index]; + self.index += 1; + Some((key.clone(), value.clone_ref(py))) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use pyo3::exceptions::{PyTypeError, PyValueError}; + use pyo3::prelude::*; + use pyo3::types::{PyDict, PyString}; + use xet_runtime::config::XetConfig; + + use super::*; + + #[test] + fn from_into_roundtrip_preserves_xet_config() { + let cfg = XetConfig::default(); + let py_cfg = PyXetConfig::from(cfg.clone()); + assert_eq!(format!("{:?}", py_cfg.into_inner()), format!("{:?}", cfg)); + } + + #[test] + fn py_with_config_single_updates_inner() { + Python::attach(|py| { + let base = PyXetConfig::from(XetConfig::new()); + let original = base.inner().data.max_concurrent_file_ingestion; + let key = PyString::new(py, "data.max_concurrent_file_ingestion"); + let val = 17usize.into_pyobject(py).unwrap(); + let updated = base + .py_with_config(key.as_any(), Some(&val)) + .expect("with_config single-arg form"); + assert_eq!(updated.inner().data.max_concurrent_file_ingestion, 17); + assert_eq!(base.inner().data.max_concurrent_file_ingestion, original); + }); + } + + #[test] + fn py_with_config_dict_updates_inner() { + Python::attach(|py| { + let base = PyXetConfig::from(XetConfig::new()); + let dict = PyDict::new(py); + dict.set_item("data.max_concurrent_file_ingestion", 21usize) + .expect("dict set_item"); + let updated = base.py_with_config(dict.as_any(), None).expect("with_config dict form"); + assert_eq!(updated.inner().data.max_concurrent_file_ingestion, 21); + }); + } + + #[test] + fn py_with_config_dict_rejects_second_positional_argument() { + Python::attach(|py| { + let base = PyXetConfig::from(XetConfig::new()); + let dict = PyDict::new(py); + dict.set_item("data.max_concurrent_file_ingestion", 1usize).unwrap(); + let dup = PyString::new(py, "x"); + match base.py_with_config(dict.as_any(), Some(dup.as_any())) { + Ok(_) => panic!("expected TypeError"), + Err(err) => assert!(err.is_instance_of::(py)), + } + }); + } + + #[test] + fn py_with_config_name_requires_value() { + Python::attach(|py| { + let base = PyXetConfig::from(XetConfig::new()); + let key = PyString::new(py, "data.max_concurrent_file_ingestion"); + match base.py_with_config(key.as_any(), None) { + Ok(_) => panic!("expected TypeError"), + Err(err) => assert!(err.is_instance_of::(py)), + } + }); + } + + #[test] + fn py_with_config_unknown_path_errors() { + Python::attach(|py| { + let base = PyXetConfig::from(XetConfig::new()); + let key = PyString::new(py, "not_a_real_group.some_field"); + let val = 1i64.into_pyobject(py).unwrap(); + match base.py_with_config(key.as_any(), Some(&val)) { + Ok(_) => panic!("expected ValueError"), + Err(err) => assert!(err.is_instance_of::(py)), + } + }); + } + + #[test] + fn keys_include_known_setting() { + let cfg = PyXetConfig::from(XetConfig::new()); + assert!(cfg.keys().contains(&String::from("data.max_concurrent_file_ingestion"))); + } + + #[test] + fn items_count_matches_keys_len_under_gil() { + Python::attach(|py| { + let cfg = PyXetConfig::from(XetConfig::new()); + let k = cfg.keys().len(); + assert_eq!(cfg.items(py).expect("items").len(), k); + }); + } + + #[test] + fn py_get_reads_back_roundtrip_via_python_extract() { + Python::attach(|py| { + let cfg = PyXetConfig::from(XetConfig::new()); + let key = PyString::new(py, "data.max_concurrent_file_ingestion"); + let val = 3usize.into_pyobject(py).unwrap(); + let updated = cfg.py_with_config(key.as_any(), Some(&val)).unwrap(); + let got = updated.py_get(py, "data.max_concurrent_file_ingestion").unwrap(); + let py_int: isize = got.extract(py).unwrap(); + assert_eq!(py_int, 3); + }); + } +} diff --git a/hf_xet/src/headers.rs b/hf_xet/src/headers.rs new file mode 100644 index 000000000..a8672d25f --- /dev/null +++ b/hf_xet/src/headers.rs @@ -0,0 +1,124 @@ +use std::collections::HashMap; + +use http::header::{self, HeaderMap, HeaderName, HeaderValue}; +use pyo3::PyResult; +use pyo3::exceptions::PyValueError; + +const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); + +/// Build a HeaderMap from a Python dict and merge in the USER_AGENT. +pub(crate) fn build_headers_with_user_agent(request_headers: Option>) -> PyResult { + let mut map = request_headers.map(build_header_map).transpose()?.unwrap_or_default(); + + let combined_user_agent = if let Some(existing_ua) = map.get(header::USER_AGENT) { + let existing_str = existing_ua.to_str().unwrap_or(""); + format!("{}; {}", existing_str, USER_AGENT) + } else { + USER_AGENT.to_string() + }; + + let user_agent_value = + HeaderValue::from_str(&combined_user_agent).unwrap_or_else(|_| HeaderValue::from_static(USER_AGENT)); + map.insert(header::USER_AGENT, user_agent_value); + + Ok(map) +} + +/// Build a HeaderMap from a Python dict. +pub(crate) fn build_header_map(headers: HashMap) -> PyResult { + let mut map = HeaderMap::with_capacity(headers.len()); + for (key, value) in headers { + let name = HeaderName::from_bytes(key.as_bytes()) + .map_err(|e| PyValueError::new_err(format!("Invalid header name '{}': {}", key, e)))?; + let value = HeaderValue::from_str(&value) + .map_err(|e| PyValueError::new_err(format!("Invalid header value for '{}': {}", key, e)))?; + map.insert(name, value); + } + Ok(map) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_headers_with_none_empty_hashmap() { + let empty_map: HashMap = HashMap::new(); + let headers = build_headers_with_user_agent(Some(empty_map)).unwrap(); + + // Should have exactly one header: USER_AGENT + assert_eq!(headers.len(), 1); + assert!(headers.contains_key(header::USER_AGENT)); + + let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap(); + assert_eq!(user_agent, USER_AGENT); + + let headers = build_headers_with_user_agent(None).unwrap(); + + // Should have exactly one header: USER_AGENT + assert_eq!(headers.len(), 1); + assert!(headers.contains_key(header::USER_AGENT)); + + let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap(); + assert_eq!(user_agent, USER_AGENT); + } + + #[test] + fn test_build_headers_with_valid_headers() { + let mut headers_map = HashMap::new(); + headers_map.insert("Content-Type".to_string(), "application/json".to_string()); + headers_map.insert("Authorization".to_string(), "Bearer token123".to_string()); + + let headers = build_headers_with_user_agent(Some(headers_map)).unwrap(); + + // Should have 3 headers: Content-Type, Authorization, and USER_AGENT + assert_eq!(headers.len(), 3); + + // Verify each header was converted correctly + assert_eq!(headers.get(header::CONTENT_TYPE).unwrap().to_str().unwrap(), "application/json"); + assert_eq!(headers.get(header::AUTHORIZATION).unwrap().to_str().unwrap(), "Bearer token123"); + + // Verify USER_AGENT was added + let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap(); + assert_eq!(user_agent, USER_AGENT); + } + + #[test] + fn test_build_headers_appends_to_existing_user_agent() { + let mut headers_map = HashMap::new(); + headers_map.insert("User-Agent".to_string(), "CustomClient/1.0".to_string()); + + let headers = build_headers_with_user_agent(Some(headers_map)).unwrap(); + + // Should have exactly one header: USER_AGENT + assert_eq!(headers.len(), 1); + + // Verify USER_AGENT was appended to existing one + let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap(); + assert_eq!(user_agent, format!("CustomClient/1.0; {}", USER_AGENT)); + } + + #[test] + fn test_build_headers_with_invalid_header_name_or_value() { + let mut headers_map = HashMap::new(); + headers_map.insert("Invalid Header!".to_string(), "value".to_string()); + + let result = build_headers_with_user_agent(Some(headers_map)); + + // Should return an error for invalid header name + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("Invalid header name")); + + let mut headers_map = HashMap::new(); + // Header values cannot contain newlines + headers_map.insert("X-Custom".to_string(), "value\nwith\nnewlines".to_string()); + + let result = build_headers_with_user_agent(Some(headers_map)); + + // Should return an error for invalid header value + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("Invalid header value")); + } +} diff --git a/hf_xet/src/legacy/functions.rs b/hf_xet/src/legacy/functions.rs new file mode 100644 index 000000000..2d41f0be4 --- /dev/null +++ b/hf_xet/src/legacy/functions.rs @@ -0,0 +1,303 @@ +/// Deprecated top-level upload/download/hash functions. +/// +/// These are the original `hf_xet` module-level functions from the pre-1.x API. +/// They are kept here for backward compatibility with older versions of +/// `huggingface_hub`. New code should use the ``XetSession`` object-oriented API. +use std::collections::HashMap; +use std::sync::Arc; + +use pyo3::exceptions::PyKeyboardInterrupt; +use pyo3::prelude::*; +use rand::RngExt; +use tracing::debug; +use xet_pkg::legacy::progress_tracking::TrackingProgressUpdater; +use xet_pkg::legacy::{Sha256Policy, XetFileInfo, data_client}; + +use super::progress_update::WrappedProgressUpdater; +use super::runtime::async_run; +use super::token_refresh::WrappedTokenRefresher; +use super::types::{PyXetDownloadInfo, PyXetUploadInfo}; +use crate::convert_xet_error; +use crate::headers::build_headers_with_user_agent; + +fn legacy_headers(request_headers: Option>) -> PyResult>> { + Ok(Some(Arc::new(build_headers_with_user_agent(request_headers)?))) +} + +fn emit_deprecation(py: Python, msg: &str) -> PyResult<()> { + let warnings = py.import("warnings")?; + let category = py.get_type::(); + // stacklevel=2 so the warning points at the caller's frame, not this wrapper. + warnings.call_method1("warn", (msg, category, 2i32))?; + Ok(()) +} + +type DestinationPath = String; + +impl From for (XetFileInfo, DestinationPath) { + fn from(pf: PyXetDownloadInfo) -> Self { + let file_info = match pf.file_size { + Some(size) => XetFileInfo::new(pf.hash, size), + None => XetFileInfo::new_hash_only(pf.hash), + }; + (file_info, pf.destination_path) + } +} + +/// Upload raw bytes to Xet storage. +/// +/// .. deprecated:: +/// Use :class:`XetSession` and :meth:`XetUploadCommit.start_upload_bytes` instead. +#[pyfunction] +#[pyo3(signature = (file_contents, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), + text_signature = "(file_contents, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=False)")] +#[allow(clippy::too_many_arguments)] +pub fn upload_bytes( + py: Python, + file_contents: Vec>, + endpoint: Option, + token_info: Option<(String, u64)>, + token_refresher: Option>, + progress_updater: Option>, + _repo_type: Option, + request_headers: Option>, + sha256s: Option>, + skip_sha256: bool, +) -> PyResult> { + emit_deprecation( + py, + "hf_xet.upload_bytes() is deprecated. Use XetSession().new_upload_commit().start_upload_bytes() instead.", + )?; + + if skip_sha256 && sha256s.is_some() { + return Err(pyo3::exceptions::PyValueError::new_err("skip_sha256=True and sha256s are mutually exclusive")); + } + + if let Some(ref s) = sha256s + && s.len() != file_contents.len() + { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "sha256s length ({}) must match file_contents length ({})", + s.len(), + file_contents.len() + ))); + } + + let sha256_policies: Vec = match sha256s { + _ if skip_sha256 => vec![Sha256Policy::Skip; file_contents.len()], + Some(v) => v.iter().map(|s| Sha256Policy::from_hex(s)).collect(), + None => vec![Sha256Policy::Compute; file_contents.len()], + }; + + let ctx = super::runtime::get_or_init_runtime().map_err(super::runtime::convert_multithreading_error)?; + let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new); + let updater = progress_updater + .map(|f| WrappedProgressUpdater::new(f, ctx.clone())) + .transpose()? + .map(Arc::new); + let header_map = legacy_headers(request_headers)?; + let x: u64 = rand::rng().random(); + + async_run(py, async move { + debug!( + "upload_bytes (legacy) call {x:x}: (PID = {}) Uploading {} files as bytes.", + std::process::id(), + file_contents.len(), + ); + + let out: Vec = data_client::upload_bytes_async( + &ctx, + file_contents, + sha256_policies, + endpoint, + token_info, + refresher.map(|v| v as Arc<_>), + updater.map(|v| v as Arc<_>), + header_map, + ) + .await + .map_err(convert_xet_error)? + .into_iter() + .map(PyXetUploadInfo::from) + .collect(); + + debug!("upload_bytes (legacy) call {x:x} finished."); + PyResult::Ok(out) + }) +} + +/// Upload files from disk to Xet storage. +/// +/// .. deprecated:: +/// Use :class:`XetSession` and :meth:`XetUploadCommit.start_upload_file` instead. +#[pyfunction] +#[pyo3(signature = (file_paths, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), + text_signature = "(file_paths, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=False)")] +#[allow(clippy::too_many_arguments)] +pub fn upload_files( + py: Python, + file_paths: Vec, + endpoint: Option, + token_info: Option<(String, u64)>, + token_refresher: Option>, + progress_updater: Option>, + _repo_type: Option, + request_headers: Option>, + sha256s: Option>, + skip_sha256: bool, +) -> PyResult> { + emit_deprecation( + py, + "hf_xet.upload_files() is deprecated. Use XetSession().new_upload_commit().start_upload_file() instead.", + )?; + + if skip_sha256 && sha256s.is_some() { + return Err(pyo3::exceptions::PyValueError::new_err("skip_sha256=True and sha256s are mutually exclusive")); + } + + if let Some(ref s) = sha256s + && s.len() != file_paths.len() + { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "sha256s length ({}) must match file_paths length ({})", + s.len(), + file_paths.len() + ))); + } + + let sha256_policies: Vec = match sha256s { + _ if skip_sha256 => vec![Sha256Policy::Skip; file_paths.len()], + Some(v) => v.iter().map(|s| Sha256Policy::from_hex(s)).collect(), + None => vec![Sha256Policy::Compute; file_paths.len()], + }; + + let ctx = super::runtime::get_or_init_runtime().map_err(super::runtime::convert_multithreading_error)?; + let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new); + let updater = progress_updater + .map(|f| WrappedProgressUpdater::new(f, ctx.clone())) + .transpose()? + .map(Arc::new); + let header_map = legacy_headers(request_headers)?; + let x: u64 = rand::rng().random(); + + async_run(py, async move { + debug!( + "upload_files (legacy) call {x:x}: (PID = {}) Uploading {} files.", + std::process::id(), + file_paths.len(), + ); + + let out: Vec = data_client::upload_async( + &ctx, + file_paths, + sha256_policies, + endpoint, + token_info, + refresher.map(|v| v as Arc<_>), + updater.map(|v| v as Arc<_>), + header_map, + ) + .await + .map_err(convert_xet_error)? + .into_iter() + .map(PyXetUploadInfo::from) + .collect(); + + debug!("upload_files (legacy) call {x:x} finished."); + PyResult::Ok(out) + }) +} + +/// Compute Xet hashes for files without uploading. +#[pyfunction] +#[pyo3(signature = (file_paths), text_signature = "(file_paths)")] +pub fn hash_files(py: Python, file_paths: Vec) -> PyResult> { + let ctx = super::runtime::get_or_init_runtime().map_err(super::runtime::convert_multithreading_error)?; + + async_run(py, async move { + let out: Vec = data_client::hash_files_async(&ctx, file_paths) + .await + .map_err(convert_xet_error)? + .into_iter() + .map(PyXetUploadInfo::from) + .collect(); + + PyResult::Ok(out) + }) +} + +/// Download files from Xet storage to local paths. +/// +/// .. deprecated:: +/// Use :class:`XetSession` and :meth:`XetFileDownloadGroup.start_download_file` instead. +#[pyfunction] +#[pyo3(signature = (files, endpoint, token_info, token_refresher, progress_updater, request_headers=None), + text_signature = "(files, endpoint, token_info, token_refresher, progress_updater, request_headers=None)")] +pub fn download_files( + py: Python, + files: Vec, + endpoint: Option, + token_info: Option<(String, u64)>, + token_refresher: Option>, + progress_updater: Option>>, + request_headers: Option>, +) -> PyResult> { + emit_deprecation( + py, + "hf_xet.download_files() is deprecated. Use XetSession().new_file_download_group().start_download_file() instead.", + )?; + + let ctx = super::runtime::get_or_init_runtime().map_err(super::runtime::convert_multithreading_error)?; + let file_infos: Vec<_> = files.into_iter().map(<(XetFileInfo, DestinationPath)>::from).collect(); + let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new); + let updaters = progress_updater + .map(|fs| try_parse_progress_updaters(fs, ctx.clone())) + .transpose()?; + let header_map = legacy_headers(request_headers)?; + let x: u64 = rand::rng().random(); + + async_run(py, async move { + debug!( + "download_files (legacy) call {x:x}: (PID = {}) Downloading {} files.", + std::process::id(), + file_infos.len(), + ); + + let out = data_client::download_async( + &ctx, + file_infos, + endpoint, + token_info, + refresher.map(|v| v as Arc<_>), + updaters, + header_map, + ) + .await + .map_err(convert_xet_error)?; + + debug!("download_files (legacy) call {x:x}: Completed."); + PyResult::Ok(out) + }) +} + +/// Force a SIGINT shutdown when it has been intercepted by another process. +/// +/// .. deprecated:: +/// Use :meth:`XetSession.abort` or :meth:`XetSession.sigint_abort` instead. +#[pyfunction] +pub fn force_sigint_shutdown(py: Python) -> PyResult<()> { + emit_deprecation(py, "hf_xet.force_sigint_shutdown() is deprecated. Use XetSession.sigint_abort() instead.")?; + super::runtime::perform_sigint_shutdown(); + Err(PyKeyboardInterrupt::new_err(())) +} + +fn try_parse_progress_updaters( + funcs: Vec>, + ctx: xet_runtime::core::XetContext, +) -> PyResult>> { + let mut updaters = Vec::with_capacity(funcs.len()); + for func in funcs { + updaters.push(Arc::new(WrappedProgressUpdater::new(func, ctx.clone())?) as Arc); + } + Ok(updaters) +} diff --git a/hf_xet/src/legacy/mod.rs b/hf_xet/src/legacy/mod.rs new file mode 100644 index 000000000..baa3ae320 --- /dev/null +++ b/hf_xet/src/legacy/mod.rs @@ -0,0 +1,14 @@ +//! Legacy types and functions kept for backward compatibility with `huggingface_hub`. +//! +//! All items in this module emit a ``DeprecationWarning`` when called from Python. +//! New code should use the ``XetSession`` object-oriented API. + +pub mod functions; +pub(super) mod progress_update; +pub(super) mod runtime; +pub(super) mod token_refresh; +mod types; + +pub use functions::{download_files, force_sigint_shutdown, hash_files, upload_bytes, upload_files}; +pub use progress_update::{PyItemProgressUpdate, PyTotalProgressUpdate}; +pub use types::{PyPointerFile, PyXetDownloadInfo, PyXetUploadInfo}; diff --git a/hf_xet/src/progress_update.rs b/hf_xet/src/legacy/progress_update.rs similarity index 99% rename from hf_xet/src/progress_update.rs rename to hf_xet/src/legacy/progress_update.rs index bb3f74a64..e52fa687b 100644 --- a/hf_xet/src/progress_update.rs +++ b/hf_xet/src/legacy/progress_update.rs @@ -11,7 +11,7 @@ use xet_pkg::legacy::progress_tracking::{ProgressUpdate, TrackingProgressUpdater use xet_runtime::core::XetContext; use xet_runtime::error_printer::ErrorPrinter; -use crate::runtime::convert_multithreading_error; +use super::runtime::convert_multithreading_error; /// Python-exposed versions of the per-item and total progress update classes. /// @@ -139,7 +139,7 @@ struct WrappedProgressUpdaterImpl { impl Debug for WrappedProgressUpdaterImpl { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "WrappedTokenRefresher({})", self.name) + write!(f, "WrappedProgressUpdater({})", self.name) } } diff --git a/hf_xet/src/runtime.rs b/hf_xet/src/legacy/runtime.rs similarity index 100% rename from hf_xet/src/runtime.rs rename to hf_xet/src/legacy/runtime.rs diff --git a/hf_xet/src/token_refresh.rs b/hf_xet/src/legacy/token_refresh.rs similarity index 100% rename from hf_xet/src/token_refresh.rs rename to hf_xet/src/legacy/token_refresh.rs diff --git a/hf_xet/src/legacy/types.rs b/hf_xet/src/legacy/types.rs new file mode 100644 index 000000000..1ec03e990 --- /dev/null +++ b/hf_xet/src/legacy/types.rs @@ -0,0 +1,149 @@ +//! Legacy download/upload info types kept for backward compatibility with `huggingface_hub`. + +use pyo3::prelude::*; +use xet_pkg::legacy::XetFileInfo; + +// ── PyXetDownloadInfo ───────────────────────────────────────────────────────── + +// TODO: we won't need to subclass this in the next major version update. +#[pyclass(subclass)] +#[derive(Clone, Debug)] +pub struct PyXetDownloadInfo { + #[pyo3(get, set)] + pub(crate) destination_path: String, + #[pyo3(get)] + pub(crate) hash: String, + #[pyo3(get)] + pub(crate) file_size: Option, +} + +#[pymethods] +impl PyXetDownloadInfo { + #[new] + #[pyo3(signature = (destination_path, hash, file_size=None))] + pub fn new(destination_path: String, hash: String, file_size: Option) -> Self { + Self { + destination_path, + hash, + file_size, + } + } + + fn __str__(&self) -> String { + format!("{self:?}") + } + + fn __repr__(&self) -> String { + let size_str = self.file_size.map_or("None".to_string(), |s| s.to_string()); + format!("PyXetDownloadInfo({}, {}, {})", self.destination_path, self.hash, size_str) + } +} + +// ── PyXetUploadInfo ─────────────────────────────────────────────────────────── + +/// Result returned by the legacy ``upload_bytes`` / ``upload_files`` / ``hash_files`` functions. +#[pyclass] +#[derive(Clone, Debug)] +pub struct PyXetUploadInfo { + #[pyo3(get)] + pub hash: String, + #[pyo3(get)] + pub file_size: u64, + #[pyo3(get)] + pub sha256: Option, +} + +#[pymethods] +impl PyXetUploadInfo { + #[new] + pub fn new(hash: String, file_size: u64) -> Self { + Self { + hash, + file_size, + sha256: None, + } + } + + fn __str__(&self) -> String { + format!("{self:?}") + } + + fn __repr__(&self) -> String { + format!("PyXetUploadInfo({}, {}, {:?})", self.hash, self.file_size, self.sha256) + } + + /// Alias kept for backward compatibility. + #[getter] + fn filesize(self_: PyRef<'_, Self>) -> u64 { + self_.file_size + } +} + +impl From for PyXetUploadInfo { + fn from(xf: XetFileInfo) -> Self { + Self { + hash: xf.hash().to_owned(), + file_size: xf.file_size().expect("upload metadata must always include a known file size"), + sha256: xf.sha256().map(str::to_owned), + } + } +} + +// ── PyPointerFile ───────────────────────────────────────────────────────────── + +/// Legacy subclass of :class:`PyXetDownloadInfo`. +/// +/// Kept for backward compatibility with old versions of ``huggingface_hub``. +// TODO: remove during the next major version update. +#[pyclass(extends=PyXetDownloadInfo)] +#[derive(Clone, Debug)] +pub struct PyPointerFile {} + +#[pymethods] +impl PyPointerFile { + #[new] + pub fn new(path: String, hash: String, filesize: u64) -> (Self, PyXetDownloadInfo) { + (PyPointerFile {}, PyXetDownloadInfo::new(path, hash, Some(filesize))) + } + + fn __str__(&self) -> String { + format!("{self:?}") + } + + fn __repr__(self_: PyRef<'_, Self>) -> String { + let super_ = self_.as_super(); + let size_str = super_.file_size.map_or("None".to_string(), |s| s.to_string()); + format!("PyPointerFile({}, {}, {})", super_.destination_path, super_.hash, size_str) + } + + #[getter] + fn get_path(self_: PyRef<'_, Self>) -> String { + self_.as_super().destination_path.clone() + } + + #[setter] + fn set_path(mut self_: PyRefMut<'_, Self>, path: String) { + self_.as_super().destination_path = path; + } + + #[getter] + fn filesize(self_: PyRef<'_, Self>) -> Option { + self_.as_super().file_size + } +} + +#[cfg(test)] +mod tests { + use pyo3::Python; + + use super::*; + + #[test] + fn test_pyxetdownloadinfo_new() { + let _ = Python::attach(|_py| {}); + let info = PyXetDownloadInfo::new("out.bin".into(), "abc123".into(), Some(1024)); + assert_eq!(info.hash, "abc123"); + assert_eq!(info.file_size, Some(1024)); + assert_eq!(info.destination_path, "out.bin"); + } +} diff --git a/hf_xet/src/lib.rs b/hf_xet/src/lib.rs index 955b5f4d8..337cf47f7 100644 --- a/hf_xet/src/lib.rs +++ b/hf_xet/src/lib.rs @@ -1,494 +1,111 @@ +pub mod config; +mod headers; +mod legacy; mod logging; -mod progress_update; -mod runtime; -mod token_refresh; +mod py_download_stream_group; +mod py_download_stream_handle; +mod py_file_download_group; +mod py_file_download_handle; +mod py_file_upload_handle; +mod py_stream_upload_handle; +mod py_upload_commit; +mod py_xet_session; +pub(crate) mod utils; -use std::collections::HashMap; -use std::fmt::Debug; -use std::iter::IntoIterator; -use std::sync::Arc; - -use http::header::{self, HeaderMap, HeaderName, HeaderValue}; -use itertools::Itertools; -use pyo3::exceptions::{PyKeyboardInterrupt, PyValueError}; use pyo3::prelude::*; -use pyo3::pyfunction; -use rand::RngExt; -use runtime::{async_run, get_or_init_runtime}; -use token_refresh::WrappedTokenRefresher; -use tracing::debug; -use xet_pkg::XetError; -use xet_pkg::legacy::progress_tracking::TrackingProgressUpdater; -use xet_pkg::legacy::{Sha256Policy, XetFileInfo, data_client}; -use xet_runtime::core::{XetContext, file_handle_limits}; +pub(crate) use utils::{blocking_call_with_signal_check, convert_xet_error}; +use xet_runtime::core::file_handle_limits; use crate::logging::init_logging; -use crate::progress_update::WrappedProgressUpdater; - -const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); // For profiling #[cfg(feature = "profiling")] pub(crate) mod profiling; -/// Converts a HashMap of headers to a HeaderMap and merges in the USER_AGENT. -/// -/// If the input contains a User-Agent header, the USER_AGENT is appended to it. -/// Otherwise, USER_AGENT is set as the only User-Agent header. -fn build_headers_with_user_agent(request_headers: Option>) -> PyResult>> { - let mut map = request_headers - .map(|headers| { - let mut map = HeaderMap::new(); - for (key, value) in headers { - let name = HeaderName::from_bytes(key.as_bytes()) - .map_err(|e| PyValueError::new_err(format!("Invalid header name '{}': {}", key, e)))?; - let value = HeaderValue::from_str(&value) - .map_err(|e| PyValueError::new_err(format!("Invalid header value for '{}': {}", key, e)))?; - map.insert(name, value); - } - Ok::<_, PyErr>(map) - }) - .transpose()? - .unwrap_or_default(); - - // Append our USER_AGENT to any existing User-Agent header, or add it if not present - let combined_user_agent = if let Some(existing_ua) = map.get(header::USER_AGENT) { - // Append our user agent to the existing one - let existing_str = existing_ua.to_str().unwrap_or(""); - format!("{}; {}", existing_str, USER_AGENT) - } else { - // No existing user agent, use ours - USER_AGENT.to_string() - }; - - // Try to create the combined header value, fall back gracefully if invalid - let user_agent_value = HeaderValue::from_str(&combined_user_agent) - .or_else(|_: http::header::InvalidHeaderValue| { - Ok::(HeaderValue::from_static(USER_AGENT)) - }) - .unwrap_or_else(|_: http::header::InvalidHeaderValue| HeaderValue::from_static("unknown")); - map.insert(header::USER_AGENT, user_agent_value); - - Ok(Some(Arc::new(map))) -} - -fn convert_xet_error(e: impl Into) -> PyErr { - PyErr::from(e.into()) -} - -#[pyfunction] -#[pyo3(signature = (file_contents, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), text_signature = "(file_contents: List[bytes], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]], sha256s: Optional[List[str]], skip_sha256: bool = False) -> List[PyXetUploadInfo]")] -#[allow(clippy::too_many_arguments)] -pub fn upload_bytes( - py: Python, - file_contents: Vec>, - endpoint: Option, - token_info: Option<(String, u64)>, - token_refresher: Option>, - progress_updater: Option>, - _repo_type: Option, - request_headers: Option>, - sha256s: Option>, - skip_sha256: bool, -) -> PyResult> { - if skip_sha256 && sha256s.is_some() { - return Err(PyValueError::new_err("skip_sha256=True and sha256s are mutually exclusive")); - } - - if let Some(ref s) = sha256s - && s.len() != file_contents.len() - { - return Err(PyValueError::new_err(format!( - "sha256s length ({}) must match file_contents length ({})", - s.len(), - file_contents.len() - ))); - } - - let sha256_policies: Vec = match sha256s { - _ if skip_sha256 => vec![Sha256Policy::Skip; file_contents.len()], - Some(v) => v.iter().map(|s| Sha256Policy::from_hex(s)).collect(), - None => vec![Sha256Policy::Compute; file_contents.len()], - }; - - let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new); - let runtime = get_or_init_runtime().map_err(convert_xet_error)?; - let updater = progress_updater - .map(|p| WrappedProgressUpdater::new(p, runtime.clone())) - .transpose()? - .map(Arc::new); - let x: u64 = rand::rng().random(); - - // Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT - let header_map = build_headers_with_user_agent(request_headers)?; +// ── XetTaskState Python enum ────────────────────────────────────────────────── - async_run(py, async move { - debug!( - "Upload bytes call {x:x}: (PID = {}) Uploading {} files as bytes.", - std::process::id(), - file_contents.len(), - ); - let out: Vec = data_client::upload_bytes_async( - &runtime, - file_contents, - sha256_policies, - endpoint, - token_info, - refresher.map(|v| v as Arc<_>), - updater.map(|v| v as Arc<_>), - header_map, - ) - .await - .map_err(convert_xet_error)? - .into_iter() - .map(PyXetUploadInfo::from) - .collect(); - - debug!("Upload bytes call {x:x} finished."); - - PyResult::Ok(out) - }) -} - -#[pyfunction] -#[pyo3(signature = (file_paths, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), text_signature = "(file_paths: List[str], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]], sha256s: Optional[List[str]], skip_sha256: bool = False) -> List[PyXetUploadInfo]")] -#[allow(clippy::too_many_arguments)] -pub fn upload_files( - py: Python, - file_paths: Vec, - endpoint: Option, - token_info: Option<(String, u64)>, - token_refresher: Option>, - progress_updater: Option>, - _repo_type: Option, - request_headers: Option>, - sha256s: Option>, - skip_sha256: bool, -) -> PyResult> { - if skip_sha256 && sha256s.is_some() { - return Err(PyValueError::new_err("skip_sha256=True and sha256s are mutually exclusive")); - } - - if let Some(ref s) = sha256s - && s.len() != file_paths.len() - { - return Err(PyValueError::new_err(format!( - "sha256s length ({}) must match file_paths length ({})", - s.len(), - file_paths.len() - ))); - } - - let sha256_policies: Vec = match sha256s { - _ if skip_sha256 => vec![Sha256Policy::Skip; file_paths.len()], - Some(v) => v.iter().map(|s| Sha256Policy::from_hex(s)).collect(), - None => vec![Sha256Policy::Compute; file_paths.len()], - }; - - let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new); - let runtime = get_or_init_runtime().map_err(convert_xet_error)?; - let updater = progress_updater - .map(|p| WrappedProgressUpdater::new(p, runtime.clone())) - .transpose()? - .map(Arc::new); - - let file_names = file_paths.iter().take(3).join(", "); - - let x: u64 = rand::rng().random(); - - // Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT - let header_map = build_headers_with_user_agent(request_headers)?; - - async_run(py, async move { - debug!( - "Upload call {x:x}: (PID = {}) Uploading {} files {file_names}{}", - std::process::id(), - file_paths.len(), - if file_paths.len() > 3 { "..." } else { "." } - ); - let out: Vec = data_client::upload_async( - &runtime, - file_paths, - sha256_policies, - endpoint, - token_info, - refresher.map(|v| v as Arc<_>), - updater.map(|v| v as Arc<_>), - header_map, - ) - .await - .map_err(convert_xet_error)? - .into_iter() - .map(PyXetUploadInfo::from) - .collect(); - debug!("Upload call {x:x} finished."); - PyResult::Ok(out) - }) -} - -/// Compute xet hashes for files without uploading. -/// -/// This function computes cryptographic hashes for the specified files using the same -/// chunking and hashing algorithm as upload operations, but without requiring -/// authentication or server connection. The resulting hashes can be used to verify -/// file integrity after downloads or to determine which files need to be uploaded. +/// Task state returned by ``status()`` on sessions, commits, and download groups. /// -/// Args: -/// file_paths: List of file paths to hash. +/// Raises on the ``Error`` variant. Compare with class-level constants: /// -/// Returns: -/// List[PyXetUploadInfo]: List of hash results in the same order as input paths. -/// Each result contains the hash (as hex string) and file size in bytes. +/// ```python +/// from hf_xet import XetTaskState +/// if session.status() == XetTaskState.Running: +/// ... +/// ``` /// -/// Raises: -/// RuntimeError: If any file cannot be read or hashed. +/// # Why not expose `xet_pkg::xet_session::XetTaskState` directly? /// -/// Example: -/// >>> import hf_xet -/// >>> results = hf_xet.hash_files(["/path/to/file1.txt", "/path/to/file2.txt"]) -/// >>> for path, info in zip(file_paths, results): -/// ... print(f"Hash: {info.hash}, Size: {info.file_size}") -/// -/// Note: -/// This function is primarily used for validation and verification of transferred -/// files. Clients can verify that downloaded files are correctly reassembled by -/// comparing the computed hash with the expected hash from the server. -#[pyfunction] -#[pyo3(signature = (file_paths), text_signature = "(file_paths: List[str]) -> List[PyXetUploadInfo]")] -pub fn hash_files(py: Python, file_paths: Vec) -> PyResult> { - async_run(py, async move { - let runtime = get_or_init_runtime().map_err(convert_xet_error)?; - let out: Vec = data_client::hash_files_async(&runtime, file_paths) - .await - .map_err(convert_xet_error)? - .into_iter() - .map(PyXetUploadInfo::from) - .collect(); - - PyResult::Ok(out) - }) -} - -#[pyfunction] -#[pyo3(signature = (files, endpoint, token_info, token_refresher, progress_updater, request_headers=None), text_signature = "(files: List[PyXetDownloadInfo], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[List[Callable[[int], None]]], request_headers: Optional[Dict[str, str]]) -> List[str]")] -pub fn download_files( - py: Python, - files: Vec, - endpoint: Option, - token_info: Option<(String, u64)>, - token_refresher: Option>, - progress_updater: Option>>, - request_headers: Option>, -) -> PyResult> { - let file_infos: Vec<_> = files.into_iter().map(<(XetFileInfo, DestinationPath)>::from).collect(); - let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new); - let runtime = get_or_init_runtime().map_err(convert_xet_error)?; - let updaters = progress_updater.map(|f| try_parse_progress_updaters(f, &runtime)).transpose()?; - - // Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT - let header_map = build_headers_with_user_agent(request_headers)?; - - let x: u64 = rand::rng().random(); - - let file_names = file_infos.iter().take(3).map(|(_, p)| p).join(", "); - - async_run(py, async move { - debug!( - "Download call {x:x}: (PID = {}) Downloading {} files {file_names}{}", - std::process::id(), - file_infos.len(), - if file_infos.len() > 3 { "..." } else { "." } - ); - let out: Vec = data_client::download_async( - &runtime, - file_infos, - endpoint, - token_info, - refresher.map(|v| v as Arc<_>), - updaters, - header_map, - ) - .await - .map_err(convert_xet_error)?; - - debug!("Download call {x:x}: Completed."); - - PyResult::Ok(out) - }) -} - -#[pyfunction] -pub fn force_sigint_shutdown() -> PyResult<()> { - // Force a signint shutdown in the case where it gets intercepted by another process. - crate::runtime::perform_sigint_shutdown(); - Err(PyKeyboardInterrupt::new_err(())) -} - -fn try_parse_progress_updaters( - funcs: Vec>, - ctx: &XetContext, -) -> PyResult>> { - let mut updaters = Vec::with_capacity(funcs.len()); - for updater_func in funcs { - let wrapped = Arc::new(WrappedProgressUpdater::new(updater_func, ctx.clone())?); - updaters.push(wrapped as Arc); - } - Ok(updaters) -} - -// TODO: we won't need to subclass this in the next major version update. -#[pyclass(subclass)] -#[derive(Clone, Debug)] -pub struct PyXetDownloadInfo { - #[pyo3(get, set)] - destination_path: String, - #[pyo3(get)] - hash: String, - #[pyo3(get)] - file_size: Option, -} - -#[pymethods] -impl PyXetDownloadInfo { - #[new] - #[pyo3(signature = (destination_path, hash, file_size=None))] - pub fn new(destination_path: String, hash: String, file_size: Option) -> Self { - Self { - destination_path, - hash, - file_size, - } - } - - fn __str__(&self) -> String { - format!("{self:?}") - } - - fn __repr__(&self) -> String { - let size_str = self.file_size.map_or("None".to_string(), |s| s.to_string()); - format!("PyXetDownloadInfo({}, {}, {})", self.destination_path, self.hash, size_str) - } -} - -// TODO: on the next major version update, delete this class and the trait implementation. -// This is used to support backward compatibility for PyPointerFile with old versions of huggingface_hub -#[pyclass(extends=PyXetDownloadInfo)] -#[derive(Clone, Debug)] -pub struct PyPointerFile {} - -#[pymethods] -impl PyPointerFile { - #[new] - pub fn new(path: String, hash: String, filesize: u64) -> (Self, PyXetDownloadInfo) { - (PyPointerFile {}, PyXetDownloadInfo::new(path, hash, Some(filesize))) - } - - fn __str__(&self) -> String { - format!("{self:?}") - } - - fn __repr__(self_: PyRef<'_, Self>) -> String { - let super_ = self_.as_super(); - let size_str = super_.file_size.map_or("None".to_string(), |s| s.to_string()); - format!("PyPointerFile({}, {}, {})", super_.destination_path, super_.hash, size_str) - } - - #[getter] - fn get_path(self_: PyRef<'_, Self>) -> String { - self_.as_super().destination_path.clone() - } - - #[setter] - fn set_path(mut self_: PyRefMut<'_, Self>, path: String) { - self_.as_super().destination_path = path; - } - - #[getter] - fn filesize(self_: PyRef<'_, Self>) -> Option { - self_.as_super().file_size - } -} - -#[pyclass] -#[derive(Clone, Debug)] -pub struct PyXetUploadInfo { - #[pyo3(get)] - pub hash: String, - #[pyo3(get)] - pub file_size: u64, - #[pyo3(get)] - pub sha256: Option, -} - -#[pymethods] -impl PyXetUploadInfo { - #[new] - pub fn new(hash: String, file_size: u64) -> Self { - Self { - hash, - file_size, - sha256: None, - } - } - - fn __str__(&self) -> String { - format!("{self:?}") - } - - fn __repr__(&self) -> String { - format!("PyXetUploadInfo({}, {}, {:?})", self.hash, self.file_size, self.sha256) - } - - /// TODO: Remove these getters in the next major version update. - #[getter] - fn filesize(self_: PyRef<'_, Self>) -> u64 { - self_.file_size - } -} - -type DestinationPath = String; - -impl From for PyXetUploadInfo { - fn from(xf: XetFileInfo) -> Self { - Self { - hash: xf.hash().to_owned(), - file_size: xf.file_size().expect("upload metadata must always include a known file size"), - sha256: xf.sha256().map(str::to_owned), - } - } -} - -impl From for (XetFileInfo, DestinationPath) { - fn from(pf: PyXetDownloadInfo) -> Self { - let file_info = match pf.file_size { - Some(size) => XetFileInfo::new(pf.hash, size), - None => XetFileInfo::new_hash_only(pf.hash), - }; - (file_info, pf.destination_path) - } -} +/// PyO3 0.26 requires that every variant in a `#[pyclass]` enum be either all +/// unit variants or all "complex" (tuple/struct) variants — mixing is not yet +/// supported. The internal `XetTaskState` has both unit variants (`Running`, +/// `Completed`, …) and a complex variant (`Error(String)`), so it cannot be +/// annotated with `#[pyclass]` as-is. Rather than restructuring the internal +/// enum, we expose this four-variant unit-only wrapper. The `Error` case is +/// surfaced as a raised Python exception by `task_state_to_pystate` instead. +#[pyclass(eq, name = "XetTaskState")] +#[derive(Clone, Debug, PartialEq)] +pub enum PyXetTaskState { + Running, + Finalizing, + Completed, + UserCancelled, +} + +// ── Module registration ─────────────────────────────────────────────────────── #[pymodule(gil_used = false)] #[allow(unused_variables)] pub fn hf_xet(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction!(upload_files, m)?)?; - m.add_function(wrap_pyfunction!(upload_bytes, m)?)?; - m.add_function(wrap_pyfunction!(hash_files, m)?)?; - m.add_function(wrap_pyfunction!(download_files, m)?)?; - m.add_function(wrap_pyfunction!(force_sigint_shutdown, m)?)?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - - // TODO: remove this during the next major version update. - // This supports backward compatibility for PyPointerFile with old versions - // huggingface_hub. - m.add_class::()?; - + // ── Configuration ──────────────────────────────────────────────────────── + m.add_class::()?; + m.add_class::()?; + + // ── New XetSession API ─────────────────────────────────────────────────── + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add("COMPUTE_SHA256", py_upload_commit::PyComputeSha256)?; + m.add("SKIP_SHA256", py_upload_commit::PySkipSha256)?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + // ── Python-facing task state enum ──────────────────────────────────────── + m.add_class::()?; + + // ── Report types (pyclass-annotated in xet_pkg with "python" feature) ──── + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + // ── Legacy types and functions (kept for backward compatibility) ───────── + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_function(wrap_pyfunction!(legacy::upload_bytes, m)?)?; + m.add_function(wrap_pyfunction!(legacy::upload_files, m)?)?; + m.add_function(wrap_pyfunction!(legacy::hash_files, m)?)?; + m.add_function(wrap_pyfunction!(legacy::download_files, m)?)?; + m.add_function(wrap_pyfunction!(legacy::force_sigint_shutdown, m)?)?; + + // ── Exceptions ─────────────────────────────────────────────────────────── xet_pkg::register_exceptions(m)?; - // Make sure the logger is set up. + // ── Logging ────────────────────────────────────────────────────────────── init_logging(py); // Raise the soft file handle limits if possible @@ -498,7 +115,6 @@ pub fn hf_xet(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { { profiling::start_profiler(); - // Setup to save the results at the end. #[pyfunction] fn profiler_cleanup() { profiling::save_profiler_report(); @@ -512,136 +128,3 @@ pub fn hf_xet(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { Ok(()) } - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use super::*; - - // Initialize Python once for all tests - fn setup() { - // When auto-initialize is enabled, Python will be initialized on first use - // This ensures Python is available for the tests - let _ = pyo3::Python::attach(|_py| {}); - } - - #[test] - fn test_build_headers_with_none_empty_hashmap() { - setup(); - let empty_map: HashMap = HashMap::new(); - let result = build_headers_with_user_agent(Some(empty_map)).unwrap(); - let headers = result.unwrap(); - - // Should have exactly one header: USER_AGENT - assert_eq!(headers.len(), 1); - assert!(headers.contains_key(header::USER_AGENT)); - - let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap(); - assert_eq!(user_agent, USER_AGENT); - - let result = build_headers_with_user_agent(None).unwrap(); - let headers = result.unwrap(); - - // Should have exactly one header: USER_AGENT - assert_eq!(headers.len(), 1); - assert!(headers.contains_key(header::USER_AGENT)); - - let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap(); - assert_eq!(user_agent, USER_AGENT); - } - - #[test] - fn test_build_headers_with_valid_headers() { - setup(); - let mut headers_map = HashMap::new(); - headers_map.insert("Content-Type".to_string(), "application/json".to_string()); - headers_map.insert("Authorization".to_string(), "Bearer token123".to_string()); - - let result = build_headers_with_user_agent(Some(headers_map)).unwrap(); - let headers = result.unwrap(); - - // Should have 3 headers: Content-Type, Authorization, and USER_AGENT - assert_eq!(headers.len(), 3); - - // Verify each header was converted correctly - assert_eq!(headers.get(header::CONTENT_TYPE).unwrap().to_str().unwrap(), "application/json"); - assert_eq!(headers.get(header::AUTHORIZATION).unwrap().to_str().unwrap(), "Bearer token123"); - - // Verify USER_AGENT was added - let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap(); - assert_eq!(user_agent, USER_AGENT); - } - - #[test] - fn test_build_headers_appends_to_existing_user_agent() { - setup(); - let mut headers_map = HashMap::new(); - headers_map.insert("User-Agent".to_string(), "CustomClient/1.0".to_string()); - - let result = build_headers_with_user_agent(Some(headers_map)).unwrap(); - let headers = result.unwrap(); - - // Should have exactly one header: USER_AGENT - assert_eq!(headers.len(), 1); - - // Verify USER_AGENT was appended to existing one - let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap(); - assert_eq!(user_agent, format!("CustomClient/1.0; {}", USER_AGENT)); - } - - #[test] - fn test_build_headers_with_invalid_header_name_or_value() { - setup(); - let mut headers_map = HashMap::new(); - headers_map.insert("Invalid Header!".to_string(), "value".to_string()); - - let result = build_headers_with_user_agent(Some(headers_map)); - - // Should return an error for invalid header name - assert!(result.is_err()); - let err_msg = result.unwrap_err().to_string(); - assert!(err_msg.contains("Invalid header name")); - - let mut headers_map = HashMap::new(); - // Header values cannot contain newlines - headers_map.insert("X-Custom".to_string(), "value\nwith\nnewlines".to_string()); - - let result = build_headers_with_user_agent(Some(headers_map)); - - // Should return an error for invalid header value - assert!(result.is_err()); - let err_msg = result.unwrap_err().to_string(); - assert!(err_msg.contains("Invalid header value")); - } - - #[test] - fn test_upload_files_sha256s_length_mismatch() { - setup(); - pyo3::Python::attach(|py| { - let file_paths = vec!["a.txt".to_string(), "b.txt".to_string()]; - let sha256s = Some(vec!["abc123".to_string()]); // 1 hash for 2 files - - let result = upload_files(py, file_paths, None, None, None, None, None, None, sha256s, false); - - assert!(result.is_err()); - let err_msg = result.unwrap_err().to_string(); - assert!(err_msg.contains("sha256s length (1) must match file_paths length (2)"), "got: {err_msg}"); - }); - } - - #[test] - fn test_upload_files_skip_sha256_conflicts_with_sha256s() { - setup(); - pyo3::Python::attach(|py| { - let file_paths = vec!["a.txt".to_string()]; - let sha256s = Some(vec!["abc123".to_string()]); - - let result = upload_files(py, file_paths, None, None, None, None, None, None, sha256s, true); - - assert!(result.is_err()); - let err_msg = result.unwrap_err().to_string(); - assert!(err_msg.contains("mutually exclusive"), "got: {err_msg}"); - }); - } -} diff --git a/hf_xet/src/logging.rs b/hf_xet/src/logging.rs index c8e47f809..78a4efb91 100644 --- a/hf_xet/src/logging.rs +++ b/hf_xet/src/logging.rs @@ -1,10 +1,9 @@ use pyo3::Python; use pyo3::types::PyAnyMethods; use tracing::info; -use xet_runtime::logging::LoggingConfig; fn get_version_info_string(py: Python<'_>) -> String { - // populate remote telemetry calls with versions for python and hf_hub if possible + // populate version info for the User-Agent header let mut version_info = String::new(); // Get Python version @@ -16,7 +15,7 @@ fn get_version_info_string(py: Python<'_>) -> String { } // Get huggingface_hub+hf_xet versions - let package_names = ["huggingface_hub", "hfxet"]; + let package_names = ["huggingface_hub", "hf_xet"]; if let Ok(importlib_metadata) = py.import("importlib.metadata") { for package_name in package_names.iter() { if let Ok(version) = importlib_metadata @@ -30,17 +29,9 @@ fn get_version_info_string(py: Python<'_>) -> String { version_info } -/// Wrap the core runtime logging functions. -pub fn init_logging(py: Python) { +/// Initialize the global tracing subscriber. +pub fn init_logging(py: Python<'_>) { let version_info = get_version_info_string(py); - let xet_cache_directory = xet_runtime::core::xet_cache_root(); - let log_dir = xet_cache_directory.join("logs"); - - // Called before any XetContext is created, so we use a standalone default config for - // early-init logging setup. - let cfg = LoggingConfig::from_directory(&xet_runtime::config::XetConfig::new(), version_info, log_dir); - - xet_runtime::logging::init(cfg); - + xet_pkg::init_logging(version_info); info!("hf_xet logging configured."); } diff --git a/hf_xet/src/py_download_stream_group.rs b/hf_xet/src/py_download_stream_group.rs new file mode 100644 index 000000000..f4d0486fd --- /dev/null +++ b/hf_xet/src/py_download_stream_group.rs @@ -0,0 +1,138 @@ +use std::collections::HashMap; +use std::ops::Range; + +use pyo3::prelude::*; +use xet_pkg::xet_session::{XetDownloadStreamGroup, XetFileInfo, XetSession}; + +use crate::convert_xet_error; +use crate::headers::{build_header_map, build_headers_with_user_agent}; +use crate::py_download_stream_handle::{PyXetDownloadStream, PyXetUnorderedDownloadStream}; + +// ── build_download_stream_group ─────────────────────────────────────────────── + +/// Create an :class:`XetDownloadStreamGroup` from a session and optional configuration. +/// +/// Called by :meth:`XetSession.new_download_stream_group`. The Rust builder type is +/// created and consumed entirely here — it never surfaces in any public API. +#[allow(clippy::too_many_arguments)] +pub(crate) fn build_download_stream_group( + py: Python<'_>, + session: &XetSession, + endpoint: Option, + token: Option, + token_expiry_unix_secs: Option, + token_refresh_url: Option, + token_refresh_headers: Option>, + custom_headers: Option>, +) -> PyResult { + let mut builder = session.new_download_stream_group().map_err(convert_xet_error)?; + if let Some(ep) = endpoint { + builder = builder.with_endpoint(ep); + } + if let (Some(tok), Some(exp)) = (token, token_expiry_unix_secs) { + builder = builder.with_token_info(tok, exp); + } + if let Some(url) = token_refresh_url { + let headers = build_header_map(token_refresh_headers.unwrap_or_default())?; + builder = builder.with_token_refresh_url(url, headers); + } + let merged_headers = build_headers_with_user_agent(custom_headers)?; + let group = py.detach(move || { + builder + .with_custom_headers(merged_headers) + .build_blocking() + .map_err(convert_xet_error) + })?; + Ok(PyXetDownloadStreamGroup { inner: group }) +} + +// ── PyXetDownloadStreamGroup ────────────────────────────────────────────────── + +/// A group of streaming file downloads sharing a single CAS connection pool. +/// +/// Each call to :meth:`download_stream` or :meth:`download_unordered_stream` +/// returns an independent Python iterator. Multiple streams can be active +/// concurrently from the same group. +/// +/// Cloning is cheap — all clones share the same underlying state. +#[pyclass(name = "XetDownloadStreamGroup")] +#[derive(Clone)] +pub struct PyXetDownloadStreamGroup { + pub(crate) inner: XetDownloadStreamGroup, +} + +#[pymethods] +impl PyXetDownloadStreamGroup { + fn __repr__(&self) -> &'static str { + "XetDownloadStreamGroup()" + } + + // ── Stream constructors ────────────────────────────────────────────────── + + /// Open an ordered byte stream for a file. + /// + /// ``file_info`` — a :class:`XetFileInfo` identifying the file. + /// + /// ``start`` / ``end`` — optional byte offsets (exclusive end). Both + /// default to ``None``, meaning the full file. Either may be omitted + /// independently: + /// + /// ```python + /// group.download_stream(info) # whole file + /// group.download_stream(info, start=3) # 3 .. EOF + /// group.download_stream(info, end=100) # 0 .. 100 + /// group.download_stream(info, start=3, end=100) # 3 .. 100 + /// ``` + /// + /// Returns a :class:`XetDownloadStream` iterator that yields ``bytes`` + /// chunks in order. Iterate it directly or call :meth:`cancel`. + /// + /// Releases the GIL during setup. + #[pyo3(signature = (file_info, start=None, end=None))] + pub fn download_stream( + &self, + py: Python<'_>, + file_info: XetFileInfo, + start: Option, + end: Option, + ) -> PyResult { + let byte_range: Option> = match (start, end) { + (None, None) => None, + (s, e) => Some(s.unwrap_or(0)..e.unwrap_or(u64::MAX)), + }; + let inner = self.inner.clone(); + let stream = py.detach(|| inner.download_stream_blocking(file_info, byte_range).map_err(convert_xet_error))?; + Ok(PyXetDownloadStream { inner: stream }) + } + + /// Open an unordered byte stream for a file. + /// + /// Yields ``(offset, bytes)`` tuples in completion order — chunks may + /// arrive out of order relative to their position in the file. Use this + /// when you want to process or write chunks as they arrive without waiting + /// for prior chunks. + /// + /// ``start`` / ``end`` behave the same as in :meth:`download_stream`. + /// + /// Releases the GIL during setup. + #[pyo3(signature = (file_info, start=None, end=None))] + pub fn download_unordered_stream( + &self, + py: Python<'_>, + file_info: XetFileInfo, + start: Option, + end: Option, + ) -> PyResult { + let byte_range: Option> = match (start, end) { + (None, None) => None, + (s, e) => Some(s.unwrap_or(0)..e.unwrap_or(u64::MAX)), + }; + let inner = self.inner.clone(); + let stream = py.detach(|| { + inner + .download_unordered_stream_blocking(file_info, byte_range) + .map_err(convert_xet_error) + })?; + Ok(PyXetUnorderedDownloadStream { inner: stream }) + } +} diff --git a/hf_xet/src/py_download_stream_handle.rs b/hf_xet/src/py_download_stream_handle.rs new file mode 100644 index 000000000..c4d2f44dd --- /dev/null +++ b/hf_xet/src/py_download_stream_handle.rs @@ -0,0 +1,130 @@ +use pyo3::prelude::*; +use pyo3::types::PyBytes; +use xet_pkg::xet_session::{ItemProgressReport, XetDownloadStream, XetUnorderedDownloadStream}; + +use crate::convert_xet_error; +use crate::utils::progress_display; + +// ── PyXetDownloadStream ─────────────────────────────────────────────────────── + +/// Ordered byte stream iterator for a single file. +/// +/// Returned by :meth:`XetDownloadStreamGroup.download_stream`. +/// +/// Usage: +/// +/// ```text +/// for chunk in group.download_stream(file_info): +/// f.write(chunk) # chunk is bytes, in file order +/// ``` +/// +/// Or with a byte range: +/// +/// ```text +/// for chunk in group.download_stream(file_info, start=0, end=1024): +/// process(chunk) +/// ``` +#[pyclass(name = "XetDownloadStream")] +pub struct PyXetDownloadStream { + pub(crate) inner: XetDownloadStream, +} + +#[pymethods] +impl PyXetDownloadStream { + // Example output: + // XetDownloadStream(task_id=4, bytes_completed=512/2048) + // XetDownloadStream(task_id=5, bytes_completed=?/?) ← before first progress report + fn __repr__(&self) -> String { + let prog = progress_display(self.inner.progress()); + format!("XetDownloadStream(task_id={}, bytes_completed={})", self.inner.task_id(), prog) + } + + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + /// Return the next ``bytes`` chunk, or raise ``StopIteration`` when done. + /// + /// Note: the GIL is held while waiting for the next chunk. + /// ``XetDownloadStream`` is not ``Clone``, so ``py.detach()`` cannot be + /// used here. In practice chunks arrive quickly from the background task, + /// so this is not expected to cause significant contention. + fn __next__(&mut self, py: Python<'_>) -> PyResult>> { + match self.inner.blocking_next().map_err(convert_xet_error)? { + Some(bytes) => Ok(Some(PyBytes::new(py, &bytes).unbind())), + None => Ok(None), + } + } + + /// Cancel this stream. Subsequent iteration will stop immediately. + pub fn cancel(&mut self) { + self.inner.cancel(); + } + + /// Current download progress for this stream, or ``None`` if not yet available. + pub fn progress(&self) -> Option { + self.inner.progress() + } +} + +// ── PyXetUnorderedDownloadStream ────────────────────────────────────────────── + +/// Unordered byte stream iterator for a single file. +/// +/// Returned by :meth:`XetDownloadStreamGroup.download_unordered_stream`. +/// +/// Each iteration yields a ``(offset: int, data: bytes)`` tuple where +/// ``offset`` is the byte position of ``data`` within the file (or range). +/// Chunks may arrive in any order. +/// +/// Usage: +/// +/// ```text +/// buf = bytearray(file_size) +/// for offset, chunk in group.download_unordered_stream(file_info): +/// buf[offset:offset + len(chunk)] = chunk +/// ``` +#[pyclass(name = "XetUnorderedDownloadStream")] +pub struct PyXetUnorderedDownloadStream { + pub(crate) inner: XetUnorderedDownloadStream, +} + +#[pymethods] +impl PyXetUnorderedDownloadStream { + // Example output: + // XetUnorderedDownloadStream(task_id=6, bytes_completed=4096/16384) + // XetUnorderedDownloadStream(task_id=7, bytes_completed=?/?) ← before first progress report + fn __repr__(&self) -> String { + let prog = progress_display(self.inner.progress()); + format!("XetUnorderedDownloadStream(task_id={}, bytes_completed={})", self.inner.task_id(), prog) + } + + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + /// Return the next ``(offset, bytes)`` chunk, or raise ``StopIteration`` + /// when done. + /// + /// Note: the GIL is held while waiting for the next chunk. + /// ``XetUnorderedDownloadStream`` is not ``Clone``, so ``py.detach()`` + /// cannot be used here. In practice chunks arrive quickly from the + /// background task, so this is not expected to cause significant + /// contention. + fn __next__<'py>(&mut self, py: Python<'py>) -> PyResult)>> { + match self.inner.blocking_next().map_err(convert_xet_error)? { + Some((offset, bytes)) => Ok(Some((offset, PyBytes::new(py, &bytes)))), + None => Ok(None), + } + } + + /// Cancel this stream. Subsequent iteration will stop immediately. + pub fn cancel(&mut self) { + self.inner.cancel(); + } + + /// Current download progress for this stream, or ``None`` if not yet available. + pub fn progress(&self) -> Option { + self.inner.progress() + } +} diff --git a/hf_xet/src/py_file_download_group.rs b/hf_xet/src/py_file_download_group.rs new file mode 100644 index 000000000..e7f997c1a --- /dev/null +++ b/hf_xet/src/py_file_download_group.rs @@ -0,0 +1,267 @@ +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use std::time::Duration; + +use pyo3::prelude::*; +use xet_pkg::xet_session::{ + GroupProgressReport, ItemProgressReport, UniqueID, XetDownloadGroupReport, XetFileDownload, XetFileDownloadGroup, + XetFileInfo, XetSession, XetTaskState, +}; + +use crate::headers::{build_header_map, build_headers_with_user_agent}; +use crate::py_file_download_handle::PyXetFileDownload; +use crate::utils::{progress_display, task_state_display, task_state_to_pystate}; +use crate::{PyXetTaskState, blocking_call_with_signal_check, convert_xet_error}; + +// ── build_file_download_group ───────────────────────────────────────────────── + +/// Create an :class:`XetFileDownloadGroup` from a session and optional configuration. +/// +/// Called by :meth:`XetSession.new_file_download_group`. The Rust builder type is +/// created and consumed entirely here — it never surfaces in any public API. +#[allow(clippy::too_many_arguments)] +pub(crate) fn build_file_download_group( + py: Python<'_>, + session: &XetSession, + endpoint: Option, + token: Option, + token_expiry_unix_secs: Option, + token_refresh_url: Option, + token_refresh_headers: Option>, + custom_headers: Option>, + progress_callback: Option>, + progress_interval_ms: u64, +) -> PyResult { + let mut builder = session.new_file_download_group().map_err(convert_xet_error)?; + if let Some(ep) = endpoint { + builder = builder.with_endpoint(ep); + } + if let (Some(tok), Some(exp)) = (token, token_expiry_unix_secs) { + builder = builder.with_token_info(tok, exp); + } + if let Some(url) = token_refresh_url { + let headers = build_header_map(token_refresh_headers.unwrap_or_default())?; + builder = builder.with_token_refresh_url(url, headers); + } + let merged_headers = build_headers_with_user_agent(custom_headers)?; + let group = py.detach(move || { + builder + .with_custom_headers(merged_headers) + .build_blocking() + .map_err(convert_xet_error) + })?; + + let download_handles = if let Some(callback) = progress_callback { + let handles: Arc>> = Arc::new(RwLock::new(Vec::new())); + let inner = group.clone(); + let handles_for_thread = Arc::clone(&handles); + let interval = Duration::from_millis(progress_interval_ms); + std::thread::spawn(move || { + loop { + std::thread::sleep(interval); + let is_terminal = !matches!(inner.status(), Ok(XetTaskState::Running) | Ok(XetTaskState::Finalizing)); + let group_report = inner.progress(); + let item_reports: HashMap = handles_for_thread + .read() + .map(|g| g.iter().filter_map(|h| h.progress().map(|p| (h.task_id(), p))).collect()) + .unwrap_or_default(); + let result = Python::attach(|py| callback.call1(py, (group_report, item_reports))); + if let Err(e) = result { + Python::attach(|py| e.print(py)); + break; + } + if is_terminal { + break; + } + } + }); + Some(handles) + } else { + None + }; + + Ok(PyXetFileDownloadGroup { + inner: group, + download_handles, + }) +} + +// ── PyXetFileDownloadGroup ──────────────────────────────────────────────────── + +/// A group of related file downloads. +/// +/// Implements the context-manager protocol. +/// +/// ```text +/// with session.new_file_download_group(endpoint="...") as group: +/// h = group.start_download_file(info, "/tmp/out.bin") +/// # on normal exit: wait_to_finish() is called automatically +/// # on exception: abort() is called automatically +/// ``` +#[pyclass(name = "XetFileDownloadGroup")] +pub struct PyXetFileDownloadGroup { + pub(crate) inner: XetFileDownloadGroup, + /// Per-file handles shared with the progress thread; None when no callback was registered. + download_handles: Option>>>, +} + +#[pymethods] +impl PyXetFileDownloadGroup { + // Example output: + // XetFileDownloadGroup(status="Running", downloads=[(3, "/tmp/model.bin", bytes_completed=1024/4096), (4, + // "/tmp/data.bin", bytes_completed=?/?)]) + // + // Each download entry is (task_id, dest_path, bytes_completed/total_bytes). + // Progress shows "?/?" before the first report arrives. + fn __repr__(&self) -> String { + let status = task_state_display(self.inner.status()); + let downloads: Vec = self + .inner + .active_download_info() + .into_iter() + .map(|(id, path, progress)| { + let prog = progress_display(progress); + format!("({id}, \"{}\", bytes_completed={prog})", path.display()) + }) + .collect(); + format!("XetFileDownloadGroup(status=\"{}\", downloads=[{}])", status, downloads.join(", ")) + } + + // ── Context manager ────────────────────────────────────────────────────── + + fn __enter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __exit__( + &self, + py: Python<'_>, + exc_type: Bound<'_, pyo3::PyAny>, + _exc_val: Bound<'_, pyo3::PyAny>, + _exc_tb: Bound<'_, pyo3::PyAny>, + ) -> PyResult { + if exc_type.is_none() { + // Normal exit: wait for all downloads (signal-interruptible). + self.wait_to_finish(py)?; + } else { + if let Err(e) = self.inner.abort() { + tracing::warn!("abort() failed during __exit__ exception path: {e}"); + } + } + Ok(false) + } + + // ── Download methods ───────────────────────────────────────────────────── + + /// Queue a file for download. + /// + /// ``file_info`` — a :class:`XetFileInfo` identifying the file (hash and size). + /// + /// ``dest_path`` — local filesystem path to write the file to. + /// + /// Returns immediately with a :class:`XetFileDownload` handle. Call + /// :meth:`finish` (or exit the ``with`` block) to wait for completion. + pub fn start_download_file( + &self, + py: Python<'_>, + file_info: XetFileInfo, + dest_path: String, + ) -> PyResult { + let path: std::path::PathBuf = dest_path.into(); + let inner = self.inner.clone(); + let handle = py.detach(|| inner.download_file_to_path_blocking(file_info, path).map_err(convert_xet_error))?; + if let Some(ref handles) = self.download_handles { + handles + .write() + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))? + .push(handle.clone()); + } + Ok(PyXetFileDownload { inner: handle }) + } + + // ── Finish / abort ─────────────────────────────────────────────────────── + + /// Wait for all downloads to complete and return a summary report. + /// + /// Returns a :class:`XetDownloadGroupReport`. Also called automatically + /// when exiting a ``with`` block without an exception. + /// + /// Releases the GIL while waiting, polling for ``KeyboardInterrupt`` every + /// 100 ms so that Ctrl-C is delivered promptly. + pub fn wait_to_finish(&self, py: Python<'_>) -> PyResult { + let group = self.inner.clone(); + blocking_call_with_signal_check(py, move || group.finish_blocking()) + } + + /// Cancel all active downloads in this group. + pub fn abort(&self) -> PyResult<()> { + self.inner.abort().map_err(convert_xet_error) + } + + // ── Progress / status ──────────────────────────────────────────────────── + + /// Aggregate progress for all downloads in this group. + /// + /// Returns a :class:`GroupProgressReport`. Lock-free. + pub fn progress(&self) -> GroupProgressReport { + self.inner.progress() + } + + /// Current task state as a :class:`XetTaskState` enum value. Raises on error. + pub fn status(&self) -> PyResult { + task_state_to_pystate(self.inner.status()) + } +} + +#[cfg(test)] +mod tests { + use pyo3::Python; + use tempfile::tempdir; + use xet_pkg::xet_session::XetSessionBuilder; + + use super::*; + + // ── PyXetFileDownloadGroup ──────────────────────────────────────────────── + + #[test] + fn test_finish_empty_group() { + let temp = tempdir().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let session = XetSessionBuilder::new().build().unwrap(); + let group = PyXetFileDownloadGroup { + inner: session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(), + download_handles: None, + }; + + Python::attach(|py| { + let report = group.wait_to_finish(py).unwrap(); + assert!(report.downloads.is_empty()); + }); + } + + #[test] + fn test_abort_makes_finish_fail() { + let temp = tempdir().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let session = XetSessionBuilder::new().build().unwrap(); + let group = PyXetFileDownloadGroup { + inner: session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(), + download_handles: None, + }; + + Python::attach(|py| { + group.abort().unwrap(); + assert!(group.wait_to_finish(py).is_err()); + }); + } +} diff --git a/hf_xet/src/py_file_download_handle.rs b/hf_xet/src/py_file_download_handle.rs new file mode 100644 index 000000000..0e41ea2ee --- /dev/null +++ b/hf_xet/src/py_file_download_handle.rs @@ -0,0 +1,140 @@ +use pyo3::prelude::*; +use xet_pkg::xet_session::{ItemProgressReport, UniqueID, XetDownloadReport, XetFileDownload}; + +use crate::utils::{progress_display, task_state_display, task_state_to_pystate}; +use crate::{PyXetTaskState, convert_xet_error}; + +// ── PyXetFileDownload ───────────────────────────────────────────────────────── + +/// Handle for a background file-download task. +/// +/// Returned by :meth:`XetFileDownloadGroup.start_download_file`. +#[pyclass(name = "XetFileDownload")] +pub struct PyXetFileDownload { + pub(crate) inner: XetFileDownload, +} + +#[pymethods] +impl PyXetFileDownload { + // Example output: + // XetFileDownload(task_id=3, status="Running", bytes_completed=1024/4096) + // XetFileDownload(task_id=4, status="Completed", bytes_completed=2048/2048) + // XetFileDownload(task_id=5, status="Running", bytes_completed=?/?) ← before first progress report + fn __repr__(&self) -> String { + let status = task_state_display(self.inner.status()); + let prog = progress_display(self.inner.progress()); + format!("XetFileDownload(task_id={}, status=\"{}\", bytes_completed={})", self.inner.task_id(), status, prog) + } + + /// Per-file progress, or ``None`` if not yet available. + pub fn progress(&self) -> Option { + self.inner.progress() + } + + /// Current task state as a :class:`XetTaskState` enum value. Raises on error. + pub fn status(&self) -> PyResult { + task_state_to_pystate(self.inner.status()) + } + + /// Wait for this download to complete and return its report. + /// + /// Releases the GIL. + pub fn result(&self, py: Python<'_>) -> PyResult { + let inner = self.inner.clone(); + py.detach(|| inner.finish_blocking().map_err(convert_xet_error)) + } + + /// Return the download report without blocking. + /// + /// Returns ``None`` if the download has not yet completed. + /// Raises if the download completed with an error. + pub fn try_result(&self) -> PyResult> { + match self.inner.result() { + Some(Ok(r)) => Ok(Some(r)), + Some(Err(e)) => Err(convert_xet_error(e)), + None => Ok(None), + } + } + + /// The unique task ID for this download. + /// + /// Matches the keys in :attr:`XetDownloadGroupReport.downloads`. + pub fn task_id(&self) -> UniqueID { + self.inner.task_id() + } + + /// Cancel this individual download. + pub fn cancel(&self) { + self.inner.cancel(); + } +} + +#[cfg(test)] +mod tests { + use pyo3::Python; + use tempfile::tempdir; + use xet_pkg::xet_session::{Sha256Policy, XetFileInfo, XetSessionBuilder}; + + use super::*; + + fn upload_bytes_and_get_info( + data: &[u8], + endpoint: &str, + session: &xet_pkg::xet_session::XetSession, + ) -> XetFileInfo { + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(endpoint) + .build_blocking() + .unwrap(); + let handle = commit + .upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, None) + .unwrap(); + commit.commit_blocking().unwrap(); + let meta = handle.try_finish().unwrap(); + XetFileInfo::new(meta.xet_info.hash().to_owned(), meta.xet_info.file_size().unwrap_or(0)) + } + + #[test] + fn test_file_download_handle_task_id_and_result() { + let temp = tempdir().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let session = XetSessionBuilder::new().build().unwrap(); + let file_info = upload_bytes_and_get_info(b"hello world", &endpoint, &session); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); + let dest = temp.path().join("out.bin"); + let handle = group.download_file_to_path_blocking(file_info, dest.clone()).unwrap(); + let py_handle = PyXetFileDownload { inner: handle }; + assert!(py_handle.task_id().0 > 0); + Python::attach(|py| { + let result = py_handle.result(py).unwrap(); + assert_eq!(result.file_info.file_size, Some(11)); + assert!(dest.exists()); + }); + } + + #[test] + fn test_file_download_handle_cancel() { + let temp = tempdir().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let session = XetSessionBuilder::new().build().unwrap(); + let file_info = upload_bytes_and_get_info(b"cancel me", &endpoint, &session); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); + let dest = temp.path().join("cancelled.bin"); + let handle = group.download_file_to_path_blocking(file_info, dest).unwrap(); + let py_handle = PyXetFileDownload { inner: handle }; + // cancel should not panic + py_handle.cancel(); + } +} diff --git a/hf_xet/src/py_file_upload_handle.rs b/hf_xet/src/py_file_upload_handle.rs new file mode 100644 index 000000000..661a7b2ef --- /dev/null +++ b/hf_xet/src/py_file_upload_handle.rs @@ -0,0 +1,113 @@ +use pyo3::prelude::*; +use xet_pkg::xet_session::{ItemProgressReport, UniqueID, XetFileMetadata, XetFileUpload}; + +use crate::utils::{progress_display, task_state_display, task_state_to_pystate}; +use crate::{PyXetTaskState, convert_xet_error}; + +// ── PyXetFileUpload ─────────────────────────────────────────────────────────── + +/// Handle for a background file-upload task. +/// +/// Returned by :meth:`XetUploadCommit.start_upload_file` and +/// :meth:`XetUploadCommit.start_upload_bytes`. +#[pyclass(name = "XetFileUpload")] +pub struct PyXetFileUpload { + pub(crate) inner: XetFileUpload, +} + +#[pymethods] +impl PyXetFileUpload { + // Example output: + // XetFileUpload(task_id=1, status="Running", bytes_completed=1024/4096) + // XetFileUpload(task_id=2, status="Completed", bytes_completed=4096/4096) + // XetFileUpload(task_id=3, status="Running", bytes_completed=?/?) ← before first progress report + fn __repr__(&self) -> String { + let status = task_state_display(self.inner.status()); + let prog = progress_display(self.inner.progress()); + format!("XetFileUpload(task_id={}, status=\"{}\", bytes_completed={})", self.inner.task_id(), status, prog) + } + + /// Per-file progress, or ``None`` if not yet available. + pub fn progress(&self) -> Option { + self.inner.progress() + } + + /// Current task state as a :class:`XetTaskState` enum value. Raises on error. + pub fn status(&self) -> PyResult { + task_state_to_pystate(self.inner.status()) + } + + /// Wait for ingestion to complete and return upload metadata. + /// + /// Releases the GIL. Call after :meth:`XetUploadCommit.wait_to_finish` to get + /// the final :class:`XetFileMetadata`. + pub fn result(&self, py: Python<'_>) -> PyResult { + let inner = self.inner.clone(); + py.detach(|| inner.finalize_ingestion_blocking().map_err(convert_xet_error)) + } + + /// Return upload metadata without blocking, or ``None`` if not yet done. + pub fn try_result(&self) -> Option { + self.inner.try_finish() + } + + /// The unique task ID for this upload. + /// + /// Matches the keys in :attr:`XetCommitReport.uploads`. + pub fn task_id(&self) -> UniqueID { + self.inner.task_id() + } +} + +#[cfg(test)] +mod tests { + use pyo3::Python; + use tempfile::tempdir; + use xet_pkg::xet_session::{Sha256Policy, XetSessionBuilder}; + + use super::*; + + #[test] + fn test_file_upload_handle_task_id_and_result() { + let temp = tempdir().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let session = XetSessionBuilder::new().build().unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); + let handle = commit + .upload_bytes_blocking(b"hello world".to_vec(), Sha256Policy::Compute, Some("test.bin".into())) + .unwrap(); + let py_handle = PyXetFileUpload { inner: handle }; + assert!(py_handle.task_id().0 > 0); + commit.commit_blocking().unwrap(); + Python::attach(|py| { + let result = py_handle.result(py).unwrap(); + assert_eq!(result.xet_info.file_size, Some(11)); + assert!(!result.xet_info.hash.is_empty()); + }); + } + + #[test] + fn test_file_upload_handle_try_result_before_commit_is_none() { + let temp = tempdir().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let session = XetSessionBuilder::new().build().unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); + let handle = commit + .upload_bytes_blocking(b"data".to_vec(), Sha256Policy::Compute, None) + .unwrap(); + let py_handle = PyXetFileUpload { inner: handle }; + // Before commit, ingestion may not be finalized yet + // (try_result may or may not be Some depending on timing; just verify no panic) + let _ = py_handle.try_result(); + } +} diff --git a/hf_xet/src/py_stream_upload_handle.rs b/hf_xet/src/py_stream_upload_handle.rs new file mode 100644 index 000000000..1a0c2a9be --- /dev/null +++ b/hf_xet/src/py_stream_upload_handle.rs @@ -0,0 +1,127 @@ +use pyo3::prelude::*; +use xet_pkg::xet_session::{ItemProgressReport, UniqueID, XetFileMetadata, XetStreamUpload}; + +use crate::utils::{progress_display, task_state_display, task_state_to_pystate}; +use crate::{PyXetTaskState, convert_xet_error}; + +// ── PyXetStreamUpload ───────────────────────────────────────────────────────── + +/// Handle for a streaming upload within an :class:`XetUploadCommit`. +/// +/// Returned by :meth:`XetUploadCommit.start_upload_stream`. Feed data incrementally +/// with :meth:`write`, then call :meth:`finish` to finalise ingestion. +/// **:meth:`finish` must be called before** :meth:`XetUploadCommit.wait_to_finish`. +#[pyclass(name = "XetStreamUpload")] +#[derive(Clone)] +pub struct PyXetStreamUpload { + pub(crate) inner: XetStreamUpload, +} + +#[pymethods] +impl PyXetStreamUpload { + // Example output: + // XetStreamUpload(task_id=1, status="Running", bytes_completed=512/4096) + // XetStreamUpload(task_id=2, status="Completed", bytes_completed=4096/4096) + // XetStreamUpload(task_id=3, status="Running", bytes_completed=?/?) ← before first progress report + fn __repr__(&self) -> String { + let status = task_state_display(self.inner.status()); + let prog = progress_display(self.inner.progress()); + format!("XetStreamUpload(task_id={}, status=\"{}\", bytes_completed={})", self.inner.task_id(), status, prog) + } + + /// Feed a chunk of data into the upload pipeline. + /// + /// May be called any number of times before :meth:`finish`. + /// Releases the GIL while writing. + pub fn write(&self, py: Python<'_>, data: &[u8]) -> PyResult<()> { + // Copy bytes into an owned Vec before releasing the GIL so the + // Python bytes object can be freed independently. + let owned: Vec = data.to_vec(); + let inner = self.inner.clone(); + py.detach(|| inner.write_blocking(owned).map_err(convert_xet_error)) + } + + /// Finalise the stream and return per-file upload metadata. + /// + /// Must be called before :meth:`XetUploadCommit.wait_to_finish`. + /// Releases the GIL while waiting. + pub fn finish(&self, py: Python<'_>) -> PyResult { + let inner = self.inner.clone(); + py.detach(|| inner.finish_blocking().map_err(convert_xet_error)) + } + + /// Return upload metadata without blocking, or ``None`` if not yet finished. + pub fn try_finish(&self) -> Option { + self.inner.try_finish() + } + + /// Per-file progress snapshot, or ``None`` if not yet available. + pub fn progress(&self) -> Option { + self.inner.progress() + } + + /// Current task state as a :class:`XetTaskState` enum value. Raises on error. + pub fn status(&self) -> PyResult { + task_state_to_pystate(self.inner.status()) + } + + /// The unique task ID for this stream. + pub fn task_id(&self) -> UniqueID { + self.inner.task_id() + } + + /// Cancel the streaming upload. + pub fn abort(&self) { + self.inner.abort(); + } +} + +#[cfg(test)] +mod tests { + use pyo3::Python; + use tempfile::tempdir; + use xet_pkg::xet_session::{Sha256Policy, XetSessionBuilder}; + + use super::*; + + #[test] + fn test_stream_upload_write_and_finish() { + let temp = tempdir().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let session = XetSessionBuilder::new().build().unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); + let stream_handle = commit + .upload_stream_blocking(Some("stream.bin".into()), Sha256Policy::Compute) + .unwrap(); + let py_stream = PyXetStreamUpload { inner: stream_handle }; + assert!(py_stream.task_id().0 > 0); + Python::attach(|py| { + py_stream.write(py, b"hello world").unwrap(); + let result = py_stream.finish(py).unwrap(); + assert_eq!(result.xet_info.file_size, Some(11)); + assert!(!result.xet_info.hash.is_empty()); + }); + } + + #[test] + fn test_stream_upload_try_finish_before_finish_is_none() { + let temp = tempdir().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let session = XetSessionBuilder::new().build().unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); + let stream_handle = commit.upload_stream_blocking(None, Sha256Policy::Compute).unwrap(); + let py_stream = PyXetStreamUpload { inner: stream_handle }; + // Before finish(), try_finish() should return None + assert!(py_stream.try_finish().is_none()); + } +} diff --git a/hf_xet/src/py_upload_commit.rs b/hf_xet/src/py_upload_commit.rs new file mode 100644 index 000000000..56e31a2a3 --- /dev/null +++ b/hf_xet/src/py_upload_commit.rs @@ -0,0 +1,447 @@ +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use std::time::Duration; + +use pyo3::prelude::*; +use xet_pkg::xet_session::{ + GroupProgressReport, ItemProgressReport, Sha256Policy, UniqueID, XetCommitReport, XetFileUpload, XetSession, + XetTaskState, XetUploadCommit, +}; + +// ── SHA-256 policy sentinels ────────────────────────────────────────────────── + +/// Sentinel: compute SHA-256 from the file data (default behaviour). +/// +/// Pass this as the ``sha256`` argument to :meth:`XetUploadCommit.start_upload_file`, +/// :meth:`XetUploadCommit.start_upload_bytes`, or :meth:`XetUploadCommit.start_upload_stream`. +#[pyclass(frozen, name = "_ComputeSha256Type")] +pub struct PyComputeSha256; + +#[pymethods] +impl PyComputeSha256 { + fn __repr__(&self) -> &'static str { + "COMPUTE_SHA256" + } +} + +/// Sentinel: skip SHA-256 computation entirely. +/// +/// Pass this as the ``sha256`` argument to :meth:`XetUploadCommit.start_upload_file`, +/// :meth:`XetUploadCommit.start_upload_bytes`, or :meth:`XetUploadCommit.start_upload_stream`. +#[pyclass(frozen, name = "_SkipSha256Type")] +pub struct PySkipSha256; + +#[pymethods] +impl PySkipSha256 { + fn __repr__(&self) -> &'static str { + "SKIP_SHA256" + } +} + +/// Convert the Python ``sha256`` argument to a :type:`Sha256Policy`. +/// +/// Accepts: +/// - ``None`` or :data:`COMPUTE_SHA256` → compute from data +/// - :data:`SKIP_SHA256` → skip +/// - ``str`` → treat as a pre-computed hex digest +fn parse_sha256(py: Python<'_>, sha256: Option>) -> PyResult { + match sha256 { + None => Ok(Sha256Policy::Compute), + Some(obj) => { + let obj = obj.bind(py); + if obj.is_instance_of::() { + Ok(Sha256Policy::Compute) + } else if obj.is_instance_of::() { + Ok(Sha256Policy::Skip) + } else if let Ok(hex) = obj.extract::() { + Ok(Sha256Policy::from_hex(&hex)) + } else { + Err(pyo3::exceptions::PyTypeError::new_err("sha256 must be a str, COMPUTE_SHA256, or SKIP_SHA256")) + } + }, + } +} + +use crate::headers::{build_header_map, build_headers_with_user_agent}; +use crate::py_file_upload_handle::PyXetFileUpload; +use crate::py_stream_upload_handle::PyXetStreamUpload; +use crate::utils::{progress_display, task_state_display, task_state_to_pystate}; +use crate::{PyXetTaskState, blocking_call_with_signal_check, convert_xet_error}; + +// ── build_upload_commit ─────────────────────────────────────────────────────── + +/// Create an :class:`XetUploadCommit` from a session and optional configuration. +/// +/// Called by :meth:`XetSession.new_upload_commit`. The Rust builder type is +/// created and consumed entirely here — it never surfaces in any public API. +#[allow(clippy::too_many_arguments)] +pub(crate) fn build_upload_commit( + py: Python<'_>, + session: &XetSession, + endpoint: Option, + token: Option, + token_expiry_unix_secs: Option, + token_refresh_url: Option, + token_refresh_headers: Option>, + custom_headers: Option>, + progress_callback: Option>, + progress_interval_ms: u64, +) -> PyResult { + let mut builder = session.new_upload_commit().map_err(convert_xet_error)?; + if let Some(ep) = endpoint { + builder = builder.with_endpoint(ep); + } + if let (Some(tok), Some(exp)) = (token, token_expiry_unix_secs) { + builder = builder.with_token_info(tok, exp); + } + if let Some(url) = token_refresh_url { + let headers = build_header_map(token_refresh_headers.unwrap_or_default())?; + builder = builder.with_token_refresh_url(url, headers); + } + let merged_headers = build_headers_with_user_agent(custom_headers)?; + let commit = py.detach(move || { + builder + .with_custom_headers(merged_headers) + .build_blocking() + .map_err(convert_xet_error) + })?; + + let upload_handles = if let Some(callback) = progress_callback { + let handles: Arc>> = Arc::new(RwLock::new(Vec::new())); + let inner = commit.clone(); + let handles_for_thread = handles.clone(); + let interval = Duration::from_millis(progress_interval_ms); + std::thread::spawn(move || { + loop { + std::thread::sleep(interval); + let is_terminal = !matches!(inner.status(), Ok(XetTaskState::Running) | Ok(XetTaskState::Finalizing)); + let group_report = inner.progress(); + let item_reports: HashMap = handles_for_thread + .read() + .map(|g| g.iter().filter_map(|h| h.progress().map(|p| (h.task_id(), p))).collect()) + .unwrap_or_default(); + let result = Python::attach(|py| callback.call1(py, (group_report, item_reports))); + if let Err(e) = result { + Python::attach(|py| e.print(py)); + break; + } + if is_terminal { + break; + } + } + }); + Some(handles) + } else { + None + }; + + Ok(PyXetUploadCommit { + inner: commit, + upload_handles, + }) +} + +// ── PyXetUploadCommit ───────────────────────────────────────────────────────── + +/// A group of related file uploads. +/// +/// Implements the context-manager protocol. +/// +/// ```text +/// with session.new_upload_commit(endpoint="...") as commit: +/// h = commit.start_upload_file("/path/to/file.bin") +/// # on normal exit: wait_to_finish() is called automatically +/// # on exception: abort() is called automatically +/// ``` +#[pyclass(name = "XetUploadCommit")] +pub struct PyXetUploadCommit { + pub(crate) inner: XetUploadCommit, + /// Per-file handles shared with the progress thread; None when no callback was registered. + upload_handles: Option>>>, +} + +#[pymethods] +impl PyXetUploadCommit { + // Example output: + // XetUploadCommit(status="Running", uploads=[(1, "/path/model.bin", bytes_completed=1024/4096), (2, None, + // bytes_completed=?/?)]) + // + // Each upload entry is (task_id, path_or_None, bytes_completed/total_bytes). + // Path is None for uploads started from bytes rather than a file path. + // Progress shows "?/?" before the first report arrives. + fn __repr__(&self) -> String { + let status = task_state_display(self.inner.status()); + let uploads: Vec = self + .inner + .active_upload_info() + .into_iter() + .map(|(id, path, progress)| { + let p = match path { + Some(pb) => format!("\"{}\"", pb.display()), + None => "None".to_string(), + }; + let prog = progress_display(progress); + format!("({id}, {p}, bytes_completed={prog})") + }) + .collect(); + format!("XetUploadCommit(status=\"{}\", uploads=[{}])", status, uploads.join(", ")) + } + + // ── Context manager ────────────────────────────────────────────────────── + + fn __enter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __exit__( + &self, + py: Python<'_>, + exc_type: Bound<'_, pyo3::PyAny>, + _exc_val: Bound<'_, pyo3::PyAny>, + _exc_tb: Bound<'_, pyo3::PyAny>, + ) -> PyResult { + if exc_type.is_none() { + // Normal exit: commit all uploads (signal-interruptible). + self.wait_to_finish(py)?; + } else { + // Exception: cancel uploads. + if let Err(e) = self.inner.abort() { + tracing::warn!("abort() failed during __exit__ exception path: {e}"); + } + } + Ok(false) // do not suppress the exception + } + + // ── Upload methods ─────────────────────────────────────────────────────── + + /// Queue a file from disk for upload. + /// + /// Returns immediately with a :class:`XetFileUpload` handle. The upload + /// runs in the background. Call :meth:`XetUploadCommit.wait_to_finish` (or exit + /// the ``with`` block) to wait for all uploads to complete. + /// + /// ``sha256`` controls how the SHA-256 digest is handled: + /// + /// - ``sha256="f2358d9a…"`` — pre-computed hex string (most common for models/datasets) + /// - ``sha256=hf_xet.COMPUTE_SHA256`` — compute from file data (default when omitted) + /// - ``sha256=hf_xet.SKIP_SHA256`` — skip SHA-256 entirely + #[pyo3(signature = (path, sha256=None))] + pub fn start_upload_file( + &self, + py: Python<'_>, + path: String, + sha256: Option>, + ) -> PyResult { + let policy = parse_sha256(py, sha256)?; + let inner = self.inner.clone(); + let handle = py.detach(|| inner.upload_from_path_blocking(path.into(), policy).map_err(convert_xet_error))?; + if let Some(ref handles) = self.upload_handles { + handles + .write() + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))? + .push(handle.clone()); + } + Ok(PyXetFileUpload { inner: handle }) + } + + /// Queue raw bytes for upload. + /// + /// ``name`` is an optional display name used for progress reporting. + /// ``sha256`` accepts the same values as :meth:`start_upload_file`. + #[pyo3(signature = (data, sha256=None, name=None))] + pub fn start_upload_bytes( + &self, + py: Python<'_>, + data: Vec, + sha256: Option>, + name: Option, + ) -> PyResult { + let policy = parse_sha256(py, sha256)?; + let inner = self.inner.clone(); + let handle = py.detach(|| inner.upload_bytes_blocking(data, policy, name).map_err(convert_xet_error))?; + if let Some(ref handles) = self.upload_handles { + handles + .write() + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))? + .push(handle.clone()); + } + Ok(PyXetFileUpload { inner: handle }) + } + + /// Open a streaming upload channel. + /// + /// Returns a :class:`XetStreamUpload` handle. Feed data incrementally + /// with :meth:`XetStreamUpload.write`, then call :meth:`XetStreamUpload.finish` + /// **before** calling :meth:`XetUploadCommit.wait_to_finish`. + /// + /// ``name`` is an optional display name used for progress reporting. + /// ``sha256`` accepts the same values as :meth:`start_upload_file`. + /// + /// Example: + /// + /// ```text + /// stream = commit.start_upload_stream(name="model.bin") + /// for chunk in produce_chunks(): + /// stream.write(chunk) + /// result = stream.finish() # must be called before wait_to_finish() + /// print(result.xet_info.hash, result.xet_info.file_size) + /// ``` + #[pyo3(signature = (name=None, sha256=None))] + pub fn start_upload_stream( + &self, + py: Python<'_>, + name: Option, + sha256: Option>, + ) -> PyResult { + let policy = parse_sha256(py, sha256)?; + let inner = self.inner.clone(); + let handle = py.detach(|| inner.upload_stream_blocking(name, policy).map_err(convert_xet_error))?; + Ok(PyXetStreamUpload { inner: handle }) + } + + // ── Commit / abort ─────────────────────────────────────────────────────── + + /// Wait for all uploads to finish and push metadata to the CAS server. + /// + /// Returns a :class:`XetCommitReport`. Also called automatically when + /// exiting a ``with`` block without an exception. + /// + /// Releases the GIL while waiting, polling for ``KeyboardInterrupt`` every + /// 100 ms so that Ctrl-C is delivered promptly. + pub fn wait_to_finish(&self, py: Python<'_>) -> PyResult { + let inner = self.inner.clone(); + blocking_call_with_signal_check(py, move || inner.commit_blocking()) + } + + /// Cancel all active uploads in this commit. + pub fn abort(&self) -> PyResult<()> { + self.inner.abort().map_err(convert_xet_error) + } + + // ── Progress / status ──────────────────────────────────────────────────── + + /// Aggregate progress for all uploads in this commit. + /// + /// Returns a :class:`GroupProgressReport`. Lock-free — safe to call from + /// any thread without holding the GIL. + pub fn progress(&self) -> GroupProgressReport { + self.inner.progress() + } + + /// Current task state as a :class:`XetTaskState` enum value. Raises on error. + pub fn status(&self) -> PyResult { + task_state_to_pystate(self.inner.status()) + } +} + +#[cfg(test)] +mod tests { + use pyo3::Python; + use xet_pkg::xet_session::Sha256Policy; + + use super::*; + + // ── parse_sha256 ────────────────────────────────────────────────────────── + + #[test] + fn test_parse_sha256_none_gives_compute() { + Python::attach(|py| { + let policy = parse_sha256(py, None).unwrap(); + assert!(matches!(policy, Sha256Policy::Compute)); + }); + } + + #[test] + fn test_parse_sha256_compute_sentinel() { + Python::attach(|py| { + let sentinel: Py = Py::new(py, PyComputeSha256).unwrap().into(); + let policy = parse_sha256(py, Some(sentinel)).unwrap(); + assert!(matches!(policy, Sha256Policy::Compute)); + }); + } + + #[test] + fn test_parse_sha256_skip_sentinel() { + Python::attach(|py| { + let sentinel: Py = Py::new(py, PySkipSha256).unwrap().into(); + let policy = parse_sha256(py, Some(sentinel)).unwrap(); + assert!(matches!(policy, Sha256Policy::Skip)); + }); + } + + #[test] + fn test_parse_sha256_provided_hex_string() { + Python::attach(|py| { + let hex = "a".repeat(64); + let obj: Py = hex.into_pyobject(py).unwrap().into_any().unbind(); + let policy = parse_sha256(py, Some(obj)).unwrap(); + assert!(matches!(policy, Sha256Policy::Provided(_))); + }); + } + + #[test] + fn test_parse_sha256_invalid_type_returns_type_error() { + Python::attach(|py| { + let obj: Py = 42i64.into_pyobject(py).unwrap().into_any().unbind(); + match parse_sha256(py, Some(obj)) { + Ok(_) => panic!("expected TypeError"), + Err(e) => assert!(e.is_instance_of::(py)), + } + }); + } + + // ── PyXetUploadCommit ───────────────────────────────────────────────────── + + #[test] + fn test_start_upload_bytes_and_commit_report() { + use tempfile::tempdir; + use xet_pkg::xet_session::XetSessionBuilder; + + let temp = tempdir().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let session = XetSessionBuilder::new().build().unwrap(); + let commit = PyXetUploadCommit { + inner: session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(), + upload_handles: None, + }; + + Python::attach(|py| { + let handle = commit.start_upload_bytes(py, b"hello world".to_vec(), None, None).unwrap(); + let task_id = handle.task_id(); + let report = commit.wait_to_finish(py).unwrap(); + assert!(report.uploads.contains_key(&task_id)); + let meta = &report.uploads[&task_id]; + assert_eq!(meta.xet_info.file_size, Some(11)); + assert!(!meta.xet_info.hash.is_empty()); + }); + } + + #[test] + fn test_abort_makes_commit_fail() { + use tempfile::tempdir; + use xet_pkg::xet_session::XetSessionBuilder; + + let temp = tempdir().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let session = XetSessionBuilder::new().build().unwrap(); + let commit = PyXetUploadCommit { + inner: session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(), + upload_handles: None, + }; + + Python::attach(|py| { + commit.abort().unwrap(); + assert!(commit.wait_to_finish(py).is_err()); + }); + } +} diff --git a/hf_xet/src/py_xet_session.rs b/hf_xet/src/py_xet_session.rs new file mode 100644 index 000000000..77d2f0ef8 --- /dev/null +++ b/hf_xet/src/py_xet_session.rs @@ -0,0 +1,281 @@ +use std::collections::HashMap; + +use pyo3::prelude::*; +use xet_pkg::xet_session::{XetSession, XetSessionBuilder}; +use xet_runtime::config::XetConfig; + +use crate::py_download_stream_group::{PyXetDownloadStreamGroup, build_download_stream_group}; +use crate::py_file_download_group::{PyXetFileDownloadGroup, build_file_download_group}; +use crate::py_upload_commit::{PyXetUploadCommit, build_upload_commit}; +use crate::utils::{task_state_display, task_state_to_pystate}; +use crate::{PyXetTaskState, convert_xet_error}; + +// ── PyXetSession ───────────────────────────────────────────────────────────── + +/// Manages a Xet runtime context and connection pool. +/// +/// Session objects are cheap to clone — all clones share the same underlying state. +#[pyclass(name = "XetSession")] +#[derive(Clone)] +pub struct PyXetSession { + pub(crate) inner: XetSession, +} + +#[pymethods] +impl PyXetSession { + // Example output: + // XetSession(id="01JBQW...", status="Running", config={data.max_concurrent_file_ingestion=4, ...}) + fn __repr__(&self, py: Python<'_>) -> PyResult { + let status = task_state_display(self.inner.status()); + let id = self.inner.id(); + let items = self.inner.config().all_items_to_python(py)?; + let config_str = items + .into_iter() + .map(|(k, v): (String, Py)| { + let repr = v.bind(py).repr().map(|r| r.to_string()).unwrap_or_else(|_| "?".to_string()); + format!("{k}={repr}") + }) + .collect::>() + .join(", "); + Ok(format!("XetSession(id=\"{id}\", status=\"{status}\", config={{{config_str}}})")) + } + + /// Create a new XetSession. + /// + /// ``config`` is an optional :class:`XetConfig` instance. When omitted, a + /// default config (with environment-variable overrides applied) is used. + #[new] + #[pyo3(signature = (config=None))] + pub fn new(config: Option) -> PyResult { + #[allow(clippy::unwrap_or_default)] + // XetConfig::new starts from default() and applies HF_XET_* environment variable overrides + let xet_config = config.map(|c| c.into_inner()).unwrap_or_else(XetConfig::new); + let session = XetSessionBuilder::new_with_config(xet_config).build().map_err(PyErr::from)?; + Ok(Self { inner: session }) + } + + /// Create a new :class:`XetUploadCommit` and establish the CAS connection. + /// + /// All parameters are optional. Releases the GIL during the blocking + /// network handshake. + /// + /// ``endpoint`` — Xet CAS server URL (e.g. ``"https://cas.xethub.hf.co"``). If + /// omitted but ``token_refresh_url`` is provided, the endpoint is fetched + /// automatically from the token refresh response. + /// + /// ``token`` and ``token_expiry_unix_secs`` — seed an initial CAS access token and + /// its expiry as a Unix timestamp (seconds). Both must be supplied together; if + /// either is absent the token is not pre-seeded. When ``token_refresh_url`` is also + /// provided, the refresh response's token is used only if no token was pre-seeded here. + /// + /// ``token_refresh_url`` — URL called with an HTTP GET whenever the current CAS token + /// is about to expire. The response must be JSON: + /// ``{"accessToken": "…", "exp": , "casUrl": "…"}``. + /// + /// ``token_refresh_headers`` — HTTP headers sent with every token refresh request + /// (e.g. ``{"Authorization": "Bearer hf_…"}``). Defaults to ``{}`` when + /// ``token_refresh_url`` is set but headers are omitted. + /// + /// ``custom_headers`` — additional HTTP headers forwarded with every CAS request. + /// + /// ``progress_callback`` — callable invoked every ``progress_interval_ms`` + /// milliseconds with ``(GroupProgressReport, dict[UniqueID, ItemProgressReport])``. + /// + /// ``progress_interval_ms`` — milliseconds between progress callbacks (default ``100``). + /// + /// Example: + /// + /// ```text + /// with session.new_upload_commit( + /// endpoint="https://cas.xethub.hf.co", + /// token="jwt", token_expiry_unix_secs=9999999999, + /// token_refresh_url="https://…/xet-write-token/main", + /// token_refresh_headers={"Authorization": "Bearer hf_…"}, + /// progress_callback=on_progress, + /// ) as commit: + /// commit.start_upload_file("/path/to/model.bin") + /// ``` + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = ( + endpoint=None, token=None, token_expiry_unix_secs=None, + token_refresh_url=None, token_refresh_headers=None, + custom_headers=None, progress_callback=None, progress_interval_ms=100 + ))] + pub fn new_upload_commit( + &self, + py: Python<'_>, + endpoint: Option, + token: Option, + token_expiry_unix_secs: Option, + token_refresh_url: Option, + token_refresh_headers: Option>, + custom_headers: Option>, + progress_callback: Option>, + progress_interval_ms: u64, + ) -> PyResult { + build_upload_commit( + py, + &self.inner, + endpoint, + token, + token_expiry_unix_secs, + token_refresh_url, + token_refresh_headers, + custom_headers, + progress_callback, + progress_interval_ms, + ) + } + + /// Create a new :class:`XetFileDownloadGroup` and establish the CAS connection. + /// + /// All parameters are optional. Releases the GIL during the blocking + /// network handshake. + /// + /// ``endpoint`` — Xet CAS server URL (e.g. ``"https://cas.xethub.hf.co"``). If + /// omitted but ``token_refresh_url`` is provided, the endpoint is fetched + /// automatically from the token refresh response. + /// + /// ``token`` and ``token_expiry_unix_secs`` — seed an initial CAS access token and + /// its expiry as a Unix timestamp (seconds). Both must be supplied together; if + /// either is absent the token is not pre-seeded. When ``token_refresh_url`` is also + /// provided, the refresh response's token is used only if no token was pre-seeded here. + /// + /// ``token_refresh_url`` — URL called with an HTTP GET whenever the current CAS token + /// is about to expire. The response must be JSON: + /// ``{"accessToken": "…", "exp": , "casUrl": "…"}``. + /// + /// ``token_refresh_headers`` — HTTP headers sent with every token refresh request + /// (e.g. ``{"Authorization": "Bearer hf_…"}``). Defaults to ``{}`` when + /// ``token_refresh_url`` is set but headers are omitted. + /// + /// ``custom_headers`` — additional HTTP headers forwarded with every CAS request. + /// + /// ``progress_callback`` — callable invoked every ``progress_interval_ms`` + /// milliseconds with ``(GroupProgressReport, dict[UniqueID, ItemProgressReport])``. + /// + /// ``progress_interval_ms`` — milliseconds between progress callbacks (default ``100``). + /// + /// Example: + /// + /// ```text + /// with session.new_file_download_group( + /// endpoint="https://cas.xethub.hf.co", + /// token_refresh_url="https://…/xet-read-token/main", + /// token_refresh_headers={"Authorization": "Bearer hf_…"}, + /// ) as group: + /// group.start_download_file(file_info, "/tmp/out.bin") + /// ``` + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = ( + endpoint=None, token=None, token_expiry_unix_secs=None, + token_refresh_url=None, token_refresh_headers=None, + custom_headers=None, progress_callback=None, progress_interval_ms=100 + ))] + pub fn new_file_download_group( + &self, + py: Python<'_>, + endpoint: Option, + token: Option, + token_expiry_unix_secs: Option, + token_refresh_url: Option, + token_refresh_headers: Option>, + custom_headers: Option>, + progress_callback: Option>, + progress_interval_ms: u64, + ) -> PyResult { + build_file_download_group( + py, + &self.inner, + endpoint, + token, + token_expiry_unix_secs, + token_refresh_url, + token_refresh_headers, + custom_headers, + progress_callback, + progress_interval_ms, + ) + } + + /// Create a new :class:`XetDownloadStreamGroup` and establish the CAS connection. + /// + /// All parameters are optional. Releases the GIL during the blocking + /// network handshake. + /// + /// ``endpoint`` — Xet CAS server URL (e.g. ``"https://cas.xethub.hf.co"``). If + /// omitted but ``token_refresh_url`` is provided, the endpoint is fetched + /// automatically from the token refresh response. + /// + /// ``token`` and ``token_expiry_unix_secs`` — seed an initial CAS access token and + /// its expiry as a Unix timestamp (seconds). Both must be supplied together; if + /// either is absent the token is not pre-seeded. When ``token_refresh_url`` is also + /// provided, the refresh response's token is used only if no token was pre-seeded here. + /// + /// ``token_refresh_url`` — URL called with an HTTP GET whenever the current CAS token + /// is about to expire. The response must be JSON: + /// ``{"accessToken": "…", "exp": , "casUrl": "…"}``. + /// + /// ``token_refresh_headers`` — HTTP headers sent with every token refresh request + /// (e.g. ``{"Authorization": "Bearer hf_…"}``). Defaults to ``{}`` when + /// ``token_refresh_url`` is set but headers are omitted. + /// + /// ``custom_headers`` — additional HTTP headers forwarded with every CAS request. + /// + /// Example: + /// + /// ```text + /// group = session.new_download_stream_group( + /// endpoint="https://cas.xethub.hf.co", + /// token_refresh_url="https://…/xet-read-token/main", + /// token_refresh_headers={"Authorization": "Bearer hf_…"}, + /// ) + /// for chunk in group.download_stream(file_info): + /// process(chunk) + /// ``` + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = ( + endpoint=None, token=None, token_expiry_unix_secs=None, + token_refresh_url=None, token_refresh_headers=None, + custom_headers=None + ))] + pub fn new_download_stream_group( + &self, + py: Python<'_>, + endpoint: Option, + token: Option, + token_expiry_unix_secs: Option, + token_refresh_url: Option, + token_refresh_headers: Option>, + custom_headers: Option>, + ) -> PyResult { + build_download_stream_group( + py, + &self.inner, + endpoint, + token, + token_expiry_unix_secs, + token_refresh_url, + token_refresh_headers, + custom_headers, + ) + } + + /// Current task state as a :class:`XetTaskState` enum value. Raises on error. + pub fn status(&self) -> PyResult { + task_state_to_pystate(self.inner.status()) + } + + /// Cancel all in-progress operations and shut down the underlying runtime. + /// + /// Unlike :meth:`XetUploadCommit.abort` or :meth:`XetFileDownloadGroup.abort`, + /// which cancel a single operation while leaving the session usable, this + /// method destroys the session's runtime entirely. The :class:`XetSession` + /// object must be discarded and a new one created before issuing further + /// uploads or downloads. + /// + /// Intended for use in ``except KeyboardInterrupt:`` handlers. + pub fn sigint_abort(&self) -> PyResult<()> { + self.inner.sigint_abort().map_err(convert_xet_error) + } +} diff --git a/hf_xet/src/utils.rs b/hf_xet/src/utils.rs new file mode 100644 index 000000000..7a4b9f27e --- /dev/null +++ b/hf_xet/src/utils.rs @@ -0,0 +1,159 @@ +//! Shared display helpers used across Python binding modules. + +use pyo3::prelude::*; +use xet_pkg::XetError; +use xet_pkg::xet_session::{ItemProgressReport, XetTaskState}; + +// ── Error conversion ────────────────────────────────────────────────────────── + +pub(crate) fn convert_xet_error(e: impl Into) -> PyErr { + PyErr::from(e.into()) +} + +// ── Signal-checked blocking call ────────────────────────────────────────────── + +/// Run `f` on a background thread while periodically calling `py.check_signals()` +/// so that Ctrl-C is delivered to Python promptly during long-running operations. +/// +/// Background: Python handles SIGINT by setting a flag that is only checked when +/// control returns to the interpreter. While a blocking Rust call holds the GIL +/// (even via `py.detach()`), that flag is never observed. By running the blocking +/// work on a separate thread and polling `py.check_signals()` every 100 ms, we +/// give CPython a chance to raise `KeyboardInterrupt` during calls like +/// `commit_blocking()` and `finish_blocking()` that can run for many seconds. +/// +/// Reference: — "Ctrl-C doesn't do anything +/// while my Rust code is executing" +/// +/// When `KeyboardInterrupt` is raised here: +/// - It propagates to the Python caller's `except KeyboardInterrupt:` block. +/// - The caller calls `session.sigint_abort()` → sets `sigint_shutdown = true`. +/// - The background thread's next await checkpoint returns immediately. +/// - The thread exits cleanly. +/// +/// This is the most Pythonic approach: `KeyboardInterrupt` is a standard exception +/// that propagates through `try/except` like any other. The cleanup +/// (`sigint_abort()`) lives in Python where it is visible and auditable — no +/// hidden global state, no surprise interactions with `signal.signal()`. +pub(crate) fn blocking_call_with_signal_check(py: Python<'_>, f: F) -> PyResult +where + T: Send + 'static, + E: Into + Send + 'static, + F: FnOnce() -> Result + Send + 'static, +{ + use std::sync::mpsc::RecvTimeoutError; + use std::time::Duration; + + let (tx, mut rx) = std::sync::mpsc::channel(); + std::thread::spawn(move || { + tx.send(f()).ok(); + }); + loop { + // Release the GIL while waiting so that other threads that need it + // (e.g. the progress-callback thread, other Python threads) can run. + // After `detach` returns we re-hold the GIL for the signal check. + // + // `Receiver: Send` but `!Sync`, so `&Receiver` is `!Ungil`. We + // satisfy the `Ungil` bound by moving `rx` into a `move` closure and + // returning it alongside the result so it can be rebound for the next + // iteration. + let (rx_back, recv_result) = py.detach(move || { + let result = rx.recv_timeout(Duration::from_millis(100)); + (rx, result) + }); + rx = rx_back; + match recv_result { + Ok(result) => return result.map_err(|e| convert_xet_error(e)), + Err(RecvTimeoutError::Timeout) => py.check_signals()?, + // The sender was dropped without sending — the background thread panicked. + // Return a recoverable error rather than panicking a second time, which + // would crash the Python interpreter in a PyO3 context. + Err(RecvTimeoutError::Disconnected) => { + return Err(pyo3::exceptions::PyRuntimeError::new_err( + "blocking operation panicked on background thread", + )); + }, + } + } +} + +// ── Task-state helpers ──────────────────────────────────────────────────────── + +/// Convert a Rust `status()` result to a Python [`crate::PyXetTaskState`], raising on error. +pub(crate) fn task_state_to_pystate( + result: Result, +) -> PyResult { + match result { + Ok(XetTaskState::Running) => Ok(crate::PyXetTaskState::Running), + Ok(XetTaskState::Finalizing) => Ok(crate::PyXetTaskState::Finalizing), + Ok(XetTaskState::Completed) => Ok(crate::PyXetTaskState::Completed), + Ok(XetTaskState::UserCancelled) => Ok(crate::PyXetTaskState::UserCancelled), + Ok(XetTaskState::Error(msg)) => Err(convert_xet_error(xet_pkg::XetError::TaskError(msg))), + Err(e) => Err(convert_xet_error(e)), + } +} + +/// Convert a `status()` result to a display string for use in `__repr__`. +/// +/// Never discards information: the error message inside `XetTaskState::Error(msg)` +/// and any `XetError` from the `status()` call itself are both forwarded as the string. +pub(crate) fn task_state_display(result: Result) -> String { + match result { + Ok(XetTaskState::Running) => "Running".to_string(), + Ok(XetTaskState::Finalizing) => "Finalizing".to_string(), + Ok(XetTaskState::Completed) => "Completed".to_string(), + Ok(XetTaskState::UserCancelled) => "UserCancelled".to_string(), + Ok(XetTaskState::Error(msg)) => msg, + Err(e) => e.to_string(), + } +} + +// ── Progress helpers ────────────────────────────────────────────────────────── + +/// Format an `Option` as `"bytes_completed/total_bytes"`, +/// or `"?/?"` if no report is available yet. +pub(crate) fn progress_display(progress: Option) -> String { + match progress { + Some(r) => format!("{}/{}", r.bytes_completed, r.total_bytes), + None => "?/?".to_string(), + } +} + +#[cfg(test)] +mod tests { + use pyo3::Python; + + use super::*; + + #[test] + fn test_task_state_error_raises() { + Python::attach(|_py| { + let result = task_state_to_pystate(Ok(XetTaskState::Error("something went wrong".into()))); + let msg = result.unwrap_err().to_string(); + assert!(msg.contains("something went wrong")); + }); + } + + #[test] + fn test_task_state_outer_error_raises() { + Python::attach(|_py| { + let result = task_state_to_pystate(Err(xet_pkg::XetError::TaskError("outer error".into()))); + assert!(result.is_err()); + }); + } + + #[test] + fn test_progress_display_some() { + let report = ItemProgressReport { + item_name: "f".into(), + total_bytes: 100, + bytes_completed: 42, + }; + assert_eq!(progress_display(Some(report)), "42/100"); + } + + #[test] + fn test_progress_display_none() { + assert_eq!(progress_display(None), "?/?"); + } +} diff --git a/hf_xet/tests/conftest.py b/hf_xet/tests/conftest.py new file mode 100644 index 000000000..be59a3468 --- /dev/null +++ b/hf_xet/tests/conftest.py @@ -0,0 +1,47 @@ +""" +Shared pytest fixtures and upload helpers used across all test modules. + +Run after building the extension: + cd hf_xet && maturin develop + pytest tests/ -v +""" + +import pytest +import hf_xet + + +@pytest.fixture +def endpoint(tmp_path): + """Local CAS endpoint backed by a per-test temp directory.""" + return f"local://{tmp_path / 'cas'}" + + +# ── Upload helpers ──────────────────────────────────────────────────────────── + +def upload_bytes_get_info(endpoint: str, data: bytes) -> hf_xet.XetFileInfo: + """Upload raw bytes, commit, and return the resulting XetFileInfo.""" + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + h = commit.start_upload_bytes(data, sha256=hf_xet.SKIP_SHA256) + commit.wait_to_finish() + return h.result().xet_info + + +def upload_file_get_info(endpoint: str, tmp_path, data: bytes) -> hf_xet.XetFileInfo: + """Write data to a temp file, upload it, and return the resulting XetFileInfo.""" + import uuid + src = tmp_path / f"upload_src_{uuid.uuid4().hex}.bin" + src.write_bytes(data) + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + h = commit.start_upload_file(str(src), sha256=hf_xet.SKIP_SHA256) + commit.wait_to_finish() + return h.result().xet_info + + +def upload_stream_get_info(endpoint: str, data: bytes) -> hf_xet.XetFileInfo: + """Upload data via upload_stream and return the resulting XetFileInfo.""" + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + stream = commit.start_upload_stream() + stream.write(data) + r = stream.finish() + commit.wait_to_finish() + return r.xet_info diff --git a/hf_xet/tests/test_config.py b/hf_xet/tests/test_config.py new file mode 100644 index 000000000..d7a2b433b --- /dev/null +++ b/hf_xet/tests/test_config.py @@ -0,0 +1,74 @@ +""" +Tests for XetConfig: construction, field access, update, and type coercions. +""" + +from datetime import timedelta + +import pytest + +import hf_xet + + +class TestXetConfig: + def test_keys_items_len_and_iteration_are_consistent(self): + config = hf_xet.XetConfig() + keys = config.keys() + assert len(config) == len(keys) + assert frozenset(keys) == frozenset(k for k, _ in config.items()) + assert sum(1 for _ in config) == len(keys) + for k, _ in config: + assert isinstance(k, str) + + def test_get_and_getitem_agree_for_valid_paths(self): + config = hf_xet.XetConfig() + assert config.get("data.max_concurrent_file_ingestion") == config[ + "data.max_concurrent_file_ingestion" + ] + assert config.get("data.progress_update_interval") == config[ + "data.progress_update_interval" + ] + + def test_with_config_dict_sets_distinct_field_types(self): + # Duration accepts either a humantime string ("501ms") or a timedelta. + cfg = hf_xet.XetConfig().with_config( + { + "data.max_concurrent_file_ingestion": 4, + "data.progress_update_interval": "501ms", + "data.local_cas_scheme": "local://test-scheme/", + } + ) + assert cfg.get("data.max_concurrent_file_ingestion") == 4 + assert cfg.get("data.progress_update_interval") == timedelta(milliseconds=501) + assert cfg.get("data.local_cas_scheme") == "local://test-scheme/" + + def test_with_config_duration_accepts_timedelta(self): + cfg = hf_xet.XetConfig().with_config("data.progress_update_interval", timedelta(milliseconds=501)) + assert cfg.get("data.progress_update_interval") == timedelta(milliseconds=501) + + def test_with_config_is_immutable_on_original(self): + base = hf_xet.XetConfig() + before = base.get("data.max_concurrent_file_ingestion") + updated = base.with_config("data.max_concurrent_file_ingestion", 999) + assert base.get("data.max_concurrent_file_ingestion") == before + assert updated.get("data.max_concurrent_file_ingestion") == 999 + + def test_with_config_rejects_invalid_arguments(self): + with pytest.raises(TypeError, match="second argument"): + hf_xet.XetConfig().with_config( + {"data.max_concurrent_file_ingestion": 1}, + "extra", + ) + with pytest.raises(TypeError, match="value argument"): + hf_xet.XetConfig().with_config("data.max_concurrent_file_ingestion") + + def test_get_raises_value_error_for_invalid_paths(self): + with pytest.raises(ValueError): + hf_xet.XetConfig().get("no_dot_segment") + with pytest.raises(ValueError): + hf_xet.XetConfig().get("not_a_real_config_group.field") + + def test_getitem_raises_key_error_for_unknown_path(self): + missing = "definitely_missing.unlikely_xyz_field_quack" + with pytest.raises(KeyError) as ei: + hf_xet.XetConfig()[missing] + assert ei.value.args[0] == missing diff --git a/hf_xet/tests/test_file_download.py b/hf_xet/tests/test_file_download.py new file mode 100644 index 000000000..45a9ddadd --- /dev/null +++ b/hf_xet/tests/test_file_download.py @@ -0,0 +1,139 @@ +""" +Tests for XetFileDownloadGroup and XetFileDownload handles. + +Not covered here (require a real CAS server): + - token, token_refresh_url, custom_headers kwargs +""" + +import hf_xet +from conftest import upload_bytes_get_info, upload_file_get_info, upload_stream_get_info + + +# ── XetFileDownloadGroup ────────────────────────────────────────────────────── + +class TestFileDownloadGroup: + def test_file_written_to_disk(self, endpoint, tmp_path): + data = b"download file content" + info = upload_bytes_get_info(endpoint, data) + dest = tmp_path / "out.bin" + with hf_xet.XetSession().new_file_download_group(endpoint=endpoint) as group: + group.start_download_file(info, str(dest)) + assert dest.exists() + assert dest.read_bytes() == data + + def test_result_has_correct_file_size(self, endpoint, tmp_path): + data = b"result check" + info = upload_bytes_get_info(endpoint, data) + dest = tmp_path / "out.bin" + group = hf_xet.XetSession().new_file_download_group(endpoint=endpoint) + h = group.start_download_file(info, str(dest)) + group.wait_to_finish() + result = h.result() + assert result.file_info.file_size == len(data) + assert result.file_info.hash + + def test_task_id_matches_group_report(self, endpoint, tmp_path): + data = b"report match" + info = upload_bytes_get_info(endpoint, data) + group = hf_xet.XetSession().new_file_download_group(endpoint=endpoint) + h = group.start_download_file(info, str(tmp_path / "out.bin")) + report = group.wait_to_finish() + assert h.task_id() in report.downloads + + def test_multiple_files(self, endpoint, tmp_path): + payloads = [f"file {i}".encode() for i in range(3)] + infos = [upload_bytes_get_info(endpoint, d) for d in payloads] + group = hf_xet.XetSession().new_file_download_group(endpoint=endpoint) + handles = [group.start_download_file(info, str(tmp_path / f"out{i}.bin")) + for i, info in enumerate(infos)] + group.wait_to_finish() + for h, data in zip(handles, payloads): + assert h.result().file_info.file_size == len(data) + + def test_round_trip_via_upload_file(self, endpoint, tmp_path): + payloads = [f"upload_file content {i}".encode() for i in range(3)] + infos = [upload_file_get_info(endpoint, tmp_path, d) for d in payloads] + dests = [tmp_path / f"out{i}.bin" for i in range(len(payloads))] + with hf_xet.XetSession().new_file_download_group(endpoint=endpoint) as group: + for info, dest in zip(infos, dests): + group.start_download_file(info, str(dest)) + for dest, data in zip(dests, payloads): + assert dest.read_bytes() == data + + def test_round_trip_via_upload_stream(self, endpoint, tmp_path): + payloads = [f"upload_stream content {i}".encode() for i in range(3)] + infos = [upload_stream_get_info(endpoint, d) for d in payloads] + dests = [tmp_path / f"out{i}.bin" for i in range(len(payloads))] + with hf_xet.XetSession().new_file_download_group(endpoint=endpoint) as group: + for info, dest in zip(infos, dests): + group.start_download_file(info, str(dest)) + for dest, data in zip(dests, payloads): + assert dest.read_bytes() == data + + def test_status_is_valid_state(self, endpoint, tmp_path): + info = upload_bytes_get_info(endpoint, b"status check") + group = hf_xet.XetSession().new_file_download_group(endpoint=endpoint) + group.start_download_file(info, str(tmp_path / "out.bin")) + assert group.status() == hf_xet.XetTaskState.Running + group.wait_to_finish() + + def test_progress_returns_report(self, endpoint): + group = hf_xet.XetSession().new_file_download_group(endpoint=endpoint) + report = group.progress() + assert hasattr(report, "total_bytes_completed") + group.wait_to_finish() + + def test_abort_makes_finish_fail(self, endpoint): + group = hf_xet.XetSession().new_file_download_group(endpoint=endpoint) + group.abort() + try: + group.wait_to_finish() + assert False, "expected finish to raise after abort" + except Exception: + pass + + def test_context_manager_calls_abort_on_exception(self, endpoint, tmp_path): + info = upload_bytes_get_info(endpoint, b"abort download") + raised = False + try: + with hf_xet.XetSession().new_file_download_group(endpoint=endpoint) as group: + group.start_download_file(info, str(tmp_path / "out.bin")) + raise RuntimeError("intentional error") + except RuntimeError: + raised = True + assert raised + + +# ── XetFileDownload handle ──────────────────────────────────────────────────── + +class TestFileDownloadHandle: + def test_try_result_after_finish_is_not_none(self, endpoint, tmp_path): + info = upload_bytes_get_info(endpoint, b"try result data") + group = hf_xet.XetSession().new_file_download_group(endpoint=endpoint) + h = group.start_download_file(info, str(tmp_path / "out.bin")) + group.wait_to_finish() + result = h.try_result() + assert result is not None + assert result.file_info.file_size == len(b"try result data") + + def test_cancel_does_not_raise(self, endpoint, tmp_path): + info = upload_bytes_get_info(endpoint, b"cancel target") + group = hf_xet.XetSession().new_file_download_group(endpoint=endpoint) + h = group.start_download_file(info, str(tmp_path / "out.bin")) + h.cancel() # should not raise; download may or may not have completed + + def test_status_is_valid_state(self, endpoint, tmp_path): + info = upload_bytes_get_info(endpoint, b"status data") + group = hf_xet.XetSession().new_file_download_group(endpoint=endpoint) + h = group.start_download_file(info, str(tmp_path / "out.bin")) + group.wait_to_finish() + assert h.status() == hf_xet.XetTaskState.Completed + + def test_task_id_is_not_none(self, endpoint, tmp_path): + info = upload_bytes_get_info(endpoint, b"task id data") + group = hf_xet.XetSession().new_file_download_group(endpoint=endpoint) + h = group.start_download_file(info, str(tmp_path / "out.bin")) + assert h.task_id() is not None + group.wait_to_finish() + + diff --git a/hf_xet/tests/test_progress.py b/hf_xet/tests/test_progress.py new file mode 100644 index 000000000..1e384c547 --- /dev/null +++ b/hf_xet/tests/test_progress.py @@ -0,0 +1,126 @@ +""" +Tests for progress callbacks on XetUploadCommit and XetFileDownloadGroup. + +The callback receives: + arg1: GroupProgressReport — aggregate bytes (total_bytes, total_bytes_completed, …) + arg2: dict[UniqueID, ItemProgressReport] — per-file (item_name, bytes_completed, total_bytes) +""" + +import threading + +import hf_xet +from conftest import upload_bytes_get_info + + +# ── Upload progress ─────────────────────────────────────────────────────────── + +class TestUploadProgressCallback: + def test_callback_receives_correct_argument_types(self, endpoint): + calls = [] + fired = threading.Event() + + def on_progress(group, items): + calls.append((group, items)) + fired.set() + + (hf_xet.XetSession() + .new_upload_commit(endpoint=endpoint, progress_callback=on_progress, progress_interval_ms=10) + .wait_to_finish()) + + fired.wait(timeout=1.0) + assert len(calls) > 0 + group_report, item_reports = calls[0] + assert hasattr(group_report, "total_bytes_completed") + assert hasattr(group_report, "total_bytes") + assert isinstance(item_reports, dict) + + def test_item_names_match_uploaded_names(self, endpoint): + seen_names = set() + fired = threading.Event() + + def on_progress(_, items): + for item in items.values(): + seen_names.add(item.item_name) + fired.set() + + commit = hf_xet.XetSession().new_upload_commit( + endpoint=endpoint, progress_callback=on_progress, progress_interval_ms=10 + ) + commit.start_upload_bytes(b"file a", name="a.bin", sha256=hf_xet.SKIP_SHA256) + commit.start_upload_bytes(b"file b", name="b.bin", sha256=hf_xet.SKIP_SHA256) + commit.wait_to_finish() + + fired.wait(timeout=1.0) + assert "a.bin" in seen_names + assert "b.bin" in seen_names + + def test_item_progress_has_expected_fields(self, endpoint): + items_seen = [] + fired = threading.Event() + + def on_progress(_, items): + for item in items.values(): + items_seen.append(item) + fired.set() + + commit = hf_xet.XetSession().new_upload_commit( + endpoint=endpoint, progress_callback=on_progress, progress_interval_ms=10 + ) + commit.start_upload_bytes(b"progress fields", name="p.bin", sha256=hf_xet.SKIP_SHA256) + commit.wait_to_finish() + + fired.wait(timeout=1.0) + assert len(items_seen) > 0 + item = items_seen[0] + assert hasattr(item, "item_name") + assert hasattr(item, "bytes_completed") + assert hasattr(item, "total_bytes") + + +# ── Download progress ───────────────────────────────────────────────────────── + +class TestDownloadProgressCallback: + def test_callback_receives_correct_argument_types(self, endpoint, tmp_path): + data = b"progress download content" + info = upload_bytes_get_info(endpoint, data) + dest = tmp_path / "out.bin" + + calls = [] + fired = threading.Event() + + def on_progress(group, items): + calls.append((group, items)) + fired.set() + + with hf_xet.XetSession().new_file_download_group( + endpoint=endpoint, progress_callback=on_progress, progress_interval_ms=10 + ) as group: + group.start_download_file(info, str(dest)) + + fired.wait(timeout=1.0) + assert len(calls) > 0 + group_report, item_reports = calls[0] + assert hasattr(group_report, "total_bytes_completed") + assert hasattr(group_report, "total_bytes") + assert isinstance(item_reports, dict) + + def test_item_count_matches_file_count(self, endpoint, tmp_path): + payloads = {f"f{i}.bin": f"content {i}".encode() for i in range(3)} + infos = {name: upload_bytes_get_info(endpoint, data) for name, data in payloads.items()} + + seen_names = set() + fired = threading.Event() + + def on_progress(_, items): + for item in items.values(): + seen_names.add(item.item_name) + fired.set() + + with hf_xet.XetSession().new_file_download_group( + endpoint=endpoint, progress_callback=on_progress, progress_interval_ms=10 + ) as group: + for name, info in infos.items(): + group.start_download_file(info, str(tmp_path / name)) + + fired.wait(timeout=1.0) + assert len(seen_names) == len(payloads) diff --git a/hf_xet/tests/test_session.py b/hf_xet/tests/test_session.py new file mode 100644 index 000000000..7d439e136 --- /dev/null +++ b/hf_xet/tests/test_session.py @@ -0,0 +1,36 @@ +""" +Tests for XetSession itself: status, sigint_abort, and factory methods. +""" + +import hf_xet + + +class TestXetSession: + def test_status_returns_valid_state(self): + session = hf_xet.XetSession() + assert session.status() == hf_xet.XetTaskState.Running + + def test_sigint_abort_does_not_raise(self): + # Uses a dedicated session; sigint_abort shuts down the internal runtime. + session = hf_xet.XetSession() + session.sigint_abort() # should not raise + + def test_session_with_config_applies_overrides(self, endpoint): + config = hf_xet.XetConfig().with_config("data.max_concurrent_file_ingestion", 2) + session = hf_xet.XetSession(config=config) + assert session.status() == hf_xet.XetTaskState.Running + # Verify the session is usable with the custom config. + commit = session.new_upload_commit(endpoint=endpoint) + assert commit is not None + + def test_new_upload_commit_creates_commit(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + assert commit is not None + + def test_new_file_download_group_creates_group(self, endpoint): + group = hf_xet.XetSession().new_file_download_group(endpoint=endpoint) + assert group is not None + + def test_new_download_stream_group_creates_group(self, endpoint): + group = hf_xet.XetSession().new_download_stream_group(endpoint=endpoint) + assert group is not None diff --git a/hf_xet/tests/test_stream_download.py b/hf_xet/tests/test_stream_download.py new file mode 100644 index 000000000..da7d38289 --- /dev/null +++ b/hf_xet/tests/test_stream_download.py @@ -0,0 +1,218 @@ +""" +Tests for XetDownloadStreamGroup: ordered and unordered streaming downloads. + +_LARGE_SIZE = 300 KB is used throughout to exercise the full reconstruction +path while keeping tests fast. + +Covers: + - Full-file ordered stream (small and large) + - Bounded range (start + end) on large files + - Open-ended range (start only, end only) on large files + - cancel() + - Multiple concurrent streams from the same group + - Full-file unordered stream (reassemble from offsets), small and large + - Bounded range unordered on large files + - Open-ended range unordered on large files +Not covered here (require a real CAS server): + - token, token_refresh_url, custom_headers kwargs +""" + +import pytest + +import hf_xet +from conftest import upload_bytes_get_info + + +# ── Shared data ─────────────────────────────────────────────────────────────── + +DATA = b"0123456789abcdef" # 16 bytes — known content for slice assertions + +_LARGE_SIZE = 300 * 1024 # 300 KB — well above the 128 KB max chunk size +# Deterministic byte pattern: 0x00..0xFF repeating, so the content varies +# throughout the file and won't collapse into a single deduplicated chunk. +_LARGE_DATA = (bytes(range(256)) * (_LARGE_SIZE // 256 + 1))[:_LARGE_SIZE] + +# Offsets chosen to land strictly inside the file and well away from both ends, +# so byte-range slicing is exercised without needing multiple stream yields. +_RANGE_START = 50_000 +_RANGE_END = 250_000 + + +@pytest.fixture(scope="module") +def large_file_endpoint(tmp_path_factory): + """Upload _LARGE_DATA once per module and return (XetFileInfo, endpoint).""" + tmp = tmp_path_factory.mktemp("large_stream") + ep = f"local://{tmp / 'cas'}" + info = upload_bytes_get_info(ep, _LARGE_DATA) + return info, ep + + +# ── XetDownloadStreamGroup (ordered) ───────────────────────────────────────── + +class TestDownloadStream: + # ── small-file correctness ──────────────────────────────────────────────── + + def test_full_file_reassembles(self, endpoint): + data = b"ordered stream content" + info = upload_bytes_get_info(endpoint, data) + group = hf_xet.XetSession().new_download_stream_group(endpoint=endpoint) + chunks = list(group.download_stream(info)) + assert b"".join(chunks) == data + + def test_bounded_range(self, endpoint): + info = upload_bytes_get_info(endpoint, DATA) + group = hf_xet.XetSession().new_download_stream_group(endpoint=endpoint) + chunks = list(group.download_stream(info, start=4, end=12)) + assert b"".join(chunks) == DATA[4:12] + + def test_open_ended_start(self, endpoint): + """start=N with no end streams from N to EOF.""" + info = upload_bytes_get_info(endpoint, DATA) + group = hf_xet.XetSession().new_download_stream_group(endpoint=endpoint) + chunks = list(group.download_stream(info, start=8)) + assert b"".join(chunks) == DATA[8:] + + def test_open_ended_end(self, endpoint): + """end=N with no start streams from 0 to N.""" + info = upload_bytes_get_info(endpoint, DATA) + group = hf_xet.XetSession().new_download_stream_group(endpoint=endpoint) + chunks = list(group.download_stream(info, end=8)) + assert b"".join(chunks) == DATA[:8] + + def test_cancel_stops_iteration(self, endpoint): + data = b"cancel me" + info = upload_bytes_get_info(endpoint, data) + group = hf_xet.XetSession().new_download_stream_group(endpoint=endpoint) + stream = group.download_stream(info) + stream.cancel() + chunks = list(stream) + assert isinstance(chunks, list) # no crash; may be empty or partial + + def test_multiple_concurrent_streams_same_group(self, endpoint): + """The same group object can serve multiple independent streams.""" + data_a = b"stream A content" + data_b = b"stream B content" + info_a = upload_bytes_get_info(endpoint, data_a) + info_b = upload_bytes_get_info(endpoint, data_b) + group = hf_xet.XetSession().new_download_stream_group(endpoint=endpoint) + result_a = b"".join(group.download_stream(info_a)) + result_b = b"".join(group.download_stream(info_b)) + assert result_a == data_a + assert result_b == data_b + + # ── large-file multi-chunk paths ────────────────────────────────────────── + + def test_large_file_full_reassembles(self, large_file_endpoint): + """300 KB file spans multiple chunks; ordered stream must reassemble all bytes.""" + info, ep = large_file_endpoint + group = hf_xet.XetSession().new_download_stream_group(endpoint=ep) + result = b"".join(group.download_stream(info)) + assert result == _LARGE_DATA + + def test_large_file_bounded_range(self, large_file_endpoint): + """Range [50 000, 250 000] on a 300 KB file; must return exact slice.""" + info, ep = large_file_endpoint + group = hf_xet.XetSession().new_download_stream_group(endpoint=ep) + result = b"".join(group.download_stream(info, start=_RANGE_START, end=_RANGE_END)) + assert result == _LARGE_DATA[_RANGE_START:_RANGE_END] + + def test_large_file_open_ended_start(self, large_file_endpoint): + """start=N on a large file streams from N to EOF.""" + info, ep = large_file_endpoint + group = hf_xet.XetSession().new_download_stream_group(endpoint=ep) + result = b"".join(group.download_stream(info, start=_RANGE_START)) + assert result == _LARGE_DATA[_RANGE_START:] + + def test_large_file_open_ended_end(self, large_file_endpoint): + """end=N on a large file streams from 0 to N.""" + info, ep = large_file_endpoint + group = hf_xet.XetSession().new_download_stream_group(endpoint=ep) + result = b"".join(group.download_stream(info, end=_RANGE_END)) + assert result == _LARGE_DATA[:_RANGE_END] + + +# ── XetDownloadStreamGroup (unordered) ─────────────────────────────────────── + +class TestDownloadUnorderedStream: + def _reassemble(self, chunks_iter, total: int) -> bytes: + buf = bytearray(total) + for offset, chunk in chunks_iter: + buf[offset:offset + len(chunk)] = chunk + return bytes(buf) + + def _reassemble_range(self, chunks_iter) -> bytes: + pieces = {} + for offset, chunk in chunks_iter: + pieces[offset] = chunk + return b"".join(pieces[k] for k in sorted(pieces)) + + # ── small-file correctness ──────────────────────────────────────────────── + + def test_full_file_reassembles(self, endpoint): + data = b"unordered stream content" + info = upload_bytes_get_info(endpoint, data) + group = hf_xet.XetSession().new_download_stream_group(endpoint=endpoint) + result = self._reassemble(group.download_unordered_stream(info), len(data)) + assert result == data + + def test_bounded_range(self, endpoint): + info = upload_bytes_get_info(endpoint, DATA) + group = hf_xet.XetSession().new_download_stream_group(endpoint=endpoint) + assembled = self._reassemble_range(group.download_unordered_stream(info, start=2, end=10)) + assert assembled == DATA[2:10] + + def test_open_ended_start(self, endpoint): + """start=N with no end streams from N to EOF.""" + info = upload_bytes_get_info(endpoint, DATA) + group = hf_xet.XetSession().new_download_stream_group(endpoint=endpoint) + assembled = self._reassemble_range(group.download_unordered_stream(info, start=8)) + assert assembled == DATA[8:] + + def test_open_ended_end(self, endpoint): + """end=N with no start streams from 0 to N.""" + info = upload_bytes_get_info(endpoint, DATA) + group = hf_xet.XetSession().new_download_stream_group(endpoint=endpoint) + assembled = self._reassemble_range(group.download_unordered_stream(info, end=8)) + assert assembled == DATA[:8] + + # ── large-file multi-chunk paths ────────────────────────────────────────── + + def test_large_file_full_reassembles(self, large_file_endpoint): + """300 KB unordered stream must reassemble to the original bytes.""" + info, ep = large_file_endpoint + group = hf_xet.XetSession().new_download_stream_group(endpoint=ep) + result = self._reassemble(group.download_unordered_stream(info), _LARGE_SIZE) + assert result == _LARGE_DATA + + def test_large_file_offsets_are_valid(self, large_file_endpoint): + """Every (offset, chunk) pair must lie within the file bounds.""" + info, ep = large_file_endpoint + group = hf_xet.XetSession().new_download_stream_group(endpoint=ep) + for offset, chunk in group.download_unordered_stream(info): + assert 0 <= offset < _LARGE_SIZE + assert offset + len(chunk) <= _LARGE_SIZE + + def test_large_file_bounded_range(self, large_file_endpoint): + """Range [50 000, 250 000] unordered on a 300 KB file must reassemble to the correct slice.""" + info, ep = large_file_endpoint + group = hf_xet.XetSession().new_download_stream_group(endpoint=ep) + assembled = self._reassemble_range( + group.download_unordered_stream(info, start=_RANGE_START, end=_RANGE_END) + ) + assert assembled == _LARGE_DATA[_RANGE_START:_RANGE_END] + + def test_large_file_open_ended_start(self, large_file_endpoint): + """start=N unordered on a large file streams N..EOF correctly.""" + info, ep = large_file_endpoint + group = hf_xet.XetSession().new_download_stream_group(endpoint=ep) + assembled = self._reassemble_range(group.download_unordered_stream(info, start=_RANGE_START)) + assert assembled == _LARGE_DATA[_RANGE_START:] + + def test_large_file_open_ended_end(self, large_file_endpoint): + """end=N unordered on a large file streams 0..N correctly.""" + info, ep = large_file_endpoint + group = hf_xet.XetSession().new_download_stream_group(endpoint=ep) + assembled = self._reassemble_range(group.download_unordered_stream(info, end=_RANGE_END)) + assert assembled == _LARGE_DATA[:_RANGE_END] + + diff --git a/hf_xet/tests/test_upload_commit.py b/hf_xet/tests/test_upload_commit.py new file mode 100644 index 000000000..a59d1b04f --- /dev/null +++ b/hf_xet/tests/test_upload_commit.py @@ -0,0 +1,224 @@ +""" +Tests for XetUploadCommit: upload_file, upload_bytes, upload_stream, sha256 sentinels. + +Not covered here (require a real CAS server): + - with_token_info / with_token_refresh_url / with_custom_headers on the builder +""" + +import hf_xet + + +# ── sha256 sentinels ────────────────────────────────────────────────────────── + +class TestSha256Sentinels: + def test_compute_sentinel_is_not_none(self): + assert hf_xet.COMPUTE_SHA256 is not None + + def test_skip_sentinel_is_not_none(self): + assert hf_xet.SKIP_SHA256 is not None + + def test_sentinels_have_repr(self): + assert repr(hf_xet.COMPUTE_SHA256) == "COMPUTE_SHA256" + assert repr(hf_xet.SKIP_SHA256) == "SKIP_SHA256" + + + +# ── upload_file ─────────────────────────────────────────────────────────────── + +class TestUploadFile: + def test_result_has_correct_file_size(self, endpoint, tmp_path): + data = b"upload_file content" + src = tmp_path / "src.bin" + src.write_bytes(data) + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + h = commit.start_upload_file(str(src), sha256=hf_xet.SKIP_SHA256) + commit.wait_to_finish() + result = h.result() + assert result.xet_info.file_size == len(data) + assert result.xet_info.hash + assert result.xet_info.sha256 is None + + def test_sha256_computed_for_file(self, endpoint, tmp_path): + src = tmp_path / "src.bin" + src.write_bytes(b"sha256 file") + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + h = commit.start_upload_file(str(src), sha256=hf_xet.COMPUTE_SHA256) + commit.wait_to_finish() + result = h.result() + assert result.xet_info.sha256 is not None + assert len(result.xet_info.sha256) == 64 + + def test_sha256_provided_as_string_for_file(self, endpoint, tmp_path): + src = tmp_path / "src.bin" + src.write_bytes(b"provided sha256 file") + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + precomputed = "b" * 64 + h = commit.start_upload_file(str(src), sha256=precomputed) + commit.wait_to_finish() + assert h.result().xet_info.sha256 == precomputed + + def test_try_result_after_commit(self, endpoint, tmp_path): + src = tmp_path / "src.bin" + src.write_bytes(b"try result") + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + h = commit.start_upload_file(str(src), sha256=hf_xet.SKIP_SHA256) + commit.wait_to_finish() + result = h.try_result() + assert result is not None + assert result.xet_info.file_size == len(b"try result") + + def test_task_id_is_positive(self, endpoint, tmp_path): + src = tmp_path / "src.bin" + src.write_bytes(b"task id") + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + h = commit.start_upload_file(str(src), sha256=hf_xet.SKIP_SHA256) + commit.wait_to_finish() + assert h.task_id() is not None + + +# ── upload_bytes ────────────────────────────────────────────────────────────── + +class TestUploadBytes: + def test_result_has_correct_file_size(self, endpoint): + data = b"hello upload bytes" + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + h = commit.start_upload_bytes(data, name="f.bin", sha256=hf_xet.SKIP_SHA256) + commit.wait_to_finish() + result = h.result() + assert result.xet_info.file_size == len(data) + assert result.xet_info.hash + + def test_sha256_computed_when_requested(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + h = commit.start_upload_bytes(b"compute sha256", sha256=hf_xet.COMPUTE_SHA256) + commit.wait_to_finish() + result = h.result() + assert result.xet_info.sha256 is not None + assert len(result.xet_info.sha256) == 64 + + def test_sha256_provided_as_string(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + precomputed = "a" * 64 + h = commit.start_upload_bytes(b"provided sha256", sha256=precomputed) + commit.wait_to_finish() + assert h.result().xet_info.sha256 == precomputed + + def test_sha256_skipped_when_requested(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + h = commit.start_upload_bytes(b"skip sha256", sha256=hf_xet.SKIP_SHA256) + commit.wait_to_finish() + assert h.result().xet_info.sha256 is None + + def test_commit_report_contains_result(self, endpoint): + data = b"report content" + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + h = commit.start_upload_bytes(data, sha256=hf_xet.SKIP_SHA256) + report = commit.wait_to_finish() + result = report.uploads[h.task_id()] + assert result.xet_info.file_size == len(data) + assert result.xet_info.hash + + def test_multiple_files_in_one_commit(self, endpoint): + files = {f"f{i}.bin": f"content {i}".encode() for i in range(4)} + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + handles = {name: commit.start_upload_bytes(data, name=name, sha256=hf_xet.SKIP_SHA256) + for name, data in files.items()} + commit.wait_to_finish() + for name, h in handles.items(): + assert h.result().xet_info.file_size == len(files[name]) + + def test_status_is_valid_state(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + assert commit.status() == hf_xet.XetTaskState.Running + + def test_progress_returns_report(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + report = commit.progress() + assert hasattr(report, "total_bytes_completed") + + def test_abort_makes_commit_fail(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + commit.abort() + try: + commit.wait_to_finish() + assert False, "expected commit to raise after abort" + except Exception: + pass + + def test_context_manager_commits_on_normal_exit(self, endpoint): + with hf_xet.XetSession().new_upload_commit(endpoint=endpoint) as commit: + h = commit.start_upload_bytes(b"context manager", sha256=hf_xet.SKIP_SHA256) + result = h.result() + assert result.xet_info.file_size == len(b"context manager") + assert result.xet_info.hash + + def test_context_manager_aborts_on_exception(self, endpoint): + raised = False + try: + with hf_xet.XetSession().new_upload_commit(endpoint=endpoint) as commit: + commit.start_upload_bytes(b"will be aborted", sha256=hf_xet.SKIP_SHA256) + raise ValueError("intentional error") + except ValueError: + raised = True + assert raised # exception must propagate, not be suppressed + + +# ── upload_stream ───────────────────────────────────────────────────────────── + +class TestUploadStream: + def test_write_and_finish(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + stream = commit.start_upload_stream(name="stream.bin") + stream.write(b"hello ") + stream.write(b"world") + result = stream.finish() + assert result.xet_info.file_size == 11 + assert result.xet_info.hash + + def test_try_finish_before_finish_is_none(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + stream = commit.start_upload_stream() + assert stream.try_finish() is None + + def test_try_finish_after_finish(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + stream = commit.start_upload_stream() + stream.write(b"data") + stream.finish() + result = stream.try_finish() + assert result is not None + assert result.xet_info.file_size == 4 + + def test_multiple_chunks(self, endpoint): + data = b"chunk" * 200 + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + stream = commit.start_upload_stream(name="big.bin") + for i in range(0, len(data), 50): + stream.write(data[i:i + 50]) + result = stream.finish() + assert result.xet_info.file_size == len(data) + + def test_finish_must_precede_commit(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + stream = commit.start_upload_stream() + stream.write(b"abc") + stream.finish() + commit.wait_to_finish() # should not raise + + def test_status_while_open(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + stream = commit.start_upload_stream() + assert stream.status() == hf_xet.XetTaskState.Running + + def test_task_id_is_not_none(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + stream = commit.start_upload_stream() + assert stream.task_id() is not None + + def test_abort_before_finish(self, endpoint): + commit = hf_xet.XetSession().new_upload_commit(endpoint=endpoint) + stream = commit.start_upload_stream() + stream.write(b"to be aborted") + stream.abort() # should not raise + + diff --git a/xet_data/src/deduplication/dedup_metrics.rs b/xet_data/src/deduplication/dedup_metrics.rs index 461eebdb6..219eda12b 100644 --- a/xet_data/src/deduplication/dedup_metrics.rs +++ b/xet_data/src/deduplication/dedup_metrics.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; #[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)] +#[cfg_attr(feature = "python", pyo3::pyclass(get_all))] pub struct DeduplicationMetrics { pub total_bytes: u64, pub deduped_bytes: u64, diff --git a/xet_data/src/processing/xet_file.rs b/xet_data/src/processing/xet_file.rs index f3bc13dc0..9658ebf42 100644 --- a/xet_data/src/processing/xet_file.rs +++ b/xet_data/src/processing/xet_file.rs @@ -4,6 +4,7 @@ use xet_runtime::error_printer::ErrorPrinter; /// A struct that wraps a the Xet file information. #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] +#[cfg_attr(feature = "python", pyo3::pyclass(get_all))] pub struct XetFileInfo { /// The Merkle hash of the file pub hash: String, @@ -17,6 +18,21 @@ pub struct XetFileInfo { pub sha256: Option, } +#[cfg_attr(feature = "python", pyo3::pymethods)] +impl XetFileInfo { + /// Python constructor: ``XetFileInfo(hash, file_size=None)`` + #[cfg(feature = "python")] + #[new] + #[pyo3(signature = (hash, file_size=None))] + fn py_new(hash: String, file_size: Option) -> Self { + Self { + hash, + file_size, + sha256: None, + } + } +} + impl XetFileInfo { /// Creates a new `XetFileInfo` instance with a known size. /// diff --git a/xet_pkg/Cargo.toml b/xet_pkg/Cargo.toml index 77d7a8fdd..b5f1abdef 100644 --- a/xet_pkg/Cargo.toml +++ b/xet_pkg/Cargo.toml @@ -46,8 +46,13 @@ uuid = { workspace = true, features = ["v7"] } [features] smoke-test = [] fd-track = ["xet-runtime/fd-track", "xet-client/fd-track", "xet-data/fd-track"] -python = ["xet-runtime/python", "dep:pyo3"] +python = ["xet-runtime/python", "xet-data/python", "dep:pyo3"] simulation = ["xet-client/simulation"] +native-tls = ["xet-client/native-tls"] +native-tls-vendored = ["xet-client/native-tls-vendored"] +elevated_information_level = ["xet-client/elevated_information_level", "xet-runtime/elevated_information_level"] +no-default-cache = ["xet-runtime/no-default-cache"] +tokio-console = ["xet-runtime/tokio-console"] [dev-dependencies] anyhow = { workspace = true } diff --git a/xet_pkg/examples/example.rs b/xet_pkg/examples/example.rs index 09752a66e..eb8fcf03e 100644 --- a/xet_pkg/examples/example.rs +++ b/xet_pkg/examples/example.rs @@ -138,11 +138,7 @@ async fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: O let report = group.finish().await?; for r in report.downloads.values() { - println!( - " {} ({:?} bytes)", - r.path.as_ref().map_or("?".into(), |p| p.display().to_string()), - r.file_info.file_size - ); + println!(" {} ({:?} bytes)", r.path.display(), r.file_info.file_size); } Ok(()) diff --git a/xet_pkg/examples/example_sync.rs b/xet_pkg/examples/example_sync.rs index f26b09761..9627d5900 100644 --- a/xet_pkg/examples/example_sync.rs +++ b/xet_pkg/examples/example_sync.rs @@ -137,11 +137,7 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option< let report = group.finish_blocking()?; for r in report.downloads.values() { - println!( - " {} ({:?} bytes)", - r.path.as_ref().map_or("?".into(), |p| p.display().to_string()), - r.file_info.file_size - ); + println!(" {} ({:?} bytes)", r.path.display(), r.file_info.file_size); } Ok(()) diff --git a/xet_pkg/src/lib.rs b/xet_pkg/src/lib.rs index 3f208e466..ea6f44b9b 100644 --- a/xet_pkg/src/lib.rs +++ b/xet_pkg/src/lib.rs @@ -57,3 +57,21 @@ pub use error::{XetAuthenticationError, XetObjectNotFoundError, register_excepti // and `git_xet`. New code should use the [`xet_session`] API instead. pub mod legacy; pub mod xet_session; + +/// Initialize the global tracing subscriber using xet_runtime defaults. +/// +/// Reads `HF_XET_LOG_FILE` / `RUST_LOG` environment variables. Repeated calls +/// are no-ops — the global subscriber is installed only once. +pub fn init_logging(version_info: String) { + let log_dir = xet_runtime::core::xet_cache_root().join("logs"); + + // Called before any XetContext is created, so we use a standalone default config for + // early-init logging setup. + let cfg = xet_runtime::logging::LoggingConfig::from_directory( + &xet_runtime::config::XetConfig::new(), + version_info, + log_dir, + ); + + xet_runtime::logging::init(cfg); +} diff --git a/xet_pkg/src/xet_session/download_stream_group.rs b/xet_pkg/src/xet_session/download_stream_group.rs index 58f5002eb..fa5a70b7f 100644 --- a/xet_pkg/src/xet_session/download_stream_group.rs +++ b/xet_pkg/src/xet_session/download_stream_group.rs @@ -371,7 +371,7 @@ mod tests { let group = stream_group_async(&session, &endpoint).await; let mut stream = group.download_stream(file_info, None).await.unwrap(); - let initial = stream.progress(); + let initial = stream.progress().unwrap(); assert_eq!(initial.total_bytes, original.len() as u64); assert_eq!(initial.bytes_completed, 0); @@ -381,7 +381,7 @@ mod tests { } assert_eq!(collected, original); - let final_progress = stream.progress(); + let final_progress = stream.progress().unwrap(); assert_eq!(final_progress.total_bytes, original.len() as u64); assert_eq!(final_progress.bytes_completed, original.len() as u64); } @@ -404,7 +404,7 @@ mod tests { } assert_eq!(collected, original); - let final_progress = stream.progress(); + let final_progress = stream.progress().unwrap(); assert_eq!(final_progress.total_bytes, original.len() as u64); assert_eq!(final_progress.bytes_completed, original.len() as u64); } diff --git a/xet_pkg/src/xet_session/download_stream_handle.rs b/xet_pkg/src/xet_session/download_stream_handle.rs index b850d78eb..3d6d25f9e 100644 --- a/xet_pkg/src/xet_session/download_stream_handle.rs +++ b/xet_pkg/src/xet_session/download_stream_handle.rs @@ -88,15 +88,19 @@ impl XetDownloadStream { self.inner.cancel(); } - /// Returns a snapshot of this stream's download progress. + /// Returns the unique task ID for this stream. + pub fn task_id(&self) -> UniqueID { + self.id + } + + /// Returns a snapshot of this stream's download progress, or `None` if + /// the progress item is not yet available. /// /// The returned [`ItemProgressReport`] contains the item name, /// total bytes, and bytes completed so far. This method is lock-free /// (reads atomic counters) and safe to call from any thread. - pub fn progress(&self) -> ItemProgressReport { - self.download_session - .item_report(self.id) - .expect("progress item was registered at stream creation and is never removed") + pub fn progress(&self) -> Option { + self.download_session.item_report(self.id) } } @@ -186,15 +190,19 @@ impl XetUnorderedDownloadStream { self.inner.cancel(); } - /// Returns a snapshot of this stream's download progress. + /// Returns the unique task ID for this stream. + pub fn task_id(&self) -> UniqueID { + self.id + } + + /// Returns a snapshot of this stream's download progress, or `None` if + /// the progress item is not yet available. /// /// The returned [`ItemProgressReport`] contains the item name, /// total bytes, and bytes completed so far. This method is lock-free /// (reads atomic counters) and safe to call from any thread. - pub fn progress(&self) -> ItemProgressReport { - self.download_session - .item_report(self.id) - .expect("progress item was registered at stream creation and is never removed") + pub fn progress(&self) -> Option { + self.download_session.item_report(self.id) } } diff --git a/xet_pkg/src/xet_session/file_download_group.rs b/xet_pkg/src/xet_session/file_download_group.rs index 83b8f8094..52a64878f 100644 --- a/xet_pkg/src/xet_session/file_download_group.rs +++ b/xet_pkg/src/xet_session/file_download_group.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, RwLock}; use tracing::info; use xet_data::processing::{FileDownloadSession, XetFileInfo}; -use xet_data::progress_tracking::{GroupProgressReport, UniqueID}; +use xet_data::progress_tracking::{GroupProgressReport, ItemProgressReport, UniqueID}; use super::auth_group_builder::{AuthGroupBuilder, AuthOptions}; use super::common::create_translator_config; @@ -67,6 +67,7 @@ impl AuthGroupBuilder { /// Contains final progress and per-file results keyed by [`UniqueID`]. /// Only created when all downloads succeed; any failure propagates as an error. #[derive(Clone, Debug)] +#[cfg_attr(feature = "python", pyo3::pyclass(get_all))] pub struct XetDownloadGroupReport { /// Final progress snapshot at the time the group finished. pub progress: GroupProgressReport, @@ -74,6 +75,39 @@ pub struct XetDownloadGroupReport { pub downloads: HashMap, } +#[cfg(feature = "python")] +#[pyo3::pymethods] +impl XetDownloadGroupReport { + // Example output: + // XetDownloadGroupReport(files=2, bytes_completed=5120/8192, + // downloads=[(3, "/tmp/model.bin", bytes_completed=4096/4096), (4, "/tmp/data.bin", bytes_completed=?/?)]) + // XetDownloadGroupReport(files=0, bytes_completed=0/0, downloads=[]) + // + // Each download entry is (task_id, dest_path, bytes_completed/total_bytes). + // Progress shows "?/?" when no snapshot was captured. + fn __repr__(&self) -> String { + let per_file: Vec = self + .downloads + .iter() + .map(|(id, r)| { + let path = r.path.display(); + let prog = match &r.progress { + Some(p) => format!("{}/{}", p.bytes_completed, p.total_bytes), + None => "?/?".to_string(), + }; + format!("({id}, \"{path}\", bytes_completed={prog})") + }) + .collect(); + format!( + "XetDownloadGroupReport(files={}, bytes_completed={}/{}, downloads=[{}])", + self.downloads.len(), + self.progress.total_bytes_completed, + self.progress.total_bytes, + per_file.join(", ") + ) + } +} + /// API for grouping related file downloads into a single unit of work. /// /// Obtain via [`XetSession::new_file_download_group`] — configure per-group @@ -195,6 +229,18 @@ impl XetFileDownloadGroup { self.task_runtime.status() } + /// Return `(task_id, dest_path, progress)` for every queued file download. + /// + /// `progress` is `None` if the download has not started reporting yet. + /// Used for display and diagnostics (e.g. `__repr__`). + pub fn active_download_info(&self) -> Vec<(UniqueID, PathBuf, Option)> { + self.inner + .active_tasks + .read() + .map(|tasks| tasks.values().map(|h| (h.task_id(), h.file_path(), h.progress())).collect()) + .unwrap_or_default() + } + /// Wait for all downloads to complete and return a report. /// /// Returns an [`XetDownloadGroupReport`] with per-file @@ -306,7 +352,7 @@ impl XetFileDownloadGroupInner { match join_result { Ok(Ok(n_bytes)) => Ok(XetDownloadReport { task_id, - path: Some(dp), + path: dp, file_info: XetFileInfo { hash: fi.hash, file_size: Some(n_bytes), diff --git a/xet_pkg/src/xet_session/file_download_handle.rs b/xet_pkg/src/xet_session/file_download_handle.rs index c110e12ed..5f99e5ba6 100644 --- a/xet_pkg/src/xet_session/file_download_handle.rs +++ b/xet_pkg/src/xet_session/file_download_handle.rs @@ -13,17 +13,31 @@ use crate::error::XetError; /// Per-file download result returned by /// [`XetFileDownloadGroup::finish`](crate::xet_session::XetFileDownloadGroup::finish). #[derive(Clone, Debug)] +#[cfg_attr(feature = "python", pyo3::pyclass(get_all))] pub struct XetDownloadReport { /// Unique identifier for this download task. pub task_id: UniqueID, - /// Local path where the file was written, if applicable. - pub path: Option, + /// Local path where the file was written. + pub path: PathBuf, /// Xet file hash and size of the downloaded file. pub file_info: XetFileInfo, /// Per-file progress snapshot at the time of completion. pub progress: Option, } +#[cfg(feature = "python")] +#[pyo3::pymethods] +impl XetDownloadReport { + fn __repr__(&self) -> String { + format!( + "XetDownloadReport(task_id={}, hash={:?}, path={:?})", + self.task_id, + self.file_info.hash, + self.path.display() + ) + } +} + // ── XetFileDownloadInner ──────────────────────────────────────────────────── pub(super) struct XetFileDownloadInner { @@ -42,6 +56,9 @@ pub(super) struct XetFileDownloadInner { /// [`XetFileDownloadGroup::download_file_to_path`](crate::xet_session::XetFileDownloadGroup::download_file_to_path). /// Use [`finish`](Self::finish) to wait for completion or /// [`result`](Self::result) to poll without blocking. +/// +/// Cloning is cheap — all clones share the same underlying state via `Arc`. +#[derive(Clone)] pub struct XetFileDownload { pub(super) inner: Arc, pub(super) task_runtime: Arc, diff --git a/xet_pkg/src/xet_session/session.rs b/xet_pkg/src/xet_session/session.rs index ad03b4e3a..829d16c44 100644 --- a/xet_pkg/src/xet_session/session.rs +++ b/xet_pkg/src/xet_session/session.rs @@ -368,9 +368,13 @@ impl XetSession { Ok(()) } - pub(super) fn id(&self) -> &Uuid { + pub fn id(&self) -> &Uuid { &self.inner.id } + + pub fn config(&self) -> &XetConfig { + &self.inner.ctx.config + } } #[cfg(test)] @@ -660,7 +664,7 @@ mod tests { .download_stream(file_info, None) .await .unwrap(); - let initial = stream.progress(); + let initial = stream.progress().unwrap(); assert_eq!(initial.total_bytes, original.len() as u64); assert_eq!(initial.bytes_completed, 0); @@ -670,7 +674,7 @@ mod tests { } assert_eq!(collected, original); - let final_progress = stream.progress(); + let final_progress = stream.progress().unwrap(); assert_eq!(final_progress.total_bytes, original.len() as u64); assert_eq!(final_progress.bytes_completed, original.len() as u64); } @@ -699,7 +703,7 @@ mod tests { } assert_eq!(collected, original); - let final_progress = stream.progress(); + let final_progress = stream.progress().unwrap(); assert_eq!(final_progress.total_bytes, original.len() as u64); assert_eq!(final_progress.bytes_completed, original.len() as u64); } diff --git a/xet_pkg/src/xet_session/upload_commit.rs b/xet_pkg/src/xet_session/upload_commit.rs index 8fa34b28e..57373da9e 100644 --- a/xet_pkg/src/xet_session/upload_commit.rs +++ b/xet_pkg/src/xet_session/upload_commit.rs @@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex, OnceLock}; use tracing::{error, info}; use xet_data::deduplication::DeduplicationMetrics; use xet_data::processing::{FileUploadSession, Sha256Policy, XetFileInfo}; -use xet_data::progress_tracking::{GroupProgressReport, UniqueID}; +use xet_data::progress_tracking::{GroupProgressReport, ItemProgressReport, UniqueID}; use super::auth_group_builder::{AuthGroupBuilder, AuthOptions}; use super::common::create_translator_config; @@ -71,6 +71,7 @@ impl AuthGroupBuilder { /// [`XetFileMetadata`] for every file that was successfully ingested, /// keyed by [`UniqueID`]. #[derive(Clone, Debug)] +#[cfg_attr(feature = "python", pyo3::pyclass(get_all))] pub struct XetCommitReport { /// Aggregate deduplication metrics across all files in this commit. pub dedup_metrics: DeduplicationMetrics, @@ -80,9 +81,33 @@ pub struct XetCommitReport { pub uploads: HashMap, } +#[cfg(feature = "python")] +#[pyo3::pymethods] +impl XetCommitReport { + fn __repr__(&self) -> String { + let per_file: Vec = self + .uploads + .iter() + .map(|(id, m)| { + let name = m.tracking_name.as_deref().unwrap_or("None"); + let size = m.xet_info.file_size.map_or("?".to_string(), |s| s.to_string()); + format!("({id}, \"{name}\", hash=\"{}\", size={size})", m.xet_info.hash) + }) + .collect(); + format!( + "XetCommitReport(files={}, total_bytes={}, deduped_bytes={}, uploads=[{}])", + self.uploads.len(), + self.dedup_metrics.total_bytes, + self.dedup_metrics.deduped_bytes, + per_file.join(", ") + ) + } +} + /// Per-file metadata returned by [`XetFileUpload::finalize_ingestion`] and /// [`XetStreamUpload::finish`]. #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[cfg_attr(feature = "python", pyo3::pyclass(get_all))] pub struct XetFileMetadata { /// Unique identifier for the task that produced this metadata. #[serde(skip)] @@ -96,6 +121,19 @@ pub struct XetFileMetadata { pub tracking_name: Option, } +#[cfg(feature = "python")] +#[pyo3::pymethods] +impl XetFileMetadata { + fn __repr__(&self) -> String { + let name = self.tracking_name.as_deref().unwrap_or("None"); + let size = self.xet_info.file_size.map_or("?".to_string(), |s| s.to_string()); + format!( + "XetFileMetadata(task_id={}, name=\"{}\", hash=\"{}\", size={})", + self.task_id, name, self.xet_info.hash, size + ) + } +} + // ── XetUploadCommitInner ──────────────────────────────────────────────────── pub(super) struct XetUploadCommitInner { @@ -143,10 +181,7 @@ impl XetUploadCommitInner { task_runtime, }; - self.stream_handles - .lock() - .expect("stream_handles lock poisoned") - .push(handle.clone()); + self.stream_handles.lock()?.push(handle.clone()); Ok(handle) } @@ -208,13 +243,10 @@ impl XetUploadCommitInner { }), }); - self.file_handles - .lock() - .expect("file_handles lock poisoned") - .push(XetFileUpload { - inner: inner.clone(), - task_runtime: task_runtime.clone(), - }); + self.file_handles.lock()?.push(XetFileUpload { + inner: inner.clone(), + task_runtime: task_runtime.clone(), + }); Ok(XetFileUpload { inner, task_runtime }) } @@ -435,6 +467,19 @@ impl XetUploadCommit { self.task_runtime.status() } + /// Return `(task_id, file_path, progress)` for every queued file upload. + /// + /// `file_path` is `Some` for path/bytes uploads and `None` for stream uploads. + /// `progress` is `None` if the upload has not started reporting yet. + /// Used for display and diagnostics (e.g. `__repr__`). + pub fn active_upload_info(&self) -> Vec<(UniqueID, Option, Option)> { + self.inner + .file_handles + .lock() + .map(|handles| handles.iter().map(|h| (h.task_id(), h.file_path(), h.progress())).collect()) + .unwrap_or_default() + } + /// Wait for all uploads to complete and push metadata to the CAS server. /// /// Returns a [`XetCommitReport`] with aggregate dedup metrics, progress, diff --git a/xet_pkg/src/xet_session/upload_file_handle.rs b/xet_pkg/src/xet_session/upload_file_handle.rs index 8aaac88c2..d0bde6f36 100644 --- a/xet_pkg/src/xet_session/upload_file_handle.rs +++ b/xet_pkg/src/xet_session/upload_file_handle.rs @@ -31,6 +31,9 @@ pub(super) struct XetFileUploadInner { /// /// Important: ingestion completion means the file has been chunked/deduplicated. /// The file is not uploaded to CAS until [`XetUploadCommit::commit`] is called. +/// +/// Cloning is cheap — all clones share the same underlying state via `Arc`. +#[derive(Clone)] pub struct XetFileUpload { pub(super) inner: Arc, pub(super) task_runtime: Arc, diff --git a/xet_pkg/tests/test_xet_session.rs b/xet_pkg/tests/test_xet_session.rs index d08bd713c..a0ba12ad7 100644 --- a/xet_pkg/tests/test_xet_session.rs +++ b/xet_pkg/tests/test_xet_session.rs @@ -1256,13 +1256,13 @@ async fn async_stream_progress_tracking() { let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, None).await.unwrap(); - let initial = stream.progress(); + let initial = stream.progress().unwrap(); assert_eq!(initial.total_bytes, data.len() as u64); assert_eq!(initial.bytes_completed, 0); let _ = collect_stream(&mut stream).await; - let final_progress = stream.progress(); + let final_progress = stream.progress().unwrap(); assert_eq!(final_progress.total_bytes, data.len() as u64); assert_eq!(final_progress.bytes_completed, data.len() as u64); } @@ -1372,7 +1372,7 @@ fn blocking_stream_progress_tracking() { let mut stream = group.download_stream_blocking(file_info, None).unwrap(); let _ = collect_stream_blocking(&mut stream); - let final_progress = stream.progress(); + let final_progress = stream.progress().unwrap(); assert_eq!(final_progress.total_bytes, data.len() as u64); assert_eq!(final_progress.bytes_completed, data.len() as u64); } @@ -1540,13 +1540,13 @@ async fn async_unordered_stream_progress_tracking() { let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, None).await.unwrap(); - let initial = stream.progress(); + let initial = stream.progress().unwrap(); assert_eq!(initial.total_bytes, data.len() as u64); assert_eq!(initial.bytes_completed, 0); let _ = collect_unordered_stream(&mut stream, data.len()).await; - let final_progress = stream.progress(); + let final_progress = stream.progress().unwrap(); assert_eq!(final_progress.total_bytes, data.len() as u64); assert_eq!(final_progress.bytes_completed, data.len() as u64); } @@ -1637,7 +1637,7 @@ fn blocking_unordered_stream_progress_tracking() { let mut stream = group.download_unordered_stream_blocking(file_info, None).unwrap(); let _ = collect_unordered_stream_blocking(&mut stream, data.len()); - let final_progress = stream.progress(); + let final_progress = stream.progress().unwrap(); assert_eq!(final_progress.total_bytes, data.len() as u64); assert_eq!(final_progress.bytes_completed, data.len() as u64); } @@ -1793,13 +1793,13 @@ async fn async_stream_range_progress() { let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, Some(50..150)).await.unwrap(); - let initial = stream.progress(); + let initial = stream.progress().unwrap(); assert_eq!(initial.total_bytes, 100); assert_eq!(initial.bytes_completed, 0); let _ = collect_stream(&mut stream).await; - let final_progress = stream.progress(); + let final_progress = stream.progress().unwrap(); assert_eq!(final_progress.total_bytes, 100); assert_eq!(final_progress.bytes_completed, 100); } @@ -1827,7 +1827,7 @@ fn blocking_stream_range_progress() { let mut stream = group.download_stream_blocking(file_info, Some(10..110)).unwrap(); let _ = collect_stream_blocking(&mut stream); - let final_progress = stream.progress(); + let final_progress = stream.progress().unwrap(); assert_eq!(final_progress.total_bytes, 100); assert_eq!(final_progress.bytes_completed, 100); } @@ -1878,13 +1878,13 @@ async fn async_unordered_stream_range_progress() { let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, Some(50..150)).await.unwrap(); - let initial = stream.progress(); + let initial = stream.progress().unwrap(); assert_eq!(initial.total_bytes, 100); assert_eq!(initial.bytes_completed, 0); let _ = collect_unordered_stream(&mut stream, 100).await; - let final_progress = stream.progress(); + let final_progress = stream.progress().unwrap(); assert_eq!(final_progress.total_bytes, 100); assert_eq!(final_progress.bytes_completed, 100); } @@ -1912,7 +1912,7 @@ fn blocking_unordered_stream_range_progress() { let mut stream = group.download_unordered_stream_blocking(file_info, Some(10..110)).unwrap(); let _ = collect_unordered_stream_blocking(&mut stream, 100); - let final_progress = stream.progress(); + let final_progress = stream.progress().unwrap(); assert_eq!(final_progress.total_bytes, 100); assert_eq!(final_progress.bytes_completed, 100); } diff --git a/xet_runtime/src/config/mod.rs b/xet_runtime/src/config/mod.rs index e1950d42f..bee07267b 100644 --- a/xet_runtime/src/config/mod.rs +++ b/xet_runtime/src/config/mod.rs @@ -25,6 +25,3 @@ pub type ClientConfig = groups::client::ConfigValues; pub type LogConfig = groups::log::ConfigValues; pub type XorbConfig = groups::xorb::ConfigValues; pub type SessionConfig = groups::session::ConfigValues; - -#[cfg(feature = "python")] -pub use xet_config::py_xet_config::PyXetConfig; diff --git a/xet_runtime/src/config/python.rs b/xet_runtime/src/config/python.rs index 0d2441c0d..737fdd47a 100644 --- a/xet_runtime/src/config/python.rs +++ b/xet_runtime/src/config/python.rs @@ -11,7 +11,7 @@ use crate::utils::{ByteSize, ConfigEnum}; /// - Numeric types <-> Python int/float /// - String <-> Python str /// - bool <-> Python bool -/// - Duration <-> Python datetime.timedelta +/// - Duration <-> Python str (e.g. ``"500ms"``, ``"1s"``) or ``datetime.timedelta`` /// - ByteSize <-> Python int (bytes) /// - Option <-> Optional[T] pub trait PythonConfigValue { @@ -47,7 +47,27 @@ macro_rules! impl_python_extract { }; } -impl_python_extract!(usize, u8, u16, u32, u64, isize, i8, i16, i32, i64, f32, f64, bool, String, std::time::Duration); +impl_python_extract!(usize, u8, u16, u32, u64, isize, i8, i16, i32, i64, f32, f64, bool, String); + +impl PythonConfigValue for std::time::Duration { + fn to_python(&self, py: Python<'_>) -> PyResult> { + self.into_py_any(py) + } + + fn from_python(obj: &Bound<'_, PyAny>) -> PyResult { + // Accept a human-readable string like "500ms", "1s", "2m30s" (parsed by humantime), + // or a Python datetime.timedelta (PyO3's native Duration extraction). + if let Ok(s) = obj.extract::() { + humantime::parse_duration(&s).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "Invalid duration string {s:?}: {e}. Expected a humantime string like \"500ms\", \"1s\", or \"2m30s\"." + )) + }) + } else { + obj.extract() + } + } +} impl PythonConfigValue for ByteSize { fn to_python(&self, py: Python<'_>) -> PyResult> { diff --git a/xet_runtime/src/config/xet_config.rs b/xet_runtime/src/config/xet_config.rs index 9b0b6fa28..658cb5c0f 100644 --- a/xet_runtime/src/config/xet_config.rs +++ b/xet_runtime/src/config/xet_config.rs @@ -23,7 +23,7 @@ macro_rules! impl_xet_config_group_dispatch { } #[cfg(feature = "python")] - fn split_path_for_python(path: &str) -> pyo3::PyResult<(&str, &str)> { + pub fn split_path_for_python(path: &str) -> pyo3::PyResult<(&str, &str)> { Self::split_path(path).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) } @@ -49,7 +49,7 @@ macro_rules! impl_xet_config_group_dispatch { } #[cfg(feature = "python")] - fn update_field_from_python( + pub fn update_field_from_python( &mut self, path: &str, value: &pyo3::Bound<'_, pyo3::PyAny>, @@ -66,7 +66,7 @@ macro_rules! impl_xet_config_group_dispatch { } #[cfg(feature = "python")] - fn get_field_to_python( + pub fn get_field_to_python( &self, path: &str, py: pyo3::Python<'_>, @@ -110,7 +110,7 @@ macro_rules! impl_xet_config_group_dispatch { } #[cfg(feature = "python")] - fn all_items_to_python( + pub fn all_items_to_python( &self, py: pyo3::Python<'_>, ) -> pyo3::PyResult)>> { @@ -388,133 +388,3 @@ mod tests { assert_eq!(config.get("system_monitor.log_path").unwrap(), "~/logs/monitor_{PID}.log"); } } - -#[cfg(feature = "python")] -pub mod py_xet_config { - use pyo3::prelude::*; - use pyo3::types::PyDict; - - use super::*; - - #[pyclass(name = "XetConfig")] - pub struct PyXetConfig { - inner: XetConfig, - } - - impl From for PyXetConfig { - fn from(inner: XetConfig) -> Self { - Self { inner } - } - } - - impl PyXetConfig { - pub fn inner(&self) -> &XetConfig { - &self.inner - } - } - - #[pymethods] - impl PyXetConfig { - #[new] - fn py_new() -> Self { - Self { - inner: XetConfig::new(), - } - } - - /// Return a new XetConfig with one or more values updated. - /// - /// Can be called in two ways: - /// config.with_config("group.field", value) -- single update - /// config.with_config({"group.field": value, ...}) -- batch update - #[pyo3(name = "with_config")] - #[pyo3(signature = (name_or_dict, value=None))] - fn py_with_config(&self, name_or_dict: &Bound<'_, PyAny>, value: Option<&Bound<'_, PyAny>>) -> PyResult { - let mut new_inner = self.inner.clone(); - - if let Ok(dict) = name_or_dict.downcast::() { - if value.is_some() { - return Err(pyo3::exceptions::PyTypeError::new_err( - "with_config(dict) does not accept a second argument", - )); - } - for (key, val) in dict.iter() { - let key_str: String = key.extract()?; - new_inner.update_field_from_python(&key_str, &val)?; - } - } else { - let name: String = name_or_dict.extract()?; - let val = value.ok_or_else(|| { - pyo3::exceptions::PyTypeError::new_err("with_config(name, value) requires a value argument") - })?; - new_inner.update_field_from_python(&name, val)?; - } - - Ok(Self { inner: new_inner }) - } - - /// Get a configuration value as its native Python type by dotted path - /// (e.g. "data.max_concurrent_file_ingestion"). - #[pyo3(name = "get")] - fn py_get(&self, py: Python<'_>, path: &str) -> PyResult> { - self.inner.get_field_to_python(path, py) - } - - fn __getitem__(&self, py: Python<'_>, key: &str) -> PyResult> { - self.inner - .get_field_to_python(key, py) - .map_err(|_| pyo3::exceptions::PyKeyError::new_err(key.to_owned())) - } - - /// Return all (key, value) pairs as a list of tuples. - /// Keys are dotted paths like "data.max_concurrent_file_ingestion". - fn items(&self, py: Python<'_>) -> PyResult)>> { - self.inner.all_items_to_python(py) - } - - /// Return all dotted-path keys. - fn keys(&self) -> Vec { - self.inner.all_keys() - } - - fn __len__(&self) -> usize { - self.inner.all_keys().len() - } - - fn __iter__(slf: PyRef<'_, Self>, py: Python<'_>) -> PyResult> { - let items = slf.inner.all_items_to_python(py)?; - Py::new(py, PyXetConfigIter { items, index: 0 }) - } - - fn __repr__(&self) -> String { - format!("XetConfig({:?})", self.inner) - } - - fn __str__(&self) -> String { - format!("{:?}", self.inner) - } - } - - #[pyclass] - struct PyXetConfigIter { - items: Vec<(String, Py)>, - index: usize, - } - - #[pymethods] - impl PyXetConfigIter { - fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { - slf - } - - fn __next__(&mut self, py: Python<'_>) -> Option<(String, Py)> { - if self.index < self.items.len() { - let (key, value) = &self.items[self.index]; - self.index += 1; - Some((key.clone(), value.clone_ref(py))) - } else { - None - } - } - } -} diff --git a/xet_runtime/src/core/runtime.rs b/xet_runtime/src/core/runtime.rs index ec69a4f50..6d0f70451 100644 --- a/xet_runtime/src/core/runtime.rs +++ b/xet_runtime/src/core/runtime.rs @@ -169,6 +169,9 @@ pub struct XetRuntime { // Are we in the middle of a sigint shutdown? sigint_shutdown: AtomicBool, + // PID of the process that created this runtime; used in Drop to detect fork. + creation_pid: u32, + // System monitor instance if enabled, monitor starts on initiation #[cfg(not(target_family = "wasm"))] system_monitor: Option, @@ -201,6 +204,7 @@ impl XetRuntime { handle_ref: OnceLock::new(), external_executor_count: 0.into(), sigint_shutdown: false.into(), + creation_pid: std::process::id(), #[cfg(not(target_family = "wasm"))] system_monitor: system_monitor_for_config(config), }); @@ -282,6 +286,7 @@ impl XetRuntime { handle_ref: rt_handle.into(), external_executor_count: 0.into(), sigint_shutdown: false.into(), + creation_pid: std::process::id(), #[cfg(not(target_family = "wasm"))] system_monitor: system_monitor_for_config(config), }); @@ -301,6 +306,7 @@ impl XetRuntime { handle_ref: rt_handle.into(), external_executor_count: 0.into(), sigint_shutdown: false.into(), + creation_pid: std::process::id(), #[cfg(not(target_family = "wasm"))] system_monitor: None, }) @@ -477,6 +483,13 @@ impl XetRuntime { T: Send + 'static, { self.check_sigint()?; + if std::process::id() != self.creation_pid { + return Err(RuntimeError::InvalidRuntime(format!( + "XetRuntime was created in process {} but is being used in process {}", + self.creation_pid, + std::process::id(), + ))); + } match &self.backend { RuntimeBackend::External { .. } => Ok(fut.await), RuntimeBackend::OwnedThreadPool { .. } => self.bridge_to_owned(task_name, fut).await, @@ -499,6 +512,13 @@ impl XetRuntime { F::Output: Send + 'static, { self.check_sigint()?; + if std::process::id() != self.creation_pid { + return Err(RuntimeError::InvalidRuntime(format!( + "XetRuntime was created in process {} but is being used in process {}", + self.creation_pid, + std::process::id(), + ))); + } if matches!(self.backend, RuntimeBackend::External { .. }) { return Err(RuntimeError::InvalidRuntime( "bridge_sync() cannot be called on an External-mode runtime; \ @@ -662,26 +682,37 @@ impl Drop for XetRuntime { self.handle_ref.take(); - match &self.backend { - RuntimeBackend::External { handle_id: Some(id) } => { - if let Ok(mut reg) = EXTERNAL_THREADPOOL_REGISTRY.write() { - reg.remove(id); - } - }, - RuntimeBackend::External { handle_id: None } => {}, - RuntimeBackend::OwnedThreadPool { runtime } => { - let in_async_context = TokioRuntimeHandle::try_current().is_ok(); - if let Ok(mut guard) = runtime.write() - && let Some(rt_arc) = guard.take() - && let Ok(rt) = Arc::try_unwrap(rt_arc) - { - if in_async_context { - rt.shutdown_background(); - } else { - rt.shutdown_timeout(std::time::Duration::from_secs(5)); - } - } - }, + // Fork detection: if we are in a child process the Tokio worker threads from the + // parent do not exist here. shutdown_timeout() would block ~5 s waiting on futexes + // that never fire. Use discard_runtime() (std::mem::forget) instead — the OS + // reclaims all memory when the child exits. + if self.creation_pid != std::process::id() { + self.discard_runtime(); + return; + } + + if let RuntimeBackend::External { handle_id: Some(id) } = &self.backend { + if let Ok(mut reg) = EXTERNAL_THREADPOOL_REGISTRY.write() { + reg.remove(id); + } + return; + } + + // When dropping from within an async context, the default TokioRuntime Drop + // would panic ("Cannot drop a runtime in a context where blocking is not allowed"). + // Avoid this by taking ownership of the runtime and using shutdown_background(), + // which spawns a thread for the blocking shutdown work instead. + let in_async_context = TokioRuntimeHandle::try_current().is_ok(); + if let RuntimeBackend::OwnedThreadPool { runtime } = &self.backend + && let Ok(mut guard) = runtime.write() + && let Some(rt_arc) = guard.take() + && let Ok(rt) = Arc::try_unwrap(rt_arc) + { + if in_async_context { + rt.shutdown_background(); + } else { + rt.shutdown_timeout(std::time::Duration::from_secs(5)); + } } } } diff --git a/xet_runtime/src/utils/unique_id.rs b/xet_runtime/src/utils/unique_id.rs index 07d8637e9..a2bc13455 100644 --- a/xet_runtime/src/utils/unique_id.rs +++ b/xet_runtime/src/utils/unique_id.rs @@ -4,7 +4,8 @@ use std::sync::atomic::{AtomicU64, Ordering}; static NEXT_ID: AtomicU64 = AtomicU64::new(1); #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct UniqueId(u64); +#[cfg_attr(feature = "python", pyo3::pyclass)] +pub struct UniqueId(pub u64); impl UniqueId { pub fn new() -> Self { @@ -22,6 +23,22 @@ impl Default for UniqueId { } } +#[cfg(feature = "python")] +#[pyo3::pymethods] +impl UniqueId { + fn __repr__(&self) -> String { + format!("UniqueId({})", self.0) + } + + fn __hash__(&self) -> u64 { + self.0 + } + + fn __eq__(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + impl fmt::Display for UniqueId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0)