/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderDummycode;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderFeatureHash;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderPassThrough;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderUDF;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderWordEmbedding;
import org.apache.sysds.runtime.transform.encode.EncoderMVImpute;
import org.apache.sysds.runtime.transform.encode.EncoderOmit;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.TransformStatistics;
import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONObject;

public interface EncoderFactory {
    public static final Log LOG = LogFactory.getLog((String)EncoderFactory.class.getName());

    public static MultiColumnEncoder createEncoder(String spec, int clen) {
        return EncoderFactory.createEncoder(spec, null, clen, null, null, -1, -1);
    }

    public static MultiColumnEncoder createEncoder(String spec, String[] colnames, int clen, FrameBlock meta) {
        return EncoderFactory.createEncoder(spec, colnames, clen, meta, null, -1, -1);
    }

    public static MultiColumnEncoder createEncoder(String spec, String[] colnames, int clen, FrameBlock meta, int minCol, int maxCol) {
        return EncoderFactory.createEncoder(spec, colnames, clen, meta, null, minCol, maxCol);
    }

    public static MultiColumnEncoder createEncoder(String spec, String[] colnames, Types.ValueType[] schema, int clen, FrameBlock meta) {
        return EncoderFactory.createEncoder(spec, colnames, clen, meta);
    }

    public static MultiColumnEncoder createEncoder(String spec, String[] colnames, Types.ValueType[] schema, int clen, FrameBlock meta, MatrixBlock embeddings) {
        Types.ValueType[] lschema = schema == null ? UtilFunctions.nCopies(clen, Types.ValueType.STRING) : schema;
        return EncoderFactory.createEncoder(spec, colnames, lschema, meta, embeddings);
    }

    public static MultiColumnEncoder createEncoder(String spec, String[] colnames, Types.ValueType[] schema, FrameBlock meta) {
        return EncoderFactory.createEncoder(spec, colnames, schema, meta, -1, -1);
    }

    public static MultiColumnEncoder createEncoder(String spec, String[] colnames, Types.ValueType[] schema, FrameBlock meta, int minCol, int maxCol) {
        return EncoderFactory.createEncoder(spec, colnames, schema.length, meta, null, minCol, maxCol);
    }

    public static MultiColumnEncoder createEncoder(String spec, String[] colnames, int clen, FrameBlock meta, MatrixBlock embeddings) {
        return EncoderFactory.createEncoder(spec, colnames, UtilFunctions.nCopies(clen, Types.ValueType.STRING), meta, embeddings);
    }

    public static MultiColumnEncoder createEncoder(String spec, String[] colnames, Types.ValueType[] schema, FrameBlock meta, MatrixBlock embeddings) {
        return EncoderFactory.createEncoder(spec, colnames, schema.length, meta, embeddings, -1, -1);
    }

