Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 104 additions & 1 deletion native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{EvalMode, SparkError, SparkResult};
use arrow::array::builder::StringBuilder;
use arrow::array::{
BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray,
PrimitiveBuilder, StringArray, StructArray,
PrimitiveBuilder, StringArray, StructArray, TimestampMicrosecondBuilder,
};
use arrow::compute::can_cast_types;
use arrow::datatypes::{
Expand Down Expand Up @@ -1100,6 +1100,7 @@ fn cast_array(
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
}
(Binary, Utf8) => Ok(cast_binary_to_string::<i32>(&array, cast_options)?),
(Date32, Timestamp(_, tz)) => Ok(cast_date_to_timestamp(&array, cast_options, tz)?),
_ if cast_options.is_adapting_schema
|| is_datafusion_spark_compatible(from_type, to_type) =>
{
Expand All @@ -1118,6 +1119,50 @@ fn cast_array(
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
}

fn cast_date_to_timestamp(
array_ref: &ArrayRef,
cast_options: &SparkCastOptions,
target_tz: &Option<Arc<str>>,
) -> SparkResult<ArrayRef> {
let tz_str = if cast_options.timezone.is_empty() {
"UTC"
} else {
cast_options.timezone.as_str()
};
// safe to unwrap since we are falling back to UTC above
let tz = timezone::Tz::from_str(tz_str)?;
let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
let date_array = array_ref.as_primitive::<Date32Type>();

let mut builder = TimestampMicrosecondBuilder::with_capacity(date_array.len());

for date in date_array.iter() {
match date {
Some(date) => {
// safe to unwrap since chrono's range ( 262,143 yrs) is higher than
// number of years possible with days as i32 (~ 6 mil yrs)
// convert date in session timezone to timestamp in UTC
let naive_date = epoch + chrono::Duration::days(date as i64);
let local_midnight = naive_date.and_hms_opt(0, 0, 0).unwrap();
let local_midnight_in_microsec = tz
.from_local_datetime(&local_midnight)
// return earliest possible time (edge case with spring / fall DST changes)
.earliest()
.map(|dt| dt.timestamp_micros())
// in case there is an issue with DST and returns None , we fall back to UTC
.unwrap_or((date as i64) * 86_400 * 1_000_000);
builder.append_value(local_midnight_in_microsec);
}
None => {
builder.append_null();
}
}
}
Ok(Arc::new(
builder.finish().with_timezone_opt(target_tz.clone()),
))
}

fn cast_string_to_float(
array: &ArrayRef,
to_type: &DataType,
Expand Down Expand Up @@ -3408,6 +3453,64 @@ mod tests {
assert!(result.is_err())
}

#[test]
fn test_cast_date_to_timestamp() {
use arrow::array::Date32Array;

// verifying epoch , DST change dates (US) and a null value (comprehensive tests on spark side)
let dates: ArrayRef = Arc::new(Date32Array::from(vec![
Some(0),
Some(19723),
Some(19793),
None,
]));

let non_dst_date = 1704067200000000i64;
let dst_date = 1710115200000000i64;
let seven_hours_ts = 25200000000i64;
let eight_hours_ts = 28800000000i64;

// validate UTC
let result = cast_array(
Arc::clone(&dates),
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
)
.unwrap();
let ts = result.as_primitive::<TimestampMicrosecondType>();
assert_eq!(ts.value(0), 0);
assert_eq!(ts.value(1), non_dst_date);
assert_eq!(ts.value(2), dst_date);
assert!(ts.is_null(3));

// validate LA timezone (follows Daylight savings)
let result = cast_array(
Arc::clone(&dates),
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
&SparkCastOptions::new(EvalMode::Legacy, "America/Los_Angeles", false),
)
.unwrap();
let ts = result.as_primitive::<TimestampMicrosecondType>();
assert_eq!(ts.value(0), eight_hours_ts);
assert_eq!(ts.value(1), non_dst_date + eight_hours_ts);
// should adjust for DST
assert_eq!(ts.value(2), dst_date + seven_hours_ts);
assert!(ts.is_null(3));

// Phoenix timezone (does not follow Daylight savings)
let result = cast_array(
Arc::clone(&dates),
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
&SparkCastOptions::new(EvalMode::Legacy, "America/Phoenix", false),
)
.unwrap();
let ts = result.as_primitive::<TimestampMicrosecondType>();
assert_eq!(ts.value(0), seven_hours_ts);
assert_eq!(ts.value(1), non_dst_date + seven_hours_ts);
assert_eq!(ts.value(2), dst_date + seven_hours_ts);
assert!(ts.is_null(3));
}

#[test]
fn test_cast_struct_to_utf8() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
}
}
Compatible()
case (DataTypes.DateType, toType) => canCastFromDate(toType)
case _ => unsupported(fromType, toType)
}
}
Expand Down Expand Up @@ -344,6 +345,12 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
case _ => Unsupported(Some(s"Cast from DecimalType to $toType is not supported"))
}

