mirror of https://github.com/THUDM/CodeGeeX.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
2.0 KiB
Python
51 lines
2.0 KiB
Python
# Copyright (c) OpenAI (https://openai.com)
|
|
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
|
|
# The above copyright notice and this permission notice shall be included in
|
|
# all copies or substantial portions of the Software.
|
|
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|
# THE SOFTWARE.
|
|
# ============================================================================
|
|
import itertools
|
|
import numpy as np
|
|
|
|
from typing import *
|
|
|
|
|
|
def estimate_pass_at_k(
|
|
num_samples: Union[int, List[int], np.ndarray],
|
|
num_correct: Union[List[int], np.ndarray],
|
|
k: int
|
|
) -> np.ndarray:
|
|
"""
|
|
Estimates pass@k of each problem and returns them in an array.
|
|
"""
|
|
|
|
def estimator(n: int, c: int, k: int) -> float:
|
|
"""
|
|
Calculates 1 - comb(n - c, k) / comb(n, k).
|
|
"""
|
|
if n - c < k:
|
|
return 1.0
|
|
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
|
|
|
if isinstance(num_samples, int):
|
|
num_samples_it = itertools.repeat(num_samples, len(num_correct))
|
|
else:
|
|
assert len(num_samples) == len(num_correct)
|
|
num_samples_it = iter(num_samples)
|
|
|
|
return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
|