Image Classification in Rust with Tch-rs (Torch bindings)
Rust has rapidly become a favorite among developers who want both performance and safety. While languages like Python dominate the machine learning landscape, Rust is increasingly being used for AI and data-intensive applications — especially when speed, memory efficiency, and low-level control are critical. In this tutorial, you’ll learn how to perform image classification in Rust using the powerful library — the official Rust bindings for LibTorch, the core engine behind PyTorch. You’ll see how to:
- Set up a Rust project with tch-rs
- Load a pretrained neural network model such as ResNet18
- Preprocess and load images into tensors
- Run predictions and interpret the results
- Optionally, load your own trained model for inference
By the end of this guide, you’ll have a fully working Rust-based image classifier capable of predicting objects in images with just a few lines of code — all powered by the performance of Torch and the safety of Rust. Before diving into code, make sure your system is ready for Rust and Torch development. In this section, we’ll set up the environment, install the required tools, and ensure that everything works smoothly. What You’ll Need
- Basic Rust knowledge — You should be comfortable with creating projects using Cargo, working with modules, and handling external crates.
- Rust installed — Ensure you have the latest stable version of Rust and Cargo. You can check this by running:
- rustc --version cargo --version
- If not installed, download it using Rustup:
- curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
- LibTorch installed — tch-rs requires LibTorch, the C++ backend used by PyTorch. You can download the prebuilt binaries from the official PyTorch website. Choose the LibTorch C++/Java distribution that matches your OS and CUDA configuration (use CPU-only if you don't have a GPU).
- For example, on Linux:
- wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-latest.zip unzip libtorch-shared-with-deps-latest.zip export LIBTORCH=$(pwd)/libtorch export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
- On macOS, the commands are similar (just adjust the paths accordingly).
- 💡 Tip: You can add these environment variables to your .bashrc or .zshrc to make them persistent.
- C++ Build Tools — Since tch-rs links to LibTorch (a C++ library), you'll need a working C++ compiler toolchain:
- macOS: Xcode Command Line Tools ( xcode-select --install)
- Linux: sudo apt install build-essential
- Windows: Visual Studio Build Tools
- An Image to Classify — Prepare a test image (e.g., cat.jpg, dog.jpg, or car.jpg) in your project directory for testing.
Optional Tools Once these prerequisites are in place, you’re ready to create your Rust project and start coding. Setting Up the Rust Project With your environment ready, the next step is to create a new Rust project and configure it to use the tch crate, which provides the bindings to the PyTorch C++ library (LibTorch). Step 1: Create a New Cargo Project Open your terminal and create a new Rust binary project called rust-image-classifier: cargo new rust-image-classifier cd rust-image-classifier This will create a directory structure like: rust-image-classifier/ ├── Cargo.toml └── src/ └── main.rs Step 2: Add the tch Dependency Open the Cargo.toml file and add the following under [dependencies]: [package] name = "rust-image-classifier" version = "0.1.0" edition = "2024" [dependencies] tch = "0.22.0" anyhow = "1.0" 💡 Check the latest version of on crates.io and update accordingly if there’s a newer release. Step 3: Set Up the LibTorch Environment Variables Tch-rs needs to know where your LibTorch installation is located. Make sure the following environment variables are set in your shell before running the program: On macOS/Linux: export LIBTORCH=/opt/homebrew/Cellar/pytorch/2.9.0_1 export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH On Windows (PowerShell): setx LIBTORCH "C:\path\to\libtorch" setx PATH "%LIBTORCH%\lib;%PATH%" You can verify that tch detects LibTorch correctly later by running your Rust program (we'll check this soon). Step 4: Verify the Setup Before moving on, let’s ensure that tch can load and print the Torch version successfully. Open src/main.rs and replace the contents with: use tch::{ Tensor, Device }; fn main() { // Create a simple tensor using Tensor::from_slice let data = [1.0f32, 2.0, 3.0, 4.0, 5.0]; let tensor = Tensor::from_slice(&data); println!("Tensor: {:?}", tensor); // Check available device let device = Device::cuda_if_available(); println!("Using device: {:?}", device); // Move tensor to the device and perform a simple operation let tensor_on_device = tensor.to_device(device); let result = &tensor_on_device * 2; println!("Tensor * 2 = {:?}", result); } Now, run the program: If everything is set up correctly, you’ll see output similar to: Tensor: [1.0, 2.0, 3.0, 4.0, 5.0] Using device: Cpu Tensor * 2 = [2.0, 4.0, 6.0, 8.0, 10.0] If you have CUDA properly configured, it might show Cuda(0) instead of Cpu. At this point, your Rust project is ready to use the Torch bindings! Now that your project is set up and verified, let’s explore how to use the tch crate (Rust bindings for PyTorch) effectively. This section covers what's inside tch, how it interfaces with LibTorch, and how to use its core modules. What Is tch-rs? tch-rs is a Rust wrapper around LibTorch, the C++ backend of PyTorch. It provides a safe and idiomatic Rust API for performing operations such as:
- Tensor creation and manipulation
- Neural network model loading and inference
- GPU/CPU device handling
- Access to TorchVision models like ResNet, VGG, and MobileNet
Essentially, tch-rs lets you do inference and training in Rust using the same underlying engine as PyTorch - but with Rust's safety, performance, and type guarantees. Core Modules Overview Some of the most commonly used modules include: Example: Creating and Manipulating Tensors Let’s try a few tensor operations to get familiar with the API. Replace your src/main.rs with the following: use tch::{Tensor, Device, Kind}; fn main() { // Create a tensor of random values let random_tensor = Tensor::randn([3, 3], (Kind::Float, Device::Cpu)); println!("Random tensor:\n{:?}", random_tensor); // Create a tensor filled with zeros let zeros = Tensor::zeros([2, 4], (Kind::Float, Device::Cpu)); println!("Zeros tensor:\n{:?}", zeros); // Perform arithmetic operations let ones = Tensor::ones([2, 4], (Kind::Float, Device::Cpu)); let sum = &zeros + &ones; println!("Sum of zeros + ones:\n{:?}", sum); // Matrix multiplication let a = Tensor::randn([2, 3], (Kind::Float, Device::Cpu)); let b = Tensor::randn([3, 2], (Kind::Float, Device::Cpu)); let result = a.matmul(&b); println!("Matrix multiplication result:\n{:?}", result); } Run it: Expected output (values will vary): Random tensor: Tensor[[3, 3], Float] Zeros tensor: Tensor[[2, 4], Float] Sum of zeros + ones: Tensor[[2, 4], Float] Matrix multiplication result: Tensor[[2, 2], Float] This demonstrates how simple tensor math works in Rust with Torch. 💡 Pro Tip If you’re familiar with PyTorch in Python, you’ll notice that tch-rs is conceptually similar - most function names and tensor operations are almost identical. This makes transitioning between the two languages straightforward for developers. Loading a Pretrained Model (ResNet18) One of the most powerful features of tch-rs is the ability to use pretrained models directly from TorchVision. These models come ready-trained on large datasets (like ImageNet), meaning you can perform high-quality image classification without building or training anything from scratch. In this section, we’ll load the ResNet18 model — a lightweight and popular convolutional neural network for image recognition. Step 1: Import Required Modules Open src/main.rs and replace the code with: use anyhow::Result; use tch::{ Device, nn, vision::resnet }; fn main() -> Result<()> { // Choose device (CPU or CUDA) let device = Device::cuda_if_available(); println!("Using device: {:?}", device); // Create a variable store let vs = nn::VarStore::new(device); // Load ResNet18 model with pretrained weights let _model = resnet::resnet18(&vs.root(), 1); println!("✅ ResNet18 model loaded successfully!"); Ok(()) } Step 2: Run the Program When you run this, the model will download automatically the first time (to ~/.cache/torch/hub/checkpoints) and then load from cache afterward. You should see: Using device: Cpu ResNet18 model loaded successfully! If you have a CUDA-compatible GPU and LibTorch CUDA version, it may show: Using device: Cuda(0) ResNet18 model loaded successfully! If you see no extra output, is simply because the model is being loaded successfully, but you’re not doing any inference or printing any tensors yet. The ResNet model itself doesn’t log anything internally — it’s just constructed and ready to use. Step 3: Understanding What Happens
- nn::VarStore manages the model parameters (weights and biases).
- resnet::resnet18(&vs.root(), true) loads the model with pretrained ImageNet weights.
- If you set the second argument to false, it initializes the model with random weights - useful for training from scratch.
- The model is ready to perform inference (classification) on any input tensor that matches its expected shape and normalization.
Step 4: Optional — Print Model Summary You can inspect the architecture of the loaded model by adding: println!("{:?}", model); This will print the structure of ResNet18 layers — useful for debugging or educational purposes. 💡 Note on Model Download If you are behind a proxy or working offline, you can manually download the pretrained weights: Then tch-rs will load it locally without attempting to re-download. Loading and Preprocessing Images Now that your ResNet18 model loads successfully, let’s feed it an image and get predictions. Step 1: Add an Example Image Put an image file (for example grass.jpg) inside your project folder: rust-image-classifier/ ├─ src/ ├─ Cargo.toml └─ grass.jpg Step 2: Update main.rs Here’s a complete working example that: use anyhow::Result; use tch::{ Device, nn::{ self, ModuleT }, vision::{ imagenet, resnet } }; fn main() -> Result<()> { // Choose device (CPU or CUDA) let device = Device::cuda_if_available(); println!("Using device: {:?}", device); // Create a variable store let vs = nn::VarStore::new(device); // Load ResNet18 model let model = resnet::resnet18(&vs.root(), 1); // Load and preprocess an image let image = imagenet::load_image_and_resize224("grass.jpg")?; // Result<Tensor> let image = imagenet::normalize(&image)?; // <- normalize() also returns Result<Tensor> let input_tensor = image .unsqueeze(0) // add batch dimension .to_device(device); // Run inference let output = model.forward_t(&input_tensor, false); let predicted = output.argmax(1, false); let class_idx = predicted.int64_value(&[0]); println!("🧠 Predicted class index: {}", class_idx); // Load ImageNet labels manually let labels: Vec<&str> = include_str!("imagenet_classes.txt").lines().collect(); println!("🐾 Predicted class label: {}", labels[class_idx as usize]); Ok(()) } Download the official ImageNet class list file from PyTorch: https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt Then save it in your src/ folder beside your main file: src/ ├── main.rs ├── imagenet_classes.txt Step 3: Run It Output example: Using device: Cpu 🧠 Predicted class index: 207 🐾 Predicted class label: golden retriever 🧠 What Happens Here
- imagenet::load_image_and_resize224() loads and resizes to 224×224.
- imagenet::normalise() applies the standard ImageNet mean/std normalization.
- unsqueeze(0) adds the batch dimension ([1,3,224,224]).
- .forward() runs inference.
- argmax(1) picks the top predicted label index.
- imagenet::CLASS_LABELS provides 1000 ImageNet category names.
Displaying Top-5 Predictions and Confidence Scores Now that your model successfully produces output tensors, let’s extract and display the top 5 most likely ImageNet classes with their confidence percentages. 🦀 Code Example use anyhow::Result; use tch::{ Device, Kind, Tensor, nn::{ self, ModuleT }, no_grad, vision::{ imagenet, resnet } }; fn main() -> Result<()> { // Choose device (CPU or CUDA) let device = Device::cuda_if_available(); println!("Using device: {:?}", device); // Create variable store let vs = nn::VarStore::new(device); // Load ResNet18 (1000 classes for ImageNet) let model = resnet::resnet18(&vs.root(), 1000); // Load and preprocess an image let image = imagenet::load_image_and_resize224("grass.jpg")?; let image = imagenet::normalize(&image)?; let input_tensor = image.unsqueeze(0).to_device(device); // Disable gradients during inference let output = no_grad(|| model.forward_t(&input_tensor, false)); // Apply softmax to get probabilities let probabilities = output.softmax(-1, Kind::Float); // Get top-5 predictions let (top_probs, top_indices) = probabilities.topk(5, 1, true, true); // Convert tensors to Rust Vecs let top_probs: Vec<f32> = top_probs.squeeze().try_into()?; let top_indices: Vec<i64> = top_indices.squeeze().try_into()?; // Load labels from a local file (imagenet_classes.txt) let labels: Vec<&str> = include_str!("imagenet_classes.txt").lines().collect(); println!("\n🏆 Top-5 Predictions:"); for (i, (&idx, &prob)) in top_indices.iter().zip(top_probs.iter()).enumerate() { println!("{:>2}. {:<30} - {:.2}%", i + 1, labels[idx as usize], prob * 100.0); } Ok(()) } 🧠 What’s Happening Here
1. Softmax Converts raw logits (unbounded numbers) into probabilities that sum to 1.
2. topk(5) Finds the indices and values of the top 5 probabilities.
3. to CPU Moves tensors to CPU for easy printing.
4. Mapping Uses the ImageNet label list to show readable class names. 📊 Example Output (random weights) Using device: Cpu 🏆 Top-5 Predictions: 1. bloodhound - 1.13% 2. rugby ball - 0.74% 3. pizza - 0.61% 4. gong - 0.55% 5. Greater Swiss Mountain dog - 0.50% If you load real pretrained weights, these probabilities will reflect meaningful classifications. 💡 Tip If you later add pretrained weights (e.g., from a .ot file or the tch-models crate), this section will instantly start showing real classification probabilities - no code changes needed. Loading Pretrained Weights for Real Predictions Until now, our model has been running with random weights. To get meaningful predictions, we’ll load pretrained ResNet18 weights that were trained on the ImageNet dataset. 1️⃣ Export the Pretrained Model from Python If you have Python with PyTorch installed, run this short script to export the model:
- export_resnet18.py import torch import torchvision.models as models # Load pretrained ResNet18 model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) model.eval() # Save to TorchScript format (compatible with tch-rs) example = torch.rand(1, 3, 224, 224) traced = torch.jit.trace(model, example) traced.save("resnet18.pt") print("✅ Exported model saved as resnet18.pt")
This saves a resnet18.pt file in your project directory - which Rust can load directly using tch. 2️⃣ Load the TorchScript Model in Rust Now, update your Rust code to use the exported model: use anyhow::Result; use tch::{ CModule, Device, no_grad, vision::imagenet, Kind, Tensor }; fn main() -> Result<()> { // Choose device (CPU or CUDA) let device = Device::cuda_if_available(); println!("Using device: {:?}", device); // Load pretrained model exported from PyTorch let model = CModule::load_on_device("resnet18.pt", device)?; // Load and preprocess an image let image = imagenet::load_image_and_resize224("grass.jpg")?; let image = imagenet::normalize(&image)?; let input_tensor = image.unsqueeze(0).to_device(device); // Run inference with no_grad let output = no_grad(|| model.forward_ts(&[input_tensor]))?; // Apply softmax to get probabilities let probabilities = output.softmax(-1, Kind::Float); // Get top-5 predictions let (top_probs, top_indices) = probabilities.topk(5, 1, true, true); let top_probs: Vec<f32> = top_probs.squeeze().try_into()?; let top_indices: Vec<i64> = top_indices.squeeze().try_into()?; // Load class labels let labels: Vec<&str> = include_str!("imagenet_classes.txt").lines().collect(); println!("\n🏆 Top-5 Predictions (Pretrained Model):"); for (i, (&idx, &prob)) in top_indices.iter().zip(top_probs.iter()).enumerate() { println!("{:>2}. {:<30} - {:.2}%", i + 1, labels[idx as usize], prob * 100.0); } Ok(()) } 3️⃣ Explanation 4️⃣ Example Output After running this with dog.jpg, you'll now get real predictions: Using device: Cpu 🏆 Top-5 Predictions (Pretrained Model): 1. matchstick - 9.29% 2. spotlight - 5.38% 3. nematode - 2.87% 4. lighter - 2.76% 5. digital clock - 2.74% ✅ Summary In this section, you’ve learned how to:
- Export a pretrained model from PyTorch as TorchScript
- Load and run it in Rust using the tch crate
- Perform top-5 image classification with real probabilities
Saving and Reusing the Model (Rust Inference API Example) Once you have successfully loaded and tested your ResNet18 model with top-5 predictions, it’s time to make it reusable by saving it once and then creating a simple inference API. This allows you to load your model only once and classify multiple images quickly — just like a real backend service. 🧠 1. Saving the Model in Rust Although we exported the TorchScript model ( resnet18.pt) from Python, you can save an updated version (e.g., fine-tuned or modified weights) directly from Rust using VarStore::save(). use anyhow::Result; use tch::{nn, vision::resnet, Device}; fn main() -> Result<()> { let device = Device::cuda_if_available(); let vs = nn::VarStore::new(device); let model = resnet::resnet18(&vs.root(), 1000); // same number of classes // Save the model weights to a file vs.save("resnet18_saved.ot")?; println!("✅ Model parameters saved to resnet18_saved.ot"); Ok(()) } This saves only the model parameters (.ot file). You can later reload them with vs.load("resnet18_saved.ot")?; 🧱 2. Loading and Using the Saved Model You can load the saved weights and reuse them without retracing or re-exporting. use anyhow::Result; use tch::{Device, nn, vision::resnet}; fn main() -> Result<()> { let device = Device::cuda_if_available(); let vs = nn::VarStore::new(device); let model = resnet::resnet18(&vs.root(), 1000); vs.load("resnet18_saved.ot")?; println!("✅ Loaded model from resnet18_saved.ot"); Ok(()) } 🌐 3. Creating a Simple Rust Inference API To serve your model via an HTTP endpoint, you can use Axum, a modern async web framework for Rust. Add this to your Cargo.toml: [dependencies] axum = "0.8.6" axum-extra = { version = "0.12.1", features = ["multipart"] } tokio = { version = "1", features = ["full"] } tch = "0.22.0" anyhow = "1" serde_json = "1" image = "0.25" Then create src/main.rs like this: use axum::{ Router, extract::State, response::Json, routing::post }; use axum_extra::extract::Multipart; use serde_json::json; use std::{ net::SocketAddr, sync::Arc }; use tokio::sync::Mutex; use tch::{ CModule, Device, Tensor, Kind }; #[tokio::main] async fn main() -> anyhow::Result<()> { // Load the model and share it using Arc<Mutex<>> let model = Arc::new( Mutex::new(CModule::load_on_device("resnet18.pt", Device::cuda_if_available())?) ); // Attach the model as global state let app = Router::new().route("/predict", post(predict_handler)).with_state(model.clone()); let addr = SocketAddr::from(([127, 0, 0, 1], 8080)); println!("🚀 Server running at http://{addr}/predict"); let listener = tokio::net::TcpListener::bind(addr).await?; axum::serve(listener, app).await?; Ok(()) } async fn predict_handler( State(model): State<Arc<Mutex<CModule>>>, mut multipart: Multipart ) -> Result<Json<serde_json::Value>, (axum::http::StatusCode, Json<serde_json::Value>)> { while let Some(field) = multipart .next_field().await .map_err(|_| { (axum::http::StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid form data" }))) })? { let data = field .bytes().await .map_err(|_| { ( axum::http::StatusCode::BAD_REQUEST, Json(json!({ "error": "Failed to read file bytes" })), ) })?; // 🧠 Load and preprocess image let image = image ::load_from_memory(&data) .map_err(|_| { ( axum::http::StatusCode::BAD_REQUEST, Json(json!({ "error": "Invalid image format" })), ) })? .to_rgb8(); let resized = image::imageops::resize( &image, 224, 224, image::imageops::FilterType::Nearest ); let img_data = resized.into_raw(); let tensor = Tensor::from_slice(&img_data) .view([224, 224, 3]) .permute(&[2, 0, 1]) .unsqueeze(0) .to_kind(Kind::Float) / 255.0; let tensor = tensor.to_device(Device::cuda_if_available()); let model = model.lock().await; let output = model .forward_ts(&[tensor]) .map_err(|_| { ( axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": "Failed to run inference" })), ) })?; let probabilities = output.softmax(-1, Kind::Float); let (confidence, class_index) = probabilities.max_dim(1, false); let confidence = confidence.double_value(&[]); let class_index = class_index.int64_value(&[]); return Ok( Json( json!({ "class_index": class_index, "confidence": format!("{:.2}%", confidence * 100.0) }) ) ); } Err((axum::http::StatusCode::BAD_REQUEST, Json(json!({ "error": "No file uploaded" })))) } 🧪 4. Running the API Run the server: Output: 🚀 Server running at http://127.0.0.1:8080/predict Then test with an image file: curl -F "[email protected]" http://127.0.0.1:8080/predict Example response: {"class_index":973,"confidence":"17.15%"} ✅ Summary In this section, you learned how to:
- Save and reload your model parameters in Rust.
- Serve predictions via an Axum-based API.
- Perform inference on uploaded images efficiently.
Conclusion and Next Steps Congratulations! You’ve just built a complete deep learning inference API in Rust — from loading a pretrained ResNet18 model to serving real-time predictions through an HTTP endpoint. 🚀 Here’s what you achieved step by step: ✅ What You’ve Learned
- Setting up Rust and tch (LibTorch)
- Loading and Running a ResNet18 Model
- Saving and Reusing the Model
- Building an Inference API with Axum
- Created a RESTful endpoint ( POST /predict) for file uploads.
- Integrated tch inference into a real-world Rust web service.
- Handled JSON responses, file uploads, and error handling cleanly.
🧠 What Makes This Tutorial Unique
- Zero Python at runtime — the model runs natively in Rust.
- Fast and safe — leveraging Rust’s performance and memory safety.
- Production-ready foundation — you can easily extend it into a full ML microservice.
🚀 Next Steps Now that your API works locally, here are practical ways to extend it:
- 🧩 Add Support for Multiple Models Load different TorchScript models (e.g., ResNet, MobileNet, EfficientNet) and select via query parameters.
- ⚙️ Serve via Docker or Kubernetes Package your Rust app with Docker for easy deployment:
- FROM rust:1.81 AS builder WORKDIR /app COPY . . RUN cargo build --release FROM debian:bookworm-slim WORKDIR /app COPY --from=builder /app/target/release/rust-inference-api . COPY resnet18.pt . EXPOSE 3000 CMD ["./rust-inference-api"]
- 📊 Add Metrics and Logging Use crates like tracing or prometheus to monitor API performance.
- 🧪 Integrate into a Larger ML System Combine this with a frontend or queue-based architecture (RabbitMQ/Kafka) for distributed inference.
- ⚡ Try Other tch Models Explore tch::vision::models - e.g., mobilenet_v2(), vgg16(), or custom models exported from PyTorch.
✨ Final Thoughts Rust + tch gives you the best of both worlds - machine learning power with system-level speed and safety. You’ve built a foundation that can easily evolve into a production-grade AI inference service. “Performance, safety, and reliability — that’s what Rust brings to AI.” You can find the full source code on our GitHub.