Skip to content

Commit

Permalink
add inner use sort to formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
dean-starkware committed Oct 16, 2024
1 parent 65d3246 commit a805cdb
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 35 deletions.
277 changes: 243 additions & 34 deletions crates/cairo-lang-formatter/src/formatter_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::fmt;
use cairo_lang_filesystem::span::TextWidth;
use cairo_lang_syntax as syntax;
use cairo_lang_syntax::attribute::consts::FMT_SKIP_ATTR;
use cairo_lang_syntax::node::ast::UsePath;
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::{SyntaxNode, Terminal, TypedSyntaxNode, ast};
use itertools::Itertools;
Expand Down Expand Up @@ -730,6 +731,20 @@ impl BreakLinePointsPositions {
}
}

trait IdentExtractor {
fn extract_ident(&self, db: &dyn SyntaxGroup) -> String;
}
impl IdentExtractor for ast::UsePathLeaf {
fn extract_ident(&self, db: &dyn SyntaxGroup) -> String {
self.ident(db).as_syntax_node().get_text_without_trivia(db)
}
}
impl IdentExtractor for ast::UsePathSingle {
fn extract_ident(&self, db: &dyn SyntaxGroup) -> String {
self.ident(db).as_syntax_node().get_text_without_trivia(db)
}
}

