cargo new myapp
cd myapp
cargo add --git https://github.com/huggingface/candle.git candle-core
cargo build # 测试,或执行 cargo ckeck
use candle_core::{Device, Tensor};
fn main() -> Result<(), Box> {
let device = Device::Cpu;
let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
let b = Tensor::randn(0f32, 1., (3, 4), &device)?;
let c = a.matmul(&b)?;
println!("{c}");
Ok(())
}
~/myrust$ cargo new myapp
Created binary (application) `myapp` package
~/myrust$ cd myapp
~/myrust/myapp$ cargo add --git https://github.com/huggingface/candle.git candle-core
Updating git repository `https://github.com/huggingface/candle.git`
Updating git submodule `https://github.com/NVIDIA/cutlass.git`
Adding candle-core (git) to dependencies.
Features:
- accelerate
- cuda
- cudarc
- cudnn
- metal
- mkl
Updating git repository `https://github.com/huggingface/candle.git`
Updating crates.io index
~/myrust/myapp$ cargo build
Downloaded serde_derive v1.0.195
Downloaded either v1.9.0
Downloaded autocfg v1.1.0
Downloaded zerofrom v0.1.3
Downloaded zerofrom-derive v0.1.3
Downloaded synstructure v0.13.0
Downloaded crossbeam-deque v0.8.5
Downloaded yoke-derive v0.7.3
Downloaded half v2.3.1
Downloaded bytemuck v1.14.1
Downloaded rand_core v0.6.4
Downloaded paste v1.0.14
Downloaded proc-macro2 v1.0.78
Downloaded itoa v1.0.10
Downloaded memmap2 v0.9.4
Downloaded syn v2.0.48
Downloaded crossbeam-epoch v0.9.18
Downloaded cfg-if v1.0.0
Downloaded bitflags v1.3.2
Downloaded num_cpus v1.16.0
Downloaded gemm-f32 v0.17.0
Downloaded reborrow v0.5.5
Downloaded stable_deref_trait v1.2.0
Downloaded rayon-core v1.12.1
Downloaded seq-macro v0.3.5
Downloaded thiserror-impl v1.0.56
Downloaded dyn-stack v0.10.0
Downloaded thiserror v1.0.56
Downloaded unicode-xid v0.2.4
Downloaded rand_chacha v0.3.1
Downloaded ppv-lite86 v0.2.17
Downloaded bytemuck_derive v1.5.0
Downloaded getrandom v0.2.12
Downloaded once_cell v1.19.0
Downloaded unicode-ident v1.0.12
Downloaded byteorder v1.5.0
Downloaded crc32fast v1.3.2
Downloaded num-complex v0.4.4
Downloaded gemm-common v0.17.0
Downloaded crossbeam-utils v0.8.19
Downloaded quote v1.0.35
Downloaded ryu v1.0.16
Downloaded num-traits v0.2.17
Downloaded zip v0.6.6
Downloaded rand_distr v0.4.3
Downloaded serde v1.0.195
Downloaded rand v0.8.5
Downloaded raw-cpuid v10.7.0
Downloaded libm v0.2.8
Downloaded serde_json v1.0.111
Downloaded rayon v1.8.1
Downloaded libc v0.2.152
Downloaded gemm-c64 v0.17.0
Downloaded gemm-c32 v0.17.0
Downloaded safetensors v0.4.2
Downloaded gemm-f64 v0.17.0
Downloaded gemm v0.17.0
Downloaded gemm-f16 v0.17.0
Downloaded yoke v0.7.3
Downloaded pulp v0.18.6
Downloaded 60 crates (3.1 MB) in 14.91s
Compiling proc-macro2 v1.0.78
Compiling unicode-ident v1.0.12
Compiling libc v0.2.152
Compiling cfg-if v1.0.0
Compiling libm v0.2.8
Compiling autocfg v1.1.0
Compiling crossbeam-utils v0.8.19
Compiling ppv-lite86 v0.2.17
Compiling rayon-core v1.12.1
Compiling reborrow v0.5.5
Compiling paste v1.0.14
Compiling either v1.9.0
Compiling bitflags v1.3.2
Compiling seq-macro v0.3.5
Compiling once_cell v1.19.0
Compiling unicode-xid v0.2.4
Compiling raw-cpuid v10.7.0
Compiling serde v1.0.195
Compiling crc32fast v1.3.2
Compiling serde_json v1.0.111
Compiling stable_deref_trait v1.2.0
Compiling itoa v1.0.10
Compiling ryu v1.0.16
Compiling thiserror v1.0.56
Compiling byteorder v1.5.0
Compiling num-traits v0.2.17
Compiling zip v0.6.6
Compiling crossbeam-epoch v0.9.18
Compiling quote v1.0.35
Compiling syn v2.0.48
Compiling crossbeam-deque v0.8.5
Compiling getrandom v0.2.12
Compiling memmap2 v0.9.4
Compiling num_cpus v1.16.0
Compiling rand_core v0.6.4
Compiling rand_chacha v0.3.1
Compiling rayon v1.8.1
Compiling rand v0.8.5
Compiling rand_distr v0.4.3
Compiling synstructure v0.13.0
Compiling bytemuck_derive v1.5.0
Compiling serde_derive v1.0.195
Compiling zerofrom-derive v0.1.3
Compiling thiserror-impl v1.0.56
Compiling yoke-derive v0.7.3
Compiling bytemuck v1.14.1
Compiling num-complex v0.4.4
Compiling dyn-stack v0.10.0
Compiling half v2.3.1
Compiling zerofrom v0.1.3
Compiling yoke v0.7.3
Compiling pulp v0.18.6
Compiling gemm-common v0.17.0
Compiling gemm-f32 v0.17.0
Compiling gemm-c64 v0.17.0
Compiling gemm-f64 v0.17.0
Compiling gemm-c32 v0.17.0
Compiling gemm-f16 v0.17.0
Compiling gemm v0.17.0
Compiling safetensors v0.4.2
Compiling candle-core v0.3.3 (https://github.com/huggingface/candle.git#fd7c8565)
Compiling myapp v0.1.0 (/home/pdd/myrust/myapp)
Finished dev [unoptimized + debuginfo] target(s) in 32.90s
https://github.com/RileySeaburg/candle_test
git clone https://github.com/RileySeaburg/candle_test.git
Cargo.toml
文件[package]
name = "candle_test"
version = "0.1.0"
edition = "2021" # Rust 版本
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.2.1", features = ["cuda"] }
# `candle-core`:项目依赖的包的名称。`git` 字段指定了包的源代码仓库地址。`version` 字段指定了使用的包的版本。`features` 字段是一个数组,指定了启用的功能。在这里,启用了 "cuda" 功能。
# 可以通过以下命令添加,取消可注释掉"cuda",再cargo build
# cargo add --git https://github.com/huggingface/candle.git candle-core
# cargo add candle-core --features cuda
use candle_core::{DType, Device, Result, Tensor};
// 定义一个模型结构体
struct Model {
first: Tensor,
second: Tensor,
}
impl Model {
// 定义模型的前向传播方法
fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = image.matmul(&self.first)?; // 输入乘以第一层权重
let x = x.relu()?; // 使用 ReLU 激活函数
x.matmul(&self.second) // 结果乘以第二层权重
}
}
fn main() -> Result<()> {
// 初始化设备,如果 GPU 可用则使用 GPU,否则使用 CPU
let device = match Device::new_cuda(0) {
Ok(device) => device,
Err(_) => Device::Cpu,
};
// 创建模型的第一层和第二层权重张量
let first = Tensor::zeros((784, 100), DType::F32, &device)
.unwrap()
.contiguous()?;
let second = Tensor::zeros((100, 10), DType::F32, &device)
.unwrap()
.contiguous()?;
// 初始化模型
let model = Model { first, second };
// 创建一个用于测试的虚拟图像张量
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)
.unwrap()
.contiguous()?;
// 调用模型的前向传播方法获取预测结果
let digit = model.forward(&dummy_image)?;
// 打印预测结果
println!("Digit {digit:?} digit");
Ok(())
}
// Result定义在/home/pdd/.cargo/git/checkouts/candle-0c2b4fa9e5801351/e8e3375/candle-core/src/error.rs
pub type Result = std::result::Result; // 定义了一个 `Result` 类型,这是一个 `Result` 类型的别名。其中 `T` 是成功时的返回类型,而 `Error` 是失败时的错误类型。
// Ok(()) 定义在 /home/pdd/.rustup/toolchains/stable-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/result.rs
// 这是 Rust 标准库中的 `Result` 公共的枚举类型,它有两个泛型参数 `T` 和 `E`。`T` 代表成功时返回的值的类型,`E` 代表错误时返回的错误类型。
// #[]是属性(attribute),提供额外信息
pub enum Result<T, E> {
/// Contains the success value
#[lang = "Ok"]
#[stable(feature = "rust1", since = "1.0.0")]
Ok(#[stable(feature = "rust1", since = "1.0.0")] T),// `Ok(T)`: 这是 `Result` 枚举的一个变体,用于表示成功的情况
// (): 是 Rust 中的单元类型(unit type),类似于其他语言中的 void。
/// Contains the error value
#[lang = "Err"]
#[stable(feature = "rust1", since = "1.0.0")]
Err(#[stable(feature = "rust1", since = "1.0.0")] E),// `Err(E)`: 这是 `Result` 枚举的另一个变体,用于表示错误的情况。
}
?
符号用于处理 Result
或 Option
类型的返回值。这个符号的作用是将可能的错误或 None
值快速传播到调用链的最上层,使得代码更加简洁和易读。fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = image.matmul(&self.first)?; // 如果matmul返回Err,则整个forward函数返回Err
let x = x.relu()?; // 如果relu返回Err,则整个forward函数返回Err
x.matmul(&self.second) // 如果matmul返回Err,则整个forward函数返回Err;否则返回Ok(Tensor)
}
函数体:函数体是一个块表达式,其值是最后一个表达式的值。
fn add(x: i32, y: i32) -> i32 {
x + y // 表达式
}