Extraction and processing of segmentation masks of yolov8 instance segmentation model in tflite using android studio

package com.example.appcoffee;

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.Path;
import android.graphics.Point;
import android.graphics.PorterDuff;
import android.graphics.PorterDuffXfermode;
import android.graphics.RectF;
import android.util.Log;

import org.opencv.android.Utils;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.MatOfPoint;
import org.opencv.core.Scalar;
import org.opencv.imgproc.Imgproc;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.common.FileUtil;
import org.tensorflow.lite.support.common.ops.CastOp;
import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.image.ImageProcessor;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;

import java.io.IOException;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

public class InstanceSegmentation {
private static final String MODEL_PATH = “Coffee-Disease-Detection-Model_float32.tflite”;
private static final int TENSOR_WIDTH = 640;
private static final int TENSOR_HEIGHT = 640;

private static final float TENSOR_WIDTH_FLOAT = (float) TENSOR_WIDTH;
private static final float TENSOR_HEIGHT_FLOAT = (float) TENSOR_HEIGHT;
private static final float INPUT_MEAN = 0f;
private static final float INPUT_STANDARD_DEVIATION = 255f;
private static final DataType INPUT_IMAGE_TYPE = DataType.FLOAT32;
private static final DataType OUTPUT_IMAGE_TYPE = DataType.FLOAT32;
private static final int NUM_ELEMENTS = 8400;
private static final int NUM_CHANNELS = 38;
private static final int BATCH_SIZE = 1;
private static final int X_POINTS = 160;
private static final int Y_POINTS = 160;
private static final int MASKS_NUMBERS = 32;
private static final float CONFIDENCE_THRESHOLD = 0.60f;
private static final float IOU_THRESHOLD = 0.6f;

private static final float MASK_THRESHOLD = 0.5f;
private final ImageProcessor imageProcessor = new ImageProcessor.Builder()
        .add(new NormalizeOp(INPUT_MEAN, INPUT_STANDARD_DEVIATION))
        .add(new CastOp(INPUT_IMAGE_TYPE))
        .build();
public Bitmap invoke(Context context, Bitmap bitmap) throws IOException {
    Interpreter interpreter = new Interpreter(FileUtil.loadMappedFile(context, MODEL_PATH));
    Bitmap resizedBitmap = Bitmap.createScaledBitmap(bitmap, TENSOR_WIDTH, TENSOR_HEIGHT, false);

    TensorImage tensor = new TensorImage(DataType.FLOAT32);
    tensor.load(resizedBitmap);
    TensorImage processedImage = imageProcessor.process(tensor);
    Object[] imageBuffer = {processedImage.getBuffer()};
    //FloatBuffer imageBuffer = processedImage.getBuffer().asFloatBuffer();

    TensorBuffer coordinatesBuffer = TensorBuffer.createFixedSize(
            new int[]{BATCH_SIZE, NUM_CHANNELS, NUM_ELEMENTS},
            OUTPUT_IMAGE_TYPE
    );

    TensorBuffer maskProtoBuffer = TensorBuffer.createFixedSize(
            new int[]{BATCH_SIZE, MASKS_NUMBERS, Y_POINTS, X_POINTS},
            OUTPUT_IMAGE_TYPE
    );

    Map<Integer, Object> outputBuffer = new HashMap<>();
    outputBuffer.put(0, coordinatesBuffer.getBuffer().rewind());
    outputBuffer.put(1, maskProtoBuffer.getBuffer().rewind());

    interpreter.runForMultipleInputsOutputs(imageBuffer, outputBuffer);


    float[] coordinates = coordinatesBuffer.getFloatArray(); //
    float[] masks = maskProtoBuffer.getFloatArray(); //

    Log.d("Coordinates", "Coordinates Length: " + coordinates.length);
    Log.d("Masks", "Masks Length:" + masks.length);

    List<Output0> filterOutput0 = bestBox(coordinates, masks);

    if (filterOutput0 == null) {
        return bitmap;
    }
    filterOutput0.get(0);
    Bitmap resultBitmap = Bitmap.createBitmap(640, 640, Bitmap.Config.ARGB_8888);
    Canvas canvas = new Canvas(resultBitmap);

    // Draw the original bitmap on the new bitmap
    canvas.drawBitmap(bitmap, 0, 0, null);

    Paint paint = new Paint();
    // Use the canvas for drawing bounding boxes and labels on the new bitmap
    Paint textPaint = new Paint();
    textPaint.setColor(Color.WHITE);
    textPaint.setTextSize(20); // Adjust text size as needed

    for (Output0 box : filterOutput0) {

        //Normalize the masks using sigmoid function
        //Reshape the array

        paint.setColor(Color.YELLOW);
        paint.setStyle(Paint.Style.STROKE);
        paint.setStrokeWidth(3);

        /*int boxColor;
        if (box.classId == 1) {
            // Unhealthy (red) with 50% transparency
            boxColor = Color.argb(128, 255, 0, 0); // 128 is the alpha value (0 to 255)
        } else if (box.classId == 0) {
            // Healthy (green) with 50% transparency
            boxColor = Color.argb(128, 0, 255, 0); // 128 is the alpha value (0 to 255)
        } else {
            // Unknown Class (default color) with 50% transparency
            boxColor = Color.argb(128, 255, 255, 0); // 128 is the alpha value (0 to 255)
        }*/
        /*paint.setColor(boxColor);
        paint.setStyle(Paint.Style.FILL);*/


        // Draw filled rectangle
        float x1 = box.cx - (box.w / 2F);
        float y1 = box.cy - (box.h / 2F);
        float x2 = box.cx + (box.w / 2F);
        float y2 = box.cy + (box.h / 2F);

        float left = (x1 * resultBitmap.getWidth());
        float top = (y1 * resultBitmap.getHeight());
        float right = (x2 * resultBitmap.getWidth());
        float bottom = (y2 * resultBitmap.getHeight());

        canvas.drawRect(left, top, right, bottom, paint);

        // Display label text
        String label;
        if (box.classId == 1) {
            label = "Unhealthy";
        } else if (box.classId == 0) {
            label = "Healthy";
        } else {
            label = "Unknown Class";
        }

        // Display label and confidence score together
        String confidenceText = label + " (" + String.format("%.2f", box.cnf) + ")";

        // Calculate the position for displaying label and confidence score
        float textX = left;
        float textY = top - 10;

        // Draw label and confidence score together
        canvas.drawText(confidenceText, textX, textY, textPaint);

        float[][] reshapedMask = box.reshapedMasks;
    }

    interpreter.close();
    return resultBitmap;

    //return null;
    //return matToBitmap(mask);
}
private List<Output0> applyNMS(List<Output0> bestOutput0) {
    List<Output0> sortedBoxes = new ArrayList<>(bestOutput0);
    sortedBoxes.sort((o1, o2) -> Float.compare(o2.cnf, o1.cnf));

    List<Output0> selectedBoxes = new ArrayList<>();

    while (!sortedBoxes.isEmpty()) {
        Output0 first = sortedBoxes.get(0);
        selectedBoxes.add(first);
        sortedBoxes.remove(0);

        Iterator<Output0> iterator = sortedBoxes.iterator();
        while (iterator.hasNext()) {
            Output0 nextBox = iterator.next();
            float iou = calculateIoU(first, nextBox);
            if (iou >= IOU_THRESHOLD) {
                iterator.remove();
            }
        }
    }

    return selectedBoxes;
}

private float calculateIoU(Output0 b1, Output0 b2) {
    float x1 = Math.max(b1.cx - (b1.w / 2F), b2.cx - (b2.w / 2F));
    float y1 = Math.max(b1.cy - (b1.h / 2F), b2.cy - (b2.h / 2F));
    float x2 = Math.min(b1.cx + (b1.w / 2F), b2.cx + (b2.w / 2F));
    float y2 = Math.min(b1.cy + (b1.h / 2F), b2.cy + (b2.h / 2F));


    float intersectionArea = Math.max(0F, x2 - x1) * Math.max(0F, y2 - y1);
    float box1Area = b1.w * b1.h;
    float box2Area = b2.w * b2.h;

    return intersectionArea / (box1Area + box2Area - intersectionArea);
}

private static class Output0 {
    float cx;
    float cy;
    float w;
    float h;
    float cnf;
    List<Float> maskWeight;

