/*
 * Decompiled with CFR 0.152.
 */
package dnnParams;

import com.google.flatbuffers.FlatBufferBuilder;
import dnnParams.fb.FBDimension;
import dnnParams.fb.FBDnnParams;
import dnnParams.fb.FBInputTensor;
import dnnParams.fb.FBL2Memory;
import dnnParams.fb.FBNetwork;
import dnnParams.fb.FBOutputTensor;
import dnnParams.xml.DnnParams;
import dnnParams.xml.InputTensorInfo;
import dnnParams.xml.Network;
import dnnParams.xml.TensorFormatType;
import dnnParams.xml.TensorInfo;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Unmarshaller;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import org.w3c.dom.Document;
import org.xml.sax.SAXException;

public class DnnParamsTool {
    public static final long TOTAL_L2_MEMORY_SIZE = 0x800000L;
    public static final int MAX_NAME_LENGTH = 20;

    public static void createDnnParamsFlatBuffer(String inFileName, String outFileName) {
        try {
            File file = new File(inFileName);
            JAXBContext jaxbContext = JAXBContext.newInstance(DnnParams.class);
            Unmarshaller jaxbUnmarshaller = jaxbContext.createUnmarshaller();
            DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
            try {
                dbf.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true);
            }
            catch (ParserConfigurationException e2) {
                e2.printStackTrace();
            }
            try {
                DocumentBuilder db = dbf.newDocumentBuilder();
                try {
                    Document document = db.parse(file);
                    DnnParams dnnParams = (DnnParams)jaxbUnmarshaller.unmarshal(document);
                    FlatBufferBuilder fbb = new FlatBufferBuilder(0);
                    int numNetworks = dnnParams.getNetworks().getNetwork().size();
                    System.out.println("Num of Networks = " + numNetworks);
                    System.out.println("================================");
                    int[] networkOffsets = new int[numNetworks];
                    int i = 0;
                    while (i < numNetworks) {
                        int[] dimArray;
                        Network network = dnnParams.getNetworks().getNetwork().get(i);
                        System.out.println("   ID: " + String.valueOf(network.getOrdinal()));
                        System.out.println("   Name: " + network.getName());
                        System.out.println("   Type: " + network.getType());
                        if (network.getOrdinal() >= numNetworks) {
                            System.out.println("Invalid network Enumeration");
                        }
                        int numInputTensors = network.getInputTensors().getInputTensor().size();
                        int[] inputTensorsOffsets = new int[numInputTensors];
                        int j = 0;
                        while (j < numInputTensors) {
                            int format = 0;
                            InputTensorInfo inputTensorInfo = network.getInputTensors().getInputTensor().get(j);
                            String inputTensorName = inputTensorInfo.getName();
                            if (inputTensorName.length() > 20) {
                                inputTensorName = inputTensorName.substring(inputTensorName.length() - 20);
                            }
                            System.out.println("      Input Tensor" + inputTensorInfo.getOrdinal());
                            System.out.println("         Name = " + inputTensorInfo.getName());
                            System.out.println("         ID = " + inputTensorInfo.getOrdinal());
                            System.out.println("         L2 Offset = " + inputTensorInfo.getL2Offset());
                            System.out.println("         Shift  = " + inputTensorInfo.getShift());
                            System.out.println("         Scale  = " + inputTensorInfo.getScale());
                            System.out.println("         Format  = " + String.valueOf((Object)inputTensorInfo.getFormat()));
                            System.out.println("         Persistency  = " + inputTensorInfo.getPersistency());
                            System.out.println("         Num of Dimensions  = " + inputTensorInfo.getNumOfDimensions());
                            int numDims = inputTensorInfo.getDimensions().getDimension().size();
                            if (numDims != inputTensorInfo.getNumOfDimensions()) {
                                System.out.println("Error: numOfDimensions mismatch in InputTensor " + j);
                            }
                            dimArray = new int[numDims];
                            int k = 0;
                            while (k < numDims) {
                                short dimOrdinal = inputTensorInfo.getDimensions().getDimension().get(k).getOrdinal();
                                int size = inputTensorInfo.getDimensions().getDimension().get(k).getSize();
                                short serOrder = inputTensorInfo.getDimensions().getDimension().get(k).getSerializationOrder();
                                short padding = inputTensorInfo.getDimensions().getDimension().get(k).getPadding();
                                dimArray[k] = FBDimension.createFBDimension(fbb, dimOrdinal, size, serOrder, padding);
                                System.out.println("         Dimension" + dimOrdinal + " = (size/serOrder/padding)" + size + "/" + serOrder + "/" + padding);
                                ++k;
                            }
                            int dimArrayOffset = FBInputTensor.createDimensionsVector(fbb, dimArray);
                            if (inputTensorInfo.getFormat().value().equals(TensorFormatType.SIGNED.value())) {
                                format = 0;
                            } else if (inputTensorInfo.getFormat().value().equals(TensorFormatType.UNSIGNED.value())) {
                                format = 1;
                            } else {
                                System.out.println("Error : Invalid Input Tensor type");
                            }
                            if (inputTensorInfo.getBitsPerElement() != 8) {
                                System.out.println("Error : Invalid bitsPerELement in Input Tensor: Only 8 is supported");
                            }
                            int tensorsOffset = FBInputTensor.createFBInputTensor(fbb, inputTensorInfo.getOrdinal(), fbb.createString(inputTensorName), inputTensorInfo.getNumOfDimensions(), dimArrayOffset, inputTensorInfo.getL2Offset(), inputTensorInfo.getBitsPerElement(), inputTensorInfo.getShift(), inputTensorInfo.getScale(), format, inputTensorInfo.getPersistency());
                            short inputOrdinal = inputTensorInfo.getOrdinal();
                            if (inputOrdinal >= numInputTensors) {
                                System.out.println("Error : Invalid Input Tensor Ordinal");
                            }
                            inputTensorsOffsets[inputTensorInfo.getOrdinal()] = tensorsOffset;
                            ++j;
                        }
                        int numOutputTensors = network.getOutputTensors().getOutputTensor().size();
                        int[] outputTensorsOffsets = new int[numOutputTensors];
                        int j2 = 0;
                        while (j2 < numOutputTensors) {
                            TensorInfo outputTensorInfo = dnnParams.getNetworks().getNetwork().get(i).getOutputTensors().getOutputTensor().get(j2);
                            String outputTensorName = outputTensorInfo.getName();
                            if (outputTensorName.length() > 20) {
                                outputTensorName = outputTensorName.substring(outputTensorName.length() - 20);
                            }
                            System.out.println("     Output Tensor" + outputTensorInfo.getOrdinal());
                            System.out.println("         ID = " + outputTensorInfo.getOrdinal());
                            System.out.println("         L2 Offset = " + outputTensorInfo.getL2Offset());
                            System.out.println("         Name = " + outputTensorName);
                            System.out.println("         bitsPerElement = " + outputTensorInfo.getBitsPerElement());
                            System.out.println("         shift = " + outputTensorInfo.getShift());
                            System.out.println("         scale = " + outputTensorInfo.getScale());
                            System.out.println("         format = " + String.valueOf((Object)outputTensorInfo.getFormat()));
                            System.out.println("         Num of Dimensions  = " + outputTensorInfo.getNumOfDimensions());
                            int numDims = outputTensorInfo.getDimensions().getDimension().size();
                            if (numDims != outputTensorInfo.getNumOfDimensions()) {
                                System.out.println("Error: numOfDimensions mismatch in OutputTensor " + j2);
                            }
                            dimArray = new int[numDims];
                            int format = 0;
                            int k = 0;
                            while (k < numDims) {
                                short ordinal = outputTensorInfo.getDimensions().getDimension().get(k).getOrdinal();
                                int size = outputTensorInfo.getDimensions().getDimension().get(k).getSize();
                                short serOrder = outputTensorInfo.getDimensions().getDimension().get(k).getSerializationOrder();
                                short padding = outputTensorInfo.getDimensions().getDimension().get(k).getPadding();
                                dimArray[k] = FBDimension.createFBDimension(fbb, ordinal, size, serOrder, padding);
                                System.out.println("         Dimension" + k + " = (size/serOrder/padding)" + size + "/" + serOrder + "/" + padding);
                                ++k;
                            }
                            int dimArrayOffset = FBOutputTensor.createDimensionsVector(fbb, dimArray);
                            if (outputTensorInfo.getFormat().value().equals(TensorFormatType.SIGNED.value())) {
                                format = 0;
                            } else if (outputTensorInfo.getFormat().value().equals(TensorFormatType.UNSIGNED.value())) {
                                format = 1;
                            } else {
                                System.out.println("Error : Invalid Output Tensor type");
                            }
                            int tensorsOffset = FBOutputTensor.createFBOutputTensor(fbb, outputTensorInfo.getOrdinal(), fbb.createString(outputTensorName), outputTensorInfo.getNumOfDimensions(), dimArrayOffset, outputTensorInfo.getL2Offset(), outputTensorInfo.getBitsPerElement(), outputTensorInfo.getShift(), outputTensorInfo.getScale(), format);
                            short ordinal = outputTensorInfo.getOrdinal();
                            if (ordinal >= numOutputTensors) {
                                System.out.println("Error : Invalid Output Tensor Ordinal");
                            }
                            outputTensorsOffsets[outputTensorInfo.getOrdinal()] = tensorsOffset;
                            ++j2;
                        }
                        networkOffsets[network.getOrdinal().shortValue()] = FBNetwork.createFBNetwork(fbb, network.getOrdinal().shortValue(), fbb.createString(network.getName()), fbb.createString(network.getType()), FBNetwork.createInputTensorsVector(fbb, inputTensorsOffsets), FBNetwork.createOutputTensorsVector(fbb, outputTensorsOffsets));
                        ++i;
                    }
                    if (dnnParams.getL2Memory().getTotalSize() > 0x800000L) {
                        System.out.println("Invalid L2Memory Size: " + String.valueOf(dnnParams.getL2Memory().getTotalSize()));
                    }
                    if (dnnParams.getL2Memory().getReservedMemorySize() + dnnParams.getL2Memory().getNetworksRuntimeSize() + dnnParams.getL2Memory().getCoefficientsSize() > 0x800000L) {
                        System.out.println("Error in L2Memory Size: Total size of all components > 8MB");
                    }
                    System.out.println("================================");
                    System.out.println(" L2 Memory Details ");
                    System.out.println("Total Memory Size:" + String.valueOf(dnnParams.getL2Memory().getTotalSize()));
                    System.out.println("Reserved Memory Size:" + String.valueOf(dnnParams.getL2Memory().getReservedMemorySize()));
                    System.out.println("Runtime Memory Size:" + String.valueOf(dnnParams.getL2Memory().getNetworksRuntimeSize()));
                    System.out.println("Coefficeints Memory Size:" + String.valueOf(dnnParams.getL2Memory().getCoefficientsSize()));
                    System.out.println("================================");
                    int networksOffset = FBDnnParams.createNetworksVector(fbb, networkOffsets);
                    int l2MemoryOffset = FBL2Memory.createFBL2Memory(fbb, dnnParams.getL2Memory().getTotalSize(), dnnParams.getL2Memory().getReservedMemorySize(), dnnParams.getL2Memory().getNetworksRuntimeSize(), dnnParams.getL2Memory().getCoefficientsSize());
                    fbb.finish(FBDnnParams.createFBDnnParams(fbb, networksOffset, l2MemoryOffset));
                    FileOutputStream stream = null;
                    try {
                        stream = new FileOutputStream(outFileName);
                    }
                    catch (FileNotFoundException e1) {
                        e1.printStackTrace();
                    }
                    try {
                        try {
                            byte[] buffer = fbb.sizedByteArray();
                            stream.write(buffer);
                        }
                        catch (IOException e) {
                            e.printStackTrace();
                        }
                    }
                    finally {
                        try {
                            stream.close();
                        }
                        catch (IOException e) {
                            e.printStackTrace();
                        }
                    }
                }
                catch (IOException | SAXException e2) {
                    e2.printStackTrace();
                }
            }
            catch (ParserConfigurationException e2) {
                e2.printStackTrace();
            }
        }
        catch (JAXBException e) {
            e.printStackTrace();
        }
    }

    class FBTensorFormatType {
        static final int SIGNED = 0;
        static final int UNSIGNED = 1;

        FBTensorFormatType() {
        }
    }
}

