python cplusplus extension

Python C++ Extension

  • .vscode/c_cpp_properties.json

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    {
    "configurations": [
    {
    "name": "Linux",
    "includePath": [
    "${workspaceFolder}/**",
    "/root/miniconda3/envs/python3/include/python3.10",
    "/root/miniconda3/envs/python3/lib/python3.10/site-packages/torch/include/**",
    "/root/miniconda3/envs/python3/lib/python3.10/site-packages/torch/include/torch/csrc/api/include"
    ],
    "defines": [],
    "compilerPath": "/root/.local/bin/clang",
    "cStandard": "c17",
    "cppStandard": "c++17",
    "intelliSenseMode": "linux-clang-x64"
    }
    ],
    "version": 4
    }
  • quantization.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <torch/extension.h>
#include <vector>

std::vector<torch::Tensor> quantize(torch::Tensor src, int bits) {
int64_t level = (1 << bits) - 1;
auto max_val = std::get<0>(src.max(-1, true));
auto min_val = std::get<0>(src.min(-1, true));
auto range_val = max_val - min_val;
range_val = torch::where(range_val == 0, torch::ones_like(range_val), range_val);

auto quantized = torch::round((src - min_val) * level / range_val).to(torch::kInt8);
return {max_val, min_val, quantized};
}

torch::Tensor dequantize(torch::Tensor src, torch::Tensor max_val, torch::Tensor min_val, int bits) {
int64_t level = (1 << bits) - 1;
auto range_val = max_val - min_val;
range_val = torch::where(range_val == 0, torch::ones_like(range_val), range_val);

auto dequantized = src.to(torch::kFloat32) * range_val / level + min_val;
return dequantized;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("quantize", &quantize, "Quantize tensor");
m.def("dequantize", &dequantize, "Dequantize tensor");
}
  • setup.py
1
2
3
4
5
6
7
8
9
10
11
12
from setuptools import setup
from torch.utils.cpp_extension import CppExtension, BuildExtension

setup(
name='quantization',
ext_modules=[
CppExtension('quantization', ['quantization.cpp']),
],
cmdclass={
'build_ext': BuildExtension
}
)

Build

1
2
python setup.py build_ext --inplace
python setup.py install

Usage

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import json
import time
from typing import Optional, Tuple
import torch
import quantization

class Quantization(torch.nn.Module):
def quant(src: torch.Tensor, bits=4):
# level = bits ** 2 - 1
# max_val = torch.max(src, dim=-1, keepdim=True).values
# min_val = torch.min(src, dim=-1, keepdim=True).values
# src = torch.round(torch.div(torch.mul(torch.sub(src, min_val), level), torch.sub(max_val, min_val)))
max_val, min_val, quantized = quantization.quantize(src, bits)
return max_val, min_val, bits, quantized

def dequant(src: torch.Tensor, max_val: torch.Tensor, min_val: torch.Tensor, bits=4):
# level = bits ** 2 - 1
# src = torch.add(torch.div(torch.mul(src, torch.sub(max_val, min_val)), level), min_val)
# return src
dequantized = quantization.dequantize(src, max_val, min_val, bits)
return dequantized