给定整数数组
[1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5]
我需要遮盖重复超过N
次的元素。需要说明的是:主要目标是检索布尔掩码数组,以后再用于装箱计算。
我想出了一个相当复杂的解决方案
import numpy as np
bins = np.array([1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5])
N = 3
splits = np.split(bins, np.where(np.diff(bins) != 0)[0]+1)
mask = []
for s in splits:
if s.shape[0] <= N:
mask.append(np.ones(s.shape[0]).astype(np.bool_))
else:
mask.append(np.append(np.ones(N), np.zeros(s.shape[0]-N)).astype(np.bool_))
mask = np.concatenate(mask)
给例如
bins[mask]
Out[90]: array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5])
有没有更好的方法可以做到这一点?
编辑,#2
非常感谢您的回答!这是MSeifert基准测试图的精简版。感谢您将我指向simple_benchmark
。仅显示4个最快的选项:
结论
由Florian H提出并由Paul Panzer修改的想法似乎是解决此问题的好方法,因为它很简单并且仅numpy
。但是,如果使用numba
没问题,则MSeifert's solution会胜过另一个。
我选择接受MSeifert的答案作为解决方案,因为它是更笼统的答案:它可以正确地处理带有(非唯一)连续重复元素块的任意数组。如果numba
不可行,Divakar's answer也值得一看!
python大神给出的解决方案
我想提出一个使用numba的解决方案,该解决方案应该很容易理解。我假设您要“屏蔽”连续的重复项:
import numpy as np
import numba as nb
@nb.njit
def mask_more_n(arr, n):
mask = np.ones(arr.shape, np.bool_)
current = arr[0]
count = 0
for idx, item in enumerate(arr):
if item == current:
count += 1
else:
current = item
count = 1
mask[idx] = count <= n
return mask
例如:
>>> bins = np.array([1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5])
>>> bins[mask_more_n(bins, 3)]
array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5])
>>> bins[mask_more_n(bins, 2)]
array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5])
性能:
使用simple_benchmark
-但是我没有包括所有方法。这是对数-对数比例:
似乎numba解决方案无法胜过Paul Panzer的解决方案,后者对于大型阵列来说似乎要快一点(并且不需要其他依赖项)。
但是,两者似乎都胜过其他解决方案,但是它们确实返回掩码而不是“过滤”数组。
import numpy as np
import numba as nb
from simple_benchmark import BenchmarkBuilder, MultiArgument
b = BenchmarkBuilder()
bins = np.array([1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5])
@nb.njit
def mask_more_n(arr, n):
mask = np.ones(arr.shape, np.bool_)
current = arr[0]
count = 0
for idx, item in enumerate(arr):
if item == current:
count += 1
else:
current = item
count = 1
mask[idx] = count <= n
return mask
@b.add_function(warmups=True)
def MSeifert(arr, n):
return mask_more_n(arr, n)
from scipy.ndimage.morphology import binary_dilation
@b.add_function()
def Divakar_1(a, N):
k = np.ones(N,dtype=bool)
m = np.r_[True,a[:-1]!=a[1:]]
return a[binary_dilation(m,k,origin=-(N//2))]
@b.add_function()
def Divakar_2(a, N):
k = np.ones(N,dtype=bool)
return a[binary_dilation(np.ediff1d(a,to_begin=a[0])!=0,k,origin=-(N//2))]
@b.add_function()
def Divakar_3(a, N):
m = np.r_[True,a[:-1]!=a[1:],True]
idx = np.flatnonzero(m)
c = np.diff(idx)
return np.repeat(a[idx[:-1]],np.minimum(c,N))
from skimage.util import view_as_windows
@b.add_function()
def Divakar_4(a, N):
m = np.r_[True,a[:-1]!=a[1:]]
w = view_as_windows(m,N)
idx = np.flatnonzero(m)
v = idx<len(w)
w[idx[v]] = 1
if v.all()==0:
m[idx[v.argmin()]:] = 1
return a[m]
@b.add_function()
def Divakar_5(a, N):
m = np.r_[True,a[:-1]!=a[1:]]
w = view_as_windows(m,N)
last_idx = len(a)-m[::-1].argmax()-1
w[m[:-N+1]] = 1
m[last_idx:last_idx+N] = 1
return a[m]
@b.add_function()
def PaulPanzer(a,N):
mask = np.empty(a.size,bool)
mask[:N] = True
np.not_equal(a[N:],a[:-N],out=mask[N:])
return mask
import random
@b.add_arguments('array size')
def argument_provider():
for exp in range(2, 20):
size = 2**exp
yield size, MultiArgument([np.array([random.randint(0, 5) for _ in range(size)]), 3])
r = b.run()
import matplotlib.pyplot as plt
plt.figure(figsize=[10, 8])
r.plot()
在Numpy数组中替换子数组 - python我正在尝试将Numpy数组中的子数组替换为形状相同的数组,以使所有更改都在两个数组中得到镜像。我在IDLE中运行了以下代码。import numpy a=numpy.zeros((2,1)) a array([[0.], [0.]]) b=numpy.zeros((1)) b array([0.]) a[0]=b b[0]=1 b array([1.]) 现…
Python sqlite3数据库已锁定 - python我在Windows上使用Python 3和sqlite3。我正在开发一个使用数据库存储联系人的小型应用程序。我注意到,如果应用程序被强制关闭(通过错误或通过任务管理器结束),则会收到sqlite3错误(sqlite3.OperationalError:数据库已锁定)。我想这是因为在应用程序关闭之前,我没有正确关闭数据库连接。我已经试过了: connectio…
用傅立叶变换做卷积? - python根据卷积定理(links),我们可以将傅立叶变换算子进行卷积。使用python和scripy,我的代码在下面,但不正确。你能帮我解释一下吗?import tensorflow as tf import sys from scipy import signal from scipy import linalg import numpy as np x = [[…
如何在python中将.npz格式转换为.csv? - python我是python的新手。我想将.npz file(.npz是一种numpy文件格式)转换为.csv文件,以便在R中使用它。请提出一种方法 python大神给出的解决方案 尝试类似的方法:import numpy as np data = np.load(filename) for key, value in data.items(): np.savetxt(…
并行dask for循环比常规循环慢? - python如果我尝试用dask并行化for循环,它的执行速度将比常规版本慢。基本上,我只是按照dask教程中的介绍性示例进行操作,但是由于某种原因,它最终还是失败了。我究竟做错了什么?In [1]: import numpy as np ...: from dask import delayed, compute ...: import dask.multiproce…