• 微信公众号：美女很有趣。 工作之余，放松一下，关注即送10G+美女照片！

# python3使用迭代生成器yield减少内存占用

1周前 (05-04) 10次浏览

# 技术背景

python编码中for循环处理任务时，会将所有的待遍历参量加载到内存中。其实这本没有必要，因为这些参量很有可能是一次性使用的，甚至很多场景下这些参量是不需要同时存储在内存中的，这时候就会用到本文所介绍的迭代生成器yield。

# 基本使用

``````# test_yield.py

def square_number(length):
s = []
for i in range(length):
s.append(i ** 2)
return s

def square_number_yield(length):
for i in range(length):
yield i ** 2

if __name__ == '__main__':
length = 10
sn1 = square_number(length)
sn2 = square_number_yield(length)
for i in range(length):
print (sn1[i], 't', end='')
print (next(sn2))
``````

``````[dechin@dechin-manjaro yield]\$ python3 test_yield.py
0       0
1       1
4       4
9       9
16      16
25      25
36      36
49      49
64      64
81      81
``````

``````# test_yield.py

def square_number(length):
s = []
for i in range(length):
s.append(i ** 2)
return s

def square_number_yield(length):
for i in range(length):
yield i ** 2

if __name__ == '__main__':
length = 10
sn1 = square_number(length)
sn2 = square_number_yield(length)
sn3 = list(square_number_yield(length))
for i in range(length):
print (sn1[i], 't', end='')
print (next(sn2), 't', end='')
print (sn3[i])
``````

``````[dechin@dechin-manjaro yield]\$ python3 test_yield.py
0       0       0
1       1       1
4       4       4
9       9       9
16      16      16
25      25      25
36      36      36
49      49      49
64      64      64
81      81      81
``````

# 进阶测试

``````# square_sum.py

import tracemalloc
import time
import numpy as np
tracemalloc.start()

start_time = time.time()
ss_list = np.random.randn(100000)
s = 0
for ss in ss_list:
s += ss ** 2
end_time = time.time()
print ('Time cost is: {}s'.format(end_time - start_time))

snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')

for stat in top_stats[:5]:
print (stat)
``````

``````# yield_square_sum.py

import tracemalloc
import time
import numpy as np
tracemalloc.start()

start_time = time.time()
def ss_list(length):
for i in range(length):
yield np.random.random()

s = 0
ss = ss_list(100000)
for i in range(100000):
s += next(ss) ** 2
end_time = time.time()
print ('Time cost is: {}s'.format(end_time - start_time))

snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')

for stat in top_stats[:5]:
print (stat)
``````

``````[dechin@dechin-manjaro yield]\$ python3 square_sum.py
Time cost is: 0.24723434448242188s
square_sum.py:9: size=781 KiB, count=2, average=391 KiB
square_sum.py:12: size=24 B, count=1, average=24 B
square_sum.py:11: size=24 B, count=1, average=24 B
[dechin@dechin-manjaro yield]\$ python3 yield_square_sum.py
Time cost is: 0.23023390769958496s
yield_square_sum.py:9: size=136 B, count=1, average=136 B
yield_square_sum.py:14: size=112 B, count=1, average=112 B
yield_square_sum.py:11: size=79 B, count=2, average=40 B
yield_square_sum.py:10: size=76 B, count=2, average=38 B
yield_square_sum.py:15: size=28 B, count=1, average=28 B
``````

# 无限长迭代器

``````def get_primes(number):
while True:
if is_prime(number):
yield number
number += 1
``````

``````# yield_iter.py

def yield_range2(i):
while True:
yield i
i += 2

iter = yield_range2(0)
for i in range(10):
print (next(iter))
``````

``````[dechin@dechin-manjaro yield]\$ python3 yield_iter.py
0
2
4
6
8
10
12
14
16
18
``````