diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 760dc3570f..1ac5fbcab8 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -23,7 +23,7 @@ use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateDiff, SparkDateTrunc, - SparkSizeFunc, SparkStringSpace, + SparkSizeFunc, SparkStringReplace, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -194,6 +194,7 @@ fn all_scalar_functions() -> Vec> { Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), + Arc::new(ScalarUDF::new_from_impl(SparkStringReplace::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())), ] diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index aac8204e29..f518df7d86 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +mod string_replace; mod string_space; mod substring; +pub use string_replace::SparkStringReplace; pub use string_space::SparkStringSpace; pub use substring::SubstringExpr; diff --git a/native/spark-expr/src/string_funcs/string_replace.rs b/native/spark-expr/src/string_funcs/string_replace.rs new file mode 100644 index 0000000000..597ebd79ad --- /dev/null +++ b/native/spark-expr/src/string_funcs/string_replace.rs @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{AsArray, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::{utils::take_function_args, Result}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::{any::Any, sync::Arc}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkStringReplace { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkStringReplace { + fn default() -> Self { + Self::new() + } +} + +impl SparkStringReplace { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkStringReplace { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "spark_replace" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args: [ColumnarValue; 3] = take_function_args(self.name(), args.args)?; + spark_string_replace(&args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn spark_string_replace(args: &[ColumnarValue]) -> Result { + let arrays = ColumnarValue::values_to_arrays(args)?; + let src = arrays[0].as_string::(); + let search = arrays[1].as_string::(); + let replace = arrays[2].as_string::(); + + let result: StringArray = src + .iter() + .zip(search.iter()) + .zip(replace.iter()) + .map(|((s, search), replace)| match (s, search, replace) { + (Some(s), Some(search), Some(replace)) => { + Some(spark_replace_string(s, search, replace)) + } + _ => None, + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result))) +} + +fn spark_replace_string(src: &str, search: &str, replace: &str) -> String { + if search.is_empty() { + let mut result = String::with_capacity(src.len() + replace.len() * (src.len() + 1)); + result.push_str(replace); + for c in src.chars() { + result.push(c); + result.push_str(replace); + } + result + } else { + src.replace(search, replace) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Array; + use datafusion::common::ScalarValue; + + #[test] + fn test_empty_search_string() { + let src = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("hello"), + Some("world"), + None, + ]))); + let search = ColumnarValue::Scalar(ScalarValue::Utf8(Some("".to_string()))); + let replace = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x".to_string()))); + + match spark_string_replace(&[src, search, replace]) { + Ok(ColumnarValue::Array(result)) => { + let string_arr = result.as_string::(); + assert_eq!(string_arr.value(0), "xhxexlxlxox"); + assert_eq!(string_arr.value(1), "xwxoxrxlxdx"); + assert!(string_arr.is_null(2)); + } + _ => unreachable!(), + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e25d7fb4eb..a4b216a8b5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -162,7 +162,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[StartsWith] -> CometScalarFunction("starts_with"), classOf[StringInstr] -> CometScalarFunction("instr"), classOf[StringRepeat] -> CometStringRepeat, - classOf[StringReplace] -> CometScalarFunction("replace"), + classOf[StringReplace] -> CometScalarFunction("spark_replace"), classOf[StringRPad] -> CometStringRPad, classOf[StringLPad] -> CometStringLPad, classOf[StringSpace] -> CometScalarFunction("string_space"), diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 2a2932c643..188b779efd 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -280,6 +280,16 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("replace with empty search string") { + val table = "test" + withTable(table) { + sql(s"create table $table(col string) using parquet") + sql(s"insert into $table values('hello'), (NULL), ('')") + checkSparkAnswerAndOperator( + s"select replace(col, '', 'x'), replace(col, '', '') from $table") + } + } + // Simplified version of "filter pushdown - StringPredicate" that does not generate dictionaries test("string predicate filter") { Seq(false, true).foreach { pushdown =>