    int classId;

    float[][] reshapedMasks;

    //float [][] actualMaskValues;


    Output0(float cx, float cy, float w, float h, float cnf, List<Float> maskWeight, int classId, float[][] reshapedMasks) {
        this.cx = cx;
        this.cy = cy;
        this.w = w;
        this.h = h;
        this.cnf = cnf;
        this.maskWeight = maskWeight;
        this.classId = classId;
        this.reshapedMasks = reshapedMasks;
        //this.actualMaskValues = actualMaskValues;
    }
}
private List<Output0> bestBox(float[] array, float[] masks) {
    List<Output0> boundingBoxes = new ArrayList<>();

    for (int c = 0; c < NUM_ELEMENTS; c++) {
        float cnf = array[c + NUM_ELEMENTS * 4];
        if (cnf > CONFIDENCE_THRESHOLD) {
            float cx = array[c];
            float cy = array[c + NUM_ELEMENTS];
            float w = array[c + NUM_ELEMENTS * 2];
            float h = array[c + NUM_ELEMENTS * 3];
            float x1 = cx - (w / 2F);
            float y1 = cy - (h / 2F);
            float x2 = cx + (w / 2F);
            float y2 = cy + (h / 2F);
            if (x1 <= 0F || x1 >= TENSOR_WIDTH_FLOAT) continue;
            if (y1 <= 0F || y1 >= TENSOR_HEIGHT_FLOAT) continue;
            if (x2 <= 0F || x2 >= TENSOR_WIDTH_FLOAT) continue;
            if (y2 <= 0F || y2 >= TENSOR_HEIGHT_FLOAT) continue;

            // Extract class probabilities starting at index c + NUM_ELEMENTS * (MASKS_NUMBERS + 5)
            float[] classProbabilities = new float[2];
            int classIndex = c + NUM_ELEMENTS * (MASKS_NUMBERS + 5);

            // Check if the remaining elements are available
            if (array.length >= classIndex + 2) {
                for (int index = 0; index < 2; index++) {
                    classProbabilities[index] = array[classIndex + index];
                }
            } else {
                // Handle the case where there are not enough elements in the array
                Log.e("YourTag", "Not enough elements in the array for class probabilities");
                continue; // Skip this bounding box
            }
            int classId = findMaxClassId(classProbabilities);

            List<Float> maskWeight = new ArrayList<>();
            for (int index = 0; index < MASKS_NUMBERS; index++) {
                maskWeight.add(array[c + NUM_ELEMENTS * (index + 5)]);
            }

            // Get the corresponding mask for the detected object
            float[] objectMask = new float[Y_POINTS * X_POINTS];
            int maskIndex = MASKS_NUMBERS * Y_POINTS * (int) cy + MASKS_NUMBERS * (int) cx;

            for (int i = 0; i < Y_POINTS * X_POINTS; i++) {
                objectMask[i] = masks[maskIndex + i];
            }

            for (int i = 0; i < objectMask.length; i++) {
                objectMask[i] = sigmoid(objectMask[i]);
            }

            float[][] reshapedMask = reshapeMask(objectMask, Y_POINTS, X_POINTS);

            boundingBoxes.add(new Output0(cx, cy, w, h, cnf, maskWeight, classId, reshapedMask));
        }
    }

    if (boundingBoxes.isEmpty()) return null;

    return applyNMS(boundingBoxes);
}
private float sigmoid(float x) {
    return (float) (1 / (1 + Math.exp(-x)));
}

private float[][] reshapeMask(float[] flatMask, int numRows, int numCols) {
    float[][] reshapedMask = new float[numRows][numCols];
    for (int i = 0; i < numRows; i++) {
        for (int j = 0; j < numCols; j++) {
            reshapedMask[i][j] = flatMask[i * numCols + j];
        }
    }
    return reshapedMask;
}
private int findMaxClassId(float[] classProbabilities) {
    // Assuming class labels: 0 for Healthy, 1 for Unhealthy
    int numClasses = classProbabilities.length;

    if (numClasses == 2) {
        int maxClassId = (classProbabilities[0] > classProbabilities[1]) ? 1 : 0;
        return maxClassId;
    } else {
        // For multi-class classification, find the index with the highest probability
        int maxIndex = -1;
        float maxProbability = -1;

        for (int i = 0; i < numClasses; i++) {
            if (classProbabilities[i] > maxProbability) {
                maxProbability = classProbabilities[i];
                maxIndex = i;
            }
        }

        return maxIndex;
    }
}
private Bitmap matToBitmap(Mat mat) {
    Bitmap outputBitmap = Bitmap.createBitmap(mat.width(), mat.height(), Bitmap.Config.ARGB_8888);
    Utils.matToBitmap(mat, outputBitmap);
    return outputBitmap;
}

}
I just need some help with my code. I am having trouble processing or extracting the masks of my model. The bounding boxes, class labels, classId’s and class scores already works but Im having trouble with the segmentation masks. Can somebody help me. I just direly need it for my thesis project

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.