Skip to content

Commit 549c30a

Browse files
xiaonanyang-dbcloud-fan
authored andcommitted
[SPARK-52582][SQL] Improve the memory usage of XML parser
### What changes were proposed in this pull request? Today, the XML parser is not memory efficient. It loads each XML record into memory first before parsing, which causes OOMs if the input XML record is large. This PR improves the parser to parse XML records token by token to avoid copying the entire XML records into memory ahead of time. This improved parser uses less memory when parsing large XML files than the legacy parser. However, it enforces stricter validation to ensure the XML is well-formed: 1. The legacy parser doesn't scavenge all valid records deterministically. On the other hand, the improved parser will stop processing the file where malformedness is detected. 2. The legacy parser was able to handle malformed XML files with multiple root tags. However, the enhanced parser will only read the records in the first root tag. The enhanced parser is enabled by default, but users can fallback to the legacy parser via the `spark.sql.xml.legacyParser.enabled` SQL conf. ### Why are the changes needed? Solve the OOM issue in XML ingestion. ### Does this PR introduce _any_ user-facing change? No. The new behavior is disabled by default for now. ### How was this patch tested? New UTs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51287 from xiaonanyang-db/SPARK-52582. Authored-by: Xiaonan Yang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 594d26c commit 549c30a

File tree

21 files changed

+7217
-3721
lines changed

21 files changed

+7217
-3721
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.catalyst.xml
18+
19+
import java.io.InputStream
20+
import javax.xml.stream.{XMLEventReader, XMLStreamConstants, XMLStreamReader}
21+
import javax.xml.stream.events.{EndDocument, StartElement, XMLEvent}
22+
import javax.xml.transform.stax.StAXSource
23+
24+
import scala.util.control.NonFatal
25+
26+
import org.apache.hadoop.shaded.com.ctc.wstx.exc.WstxEOFException
27+
28+
import org.apache.spark.internal.Logging
29+
import org.apache.spark.util.SparkErrorUtils
30+
31+
/**
32+
* XML record reader that reads the next XML record in the underlying XML stream. It can support XSD
33+
* schema validation by maintaining a separate XML reader and keep it in sync with the primary XML
34+
* reader.
35+
*/
36+
case class StaxXMLRecordReader(inputStream: () => InputStream, options: XmlOptions)
37+
extends XMLEventReader
38+
with Logging {
39+
// Reader for the XML record parsing.
40+
private val in1 = inputStream()
41+
private val primaryEventReader = StaxXmlParserUtils.filteredReader(in1, options)
42+
43+
private val xsdSchemaValidator = Option(options.rowValidationXSDPath)
44+
.map(path => ValidatorUtil.getSchema(path).newValidator())
45+
// Reader for the XSD validation, if an XSD schema is provided.
46+
private val in2 = xsdSchemaValidator.map(_ => inputStream())
47+
// An XMLStreamReader used by StAXSource for XSD validation.
48+
private val xsdValidationStreamReader =
49+
in2.map(in => StaxXmlParserUtils.filteredStreamReader(in, options))
50+
51+
final var hasMoreRecord: Boolean = true
52+
53+
/**
54+
* Skip through the XML stream until we find the next row start element.
55+
* Returns true if a row start element is found, false if end of stream is reached.
56+
*/
57+
def skipToNextRecord(): Boolean = {
58+
hasMoreRecord = skipToNextRowStart()
59+
if (hasMoreRecord) {
60+
xsdValidationStreamReader.foreach(validateXSDSchema)
61+
} else {
62+
close()
63+
}
64+
hasMoreRecord
65+
}
66+
67+
/**
68+
* Skip through the XML stream until we find the next row start element.
69+
*/
70+
private def skipToNextRowStart(): Boolean = {
71+
val rowTagName = options.rowTag
72+
try {
73+
while (primaryEventReader.hasNext) {
74+
val event = primaryEventReader.peek()
75+
event match {
76+
case startElement: StartElement =>
77+
val elementName = StaxXmlParserUtils.getName(startElement.getName, options)
78+
if (elementName == rowTagName) {
79+
return true
80+
}
81+
case _: EndDocument =>
82+
return false
83+
case _ =>
84+
// Continue searching
85+
}
86+
// if not the event we want, advance the reader
87+
primaryEventReader.nextEvent()
88+
}
89+
false
90+
} catch {
91+
case NonFatal(e) if SparkErrorUtils.getRootCause(e).isInstanceOf[WstxEOFException] =>
92+
logWarning("Reached end of file while looking for next row start element.")
93+
false
94+
}
95+
}
96+
97+
private def validateXSDSchema(streamReader: XMLStreamReader): Unit = {
98+
// StAXSource requires the stream reader to start with the START_DOCUMENT OR START_ELEMENT
99+
// events.
100+
def rowTagStarted: Boolean =
101+
streamReader.getEventType == XMLStreamConstants.START_ELEMENT &&
102+
StaxXmlParserUtils.getName(streamReader.getName, options) == options.rowTag
103+
while (!rowTagStarted && streamReader.hasNext) {
104+
streamReader.next()
105+
}
106+
xsdSchemaValidator.get.reset()
107+
xsdSchemaValidator.get.validate(new StAXSource(streamReader))
108+
}
109+
110+
override def close(): Unit = {
111+
primaryEventReader.close()
112+
xsdValidationStreamReader.foreach(_.close())
113+
in1.close()
114+
in2.foreach(_.close())
115+
hasMoreRecord = false
116+
}
117+
118+
override def nextEvent(): XMLEvent = primaryEventReader.nextEvent()
119+
override def hasNext: Boolean = primaryEventReader.hasNext
120+
override def peek(): XMLEvent = primaryEventReader.peek()
121+
override def getElementText: String = primaryEventReader.getElementText
122+
override def nextTag(): XMLEvent = primaryEventReader.nextTag()
123+
override def getProperty(name: String): AnyRef = primaryEventReader.getProperty(name)
124+
override def next(): AnyRef = primaryEventReader.next()
125+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import scala.util.control.Exception.allCatch
3333
import scala.util.control.NonFatal
3434
import scala.xml.SAXException
3535

36+
import com.google.common.io.ByteStreams
3637
import org.apache.hadoop.hdfs.BlockMissingException
3738
import org.apache.hadoop.security.AccessControlException
3839

@@ -50,7 +51,7 @@ import org.apache.spark.types.variant.{Variant, VariantBuilder}
5051
import org.apache.spark.types.variant.VariantBuilder.FieldEntry
5152
import org.apache.spark.types.variant.VariantUtil
5253
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
53-
import org.apache.spark.util.Utils
54+
import org.apache.spark.util.{SparkErrorUtils, Utils}
5455

5556
class StaxXmlParser(
5657
schema: StructType,
@@ -127,12 +128,12 @@ class StaxXmlParser(
127128
// is not manually specified, then fall back to DROPMALFORMED, which will return
128129
// null column values where parsing fails.
129130
val parseMode =
130-
if (options.parseMode == PermissiveMode &&
131-
!schema.fields.exists(_.name == options.columnNameOfCorruptRecord)) {
132-
DropMalformedMode
133-
} else {
134-
options.parseMode
135-
}
131+
if (options.parseMode == PermissiveMode &&
132+
!schema.fields.exists(_.name == options.columnNameOfCorruptRecord)) {
133+
DropMalformedMode
134+
} else {
135+
options.parseMode
136+
}
136137
val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)
137138
doParseColumn(xml, parseMode, xsdSchema).orNull
138139
}
@@ -188,6 +189,110 @@ class StaxXmlParser(
188189
}
189190
}
190191