    public static MultiColumnEncoder createEncoder(String spec, String[] colnames, int clen, FrameBlock meta, MatrixBlock embeddings, int minCol, int maxCol) {
        MultiColumnEncoder encoder;
        try {
            JSONObject jSpec = new JSONObject(spec);
            ArrayList<ColumnEncoderComposite> lencoders = new ArrayList<ColumnEncoderComposite>();
            HashMap<Integer, List<ColumnEncoder>> colEncoders = new HashMap<Integer, List<ColumnEncoder>>();
            boolean ids = jSpec.containsKey("ids") && jSpec.getBoolean("ids");
            TfMetaUtils.checkValidEncoders(jSpec);
            List<Integer> rcIDs = Arrays.asList(ArrayUtils.toObject((int[])TfMetaUtils.parseJsonIDList(jSpec, colnames, TfUtils.TfMethod.RECODE.toString(), minCol, maxCol)));
            List<Integer> haIDs = Arrays.asList(ArrayUtils.toObject((int[])TfMetaUtils.parseJsonIDList(jSpec, colnames, TfUtils.TfMethod.HASH.toString(), minCol, maxCol)));
            List<Integer> dcIDs = Arrays.asList(ArrayUtils.toObject((int[])TfMetaUtils.parseJsonIDList(jSpec, colnames, TfUtils.TfMethod.DUMMYCODE.toString(), minCol, maxCol)));
            List<Integer> binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol);
            List<Integer> weIDs = Arrays.asList(ArrayUtils.toObject((int[])TfMetaUtils.parseJsonIDList(jSpec, colnames, TfUtils.TfMethod.WORD_EMBEDDING.toString(), minCol, maxCol)));
            List<Integer> bowIDs = Arrays.asList(ArrayUtils.toObject((int[])TfMetaUtils.parseJsonIDList(jSpec, colnames, TfUtils.TfMethod.BAG_OF_WORDS.toString(), minCol, maxCol)));
            rcIDs = CollectionUtils.unionDistinct(rcIDs, CollectionUtils.except(CollectionUtils.except(dcIDs, binIDs), haIDs));
            if (CollectionUtils.intersect(rcIDs, binIDs, haIDs, weIDs, bowIDs)) {
                throw new DMLRuntimeException("More than one encoders (recode, binning, hashing, word_embedding, bag_of_words) on one column is not allowed:\n" + spec);
            }
            List<Integer> ptIDs = CollectionUtils.except(UtilFunctions.getSeqList(1, clen, 1), CollectionUtils.naryUnionDistinct(rcIDs, haIDs, binIDs, weIDs, bowIDs));
            ArrayList<Integer> oIDs = new ArrayList<Integer>(Arrays.asList(ArrayUtils.toObject((int[])TfMetaUtils.parseJsonIDList(jSpec, colnames, TfUtils.TfMethod.OMIT.toString(), minCol, maxCol))));
            ArrayList<Integer> mvIDs = new ArrayList<Integer>(Arrays.asList(ArrayUtils.toObject((int[])TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, TfUtils.TfMethod.IMPUTE.toString(), minCol, maxCol))));
            List<Integer> udfIDs = TfMetaUtils.parseUDFColIDs(jSpec, colnames, minCol, maxCol);
            rcIDs.removeIf(i -> i > clen);
            ptIDs.removeIf(i -> i > clen);
            oIDs.removeIf(i -> i > clen);
            mvIDs.removeIf(i -> i > clen);
            udfIDs.removeIf(i -> i > clen);
            binIDs.removeIf(i -> i > clen);
            weIDs.removeIf(i -> i > clen);
            bowIDs.removeIf(i -> i > clen);
            if (!rcIDs.isEmpty()) {
                for (Integer n : rcIDs) {
                    EncoderFactory.addEncoderToMap(new ColumnEncoderRecode(n), colEncoders);
                }
            }
            if (!haIDs.isEmpty()) {
                for (Integer n : haIDs) {
                    EncoderFactory.addEncoderToMap(new ColumnEncoderFeatureHash(n, TfMetaUtils.getK(jSpec)), colEncoders);
                }
            }
            if (!ptIDs.isEmpty()) {
                for (Integer n : ptIDs) {
                    EncoderFactory.addEncoderToMap(new ColumnEncoderPassThrough(n), colEncoders);
                }
            }
            if (!weIDs.isEmpty()) {
                for (Integer n : weIDs) {
                    EncoderFactory.addEncoderToMap(new ColumnEncoderWordEmbedding(n), colEncoders);
                }
            }
            if (!bowIDs.isEmpty()) {
                for (Integer n : bowIDs) {
                    EncoderFactory.addEncoderToMap(new ColumnEncoderBagOfWords(n), colEncoders);
                }
            }
            if (!binIDs.isEmpty()) {
                for (Integer n : (JSONArray)jSpec.get(TfUtils.TfMethod.BIN.toString())) {
                    ColumnEncoderBin.BinMethod binMethod;
                    int numBins;
                    JSONObject colspec = (JSONObject)((Object)n);
                    int n2 = numBins = colspec.containsKey("numbins") ? colspec.getInt("numbins") : 1;
                    int id = TfMetaUtils.parseJsonObjectID(colspec, colnames, minCol, maxCol, ids);
                    if (id <= 0) continue;
                    String method = colspec.get("method").toString().toUpperCase();
                    if ("EQUI-WIDTH".equals(method)) {
                        binMethod = ColumnEncoderBin.BinMethod.EQUI_WIDTH;
                    } else if ("EQUI-HEIGHT".equals(method)) {
                        binMethod = ColumnEncoderBin.BinMethod.EQUI_HEIGHT;
                    } else if ("EQUI-HEIGHT-APPROX".equals(method)) {
                        binMethod = ColumnEncoderBin.BinMethod.EQUI_HEIGHT_APPROX;
                    } else {
                        throw new DMLRuntimeException("Unsupported binning method: " + method);
                    }
                    ColumnEncoderBin bin = new ColumnEncoderBin(id, numBins, binMethod);
                    EncoderFactory.addEncoderToMap(bin, colEncoders);
                }
            }
            if (!dcIDs.isEmpty()) {
                for (Integer n : dcIDs) {
                    EncoderFactory.addEncoderToMap(new ColumnEncoderDummycode(n), colEncoders);
                }
            }
            if (!udfIDs.isEmpty()) {
                String name = jSpec.getJSONObject("udf").getString("name");
                for (Integer id : udfIDs) {
                    EncoderFactory.addEncoderToMap(new ColumnEncoderUDF(id, name), colEncoders);
                }
            }
            for (Map.Entry entry : colEncoders.entrySet()) {
                if (DMLScript.STATISTICS) {
                    TransformStatistics.incEncoderCount(((List)entry.getValue()).size());
                }
                lencoders.add(new ColumnEncoderComposite((List)entry.getValue()));
            }
            encoder = new MultiColumnEncoder(lencoders);
            if (!oIDs.isEmpty()) {
                encoder.addReplaceLegacyEncoder(new EncoderOmit(jSpec, colnames, clen, minCol, maxCol));
                if (DMLScript.STATISTICS) {
                    TransformStatistics.incEncoderCount(1L);
                }
            }
            if (!mvIDs.isEmpty()) {
                EncoderMVImpute ma = new EncoderMVImpute(jSpec, colnames, clen, minCol, maxCol);
                ma.initRecodeIDList(rcIDs);
                encoder.addReplaceLegacyEncoder(ma);
                if (DMLScript.STATISTICS) {
                    TransformStatistics.incEncoderCount(1L);
                }
            }
            if (meta != null) {
                Object[] colnames2 = meta.getColumnNames();
                if (!TfMetaUtils.isIDSpec(jSpec) && colnames != null && colnames2 != null && !Objects.deepEquals(colnames, colnames2)) {
                    HashMap<String, Integer> hashMap = EncoderFactory.getColumnPositions((String[])colnames2);
                    FrameBlock meta2 = new FrameBlock(meta.getSchema(), (String[])colnames2);
                    for (int i2 = 0; i2 < colnames.length; ++i2) {
                        if (!hashMap.containsKey(colnames[i2])) {
                            throw new DMLRuntimeException("Column name not found in meta data: " + colnames[i2] + " (meta: " + Arrays.toString(colnames2) + ")");
                        }
                        int pos = hashMap.get(colnames[i2]);
                        meta2.setColumn(i2, meta.getColumn(pos));
                        meta2.setColumnMetadata(i2, meta.getColumnMetadata(pos));
                    }
                    meta = meta2;
                }
                encoder.initMetaData(meta);
            }
            if (!weIDs.isEmpty()) {
                encoder.initEmbeddings(embeddings);
            }
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        return encoder;
    }

    private static void addEncoderToMap(ColumnEncoder encoder, HashMap<Integer, List<ColumnEncoder>> map) {
        if (!map.containsKey(encoder._colID)) {
            map.put(encoder._colID, new ArrayList());
        }
        map.get(encoder._colID).add(encoder);
    }

    public static int getEncoderType(ColumnEncoder columnEncoder) {
        if (columnEncoder instanceof ColumnEncoderBin) {
            return ColumnEncoder.EncoderType.Bin.ordinal();
        }
        if (columnEncoder instanceof ColumnEncoderDummycode) {
            return ColumnEncoder.EncoderType.Dummycode.ordinal();
        }
        if (columnEncoder instanceof ColumnEncoderFeatureHash) {
            return ColumnEncoder.EncoderType.FeatureHash.ordinal();
        }
        if (columnEncoder instanceof ColumnEncoderPassThrough) {
            return ColumnEncoder.EncoderType.PassThrough.ordinal();
        }
        if (columnEncoder instanceof ColumnEncoderRecode) {
            return ColumnEncoder.EncoderType.Recode.ordinal();
        }
        if (columnEncoder instanceof ColumnEncoderWordEmbedding) {
            return ColumnEncoder.EncoderType.WordEmbedding.ordinal();
        }
        if (columnEncoder instanceof ColumnEncoderBagOfWords) {
            return ColumnEncoder.EncoderType.BagOfWords.ordinal();
        }
        throw new DMLRuntimeException("Unsupported encoder type: " + columnEncoder.getClass().getCanonicalName());
    }

    public static ColumnEncoder createInstance(int type) {
        ColumnEncoder.EncoderType etype = ColumnEncoder.EncoderType.values()[type];
        switch (etype) {
            case Bin: {
                return new ColumnEncoderBin();
            }
            case Dummycode: {
                return new ColumnEncoderDummycode();
            }
            case FeatureHash: {
                return new ColumnEncoderFeatureHash();
            }
            case PassThrough: {
                return new ColumnEncoderPassThrough();
            }
            case Recode: {
                return new ColumnEncoderRecode();
            }
            case WordEmbedding: {
                return new ColumnEncoderWordEmbedding();
            }
            case BagOfWords: {
                return new ColumnEncoderBagOfWords();
            }
        }
        throw new DMLRuntimeException("Unsupported encoder type: " + etype);
    }

    private static HashMap<String, Integer> getColumnPositions(String[] colnames) {
        HashMap<String, Integer> ret = new HashMap<String, Integer>();
        for (int i = 0; i < colnames.length; ++i) {
            ret.put(colnames[i], i);
        }
        return ret;
    }
}

