From 41a22ba97e9a06d2d72a53490991acb8d32c20f8 Mon Sep 17 00:00:00 2001 From: Bryan Newbold Date: Wed, 10 Feb 2021 19:41:44 -0800 Subject: batch download parallelism --- fatcat-cli/src/download.rs | 120 ++++++++++++++++++++++++++++++++++++--------- fatcat-cli/src/main.rs | 10 +++- 2 files changed, 104 insertions(+), 26 deletions(-) (limited to 'fatcat-cli/src') diff --git a/fatcat-cli/src/download.rs b/fatcat-cli/src/download.rs index 5bec0d4..ada9834 100644 --- a/fatcat-cli/src/download.rs +++ b/fatcat-cli/src/download.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, Context, Result}; use fatcat_openapi::models::{FileEntity, ReleaseEntity}; use indicatif::{ProgressBar, ProgressStyle}; -use log::info; +use log::{info, error}; use reqwest::header::USER_AGENT; use std::fmt; use std::fs::File; @@ -10,6 +10,8 @@ use std::path::PathBuf; use url::Url; use crate::{ApiModelIdent, Specifier}; use sha1::Sha1; +use std::thread; +use crossbeam_channel as channel; #[derive(Debug, PartialEq, Clone)] pub enum DownloadStatus { @@ -110,7 +112,7 @@ fn default_filename(specifier: &Specifier, fe: &FileEntity) -> Result { } /// Attempts to download a file entity, including verifying checksum. -pub fn download_file(fe: &FileEntity, specifier: &Specifier, output_path: Option) -> Result { +pub fn download_file(fe: &FileEntity, specifier: &Specifier, output_path: Option, show_progress: bool) -> Result { let expected_sha1 = match &fe.sha1 { Some(v) => v, None => return Ok(DownloadStatus::FileMissingMetadata), @@ -137,7 +139,7 @@ pub fn download_file(fe: &FileEntity, specifier: &Specifier, output_path: Option )); }; - let download_path = final_path.with_extension("download"); + let download_path = final_path.with_extension("partial_download"); // TODO: only archive.org URLs (?) let raw_url = match fe.urls.as_ref() { @@ -194,13 +196,26 @@ pub fn download_file(fe: &FileEntity, specifier: &Specifier, output_path: Option } } + let (write_result, out_sha1) = match show_progress { + true => { + let pb = ProgressBar::new(fe.size.unwrap() as u64); + pb.set_style(ProgressStyle::default_bar() + .template("{spinner:.green} [{elapsed_precise}] [{bar:40}] {bytes}/{total_bytes} ({eta})") + .progress_chars("#>-")); + let mut wrapped_file = Sha1WriteWrapper::new(pb.wrap_write(download_file)); + let result = resp.copy_to(&mut wrapped_file); + let out_sha1 = wrapped_file.into_hexdigest(); + (result, out_sha1) + }, + false => { + let mut wrapped_file = Sha1WriteWrapper::new(download_file); + let result = resp.copy_to(&mut wrapped_file); + let out_sha1 = wrapped_file.into_hexdigest(); + (result, out_sha1) + }, + }; - let pb = ProgressBar::new(fe.size.unwrap() as u64); - pb.set_style(ProgressStyle::default_bar() - .template("{spinner:.green} [{elapsed_precise}] [{bar:40}] {bytes}/{total_bytes} ({eta})") - .progress_chars("#>-")); - let mut wrapped_file = Sha1WriteWrapper::new(pb.wrap_write(download_file)); - let out_size = match resp.copy_to(&mut wrapped_file) { + let out_size = match write_result { Ok(r) => r, Err(e) => { std::fs::remove_file(download_path)?; @@ -208,7 +223,6 @@ pub fn download_file(fe: &FileEntity, specifier: &Specifier, output_path: Option } }; - let out_sha1 = wrapped_file.into_hexdigest(); if &out_sha1 != expected_sha1 { std::fs::remove_file(download_path)?; return Ok(DownloadStatus::WrongHash); @@ -225,7 +239,7 @@ pub fn download_file(fe: &FileEntity, specifier: &Specifier, output_path: Option )) } -pub fn download_release(re: &ReleaseEntity, output_path: Option) -> Result { +pub fn download_release(re: &ReleaseEntity, output_path: Option, show_progress: bool) -> Result { let file_entities = match &re.files { None => { return Err(anyhow!( @@ -237,7 +251,7 @@ pub fn download_release(re: &ReleaseEntity, output_path: Option) -> Res let mut status = DownloadStatus::NoPublicFile; let specifier = re.specifier(); for fe in file_entities { - status = download_file(&fe, &specifier, output_path.clone())?; + status = download_file(&fe, &specifier, output_path.clone(), show_progress)?; match status { DownloadStatus::Exists(_) | DownloadStatus::Downloaded(_) => break, _ => (), @@ -247,18 +261,18 @@ pub fn download_release(re: &ReleaseEntity, output_path: Option) -> Res } /// Tries either file or release -fn download_entity(json_str: String, output_path: Option) -> Result { +fn download_entity(json_str: String, output_path: Option, show_progress: bool) -> Result<(DownloadStatus, String)> { let release_attempt = serde_json::from_str::(&json_str); if let Ok(re) = release_attempt { if re.ident.is_some() && (re.title.is_some() || re.files.is_some()) { - let status = download_release(&re, output_path)?; - println!( + let status = download_release(&re, output_path, show_progress)?; + let status_line = format!( "release_{}\t{}\t{}", re.ident.unwrap(), status, status.details().unwrap_or("".to_string()) ); - return Ok(status); + return Ok((status, status_line)); }; } let file_attempt = @@ -267,14 +281,14 @@ fn download_entity(json_str: String, output_path: Option) -> Result { if fe.ident.is_some() && fe.urls.is_some() { let specifier = fe.specifier(); - let status = download_file(&fe, &specifier, output_path)?; - println!( + let status = download_file(&fe, &specifier, output_path, show_progress)?; + let status_line = format!( "file_{}\t{}\t{}", fe.ident.unwrap(), status, status.details().unwrap_or("".to_string()) ); - return Ok(status); + return Ok((status, status_line)); } else { Err(anyhow!("not a file entity (JSON)")) } @@ -283,17 +297,73 @@ fn download_entity(json_str: String, output_path: Option) -> Result, output_dir: Option, limit: Option, _jobs: u64) -> Result { - // TODO: create worker pipeline using channels +struct DownloadTask { + json_str: String, + output_path: Option, + show_progress: bool, +} + +fn loop_printer( + output_receiver: channel::Receiver, + done_sender: channel::Sender<()>, +) -> Result<()> { + for line in output_receiver { + println!("{}", line); + } + drop(done_sender); + Ok(()) +} + +fn loop_download_tasks(task_receiver: channel::Receiver, output_sender: channel::Sender) { + let thread_result: Result<()> = (|| { + for task in task_receiver { + let (_, status_line) = download_entity(task.json_str, task.output_path, task.show_progress)?; + output_sender.send(status_line)?; + } + Ok(()) + })(); + if let Err(ref e) = thread_result { + error!("{}", e); + } + thread_result.unwrap() +} + +pub fn download_batch(input_path: Option, output_dir: Option, limit: Option, jobs: u64) -> Result { let count = 0; + + assert!(jobs > 0 && jobs <= 12); + + let show_progress = jobs == 1; + + let (task_sender, task_receiver) = channel::bounded(12); + let (output_sender, output_receiver) = channel::bounded(12); + let (done_sender, done_receiver) = channel::bounded(0); + + for _ in 0..jobs { + let task_receiver = task_receiver.clone(); + let output_sender = output_sender.clone(); + thread::spawn(move || { + loop_download_tasks(task_receiver, output_sender); + }); + } + drop(output_sender); + + // Start printer thread + thread::spawn(move || loop_printer(output_receiver, done_sender).expect("printing to stdout")); + match input_path { None => { let stdin = io::stdin(); let stdin_lock = stdin.lock(); let lines = stdin_lock.lines(); for line in lines { - let json_str = line?; - download_entity(json_str, output_dir.clone())?; + let task = DownloadTask { + json_str: line?, + output_path: output_dir.clone(), + show_progress, + }; + task_sender.send(task)?; + count += 1; if let Some(limit) = limit { if count >= limit { break; @@ -307,7 +377,8 @@ pub fn download_batch(input_path: Option, output_dir: Option, let lines = buffered.lines(); for line in lines { let json_str = line?; - download_entity(json_str, output_dir.clone())?; + download_entity(json_str, output_dir.clone(), show_progress)?; + count += 1; if let Some(limit) = limit { if count >= limit { break; @@ -316,5 +387,6 @@ pub fn download_batch(input_path: Option, output_dir: Option, } } }; + done_receiver.recv()?; Ok(count) } diff --git a/fatcat-cli/src/main.rs b/fatcat-cli/src/main.rs index ced35b4..3b96cce 100644 --- a/fatcat-cli/src/main.rs +++ b/fatcat-cli/src/main.rs @@ -485,6 +485,12 @@ fn run(opt: Opt) -> Result<()> { return Err(anyhow!("output directory doesn't exist")); } } + if jobs == 0 { + return Err(anyhow!("--jobs=0 not implemented")); + } + if jobs > 12 { + return Err(anyhow!("please don't download more than 12 parallel requests")); + } download_batch(input_path, output_dir, limit, jobs)?; } Command::Download { @@ -515,7 +521,7 @@ fn run(opt: Opt) -> Result<()> { resp => Err(anyhow!("{:?}", resp)) .with_context(|| format!("API GET failed: {:?}", ident)), }?; - download_release(&release_entity, output_path) + download_release(&release_entity, output_path, true) } Specifier::File(ident) => { let result = api_client.rt.block_on(api_client.api.get_file( @@ -528,7 +534,7 @@ fn run(opt: Opt) -> Result<()> { resp => Err(anyhow!("{:?}", resp)) .with_context(|| format!("API GET failed: {:?}", ident)), }?; - download_file(&file_entity, &file_entity.specifier(), output_path) + download_file(&file_entity, &file_entity.specifier(), output_path, true) } other => Err(anyhow!("Don't know how to download: {:?}", other)), }?; -- cgit v1.2.3