192+
/**
193+
* XML stream parser that reads XML records from the input file stream sequentially without
194+
* loading each individual XML record string into memory.
195+
*/
196+
def parseStreamOptimized(
197+
inputStream: () => InputStream,
198+
schema: StructType): Iterator[InternalRow] = {
199+
val streamLiteral = () =>
200+
Utils.tryWithResource(
201+
inputStream()
202+
) { is =>
203+
UTF8String.fromBytes(ByteStreams.toByteArray(is))
204+
}
205+
val safeParser = new FailureSafeParser[StaxXMLRecordReader](
206+
input => doParseColumnOptimized(input, streamLiteral),
207+
options.parseMode,
208+
schema,
209+
options.columnNameOfCorruptRecord
210+
)
211+
212+
convertStream(inputStream, options) { reader =>
213+
safeParser.parse(reader)
214+
}.flatten
215+
}
216+
217+
/**
218+
* Parse the next XML record from the XML event stream.
219+
* Note that the method will **NOT** close the XML event stream as there could have more XML
220+
* records to parse. It's the caller's responsibility to close the stream.
221+
*
222+
* @param parser The XML event reader.
223+
* @param xmlLiteral A function that returns the entire XML file content as a UTF8String. Used
224+
* to create a BadRecordException in case of parsing errors.
225+
* TODO: Only include the file content starting with the current record.
226+
*/
227+
def doParseColumnOptimized(
228+
parser: StaxXMLRecordReader,
229+
xmlLiteral: () => UTF8String): Option[InternalRow] = {
230+
try {
231+
if (!parser.skipToNextRecord()) {
232+
return None
233+
}
234+
235+
options.singleVariantColumn match {
236+
case Some(_) =>
237+
// If the singleVariantColumn is specified, parse the entire xml record as a Variant
238+
val v = StaxXmlParser.parseVariant(parser, options)
239+
Some(InternalRow(v))
240+
case _ =>
241+
// Otherwise, parse the xml record as Structs
242+
val rootAttributes = parser.nextEvent().asStartElement.getAttributes.asScala.toArray
243+
val result = Some(convertObject(parser, schema, rootAttributes))
244+
result
245+
}
246+
} catch {
247+
case e: SparkUpgradeException =>
248+
parser.close()
249+
throw e
250+
case e: CharConversionException if options.charset.isEmpty =>
251+
val msg =
252+
"""XML parser cannot handle a character in its input.
253+
|Specifying encoding as an input option explicitly might help to resolve the issue.
254+
|""".stripMargin + e.getMessage
255+
val wrappedCharException = new CharConversionException(msg)
256+
wrappedCharException.initCause(e)
257+
throw BadRecordException(xmlLiteral, () => Array.empty,
258+
wrappedCharException)
259+
case PartialResultException(row, cause) =>
260+
throw BadRecordException(
261+
record = xmlLiteral,
262+
partialResults = () => Array(row),
263+
cause)
264+
case PartialResultArrayException(rows, cause) =>
265+
throw BadRecordException(record = xmlLiteral, partialResults = () => rows, cause)
266+
case e: Throwable =>
267+
SparkErrorUtils.getRootCause(e) match {
268+
case _: FileNotFoundException if options.ignoreMissingFiles =>
269+
logWarning("Skipped missing file", e)
270+
parser.close()
271+
None
272+
case _: IOException | _: RuntimeException | _: InternalError
273+
if options.ignoreCorruptFiles =>
274+
logWarning("Skipped the rest of the content in the corrupted file", e)
275+
parser.close()
276+
None
277+
case _: XMLStreamException | _: MalformedInputException =>
278+
// Skip rest of the content in the parser and put the whole XML file in the
279+
// BadRecordException.
280+
parser.close()
281+
// XML parser currently doesn't support partial results for corrupted records.
282+
// For such records, all fields other than the field configured by
283+
// `columnNameOfCorruptRecord` are set to `null`.
284+
throw BadRecordException(xmlLiteral, () => Array.empty, e)
285+
case _: SAXException =>
286+
// XSD validation failed, throw a bad record exception and continue to parse the rest
287+
// records.
288+
val record = UTF8String.fromString(
289+
StaxXmlParserUtils.currentElementAsString(parser, options.rowTag, options).trim
290+
)
291+
throw BadRecordException(() => record, () => Array.empty, e)
292+
}
293+
}
294+
}
295+
191296
/**
192297
* Parse the current token (and related children) according to a desired schema
193298
*/
@@ -929,6 +1034,20 @@ object StaxXmlParser {
9291034
}
9301035
}
9311036

