diff --git a/snail-job-server/snail-job-server-job-task/src/main/java/com/aizuda/snailjob/server/job/task/support/generator/task/MapReduceTaskGenerator.java b/snail-job-server/snail-job-server-job-task/src/main/java/com/aizuda/snailjob/server/job/task/support/generator/task/MapReduceTaskGenerator.java index 7de1fd40..548899e0 100644 --- a/snail-job-server/snail-job-server-job-task/src/main/java/com/aizuda/snailjob/server/job/task/support/generator/task/MapReduceTaskGenerator.java +++ b/snail-job-server/snail-job-server-job-task/src/main/java/com/aizuda/snailjob/server/job/task/support/generator/task/MapReduceTaskGenerator.java @@ -136,10 +136,8 @@ public class MapReduceTaskGenerator extends AbstractJobTaskGenerator { } // 这里需要判断是否是map - // 平均分配map集合, 若reduceParallel > allMapJobTasks.size(), 则取allMapJobTasks.size()作为分片数 List allMapJobTasks = StreamUtils.toList(jobTasks, JobTask::getResultMessage); - int size = (allMapJobTasks.size() + reduceParallel - 1) / reduceParallel; - List> partition = Lists.partition(allMapJobTasks, size); + List> partition = averageAlgorithm(allMapJobTasks, reduceParallel); jobTasks = new ArrayList<>(partition.size()); final List finalJobTasks = jobTasks; @@ -238,4 +236,28 @@ public class MapReduceTaskGenerator extends AbstractJobTaskGenerator { return jobTasks; } + private List> averageAlgorithm(List allMapJobTasks, int shard) { + + // 最多分片数为allMapJobTasks.size() + shard = Math.min(allMapJobTasks.size(), shard); + int totalSize = allMapJobTasks.size(); + List partitionSizes = new ArrayList<>(); + int quotient = totalSize / shard; + int remainder = totalSize % shard; + + for (int i = 0; i < shard; i++) { + partitionSizes.add(quotient + (i < remainder ? 1 : 0)); + } + + List> partitions = new ArrayList<>(); + int currentIndex = 0; + + for (int size : partitionSizes) { + int endIndex = Math.min(currentIndex + size, totalSize); + partitions.add(new ArrayList<>(allMapJobTasks.subList(currentIndex, endIndex))); + currentIndex = endIndex; + } + + return partitions; + } }