Spark SQL百万级数据批量读写入MySQL

Spark SQL读取MySQL的方式

Spark SQL还包括一个可以使用JDBC从其他数据库读取数据的数据源。与使用JdbcRDD相比,应优先使用此功能。这是因为结果作为DataFrame返回,它们可以在Spark SQL中轻松处理或与其他数据源连接。JDBC数据源也更易于使用Java或Python,因为它不需要用户提供ClassTag。

可以使用Data Sources API将远程数据库中的表加载为DataFrame或Spark SQL临时视图。用户可以在数据源选项中指定JDBC连接属性。 user和password通常作为用于登录数据源的连接属性。除连接属性外,Spark还支持以下不区分大小写的选项:

属性名称解释
url要连接的JDBC URL
dbtable读取或写入的JDBC表
query指定查询语句
driver用于连接到该URL的JDBC驱动类名
partitionColumn, lowerBound, upperBound如果指定了这些选项,则必须全部指定。另外, numPartitions必须指定
numPartitions表读写中可用于并行处理的最大分区数。这也确定了并发JDBC连接的最大数量。如果要写入的分区数超过此限制,我们可以通过coalesce(numPartitions)在写入之前进行调用将其降低到此限制
queryTimeout默认为0,查询超时时间
fetchsizeJDBC的获取大小,它确定每次要获取多少行。这可以帮助提高JDBC驱动程序的性能
batchsize默认为1000,JDBC批处理大小,这可以帮助提高JDBC驱动程序的性能。
isolationLevel事务隔离级别,适用于当前连接。它可以是一个NONEREAD_COMMITTEDREAD_UNCOMMITTEDREPEATABLE_READ,或SERIALIZABLE,对应于由JDBC的连接对象定义,缺省值为标准事务隔离级别READ_UNCOMMITTED。此选项仅适用于写作。
sessionInitStatement在向远程数据库打开每个数据库会话之后,在开始读取数据之前,此选项将执行自定义SQL语句,使用它来实现会话初始化代码。
truncate这是与JDBC writer相关的选项。当SaveMode.Overwrite启用时,就会清空目标表的内容,而不是删除和重建其现有的表。默认为false
pushDownPredicate用于启用或禁用谓词下推到JDBC数据源的选项。默认值为true,在这种情况下,Spark将尽可能将过滤器下推到JDBC数据源。

源码

  • SparkSession
/**
 * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a
 * `DataFrame`.
 * {{{
 * sparkSession.read.parquet("/path/to/file.parquet")
 * sparkSession.read.schema(schema).json("/path/to/file.json")
 * }}}
 *
 * @since 2.0.0
 */
 def read: DataFrameReader = new DataFrameReader(self)
  • DataFrameReader
 // ...省略代码...
 /**
 *所有的数据由RDD的一个分区处理,如果你这个表很大,很可能会出现OOM
 *可以使用DataFrameDF.rdd.partitions.size方法查看
 */
 def jdbc(url: String, table: String, properties: Properties): DataFrame = {
 assertNoSpecifiedSchema("jdbc")
 this.extraOptions ++= properties.asScala
 this.extraOptions += (JDBCOptions.JDBC_URL -> url, JDBCOptions.JDBC_TABLE_NAME -> table)
 format("jdbc").load()
 }
/**
 * @param url 数据库url
 * @param table 表名
 * @param columnName 分区字段名
 * @param lowerBound `columnName`的最小值,用于分区步长
 * @param upperBound `columnName`的最大值,用于分区步长.
 * @param numPartitions 分区数量 
 * @param connectionProperties 其他参数
 * @since 1.4.0
 */
 def jdbc(
 url: String,
 table: String,
 columnName: String,
 lowerBound: Long,
 upperBound: Long,
 numPartitions: Int,
 connectionProperties: Properties): DataFrame = {
 this.extraOptions ++= Map(
 JDBCOptions.JDBC_PARTITION_COLUMN -> columnName,
 JDBCOptions.JDBC_LOWER_BOUND -> lowerBound.toString,
 JDBCOptions.JDBC_UPPER_BOUND -> upperBound.toString,
 JDBCOptions.JDBC_NUM_PARTITIONS -> numPartitions.toString)
 jdbc(url, table, connectionProperties)
 }
 /**
 * @param predicates 每个分区的where条件
 * 比如:"id <= 1000", "score > 1000 and score <= 2000"
 * 将会分成两个分区
 * @since 1.4.0
 */
 def jdbc(
 url: String,
 table: String,
 predicates: Array[String],
 connectionProperties: Properties): DataFrame = {
 assertNoSpecifiedSchema("jdbc")
 val params = extraOptions.toMap ++ connectionProperties.asScala.toMap
 val options = new JDBCOptions(url, table, params)
 val parts: Array[Partition] = predicates.zipWithIndex.map { case (part, i) =>
 JDBCPartition(part, i) : Partition
 }
 val relation = JDBCRelation(parts, options)(sparkSession)
 sparkSession.baseRelationToDataFrame(relation)
 }