1037+
def convertStream[T](inputStream: () => InputStream, options: XmlOptions)(
1038+
convert: StaxXMLRecordReader => T): Iterator[T] = new Iterator[T] {
1039+
private val reader = StaxXMLRecordReader(inputStream, options)
1040+
1041+
override def hasNext: Boolean = reader.hasMoreRecord
1042+
1043+
override def next(): T = {
1044+
if (!hasNext) {
1045+
throw QueryExecutionErrors.endOfStreamError()
1046+
}
1047+
convert(reader)
1048+
}
1049+
}
1050+
9321051
/**
9331052
* Parse the input XML string as a Variant value
9341053
*/
@@ -940,6 +1059,12 @@ object StaxXmlParser {
9401059
v
9411060
}
9421061

1062+
def parseVariant(parser: StaxXMLRecordReader, options: XmlOptions): VariantVal = {
1063+
val rootAttributes = parser.nextEvent().asStartElement.getAttributes.asScala.toArray
1064+
val v = convertVariant(parser, rootAttributes, options)
1065+
new VariantVal(v.getValue, v.getMetadata)
1066+
}
1067+
9431068
/**
9441069
* Parse an XML element from the XML event stream into a Variant.
9451070
* This method transforms the XML element along with its attributes and child elements

0 commit comments

Comments
 (0)