Skip to content

[SPARK-51834][SQL] Support end-to-end table constraint management #50631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
12 changes: 12 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,12 @@
],
"sqlState" : "42704"
},
"CONSTRAINT_DOES_NOT_HAVE_DATA_TYPE" : {
"message" : [
"Table constraint expressions do not have a data type."
],
"sqlState" : "0A000"
},
"CONVERSION_INVALID_INPUT" : {
"message" : [
"The value <str> (<fmt>) cannot be converted to <targetType> because it is malformed. Correct the value as per the syntax, or change its format. Use <suggestion> to tolerate malformed input and return NULL instead."
Expand Down Expand Up @@ -4070,6 +4076,12 @@
],
"sqlState" : "HV091"
},
"NON_DETERMINISTIC_CHECK_CONSTRAINT" : {
"message" : [
"The check constraint `<checkCondition>` is non-deterministic. Check constraints must only contain deterministic expressions."
],
"sqlState" : "42621"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error code seems consistent with DB2 and what we use for generated columns, +1.

},
"NON_FOLDABLE_ARGUMENT" : {
"message" : [
"The function <funcName> requires the parameter <paramName> to be a foldable expression of the type <paramType>, but the actual argument is a non-foldable."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,15 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
case RenameColumn(table: ResolvedTable, col: ResolvedFieldName, newName) =>
checkColumnNotExists("rename", col.path :+ newName, table.schema)


case AddConstraint(_: ResolvedTable, check: CheckConstraint) =>
if (!check.deterministic) {
check.child.failAnalysis(
errorClass = "NON_DETERMINISTIC_CHECK_CONSTRAINT",
messageParameters = Map("checkCondition" -> check.condition)
)
}

case AlterColumns(table: ResolvedTable, specs) =>
val groupedColumns = specs.groupBy(_.column.name)
groupedColumns.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.SparkException
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SqlScriptingLocalVariableManager
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, Identifier, LookupCatalog, SupportsNamespaces}
Expand Down Expand Up @@ -77,14 +78,19 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
assertValidSessionVariableNameParts(nameParts, resolved)
d.copy(name = resolved)

// For CREATE TABLE and REPLACE TABLE statements, resolve the table identifier and include
// the table columns as output. This allows expressions (e.g., constraints) referencing these
// columns to be resolved correctly.
case c @ CreateTable(UnresolvedIdentifier(nameParts, allowTemp), columns, _, _, _) =>
val resolvedIdentifier = resolveIdentifier(nameParts, allowTemp, columns)
c.copy(name = resolvedIdentifier)

case r @ ReplaceTable(UnresolvedIdentifier(nameParts, allowTemp), columns, _, _, _) =>
val resolvedIdentifier = resolveIdentifier(nameParts, allowTemp, columns)
r.copy(name = resolvedIdentifier)

case UnresolvedIdentifier(nameParts, allowTemp) =>
if (allowTemp && catalogManager.v1SessionCatalog.isTempView(nameParts)) {
val ident = Identifier.of(nameParts.dropRight(1).toArray, nameParts.last)
ResolvedIdentifier(FakeSystemCatalog, ident)
} else {
val CatalogAndIdentifier(catalog, identifier) = nameParts
ResolvedIdentifier(catalog, identifier)
}
resolveIdentifier(nameParts, allowTemp, Seq.empty)

case CurrentNamespace =>
ResolvedNamespace(currentCatalog, catalogManager.currentNamespace.toImmutableArraySeq)
Expand All @@ -94,6 +100,22 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
resolveNamespace(catalog, ns, fetchMetadata)
}

private def resolveIdentifier(
nameParts: Seq[String],
allowTemp: Boolean,
columns: Seq[ColumnDefinition]): ResolvedIdentifier = {
val columnOutput = columns.map { col =>
AttributeReference(col.name, col.dataType, col.nullable, col.metadata)()
}
if (allowTemp && catalogManager.v1SessionCatalog.isTempView(nameParts)) {
val ident = Identifier.of(nameParts.dropRight(1).toArray, nameParts.last)
ResolvedIdentifier(FakeSystemCatalog, ident, columnOutput)
} else {
val CatalogAndIdentifier(catalog, identifier) = nameParts
ResolvedIdentifier(catalog, identifier, columnOutput)
}
}

