I am seeking assistance with an issue I’m encountering after fine-tuning a Segment Anything Model (SAM) using a Roboflow dataset. My end goal is to use this custom-trained model to segment kitchen variants from uploaded images.
I have fine-tuned the SAM model and am now attempting to load it for inference. However, when executing my Python script, I receive the following error:
Using device: cpu
Loading SAM model from models/finetuned_sam_model.pth...
Traceback (most recent call last):
File "/home/chandani.vibhakar/Documents/Projects/ai_ml_projects/kitchen-visualizer/onlySam.py", line 29, in <module>
checkpoint = torch.load(SAM_CHECKPOINT_PATH, map_location=torch.device('cpu'))
File "/home/chandani.vibhakar/Documents/Projects/ai_ml_projects/kitchen-visualizer/venv/lib/python3.10/site-packages/torch/serialization.py", line 1486, in load
with _open_zipfile_reader(opened_file) as opened_zipfile:
File "/home/chandani.vibhakar/Documents/Projects/ai_ml_projects/kitchen-visualizer/venv/lib/python3.10/site-packages/torch/serialization.py", line 771, in __init__
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
RuntimeError: PytorchStreamReader failed reading zip archive: failed finding central directory
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/chandani.vibhakar/Documents/Projects/ai_ml_projects/kitchen-visualizer/onlySam.py", line 39, in <module>
raise RuntimeError(f"Failed to load SAM model. Ensure you are using the official checkpoint file. Error: {e}")
RuntimeError: Failed to load SAM model. Ensure you are using the official checkpoint file. Error: PytorchStreamReader failed reading zip archive: failed finding central directory
It appears to be a RuntimeError: PytorchStreamReader failed reading zip archive: failed finding central directory
, suggesting an issue with the .pth
file itself or how PyTorch is attempting to read it.
I am attaching my training script and the relevant inference code snippet for your reference:
Training Script:
import os
import torch
import roboflow
import supervision as sv
from segment_anything import sam_model_registry, SamPredictor
import numpy as np
import cv2
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
# --- 1. Initial Setup ---
HOME = os.getcwd()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_TYPE = "vit_h"
print(f"Home directory: {HOME}")
print(f"Using device: {DEVICE}")
# --- 2. Download Pre-trained SAM Checkpoint ---
print("\nDownloading the pre-trained SAM model (ViT-H)...")
!mkdir -p {HOME}/weights
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {HOME}/weights
CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(f"SAM checkpoint downloaded to: {CHECKPOINT_PATH}")
# --- 3. Download Custom Dataset from Roboflow ---
# ‼️ PASTE YOUR ROBOFLOW CREDENTIALS HERE ‼️
ROBOFLOW_API_KEY = "ROBOFLOW_API_KEY" # Replace with your key
WORKSPACE_ID = "WORKSPACE_ID" # Replace with your workspace ID
PROJECT_ID = "PROJECT_ID" # Replace with your project ID
VERSION_ID = 1 # Replace with your dataset version number
try:
rf = roboflow.Roboflow(api_key=ROBOFLOW_API_KEY)
project = rf.workspace(WORKSPACE_ID).project(PROJECT_ID)
dataset = project.version(VERSION_ID).download("coco")
DATASET_PATH = dataset.location
print(f"\nDataset downloaded successfully to: {DATASET_PATH}")
except Exception as e:
print(f"\n❌ Error downloading dataset: {e}")
print("Please check your Roboflow API key, workspace, project, and version details.")
DATASET_PATH = None
# --- 4. Fine-Tuning the SAM Model ---
if DATASET_PATH:
print(f"\nLoading SAM model of type '{MODEL_TYPE}' to '{DEVICE}'...")
sam_model = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH)
sam_model.to(device=DEVICE)
print("Model loaded successfully.")
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
loss_fn = torch.nn.MSELoss()
# Custom Dataset Class for SAM
class SAMDataset(Dataset):
def __init__(self, dataset_path, model):
self.dataset_path = dataset_path
self.model = model
self.predictor = SamPredictor(self.model)
train_dir = os.path.join(dataset_path, "train")
self.image_paths = sorted([os.path.join(train_dir, f) for f in os.listdir(train_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])
self.annotations_path = os.path.join(train_dir, "_annotations.coco.json")
# 💡 FIX: Use sv.DetectionDataset.from_coco for loading annotations
self.annotations = sv.DetectionDataset.from_coco(
annotations_path=self.annotations_path,
images_directory_path=train_dir
)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
image_name = os.path.basename(image_path)
# Search for corresponding annotation
img_annotations = None
for annotation in self.annotations:
print(f"Checking annotation: {annotation}") # Debugging each annotation
boxes = annotation[0] # Bounding boxes are the first element in the tuple
if boxes is not None: # Ensure boxes exist
img_annotations = annotation
break
if img_annotations is None or len(img_annotations[0]) == 0:
return None
image = cv2.imread(image_path)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.predictor.set_image(image_rgb)
image_embedding = self.predictor.get_image_embedding().squeeze()
# Check the data type of the bounding box
print(f"Bounding box data type: {type(img_annotations[0][0])}")
print(f"Bounding box values: {img_annotations[0][0]}")
# Extract the bounding boxes (first element of the tuple)
try:
box = torch.tensor([float(x) for x in img_annotations[0][0]], device=self.model.device) # Ensure values are floats
except Exception as e:
print(f"Error converting bounding box to tensor: {e}")
return None # Skip this sample if there's an issue with the box
# Extract mask if it exists (second element in the tuple)
if img_annotations[1] is not None:
ground_truth_mask = torch.tensor(img_annotations[1], device=self.model.device).float()
else:
ground_truth_mask = torch.zeros_like(box) # Empty mask if no mask exists
return image_embedding, box, ground_truth_mask
def collate_fn(batch):
batch = list(filter(lambda x: x is not None, batch))
if not batch: return None, None, None
return torch.utils.data.dataloader.default_collate(batch)
sam_dataset = SAMDataset(dataset_path=DATASET_PATH, model=sam_model)
train_dataloader = DataLoader(sam_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
print("\nStarting the fine-tuning process...")
epochs = 10
for epoch in range(epochs):
epoch_losses = []
for image_embedding, box, ground_truth_mask in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
if image_embedding is None: continue
sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(points=None, boxes=box, masks=None)
low_res_masks, _ = sam_model.mask_decoder(
image_embeddings=image_embedding,
image_pe=sam_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
upscaled_masks = sam_model.postprocess_masks(low_res_masks, (1024, 1024), (1024, 1024)).to(DEVICE)
loss = loss_fn(upscaled_masks, ground_truth_mask.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_losses.append(loss.item())
print(f"Epoch {epoch + 1} | Average Loss: {np.mean(epoch_losses):.4f}")
print("\n✅ Fine-tuning finished.")
FINETUNED_CHECKPOINT_PATH = os.path.join(HOME, "weights", "finetuned_sam_model.pth")
torch.save(sam_model.state_dict(), FINETUNED_CHECKPOINT_PATH)
print(f"Fine-tuned model saved to: {FINETUNED_CHECKPOINT_PATH}")
# --- 5. Inference with the Fine-Tuned Model ---
if DATASET_PATH and 'FINETUNED_CHECKPOINT_PATH' in locals():
print("\nRunning inference with the fine-tuned model...")
fine_tuned_model = sam_model_registry[MODEL_TYPE]()
fine_tuned_model.load_state_dict(torch.load(FINETUNED_CHECKPOINT_PATH))
fine_tuned_model.to(device=DEVICE)
predictor = SamPredictor(fine_tuned_model)
valid_dir = os.path.join(DATASET_PATH, "valid")
test_image_names = [f for f in os.listdir(valid_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
if not test_image_names:
print("No validation images found to test inference.")
else:
test_image_path = os.path.join(valid_dir, np.random.choice(test_image_names))
test_image = cv2.imread(test_image_path)
test_image_rgb = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)
predictor.set_image(test_image_rgb)
valid_annotations_path = os.path.join(valid_dir, "_annotations.coco.json")
valid_annotations = sv.DetectionDataset.from_coco(
annotations_path=valid_annotations_path,
images_directory_path=valid_dir
)
img_annotations = valid_annotations[os.path.basename(test_image_path)]
input_box = np.array(img_annotations.xyxy[0])
masks, _, _ = predictor.predict(box=input_box, multimask_output=False)
detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=masks), mask=masks)
# Annotate and display the result
mask_annotator = sv.MaskAnnotator()
annotated_image = mask_annotator.annotate(scene=test_image.copy(), detections=detections)
print("\nDisplaying segmentation result from the fine-tuned model:")
sv.plot_image(annotated_image)
print("\n✨ Script finished successfully!")
Code Snippet:
from segment_anything import sam_model_registry, SamPredictor
import os
import torch
SAM_CHECKPOINT_PATH = 'models/sam_finetuned_model.pth'
SAM_MODEL_TYPE = 'vit_h'
INPUT_IMAGE_PATH = 'images/kitchen14.jpg'
# --- Setup ---
# Load SAM model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
try:
checkpoint = torch.load(SAM_CHECKPOINT_PATH, map_location=torch.device('cpu'))
sam = sam_model_registry[SAM_MODEL_TYPE]()
sam.load_state_dict(checkpoint)
sam.to(device)
predictor = SamPredictor(sam)
print("SAM model loaded successfully.")
except FileNotFoundError:
raise FileNotFoundError(f"SAM checkpoint not found at: {SAM_CHECKPOINT_PATH}. Please download the official model and place it here.")
except Exception as e:
raise RuntimeError(f"Failed to load SAM model. Ensure you are using the official checkpoint file. Error: {e}")
Could anyone shed light on what might be causing this PytorchStreamReader
error when loading a fine-tuned SAM model? Specifically:
- Could the fine-tuning process or saving method result in a corrupted
.pth
file? - Are there specific considerations when saving a fine-tuned SAM model that I might have missed?
- Are there common pitfalls when attempting to load a custom SAM checkpoint that might lead to this “failed finding central directory” error?
Any guidance or suggestions on troubleshooting this issue would be greatly appreciated.
Thank you for your time and assistance.