You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

100 lines
2.9 KiB
Python

import pkg_resources
import torch
import ctypes
from typing import List
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
RESOURCE_PACKAGE_NAME = __name__
class Kernel:
def __init__(self, filename: str, function_names: List[str]):
filename = filename + ".fatbin"
if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename):
raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME))
self.filename = filename
self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename)
self._function_names = function_names
self._cmodule = LazyKernelCModule(self.code)
for name in self._function_names:
setattr(self, name, KernelFunction(self._cmodule, name))
kernels = Kernel(
"quantization",
[
"int4WeightCompression",
"int4WeightExtractionFloat",
"int4WeightExtractionHalf",
"int8WeightExtractionFloat",
"int8WeightExtractionHalf",
],
)
def compress_int4_weight(weight: torch.Tensor): # (n, m)
with torch.cuda.device(weight.device):
n, m = weight.size(0), weight.size(1)
assert m % 2 == 0
m = m // 2
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
stream = torch.cuda.current_stream()
gridDim = (n, 1, 1)
blockDim = (min(round_up(m, 32), 1024), 1, 1)
kernels.int4WeightCompression(
gridDim,
blockDim,
0,
stream,
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
)
return out
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
if source_bit_width == 8:
func = kernels.int8WeightExtractionHalf
elif source_bit_width == 4:
func = kernels.int4WeightExtractionHalf
else:
assert False, "Unsupported bit-width"
with torch.cuda.device(weight.device):
n, m = weight.size(0), weight.size(1)
out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda")
stream = torch.cuda.current_stream()
gridDim = (n, 1, 1)
blockDim = (min(round_up(m, 32), 1024), 1, 1)
func(
gridDim,
blockDim,
0,
stream,
[
ctypes.c_void_p(weight.data_ptr()),
ctypes.c_void_p(scale_list.data_ptr()),
ctypes.c_void_p(out.data_ptr()),
ctypes.c_int32(n),
ctypes.c_int32(m),
],
)
return out
if __name__ == "__main__":
weight = torch.randn(4, 32).to(torch.int8).cuda()
scale = torch.ones(weight.size(0)).to(torch.half).cuda()
print(weight)
b = compress_int4_weight(weight)
print(b)
a = extract_weight_to_half(b, scale, source_bit_width=4)
print(a)