Kubeflow/tf-operator源码分析

调用流程

虽然KubeFlow提供了一大堆组件,涵盖了机器学习的方方面面,但模型训练肯定是KubeFlow最重要的功能。 KubeFlow针对各种各样的机器学习框架提供了训练的能力。方式是定义了各种各样的Operator,这些Operator的本质,是K8SCRD
一句话,TF-Operator就是开源社区基于K8S提供的扩展API,提供了TensorFlow的训练能力,从名字也能看出来,这个实现是类似Job的一种方式。
TF-Operator的代码不太多,但是由于用了大量的K8SAPI,结构有点复杂,我们只把重要的地方摘出来。最重要的是下面这几个源代码文件(后附源码分析)。

TF-Operator流程图

源码分析

为了帮助我们更加深刻的理解Kubeflow@TFJob的工作流程和实现机制,下面将TF-Operator重点代码拿出来一起过一遍。
pkg/controller.v1/tensorflow/controller.go:
NewTFController返回一个新的TFJob控制器:

func NewTFController (...) *TFController {
    ... ...
    // 设置同步处理程序。
    tc.syncHandler = tc.syncTFJob
    ... ...
    return tc
}

processNextWorkItem将从WorkQueue中读取单个工作项,并尝试通过调用syncHandler来处理它:

func (tc *TFController) processNextWorkItem() bool {
    obj, quit := tc.WorkQueue.Get()
    ... ...
    // 同步TFJob以将实际状态匹配到所需的状态。
    forget, err := tc.syncHandler(key=obj.(string))
    if err == nil {
        if forget {
            tc.WorkQueue.Forget(key)
        }
    }
}

如果tfjob的期望值已经实现,那么syncTFJob就会用给定的key来同步tfjob,这意味着它不希望更多的
pod/service被创建或删除:

// 这个函数不能与同一个key同时调用
func (tc *TFController) syncTFJob(key string) (bool, error) {
    ... ...
    sharedTFJob, err := tc.getTFJobFromName(namespace, name)
    
    tfjob := sharedTFJob.DeepCopy()

    // 为新tfjob设置默认值。
    scheme.Scheme.Default(tfjob)

    if tfjobNeedsSync && tfjob.DeletionTimestamp == nil {
        // 调用reconcileTFJobs来启动TFJobs
        reconcileTFJobsErr = tc.reconcileTFJobs(tfjob)
    }
    ... ...
}

pkg/controller.v1/tensorflow/pod.go:
reconcileTFJobs检查并更新每个给定TFReplicaSpecreplicas

// 如果在创建/删除 pods/services时发生错误,它将请求tfjob。 
func (tc *TFController) reconcileTFJobs(tfjob *tfv1.TFJob) error {
    ... ...
    // 如果TFJob terminated,则delete所有pod和service。
    if isSucceeded(tfjob.Status) || isFailed(tfjob.Status) {
        if err := tc.deletePodsAndServices(tfjob, pods); err != nil {
            return err
        }
        if err := tc.cleanupTFJob(tfjob); err != nil {
            return err
        }
        if tc.Config.EnableGangScheduling {
            if err := tc.DeletePodGroup(tfjob); err != nil {
                return err
            }
        }
        ... ...
    }

    // 检索以前的重试次数
    previousRetry := tc.WorkQueue.NumRequeues(tfjobKey)

    if tfJobExceedsLimit {
        // 如果TFJob超过了backofflimit或超过了active deadline,删除所有pod和service,然后将状态设置为failed(代码同上)
        ... ...
        // 遍历配置文件的TFReplicaSpecs部分,分别为不同类型的节点启动相应的Pod。
        // 在启动Pod之后,还要为其启动一个Service。
        for rtype, spec := range tfjob.Spec.TFReplicaSpecs {
            err = tc.reconcilePods(tfjob, pods, rtype, spec, replicasStatus)
            ... ...
            err = tc.reconcileServices(tfjob, services, rtype, spec)
            ... ...
        }
    }   
}

reconcilePods为每个给定的TFReplicaSpec检查和更新pod

// 如果在创建/删除pod时发生错误,它将请求tfjob。
func (tc *TFController) reconcilePods(...) error {  
    ... ...
    // 获取rtype类型的所有pod。
    pods, err := tc.FilterPodsForReplicaType(pods, rt)
    ... ...
    podSlices, podsToBeRemoved := tc.GetPodSlices(pods, replicas, logger)

    // 缩减
    if tfjob.Spec.EnableDynamicWorker && len(podsToBeRemoved) > 0 {
        // 目前只允许缩减workers
        if rtype == tfv1.TFReplicaTypeWorker {
            for _, pod := range podsToBeRemoved {
                err := tc.PodControl.DeletePod(tfjob.Namespace, pod.Name, tfjob)
            }
        } 
    }

    for index, podSlice := range podSlices {
        if len(podSlice) == 0 {
            // 如果master pod存在,选择master pod
            // 如果没有master,第一个worker pod被选为master。
            if ContainChieforMasterSpec(tfjob) {
                if tfv1.IsChieforMaster(rtype) {
                    masterRole = true
                }
            } else {
                if tfv1.IsWorker(rtype) && (index == 0) {
                    masterRole = true
                }
            }
            // 调用createNewPod创建Pod
            err = tc.createNewPod(tfjob, rt, strconv.Itoa(index), spec, masterRole)
        } 
        ... ...
    }

    return tc.updateStatusSingle(tfjob, rtype, replicas, restart, worker0Completed)
}

