Jump to content

Lessons Learned From Writing a Data Pipeline in Rust Using Tokio

From JOHNWICK

I needed to write a data pipeline in Rust for a machine learning project I am working on.

Some Context

I will not discuss the details of the data or model here but here are some general parameters that defined the scope.

  • The source of the data is an third party API with no rate limits, each training/inference records requires between 1 and 5 requests to the third party API.
  • Each request returns data on a single potential training/inference record.
  • There is a dataset that is created upfront that final data records are joined on.
  • Records are filter after the initial set of requests if some criteria is met.
  • Given the number of requests needed for each record, the pipeline is very I/O heavy.
  • Significant CPU work is needed to perform aggregations.

Given these constraints, some of the things to consider are balancing heavy I/O with somewhat expensive CPU work. Generally using tokio, I have read this it is recommended not to mix this work, so this was certainly something I considered. For example, I have heard of projects that leverage multiple tokio runtimes to separate these concerns, though I did not use this approach. However, tokio does all you to specify blocking and non blocking threads to separate cpu bound work from I/O heavy work. This is where I settled.

I used tokio, reqwests, futures, and of course serde as the core components to achieve this multi-step pipeline using channels to propagate data forward. The destination is a number of parquet files in S3, that are compressed using an AWS Glue job before model training occurs. This was the first real project I built using tokio. I learned a lot about the patterns that worked well throughout this process, and I want to pass these on.

The inference pipeline is very similar, but using a more real time data API. Use Non-Blocking Threads for CPU Bound Work

Regular tokio tasks are not well fit for CPU bound work, as they are designed for non-blocking I/O work. Tokio threads work in a cooperative fashion, CPU bound task will not voluntarily yield, and thus will block other threads from making progress. Fortunately, there is a family of tokio threads that are not expected to yield. These are blocking tokio threads.

// non-blocking thread (defualt)
// this will starve other threads
let task = tokio::task::spawn(async move|| {
  // cpu heavy work
}).await?;

let task = tokio::task::spawn_blocking(async move|| {
  // cpu heavy work
}).await?;

There are different approaches to using non blocking threads. One of which is leveraging rayon iterators. In this approach, data items for processing are queued in a blocking thread, when either some batch size is met, or some timeout, then the items are submitted to a rayon iterator for data transformation or aggregation work. This is a similar to the inference batching technique used in high throughput ML services. The caveat here is that channel operations need to use the blocking APIs on Senders and Receivers, given that blocking threads run in synchronous fashion. Rayon should also be configured to use less threads than the default, which is the number of cores available on the machine. This can be done by setting thenRAYON_NUM_THREADS environment variable to the desired number of threads allocated for rayon workers. Some pseudo code to describe the pattern.

struct Data {
  // data for processing
}

struct TransformedData {
  // post transform data
}

let (tx, mut rx) = tokio::sync::mpsc::channel<Data>(CHANNEL_SIZE);
let (tx2, mut rx2) = tokio::sync::mpsc::channel<TransformedData>(CHANNEL_SIZE);

