[转帖]Spark修炼之道(高级篇)——Spark源码阅读:第八节_MQ, Tuxedo及OLTP讨论区_Weblogic技术|Tuxedo技术|中间件技术|Oracle论坛|JAVA论坛|Linux/Unix技术|hadoop论坛_联动北方技术论坛  
网站首页 | 关于我们 | 服务中心 | 经验交流 | 公司荣誉 | 成功案例 | 合作伙伴 | 联系我们 |
联动北方-国内领先的云技术服务提供商
»  游客             当前位置:  论坛首页 »  自由讨论区 »  MQ, Tuxedo及OLTP讨论区 »
总帖数
1
每页帖数
101/1页1
返回列表
0
发起投票  发起投票 发新帖子
查看: 333 | 回复: 0   主题: [转帖]Spark修炼之道(高级篇)——Spark源码阅读:第八节        上一篇   下一篇 
youduoduo
注册用户
等级:新兵
经验:78
发帖:78
精华:0
注册:2011-11-26
状态:离线
发送短消息息给youduoduo 加好友    发送短消息息给youduoduo 发消息
发表于: IP:您无权察看 2018-5-9 10:48:36 | [全部帖] [楼主帖] 楼主

Task执行
在上一节中,我们提到在Driver端CoarseGrainedSchedulerBackend中的launchTasks方法向Worker节点中的Executor发送启动任务命令,该命令的接收者是CoarseGrainedExecutorBackend(Standalone模式),类定义源码如下:

