diff --git a/pyrefly/lib/export/exports.rs b/pyrefly/lib/export/exports.rs index c4f748827c..111dd9978b 100644 --- a/pyrefly/lib/export/exports.rs +++ b/pyrefly/lib/export/exports.rs @@ -403,6 +403,14 @@ impl Exports { }; self.0.exports.calculate(f).unwrap_or_default() } + + pub fn is_explicit_reexport(&self, name: &Name) -> bool { + self.0 + .definitions + .definitions + .get(name) + .is_some_and(|definition| matches!(definition.style, DefinitionStyle::ImportAsEq(_))) + } } #[cfg(test)] diff --git a/pyrefly/lib/state/lsp.rs b/pyrefly/lib/state/lsp.rs index f603bfd6df..5b376e633e 100644 --- a/pyrefly/lib/state/lsp.rs +++ b/pyrefly/lib/state/lsp.rs @@ -513,6 +513,15 @@ pub struct FindDefinitionItem { } impl<'a> Transaction<'a> { + fn allows_explicit_reexport(handle: &Handle) -> bool { + matches!( + handle.path().details(), + ModulePathDetails::FileSystem(_) + | ModulePathDetails::Namespace(_) + | ModulePathDetails::Memory(_) + ) + } + pub fn get_type(&self, handle: &Handle, key: &Key) -> Option { let idx = self.get_bindings(handle)?.key_to_idx(key); let answers = self.get_answers(handle)?; @@ -2915,7 +2924,7 @@ impl<'a> Transaction<'a> { } pub fn search_exports_exact(&self, name: &str) -> Vec<(Handle, Export)> { - self.search_exports(|handle, exports| { + self.search_exports(|handle, exports_data, exports| { let name = Name::new(name); match exports.get(&name) { Some(location) => { @@ -2924,7 +2933,9 @@ impl<'a> Transaction<'a> { { let mut results = vec![(canonical_handle.dupe(), export.clone())]; if canonical_handle != *handle - && Self::should_include_reexport(handle, &canonical_handle) + && (Self::should_include_reexport(handle, &canonical_handle) + || (exports_data.is_explicit_reexport(&name) + && Self::allows_explicit_reexport(handle))) { results.push((handle.dupe(), export)); } @@ -2939,7 +2950,7 @@ impl<'a> Transaction<'a> { } pub fn search_exports_fuzzy(&self, pattern: &str) -> Vec<(Handle, String, Export)> { - let mut res = self.search_exports(|handle, exports| { + let mut res = self.search_exports(|handle, exports_data, exports| { let matcher = SkimMatcherV2::default().smart_case(); let mut results = Vec::new(); for (name, location) in exports.iter() { @@ -2955,7 +2966,9 @@ impl<'a> Transaction<'a> { export.clone(), )); if canonical_handle != *handle - && Self::should_include_reexport(handle, &canonical_handle) + && (Self::should_include_reexport(handle, &canonical_handle) + || (exports_data.is_explicit_reexport(name) + && Self::allows_explicit_reexport(handle))) { results.push((score, handle.dupe(), name_str.to_owned(), export)); } diff --git a/pyrefly/lib/state/state.rs b/pyrefly/lib/state/state.rs index 9b301dc795..2faf453d68 100644 --- a/pyrefly/lib/state/state.rs +++ b/pyrefly/lib/state/state.rs @@ -653,7 +653,7 @@ impl<'a> Transaction<'a> { /// The order of the resulting `Vec` is unspecified. pub fn search_exports( &self, - searcher: impl Fn(&Handle, &SmallMap) -> Vec + Sync, + searcher: impl Fn(&Handle, &Exports, &SmallMap) -> Vec + Sync, ) -> Vec { // Make sure all the modules are in updated_modules. // We have to get a mutable module data to do the lookup we need anyway. @@ -675,10 +675,9 @@ impl<'a> Transaction<'a> { tasks.work_without_cancellation(|_, modules| { let mut thread_local_results = Vec::new(); for (handle, module_data) in modules { - let exports = self - .lookup_export(module_data) - .exports(&self.lookup(module_data.dupe())); - thread_local_results.extend(searcher(handle, &exports)); + let exports_data = self.lookup_export(module_data); + let exports = exports_data.exports(&self.lookup(module_data.dupe())); + thread_local_results.extend(searcher(handle, &exports_data, &exports)); } if !thread_local_results.is_empty() { all_results.lock().push(thread_local_results); diff --git a/pyrefly/lib/test/lsp/completion.rs b/pyrefly/lib/test/lsp/completion.rs index 9e56e11061..a83dd3af7f 100644 --- a/pyrefly/lib/test/lsp/completion.rs +++ b/pyrefly/lib/test/lsp/completion.rs @@ -1941,6 +1941,41 @@ Completion Results: ); } +#[test] +fn autoimport_explicit_reexport_suggests_reexport_path() { + let code = r#" +T = Thing +# ^ +"#; + let report = get_batched_lsp_operations_report_allow_error( + &[ + ("main", code), + ("source", "Thing = 1\n"), + ("public", "from source import Thing as Thing\n"), + ], + get_test_report(Default::default(), ImportFormat::Absolute), + ); + assert_eq!( + r#" +# main.py +2 | T = Thing + ^ +Completion Results: +- (Variable) Thing: from public import Thing + +- (Variable) Thing: from source import Thing + + + +# source.py + +# public.py +"# + .trim(), + report.trim(), + ); +} + #[test] fn autoimport_prefers_shorter_module() { let code = r#"