提问人:Script Raccoon 提问时间:10/28/2023 最后编辑:Script Raccoon 更新时间:10/29/2023 访问量:123
射影函数集的高效计算
Efficient computation of the set of surjective functions
问:
当 的每个元素在 中至少有一个前像时,函数是射射函数。当 和 是两个有限集合时,则对应于一个数字的元组,并且当每个数字至少出现一次时,它就是投射的。(当我们要求每个数字只出现一次时,我们已经并且正在谈论排列。f : X -> Y
Y
X
X = {0,...,m-1}
Y = {0,...,n-1}
f
m
< n
< n
n=m
我想知道一种有效的算法,用于计算两个给定数字和 的所有射射元组的集合。这些元组的数量可以通过包含-排除原则非常有效地计算(参见此处的示例),但我认为这在这里没有用(因为我们会首先计算所有元组,然后逐步删除非投射元组,并且我假设所有元组的计算将花费更长的时间*)。另一种方法如下:n
m
例如,考虑元组
(1,6,4,2,1,6,0,2,5,1,3,2,3)
其中每个数字< 7 至少出现一次。看最大的 编号并擦除它:
(1,*,4,2,1,*,0,2,5,1,3,2,3)
它出现在索引 1 和 5 中,因此这对应于集合 ,索引的子集。
其余部分对应于元组{1,5}
(1,4,2,1,0,2,5,1,3,2,3)
属性< 6 的每个数字至少出现一次。
我们看到数的射射元组对应于对,其中 是 的非空子集,是数的射射 -元组,其中 有元素。m
< n
(T,a)
T
{0,...,m-1}
a
(m-k)
< n-1
T
k
这导致了以下递归实现(用 Python 编写):
import itertools
def surjective_tuples(m: int, n: int) -> set[tuple]:
"""Set of all m-tuples of numbers < n where every number < n appears at least once.
Arguments:
m: length of the tuple
n: number of distinct values
"""
if n == 0:
return set() if m > 0 else {()}
if n > m:
return set()
result = set()
for k in range(1, m + 1):
smaller_tuples = surjective_tuples(m - k, n - 1)
subsets = itertools.combinations(range(m), k)
for subset in subsets:
for smaller_tuple in smaller_tuples:
my_tuple = []
count = 0
for i in range(m):
if i in subset:
my_tuple.append(n - 1)
count += 1
else:
my_tuple.append(smaller_tuple[i - count])
result.add(tuple(my_tuple))
return result
不过,我注意到,当输入数字很大时,这很慢。例如,当我的(旧)PC 上的计算需要几秒钟时,集合在这里有元素。我怀疑有一种更快的算法。(m,n)=(10,6)
32
16435440
*事实上,下面的实现非常缓慢。
def surjective_tuples_stupid(m: int, n: int) -> list[int]:
all_tuples = list(itertools.product(*(range(n) for _ in range(m))))
surjective_tuples = filter(lambda t: all(i in t for i in range(n)), all_tuples)
return list(surjective_tuples)
答:
我设法使用 s multiset_permutations
提高了几个百分点的性能:sympy
from itertools import combinations_with_replacement
from sympy.utilities.iterables import multiset_permutations
def get_combs(s, n):
for c in combinations_with_replacement(range(1, s), n):
if sum(c) == s:
yield c
def surjective_tuples_new(s, n):
for c in get_combs(s, n):
for p in multiset_permutations(c):
out = []
for i, n in enumerate(p):
out.extend(i for _ in range(n))
yield from multiset_permutations(out)
基准:
from timeit import timeit
assert sorted(surjective_tuples_new(10, 8)) == list(
map(list, sorted(surjective_tuples(10, 8)))
)
t1 = timeit("surjective_tuples(10, 8)", number=1, globals=globals())
t2 = timeit("list(surjective_tuples_new(10, 8))", number=1, globals=globals())
t3 = timeit("surjective_tuples_kelly(10, 8)", number=1, globals=globals())
print(t1)
print(t2)
print(t3)
在我的机器上打印(AMD 5700x,Python 3.11):
27.863450561184436
22.92276939912699
9.207325214054435
编辑:将@kelly的答案添加到基准测试中。
评论
只是稍微优化了你的,主要是通过使用来构建元组。m=9,n=7 时,大约比你的快 5 倍。insert
def surjective_tuples(m: int, n: int) -> list[tuple]:
"""List of all m-tuples of numbers < n where every number < n appears at least once.
Arguments:
m: length of the tuple
n: number of distinct values
"""
if not n:
return [] if m else [()]
if n > m:
return []
n -= 1
result = []
for k in range(1, m - n + 1):
smaller_tuples = surjective_tuples(m - k, n)
subsets = itertools.combinations(range(m), k)
for subset in subsets:
for smaller_tuple in smaller_tuples:
my_tuple = [*smaller_tuple]
for i in subset:
my_tuple.insert(i, n)
result.append(tuple(my_tuple))
return result
评论
.insert
functools.reduce
map
set
评论