How to deploy a custom ViT based classification model to a on premises API?

Hello lads

I have trained a custom classification model based on vision transformers ViT. After downloaded the model weights(.pt) I tried loading it using regular PyTorch load method but I got module src not found.
What is the best option to load weights and make predictions using ViT based model?

Any help is appreciated.
Thanks

Hello! Would you mind providing a short code snippet that shows how you’re loading the model when you get the error?

Yes. I can do it.

import torch

model = torch.load(“D:\Kleber\MetricaModelos\models\screen_detector\screen_detector_classifier_v1.pt”,weights_only=False)

print(“End”)

Also, find attached few screenshots that may help.

This is my model I have downloaded the weights to a local folder:

I see – I’m guessing it is something to do with your pytorch installation. We highly recommend using https://inference.roboflow.com/ to infer on models trained in Roboflow!

It could be but I dont think so as I am loading few other external models with Pytorch successfully.

Moreover, the documentation about Inference Python Module is confusing. I did not find how I can authenticate it using API-KEY. If you please can point me out where I can find it, it would be great.

Thanks

Absolutely! Here is the documentation page on authenticating with the Roboflow API Key. The best way is to set it as an environment variable! Retrieve Your API Key - Roboflow Inference

Hey @Lake , the Roboflow inference worked.

However, is there any option to run inference completely off-line rather than using inference server in the background? The reason is: even using GPU it is slow and after hundreds of predictions it is throwing a 500 Internal Server Error.

I mean I just want to load trained model, call whatever is predict functions and get the results.

Thanks

Regarding the issue “module not found src”, it seems the .pt weight files has serialized the folder structure as well. After some trial and error I got the following message: “src.classification.timm_model.ImageClassifier” needs to be added as Safe Globals.

The issue is: I don’t have the source code for src.classification.timm_model.ImageClassifier class neither I have the model architecture saved.

So the question is: is it a bug while saving model weight file? If not, how can I get model architecture and create a workaround for it?

Thanks

There are a couple options to run inference without the hosted API.

You can run an inference server on localhost and then call that using the Inference Client.

The other options is to instantiate a Roboflow model and run an inference on it directly. Here’s a simple snippet to do that:

import cv2
from inference import get_model

model = get_model(model_id="screendetectorclassification/1", api_key=ROBOFLOW_API_KEY)
image = cv2.imread("/path/to/image.jpg")
results = model.infer(image)[0]
print(results)

Take a look at the Classification Inference Model reference for more information on what you can pass the infer method.

Thanks @bruno

using model.infer is much faster than roboflow module.
I had problems before because I was using model.predict and I had to perform many array operations before pass it to the method.

By the way, to you have any idea I can’t load model into PyTorch using torch.load()?

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.