private def canCastFromDate(toType: DataType): SupportLevel = toType match {
case DataTypes.TimestampType =>
Compatible()
case _ => Unsupported(Some(s"Cast from DateType to $toType is not supported"))
}

private def unsupported(fromType: DataType, toType: DataType): Unsupported = {
Unsupported(Some(s"Cast from $fromType to $toType is not supported"))
}
Expand Down
82 changes: 78 additions & 4 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -989,9 +989,27 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateDates(), DataTypes.StringType)
}

ignore("cast DateType to TimestampType") {
// Arrow error: Cast error: Casting from Date32 to Timestamp(Microsecond, Some("UTC")) not supported
castTest(generateDates(), DataTypes.TimestampType)
test("cast DateType to TimestampType") {
val compatibleTimezones = Seq(
"UTC",
"America/New_York",
"America/Chicago",
"America/Denver",
"America/Los_Angeles",
"Europe/London",
"Europe/Paris",
"Europe/Berlin",
"Asia/Tokyo",
"Asia/Shanghai",
"Asia/Singapore",
"Asia/Kolkata",
"Australia/Sydney",
"Pacific/Auckland")
compatibleTimezones.map { tz =>
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
castTest(generateDates(), DataTypes.TimestampType)
}
}
}

// CAST from TimestampType
Expand Down Expand Up @@ -1264,7 +1282,63 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

private def generateDates(): DataFrame = {
val values = Seq("2024-01-01", "999-01-01", "12345-01-01")
// add 1st, 10th, 20th of each month from epoch to 2027
val sampledDates = (1970 to 2027).flatMap { year =>
(1 to 12).flatMap { month =>
Seq(1, 10, 20).map(day => f"$year-$month%02d-$day%02d")
}
}

// DST transition dates (1970-2099) for US, EU, Australia
val dstDates = (1970 to 2099).flatMap { year =>
Seq(
// spring forward
s"$year-03-08",
s"$year-03-09",
s"$year-03-10",
s"$year-03-11",
s"$year-03-14",
s"$year-03-15",
s"$year-03-25",
s"$year-03-26",
s"$year-03-27",
s"$year-03-28",
s"$year-03-29",
s"$year-03-30",
s"$year-03-31",
// April (Australia fall back)
s"$year-04-01",
s"$year-04-02",
s"$year-04-03",
s"$year-04-04",
s"$year-04-05",
// October (EU fall back and Australia spring forward)
s"$year-10-01",
s"$year-10-02",
s"$year-10-03",
s"$year-10-04",
s"$year-10-05",
s"$year-10-25",
s"$year-10-26",
s"$year-10-27",
s"$year-10-28",
s"$year-10-29",
s"$year-10-30",
s"$year-10-31",
// US fall back
s"$year-11-01",
s"$year-11-02",
s"$year-11-03",
s"$year-11-04",
s"$year-11-05",
s"$year-11-06",
s"$year-11-07",
s"$year-11-08")
}

// Edge cases
val edgeCases = Seq("1969-12-31", "2000-02-29", "999-01-01", "12345-01-01")
val values = (sampledDates ++ dstDates ++ edgeCases).distinct
withNulls(values).toDF("b").withColumn("a", col("b").cast(DataTypes.DateType)).drop("b")
}

Expand Down
Loading