createNewPod为给定的indextype创建一个新的pod

func (tc *TFController) createNewPod(tfjob *tfv1.TFJob, rt, index string, spec *common.ReplicaSpec, masterRole bool) error {
    
    expectationPodsKey := jobcontroller.GenExpectationPodsKey(tfjobKey, rt)
    err = tc.Expectations.ExpectCreations(expectationPodsKey, 1)
    
    // 创建 OwnerReference.
    controllerRef := tc.GenOwnerReference(tfjob)

    podTemplate := spec.Template.DeepCopy()
    ... ...
    // 生成集群的配置信息,这里最关键,看一下实现
    if err := setClusterSpec(podTemplate, tfjob, rt, index); err != nil {
        return err
    }
    ... ...
    // 使用上面的配置信息,真正启动Pod的创建
    err = tc.PodControl.CreatePodsWithControllerRef(tfjob.Namespace, podTemplate, tfjob, controllerRef)
}

setClusterSpec为给定的podTemplateSpec生成并设置TF_CONFIG

func setClusterSpec(podTemplateSpec *v1.PodTemplateSpec, tfjob *tfv1.TFJob, rt, index string) error {
    ... ...
    // 生成TF_CONFIG JSON字符串。
    tfConfigStr, err := genTFConfigJSONStr(tfjob, rt, index)
    ... ...
    // 将TF_CONFIG环境变量添加到pod中的tensorflow容器中。
    for i := range podTemplateSpec.Spec.Containers {
        if podTemplateSpec.Spec.Containers[i].Name == tfv1.DefaultContainerName {
            ... ...
            podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, v1.EnvVar{
                Name:  tfConfig,
                Value: tfConfigStr,
            })
            break
        }
    }
}

pkg/controller.v1/tensorflow/tensorflow.go:
genTFConfig将生成环境变量TF_CONFIG

{
    "cluster": {
        "ps": ["ps1:2222", "ps2:2222"],
        "worker": ["worker1:2222", "worker2:2222", "worker3:2222"]
    },
    "task": {
        "type": "ps",
        "index": 1
    },
}

主要代码如下:

func genTFConfigJSONStr(tfjob *tfv1.TFJob, rtype, index string) (string, error) {
    // 配置TFCONFIG环境变量。
    cluster, err := genClusterSpec(tfjob)
    ... ...
    // 组装形成TF_CONFIG
    if tfjob.Spec.EnableDynamicWorker {
        sparseCluster := convertClusterSpecToSparseClusterSpec(cluster, rtype, int32(i))
        sparseTFConfig := SparseTFConfig{
            Cluster: sparseCluster,
            Task: TaskSpec{
                Type:  rtype,
                Index: int(i),
            },
        }
        tfConfigJSONByteSlice, err = json.Marshal(sparseTFConfig)
    } else {
        tfConfig := TFConfig{
            Cluster: cluster,
            Task: TaskSpec{
                Type:  rtype,
                Index: int(i),
            },
            // 我们需要设置环境为cloud,否则它会默认为local,这不是我们想要的。
            Environment: "cloud",
        }
        tfConfigJSONByteSlice, err = json.Marshal(tfConfig)
    }
    return string(tfConfigJSONByteSlice), nil
}

genClusterSpec将生成ClusterSpec

func genClusterSpec(tfjob *tfv1.TFJob) (ClusterSpec, error) {
    ... ...
    for rtype, spec := range tfjob.Spec.TFReplicaSpecs {
        port, err := GetPortFromTFJob(tfjob, rtype)
        // 这里循环生成了TF_CONFIG里面的Cluster信息。注意看注释,使用DNS配合Service,解决的还是各个节点IP不固定的问题
        for i := int32(0); i < *spec.Replicas; i++ {
            // 如下所述:https://kubernetes.io/docs/concepts/services-networking/dns-pos-service/#a-records。
            // Headless service为"my-svc.my-namespace.svc.cluster.local"的名称分配一个DNS记录。
            // 最后一部分是"svc.cluster.local"被称为cluster domain,在不同的kubernetes集群之间可能存在差异。
            hostName := jobcontroller.GenGeneralName(tfjob.Name, rt, fmt.Sprintf("%d", i))
            svcName := hostName + "." + tfjob.Namespace + "." + "svc"
            cluserDomain := os.Getenv(EnvCustomClusterDomain)
            if len(cluserDomain) > 0 {
                svcName += "." + cluserDomain
            }
            endpoint := fmt.Sprintf("%s:%d", svcName, port)
            replicaNames = append(replicaNames, endpoint)
        }
        clusterSpec[rt] = replicaNames
    }
    return clusterSpec, nil
}