let task = tokio::task::spawn_blocking(async move|| {
  let BATCH_SIZE = ...;
  let mut buffer: Vec<Data> = Vec::with_capacity(BATCH_SIZE);

  while let Some(item) = rx.blocking_recv() {
    buffer.push(item);

    if buffer.len() == BATCH_SIZE {
      let transformed_data = data
          .par_iter()
          .map(|item| {
            // transformation on item
          })
          .collect::<Vec<TransformedData>>();


    for item in transformed_data.into_iter() {
       rx2.blocking_send(item)
    }
    
    buffer.clear();
  }

// clean up the rest when rx's strong sender count becomes 0

}


Another way I have found effective using this pattern is to have a long running blocking thread, where data for aggregation is passed in and out of the dedicated worker CPU bound thread via channels. Here is some pseudo code that describes the pattern.


use tokio;
use reqwests;

struct RequestParameters {
  // parameters for data fetch
}

#[derive(Deserialize)]
struct ApiResponse {
  // third party api shape
}

struct AggregationResults {
  // data after aggregation is performed
}



#[tokio::main]
async fn main() {
  let client = reqwest::Client();
  let (req_param_send, rep_param_recv) = 
      tokio::sync::mspc::channel::<RequestParameters>(10);
  let (response_send, response_recv) = 
      tokio::sync::mspc::channel::<ApiResponse>(10);
  let (aggregation_send, aggregation_recv) = 
      tokio::sync::mspc::channel::<AggregationResults>(10);

  let request_thread = tokio::spawn(async move {
      while let Some(req_params) = rep_param_recv.recv().await.unwrap() {
         let request = // construct request from request params
         let response = client.get(request).await.unwrap();
         let response_data = response.json::<ApiResponse>().await.unwrap();
         response_send.send(response_data).await.unwrap();
      }
   });

    let agg_thread = tokio::task::spawn_blocking(move {
      while let Some(response) = response_recv.recv_blocking().unwrap() {
        let agg = // perform aggregation
        
        aggregation_send.send_blocking(agg).unwrap();
      }      
    }
}
  let req_params = Vec::new();
  // populate request params
  for req in req_params.into_iter() {
      req_param_send.send(req).await.unwrap();
   }

  let agg_results: Vec<AggregationResults> = Vec::new();
  while let Some(agg_result) = aggregation_recv.recv().await.unwrap() {
    agg_results.push(agg_result);
  }

Use a futures Buffer When Making a Large Number of HTTP Requests My first crack at this was to iterate through a Vec of reqwest RequestBuilders, send them off and save the JoinHandle in a Vec, collecting through another iteration through to await the JoinHandle.

Something like this:

let requests: Vec<RequestBuilder> = vec![<built up requests>];

let handles: JoinHandle<_> = requests.into_iter().map(|r| r.send()).collect();

let results = Vec::new();

for handle in handles.into_iter() {
  match handle.await {
      Ok(data) => {
          // do stuff with the data
        }
      Err(e) => {
          // handle join error
        }
    }
}

This approach is unstable, if all goes well and the number of requests is not too large, this can be fast. However, this can also overwhelm resources by spawning too many tasks, resulting in significant resource contention. Also, if order does not matter, a slow request can become a bottleneck, as we need to wait for it to finish before retrieving the next tokio future.

An more efficient alternative is to use a futures stream, and buffer it.

use futures::stream::StreamExt;

let requests: Vec<RequestBuilder> = vec![<built up requests>];

let stream = tokio_stream::iter(
  requests.into_iter().map(|r| r.send()))
.buffer_unorderd(BUFFER_SIZE);

while let Some(res) = stream.next().await {
  // do stuff with result
}

Reqwest Client is Arc internally This took me a while to figure out. Given that the data came from single vendor service, a single HTTP connections can be reused, which lowers overhead and eliminates the need to constantly reconnect. The reqwest crate allows for setting custom configurations using the builder pattern. Something like this.

use reqwest::ClientBuilder;

let client = ClientBuilder::new()
            .pool_max_idle_per_host(60)
            .connect_timeout(std::time::Duration::from_secs(1))
            .read_timeout(std::time::Duration::from_millis(500))
            .build()
            .unwrap();

Initially, I wrapped the client in Arc, so it could be safely passed between threads. Only later when I went to the source code for Client, did I realize the reqwest Client is already Arc internally. Here is the struct definition for Client.

/// An asynchronous `Client` to make Requests with.
///
/// The Client has various configuration values to tweak, but the defaults
/// are set to what is usually the most commonly desired value. To configure a
/// `Client`, use `Client::builder()`.
///
/// The `Client` holds a connection pool internally, so it is advised that
/// you create one and **reuse** it.
///
/// You do **not** have to wrap the `Client` in an [`Rc`] or [`Arc`] to **reuse** it,
/// because it already uses an [`Arc`] internally.
///
/// [`Rc`]: std::rc::Rc

#[derive(Clone)]
pub struct Client {
    inner: Arc<ClientRef>,
}

Pays to read the documentation more closely. Anyways, since it is already Arc internally, there is no safety concern or overhead for cloning a client before the clone is moved in a tokio thread. I’ve come to really like this pattern, wrapper type with an internal smart pointer to the core data structure. Use Channel Sizes to Apply Back Pressure

Channels are a great way in multi-threaded Rust to pass data from one thread to another without requiring any locking primitives. There are a number of channel implementations in tokio, oneshot, multiple producer single consumer, broadcast, and watch. The async_channel crate also has a multi producer multi consumer channel that ensures each message is consumed once across all consumers. Channels make data sharing across threads much simpler than having shared memory across threads. When an item is send through a channel, then ownership is moved to the receiving thread. In a multi-step data processing pipeline, this is a natural way to promote data items to the next processing step.

Bounded channels require a buffer size, which internally is a semaphore of value requested buffer size. This means that no more that buffer size items can be queued in the channel at any given point, providing implicit back pressure. Right sizing these buffers is critical. When the buffer size is too large, the memory footprint on the data pipeline will increase rapidly.

When channel buffers are full, trying to push data into them is either a blocking call on the thread, or an error depending on the API you are using. Sender.send() will wait until there is space in the channel, whereas try_send will try to send immediately and error if unsuccessful. The channel buffer size allows for control over how many records are in flight at any given time.

Use a Semaphore to Limit to Number of in Flight Requests Something I found useful to limit to number of inflight requests is to wrap the client into a struct with an internal semaphore, and submit all requests internally to the client. Given the client is already Arc internally, we only need to wrap the Semaphore in an Arc to achieve this. Here is my data structure I used to accomplish this. The value of the semaphore is the limiting factor in the number of in flight requests. You can also use a tokio timeout to cancel a request in the case a request hangs, so other threads are able to acquire the semaphore, otherwise a deadlock may occur (not shown here).

#[derive(Clone, Debug)]
pub struct HttpLimitedClient {
    client: Client,
    semaphore: Arc<Semaphore>,
}

impl HttpLimitedClient {
    pub fn new() -> HttpLimitedClient {
        let client = ClientBuilder::new()
            .pool_max_idle_per_host(60)
            .connect_timeout(std::time::Duration::from_secs(1))
            .read_timeout(std::time::Duration::from_millis(500))
            .build()
            .unwrap();

        let semaphore = Semaphore::new(32);
        HttpLimitedClient {
            client,
            semaphore: Arc::new(semaphore),
        }
    }
    pub async fn get(&self, url: Url) -> Result<Response, HttpRequestError> {
        // acquire sempahore permit
        let permit = self.semaphore.clone().acquire_owned().await?;
        let resp = self.client.get(url).send().await?;
        drop(permit);
        Ok(resp)
    }
}

Limit Shared State

The concurrency model in Rust forces you to avoid data races, but this safety comes with some overhead. The Rust compiler must be able to statically proved no data races are introduced at compile time. In order to safely share mutable state across threads, quite a bit of locking is required. Generally, I find sharing data between threads via channels to be a far more elegant approach. Concurrent data structures such as DashMap lower the burden of lock contention by locking on shards rather than an entire table, but nonetheless, the less shared state the better. This stands for both performance, complexity, and safety. It is also a rather natural pattern for this kind of work.

Consider Latency vs Throughput

The tradeoff for this kind of process is latency of a single record through the entire pipeline versus the throughput of records complete per time unit. Initially I was trying to optimize for latency of a single record, while performance across the entire dataset improved drastically when I optimized for in flight records being processed and thought in terms of records finished per time unit.

Closing

These are some of the things I learned in this project. Hopefully they help. If you as a reader see anything or contest anything I say, please share in the comments of this post. I built this project a while ago now, and certainly have learned some things since.

Read the full article here: https://medium.com/@kilianhammersmith/lessons-learned-from-writing-a-data-pipeline-in-rust-using-tokio-1716f49ce970