1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
| func executeSubTasks[T any, R interface{ GetStatus() *commonpb.Status }]( ctx context.Context, tasks []subTask[T], evaluator PartialResultEvaluator, execute func(context.Context, T, cluster.Worker) (R, error), taskType string, log *log.MLogger, ) ([]R, error) { ctx, cancel := context.WithCancel(ctx) defer cancel() var partialResultRequiredDataRatio float64 if taskType == "Query" || taskType == "Search" { partialResultRequiredDataRatio = paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.GetAsFloat() } else { partialResultRequiredDataRatio = 1.0 } wg, ctx := errgroup.WithContext(ctx) resultCh := make(chan channelResult, len(tasks)) for _, task := range tasks { task := task wg.Go(func() error { var result R var err error if task.targetID == -1 || task.worker == nil { err = fmt.Errorf("segments not loaded in any worker: %v", ...) } else { result, err = execute(ctx, task.req, task.worker) if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { err = fmt.Errorf("worker(%d) query failed: %s", task.targetID, result.GetStatus().GetReason()) } } if err != nil { log.Warn("failed to execute sub task", ...) if partialResultRequiredDataRatio == 1 { return err } } resultCh <- channelResult{ nodeID: task.targetID, result: result, err: err, segments: req.GetSegmentIDs(), } return nil }) } if err := wg.Wait(); err != nil { return nil, err } close(resultCh) successSegmentList := typeutil.NewSet[int64]() failureSegmentList := make([]int64, 0) var errors []error results := make([]R, 0, len(tasks)) for item := range resultCh { if item.err == nil { successSegmentList.Insert(item.segments...) results = append(results, item.result) } else { failureSegmentList = append(failureSegmentList, item.segments...) errors = append(errors, item.err) } } if len(errors) == 0 { return results, nil } if evaluator != nil { shouldReturnPartial, accessedDataRatio := evaluator( taskType, successSegmentList, failureSegmentList, errors) if shouldReturnPartial { log.Info("partial result executed successfully", zap.Float64("accessedDataRatio", accessedDataRatio), zap.Int64s("failureSegmentList", failureSegmentList), ) return results, nil } } return nil, merr.Combine(errors...) }
|