diff options
Diffstat (limited to 'components/script/xpath')
-rw-r--r-- | components/script/xpath/context.rs | 95 | ||||
-rw-r--r-- | components/script/xpath/eval.rs | 589 | ||||
-rw-r--r-- | components/script/xpath/eval_function.rs | 357 | ||||
-rw-r--r-- | components/script/xpath/eval_value.rs | 242 | ||||
-rw-r--r-- | components/script/xpath/mod.rs | 73 | ||||
-rw-r--r-- | components/script/xpath/parser.rs | 1209 |
6 files changed, 2565 insertions, 0 deletions
diff --git a/components/script/xpath/context.rs b/components/script/xpath/context.rs new file mode 100644 index 00000000000..90873a89776 --- /dev/null +++ b/components/script/xpath/context.rs @@ -0,0 +1,95 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ + +use std::iter::Enumerate; +use std::vec::IntoIter; + +use super::Node; +use crate::dom::bindings::root::DomRoot; + +/// The context during evaluation of an XPath expression. +pub struct EvaluationCtx { + /// Where we started at + pub starting_node: DomRoot<Node>, + /// The "current" node in the evaluation + pub context_node: DomRoot<Node>, + /// Details needed for evaluating a predicate list + pub predicate_ctx: Option<PredicateCtx>, + /// The nodes we're currently matching against + pub predicate_nodes: Option<Vec<DomRoot<Node>>>, +} + +#[derive(Clone, Copy)] +pub struct PredicateCtx { + pub index: usize, + pub size: usize, +} + +impl EvaluationCtx { + /// Prepares the context used while evaluating the XPath expression + pub fn new(context_node: &Node) -> EvaluationCtx { + EvaluationCtx { + starting_node: DomRoot::from_ref(context_node), + context_node: DomRoot::from_ref(context_node), + predicate_ctx: None, + predicate_nodes: None, + } + } + + /// Creates a new context using the provided node as the context node + pub fn subcontext_for_node(&self, node: &Node) -> EvaluationCtx { + EvaluationCtx { + starting_node: self.starting_node.clone(), + context_node: DomRoot::from_ref(node), + predicate_ctx: self.predicate_ctx, + predicate_nodes: self.predicate_nodes.clone(), + } + } + + pub fn update_predicate_nodes(&self, nodes: Vec<&Node>) -> EvaluationCtx { + EvaluationCtx { + starting_node: self.starting_node.clone(), + context_node: self.context_node.clone(), + predicate_ctx: None, + predicate_nodes: Some(nodes.into_iter().map(DomRoot::from_ref).collect()), + } + } + + pub fn subcontext_iter_for_nodes(&self) -> EvalNodesetIter { + let size = self.predicate_nodes.as_ref().map_or(0, |v| v.len()); + EvalNodesetIter { + ctx: self, + nodes_iter: self + .predicate_nodes + .as_ref() + .map_or_else(|| Vec::new().into_iter(), |v| v.clone().into_iter()) + .enumerate(), + size, + } + } +} + +/// When evaluating predicates, we need to keep track of the current node being evaluated and +/// the index of that node in the nodeset we're operating on. +pub struct EvalNodesetIter<'a> { + ctx: &'a EvaluationCtx, + nodes_iter: Enumerate<IntoIter<DomRoot<Node>>>, + size: usize, +} + +impl<'a> Iterator for EvalNodesetIter<'a> { + type Item = EvaluationCtx; + + fn next(&mut self) -> Option<EvaluationCtx> { + self.nodes_iter.next().map(|(idx, node)| EvaluationCtx { + starting_node: self.ctx.starting_node.clone(), + context_node: node.clone(), + predicate_nodes: self.ctx.predicate_nodes.clone(), + predicate_ctx: Some(PredicateCtx { + index: idx + 1, + size: self.size, + }), + }) + } +} diff --git a/components/script/xpath/eval.rs b/components/script/xpath/eval.rs new file mode 100644 index 00000000000..d880e2f3930 --- /dev/null +++ b/components/script/xpath/eval.rs @@ -0,0 +1,589 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ + +use std::fmt; + +use html5ever::{local_name, namespace_prefix, namespace_url, ns, QualName}; + +use super::parser::{ + AdditiveOp, Axis, EqualityOp, Expr, FilterExpr, KindTest, Literal, MultiplicativeOp, NodeTest, + NumericLiteral, PathExpr, PredicateExpr, PredicateListExpr, PrimaryExpr, + QName as ParserQualName, RelationalOp, StepExpr, UnaryOp, +}; +use super::{EvaluationCtx, Value}; +use crate::dom::bindings::codegen::Bindings::NodeBinding::NodeMethods; +use crate::dom::bindings::inheritance::{Castable, CharacterDataTypeId, NodeTypeId}; +use crate::dom::bindings::root::DomRoot; +use crate::dom::bindings::xmlname::validate_and_extract; +use crate::dom::element::Element; +use crate::dom::node::{Node, ShadowIncluding}; +use crate::dom::processinginstruction::ProcessingInstruction; + +#[derive(Clone, Debug, PartialEq)] +pub enum Error { + NotANodeset, + InvalidPath, + UnknownFunction { name: QualName }, + UnknownVariable { name: QualName }, + UnknownNamespace { prefix: String }, + InvalidQName { qname: ParserQualName }, + FunctionEvaluation { fname: String }, + Internal { msg: String }, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::NotANodeset => write!(f, "expression did not evaluate to a nodeset"), + Error::InvalidPath => write!(f, "invalid path expression"), + Error::UnknownFunction { name } => write!(f, "unknown function {:?}", name), + Error::UnknownVariable { name } => write!(f, "unknown variable {:?}", name), + Error::UnknownNamespace { prefix } => { + write!(f, "unknown namespace prefix {:?}", prefix) + }, + Error::InvalidQName { qname } => { + write!(f, "invalid QName {:?}", qname) + }, + Error::FunctionEvaluation { fname } => { + write!(f, "error while evaluating function: {}", fname) + }, + Error::Internal { msg } => { + write!(f, "internal error: {}", msg) + }, + } + } +} + +impl std::error::Error for Error {} + +pub fn try_extract_nodeset(v: Value) -> Result<Vec<DomRoot<Node>>, Error> { + match v { + Value::Nodeset(ns) => Ok(ns), + _ => Err(Error::NotANodeset), + } +} + +pub trait Evaluatable: fmt::Debug { + fn evaluate(&self, context: &EvaluationCtx) -> Result<Value, Error>; + /// Returns true if this expression evaluates to a primitive value, without needing to touch the DOM + fn is_primitive(&self) -> bool; +} + +impl<T: ?Sized> Evaluatable for Box<T> +where + T: Evaluatable, +{ + fn evaluate(&self, context: &EvaluationCtx) -> Result<Value, Error> { + (**self).evaluate(context) + } + + fn is_primitive(&self) -> bool { + (**self).is_primitive() + } +} + +impl Evaluatable for Expr { + fn evaluate(&self, context: &EvaluationCtx) -> Result<Value, Error> { + match self { + Expr::And(left, right) => { + let left_bool = left.evaluate(context)?.boolean(); + let v = left_bool && right.evaluate(context)?.boolean(); + Ok(Value::Boolean(v)) + }, + Expr::Or(left, right) => { + let left_bool = left.evaluate(context)?.boolean(); + let v = left_bool || right.evaluate(context)?.boolean(); + Ok(Value::Boolean(v)) + }, + Expr::Equality(left, equality_op, right) => { + let left_val = left.evaluate(context)?; + let right_val = right.evaluate(context)?; + + let v = match equality_op { + EqualityOp::Eq => left_val == right_val, + EqualityOp::NotEq => left_val != right_val, + }; + + Ok(Value::Boolean(v)) + }, + Expr::Relational(left, relational_op, right) => { + let left_val = left.evaluate(context)?.number(); + let right_val = right.evaluate(context)?.number(); + + let v = match relational_op { + RelationalOp::Lt => left_val < right_val, + RelationalOp::Gt => left_val > right_val, + RelationalOp::LtEq => left_val <= right_val, + RelationalOp::GtEq => left_val >= right_val, + }; + Ok(Value::Boolean(v)) + }, + Expr::Additive(left, additive_op, right) => { + let left_val = left.evaluate(context)?.number(); + let right_val = right.evaluate(context)?.number(); + + let v = match additive_op { + AdditiveOp::Add => left_val + right_val, + AdditiveOp::Sub => left_val - right_val, + }; + Ok(Value::Number(v)) + }, + Expr::Multiplicative(left, multiplicative_op, right) => { + let left_val = left.evaluate(context)?.number(); + let right_val = right.evaluate(context)?.number(); + + let v = match multiplicative_op { + MultiplicativeOp::Mul => left_val * right_val, + MultiplicativeOp::Div => left_val / right_val, + MultiplicativeOp::Mod => left_val % right_val, + }; + Ok(Value::Number(v)) + }, + Expr::Unary(unary_op, expr) => { + let v = expr.evaluate(context)?.number(); + + match unary_op { + UnaryOp::Minus => Ok(Value::Number(-v)), + } + }, + Expr::Union(left, right) => { + let as_nodes = |e: &Expr| e.evaluate(context).and_then(try_extract_nodeset); + + let mut left_nodes = as_nodes(left)?; + let right_nodes = as_nodes(right)?; + + left_nodes.extend(right_nodes); + Ok(Value::Nodeset(left_nodes)) + }, + Expr::Path(path_expr) => path_expr.evaluate(context), + } + } + + fn is_primitive(&self) -> bool { + match self { + Expr::Or(left, right) => left.is_primitive() && right.is_primitive(), + Expr::And(left, right) => left.is_primitive() && right.is_primitive(), + Expr::Equality(left, _, right) => left.is_primitive() && right.is_primitive(), + Expr::Relational(left, _, right) => left.is_primitive() && right.is_primitive(), + Expr::Additive(left, _, right) => left.is_primitive() && right.is_primitive(), + Expr::Multiplicative(left, _, right) => left.is_primitive() && right.is_primitive(), + Expr::Unary(_, expr) => expr.is_primitive(), + Expr::Union(_, _) => false, + Expr::Path(path_expr) => path_expr.is_primitive(), + } + } +} + +impl Evaluatable for PathExpr { + fn evaluate(&self, context: &EvaluationCtx) -> Result<Value, Error> { + let mut current_nodes = vec![context.context_node.clone()]; + + // If path starts with '//', add an implicit descendant-or-self::node() step + if self.is_descendant { + current_nodes = current_nodes + .iter() + .flat_map(|n| n.traverse_preorder(ShadowIncluding::No)) + .collect(); + } + + trace!("[PathExpr] Evaluating path expr: {:?}", self); + + let have_multiple_steps = self.steps.len() > 1; + + for step in &self.steps { + let mut next_nodes = Vec::new(); + for node in current_nodes { + let step_context = context.subcontext_for_node(&node); + let step_result = step.evaluate(&step_context)?; + match (have_multiple_steps, step_result) { + (_, Value::Nodeset(mut nodes)) => { + // as long as we evaluate to nodesets, keep going + next_nodes.append(&mut nodes); + }, + (false, value) => { + trace!("[PathExpr] Got single primitive value: {:?}", value); + return Ok(value); + }, + (true, value) => { + error!( + "Expected nodeset from step evaluation, got: {:?} node: {:?}, step: {:?}", + value, node, step + ); + return Ok(value); + }, + } + } + current_nodes = next_nodes; + } + + trace!("[PathExpr] Got nodes: {:?}", current_nodes); + + Ok(Value::Nodeset(current_nodes)) + } + + fn is_primitive(&self) -> bool { + !self.is_absolute && + !self.is_descendant && + self.steps.len() == 1 && + self.steps[0].is_primitive() + } +} + +impl TryFrom<&ParserQualName> for QualName { + type Error = Error; + + fn try_from(qname: &ParserQualName) -> Result<Self, Self::Error> { + let qname_as_str = qname.to_string(); + if let Ok((ns, prefix, local)) = validate_and_extract(None, &qname_as_str) { + Ok(QualName { prefix, ns, local }) + } else { + Err(Error::InvalidQName { + qname: qname.clone(), + }) + } + } +} + +pub enum NameTestComparisonMode { + /// Namespaces must match exactly + XHtml, + /// Missing namespace information is treated as the HTML namespace + Html, +} + +pub fn element_name_test( + expected_name: QualName, + element_qualname: QualName, + comparison_mode: NameTestComparisonMode, +) -> bool { + let is_wildcard = expected_name.local == local_name!("*"); + + let test_prefix = expected_name + .prefix + .clone() + .unwrap_or(namespace_prefix!("")); + let test_ns_uri = match test_prefix { + namespace_prefix!("*") => ns!(*), + namespace_prefix!("html") => ns!(html), + namespace_prefix!("xml") => ns!(xml), + namespace_prefix!("xlink") => ns!(xlink), + namespace_prefix!("svg") => ns!(svg), + namespace_prefix!("mathml") => ns!(mathml), + namespace_prefix!("") => { + if matches!(comparison_mode, NameTestComparisonMode::XHtml) { + ns!() + } else { + ns!(html) + } + }, + _ => { + // We don't support custom namespaces, use fallback or panic depending on strictness + if matches!(comparison_mode, NameTestComparisonMode::XHtml) { + panic!("Unrecognized namespace prefix: {}", test_prefix) + } else { + ns!(html) + } + }, + }; + + if is_wildcard { + test_ns_uri == element_qualname.ns + } else { + test_ns_uri == element_qualname.ns && expected_name.local == element_qualname.local + } +} + +fn apply_node_test(test: &NodeTest, node: &Node) -> Result<bool, Error> { + let result = match test { + NodeTest::Name(qname) => { + // Convert the unvalidated "parser QualName" into the proper QualName structure + let wanted_name: QualName = qname.try_into()?; + if matches!(node.type_id(), NodeTypeId::Element(_)) { + let element = node.downcast::<Element>().unwrap(); + let comparison_mode = if node.owner_doc().is_xhtml_document() { + NameTestComparisonMode::XHtml + } else { + NameTestComparisonMode::Html + }; + let element_qualname = QualName::new( + element.prefix().as_ref().cloned(), + element.namespace().clone(), + element.local_name().clone(), + ); + element_name_test(wanted_name, element_qualname, comparison_mode) + } else { + false + } + }, + NodeTest::Wildcard => true, + NodeTest::Kind(kind) => match kind { + KindTest::PI(target) => { + if NodeTypeId::CharacterData(CharacterDataTypeId::ProcessingInstruction) == + node.type_id() + { + let pi = node.downcast::<ProcessingInstruction>().unwrap(); + match (target, pi.target()) { + (Some(target_name), node_target_name) + if target_name == &node_target_name.to_string() => + { + true + }, + (Some(_), _) => false, + (None, _) => true, + } + } else { + false + } + }, + KindTest::Comment => matches!( + node.type_id(), + NodeTypeId::CharacterData(CharacterDataTypeId::Comment) + ), + KindTest::Text => matches!( + node.type_id(), + NodeTypeId::CharacterData(CharacterDataTypeId::Text(_)) + ), + KindTest::Node => true, + }, + }; + Ok(result) +} + +impl Evaluatable for StepExpr { + fn evaluate(&self, context: &EvaluationCtx) -> Result<Value, Error> { + match self { + StepExpr::Filter(filter_expr) => filter_expr.evaluate(context), + StepExpr::Axis(axis_step) => { + let nodes: Vec<DomRoot<Node>> = match axis_step.axis { + Axis::Child => context.context_node.children().collect(), + Axis::Descendant => context + .context_node + .traverse_preorder(ShadowIncluding::No) + .skip(1) + .collect(), + Axis::Parent => vec![context.context_node.GetParentNode()] + .into_iter() + .flatten() + .collect(), + Axis::Ancestor => context.context_node.ancestors().collect(), + Axis::Following => context + .context_node + .following_nodes(&context.context_node) + .skip(1) + .collect(), + Axis::Preceding => context + .context_node + .preceding_nodes(&context.context_node) + .skip(1) + .collect(), + Axis::FollowingSibling => context.context_node.following_siblings().collect(), + Axis::PrecedingSibling => context.context_node.preceding_siblings().collect(), + Axis::Attribute => { + if matches!(Node::type_id(&context.context_node), NodeTypeId::Element(_)) { + let element = context.context_node.downcast::<Element>().unwrap(); + element + .attrs() + .iter() + .map(|attr| attr.upcast::<Node>()) + .map(DomRoot::from_ref) + .collect() + } else { + vec![] + } + }, + Axis::Self_ => vec![context.context_node.clone()], + Axis::DescendantOrSelf => context + .context_node + .traverse_preorder(ShadowIncluding::No) + .collect(), + Axis::AncestorOrSelf => context + .context_node + .inclusive_ancestors(ShadowIncluding::No) + .collect(), + Axis::Namespace => Vec::new(), // Namespace axis is not commonly implemented + }; + + trace!("[StepExpr] Axis {:?} got nodes {:?}", axis_step.axis, nodes); + + // Filter nodes according to the step's node_test. Will error out if any NodeTest + // application errors out. + let filtered_nodes: Vec<DomRoot<Node>> = nodes + .into_iter() + .map(|node| { + apply_node_test(&axis_step.node_test, &node) + .map(|matches| matches.then_some(node)) + }) + .collect::<Result<Vec<_>, _>>()? + .into_iter() + .flatten() + .collect(); + + trace!("[StepExpr] Filtering got nodes {:?}", filtered_nodes); + + if axis_step.predicates.predicates.is_empty() { + trace!( + "[StepExpr] No predicates, returning nodes {:?}", + filtered_nodes + ); + Ok(Value::Nodeset(filtered_nodes)) + } else { + // Apply predicates + let predicate_list_subcontext = context + .update_predicate_nodes(filtered_nodes.iter().map(|n| &**n).collect()); + axis_step.predicates.evaluate(&predicate_list_subcontext) + } + }, + } + } + + fn is_primitive(&self) -> bool { + match self { + StepExpr::Filter(filter_expr) => filter_expr.is_primitive(), + StepExpr::Axis(_) => false, + } + } +} + +impl Evaluatable for PredicateListExpr { + fn evaluate(&self, context: &EvaluationCtx) -> Result<Value, Error> { + if let Some(ref predicate_nodes) = context.predicate_nodes { + // Initializing: every node the predicates act on is matched + let mut matched_nodes: Vec<DomRoot<Node>> = predicate_nodes.clone(); + + // apply each predicate to the nodes matched by the previous predicate + for predicate_expr in &self.predicates { + let context_for_predicate = + context.update_predicate_nodes(matched_nodes.iter().map(|n| &**n).collect()); + + let narrowed_nodes = predicate_expr + .evaluate(&context_for_predicate) + .and_then(try_extract_nodeset)?; + matched_nodes = narrowed_nodes; + trace!( + "[PredicateListExpr] Predicate {:?} matched nodes {:?}", + predicate_expr, + matched_nodes + ); + } + Ok(Value::Nodeset(matched_nodes)) + } else { + Err(Error::Internal { + msg: "[PredicateListExpr] No nodes on stack for predicate to operate on" + .to_string(), + }) + } + } + + fn is_primitive(&self) -> bool { + self.predicates.len() == 1 && self.predicates[0].is_primitive() + } +} + +impl Evaluatable for PredicateExpr { + fn evaluate(&self, context: &EvaluationCtx) -> Result<Value, Error> { + let narrowed_nodes: Result<Vec<DomRoot<Node>>, Error> = context + .subcontext_iter_for_nodes() + .filter_map(|ctx| { + if let Some(predicate_ctx) = ctx.predicate_ctx { + let eval_result = self.expr.evaluate(&ctx); + + let v = match eval_result { + Ok(Value::Number(v)) => Ok(predicate_ctx.index == v as usize), + Ok(v) => Ok(v.boolean()), + Err(e) => Err(e), + }; + + match v { + Ok(true) => Some(Ok(ctx.context_node)), + Ok(false) => None, + Err(e) => Some(Err(e)), + } + } else { + Some(Err(Error::Internal { + msg: "[PredicateExpr] No predicate context set".to_string(), + })) + } + }) + .collect(); + + Ok(Value::Nodeset(narrowed_nodes?)) + } + + fn is_primitive(&self) -> bool { + self.expr.is_primitive() + } +} + +impl Evaluatable for FilterExpr { + fn evaluate(&self, context: &EvaluationCtx) -> Result<Value, Error> { + let primary_result = self.primary.evaluate(context)?; + let have_predicates = !self.predicates.predicates.is_empty(); + + match (have_predicates, &primary_result) { + (false, _) => { + trace!( + "[FilterExpr] No predicates, returning primary result: {:?}", + primary_result + ); + Ok(primary_result) + }, + (true, Value::Nodeset(vec)) => { + let predicate_list_subcontext = + context.update_predicate_nodes(vec.iter().map(|n| &**n).collect()); + let result_filtered_by_predicates = + self.predicates.evaluate(&predicate_list_subcontext); + trace!( + "[FilterExpr] Result filtered by predicates: {:?}", + result_filtered_by_predicates + ); + result_filtered_by_predicates + }, + // You can't use filtering expressions `[]` on other than node-sets + (true, _) => Err(Error::NotANodeset), + } + } + + fn is_primitive(&self) -> bool { + self.predicates.predicates.is_empty() && self.primary.is_primitive() + } +} + +impl Evaluatable for PrimaryExpr { + fn evaluate(&self, context: &EvaluationCtx) -> Result<Value, Error> { + match self { + PrimaryExpr::Literal(literal) => literal.evaluate(context), + PrimaryExpr::Variable(_qname) => todo!(), + PrimaryExpr::Parenthesized(expr) => expr.evaluate(context), + PrimaryExpr::ContextItem => Ok(Value::Nodeset(vec![context.context_node.clone()])), + PrimaryExpr::Function(core_function) => core_function.evaluate(context), + } + } + + fn is_primitive(&self) -> bool { + match self { + PrimaryExpr::Literal(_) => true, + PrimaryExpr::Variable(_qname) => false, + PrimaryExpr::Parenthesized(expr) => expr.is_primitive(), + PrimaryExpr::ContextItem => false, + PrimaryExpr::Function(_) => false, + } + } +} + +impl Evaluatable for Literal { + fn evaluate(&self, _context: &EvaluationCtx) -> Result<Value, Error> { + match self { + Literal::Numeric(numeric_literal) => match numeric_literal { + // We currently make no difference between ints and floats + NumericLiteral::Integer(v) => Ok(Value::Number(*v as f64)), + NumericLiteral::Decimal(v) => Ok(Value::Number(*v)), + }, + Literal::String(s) => Ok(Value::String(s.into())), + } + } + + fn is_primitive(&self) -> bool { + true + } +} diff --git a/components/script/xpath/eval_function.rs b/components/script/xpath/eval_function.rs new file mode 100644 index 00000000000..dbf77c503ea --- /dev/null +++ b/components/script/xpath/eval_function.rs @@ -0,0 +1,357 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ + +use super::context::EvaluationCtx; +use super::eval::{try_extract_nodeset, Error, Evaluatable}; +use super::parser::CoreFunction; +use super::Value; +use crate::dom::bindings::codegen::Bindings::NodeBinding::NodeMethods; +use crate::dom::bindings::inheritance::{Castable, NodeTypeId}; +use crate::dom::element::Element; +use crate::dom::node::Node; + +/// Returns e.g. "rect" for `<svg:rect>` +fn local_name(node: &Node) -> Option<String> { + if matches!(Node::type_id(node), NodeTypeId::Element(_)) { + let element = node.downcast::<Element>().unwrap(); + Some(element.local_name().to_string()) + } else { + None + } +} + +/// Returns e.g. "svg:rect" for `<svg:rect>` +fn name(node: &Node) -> Option<String> { + if matches!(Node::type_id(node), NodeTypeId::Element(_)) { + let element = node.downcast::<Element>().unwrap(); + if let Some(prefix) = element.prefix().as_ref() { + Some(format!("{}:{}", prefix, element.local_name())) + } else { + Some(element.local_name().to_string()) + } + } else { + None + } +} + +/// Returns e.g. the SVG namespace URI for `<svg:rect>` +fn namespace_uri(node: &Node) -> Option<String> { + if matches!(Node::type_id(node), NodeTypeId::Element(_)) { + let element = node.downcast::<Element>().unwrap(); + Some(element.namespace().to_string()) + } else { + None + } +} + +/// Returns the text contents of the Node, or empty string if none. +fn string_value(node: &Node) -> String { + node.GetTextContent().unwrap_or_default().to_string() +} + +/// If s2 is found inside s1, return everything *before* s2. Return all of s1 otherwise. +fn substring_before(s1: &str, s2: &str) -> String { + match s1.find(s2) { + Some(pos) => s1[..pos].to_string(), + None => String::new(), + } +} + +/// If s2 is found inside s1, return everything *after* s2. Return all of s1 otherwise. +fn substring_after(s1: &str, s2: &str) -> String { + match s1.find(s2) { + Some(pos) => s1[pos + s2.len()..].to_string(), + None => String::new(), + } +} + +fn substring(s: &str, start_idx: isize, len: Option<isize>) -> String { + let s_len = s.len(); + let len = len.unwrap_or(s_len as isize).max(0) as usize; + let start_idx = start_idx.max(0) as usize; + let end_idx = (start_idx + len.max(0)).min(s_len); + s[start_idx..end_idx].to_string() +} + +/// <https://www.w3.org/TR/1999/REC-xpath-19991116/#function-normalize-space> +pub fn normalize_space(s: &str) -> String { + let mut result = String::with_capacity(s.len()); + let mut last_was_whitespace = true; // Handles leading whitespace + + for c in s.chars() { + match c { + '\x20' | '\x09' | '\x0D' | '\x0A' => { + if !last_was_whitespace { + result.push(' '); + last_was_whitespace = true; + } + }, + other => { + result.push(other); + last_was_whitespace = false; + }, + } + } + + if last_was_whitespace { + result.pop(); + } + + result +} + +impl Evaluatable for CoreFunction { + fn evaluate(&self, context: &EvaluationCtx) -> Result<Value, Error> { + match self { + CoreFunction::Last => { + let predicate_ctx = context.predicate_ctx.ok_or_else(|| Error::Internal { + msg: "[CoreFunction] last() is only usable as a predicate".to_string(), + })?; + Ok(Value::Number(predicate_ctx.size as f64)) + }, + CoreFunction::Position => { + let predicate_ctx = context.predicate_ctx.ok_or_else(|| Error::Internal { + msg: "[CoreFunction] position() is only usable as a predicate".to_string(), + })?; + Ok(Value::Number(predicate_ctx.index as f64)) + }, + CoreFunction::Count(expr) => { + let nodes = expr.evaluate(context).and_then(try_extract_nodeset)?; + Ok(Value::Number(nodes.len() as f64)) + }, + CoreFunction::String(expr_opt) => match expr_opt { + Some(expr) => Ok(Value::String(expr.evaluate(context)?.string())), + None => Ok(Value::String(string_value(&context.context_node))), + }, + CoreFunction::Concat(exprs) => { + let strings: Result<Vec<_>, _> = exprs + .iter() + .map(|e| Ok(e.evaluate(context)?.string())) + .collect(); + Ok(Value::String(strings?.join(""))) + }, + CoreFunction::Id(_expr) => todo!(), + CoreFunction::LocalName(expr_opt) => { + let node = match expr_opt { + Some(expr) => expr + .evaluate(context) + .and_then(try_extract_nodeset)? + .first() + .cloned(), + None => Some(context.context_node.clone()), + }; + let name = node.and_then(|n| local_name(&n)).unwrap_or_default(); + Ok(Value::String(name.to_string())) + }, + CoreFunction::NamespaceUri(expr_opt) => { + let node = match expr_opt { + Some(expr) => expr + .evaluate(context) + .and_then(try_extract_nodeset)? + .first() + .cloned(), + None => Some(context.context_node.clone()), + }; + let ns = node.and_then(|n| namespace_uri(&n)).unwrap_or_default(); + Ok(Value::String(ns.to_string())) + }, + CoreFunction::Name(expr_opt) => { + let node = match expr_opt { + Some(expr) => expr + .evaluate(context) + .and_then(try_extract_nodeset)? + .first() + .cloned(), + None => Some(context.context_node.clone()), + }; + let name = node.and_then(|n| name(&n)).unwrap_or_default(); + Ok(Value::String(name)) + }, + CoreFunction::StartsWith(str1, str2) => { + let s1 = str1.evaluate(context)?.string(); + let s2 = str2.evaluate(context)?.string(); + Ok(Value::Boolean(s1.starts_with(&s2))) + }, + CoreFunction::Contains(str1, str2) => { + let s1 = str1.evaluate(context)?.string(); + let s2 = str2.evaluate(context)?.string(); + Ok(Value::Boolean(s1.contains(&s2))) + }, + CoreFunction::SubstringBefore(str1, str2) => { + let s1 = str1.evaluate(context)?.string(); + let s2 = str2.evaluate(context)?.string(); + Ok(Value::String(substring_before(&s1, &s2))) + }, + CoreFunction::SubstringAfter(str1, str2) => { + let s1 = str1.evaluate(context)?.string(); + let s2 = str2.evaluate(context)?.string(); + Ok(Value::String(substring_after(&s1, &s2))) + }, + CoreFunction::Substring(str1, start, length_opt) => { + let s = str1.evaluate(context)?.string(); + let start_idx = start.evaluate(context)?.number().round() as isize - 1; + let len = match length_opt { + Some(len_expr) => Some(len_expr.evaluate(context)?.number().round() as isize), + None => None, + }; + Ok(Value::String(substring(&s, start_idx, len))) + }, + CoreFunction::StringLength(expr_opt) => { + let s = match expr_opt { + Some(expr) => expr.evaluate(context)?.string(), + None => string_value(&context.context_node), + }; + Ok(Value::Number(s.chars().count() as f64)) + }, + CoreFunction::NormalizeSpace(expr_opt) => { + let s = match expr_opt { + Some(expr) => expr.evaluate(context)?.string(), + None => string_value(&context.context_node), + }; + + Ok(Value::String(normalize_space(&s))) + }, + CoreFunction::Translate(str1, str2, str3) => { + let s = str1.evaluate(context)?.string(); + let from = str2.evaluate(context)?.string(); + let to = str3.evaluate(context)?.string(); + let result = s + .chars() + .map(|c| match from.find(c) { + Some(i) if i < to.chars().count() => to.chars().nth(i).unwrap(), + _ => c, + }) + .collect(); + Ok(Value::String(result)) + }, + CoreFunction::Number(expr_opt) => { + let val = match expr_opt { + Some(expr) => expr.evaluate(context)?, + None => Value::String(string_value(&context.context_node)), + }; + Ok(Value::Number(val.number())) + }, + CoreFunction::Sum(expr) => { + let nodes = expr.evaluate(context).and_then(try_extract_nodeset)?; + let sum = nodes + .iter() + .map(|n| Value::String(string_value(n)).number()) + .sum(); + Ok(Value::Number(sum)) + }, + CoreFunction::Floor(expr) => { + let num = expr.evaluate(context)?.number(); + Ok(Value::Number(num.floor())) + }, + CoreFunction::Ceiling(expr) => { + let num = expr.evaluate(context)?.number(); + Ok(Value::Number(num.ceil())) + }, + CoreFunction::Round(expr) => { + let num = expr.evaluate(context)?.number(); + Ok(Value::Number(num.round())) + }, + CoreFunction::Boolean(expr) => Ok(Value::Boolean(expr.evaluate(context)?.boolean())), + CoreFunction::Not(expr) => Ok(Value::Boolean(!expr.evaluate(context)?.boolean())), + CoreFunction::True => Ok(Value::Boolean(true)), + CoreFunction::False => Ok(Value::Boolean(false)), + CoreFunction::Lang(_) => Ok(Value::Nodeset(vec![])), // Not commonly used in the DOM, short-circuit it + } + } + + fn is_primitive(&self) -> bool { + match self { + CoreFunction::Last => false, + CoreFunction::Position => false, + CoreFunction::Count(_) => false, + CoreFunction::Id(_) => false, + CoreFunction::LocalName(_) => false, + CoreFunction::NamespaceUri(_) => false, + CoreFunction::Name(_) => false, + CoreFunction::String(expr_opt) => expr_opt + .as_ref() + .map(|expr| expr.is_primitive()) + .unwrap_or(false), + CoreFunction::Concat(vec) => vec.iter().all(|expr| expr.is_primitive()), + CoreFunction::StartsWith(expr, substr) => expr.is_primitive() && substr.is_primitive(), + CoreFunction::Contains(expr, substr) => expr.is_primitive() && substr.is_primitive(), + CoreFunction::SubstringBefore(expr, substr) => { + expr.is_primitive() && substr.is_primitive() + }, + CoreFunction::SubstringAfter(expr, substr) => { + expr.is_primitive() && substr.is_primitive() + }, + CoreFunction::Substring(expr, start_pos, length_opt) => { + expr.is_primitive() && + start_pos.is_primitive() && + length_opt + .as_ref() + .map(|length| length.is_primitive()) + .unwrap_or(false) + }, + CoreFunction::StringLength(expr_opt) => expr_opt + .as_ref() + .map(|expr| expr.is_primitive()) + .unwrap_or(false), + CoreFunction::NormalizeSpace(expr_opt) => expr_opt + .as_ref() + .map(|expr| expr.is_primitive()) + .unwrap_or(false), + CoreFunction::Translate(expr, from_chars, to_chars) => { + expr.is_primitive() && from_chars.is_primitive() && to_chars.is_primitive() + }, + CoreFunction::Number(expr_opt) => expr_opt + .as_ref() + .map(|expr| expr.is_primitive()) + .unwrap_or(false), + CoreFunction::Sum(expr) => expr.is_primitive(), + CoreFunction::Floor(expr) => expr.is_primitive(), + CoreFunction::Ceiling(expr) => expr.is_primitive(), + CoreFunction::Round(expr) => expr.is_primitive(), + CoreFunction::Boolean(expr) => expr.is_primitive(), + CoreFunction::Not(expr) => expr.is_primitive(), + CoreFunction::True => true, + CoreFunction::False => true, + CoreFunction::Lang(_) => false, + } + } +} +#[cfg(test)] +mod tests { + use super::{substring, substring_after, substring_before}; + + #[test] + fn test_substring_before() { + assert_eq!(substring_before("hello world", "world"), "hello "); + assert_eq!(substring_before("prefix:name", ":"), "prefix"); + assert_eq!(substring_before("no-separator", "xyz"), ""); + assert_eq!(substring_before("", "anything"), ""); + assert_eq!(substring_before("multiple:colons:here", ":"), "multiple"); + assert_eq!(substring_before("start-match-test", "start"), ""); + } + + #[test] + fn test_substring_after() { + assert_eq!(substring_after("hello world", "hello "), "world"); + assert_eq!(substring_after("prefix:name", ":"), "name"); + assert_eq!(substring_after("no-separator", "xyz"), ""); + assert_eq!(substring_after("", "anything"), ""); + assert_eq!(substring_after("multiple:colons:here", ":"), "colons:here"); + assert_eq!(substring_after("test-end-match", "match"), ""); + } + + #[test] + fn test_substring() { + assert_eq!(substring("hello world", 0, Some(5)), "hello"); + assert_eq!(substring("hello world", 6, Some(5)), "world"); + assert_eq!(substring("hello", 1, Some(3)), "ell"); + assert_eq!(substring("hello", -5, Some(2)), "he"); + assert_eq!(substring("hello", 0, None), "hello"); + assert_eq!(substring("hello", 2, Some(10)), "llo"); + assert_eq!(substring("hello", 5, Some(1)), ""); + assert_eq!(substring("", 0, Some(5)), ""); + assert_eq!(substring("hello", 0, Some(0)), ""); + assert_eq!(substring("hello", 0, Some(-5)), ""); + } +} diff --git a/components/script/xpath/eval_value.rs b/components/script/xpath/eval_value.rs new file mode 100644 index 00000000000..de3f0d9d075 --- /dev/null +++ b/components/script/xpath/eval_value.rs @@ -0,0 +1,242 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ + +use std::borrow::ToOwned; +use std::collections::HashSet; +use std::{fmt, string}; + +use crate::dom::bindings::codegen::Bindings::NodeBinding::Node_Binding::NodeMethods; +use crate::dom::bindings::root::DomRoot; +use crate::dom::bindings::utils::AsVoidPtr; +use crate::dom::node::Node; + +/// The primary types of values that an XPath expression returns as a result. +pub enum Value { + Boolean(bool), + /// A IEEE-754 double-precision floating point number + Number(f64), + String(String), + /// A collection of not-necessarily-unique nodes + Nodeset(Vec<DomRoot<Node>>), +} + +impl fmt::Debug for Value { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Value::Boolean(val) => write!(f, "{}", val), + Value::Number(val) => write!(f, "{}", val), + Value::String(ref val) => write!(f, "{}", val), + Value::Nodeset(ref val) => write!(f, "Nodeset({:?})", val), + } + } +} + +pub fn str_to_num(s: &str) -> f64 { + s.trim().parse().unwrap_or(f64::NAN) +} + +/// Helper for PartialEq<Value> implementations +fn str_vals(nodes: &[DomRoot<Node>]) -> HashSet<String> { + nodes + .iter() + .map(|n| n.GetTextContent().unwrap_or_default().to_string()) + .collect() +} + +/// Helper for PartialEq<Value> implementations +fn num_vals(nodes: &[DomRoot<Node>]) -> Vec<f64> { + nodes + .iter() + .map(|n| Value::String(n.GetTextContent().unwrap_or_default().into()).number()) + .collect() +} + +impl PartialEq<Value> for Value { + fn eq(&self, other: &Value) -> bool { + match (self, other) { + (Value::Nodeset(left_nodes), Value::Nodeset(right_nodes)) => { + let left_strings = str_vals(left_nodes); + let right_strings = str_vals(right_nodes); + !left_strings.is_disjoint(&right_strings) + }, + (&Value::Nodeset(ref nodes), &Value::Number(val)) | + (&Value::Number(val), &Value::Nodeset(ref nodes)) => { + let numbers = num_vals(nodes); + numbers.iter().any(|n| *n == val) + }, + (&Value::Nodeset(ref nodes), &Value::String(ref val)) | + (&Value::String(ref val), &Value::Nodeset(ref nodes)) => { + let strings = str_vals(nodes); + strings.contains(val) + }, + (&Value::Boolean(_), _) | (_, &Value::Boolean(_)) => self.boolean() == other.boolean(), + (&Value::Number(_), _) | (_, &Value::Number(_)) => self.number() == other.number(), + _ => self.string() == other.string(), + } + } +} + +impl Value { + pub fn boolean(&self) -> bool { + match *self { + Value::Boolean(val) => val, + Value::Number(n) => n != 0.0 && !n.is_nan(), + Value::String(ref s) => !s.is_empty(), + Value::Nodeset(ref nodeset) => !nodeset.is_empty(), + } + } + + pub fn into_boolean(self) -> bool { + self.boolean() + } + + pub fn number(&self) -> f64 { + match *self { + Value::Boolean(val) => { + if val { + 1.0 + } else { + 0.0 + } + }, + Value::Number(val) => val, + Value::String(ref s) => str_to_num(s), + Value::Nodeset(..) => str_to_num(&self.string()), + } + } + + pub fn into_number(self) -> f64 { + self.number() + } + + pub fn string(&self) -> string::String { + match *self { + Value::Boolean(v) => v.to_string(), + Value::Number(n) => { + if n.is_infinite() { + if n.signum() < 0.0 { + "-Infinity".to_owned() + } else { + "Infinity".to_owned() + } + } else if n == 0.0 { + // catches -0.0 also + 0.0.to_string() + } else { + n.to_string() + } + }, + Value::String(ref val) => val.clone(), + Value::Nodeset(ref nodes) => match nodes.document_order_first() { + Some(n) => n.GetTextContent().unwrap_or_default().to_string(), + None => "".to_owned(), + }, + } + } + + pub fn into_string(self) -> string::String { + match self { + Value::String(val) => val, + other => other.string(), + } + } +} + +macro_rules! from_impl { + ($raw:ty, $variant:expr) => { + impl From<$raw> for Value { + fn from(other: $raw) -> Value { + $variant(other) + } + } + }; +} + +from_impl!(bool, Value::Boolean); +from_impl!(f64, Value::Number); +from_impl!(String, Value::String); +impl<'a> From<&'a str> for Value { + fn from(other: &'a str) -> Value { + Value::String(other.into()) + } +} +from_impl!(Vec<DomRoot<Node>>, Value::Nodeset); + +macro_rules! partial_eq_impl { + ($raw:ty, $variant:pat => $b:expr) => { + impl PartialEq<$raw> for Value { + fn eq(&self, other: &$raw) -> bool { + match *self { + $variant => $b == other, + _ => false, + } + } + } + + impl PartialEq<Value> for $raw { + fn eq(&self, other: &Value) -> bool { + match *other { + $variant => $b == self, + _ => false, + } + } + } + }; +} + +partial_eq_impl!(bool, Value::Boolean(ref v) => v); +partial_eq_impl!(f64, Value::Number(ref v) => v); +partial_eq_impl!(String, Value::String(ref v) => v); +partial_eq_impl!(&str, Value::String(ref v) => v); +partial_eq_impl!(Vec<DomRoot<Node>>, Value::Nodeset(ref v) => v); + +pub trait NodesetHelpers { + /// Returns the node that occurs first in [document order] + /// + /// [document order]: https://www.w3.org/TR/xpath/#dt-document-order + fn document_order_first(&self) -> Option<DomRoot<Node>>; + fn document_order(&self) -> Vec<DomRoot<Node>>; + fn document_order_unique(&self) -> Vec<DomRoot<Node>>; +} + +impl NodesetHelpers for Vec<DomRoot<Node>> { + fn document_order_first(&self) -> Option<DomRoot<Node>> { + self.iter() + .min_by(|a, b| { + if a == b { + std::cmp::Ordering::Equal + } else if a.is_before(b) { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Greater + } + }) + .cloned() + } + fn document_order(&self) -> Vec<DomRoot<Node>> { + let mut nodes: Vec<DomRoot<Node>> = self.clone(); + if nodes.len() == 1 { + return nodes; + } + + nodes.sort_by(|a, b| { + if a == b { + std::cmp::Ordering::Equal + } else if a.is_before(b) { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Greater + } + }); + + nodes + } + fn document_order_unique(&self) -> Vec<DomRoot<Node>> { + let mut nodes: Vec<DomRoot<Node>> = self.document_order(); + + nodes.dedup_by_key(|n| n.as_void_ptr()); + + nodes + } +} diff --git a/components/script/xpath/mod.rs b/components/script/xpath/mod.rs new file mode 100644 index 00000000000..b1bfa3b3089 --- /dev/null +++ b/components/script/xpath/mod.rs @@ -0,0 +1,73 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ + +use context::EvaluationCtx; +use eval::Evaluatable; +pub use eval_value::{NodesetHelpers, Value}; +use parser::OwnedParserError; +pub use parser::{parse as parse_impl, Expr}; + +use super::dom::node::Node; + +mod context; +mod eval; +mod eval_function; +mod eval_value; +mod parser; + +/// The failure modes of executing an XPath. +#[derive(Debug, PartialEq)] +pub enum Error { + /// The XPath was syntactically invalid + Parsing { source: OwnedParserError }, + /// The XPath could not be executed + Evaluating { source: eval::Error }, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Error::Parsing { source } => write!(f, "Unable to parse XPath: {}", source), + Error::Evaluating { source } => write!(f, "Unable to evaluate XPath: {}", source), + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Error::Parsing { source } => Some(source), + Error::Evaluating { source } => Some(source), + } + } +} + +/// Parse an XPath expression from a string +pub fn parse(xpath: &str) -> Result<Expr, Error> { + match parse_impl(xpath) { + Ok(expr) => { + debug!("Parsed XPath: {:?}", expr); + Ok(expr) + }, + Err(e) => { + debug!("Unable to parse XPath: {}", e); + Err(Error::Parsing { source: e }) + }, + } +} + +/// Evaluate an already-parsed XPath expression +pub fn evaluate_parsed_xpath(expr: &Expr, context_node: &Node) -> Result<Value, Error> { + let context = EvaluationCtx::new(context_node); + match expr.evaluate(&context) { + Ok(v) => { + debug!("Evaluated XPath: {:?}", v); + Ok(v) + }, + Err(e) => { + debug!("Unable to evaluate XPath: {}", e); + Err(Error::Evaluating { source: e }) + }, + } +} diff --git a/components/script/xpath/parser.rs b/components/script/xpath/parser.rs new file mode 100644 index 00000000000..3d439b8a4ca --- /dev/null +++ b/components/script/xpath/parser.rs @@ -0,0 +1,1209 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ + +use nom::branch::alt; +use nom::bytes::complete::{tag, take_while1}; +use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0}; +use nom::combinator::{map, opt, recognize, value}; +use nom::error::{Error as NomError, ErrorKind as NomErrorKind}; +use nom::multi::{many0, separated_list0}; +use nom::sequence::{delimited, pair, preceded, tuple}; +use nom::{Finish, IResult}; + +pub fn parse(input: &str) -> Result<Expr, OwnedParserError> { + let (_, ast) = expr(input).finish().map_err(OwnedParserError::from)?; + Ok(ast) +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum Expr { + Or(Box<Expr>, Box<Expr>), + And(Box<Expr>, Box<Expr>), + Equality(Box<Expr>, EqualityOp, Box<Expr>), + Relational(Box<Expr>, RelationalOp, Box<Expr>), + Additive(Box<Expr>, AdditiveOp, Box<Expr>), + Multiplicative(Box<Expr>, MultiplicativeOp, Box<Expr>), + Unary(UnaryOp, Box<Expr>), + Union(Box<Expr>, Box<Expr>), + Path(PathExpr), +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum EqualityOp { + Eq, + NotEq, +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum RelationalOp { + Lt, + Gt, + LtEq, + GtEq, +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum AdditiveOp { + Add, + Sub, +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum MultiplicativeOp { + Mul, + Div, + Mod, +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum UnaryOp { + Minus, +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub struct PathExpr { + pub is_absolute: bool, + pub is_descendant: bool, + pub steps: Vec<StepExpr>, +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub struct PredicateListExpr { + pub predicates: Vec<PredicateExpr>, +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub struct PredicateExpr { + pub expr: Expr, +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub struct FilterExpr { + pub primary: PrimaryExpr, + pub predicates: PredicateListExpr, +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum StepExpr { + Filter(FilterExpr), + Axis(AxisStep), +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub struct AxisStep { + pub axis: Axis, + pub node_test: NodeTest, + pub predicates: PredicateListExpr, +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum Axis { + Child, + Descendant, + Attribute, + Self_, + DescendantOrSelf, + FollowingSibling, + Following, + Namespace, + Parent, + Ancestor, + PrecedingSibling, + Preceding, + AncestorOrSelf, +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum NodeTest { + Name(QName), + Wildcard, + Kind(KindTest), +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub struct QName { + pub prefix: Option<String>, + pub local_part: String, +} + +impl std::fmt::Display for QName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.prefix { + Some(prefix) => write!(f, "{}:{}", prefix, self.local_part), + None => write!(f, "{}", self.local_part), + } + } +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum KindTest { + PI(Option<String>), + Comment, + Text, + Node, +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum PrimaryExpr { + Literal(Literal), + Variable(QName), + Parenthesized(Box<Expr>), + ContextItem, + /// We only support the built-in core functions + Function(CoreFunction), +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum Literal { + Numeric(NumericLiteral), + String(String), +} + +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum NumericLiteral { + Integer(u64), + Decimal(f64), +} + +/// In the DOM we do not support custom functions, so we can enumerate the usable ones +#[derive(Clone, Debug, MallocSizeOf, PartialEq)] +pub enum CoreFunction { + // Node Set Functions + /// last() + Last, + /// position() + Position, + /// count(node-set) + Count(Box<Expr>), + /// id(object) + Id(Box<Expr>), + /// local-name(node-set?) + LocalName(Option<Box<Expr>>), + /// namespace-uri(node-set?) + NamespaceUri(Option<Box<Expr>>), + /// name(node-set?) + Name(Option<Box<Expr>>), + + // String Functions + /// string(object?) + String(Option<Box<Expr>>), + /// concat(string, string, ...) + Concat(Vec<Expr>), + /// starts-with(string, string) + StartsWith(Box<Expr>, Box<Expr>), + /// contains(string, string) + Contains(Box<Expr>, Box<Expr>), + /// substring-before(string, string) + SubstringBefore(Box<Expr>, Box<Expr>), + /// substring-after(string, string) + SubstringAfter(Box<Expr>, Box<Expr>), + /// substring(string, number, number?) + Substring(Box<Expr>, Box<Expr>, Option<Box<Expr>>), + /// string-length(string?) + StringLength(Option<Box<Expr>>), + /// normalize-space(string?) + NormalizeSpace(Option<Box<Expr>>), + /// translate(string, string, string) + Translate(Box<Expr>, Box<Expr>, Box<Expr>), + + // Number Functions + /// number(object?) + Number(Option<Box<Expr>>), + /// sum(node-set) + Sum(Box<Expr>), + /// floor(number) + Floor(Box<Expr>), + /// ceiling(number) + Ceiling(Box<Expr>), + /// round(number) + Round(Box<Expr>), + + // Boolean Functions + /// boolean(object) + Boolean(Box<Expr>), + /// not(boolean) + Not(Box<Expr>), + /// true() + True, + /// false() + False, + /// lang(string) + Lang(Box<Expr>), +} + +impl CoreFunction { + pub fn name(&self) -> &'static str { + match self { + CoreFunction::Last => "last", + CoreFunction::Position => "position", + CoreFunction::Count(_) => "count", + CoreFunction::Id(_) => "id", + CoreFunction::LocalName(_) => "local-name", + CoreFunction::NamespaceUri(_) => "namespace-uri", + CoreFunction::Name(_) => "name", + CoreFunction::String(_) => "string", + CoreFunction::Concat(_) => "concat", + CoreFunction::StartsWith(_, _) => "starts-with", + CoreFunction::Contains(_, _) => "contains", + CoreFunction::SubstringBefore(_, _) => "substring-before", + CoreFunction::SubstringAfter(_, _) => "substring-after", + CoreFunction::Substring(_, _, _) => "substring", + CoreFunction::StringLength(_) => "string-length", + CoreFunction::NormalizeSpace(_) => "normalize-space", + CoreFunction::Translate(_, _, _) => "translate", + CoreFunction::Number(_) => "number", + CoreFunction::Sum(_) => "sum", + CoreFunction::Floor(_) => "floor", + CoreFunction::Ceiling(_) => "ceiling", + CoreFunction::Round(_) => "round", + CoreFunction::Boolean(_) => "boolean", + CoreFunction::Not(_) => "not", + CoreFunction::True => "true", + CoreFunction::False => "false", + CoreFunction::Lang(_) => "lang", + } + } + + pub fn min_args(&self) -> usize { + match self { + // No args + CoreFunction::Last | + CoreFunction::Position | + CoreFunction::True | + CoreFunction::False => 0, + + // Optional single arg + CoreFunction::LocalName(_) | + CoreFunction::NamespaceUri(_) | + CoreFunction::Name(_) | + CoreFunction::String(_) | + CoreFunction::StringLength(_) | + CoreFunction::NormalizeSpace(_) | + CoreFunction::Number(_) => 0, + + // Required single arg + CoreFunction::Count(_) | + CoreFunction::Id(_) | + CoreFunction::Sum(_) | + CoreFunction::Floor(_) | + CoreFunction::Ceiling(_) | + CoreFunction::Round(_) | + CoreFunction::Boolean(_) | + CoreFunction::Not(_) | + CoreFunction::Lang(_) => 1, + + // Required two args + CoreFunction::StartsWith(_, _) | + CoreFunction::Contains(_, _) | + CoreFunction::SubstringBefore(_, _) | + CoreFunction::SubstringAfter(_, _) => 2, + + // Special cases + CoreFunction::Concat(_) => 2, // Minimum 2 args + CoreFunction::Substring(_, _, _) => 2, // 2 or 3 args + CoreFunction::Translate(_, _, _) => 3, // Exactly 3 args + } + } + + pub fn max_args(&self) -> Option<usize> { + match self { + // No args + CoreFunction::Last | + CoreFunction::Position | + CoreFunction::True | + CoreFunction::False => Some(0), + + // Optional single arg (0 or 1) + CoreFunction::LocalName(_) | + CoreFunction::NamespaceUri(_) | + CoreFunction::Name(_) | + CoreFunction::String(_) | + CoreFunction::StringLength(_) | + CoreFunction::NormalizeSpace(_) | + CoreFunction::Number(_) => Some(1), + + // Exactly one arg + CoreFunction::Count(_) | + CoreFunction::Id(_) | + CoreFunction::Sum(_) | + CoreFunction::Floor(_) | + CoreFunction::Ceiling(_) | + CoreFunction::Round(_) | + CoreFunction::Boolean(_) | + CoreFunction::Not(_) | + CoreFunction::Lang(_) => Some(1), + + // Exactly two args + CoreFunction::StartsWith(_, _) | + CoreFunction::Contains(_, _) | + CoreFunction::SubstringBefore(_, _) | + CoreFunction::SubstringAfter(_, _) => Some(2), + + // Special cases + CoreFunction::Concat(_) => None, // Unlimited args + CoreFunction::Substring(_, _, _) => Some(3), // 2 or 3 args + CoreFunction::Translate(_, _, _) => Some(3), // Exactly 3 args + } + } + + /// Returns true if the number of arguments is valid for this function + pub fn is_valid_arity(&self, num_args: usize) -> bool { + let min = self.min_args(); + let max = self.max_args(); + + num_args >= min && max.map_or(true, |max| num_args <= max) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct OwnedParserError { + input: String, + kind: NomErrorKind, +} + +impl<'a> From<NomError<&'a str>> for OwnedParserError { + fn from(err: NomError<&'a str>) -> Self { + OwnedParserError { + input: err.input.to_string(), + kind: err.code, + } + } +} + +impl std::fmt::Display for OwnedParserError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "error {:?} at: {}", self.kind, self.input) + } +} + +impl std::error::Error for OwnedParserError {} + +/// Top-level parser +fn expr(input: &str) -> IResult<&str, Expr> { + expr_single(input) +} + +fn expr_single(input: &str) -> IResult<&str, Expr> { + or_expr(input) +} + +fn or_expr(input: &str) -> IResult<&str, Expr> { + let (input, first) = and_expr(input)?; + let (input, rest) = many0(preceded(ws(tag("or")), and_expr))(input)?; + + Ok(( + input, + rest.into_iter() + .fold(first, |acc, expr| Expr::Or(Box::new(acc), Box::new(expr))), + )) +} + +fn and_expr(input: &str) -> IResult<&str, Expr> { + let (input, first) = equality_expr(input)?; + let (input, rest) = many0(preceded(ws(tag("and")), equality_expr))(input)?; + + Ok(( + input, + rest.into_iter() + .fold(first, |acc, expr| Expr::And(Box::new(acc), Box::new(expr))), + )) +} + +fn equality_expr(input: &str) -> IResult<&str, Expr> { + let (input, first) = relational_expr(input)?; + let (input, rest) = many0(tuple(( + ws(alt(( + map(tag("="), |_| EqualityOp::Eq), + map(tag("!="), |_| EqualityOp::NotEq), + ))), + relational_expr, + )))(input)?; + + Ok(( + input, + rest.into_iter().fold(first, |acc, (op, expr)| { + Expr::Equality(Box::new(acc), op, Box::new(expr)) + }), + )) +} + +fn relational_expr(input: &str) -> IResult<&str, Expr> { + let (input, first) = additive_expr(input)?; + let (input, rest) = many0(tuple(( + ws(alt(( + map(tag("<="), |_| RelationalOp::LtEq), + map(tag(">="), |_| RelationalOp::GtEq), + map(tag("<"), |_| RelationalOp::Lt), + map(tag(">"), |_| RelationalOp::Gt), + ))), + additive_expr, + )))(input)?; + + Ok(( + input, + rest.into_iter().fold(first, |acc, (op, expr)| { + Expr::Relational(Box::new(acc), op, Box::new(expr)) + }), + )) +} + +fn additive_expr(input: &str) -> IResult<&str, Expr> { + let (input, first) = multiplicative_expr(input)?; + let (input, rest) = many0(tuple(( + ws(alt(( + map(tag("+"), |_| AdditiveOp::Add), + map(tag("-"), |_| AdditiveOp::Sub), + ))), + multiplicative_expr, + )))(input)?; + + Ok(( + input, + rest.into_iter().fold(first, |acc, (op, expr)| { + Expr::Additive(Box::new(acc), op, Box::new(expr)) + }), + )) +} + +fn multiplicative_expr(input: &str) -> IResult<&str, Expr> { + let (input, first) = unary_expr(input)?; + let (input, rest) = many0(tuple(( + ws(alt(( + map(tag("*"), |_| MultiplicativeOp::Mul), + map(tag("div"), |_| MultiplicativeOp::Div), + map(tag("mod"), |_| MultiplicativeOp::Mod), + ))), + unary_expr, + )))(input)?; + + Ok(( + input, + rest.into_iter().fold(first, |acc, (op, expr)| { + Expr::Multiplicative(Box::new(acc), op, Box::new(expr)) + }), + )) +} + +fn unary_expr(input: &str) -> IResult<&str, Expr> { + let (input, minus_count) = many0(ws(char('-')))(input)?; + let (input, expr) = union_expr(input)?; + + Ok(( + input, + (0..minus_count.len()).fold(expr, |acc, _| Expr::Unary(UnaryOp::Minus, Box::new(acc))), + )) +} + +fn union_expr(input: &str) -> IResult<&str, Expr> { + let (input, first) = path_expr(input)?; + let (input, rest) = many0(preceded(ws(char('|')), path_expr))(input)?; + + Ok(( + input, + rest.into_iter().fold(first, |acc, expr| { + Expr::Union(Box::new(acc), Box::new(expr)) + }), + )) +} + +fn path_expr(input: &str) -> IResult<&str, Expr> { + alt(( + // "//" RelativePathExpr + map(pair(tag("//"), relative_path_expr), |(_, rel_path)| { + Expr::Path(PathExpr { + is_absolute: true, + is_descendant: true, + steps: match rel_path { + Expr::Path(p) => p.steps, + _ => unreachable!(), + }, + }) + }), + // "/" RelativePathExpr? + map(pair(char('/'), opt(relative_path_expr)), |(_, rel_path)| { + Expr::Path(PathExpr { + is_absolute: true, + is_descendant: false, + steps: rel_path + .map(|p| match p { + Expr::Path(p) => p.steps, + _ => unreachable!(), + }) + .unwrap_or_default(), + }) + }), + // RelativePathExpr + relative_path_expr, + ))(input) +} + +fn relative_path_expr(input: &str) -> IResult<&str, Expr> { + let (input, first) = step_expr(input)?; + let (input, steps) = many0(pair( + // ("/" | "//") + ws(alt((value(false, char('/')), value(true, tag("//"))))), + step_expr, + ))(input)?; + + let mut all_steps = vec![first]; + for (is_descendant, step) in steps { + if is_descendant { + // Insert an implicit descendant-or-self::node() step + all_steps.push(StepExpr::Axis(AxisStep { + axis: Axis::DescendantOrSelf, + node_test: NodeTest::Kind(KindTest::Node), + predicates: PredicateListExpr { predicates: vec![] }, + })); + } + all_steps.push(step); + } + + Ok(( + input, + Expr::Path(PathExpr { + is_absolute: false, + is_descendant: false, + steps: all_steps, + }), + )) +} + +fn step_expr(input: &str) -> IResult<&str, StepExpr> { + alt(( + map(filter_expr, StepExpr::Filter), + map(axis_step, StepExpr::Axis), + ))(input) +} + +fn axis_step(input: &str) -> IResult<&str, AxisStep> { + let (input, (step, predicates)) = + pair(alt((forward_step, reverse_step)), predicate_list)(input)?; + + let (axis, node_test) = step; + Ok(( + input, + AxisStep { + axis, + node_test, + predicates, + }, + )) +} + +fn forward_step(input: &str) -> IResult<&str, (Axis, NodeTest)> { + alt(( + // ForwardAxis NodeTest + pair(forward_axis, node_test), + // AbbrevForwardStep + abbrev_forward_step, + ))(input) +} + +fn forward_axis(input: &str) -> IResult<&str, Axis> { + let (input, axis) = alt(( + value(Axis::Child, tag("child::")), + value(Axis::Descendant, tag("descendant::")), + value(Axis::Attribute, tag("attribute::")), + value(Axis::Self_, tag("self::")), + value(Axis::DescendantOrSelf, tag("descendant-or-self::")), + value(Axis::FollowingSibling, tag("following-sibling::")), + value(Axis::Following, tag("following::")), + value(Axis::Namespace, tag("namespace::")), + ))(input)?; + + Ok((input, axis)) +} + +fn abbrev_forward_step(input: &str) -> IResult<&str, (Axis, NodeTest)> { + let (input, attr) = opt(char('@'))(input)?; + let (input, test) = node_test(input)?; + + Ok(( + input, + ( + if attr.is_some() { + Axis::Attribute + } else { + Axis::Child + }, + test, + ), + )) +} + +fn reverse_step(input: &str) -> IResult<&str, (Axis, NodeTest)> { + alt(( + // ReverseAxis NodeTest + pair(reverse_axis, node_test), + // AbbrevReverseStep + abbrev_reverse_step, + ))(input) +} + +fn reverse_axis(input: &str) -> IResult<&str, Axis> { + alt(( + value(Axis::Parent, tag("parent::")), + value(Axis::Ancestor, tag("ancestor::")), + value(Axis::PrecedingSibling, tag("preceding-sibling::")), + value(Axis::Preceding, tag("preceding::")), + value(Axis::AncestorOrSelf, tag("ancestor-or-self::")), + ))(input) +} + +fn abbrev_reverse_step(input: &str) -> IResult<&str, (Axis, NodeTest)> { + map(tag(".."), |_| { + (Axis::Parent, NodeTest::Kind(KindTest::Node)) + })(input) +} + +fn node_test(input: &str) -> IResult<&str, NodeTest> { + alt(( + map(kind_test, NodeTest::Kind), + map(name_test, |name| match name { + NameTest::Wildcard => NodeTest::Wildcard, + NameTest::QName(qname) => NodeTest::Name(qname), + }), + ))(input) +} + +#[derive(Clone, Debug, PartialEq)] +enum NameTest { + QName(QName), + Wildcard, +} + +fn name_test(input: &str) -> IResult<&str, NameTest> { + alt(( + // NCName ":" "*" + map(tuple((ncname, char(':'), char('*'))), |(prefix, _, _)| { + NameTest::QName(QName { + prefix: Some(prefix.to_string()), + local_part: "*".to_string(), + }) + }), + // "*" + value(NameTest::Wildcard, char('*')), + // QName + map(qname, NameTest::QName), + ))(input) +} + +fn filter_expr(input: &str) -> IResult<&str, FilterExpr> { + let (input, primary) = primary_expr(input)?; + let (input, predicates) = predicate_list(input)?; + + Ok(( + input, + FilterExpr { + primary, + predicates, + }, + )) +} + +fn predicate_list(input: &str) -> IResult<&str, PredicateListExpr> { + let (input, predicates) = many0(predicate)(input)?; + Ok((input, PredicateListExpr { predicates })) +} + +fn predicate(input: &str) -> IResult<&str, PredicateExpr> { + let (input, expr) = delimited(ws(char('[')), expr, ws(char(']')))(input)?; + Ok((input, PredicateExpr { expr })) +} + +fn primary_expr(input: &str) -> IResult<&str, PrimaryExpr> { + alt(( + literal, + var_ref, + map(parenthesized_expr, |e| { + PrimaryExpr::Parenthesized(Box::new(e)) + }), + context_item_expr, + function_call, + ))(input) +} + +fn literal(input: &str) -> IResult<&str, PrimaryExpr> { + map(alt((numeric_literal, string_literal)), |lit| { + PrimaryExpr::Literal(lit) + })(input) +} + +fn numeric_literal(input: &str) -> IResult<&str, Literal> { + alt((decimal_literal, integer_literal))(input) +} + +fn var_ref(input: &str) -> IResult<&str, PrimaryExpr> { + let (input, _) = char('$')(input)?; + let (input, name) = qname(input)?; + Ok((input, PrimaryExpr::Variable(name))) +} + +fn parenthesized_expr(input: &str) -> IResult<&str, Expr> { + delimited(ws(char('(')), expr, ws(char(')')))(input) +} + +fn context_item_expr(input: &str) -> IResult<&str, PrimaryExpr> { + map(char('.'), |_| PrimaryExpr::ContextItem)(input) +} + +fn function_call(input: &str) -> IResult<&str, PrimaryExpr> { + let (input, name) = qname(input)?; + let (input, args) = delimited( + ws(char('(')), + separated_list0(ws(char(',')), expr_single), + ws(char(')')), + )(input)?; + + // Helper to create error + let arity_error = || nom::Err::Error(NomError::new(input, NomErrorKind::Verify)); + + let core_fn = match name.local_part.as_str() { + // Node Set Functions + "last" => CoreFunction::Last, + "position" => CoreFunction::Position, + "count" => CoreFunction::Count(Box::new(args.into_iter().next().ok_or_else(arity_error)?)), + "id" => CoreFunction::Id(Box::new(args.into_iter().next().ok_or_else(arity_error)?)), + "local-name" => CoreFunction::LocalName(args.into_iter().next().map(Box::new)), + "namespace-uri" => CoreFunction::NamespaceUri(args.into_iter().next().map(Box::new)), + "name" => CoreFunction::Name(args.into_iter().next().map(Box::new)), + + // String Functions + "string" => CoreFunction::String(args.into_iter().next().map(Box::new)), + "concat" => CoreFunction::Concat(args.into_iter().collect()), + "starts-with" => { + let mut args = args.into_iter(); + CoreFunction::StartsWith( + Box::new(args.next().ok_or_else(arity_error)?), + Box::new(args.next().ok_or_else(arity_error)?), + ) + }, + "contains" => { + let mut args = args.into_iter(); + CoreFunction::Contains( + Box::new(args.next().ok_or_else(arity_error)?), + Box::new(args.next().ok_or_else(arity_error)?), + ) + }, + "substring-before" => { + let mut args = args.into_iter(); + CoreFunction::SubstringBefore( + Box::new(args.next().ok_or_else(arity_error)?), + Box::new(args.next().ok_or_else(arity_error)?), + ) + }, + "substring-after" => { + let mut args = args.into_iter(); + CoreFunction::SubstringAfter( + Box::new(args.next().ok_or_else(arity_error)?), + Box::new(args.next().ok_or_else(arity_error)?), + ) + }, + "substring" => { + let mut args = args.into_iter(); + CoreFunction::Substring( + Box::new(args.next().ok_or_else(arity_error)?), + Box::new(args.next().ok_or_else(arity_error)?), + args.next().map(Box::new), + ) + }, + "string-length" => CoreFunction::StringLength(args.into_iter().next().map(Box::new)), + "normalize-space" => CoreFunction::NormalizeSpace(args.into_iter().next().map(Box::new)), + "translate" => { + let mut args = args.into_iter(); + CoreFunction::Translate( + Box::new(args.next().ok_or_else(arity_error)?), + Box::new(args.next().ok_or_else(arity_error)?), + Box::new(args.next().ok_or_else(arity_error)?), + ) + }, + + // Number Functions + "number" => CoreFunction::Number(args.into_iter().next().map(Box::new)), + "sum" => CoreFunction::Sum(Box::new(args.into_iter().next().ok_or_else(arity_error)?)), + "floor" => CoreFunction::Floor(Box::new(args.into_iter().next().ok_or_else(arity_error)?)), + "ceiling" => { + CoreFunction::Ceiling(Box::new(args.into_iter().next().ok_or_else(arity_error)?)) + }, + "round" => CoreFunction::Round(Box::new(args.into_iter().next().ok_or_else(arity_error)?)), + + // Boolean Functions + "boolean" => { + CoreFunction::Boolean(Box::new(args.into_iter().next().ok_or_else(arity_error)?)) + }, + "not" => CoreFunction::Not(Box::new(args.into_iter().next().ok_or_else(arity_error)?)), + "true" => CoreFunction::True, + "false" => CoreFunction::False, + "lang" => CoreFunction::Lang(Box::new(args.into_iter().next().ok_or_else(arity_error)?)), + + // Unknown function + _ => return Err(nom::Err::Error(NomError::new(input, NomErrorKind::Verify))), + }; + + Ok((input, PrimaryExpr::Function(core_fn))) +} + +fn kind_test(input: &str) -> IResult<&str, KindTest> { + alt((pi_test, comment_test, text_test, any_kind_test))(input) +} + +fn any_kind_test(input: &str) -> IResult<&str, KindTest> { + map(tuple((tag("node"), ws(char('(')), ws(char(')')))), |_| { + KindTest::Node + })(input) +} + +fn text_test(input: &str) -> IResult<&str, KindTest> { + map(tuple((tag("text"), ws(char('(')), ws(char(')')))), |_| { + KindTest::Text + })(input) +} + +fn comment_test(input: &str) -> IResult<&str, KindTest> { + map( + tuple((tag("comment"), ws(char('(')), ws(char(')')))), + |_| KindTest::Comment, + )(input) +} + +fn pi_test(input: &str) -> IResult<&str, KindTest> { + map( + tuple(( + tag("processing-instruction"), + ws(char('(')), + opt(ws(string_literal)), + ws(char(')')), + )), + |(_, _, literal, _)| { + KindTest::PI(literal.map(|l| match l { + Literal::String(s) => s, + _ => unreachable!(), + })) + }, + )(input) +} + +fn ws<'a, F, O>(inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> +where + F: FnMut(&'a str) -> IResult<&'a str, O>, +{ + delimited(multispace0, inner, multispace0) +} + +fn integer_literal(input: &str) -> IResult<&str, Literal> { + map(recognize(tuple((opt(char('-')), digit1))), |s: &str| { + Literal::Numeric(NumericLiteral::Integer(s.parse().unwrap())) + })(input) +} + +fn decimal_literal(input: &str) -> IResult<&str, Literal> { + map( + recognize(tuple((opt(char('-')), opt(digit1), char('.'), digit1))), + |s: &str| Literal::Numeric(NumericLiteral::Decimal(s.parse().unwrap())), + )(input) +} + +fn string_literal(input: &str) -> IResult<&str, Literal> { + alt(( + delimited( + char('"'), + map(take_while1(|c| c != '"'), |s: &str| { + Literal::String(s.to_string()) + }), + char('"'), + ), + delimited( + char('\''), + map(take_while1(|c| c != '\''), |s: &str| { + Literal::String(s.to_string()) + }), + char('\''), + ), + ))(input) +} + +// QName parser +fn qname(input: &str) -> IResult<&str, QName> { + let (input, prefix) = opt(tuple((ncname, char(':'))))(input)?; + let (input, local) = ncname(input)?; + + Ok(( + input, + QName { + prefix: prefix.map(|(p, _)| p.to_string()), + local_part: local.to_string(), + }, + )) +} + +// NCName parser +fn ncname(input: &str) -> IResult<&str, &str> { + recognize(pair( + alpha1, + many0(alt((alphanumeric1, tag("-"), tag("_")))), + ))(input) +} + +// Test functions to verify the parsers: +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_node_tests() { + let cases = vec![ + ("node()", NodeTest::Kind(KindTest::Node)), + ("text()", NodeTest::Kind(KindTest::Text)), + ("comment()", NodeTest::Kind(KindTest::Comment)), + ( + "processing-instruction()", + NodeTest::Kind(KindTest::PI(None)), + ), + ( + "processing-instruction('test')", + NodeTest::Kind(KindTest::PI(Some("test".to_string()))), + ), + ("*", NodeTest::Wildcard), + ( + "prefix:*", + NodeTest::Name(QName { + prefix: Some("prefix".to_string()), + local_part: "*".to_string(), + }), + ), + ( + "div", + NodeTest::Name(QName { + prefix: None, + local_part: "div".to_string(), + }), + ), + ( + "ns:div", + NodeTest::Name(QName { + prefix: Some("ns".to_string()), + local_part: "div".to_string(), + }), + ), + ]; + + for (input, expected) in cases { + match node_test(input) { + Ok((remaining, result)) => { + assert!(remaining.is_empty(), "Parser didn't consume all input"); + assert_eq!(result, expected); + }, + Err(e) => panic!("Failed to parse '{}': {:?}", input, e), + } + } + } + + #[test] + fn test_filter_expr() { + let cases = vec![ + ( + "processing-instruction('test')[2]", + Expr::Path(PathExpr { + is_absolute: false, + is_descendant: false, + steps: vec![StepExpr::Axis(AxisStep { + axis: Axis::Child, + node_test: NodeTest::Kind(KindTest::PI(Some("test".to_string()))), + predicates: PredicateListExpr { + predicates: vec![PredicateExpr { + expr: Expr::Path(PathExpr { + is_absolute: false, + is_descendant: false, + steps: vec![StepExpr::Filter(FilterExpr { + primary: PrimaryExpr::Literal(Literal::Numeric( + NumericLiteral::Integer(2), + )), + predicates: PredicateListExpr { predicates: vec![] }, + })], + }), + }], + }, + })], + }), + ), + ( + "concat('hello', ' ', 'world')", + Expr::Path(PathExpr { + is_absolute: false, + is_descendant: false, + steps: vec![StepExpr::Filter(FilterExpr { + primary: PrimaryExpr::Function(CoreFunction::Concat(vec![ + Expr::Path(PathExpr { + is_absolute: false, + is_descendant: false, + steps: vec![StepExpr::Filter(FilterExpr { + primary: PrimaryExpr::Literal(Literal::String( + "hello".to_string(), + )), + predicates: PredicateListExpr { predicates: vec![] }, + })], + }), + Expr::Path(PathExpr { + is_absolute: false, + is_descendant: false, + steps: vec![StepExpr::Filter(FilterExpr { + primary: PrimaryExpr::Literal(Literal::String(" ".to_string())), + predicates: PredicateListExpr { predicates: vec![] }, + })], + }), + Expr::Path(PathExpr { + is_absolute: false, + is_descendant: false, + steps: vec![StepExpr::Filter(FilterExpr { + primary: PrimaryExpr::Literal(Literal::String( + "world".to_string(), + )), + predicates: PredicateListExpr { predicates: vec![] }, + })], + }), + ])), + predicates: PredicateListExpr { predicates: vec![] }, + })], + }), + ), + ]; + + for (input, expected) in cases { + match parse(input) { + Ok(result) => { + assert_eq!(result, expected); + }, + Err(e) => panic!("Failed to parse '{}': {:?}", input, e), + } + } + } + + #[test] + fn test_complex_paths() { + let cases = vec![ + ( + "//*[contains(@class, 'test')]", + Expr::Path(PathExpr { + is_absolute: true, + is_descendant: true, + steps: vec![StepExpr::Axis(AxisStep { + axis: Axis::Child, + node_test: NodeTest::Wildcard, + predicates: PredicateListExpr { + predicates: vec![PredicateExpr { + expr: Expr::Path(PathExpr { + is_absolute: false, + is_descendant: false, + steps: vec![StepExpr::Filter(FilterExpr { + primary: PrimaryExpr::Function(CoreFunction::Contains( + Box::new(Expr::Path(PathExpr { + is_absolute: false, + is_descendant: false, + steps: vec![StepExpr::Axis(AxisStep { + axis: Axis::Attribute, + node_test: NodeTest::Name(QName { + prefix: None, + local_part: "class".to_string(), + }), + predicates: PredicateListExpr { + predicates: vec![], + }, + })], + })), + Box::new(Expr::Path(PathExpr { + is_absolute: false, + is_descendant: false, + steps: vec![StepExpr::Filter(FilterExpr { + primary: PrimaryExpr::Literal(Literal::String( + "test".to_string(), + )), + predicates: PredicateListExpr { + predicates: vec![], + }, + })], + })), + )), + predicates: PredicateListExpr { predicates: vec![] }, + })], + }), + }], + }, + })], + }), + ), + ( + "//div[position() > 1]/*[last()]", + Expr::Path(PathExpr { + is_absolute: true, + is_descendant: true, + steps: vec![ + StepExpr::Axis(AxisStep { + axis: Axis::Child, + node_test: NodeTest::Name(QName { + prefix: None, + local_part: "div".to_string(), + }), + predicates: PredicateListExpr { + predicates: vec![PredicateExpr { + expr: Expr::Relational( + Box::new(Expr::Path(PathExpr { + is_absolute: false, + is_descendant: false, + steps: vec![StepExpr::Filter(FilterExpr { + primary: PrimaryExpr::Function( + CoreFunction::Position, + ), + predicates: PredicateListExpr { + predicates: vec![], + }, + })], + })), + RelationalOp::Gt, + Box::new(Expr::Path(PathExpr { + is_absolute: false, + is_descendant: false, + steps: vec![StepExpr::Filter(FilterExpr { + primary: PrimaryExpr::Literal(Literal::Numeric( + NumericLiteral::Integer(1), + )), + predicates: PredicateListExpr { + predicates: vec![], + }, + })], + })), + ), + }], + }, + }), + StepExpr::Axis(AxisStep { + axis: Axis::Child, + node_test: NodeTest::Wildcard, + predicates: PredicateListExpr { + predicates: vec![PredicateExpr { + expr: Expr::Path(PathExpr { + is_absolute: false, + is_descendant: false, + steps: vec![StepExpr::Filter(FilterExpr { + primary: PrimaryExpr::Function(CoreFunction::Last), + predicates: PredicateListExpr { predicates: vec![] }, + })], + }), + }], + }, + }), + ], + }), + ), + ]; + + for (input, expected) in cases { + match parse(input) { + Ok(result) => { + assert_eq!(result, expected); + }, + Err(e) => panic!("Failed to parse '{}': {:?}", input, e), + } + } + } +} |