1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| #include <torch/extension.h> #include <vector>
std::vector<torch::Tensor> quantize(torch::Tensor src, int bits) { int64_t level = (1 << bits) - 1; auto max_val = std::get<0>(src.max(-1, true)); auto min_val = std::get<0>(src.min(-1, true)); auto range_val = max_val - min_val; range_val = torch::where(range_val == 0, torch::ones_like(range_val), range_val);
auto quantized = torch::round((src - min_val) * level / range_val).to(torch::kInt8); return {max_val, min_val, quantized}; }
torch::Tensor dequantize(torch::Tensor src, torch::Tensor max_val, torch::Tensor min_val, int bits) { int64_t level = (1 << bits) - 1; auto range_val = max_val - min_val; range_val = torch::where(range_val == 0, torch::ones_like(range_val), range_val);
auto dequantized = src.to(torch::kFloat32) * range_val / level + min_val; return dequantized; }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("quantize", &quantize, "Quantize tensor"); m.def("dequantize", &dequantize, "Dequantize tensor"); }
|