示例

 private def runJdbcDatasetExample(spark: SparkSession): Unit = {
 
 // 从JDBC source加载数据(load)
 val jdbcDF = spark.read
 .format("jdbc")
 .option("url", "jdbc:mysql://127.0.0.1:3306/test")
 .option("dbtable", "mytable")
 .option("user", "root")
 .option("password", "root")
 .load()
 val connectionProperties = new Properties()
 connectionProperties.put("user", "root")
 connectionProperties.put("password", "root")
 val jdbcDF2 = spark.read
 .jdbc("jdbc:mysql://127.0.0.1:3306/test", "mytable", connectionProperties)
 // 指定读取schema的数据类型
 connectionProperties.put("customSchema", "id DECIMAL(38, 0), name STRING")
 val jdbcDF3 = spark.read
 .jdbc("jdbc:mysql://127.0.0.1:3306/test", "mytable", connectionProperties)
 }

值得注意的是,上面的方式如果不指定分区的话,Spark默认会使用一个分区读取数据,这样在数据量特别大的情况下,会出现OOM。在读取数据之后,调用DataFrameDF.rdd.partitions.size方法可以查看分区数。

Spark SQL批量写入MySQL

代码示例如下:

object BatchInsertMySQL {
 case class Person(name: String, age: Int)
 def main(args: Array[String]): Unit = {
 // 创建sparkSession对象
 val conf = new SparkConf()
 .setAppName("BatchInsertMySQL")
 val spark: SparkSession = SparkSession.builder()
 .config(conf)
 .getOrCreate()
 import spark.implicits._
 // MySQL连接参数
 val url = JDBCUtils.url
 val user = JDBCUtils.user
 val pwd = JDBCUtils.password
 // 创建Properties对象,设置连接mysql的用户名和密码
 val properties: Properties = new Properties()
 properties.setProperty("user", user) // 用户名
 properties.setProperty("password", pwd) // 密码
 properties.setProperty("driver", "com.mysql.jdbc.Driver")
 properties.setProperty("numPartitions","10")
 // 读取mysql中的表数据
 val testDF: DataFrame = spark.read.jdbc(url, "test", properties)
 println("testDF的分区数: " + testDF.rdd.partitions.size)
 testDF.createOrReplaceTempView("test")
 testDF.persist(StorageLevel.MEMORY_AND_DISK)
 testDF.printSchema()
 val result =
 s"""-- SQL代码
 """.stripMargin
 val resultBatch = spark.sql(result).as[Person]
 println("resultBatch的分区数: " + resultBatch.rdd.partitions.size)
 // 批量写入MySQL
 // 此处最好对处理的结果进行一次重分区
 // 由于数据量特别大,会造成每个分区数据特别多
 resultBatch.repartition(500).foreachPartition(record => {
 val list = new ListBuffer[Person]
 record.foreach(person => {
 val name = Person.name
 val age = Person.age
 list.append(Person(name,age))
 })
 upsertDateMatch(list) //执行批量插入数据
 })
 // 批量插入MySQL的方法
 def upsertPerson(list: ListBuffer[Person]): Unit = {
 var connect: Connection = null
 var pstmt: PreparedStatement = null
 try {
 connect = JDBCUtils.getConnection()
 // 禁用自动提交
 connect.setAutoCommit(false)
 val sql = "REPLACE INTO `person`(name, age)" +
 " VALUES(?, ?)"
 pstmt = connect.prepareStatement(sql)
 var batchIndex = 0
 for (person <- list) {
 pstmt.setString(1, person.name)
 pstmt.setString(2, person.age)
 // 加入批次
 pstmt.addBatch()
 batchIndex +=1
 // 控制提交的数量,
 // MySQL的批量写入尽量限制提交批次的数据量,否则会把MySQL写挂!!!
 if(batchIndex % 1000 == 0 && batchIndex !=0){
 pstmt.executeBatch()
 pstmt.clearBatch()
 }
 }
 // 提交批次
 pstmt.executeBatch()
 connect.commit()
 } catch {
 case e: Exception =>
 e.printStackTrace()
 } finally {
 JDBCUtils.closeConnection(connect, pstmt)
 }
 }
 spark.close()
 }
}

JDBC连接工具类:

object JDBCUtils {
 val user = "root"
 val password = "root"
 val url = "jdbc:mysql://localhost:3306/mydb"
 Class.forName("com.mysql.jdbc.Driver")
 // 获取连接
 def getConnection() = {
 DriverManager.getConnection(url,user,password)
 }
// 释放连接
 def closeConnection(connection: Connection, pstmt: PreparedStatement): Unit = {
 try {
 if (pstmt != null) {
 pstmt.close()
 }
 } catch {
 case e: Exception => e.printStackTrace()
 } finally {
 if (connection != null) {
 connection.close()
 }
 }
 }
}

总结

Spark写入大量数据到MySQL时,在写入之前尽量对写入的DF进行重分区处理,避免分区内数据过多。在写入时,要注意使用foreachPartition来进行写入,这样可以为每一个分区获取一个连接,在分区内部设定批次提交,提交的批次不易过大,以免将数据库写挂。

公众号『大数据技术与数仓』,回复『资料』领取大数据资料包
作者:大数据技术与数仓原文地址:https://segmentfault.com/a/1190000038207701

%s 个评论

要回复文章请先登录注册