summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock22
-rw-r--r--fatcat-cli/Cargo.toml1
-rw-r--r--fatcat-cli/src/download.rs120
-rw-r--r--fatcat-cli/src/main.rs10
4 files changed, 127 insertions, 26 deletions
diff --git a/Cargo.lock b/Cargo.lock
index da300b7..6a7f331 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -235,6 +235,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3a71ab494c0b5b860bdc8407ae08978052417070c2ced38573a9157ad75b8ac"
[[package]]
+name = "crossbeam-channel"
+version = "0.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dca26ee1f8d361640700bde38b2c37d8c22b3ce2d360e1fc1c74ea4b0aa7d775"
+dependencies = [
+ "cfg-if 1.0.0",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-utils"
+version = "0.8.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "02d96d1e189ef58269ebe5b97953da3274d83a93af647c2ddd6f9dab28cedb8d"
+dependencies = [
+ "autocfg",
+ "cfg-if 1.0.0",
+ "lazy_static",
+]
+
+[[package]]
name = "data-encoding"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -301,6 +322,7 @@ dependencies = [
"atty",
"chrono-humanize",
"colored_json",
+ "crossbeam-channel",
"data-encoding",
"env_logger",
"fatcat-openapi",
diff --git a/fatcat-cli/Cargo.toml b/fatcat-cli/Cargo.toml
index 1cb7aab..01758ed 100644
--- a/fatcat-cli/Cargo.toml
+++ b/fatcat-cli/Cargo.toml
@@ -39,6 +39,7 @@ tempfile = "3"
indicatif = "0.15"
url = "*"
sha1 = { version = "*", features = ["std"] }
+crossbeam-channel = "0.5"
[dev-dependencies]
assert_cmd = "1"
diff --git a/fatcat-cli/src/download.rs b/fatcat-cli/src/download.rs
index 5bec0d4..ada9834 100644
--- a/fatcat-cli/src/download.rs
+++ b/fatcat-cli/src/download.rs
@@ -1,7 +1,7 @@
use anyhow::{anyhow, Context, Result};
use fatcat_openapi::models::{FileEntity, ReleaseEntity};
use indicatif::{ProgressBar, ProgressStyle};
-use log::info;
+use log::{info, error};
use reqwest::header::USER_AGENT;
use std::fmt;
use std::fs::File;
@@ -10,6 +10,8 @@ use std::path::PathBuf;
use url::Url;
use crate::{ApiModelIdent, Specifier};
use sha1::Sha1;
+use std::thread;
+use crossbeam_channel as channel;
#[derive(Debug, PartialEq, Clone)]
pub enum DownloadStatus {
@@ -110,7 +112,7 @@ fn default_filename(specifier: &Specifier, fe: &FileEntity) -> Result<PathBuf> {
}
/// Attempts to download a file entity, including verifying checksum.
-pub fn download_file(fe: &FileEntity, specifier: &Specifier, output_path: Option<PathBuf>) -> Result<DownloadStatus> {
+pub fn download_file(fe: &FileEntity, specifier: &Specifier, output_path: Option<PathBuf>, show_progress: bool) -> Result<DownloadStatus> {
let expected_sha1 = match &fe.sha1 {
Some(v) => v,
None => return Ok(DownloadStatus::FileMissingMetadata),
@@ -137,7 +139,7 @@ pub fn download_file(fe: &FileEntity, specifier: &Specifier, output_path: Option
));
};
- let download_path = final_path.with_extension("download");
+ let download_path = final_path.with_extension("partial_download");
// TODO: only archive.org URLs (?)
let raw_url = match fe.urls.as_ref() {
@@ -194,13 +196,26 @@ pub fn download_file(fe: &FileEntity, specifier: &Specifier, output_path: Option
}
}
+ let (write_result, out_sha1) = match show_progress {
+ true => {
+ let pb = ProgressBar::new(fe.size.unwrap() as u64);
+ pb.set_style(ProgressStyle::default_bar()
+ .template("{spinner:.green} [{elapsed_precise}] [{bar:40}] {bytes}/{total_bytes} ({eta})")
+ .progress_chars("#>-"));
+ let mut wrapped_file = Sha1WriteWrapper::new(pb.wrap_write(download_file));
+ let result = resp.copy_to(&mut wrapped_file);
+ let out_sha1 = wrapped_file.into_hexdigest();
+ (result, out_sha1)
+ },
+ false => {
+ let mut wrapped_file = Sha1WriteWrapper::new(download_file);
+ let result = resp.copy_to(&mut wrapped_file);
+ let out_sha1 = wrapped_file.into_hexdigest();
+ (result, out_sha1)
+ },
+ };
- let pb = ProgressBar::new(fe.size.unwrap() as u64);
- pb.set_style(ProgressStyle::default_bar()
- .template("{spinner:.green} [{elapsed_precise}] [{bar:40}] {bytes}/{total_bytes} ({eta})")
- .progress_chars("#>-"));
- let mut wrapped_file = Sha1WriteWrapper::new(pb.wrap_write(download_file));
- let out_size = match resp.copy_to(&mut wrapped_file) {
+ let out_size = match write_result {
Ok(r) => r,
Err(e) => {
std::fs::remove_file(download_path)?;
@@ -208,7 +223,6 @@ pub fn download_file(fe: &FileEntity, specifier: &Specifier, output_path: Option
}
};
- let out_sha1 = wrapped_file.into_hexdigest();
if &out_sha1 != expected_sha1 {
std::fs::remove_file(download_path)?;
return Ok(DownloadStatus::WrongHash);
@@ -225,7 +239,7 @@ pub fn download_file(fe: &FileEntity, specifier: &Specifier, output_path: Option
))
}
-pub fn download_release(re: &ReleaseEntity, output_path: Option<PathBuf>) -> Result<DownloadStatus> {
+pub fn download_release(re: &ReleaseEntity, output_path: Option<PathBuf>, show_progress: bool) -> Result<DownloadStatus> {
let file_entities = match &re.files {
None => {
return Err(anyhow!(
@@ -237,7 +251,7 @@ pub fn download_release(re: &ReleaseEntity, output_path: Option<PathBuf>) -> Res
let mut status = DownloadStatus::NoPublicFile;
let specifier = re.specifier();
for fe in file_entities {
- status = download_file(&fe, &specifier, output_path.clone())?;
+ status = download_file(&fe, &specifier, output_path.clone(), show_progress)?;
match status {
DownloadStatus::Exists(_) | DownloadStatus::Downloaded(_) => break,
_ => (),
@@ -247,18 +261,18 @@ pub fn download_release(re: &ReleaseEntity, output_path: Option<PathBuf>) -> Res
}
/// Tries either file or release
-fn download_entity(json_str: String, output_path: Option<PathBuf>) -> Result<DownloadStatus> {
+fn download_entity(json_str: String, output_path: Option<PathBuf>, show_progress: bool) -> Result<(DownloadStatus, String)> {
let release_attempt = serde_json::from_str::<ReleaseEntity>(&json_str);
if let Ok(re) = release_attempt {
if re.ident.is_some() && (re.title.is_some() || re.files.is_some()) {
- let status = download_release(&re, output_path)?;
- println!(
+ let status = download_release(&re, output_path, show_progress)?;
+ let status_line = format!(
"release_{}\t{}\t{}",
re.ident.unwrap(),
status,
status.details().unwrap_or("".to_string())
);
- return Ok(status);
+ return Ok((status, status_line));
};
}
let file_attempt =
@@ -267,14 +281,14 @@ fn download_entity(json_str: String, output_path: Option<PathBuf>) -> Result<Dow
Ok(fe) => {
if fe.ident.is_some() && fe.urls.is_some() {
let specifier = fe.specifier();
- let status = download_file(&fe, &specifier, output_path)?;
- println!(
+ let status = download_file(&fe, &specifier, output_path, show_progress)?;
+ let status_line = format!(
"file_{}\t{}\t{}",
fe.ident.unwrap(),
status,
status.details().unwrap_or("".to_string())
);
- return Ok(status);
+ return Ok((status, status_line));
} else {
Err(anyhow!("not a file entity (JSON)"))
}
@@ -283,17 +297,73 @@ fn download_entity(json_str: String, output_path: Option<PathBuf>) -> Result<Dow
}
}
-pub fn download_batch(input_path: Option<PathBuf>, output_dir: Option<PathBuf>, limit: Option<u64>, _jobs: u64) -> Result<u64> {
- // TODO: create worker pipeline using channels
+struct DownloadTask {
+ json_str: String,
+ output_path: Option<PathBuf>,
+ show_progress: bool,
+}
+
+fn loop_printer(
+ output_receiver: channel::Receiver<String>,
+ done_sender: channel::Sender<()>,
+) -> Result<()> {
+ for line in output_receiver {
+ println!("{}", line);
+ }
+ drop(done_sender);
+ Ok(())
+}
+
+fn loop_download_tasks(task_receiver: channel::Receiver<DownloadTask>, output_sender: channel::Sender<String>) {
+ let thread_result: Result<()> = (|| {
+ for task in task_receiver {
+ let (_, status_line) = download_entity(task.json_str, task.output_path, task.show_progress)?;
+ output_sender.send(status_line)?;
+ }
+ Ok(())
+ })();
+ if let Err(ref e) = thread_result {
+ error!("{}", e);
+ }
+ thread_result.unwrap()
+}
+
+pub fn download_batch(input_path: Option<PathBuf>, output_dir: Option<PathBuf>, limit: Option<u64>, jobs: u64) -> Result<u64> {
let count = 0;
+
+ assert!(jobs > 0 && jobs <= 12);
+
+ let show_progress = jobs == 1;
+
+ let (task_sender, task_receiver) = channel::bounded(12);
+ let (output_sender, output_receiver) = channel::bounded(12);
+ let (done_sender, done_receiver) = channel::bounded(0);
+
+ for _ in 0..jobs {
+ let task_receiver = task_receiver.clone();
+ let output_sender = output_sender.clone();
+ thread::spawn(move || {
+ loop_download_tasks(task_receiver, output_sender);
+ });
+ }
+ drop(output_sender);
+
+ // Start printer thread
+ thread::spawn(move || loop_printer(output_receiver, done_sender).expect("printing to stdout"));
+
match input_path {
None => {
let stdin = io::stdin();
let stdin_lock = stdin.lock();
let lines = stdin_lock.lines();
for line in lines {
- let json_str = line?;
- download_entity(json_str, output_dir.clone())?;
+ let task = DownloadTask {
+ json_str: line?,
+ output_path: output_dir.clone(),
+ show_progress,
+ };
+ task_sender.send(task)?;
+ count += 1;
if let Some(limit) = limit {
if count >= limit {
break;
@@ -307,7 +377,8 @@ pub fn download_batch(input_path: Option<PathBuf>, output_dir: Option<PathBuf>,
let lines = buffered.lines();
for line in lines {
let json_str = line?;
- download_entity(json_str, output_dir.clone())?;
+ download_entity(json_str, output_dir.clone(), show_progress)?;
+ count += 1;
if let Some(limit) = limit {
if count >= limit {
break;
@@ -316,5 +387,6 @@ pub fn download_batch(input_path: Option<PathBuf>, output_dir: Option<PathBuf>,
}
}
};
+ done_receiver.recv()?;
Ok(count)
}
diff --git a/fatcat-cli/src/main.rs b/fatcat-cli/src/main.rs
index ced35b4..3b96cce 100644
--- a/fatcat-cli/src/main.rs
+++ b/fatcat-cli/src/main.rs
@@ -485,6 +485,12 @@ fn run(opt: Opt) -> Result<()> {
return Err(anyhow!("output directory doesn't exist"));
}
}
+ if jobs == 0 {
+ return Err(anyhow!("--jobs=0 not implemented"));
+ }
+ if jobs > 12 {
+ return Err(anyhow!("please don't download more than 12 parallel requests"));
+ }
download_batch(input_path, output_dir, limit, jobs)?;
}
Command::Download {
@@ -515,7 +521,7 @@ fn run(opt: Opt) -> Result<()> {
resp => Err(anyhow!("{:?}", resp))
.with_context(|| format!("API GET failed: {:?}", ident)),
}?;
- download_release(&release_entity, output_path)
+ download_release(&release_entity, output_path, true)
}
Specifier::File(ident) => {
let result = api_client.rt.block_on(api_client.api.get_file(
@@ -528,7 +534,7 @@ fn run(opt: Opt) -> Result<()> {
resp => Err(anyhow!("{:?}", resp))
.with_context(|| format!("API GET failed: {:?}", ident)),
}?;
- download_file(&file_entity, &file_entity.specifier(), output_path)
+ download_file(&file_entity, &file_entity.specifier(), output_path, true)
}
other => Err(anyhow!("Don't know how to download: {:?}", other)),
}?;