private def resolveNamespace(
catalog: CatalogPlugin,
ns: Seq[String],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.SparkThrowable
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.expressions._
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the agreement in the community on wildcard imports? Are they permitted after a given number of elements are imported directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per https://github.com/databricks/scala-style-guide?tab=readme-ov-file#imports,
"Avoid using wildcard imports, unless you are importing more than 6 entities"

import org.apache.spark.sql.catalyst.optimizer.ComputeCurrentTime
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -61,7 +61,7 @@ object ResolveTableSpec extends Rule[LogicalPlan] {
input: LogicalPlan,
tableSpec: TableSpecBase,
withNewSpec: TableSpecBase => LogicalPlan): LogicalPlan = tableSpec match {
case u: UnresolvedTableSpec if u.optionExpression.resolved =>
case u: UnresolvedTableSpec if u.childrenResolved =>
val newOptions: Seq[(String, String)] = u.optionExpression.options.map {
case (key: String, null) =>
(key, null)
Expand All @@ -86,6 +86,18 @@ object ResolveTableSpec extends Rule[LogicalPlan] {
}
(key, newValue)
}

u.constraints.foreach {
case check: CheckConstraint =>
if (!check.child.deterministic) {
check.child.failAnalysis(
errorClass = "NON_DETERMINISTIC_CHECK_CONSTRAINT",
messageParameters = Map("checkCondition" -> check.condition)
)
}
case _ =>
}

val newTableSpec = TableSpec(
properties = u.properties,
provider = u.provider,
Expand All @@ -94,7 +106,8 @@ object ResolveTableSpec extends Rule[LogicalPlan] {
comment = u.comment,
collation = u.collation,
serde = u.serde,
external = u.external)
external = u.external,
constraints = u.constraints.map(_.toV2Constraint(isCreateTable = true)))
withNewSpec(newTableSpec)
case _ =>
input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,13 @@ case class ResolvedNonPersistentFunc(
*/
case class ResolvedIdentifier(
catalog: CatalogPlugin,
identifier: Identifier) extends LeafNodeWithoutStats {
override def output: Seq[Attribute] = Nil
identifier: Identifier,
override val output: Seq[Attribute] = Nil) extends LeafNodeWithoutStats

object ResolvedIdentifier {
def unapply(ri: ResolvedIdentifier): Option[(CatalogPlugin, Identifier)] = {
Some((ri.catalog, ri.identifier))
}
}

// A fake v2 catalog to hold temp views.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.UUID

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder
import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.types.DataType

trait TableConstraint {
trait TableConstraint extends Expression with Unevaluable {
// Convert to a data source v2 constraint
def toV2Constraint(isCreateTable: Boolean): Constraint

/** Returns the user-provided name of the constraint */
def userProvidedName: String
Expand Down Expand Up @@ -92,6 +98,11 @@ trait TableConstraint {
)
}
}

override def nullable: Boolean = true

override def dataType: DataType =
throw new SparkUnsupportedOperationException("CONSTRAINT_DOES_NOT_HAVE_DATA_TYPE")
}

case class ConstraintCharacteristic(enforced: Option[Boolean], rely: Option[Boolean])
Expand All @@ -108,10 +119,30 @@ case class CheckConstraint(
override val tableName: String = null,
override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
extends UnaryExpression
with Unevaluable
with TableConstraint {
// scalastyle:on line.size.limit

def toV2Constraint(isCreateTable: Boolean): Constraint = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the input param should be related to the validation status, rather than to whether it is create or alter. For instance, we can make validation optional in ALTER.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, how about let's make all the validate status as UNVALIDATED in this PR? Once we support enforcing check constraint, we can have more discussions on this one

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me.

val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull
val enforced = userProvidedCharacteristic.enforced.getOrElse(true)
val rely = userProvidedCharacteristic.rely.getOrElse(false)
// The validation status is set to UNVALIDATED for create table and
// VALID for alter table.
val validateStatus = if (isCreateTable) {
Constraint.ValidationStatus.UNVALIDATED
} else {
Constraint.ValidationStatus.VALID
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea here that we always validate existing data in ALTER?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes for check constraint

}
Constraint
.check(name)
.predicateSql(condition)
.predicate(predicate)
.rely(rely)
.enforced(enforced)
.validationStatus(validateStatus)
.build()
}

override protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)

Expand All @@ -121,8 +152,6 @@ case class CheckConstraint(

override def sql: String = s"CONSTRAINT $userProvidedName CHECK ($condition)"

override def dataType: DataType = StringType

override def withUserProvidedName(name: String): TableConstraint = copy(userProvidedName = name)

override def withTableName(tableName: String): TableConstraint = copy(tableName = tableName)
Expand All @@ -137,9 +166,20 @@ case class PrimaryKeyConstraint(
override val userProvidedName: String = null,
override val tableName: String = null,
override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
extends TableConstraint {
extends LeafExpression with TableConstraint {
// scalastyle:on line.size.limit

override def toV2Constraint(isCreateTable: Boolean): Constraint = {
val enforced = userProvidedCharacteristic.enforced.getOrElse(false)
val rely = userProvidedCharacteristic.rely.getOrElse(false)
Constraint
.primaryKey(name, columns.map(FieldReference.column).toArray)
.rely(rely)
.enforced(enforced)
.validationStatus(Constraint.ValidationStatus.UNVALIDATED)
.build()
}

override protected def generateName(tableName: String): String = s"${tableName}_pk"

override def withUserProvidedName(name: String): TableConstraint = copy(userProvidedName = name)
Expand All @@ -158,9 +198,20 @@ case class UniqueConstraint(
override val userProvidedName: String = null,
override val tableName: String = null,
override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
extends TableConstraint {
extends LeafExpression with TableConstraint {
// scalastyle:on line.size.limit

override def toV2Constraint(isCreateTable: Boolean): Constraint = {
val enforced = userProvidedCharacteristic.enforced.getOrElse(false)
val rely = userProvidedCharacteristic.rely.getOrElse(false)
Constraint
.unique(name, columns.map(FieldReference.column).toArray)
.rely(rely)
.enforced(enforced)
.validationStatus(Constraint.ValidationStatus.UNVALIDATED)
.build()
}

override protected def generateName(tableName: String): String = {
s"${tableName}_uniq_$randomSuffix"
}
Expand All @@ -183,9 +234,25 @@ case class ForeignKeyConstraint(
override val userProvidedName: String = null,
override val tableName: String = null,
override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
extends TableConstraint {
extends LeafExpression with TableConstraint {
// scalastyle:on line.size.limit

import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

override def toV2Constraint(isCreateTable: Boolean): Constraint = {
val enforced = userProvidedCharacteristic.enforced.getOrElse(false)
val rely = userProvidedCharacteristic.rely.getOrElse(false)
Constraint
.foreignKey(name,
childColumns.map(FieldReference.column).toArray,
parentTableId.asIdentifier,
parentColumns.map(FieldReference.column).toArray)
.rely(rely)
.enforced(enforced)
.validationStatus(Constraint.ValidationStatus.UNVALIDATED)
.build()
}

override protected def generateName(tableName: String): String =
s"${tableName}_${parentTableId.last}_fk_$randomSuffix"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.analysis.{FieldName, FieldPosition, ResolvedFieldName, UnresolvedException}
import org.apache.spark.sql.catalyst.analysis.{FieldName, FieldPosition, ResolvedFieldName, ResolvedTable, UnresolvedException}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ClusterBySpec
import org.apache.spark.sql.catalyst.expressions.{Expression, TableConstraint, Unevaluable}
Expand Down Expand Up @@ -295,7 +295,16 @@ case class AlterTableCollation(
case class AddConstraint(
table: LogicalPlan,
tableConstraint: TableConstraint) extends AlterTableCommand {
override def changes: Seq[TableChange] = Seq.empty
override def changes: Seq[TableChange] = {
val constraint = tableConstraint.toV2Constraint(isCreateTable = false)
val validatedTableVersion = table match {
case t: ResolvedTable if constraint.enforced() =>
t.table.currentVersion()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created a follow-up https://issues.apache.org/jira/browse/SPARK-51835 for testing the table version

case _ =>
null
}
Seq(TableChange.addConstraint(constraint, validatedTableVersion))
Copy link
Contributor

@aokolnychyi aokolnychyi Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK constraints must optionally validate existing data in ALTER.
Am I right this PR doesn't have this? What would be our plan?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

must optionally validate

Make sense. Do you mean CHECK ... NOT ENFOCED?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ENFORCED/NOT ENFORCED impacts subsequent writes. I was referring to ALTER TABLE ... ADD CONSTRAINT that must scan the existing data.

}

protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild)
}
Expand All @@ -308,7 +317,8 @@ case class DropConstraint(
name: String,
ifExists: Boolean,
cascade: Boolean) extends AlterTableCommand {
override def changes: Seq[TableChange] = Seq.empty
override def changes: Seq[TableChange] =
Seq(TableChange.dropConstraint(name, ifExists, cascade))

protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -1505,19 +1505,25 @@ case class UnresolvedTableSpec(
serde: Option[SerdeInfo],
external: Boolean,
constraints: Seq[TableConstraint])
extends UnaryExpression with Unevaluable with TableSpecBase {
extends Expression with Unevaluable with TableSpecBase {

override def dataType: DataType =
throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3113")

override def child: Expression = optionExpression

override protected def withNewChildInternal(newChild: Expression): Expression =
this.copy(optionExpression = newChild.asInstanceOf[OptionList])

override def simpleString(maxFields: Int): String = {
this.copy(properties = Utils.redact(properties).toMap).toString
}

override def nullable: Boolean = true

override def children: Seq[Expression] = optionExpression +: constraints

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
copy(
optionExpression = newChildren.head.asInstanceOf[OptionList],
constraints = newChildren.tail.asInstanceOf[Seq[TableConstraint]])
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ case class CreateTableExec(
.withColumns(columns)
.withPartitions(partitioning.toArray)
.withProperties(tableProperties.asJava)
.withConstraints(tableSpec.constraints.toArray)
.build()
catalog.createTable(identifier, tableInfo)
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ case class ReplaceTableExec(
.withColumns(columns)
.withPartitions(partitioning.toArray)
.withProperties(tableProperties.asJava)
.withConstraints(tableSpec.constraints.toArray)
.build()
catalog.createTable(ident, tableInfo)
Seq.empty
Expand Down
Loading