// TODO(spapini): Introduce the correct types here, to reflect the "applicable" nodes types.
pub trait SyntaxNodeFormat {
/// Returns true if a token should never have a space before it.
Expand Down Expand Up @@ -826,41 +841,11 @@ impl<'a> FormatterImpl<'a> {
// TODO(ilya): consider not copying here.
let mut children = self.db.get_children(syntax_node.clone()).to_vec();
let n_children = children.len();

if self.config.sort_module_level_items {
let mut start_idx = 0;
while start_idx < children.len() {
let kind = children[start_idx].as_sort_kind(self.db);
let mut end_idx = start_idx + 1;
// Find the end of the current section.
while end_idx < children.len() {
if kind != children[end_idx].as_sort_kind(self.db) {
break;
}
end_idx += 1;
}
// Sort within this section if it's `Module` or `UseItem`.
match kind {
SortKind::Module => {
children[start_idx..end_idx].sort_by_key(|node| {
ast::ItemModule::from_syntax_node(self.db, node.clone())
.name(self.db)
.text(self.db)
});
}
SortKind::UseItem => {
children[start_idx..end_idx].sort_by_key(|node| {
ast::ItemUse::from_syntax_node(self.db, node.clone())
.use_path(self.db)
.as_syntax_node()
.get_text_without_trivia(self.db)
});
}
SortKind::Immovable => {}
}

// Move past the sorted section.
start_idx = end_idx;
if let SyntaxKind::UsePathList = syntax_node.kind(self.db) {
self.sort_inner_use_path(&mut children);
} else {
self.sort_items_sections(&mut children);
}
}
for (i, child) in children.iter().enumerate() {
Expand All @@ -878,6 +863,90 @@ impl<'a> FormatterImpl<'a> {
self.empty_lines_allowance = allowed_empty_between;
}
}

/// Sorting function for `UsePathMulti` children.
fn sort_inner_use_path(&self, children: &mut Vec<SyntaxNode>) {
// Filter and collect only UsePathLeaf and UsePathSingle, while excluding TokenComma.
let mut sorted_leaf_and_single: Vec<_> = children
.iter()
.filter(|node| {
matches!(node.kind(self.db), SyntaxKind::UsePathLeaf | SyntaxKind::UsePathSingle)
})
.cloned()
.collect();

// Sort the filtered nodes by comparing their `UsePath`.
sorted_leaf_and_single.sort_by(|a_node, b_node| {
let a_use_path = extract_use_path(a_node, self.db);
let b_use_path = extract_use_path(b_node, self.db);

match (a_use_path, b_use_path) {
(Some(a_path), Some(b_path)) => {
// Compare the extracted `UsePath`s.
compare_use_paths(&a_path, &b_path, self.db)
}
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(None, None) => Ordering::Equal,
}
});

// Filter and collect the other node types (specifically commas).
let other_types: Vec<_> = children
.iter()
.filter(|node| {
!matches!(node.kind(self.db), SyntaxKind::UsePathLeaf | SyntaxKind::UsePathSingle)
})
.cloned()
.collect();

// Intersperse sorted `UsePathLeaf` and `UsePathSingle` with other types.
*children =
itertools::Itertools::intersperse_with(sorted_leaf_and_single.into_iter(), || {
other_types.first().cloned().unwrap()
})
.collect();
}

/// Sorting function for module-level items.
fn sort_items_sections(&self, children: &mut [SyntaxNode]) {
let mut start_idx = 0;
while start_idx < children.len() {
let kind = children[start_idx].as_sort_kind(self.db);
let mut end_idx = start_idx + 1;
// Find the end of the current section.
while end_idx < children.len() {
if kind != children[end_idx].as_sort_kind(self.db) {
break;
}
end_idx += 1;
}
// Sort within this section if it's `Module` or `UseItem`.
match kind {
SortKind::Module => {
children[start_idx..end_idx].sort_by_key(|node| {
ast::ItemModule::from_syntax_node(self.db, node.clone())
.name(self.db)
.text(self.db)
});
}
SortKind::UseItem => {
children[start_idx..end_idx].sort_by(|a, b| {
let a =
ast::ItemUse::from_syntax_node(self.db, a.clone()).use_path(self.db);
let b =
ast::ItemUse::from_syntax_node(self.db, b.clone()).use_path(self.db);
compare_use_paths(&a, &b, self.db)
});
}
SortKind::Immovable => {
// Do nothing for immovable sections.
}
}
start_idx = end_idx;
}
}

/// Formats a terminal node and appends the formatted string to the result.
fn format_terminal(&mut self, syntax_node: &SyntaxNode) {
// TODO(spapini): Introduce a Terminal and a Token enum in ast.rs to make this cleaner.
Expand Down Expand Up @@ -957,6 +1026,146 @@ impl<'a> FormatterImpl<'a> {
}
}

/// Compares two `UsePath` nodes to determine their ordering.
fn compare_use_paths(a: &UsePath, b: &UsePath, db: &dyn SyntaxGroup) -> Ordering {
match (a, b) {
// Case for multi vs non-multi: multi paths are always ordered before non-multi paths.
(UsePath::Multi(_), UsePath::Leaf(_) | UsePath::Single(_)) => Ordering::Less,
(UsePath::Leaf(_) | UsePath::Single(_), UsePath::Multi(_)) => Ordering::Greater,

// Case for multi vs multi.
(UsePath::Multi(a_multi), UsePath::Multi(b_multi)) => {
// Store the elements to extend their lifetimes
let a_elements = a_multi.use_paths(db).elements(db);
let b_elements = b_multi.use_paths(db).elements(db);
// Find the minimum elements by key.
// Find the minimum elements by key using the `extract_ident` method.
let a_min = a_elements.iter().min_by_key(|path| {
match path {
UsePath::Leaf(leaf) => leaf.extract_ident(db),
UsePath::Single(single) => single.extract_ident(db),
_ => "".to_string(), // Handle cases where it's neither leaf nor single
}
});
let b_min = b_elements.iter().min_by_key(|path| {
match path {
UsePath::Leaf(leaf) => leaf.extract_ident(db),
UsePath::Single(single) => single.extract_ident(db),
_ => "".to_string(), // Handle cases where it's neither leaf nor single
}
});
// If both `a_min` and `b_min` are `Some`, compare them.
a_min
.and_then(|a_min| b_min.map(|b_min| compare_use_paths(a_min, b_min, db)))
.unwrap_or_else(|| {
// If both are equal, compare next paths if they exist.
let next_a = next_use_path(a.clone(), db);
let next_b = next_use_path(b.clone(), db);

match (next_a, next_b) {
(Some(new_a), Some(new_b)) => compare_use_paths(&new_a, &new_b, db),
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(None, None) => Ordering::Equal,
}
})
}

// Case for Leaf vs Single and Single vs Leaf.
(UsePath::Leaf(a_leaf), UsePath::Single(b_single)) => {
let a_str = a_leaf.extract_ident(db);
let b_str = b_single.extract_ident(db);

// Compare the extracted identifiers.
match a_str.cmp(&b_str) {
Ordering::Equal => Ordering::Less, // Leaf is always ordered before Single if
// equal.
other => other,
}
}

(UsePath::Single(a_single), UsePath::Leaf(b_leaf)) => {
let a_str = a_single.extract_ident(db);
let b_str = b_leaf.extract_ident(db);

// Compare the extracted identifiers.
match a_str.cmp(&b_str) {
Ordering::Equal => Ordering::Greater, // Single is ordered after Leaf if equal.
other => other,
}
}

// Case for Leaf vs Leaf: directly compare their identifiers.
(UsePath::Leaf(a_leaf), UsePath::Leaf(b_leaf)) => {
let a_str = a_leaf.extract_ident(db);
let b_str = b_leaf.extract_ident(db);

// In case of tie breaks we do not sort by aliases (same methodology as Rust).
a_str.cmp(&b_str)
}

// Case for Single vs Single: compare their identifiers, then move to the next segment if
// equal.
(UsePath::Single(a_single), UsePath::Single(b_single)) => {
// this is the problem. it takes the whole path instead of only the single before the ::
let a_str = a_single.extract_ident(db);
let b_str = b_single.extract_ident(db);
match a_str.cmp(&b_str) {
Ordering::Equal => {
// If the identifiers are equal, compare the next path segment if available.
let next_a = next_use_path(a.clone(), db);
let next_b = next_use_path(b.clone(), db);
compare_use_paths(&next_a.unwrap(), &next_b.unwrap(), db)
}
other => other,
}
}
}
}

/// Helper function to extract `UsePath` from a `SyntaxNode`.
fn extract_use_path(node: &SyntaxNode, db: &dyn SyntaxGroup) -> Option<ast::UsePath> {
match node.kind(db) {
SyntaxKind::UsePathLeaf => {
Some(ast::UsePath::Leaf(ast::UsePathLeaf::from_syntax_node(db, node.clone())))
}
SyntaxKind::UsePathSingle => {
Some(ast::UsePath::Single(ast::UsePathSingle::from_syntax_node(db, node.clone())))
}
SyntaxKind::UsePathMulti => {
Some(ast::UsePath::Multi(ast::UsePathMulti::from_syntax_node(db, node.clone())))
}
_ => None,
}
}

/// Function to get the next part of a `UsePath`.
fn next_use_path(path: UsePath, db: &dyn SyntaxGroup) -> Option<UsePath> {
match path {
UsePath::Leaf(_) => None,
UsePath::Single(single) => match single.use_path(db) {
UsePath::Leaf(leaf) => Some(UsePath::Leaf(leaf)),
UsePath::Single(single) => Some(UsePath::Single(single)),
UsePath::Multi(multi) => {
multi
.use_paths(db)
.elements(db)
.iter()
.min_by_key(|x| {
match x {
UsePath::Leaf(leaf) => leaf.extract_ident(db),
UsePath::Single(single) => single.extract_ident(db),
_ => "".to_string(), /* Return an empty string if it's not Single or
* Leaf */
}
})
.cloned()
}
},
UsePath::Multi(_) => unreachable!(),
}
}

/// Represents the kind of sections in the syntax tree that can be sorted.
/// Classify consecutive nodes into sections that are eligible for sorting.
#[derive(PartialEq, Eq)]
Expand Down
5 changes: 5 additions & 0 deletions crates/cairo-lang-formatter/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ impl Upcast<dyn FilesGroup> for DatabaseImpl {
"test_data/expected_results/sorted_mod_use.cairo",
true
)]
#[test_case(
"test_data/cairo_files/sort_inner_use.cairo",
"test_data/expected_results/sort_inner_use.cairo",
true
)]
fn format_and_compare_file(unformatted_filename: &str, expected_filename: &str, use_sorting: bool) {
let db_val = SimpleParserDatabase::default();
let db = &db_val;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use std::collections::HashMap;
use crate::utils::{a, b, d, c};
use b::{d, c, b, a};
use a::{d, b, a, c};

use a::{c, d};
use a::{d, b};
use aba;

use a::{d, b as bee, c as cee, a};
use a::{d, b};
use a::{a as ab, a as bc};
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use a::{a, b, c, d};

use a::{a, b as bee, c as cee, d};
use a::{a as ab, a as bc};
use a::{b, d};
use a::{b, d};

use a::{c, d};
use aba;
use b::{a, b, c, d};
use crate::utils::{a, b, c, d};
use std::collections::HashMap;
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use openzeppelin::introspection::interface;
#[starknet::contract]
mod SRC5 {
//! Header comment, should not be moved by the formatter.
use openzeppelin::introspection::{AB, interface};
/// Doc comment, should be moved by the formatter.
use openzeppelin::introspection::interface;
use openzeppelin::introspection::{interface, AB};

#[storage]
struct Storage {
Expand Down

0 comments on commit a805cdb

Please sign in to comment.