/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.shuffle.writer;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.spark.Aggregator;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.shuffle.FetchFailedException;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.ShuffleHandleInfo;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.writer.AddBlockEvent;
import org.apache.spark.shuffle.writer.BufferManagerOptions;
import org.apache.spark.shuffle.writer.WriteBufferManager;
import org.apache.spark.storage.BlockManagerId;
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.request.RssReassignServersRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
import org.apache.uniffle.client.response.RssReassignServersReponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssSendFailedException;
import org.apache.uniffle.common.exception.RssWaitFailedException;
import org.apache.uniffle.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.uniffle.shaded.com.google.common.collect.Lists;
import org.apache.uniffle.shaded.com.google.common.collect.Maps;
import org.apache.uniffle.shaded.com.google.common.collect.Sets;
import org.apache.uniffle.shaded.com.google.common.util.concurrent.Uninterruptibles;
import org.apache.uniffle.storage.util.StorageType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Function1;
import scala.Option;
import scala.Product2;
import scala.collection.Iterator;

public class RssShuffleWriter<K, V, C>
extends ShuffleWriter<K, V> {
    private static final Logger LOG = LoggerFactory.getLogger(RssShuffleWriter.class);
    private static final String DUMMY_HOST = "dummy_host";
    private static final int DUMMY_PORT = 99999;
    private final Set<ShuffleServerInfo> shuffleServersForData;
    private Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds;
    private final ShuffleWriteClient shuffleWriteClient;
    private final Map<Integer, List<ShuffleServerInfo>> partitionToServers;
    private String appId;
    private int numMaps;
    private int shuffleId;
    private int bitmapSplitNum;
    private String taskId;
    private long taskAttemptId;
    private ShuffleDependency<K, V, C> shuffleDependency;
    private ShuffleWriteMetrics shuffleWriteMetrics;
    private Partitioner partitioner;
    private boolean shouldPartition;
    private WriteBufferManager bufferManager;
    private RssShuffleManager shuffleManager;
    private long sendCheckTimeout;
    private long sendCheckInterval;
    private boolean isMemoryShuffleEnabled;
    private final Function<String, Boolean> taskFailureCallback;
    private final Set<Long> blockIds = Sets.newConcurrentHashSet();
    private TaskContext taskContext;
    private SparkConf sparkConf;

    public RssShuffleWriter(String appId, int shuffleId, String taskId, long taskAttemptId, WriteBufferManager bufferManager, ShuffleWriteMetrics shuffleWriteMetrics, RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, RssShuffleHandle<K, V, C> rssHandle, ShuffleHandleInfo shuffleHandleInfo, TaskContext context) {
        this(appId, shuffleId, taskId, taskAttemptId, shuffleWriteMetrics, shuffleManager, sparkConf, shuffleWriteClient, rssHandle, (String tid) -> true, shuffleHandleInfo, context);
        this.bufferManager = bufferManager;
    }

    private RssShuffleWriter(String appId, int shuffleId, String taskId, long taskAttemptId, ShuffleWriteMetrics shuffleWriteMetrics, RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, RssShuffleHandle<K, V, C> rssHandle, Function<String, Boolean> taskFailureCallback, ShuffleHandleInfo shuffleHandleInfo, TaskContext context) {
        this.appId = appId;
        this.shuffleId = shuffleId;
        this.taskId = taskId;
        this.taskAttemptId = taskAttemptId;
        this.numMaps = rssHandle.getNumMaps();
        this.shuffleDependency = rssHandle.getDependency();
        this.shuffleWriteMetrics = shuffleWriteMetrics;
        this.partitioner = this.shuffleDependency.partitioner();
        this.shuffleManager = shuffleManager;
        this.shouldPartition = this.partitioner.numPartitions() > 1;
        this.sendCheckTimeout = (Long)sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS);
        this.sendCheckInterval = (Long)sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS);
        this.bitmapSplitNum = (Integer)sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
        this.serverToPartitionToBlockIds = Maps.newHashMap();
        this.shuffleWriteClient = shuffleWriteClient;
        this.shuffleServersForData = shuffleHandleInfo.getShuffleServersForData();
        this.partitionToServers = shuffleHandleInfo.getPartitionToServers();
        this.isMemoryShuffleEnabled = this.isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
        this.taskFailureCallback = taskFailureCallback;
        this.taskContext = context;
        this.sparkConf = sparkConf;
    }

    public RssShuffleWriter(String appId, int shuffleId, String taskId, long taskAttemptId, ShuffleWriteMetrics shuffleWriteMetrics, RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, RssShuffleHandle<K, V, C> rssHandle, Function<String, Boolean> taskFailureCallback, TaskContext context) {
        this(appId, shuffleId, taskId, taskAttemptId, shuffleWriteMetrics, shuffleManager, sparkConf, shuffleWriteClient, rssHandle, taskFailureCallback, shuffleManager.getShuffleHandleInfo(rssHandle), context);
        WriteBufferManager bufferManager;
        BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
        ShuffleHandleInfo shuffleHandleInfo = shuffleManager.getShuffleHandleInfo(rssHandle);
        this.bufferManager = bufferManager = new WriteBufferManager(shuffleId, taskId, taskAttemptId, bufferOptions, rssHandle.getDependency().serializer(), shuffleHandleInfo.getPartitionToServers(), context.taskMemoryManager(), shuffleWriteMetrics, RssSparkConfig.toRssConf(sparkConf), this::processShuffleBlockInfos);
    }

    private boolean isMemoryShuffleEnabled(String storageType) {
        return StorageType.withMemory(StorageType.valueOf(storageType));
    }

    private BlockManagerId createDummyBlockManagerId(String executorId, long taskAttemptId) {
        return BlockManagerId.apply((String)executorId, (String)DUMMY_HOST, (int)99999, (Option)Option.apply((Object)Long.toString(taskAttemptId)));
    }

    public void write(Iterator<Product2<K, V>> records) {
        try {
            this.writeImpl(records);
        }
        catch (Exception e) {
            this.taskFailureCallback.apply(this.taskId);
            if (this.shuffleManager.isRssResubmitStage()) {
                this.throwFetchFailedIfNecessary(e);
            }
            throw e;
        }
    }

    private void writeImpl(Iterator<Product2<K, V>> records) {
        List<ShuffleBlockInfo> shuffleBlockInfos;
        long recordCount = 0L;
        while (records.hasNext()) {
            ++recordCount;
            Product2 record = (Product2)records.next();
            int partition = this.getPartition(record._1());
            if (this.shuffleDependency.mapSideCombine()) {
                Function1 createCombiner = ((Aggregator)this.shuffleDependency.aggregator().get()).createCombiner();
                Object c = createCombiner.apply(record._2());
                shuffleBlockInfos = this.bufferManager.addRecord(partition, record._1(), c);
            } else {
                shuffleBlockInfos = this.bufferManager.addRecord(partition, record._1(), record._2());
            }
            this.processShuffleBlockInfos(shuffleBlockInfos);
        }
        long start = System.currentTimeMillis();
        shuffleBlockInfos = this.bufferManager.clear();
        this.processShuffleBlockInfos(shuffleBlockInfos);
        long s = System.currentTimeMillis();
        this.checkSentRecordCount(recordCount);
        this.checkBlockSendResult(this.blockIds);
        long checkDuration = System.currentTimeMillis() - s;
        long commitDuration = 0L;
        if (!this.isMemoryShuffleEnabled) {
            s = System.currentTimeMillis();
            this.sendCommit();
            commitDuration = System.currentTimeMillis() - s;
        }
        long writeDurationMs = this.bufferManager.getWriteTime() + (System.currentTimeMillis() - start);
        this.shuffleWriteMetrics.incWriteTime(TimeUnit.MILLISECONDS.toNanos(writeDurationMs));
        LOG.info("Finish write shuffle for appId[" + this.appId + "], shuffleId[" + this.shuffleId + "], taskId[" + this.taskId + "] with write " + writeDurationMs + " ms, include checkSendResult[" + checkDuration + "], commit[" + commitDuration + "], " + this.bufferManager.getManagerCostInfo());
    }

    private void checkSentRecordCount(long recordCount) {
        if (recordCount != this.bufferManager.getRecordCount()) {
            String errorMsg = "Potential record loss may have occurred while preparing to send blocks for task[" + this.taskId + "]";
            throw new RssSendFailedException(errorMsg);
        }
    }

    private List<CompletableFuture<Long>> processShuffleBlockInfos(List<ShuffleBlockInfo> shuffleBlockInfoList) {
        if (shuffleBlockInfoList != null && !shuffleBlockInfoList.isEmpty()) {
            shuffleBlockInfoList.stream().forEach(sbi -> {
                long blockId = sbi.getBlockId();
                this.blockIds.add(blockId);
                int partitionId = sbi.getPartitionId();
                sbi.getShuffleServerInfos().forEach(shuffleServerInfo -> {
                    Map pToBlockIds = this.serverToPartitionToBlockIds.computeIfAbsent((ShuffleServerInfo)shuffleServerInfo, k -> Maps.newHashMap());
                    pToBlockIds.computeIfAbsent(partitionId, v -> Sets.newHashSet()).add(blockId);
                });
            });
            return this.postBlockEvent(shuffleBlockInfoList);
        }
        return Collections.emptyList();
    }

    protected List<CompletableFuture<Long>> postBlockEvent(List<ShuffleBlockInfo> shuffleBlockInfoList) {
        ArrayList<CompletableFuture<Long>> futures = new ArrayList<CompletableFuture<Long>>();
        for (AddBlockEvent event : this.bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
            futures.add(this.shuffleManager.sendData(event));
        }
        return futures;
    }

    @VisibleForTesting
    protected void sendCommit() {
        ExecutorService executor = Executors.newSingleThreadExecutor();
        Future<Boolean> future = executor.submit(() -> this.shuffleWriteClient.sendCommit(this.shuffleServersForData, this.appId, this.shuffleId, this.numMaps));
        long start = System.currentTimeMillis();
        int currentWait = 200;
        int maxWait = 5000;
        while (!future.isDone()) {
            LOG.info("Wait commit to shuffle server for task[" + this.taskAttemptId + "] cost " + (System.currentTimeMillis() - start) + " ms");
            Uninterruptibles.sleepUninterruptibly(currentWait, TimeUnit.MILLISECONDS);
            currentWait = Math.min(currentWait * 2, maxWait);
        }
        try {
            if (!future.get().booleanValue()) {
                throw new RssException("Failed to commit task to shuffle server");
            }
        }
        catch (InterruptedException ie) {
            LOG.warn("Ignore the InterruptedException which should be caused by internal killed");
        }
        catch (Exception e) {
            throw new RssException("Exception happened when get commit status", e);
        }
        finally {
            executor.shutdown();
        }
    }

    @VisibleForTesting
    protected void checkBlockSendResult(Set<Long> blockIds) {
        block2: {
            long start = System.currentTimeMillis();
            do {
                Set<Long> failedBlockIds = this.shuffleManager.getFailedBlockIds(this.taskId);
                Set<Long> successBlockIds = this.shuffleManager.getSuccessBlockIds(this.taskId);
                if (failedBlockIds.size() > 0) {
                    String errorMsg = "Send failed: Task[" + this.taskId + "] failed because " + failedBlockIds.size() + " blocks can't be sent to shuffle server: " + this.shuffleManager.getBlockIdsFailedSendTracker(this.taskId).getFaultyShuffleServers();
                    LOG.error(errorMsg);
                    throw new RssSendFailedException(errorMsg);
                }
                blockIds.removeAll(successBlockIds);
                if (blockIds.isEmpty()) break block2;
                LOG.info("Wait " + blockIds.size() + " blocks sent to shuffle server");
                Uninterruptibles.sleepUninterruptibly(this.sendCheckInterval, TimeUnit.MILLISECONDS);
            } while (System.currentTimeMillis() - start <= this.sendCheckTimeout);
            String errorMsg = "Timeout: Task[" + this.taskId + "] failed because " + blockIds.size() + " blocks can't be sent to shuffle server in " + this.sendCheckTimeout + " ms.";
            LOG.error(errorMsg);
            throw new RssWaitFailedException(errorMsg);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Option<MapStatus> stop(boolean success) {
        try {
            if (success) {
                long[] partitionLengths = new long[this.partitioner.numPartitions()];
                Arrays.fill(partitionLengths, 1L);
                BlockManagerId blockManagerId = this.createDummyBlockManagerId(this.appId + "_" + this.taskId, this.taskAttemptId);
                long start = System.currentTimeMillis();
                this.shuffleWriteClient.reportShuffleResult(this.serverToPartitionToBlockIds, this.appId, this.shuffleId, this.taskAttemptId, this.bitmapSplitNum);
                LOG.info("Report shuffle result for task[{}] with bitmapNum[{}] cost {} ms", new Object[]{this.taskAttemptId, this.bitmapSplitNum, System.currentTimeMillis() - start});
                MapStatus mapStatus = MapStatus$.MODULE$.apply(blockManagerId, partitionLengths);
                Option option = Option.apply((Object)mapStatus);
                return option;
            }
            Option option = Option.empty();
            return option;
        }
        finally {
            if (this.bufferManager != null) {
                this.bufferManager.freeAllMemory();
            }
            if (this.shuffleManager != null) {
                this.shuffleManager.clearTaskMeta(this.taskId);
            }
        }
    }

    @VisibleForTesting
    protected <T> int getPartition(T key) {
        int result = 0;
        if (this.shouldPartition) {
            result = this.partitioner.getPartition(key);
        }
        return result;
    }

    @VisibleForTesting
    protected Map<Integer, Set<Long>> getPartitionToBlockIds() {
        return this.serverToPartitionToBlockIds.values().stream().flatMap(s -> s.entrySet().stream()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (existingSet, newSet) -> {
            HashSet mergedSet = new HashSet(existingSet);
            mergedSet.addAll(newSet);
            return mergedSet;
        }));
    }

    @VisibleForTesting
    protected ShuffleWriteMetrics getShuffleWriteMetrics() {
        return this.shuffleWriteMetrics;
    }

    private static ShuffleManagerClient createShuffleManagerClient(String host, int port) throws IOException {
        ClientType grpc = ClientType.GRPC;
        return ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc, host, port);
    }

    private void throwFetchFailedIfNecessary(Exception e) {
        if (e instanceof RssSendFailedException) {
            FailedBlockSendTracker blockIdsFailedSendTracker = this.shuffleManager.getBlockIdsFailedSendTracker(this.taskId);
            ArrayList<ShuffleServerInfo> shuffleServerInfos = Lists.newArrayList(blockIdsFailedSendTracker.getFaultyShuffleServers());
            RssReportShuffleWriteFailureRequest req = new RssReportShuffleWriteFailureRequest(this.appId, this.shuffleId, this.taskContext.stageAttemptNumber(), shuffleServerInfos, e.getMessage());
            RssConf rssConf = RssSparkConfig.toRssConf(this.sparkConf);
            String driver = rssConf.getString("driver.host", "");
            int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
            try (ShuffleManagerClient shuffleManagerClient = RssShuffleWriter.createShuffleManagerClient(driver, port);){
                RssReportShuffleWriteFailureResponse response = shuffleManagerClient.reportShuffleWriteFailure(req);
                if (response.getReSubmitWholeStage()) {
                    RssReassignServersRequest rssReassignServersRequest = new RssReassignServersRequest(this.taskContext.stageId(), this.taskContext.stageAttemptNumber(), this.shuffleId, this.partitioner.numPartitions());
                    RssReassignServersReponse rssReassignServersReponse = shuffleManagerClient.reassignShuffleServers(rssReassignServersRequest);
                    LOG.info("Whether the reassignment is successful: {}", (Object)rssReassignServersReponse.isNeedReassign());
                    FetchFailedException ffe = RssSparkShuffleUtils.createFetchFailedException(this.shuffleId, -1, this.taskContext.stageAttemptNumber(), e);
                    throw new RssException((Throwable)ffe);
                }
            }
            catch (IOException ioe) {
                LOG.info("Error closing shuffle manager client with error:", (Throwable)ioe);
            }
        }
        throw new RssException(e);
    }
}

