FfDL任务挂起halt和恢复resume功能实现

    FfDL官方目前代码任务挂起仅仅实现将数据库里面的Status字段修改为HALTED,没有实现真正意义上的pod销毁、任务状态信息保留,resume接口更是没有实现。由于项目需要,需要实现这部分功能。

    思路:采用restapi PatchModel进行服用,当消息是Halt的时候标记挂起操作,Resume标记恢复操作。

    挂起实现逻辑:请求参数为训练id,从mongo数据库中获取需要挂起任务的信息,将status标记为HALT状态,调用lcm gRPC调用,lcm gRPC调用k8s api销毁该任务相关的pod。

    恢复实现逻辑:请求参数为训练id,从mongo数据库中获取需要恢复的训练任务信息,查询status是否为HALT状态,若是,将status标记为pending,调用lcm的gRPC调用,调用k8s api创建训练任务相关pod

    tensorflow实验,每1000次迭代保存ckpt文件,在训练过程中调用halt挂起任务,随后调用resume恢复任务,实验代码可以从保存的ckpt中恢复之前保存的训练结果,并继续训练

   贴上修改代码:

From 89cc0990e6b012723bdbe0d981be3ba1ac00bbf4 Mon Sep 17 00:00:00 2001
From: James 
Date: Mon, 18 Feb 2019 11:28:59 +0000
Subject: [PATCH] add trainer halt and resume

Signed-off-by: James 
---
 restapi/api_v1/server/models_impl.go |  3 +-
 trainer/trainer/trainer_impl.go      | 54 ++++++++++++++++++++++++----
 2 files changed, 50 insertions(+), 7 deletions(-)

