本篇博客对 yield 不求甚解,用简短通俗的语言说明其功能及优势。
1 用法
对于函数返回 yield 的通俗理解就是返回了一个存储函数的地址,在某个空间有这个暂时用不到的函数,而这个函数本来是要返回一个容器,比方 list 。
下面具体说明:
在谈 yield 之前,先说说 list 。
list 数据是一种容器,包含其他元素,可以迭代的从众逐个获取元素,举个栗子:
def func1():
return [1,2]
a = func1()
for i in a:
print(i)
输出:
1
2
下面谈谈 yield 。
包含 yield 的函数为生成器函数。这种函数被调用后,函数内部的代码不会立刻执行,而是返回一个生成器。
生成器是特殊的迭代器,迭代器是一种对象(啥是对象,emmmmm.......),抽象理解是一种数据流。
举个栗子:
def func2():
yield [1,2]
b = func2()
print(b)
输出:
<generator object func2 at 0x000000001812F8E0>
只有当生成器调用成员方法时,生成器中的代码才会执行。
举个栗子:
def func3():
for x in range(2):
yield x ** 2
c = func3()
for x in c:
print(x)
输出:
0
1
2 好处
很多时候我们是逐个访问容器中的元素,而不是一下子获得所有的元素。
举个栗子:
对于 [1, 2, 3, 4],我们想依次获得列表里面的前两个数字,有以下两种方法:
1. 获得所有元素,然后取出前2个;
2. 从头逐个迭代,到第二个元素后终止。
显然,第二种方法节省时间空间开销。
3 实际应用
当我们从样本中批量获得数据输入模型的时候,我们就可以用 yield 表达式。
举个栗子:
def minibatches(inputs_data=None, labels=None, batch_size=None, shuffle=False):
'''
inputs_data 为 numpy array 数据
'''
assert len(inputs_data) == len(labels)
if shuffle:
indices = np.arange(len(inputs_data))
np.random.shuffle(indices)
for start_idx in range(0, len(inputs_data) - batch_size + 1, batch_size):
if shuffle:
excerpt = indices[start_idx:start_idx + batch_size]
else:
excerpt = slice(start_idx, start_idx + batch_size)
yield inputs_data[excerpt], labels[excerpt]
for x_train, y_train in minibatches(x_train, y_train, batch_size, shuffle=True):
model(x_train, y_train)
更加详细内容可以参考这篇博客:https://liam.page/2017/06/30/understanding-yield-in-python/