private[spark] class CoarseGrainedExecutorBackend(
override val rpcEnv: RpcEnv,
driverUrl: String,
executorId: String,
hostPort: String,
cores: Int,
userClassPath: Seq[URL],
env: SparkEnv)
extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {


可以看到它继承ThreadSafeRpcEndpoint,它ThreadSafeRpcEndpoint中的receive方法进行了实现,具体源代码如下:

override def receive: PartialFunction[Any, Unit] = {
      case RegisteredExecutor =>
      logInfo("Successfully registered with driver")
      val (hostname, _) = Utils.parseHostPort(hostPort)
      executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
      case RegisterExecutorFailed(message) =>
      logError("Slave registration failed: " + message)
      System.exit(1)
      //处理Driver端发送过来的LaunchTask命令
      case LaunchTask(data) =>
      if (executor == null) {
            logError("Received LaunchTask command but executor was null")
            System.exit(1)
      } else {
      //对任务进行反序列化
      val taskDesc = ser.deserialize[TaskDescription](data.value)
      logInfo("Got assigned task " + taskDesc.taskId)
      //Executor启动任务的运行
      executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
      taskDesc.name, taskDesc.serializedTask)
      case KillTask(taskId, _, interruptThread) =>
      if (executor == null) {
            logError("Received KillTask command but executor was null")
            System.exit(1)
      } else {
      executor.killTask(taskId, interruptThread)
      case StopExecutor =>
      logInfo("Driver commanded a shutdown")
      executor.stop()
      stop()
      rpcEnv.shutdown()


从前面的代码可以看到,通过 executor.launchTask方法启动Worker节点上Task的运行,其源码如下:

//Executor类中的launchTask方法
def launchTask( context: ExecutorBackend, taskId: Long, attemptNumber: Int, taskName: String, serializedTask: ByteBuffer): Unit = {
//创建TaskRunner
val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
serializedTask)
runningTasks.put(taskId, tr)
//线程池执行TaskRunner线程,该线程中有一个run方法,完成Task的执行
threadPool.execute(tr)


TaskRunner是一个线程,它是一个内部类,被定义在org.apache.spark.executor.Executor类当中,具体源码如下:

class TaskRunner( execBackend: ExecutorBackend, val taskId: Long, val attemptNumber: Int, taskName: String, serializedTask: ByteBuffer)
extends Runnable {
      /** Whether this task has been killed. */
      @volatile private var killed = false
      /** How much the JVM process has spent in GC when the task starts to run. */
      @volatile var startGCTime: Long = _
      /** * The task to run. This will be set in run() by deserializing the task binary coming * from the driver. Once it is set, it will never be changed. */
      @volatile var task: Task[Any] = _
      def kill(interruptThread: Boolean): Unit = {
            logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
            killed = true
            if (task != null) {
                  task.kill(interruptThread)
                  override def run(): Unit = {
                        val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
                        val deserializeStartTime = System.currentTimeMillis()
                        Thread.currentThread.setContextClassLoader(replClassLoader)
                        val ser = env.closureSerializer.newInstance()
                        logInfo(s"Running $taskName (TID $taskId)")
                        //向Driver端发状态更新
                        execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
                        var taskStart: Long = 0
                        startGCTime = computeTotalGcTime()
                        try {
                              val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
                              updateDependencies(taskFiles, taskJars)
                              task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
                              task.setTaskMemoryManager(taskMemoryManager)
                              // If this task has been killed before we deserialized it, let's quit now. Otherwise,
                              // continue executing the task.
                              if (killed) {
                              // Throw an exception rather than returning, because returning within a try{} block
                                    // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
                                    // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
                                    // for the task.
                                    throw new TaskKilledException
                                    logDebug("Task " + taskId + "'s epoch is " + task.epoch)
                                    env.mapOutputTracker.updateEpoch(task.epoch)
                                    // Run the actual task and measure its runtime.
                                    taskStart = System.currentTimeMillis()
                                    var threwException = true
                                    val (value, accumUpdates) = try {
                                          //执行Task的run方法,不同的Task有不同的实现,例如ShuffleMapTask及ResultTask有各自的实现
                                          val res = task.run(
                                          taskAttemptId = taskId,
                                          attemptNumber = attemptNumber,
                                          metricsSystem = env.metricsSystem)
                                          threwException = false
                                    } finally {
                                    val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
                                    if (freedMemory > 0) {
                                          val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
                                          if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) {
                                                throw new SparkException(errMsg)
                                          } else {
                                          logError(errMsg)
                                          val taskFinish = System.currentTimeMillis()
                                          // If the task has been killed, let's fail it.
                                          if (task.killed) {
                                                throw new TaskKilledException
                                                val resultSer = env.serializer.newInstance()
                                                val beforeSerialization = System.currentTimeMillis()
                                                val valueBytes = resultSer.serialize(value)
                                                val afterSerialization = System.currentTimeMillis()
                                                for (m <- task.metrics) {
                                                      // Deserialization happens in two parts: first, we deserialize a Task object, which
                                                      // includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
                                                      m.setExecutorDeserializeTime(
                                                      (taskStart - deserializeStartTime) + task.executorDeserializeTime)
                                                      // We need to subtract Task.run()'s deserialization time to avoid double-counting
                                                      m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
                                                      m.setJvmGCTime(computeTotalGcTime() - startGCTime)
                                                      m.setResultSerializationTime(afterSerialization - beforeSerialization)
                                                      m.updateAccumulators()
                                                      val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
                                                      val serializedDirectResult = ser.serialize(directResult)
                                                      val resultSize = serializedDirectResult.limit
                                                      // directSend = sending directly back to the driver
                                                      val serializedResult: ByteBuffer = {
                                                            if (maxResultSize > 0 && resultSize > maxResultSize) {
                                                                  logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
                                                      s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
                                                            s"dropping it.")
                                                            ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
                                                      } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
                                                            val blockId = TaskResultBlockId(taskId)
                                                            env.blockManager.putBytes(
                                                            blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
                                                            logInfo(
                                                            s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
                                                            ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
                                                      } else {
                                                      logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
                                                      serializedDirectResult
                                                      //执行完成后,通知Driver端进行状态更新
                                                      execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
                                                } catch {
                                                case ffe: FetchFailedException =>
                                                val reason = ffe.toTaskEndReason
                                                execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
                                                case _: TaskKilledException | _: InterruptedException if task.killed =>
                                                logInfo(s"Executor killed $taskName (TID $taskId)")
                                                execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
                                                case cDE: CommitDeniedException =>
                                                val reason = cDE.toTaskEndReason
                                                execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
                                                case t: Throwable =>
                                                // Attempt to exit cleanly by informing the driver of our failure.
                                                // If anything goes wrong (or this was a fatal exception), we will delegate to
                                                // the default uncaught exception handler, which will terminate the Executor.
                                                logError(s"Exception in $taskName (TID $taskId)", t)
                                                val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>
                                                      task.metrics.map { m =>
                                                            m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
                                                            m.setJvmGCTime(computeTotalGcTime() - startGCTime)
                                                            m.updateAccumulators()
                                                            val serializedTaskEndReason = {
                                                                  try {
                                                                        ser.serialize(new ExceptionFailure(t, metrics))
                                                                  } catch {
                                                                  case _: NotSerializableException =>
                                                                  // t is not serializable so just send the stacktrace
                                                                  ser.serialize(new ExceptionFailure(t, metrics, false))
                                                                  //任务失败时,同样进行状态更新,方便后期任务重运行
                                                                  execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
                                                                  // Don't forcibly exit unless the exception was inherently fatal, to avoid
                                                                  // stopping other tasks unnecessarily.
                                                                  if (Utils.isFatalError(t)) {
                                                                        SparkUncaughtExceptionHandler.uncaughtException(t)
                                                                  } finally {
                                                                  //从运行任务列表中删除
                                                                  runningTasks.remove(taskId)


Task run方法负责Task的执行,其源码如下:

/** * Called by [[Executor]] to run this task. * * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. * @param attemptNumber how many times this task has been attempted (0 for the first attempt) * @return the result of the task along with updates of Accumulators. */
final def run(
taskAttemptId: Long,
attemptNumber: Int,
metricsSystem: MetricsSystem)
: (T, AccumulatorUpdates) = {
      //任务运行环境信息
      context = new TaskContextImpl(
      stageId,
      partitionId,
      taskAttemptId,
      attemptNumber,
      taskMemoryManager,
      metricsSystem,
      internalAccumulators,
      runningLocally = false)
      TaskContext.setTaskContext(context)
      context.taskMetrics.setHostname(Utils.localHostName())
      context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)
      taskThread = Thread.currentThread()
      if (_killed) {
            kill(interruptThread = false)
            try {
                  //调用runTask方法执行,不同的任务其实现不同,例如ShuffleMapTask和ResultTask其runTask方法逻辑不同
                  (runTask(context), context.collectAccumulators())
            } finally {
            context.markTaskCompleted()
            try {
                  Utils.tryLogNonFatalError {
                        // Release memory used by this thread for shuffles
                        SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask()
                        Utils.tryLogNonFatalError {
                              // Release memory used by this thread for unrolling blocks
                              SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
                        } finally {
                        TaskContext.unset()


以ResultTask为例,其runTask方法源码如下:

//ResultTask中的runTask方法
override def runTask(context: TaskContext): U = {
// Deserialize the RDD and the func using the broadcast variables.
val deserializeStartTime = System.currentTimeMillis()
val ser = SparkEnv.get.closureSerializer.newInstance()
//反序列化rdd及执行函数
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
metrics = Some(context.taskMetrics)
//执行rdd.iterator方法,完成任务的计算
func(context, rdd.iterator(partition, context))


总结一下Task的执行过程:
1 调用Driver端org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend中的launchTasks
2 调用Worker端的org.apache.spark.executor.CoarseGrainedExecutorBackend.launchTask
3 执行org.apache.spark.executor.TaskRunner线程中的run方法
4 调用org.apache.spark.scheduler.Task.run方法
5 调用org.apache.spark.scheduler..runTask方法
6 调用org.apache.spark.rdd.RDD.iterator方法




赞(0)    操作        顶端 
总帖数
1
每页帖数
101/1页1
返回列表
发新帖子
请输入验证码: 点击刷新验证码
您需要登录后才可以回帖 登录 | 注册
技术讨论