diff --git a/restapi/api_v1/server/models_impl.go b/restapi/api_v1/server/models_impl.go
index f0776a1..a04c6ed 100644
--- a/restapi/api_v1/server/models_impl.go
+++ b/restapi/api_v1/server/models_impl.go
@@ -795,7 +795,7 @@ func patchModel(params models.PatchModelParams) middleware.Responder {
 	logr := logger.LocLogger(logWithUpdateStatusParams(params))
 	logr.Debugf("patchModel invoked: %v", params.HTTPRequest.Header)
 
-	if params.Payload.Status != "halt" {
+	if (params.Payload.Status != "halt" && params.Payload.Status != "resume") {
 		return models.NewPatchModelBadRequest().WithPayload(&restmodels.Error{
 			Error:       "Bad request",
 			Code:        http.StatusBadRequest,
@@ -814,6 +814,7 @@ func patchModel(params models.PatchModelParams) middleware.Responder {
 		TrainingId: params.ModelID,
 		UserId:     getUserID(params.HTTPRequest),
 		Status:     grpc_trainer_v2.Status_HALTED,
+        StatusMessage: params.Payload.Status,
 	})
 	//
 	if err != nil {
diff --git a/trainer/trainer/trainer_impl.go b/trainer/trainer/trainer_impl.go
index d34a4f0..721b5a2 100644
--- a/trainer/trainer/trainer_impl.go
+++ b/trainer/trainer/trainer_impl.go
@@ -70,7 +70,7 @@ const (
 	collectionNameTrainingJobs = "training_jobs"
 	collectionNameJobHistory   = "job_history"
 
-	debugLogsMode = false
+	debugLogsMode = true
 
 	oldEndpointInternalPageSize = 10
 
@@ -604,7 +604,7 @@ func (s *trainerService) CreateTrainingJob(ctx context.Context, req *grpc_traine
 		qHandler = s.queues["ANY"]
 	}
 
-	rateLimited := true
+    rateLimited := true
 	qSize, err := qHandler.Size()
 	logGpuTypeQueueSize := fmt.Sprintf("%s_%s", gpuType, "queue_size")
 	logr.WithFields(logrus.Fields{
@@ -617,6 +617,7 @@ func (s *trainerService) CreateTrainingJob(ctx context.Context, req *grpc_traine
 		rateLimited = s.rateLimitTrainingJob(tr, logr)
 	}
 
+    //rateLimited = true
 	if rateLimited {
 		// either queue was not empty or rate-limiting was needed, so send this job to the queue
 		logr.Infof("training job %s is rate-limited, adding to queue %s", tr.TrainingID, gpuType)
@@ -733,7 +734,28 @@ func (s *trainerService) GetTrainingStatusID(ctx context.Context, req *grpc_trai
 
 func (s *trainerService) UpdateTrainingJob(ctx context.Context, req *grpc_trainer_v2.UpdateRequest) (*grpc_trainer_v2.UpdateResponse, error) {
 	logr := logger.LocLogger(logWith(req.TrainingId, req.UserId))
-	logr.Debugf("UpdateTrainingJob called for training %s", req.TrainingId)
+	logr.Debugf("UpdateTrainingJob called for training %s message %s", req.TrainingId, req.StatusMessage)
+
+    if(req.Status == grpc_trainer_v2.Status_HALTED) {
+	    training, err := s.repo.Find(req.TrainingId)
+        if err != nil {
+		    logr.WithError(err).Errorf("Cannot retrieve training '%s'", req.TrainingId)
+		    return nil, err
+        }
+	    ts := training.TrainingStatus
+        if (ts.Status == grpc_trainer_v2.Status_HALTED && req.StatusMessage == "resume") {
+            s.ResumeTrainingJob(ctx, &grpc_trainer_v2.ResumeRequest{
+                TrainingId: req.TrainingId,
+                UserId: req.UserId,
+            })
+        } else if (ts.Status != grpc_trainer_v2.Status_FAILED && ts.Status != grpc_trainer_v2.Status_COMPLETED && req.StatusMessage == "halt") {
+            s.HaltTrainingJob(ctx, &grpc_trainer_v2.HaltRequest{
+                TrainingId: req.TrainingId,
+                UserId: req.UserId,
+            })
+        }
+	    return &grpc_trainer_v2.UpdateResponse{TrainingId: req.TrainingId}, nil
+    }
 
 	return updateTrainingJobPostLock(s, req)
 }
@@ -1132,7 +1154,7 @@ func (s *trainerService) HaltTrainingJob(ctx context.Context, req *grpc_trainer_
 		logr.Debugf("Kubernetes job '%s' no longer exists.", job.JobId)
 
 		// update the status in mongo
-		_, err = updateTrainingJobPostLock(s, &grpc_trainer_v2.UpdateRequest{
+/*		_, err = updateTrainingJobPostLock(s, &grpc_trainer_v2.UpdateRequest{
 			TrainingId:    req.TrainingId,
 			UserId:        req.UserId,
 			Status:        grpc_trainer_v2.Status_HALTED,
@@ -1142,7 +1164,15 @@ func (s *trainerService) HaltTrainingJob(ctx context.Context, req *grpc_trainer_
 		if err != nil {
 			logr.WithError(err).Errorln("Unable to update job status to halted")
 			return nil, err
-		}
+		}*/
+	    training, _ := s.repo.Find(req.TrainingId)
+        ts := training.TrainingStatus
+	    ts.Status = grpc_trainer_v2.Status_HALTED
+	    err = s.repo.Store(training)
+	    if err != nil {
+		    logr.WithError(err).Errorf("Failed updating status of training %s in DB", req.TrainingId)
+		    return nil, err
+        }
 
 		return &grpc_trainer_v2.HaltResponse{TrainingId: job.JobId, UserId: job.UserId, Status: grpc_trainer_v2.Status_HALTED}, nil
 	}
@@ -1151,7 +1181,19 @@ func (s *trainerService) HaltTrainingJob(ctx context.Context, req *grpc_trainer_
 
 func (s *trainerService) ResumeTrainingJob(ctx context.Context, req *grpc_trainer_v2.ResumeRequest) (*grpc_trainer_v2.ResumeResponse, error) {
 	logr := logger.LocLogger(logWith(req.TrainingId, req.UserId))
-	logr.Debugf("HaltTrainingJob called")
+	logr.Debugf("ResumeTrainingJob called")
+
+	training, err := s.repo.Find(req.TrainingId)
+	if err != nil {
+		logr.WithError(err).Errorf("Cannot retrieve training '%s'", req.TrainingId)
+		return nil, err
+	}
+    err = s.submitJobToLCM(training, logr)
+    if err != nil {
+	    // err logged in submitJobToLCM
+		return nil, err
+	}
+
 	return nil, gerrf(codes.Unimplemented, "ResumeTrainingJob not implemented yet")
 }
 
-- 
2.17.1

 

你可能感兴趣的:(分布式)