pkg/control/pod_control.go:
使用集群的配置信息,真正启动Pod的创建:

func (r RealPodControl) CreatePodsWithControllerRef(...) error {
    ... ...
    return r.createPods("", namespace, template, controllerObject, controllerRef)
}

调用K8S接口创建pod

func (r RealPodControl) createPods(...) error {
    pod, err := GetPodFromTemplate(template, object, controllerRef)
    ... ...
    if newPod, err := r.KubeClient.CoreV1().Pods(namespace).Create(pod); err != nil {
        r.Recorder.Eventf(object, v1.EventTypeWarning, FailedCreatePodReason, "Error creating: %v", err)
        return err
    } 
    ... ...
}

pkg/controller.v1/tensorflow/service.go:
为每个给定的TFReplicaSpec检查和更新service

// 它将在创建/删除服务时发生错误时请求tfjob。
func (tc *TFController) reconcileServices(...) error {

    // 获取rt类型的所有service。
    services, err := tc.FilterServicesForReplicaType(services, rt)
    
    serviceSlices, servicesToBeRemoved := tc.GetServiceSlices(services, replicas, tflogger.LoggerForReplica(tfjob, rt))

    // 缩减
    if tfjob.Spec.EnableDynamicWorker && len(servicesToBeRemoved) > 0 {
        // 目前只允许缩小worker的service范围
        if rtype == tfv1.TFReplicaTypeWorker {
            for _, service := range servicesToBeRemoved {
                if err := tc.ServiceControl.DeleteService(tfjob.Namespace, service.Name, tfjob); err != nil {
                    return err
                }
            }
        }
    }

    for index, serviceSlice := range serviceSlices {
        if len(serviceSlice) == 0 {
            err = tc.createNewService(tfjob, rtype, strconv.Itoa(index), spec)
            
        }
    }
}

为给定的indextype创建一个新service

func (tc *TFController) createNewService(tfjob *tfv1.TFJob, rtype tfv1.TFReplicaType, index string, spec *common.ReplicaSpec) error {
    ... ...
    expectationServicesKey := jobcontroller.GenExpectationServicesKey(tfjobKey, rt)
    err = tc.Expectations.ExpectCreations(expectationServicesKey, 1)
    
    // 创建 OwnerReference.
    controllerRef := tc.GenOwnerReference(tfjob)
    ... ...
    // 直接生成了Service的配置信息
    service := &v1.Service{
        Spec: v1.ServiceSpec{
            ClusterIP: "None",
            Selector:  labels,
            Ports: []v1.ServicePort{
                {
                    Name: tfv1.DefaultPortName,
                    Port: port,
                },
            },
        },
    }
    ... ...
    err = tc.ServiceControl.CreateServicesWithControllerRef(tfjob.Namespace, service, tfjob, controllerRef)
    ... ...
}

pkg/control/service_control.go:
使用集群的配置信息,真正启动Service的创建:

func (r RealServiceControl) CreateServicesWithControllerRef(...) error {
    ... ...
    return r.createServices(namespace, service, controllerObject, controllerRef)
}

调用K8S接口创建service

func (r RealServiceControl) createServices(namespace string, service *v1.Service, object runtime.Object, controllerRef *metav1.OwnerReference) error {
    serviceWithOwner, err := getServiceFromTemplate(service, object, controllerRef)
    ... ...
    newService, err := r.KubeClient.CoreV1().Services(namespace).Create(serviceWithOwner)
    ... ...
}

Good!要想真正搞懂Kubeflow,就必须要搞懂其核心TFJob的实现机制,如我们所见,TFJob代码量并不多,实现逻辑也不难掌握,以此为突破口,如果有必要,我们完全可以在参照它实现一套自己的定制化分布式训练框架。后续会有Kubeflow@Pipelines系列,如果本文对你有帮助,需要你的点赞收藏或直接关注我,会不定时更新技术干货和学习感悟,感谢支持~。

技术分享:

  1. Kubeflow-K8S的机器学习工具包,太牛了!
  2. 从原理到实战,彻底搞懂 Nginx!
  3. 掌握Shell编程,一篇就够了
  4. Kafka 概述:深入理解架构

你可能感兴趣的:(Kubeflow/tf-operator源码分析)