Training Florence2 for Ball Detection
This guide demonstrates how to train a Florence2 model specifically for baseball detection using BaseballCV’s ball dataset. Florence2’s vision-language capabilities make it particularly suitable for precise ball detection and tracking.
Setting Up the Training Pipeline
First, let’s create a comprehensive training pipeline that leverages BaseballCV’s utilities:
from baseballcv.functions import LoadTools
from baseballcv.model import Florence2
import os
import torch
class BallDetectionPipeline:
def __init__(self):
"""Initialize pipeline components"""
self.load_tools = LoadTools()
# Initialize Florence2 with appropriate batch size
self.model = Florence2(
model_id='microsoft/Florence-2-large',
batch_size=4 # Adjust based on GPU memory
)
# Define class mapping
self.classes = {
0: "baseball" # Single class for ball detection
}
# Set up output directories
self.output_dir = "ball_detector"
os.makedirs(self.output_dir, exist_ok=True)
def prepare_dataset(self):
"""
Load and prepare the ball detection dataset
"""
# Load the baseball-only dataset
dataset_path = self.load_tools.load_dataset("baseball")
return dataset_path
def configure_training(self):
"""
Configure Florence2 training parameters
"""
training_config = {
# Dataset parameters
'train_test_split': (80, 10, 10), # Training/Test/Validation split
# Training hyperparameters
'epochs': 20,
'lr': 4e-6,
'batch_size': 4,
# LoRA parameters for efficient fine-tuning
'lora_r': 8,
'lora_scaling': 8,
'lora_dropout': 0.05,
# Training optimizations
'warmup_epochs': 1,
'gradient_accumulation_steps': 2,
'lr_schedule': "cosine",
# Early stopping settings
'patience': 5,
'patience_threshold': 0.01,
# Model saving
'save_dir': self.output_dir,
# Hardware utilization
'num_workers': 4 # Adjust based on CPU cores
}
return training_config
def train(self):
"""
Execute the training pipeline
"""
dataset_path = self.prepare_dataset()
config = self.configure_training()
metrics = self.model.finetune(
dataset=dataset_path,
classes=self.classes,
**config
)
return metrics
# Initialize and run the pipeline
pipeline = BallDetectionPipeline()
training_metrics = pipeline.train()
Dataset Structure and Preparation
The baseball dataset in BaseballCV is already optimized for training, but understanding its structure is important:
def inspect_dataset(dataset_path):
"""
Analyze the ball detection dataset structure
"""
# Dataset statistics
train_path = os.path.join(dataset_path, "train")
val_path = os.path.join(dataset_path, "valid")
test_path = os.path.join(dataset_path, "test")
print(f"Training images: {len(os.listdir(os.path.join(train_path, 'images')))}")
print(f"Validation images: {len(os.listdir(os.path.join(val_path, 'images')))}")
print(f"Test images: {len(os.listdir(os.path.join(test_path, 'images')))}")
Training Configuration Details
Florence2 training can be customized through various parameters. Here’s a detailed configuration:
def detailed_training_config(self):
"""
Detailed training configuration with explanations
"""
return {
# Model configuration
'dataset': self.dataset_path,
'classes': self.classes,
'epochs': 20,
'batch_size': 4,
# Optimization parameters
'lr': 4e-6, # Learning rate for fine-tuning
'weight_decay': 0.01,
'warmup_epochs': 1,
# LoRA specific parameters
'lora_r': 8, # LoRA attention dimension
'lora_scaling': 8, # LoRA alpha scaling factor
'lora_dropout': 0.05,
# Training stability
'gradient_accumulation_steps': 2,
'patience': 5, # Early stopping patience
'patience_threshold': 0.01,
# Hardware utilization
'num_workers': 4, # DataLoader workers
# Model saving
'save_dir': "ball_detector",
'create_peft_config': True # Enable LoRA configuration
}
Monitoring Training Progress
Monitoring training progress helps identify issues early:
def setup_training_monitoring(self):
"""
Configure training monitoring and visualization
"""
import matplotlib.pyplot as plt
class TrainingMonitor:
def __init__(self):
self.train_losses = []
self.val_losses = []
self.learning_rates = []
def update(self, train_loss, val_loss, lr):
self.train_losses.append(train_loss)
self.val_losses.append(val_loss)
self.learning_rates.append(lr)
def plot_metrics(self, save_path):
# Plot loss curves
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(self.train_losses, label='Train Loss')
plt.plot(self.val_losses, label='Val Loss')
plt.title('Training Progress')
plt.legend()
# Plot learning rate
plt.subplot(1, 2, 2)
plt.plot(self.learning_rates)
plt.title('Learning Rate Schedule')
plt.savefig(save_path)
plt.close()
return TrainingMonitor()
Model Evaluation
After training, evaluate the model’s performance:
def evaluate_model(self):
"""
Evaluate trained model performance
"""
# Load test dataset
test_dataset = self.load_tools.load_dataset(
"baseball",
subset="test"
)
# Run evaluation
metrics = self.model.evaluate(
base_path=test_dataset,
classes=self.classes,
num_workers=4,
dataset_type="yolo"
)
return metrics
Practical Tips for Florence2 Training
When training Florence2 for ball detection, consider these important factors:
1. Memory Management
Florence2 is a large model. Optimize memory usage by:
- Using gradient accumulation
- Implementing LoRA for efficient fine-tuning
- Adjusting batch size based on available GPU memory
- Utilizing mixed precision training
This training pipeline provides a robust foundation for fine-tuning Florence2 specifically for baseball detection. The configuration can be adjusted based on your specific needs and computational resources.