mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
import pkg_resources
|
|
import torch
|
|
import ctypes
|
|
|
|
from typing import List
|
|
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
|
|
|
|
RESOURCE_PACKAGE_NAME = __name__
|
|
|
|
|
|
class Kernel:
|
|
def __init__(self, filename: str, function_names: List[str]):
|
|
filename = filename + ".fatbin"
|
|
if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename):
|
|
raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME))
|
|
self.filename = filename
|
|
self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename)
|
|
self._function_names = function_names
|
|
self._cmodule = LazyKernelCModule(self.code)
|
|
|
|
for name in self._function_names:
|
|
setattr(self, name, KernelFunction(self._cmodule, name))
|
|
|
|
|
|
kernels = Kernel(
|
|
"quantization",
|
|
[
|
|
"int4WeightCompression",
|
|
"int4WeightExtractionFloat",
|
|
"int4WeightExtractionHalf",
|
|
"int8WeightExtractionFloat",
|
|
"int8WeightExtractionHalf",
|
|
],
|
|
)
|
|
|
|
|
|
def compress_int4_weight(weight: torch.Tensor): # (n, m)
|
|
with torch.cuda.device(weight.device):
|
|
n, m = weight.size(0), weight.size(1)
|
|
assert m % 2 == 0
|
|
m = m // 2
|
|
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
|
|
stream = torch.cuda.current_stream()
|
|
|
|
gridDim = (n, 1, 1)
|
|
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
|
|
|
kernels.int4WeightCompression(
|
|
gridDim,
|
|
blockDim,
|
|
0,
|
|
stream,
|
|
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
|
|
)
|
|
return out
|
|
|
|
|
|
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
|
|
if source_bit_width == 8:
|
|
func = kernels.int8WeightExtractionHalf
|
|
elif source_bit_width == 4:
|
|
func = kernels.int4WeightExtractionHalf
|
|
else:
|
|
assert False, "Unsupported bit-width"
|
|
|
|
with torch.cuda.device(weight.device):
|
|
n, m = weight.size(0), weight.size(1)
|
|
out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda")
|
|
stream = torch.cuda.current_stream()
|
|
|
|
gridDim = (n, 1, 1)
|
|
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
|
|
|
func(
|
|
gridDim,
|
|
blockDim,
|
|
0,
|
|
stream,
|
|
[
|
|
ctypes.c_void_p(weight.data_ptr()),
|
|
ctypes.c_void_p(scale_list.data_ptr()),
|
|
ctypes.c_void_p(out.data_ptr()),
|
|
ctypes.c_int32(n),
|
|
ctypes.c_int32(m),
|
|
],
|
|
)
|
|
return out
|
|
|
|
|
|
if __name__ == "__main__":
|
|
weight = torch.randn(4, 32).to(torch.int8).cuda()
|
|
scale = torch.ones(weight.size(0)).to(torch.half).cuda()
|
|
|
|
print(weight)
|
|
b = compress_int4_weight(weight)
|
|
print(b)
|
|
|
|
a = extract_weight_to_half(b, scale, source_bit_width=4)
|
|
print(a)
|