diff --git a/pkg/controllers/v1alpha1/fluidapp/dataflowaffinity/dataflowaffinity_controller.go b/pkg/controllers/v1alpha1/fluidapp/dataflowaffinity/dataflowaffinity_controller.go index ac7b4f96b57..1e9328f6c32 100644 --- a/pkg/controllers/v1alpha1/fluidapp/dataflowaffinity/dataflowaffinity_controller.go +++ b/pkg/controllers/v1alpha1/fluidapp/dataflowaffinity/dataflowaffinity_controller.go @@ -76,7 +76,7 @@ func (f *DataOpJobReconciler) Reconcile(ctx context.Context, request reconcile.R Log: f.Log.WithValues("namespacedName", request.NamespacedName), NamespacedName: request.NamespacedName, } - job, err := kubeclient.GetJob(f.Client, request.Name, request.Namespace) + job, err := kubeclient.GetJobWithContext(ctx, f.Client, request.Name, request.Namespace) if err != nil { requestCtx.Log.Error(err, "fetch job error") return utils.RequeueIfError(err) @@ -106,7 +106,7 @@ func (f *DataOpJobReconciler) Reconcile(ctx context.Context, request reconcile.R // get job' status, if succeed, add label to job. condition := kubeclient.GetFinishedJobCondition(job) if condition != nil && condition.Type == batchv1.JobComplete { - err = f.injectPodNodeLabelsToJob(job) + err = f.injectPodNodeLabelsToJob(ctx, job) if err != nil { requestCtx.Log.Error(err, "update labels for job failed") return utils.RequeueIfError(err) @@ -120,8 +120,8 @@ func (f *DataOpJobReconciler) SetupWithManager(mgr ctrl.Manager, options control return watch.SetupDataOpJobWatcherWithReconciler(mgr, options, f) } -func (f *DataOpJobReconciler) injectPodNodeLabelsToJob(job *batchv1.Job) error { - pod, err := kubeclient.GetSucceedPodForJob(f.Client, job) +func (f *DataOpJobReconciler) injectPodNodeLabelsToJob(ctx context.Context, job *batchv1.Job) error { + pod, err := kubeclient.GetSucceedPodForJobWithContext(ctx, f.Client, job) if err != nil { return err } @@ -134,9 +134,9 @@ func (f *DataOpJobReconciler) injectPodNodeLabelsToJob(job *batchv1.Job) error { return fmt.Errorf("succeed job has no node name, podNamespace: %s, podName: %s", pod.Namespace, pod.Name) } - node, err := kubeclient.GetNode(f.Client, nodeName) + node, err := kubeclient.GetNodeWithContext(ctx, f.Client, nodeName) if err != nil { - return fmt.Errorf("error to get node %s: %v", nodeName, err) + return fmt.Errorf("error to get node %s: %w", nodeName, err) } annotationsToInject := map[string]string{} @@ -159,7 +159,7 @@ func (f *DataOpJobReconciler) injectPodNodeLabelsToJob(job *batchv1.Job) error { } } - if err = f.Client.Update(context.TODO(), job); err != nil { + if err = f.Client.Update(ctx, job); err != nil { return err } diff --git a/pkg/controllers/v1alpha1/fluidapp/dataflowaffinity/dataflowaffinity_controller_test.go b/pkg/controllers/v1alpha1/fluidapp/dataflowaffinity/dataflowaffinity_controller_test.go index a132fd3947c..bc37293c507 100644 --- a/pkg/controllers/v1alpha1/fluidapp/dataflowaffinity/dataflowaffinity_controller_test.go +++ b/pkg/controllers/v1alpha1/fluidapp/dataflowaffinity/dataflowaffinity_controller_test.go @@ -18,6 +18,7 @@ package dataflowaffinity import ( "context" + "errors" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -275,7 +276,7 @@ var _ = Describe("DataOpJobReconciler", func() { Log: fake.NullLogger(), } - err := f.injectPodNodeLabelsToJob(job) + err := f.injectPodNodeLabelsToJob(context.Background(), job) Expect(err).NotTo(HaveOccurred()) wantAnnotations := map[string]string{ @@ -328,9 +329,38 @@ var _ = Describe("DataOpJobReconciler", func() { Log: fake.NullLogger(), } - err := f.injectPodNodeLabelsToJob(job) + err := f.injectPodNodeLabelsToJob(context.Background(), job) Expect(err).To(HaveOccurred()) }) }) + + Context("when caller context is canceled", func() { + It("should return the context error", func() { + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-job-canceled", + }, + Spec: batchv1.JobSpec{ + Selector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + controllerUIDKey: jobControllerUIDValue, + }, + }, + }, + } + + c := fake.NewFakeClientWithScheme(testScheme, job) + f := &DataOpJobReconciler{ + Client: fake.ContextAwareClient{Client: c}, + Log: fake.NullLogger(), + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := f.injectPodNodeLabelsToJob(ctx, job) + Expect(errors.Is(err, context.Canceled)).To(BeTrue()) + }) + }) }) }) diff --git a/pkg/utils/fake/client.go b/pkg/utils/fake/client.go index 549835a2525..dac6cb29433 100644 --- a/pkg/utils/fake/client.go +++ b/pkg/utils/fake/client.go @@ -17,11 +17,67 @@ limitations under the License. package fake import ( + "context" + "k8s.io/apimachinery/pkg/runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" ) +// ContextAwareClient wraps a fake client and returns ctx.Err() before delegating. +type ContextAwareClient struct { + client.Client +} + +func (c ContextAwareClient) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + if err := ctx.Err(); err != nil { + return err + } + return c.Client.Get(ctx, key, obj, opts...) +} + +func (c ContextAwareClient) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { + if err := ctx.Err(); err != nil { + return err + } + return c.Client.List(ctx, list, opts...) +} + +func (c ContextAwareClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { + if err := ctx.Err(); err != nil { + return err + } + return c.Client.Create(ctx, obj, opts...) +} + +func (c ContextAwareClient) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { + if err := ctx.Err(); err != nil { + return err + } + return c.Client.Delete(ctx, obj, opts...) +} + +func (c ContextAwareClient) Update(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error { + if err := ctx.Err(); err != nil { + return err + } + return c.Client.Update(ctx, obj, opts...) +} + +func (c ContextAwareClient) Patch(ctx context.Context, obj client.Object, patch client.Patch, opts ...client.PatchOption) error { + if err := ctx.Err(); err != nil { + return err + } + return c.Client.Patch(ctx, obj, patch, opts...) +} + +func (c ContextAwareClient) DeleteAllOf(ctx context.Context, obj client.Object, opts ...client.DeleteAllOfOption) error { + if err := ctx.Err(); err != nil { + return err + } + return c.Client.DeleteAllOf(ctx, obj, opts...) +} + // NewFakeClientWithScheme is to fix the issue by wrappering it: // fake.NewFakeClientWithScheme is deprecated: Please use NewClientBuilder instead. (staticcheck) func NewFakeClientWithScheme(clientScheme *runtime.Scheme, initObjs ...runtime.Object) client.Client { diff --git a/pkg/utils/kubeclient/context_client_test.go b/pkg/utils/kubeclient/context_client_test.go index 49f4648fa69..483d227ab5c 100644 --- a/pkg/utils/kubeclient/context_client_test.go +++ b/pkg/utils/kubeclient/context_client_test.go @@ -1,39 +1,5 @@ package kubeclient -import ( - "context" +import "github.com/fluid-cloudnative/fluid/pkg/utils/fake" - "sigs.k8s.io/controller-runtime/pkg/client" -) - -type contextAwareClient struct { - client.Client -} - -func (c contextAwareClient) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { - if err := ctx.Err(); err != nil { - return err - } - return c.Client.Get(ctx, key, obj, opts...) -} - -func (c contextAwareClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { - if err := ctx.Err(); err != nil { - return err - } - return c.Client.Create(ctx, obj, opts...) -} - -func (c contextAwareClient) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { - if err := ctx.Err(); err != nil { - return err - } - return c.Client.Delete(ctx, obj, opts...) -} - -func (c contextAwareClient) Update(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error { - if err := ctx.Err(); err != nil { - return err - } - return c.Client.Update(ctx, obj, opts...) -} +type contextAwareClient = fake.ContextAwareClient diff --git a/pkg/utils/kubeclient/job.go b/pkg/utils/kubeclient/job.go index 398074e718b..4064de28d5a 100644 --- a/pkg/utils/kubeclient/job.go +++ b/pkg/utils/kubeclient/job.go @@ -27,35 +27,49 @@ import ( ) // GetJob gets the job given its name and namespace -func GetJob(client client.Client, name, namespace string) (*v1.Job, error) { +func GetJob(c client.Client, name, namespace string) (*v1.Job, error) { + return GetJobWithContext(context.TODO(), c, name, namespace) +} + +// GetJobWithContext gets the job given its name and namespace. +func GetJobWithContext(ctx context.Context, c client.Client, name, namespace string) (*v1.Job, error) { key := types.NamespacedName{ Namespace: namespace, Name: name, } var job v1.Job - if err := client.Get(context.TODO(), key, &job); err != nil { + if err := c.Get(ctx, key, &job); err != nil { return nil, err } return &job, nil } -func UpdateJob(client client.Client, job *v1.Job) error { - return client.Update(context.TODO(), job) +func UpdateJob(c client.Client, job *v1.Job) error { + return UpdateJobWithContext(context.TODO(), c, job) +} + +func UpdateJobWithContext(ctx context.Context, c client.Client, job *v1.Job) error { + return c.Update(ctx, job) } // GetSucceedPodForJob get the first finished pod for the job, if no succeed pod, return nil with no error. func GetSucceedPodForJob(c client.Client, job *v1.Job) (*corev1.Pod, error) { + return GetSucceedPodForJobWithContext(context.TODO(), c, job) +} + +// GetSucceedPodForJobWithContext gets the first finished pod for the job, if no succeed pod, return nil with no error. +func GetSucceedPodForJobWithContext(ctx context.Context, c client.Client, job *v1.Job) (*corev1.Pod, error) { var podList corev1.PodList selector, err := metav1.LabelSelectorAsSelector(job.Spec.Selector) if err != nil { return nil, fmt.Errorf("error converting Job %s in namespace %s selector: %v", job.Name, job.Namespace, err) } - err = c.List(context.TODO(), &podList, &client.ListOptions{ + err = c.List(ctx, &podList, &client.ListOptions{ Namespace: job.Namespace, LabelSelector: selector, }) if err != nil { - return nil, fmt.Errorf("error listing pods for Job %s in namespace %s: %v", job.Name, job.Namespace, err) + return nil, fmt.Errorf("error listing pods for Job %s in namespace %s: %w", job.Name, job.Namespace, err) } for _, pod := range podList.Items { diff --git a/pkg/utils/kubeclient/job_test.go b/pkg/utils/kubeclient/job_test.go index 9f8c8be9c97..3bbaafb43fd 100644 --- a/pkg/utils/kubeclient/job_test.go +++ b/pkg/utils/kubeclient/job_test.go @@ -18,6 +18,7 @@ package kubeclient import ( "context" + "errors" "github.com/fluid-cloudnative/fluid/pkg/utils/fake" batchv1 "k8s.io/api/batch/v1" @@ -104,6 +105,21 @@ var _ = Describe("Job related unit tests", Label("pkg.utils.kubeclient.job_test. Expect(gotPod).To(BeNil()) }) }) + + When("caller context is canceled", func() { + BeforeEach(func() { + resources = []runtime.Object{job, jobPod} + }) + + It("should return the context error", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + gotPod, err := GetSucceedPodForJobWithContext(ctx, contextAwareClient{Client: client}, job) + Expect(errors.Is(err, context.Canceled)).To(BeTrue()) + Expect(gotPod).To(BeNil()) + }) + }) }) Describe("Test UpdateJob()", func() { @@ -152,6 +168,20 @@ var _ = Describe("Job related unit tests", Label("pkg.utils.kubeclient.job_test. Expect(apierrs.IsNotFound(err)).To(BeTrue()) }) }) + + When("caller context is canceled", func() { + BeforeEach(func() { + resources = []runtime.Object{job} + }) + + It("should return the context error", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := UpdateJobWithContext(ctx, contextAwareClient{Client: client}, job) + Expect(err).To(MatchError(context.Canceled)) + }) + }) }) Describe("Test GetJob()", func() { @@ -193,6 +223,21 @@ var _ = Describe("Job related unit tests", Label("pkg.utils.kubeclient.job_test. Expect(apierrs.IsNotFound(err)).To(BeTrue()) }) }) + + When("caller context is canceled", func() { + BeforeEach(func() { + resources = []runtime.Object{job} + }) + + It("should return the context error", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + gotJob, err := GetJobWithContext(ctx, contextAwareClient{Client: client}, job.Name, job.Namespace) + Expect(err).To(MatchError(context.Canceled)) + Expect(gotJob).To(BeNil()) + }) + }) }) Describe("Test GetFinishedJobCondition()", func() { diff --git a/pkg/utils/kubeclient/node.go b/pkg/utils/kubeclient/node.go index 2d594ba2554..bd311e689f7 100644 --- a/pkg/utils/kubeclient/node.go +++ b/pkg/utils/kubeclient/node.go @@ -26,14 +26,19 @@ import ( ) // GetNode gets the latest node info -func GetNode(client client.Reader, name string) (node *corev1.Node, err error) { +func GetNode(c client.Reader, name string) (node *corev1.Node, err error) { + return GetNodeWithContext(context.TODO(), c, name) +} + +// GetNodeWithContext gets the latest node info. +func GetNodeWithContext(ctx context.Context, c client.Reader, name string) (node *corev1.Node, err error) { key := types.NamespacedName{ Name: name, } node = &corev1.Node{} - if err = client.Get(context.TODO(), key, node); err != nil { + if err = c.Get(ctx, key, node); err != nil { return nil, err } return node, err diff --git a/pkg/utils/kubeclient/node_test.go b/pkg/utils/kubeclient/node_test.go index 2f2a0d4b894..daf59d7bbcc 100644 --- a/pkg/utils/kubeclient/node_test.go +++ b/pkg/utils/kubeclient/node_test.go @@ -17,6 +17,8 @@ limitations under the License. package kubeclient import ( + "context" + "github.com/fluid-cloudnative/fluid/pkg/utils/fake" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -69,6 +71,18 @@ var _ = Describe("GetNode", func() { Expect(result.Name).To(Equal("test1")) }) }) + + Context("when caller context is canceled", func() { + It("should return the context error", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := GetNodeWithContext(ctx, contextAwareClient{Client: mockClient}, "test1") + + Expect(err).To(MatchError(context.Canceled)) + Expect(result).To(BeNil()) + }) + }) }) var _ = Describe("IsReady", func() {