Effective Python摘录。
一. 培养Pythonic思维
1. 查询自己使用的python版本
import sys
print(sys.version_info)
print(sys.version)
# 代码自动检查:https://pylint.org/ pip install pylint
2. 遵循PEP8风格指南
官网:https://pep8.org/
中文翻译:https://www.cnblogs.com/bymo/p/9567140.html
3. 了解bytes与str的区别
bytes:bytes的实例包含原始的8个字节
str:str的实例包含Unicode字符bytes
# 接受str或bytes,并总是返回str的方法
def to_str(bytes_or_str):
if isinstance(bytes_or_str,bytes):
value = bytes_or_str.decode('utf-8')
else:
value = bytes_or_str
return value
# 接受str或bytes,并总是返回bytes的方法:
def to_bytes(bytes_or_str):
if isinstance(bytes_or_str,str):
value = bytes_or_str.encode('utf-8')
else:
value = bytes_or_str
return value
4. 用支持插值的f-string取代C风格的格式字符串和str.format方法
pantry = [
('avocados', 1.25),
('bananas', 2.5),
('cherries', 15),
]
for i, (item, count) in enumerate(pantry):
print(f'{i+1}: {item.title():<10s} = {round(count)}')
5. 用辅助函数取代复杂的表达式
from urllib.parse import parse_qs
my_values = parse_qs('red=5&blue=0&green=3', keep_blank_values=True)
print(repr(my_values))
green_str = my_values.get('green', [''])
if green_str[0]:
green = int(green_str[0])
else:
green = 0
print(f'Green: {green!r}')
def get_first_int(values, key, default=0):
found = values.get(key, [''])
if found[0]:
return int(found[0])
return default
green = get_first_int(my_values, 'green')
print(f'Green: {green!r}')
6. 把数据结构直接拆分到多个变量里,不要专门通过下标访问
snacks = [('bacon', 350), ('donut', 240), ('muffin', 190)]
for i in range(len(snacks)):
item = snacks[i]
name = item[0]
calories = item[1]
print(f'#{i + 1}: {name} has {calories} calories')
for rank, (name, calories) in enumerate(snacks, 1):
print(f'#{rank}: {name} has {calories} calories')
7. 尽量用enumerate取代range
flavor_list = ['vanilla', 'chocolate', 'pecan', 'strawberry']
for flavor in flavor_list:
print(f'{flavor} is delicious')
for i in range(len(flavor_list)):
flavor = flavor_list[i]
print(f'{i + 1}: {flavor}')
# Example
it = enumerate(flavor_list)
print(next(it))
print(next(it))
for i, flavor in enumerate(flavor_list):
print(f'{i}: {flavor}')
for i, flavor in enumerate(flavor_list, 2): # 从2开始
print(f'enumerate {i}: {flavor}')
8. 用zip函数同时遍历两个迭代器
# Example 1
names = ['Cecilia', 'Lise', 'Marie']
counts = [len(n) for n in names]
print(counts, len(counts), len(names))
# Example 2
longest_name = 'None11'
max_count = 0
for i in range(len(names)):
count = counts[i]
if count > max_count:
longest_name = names[i]
max_count = count
print(longest_name)
# Example 3
longest_name = None
max_count = 0
for i, name in enumerate(names):
count = counts[i]
if count > max_count:
longest_name = name
max_count = count
print('--', longest_name)
assert longest_name == 'Cecilia'
# Example 4
longest_name = None
max_count = 0
for name, count in zip(names, counts):
if count > max_count:
longest_name = name
max_count = count
assert longest_name == 'Cecilia'
# Example 5
names.append('Rosalind')
# counts.append(8)
print(f'names: {names}, counts: {counts}')
for name, count in zip(names, counts):
print(name, count)
# Example 6
import itertools
for name, count in itertools.zip_longest(names, counts):
print(f'itertools: {name}: {count}')
9. 不要在for与while循环后面写else块
# Example 1
for i in range(3):
print('Loop', i)
else:
print('Else block!')
# Example
a = 4
b = 9
for i in range(2, min(a, b) + 1):
print('Testing', i)
if a % i == 0 and b % i == 0:
print('Not coprime')
break
else:
print('Coprime')
# 下面是应该的写法:方法一,只要发现某个方法成立,就立刻返回
def coprime(a, b):
for i in range(2, min(a, b) + 1):
if a % i == 0 and b % i == 0:
return False
return True
assert coprime(4, 9)
assert not coprime(3, 6)
# 方法二,用变量记录循环过程中与没有碰到成立的情况,返回这个变量的值
def coprime_alternate(a, b):
is_coprime = True
for i in range(2, min(a, b) + 1):
if a % i == 0 and b % i == 0:
is_coprime = False
break
return is_coprime
assert coprime_alternate(4, 9)
assert not coprime_alternate(3, 6)
10. 用赋值表达式减少重复代码
# -*- encoding: utf-8 -*-
"""
赋值表达式通过海象操作符(:=)给变量赋值,并且让这个值成为这个表达式的结构
"""
FRUIT_TO_PICK = [
{'apple': 1, 'banana': 3},
{'lemon': 2, 'lime': 5},
{'orange': 3, 'melon': 2},
]
def pick_fruit():
if FRUIT_TO_PICK:
return FRUIT_TO_PICK.pop(0)
else:
return []
def make_juice(fruit, count):
return [(fruit, count)]
bottles = []
while fresh_fruit := pick_fruit():
for fruit, count in fresh_fruit.items():
batch = make_juice(fruit, count)
bottles.extend(batch)
print(bottles)
二. 列表与字典
11. 学会对序列做切片
a = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
print('Middle two: ', a[3:5])
print('--:', a[-3:-1])
assert a[:5] == a[0:5]
assert a[5:] == a[5:len(a)]
print(a[2:-1])
print(a[-3:-1])
12. 不要在切片里同时 起止下标与步进
x = ['red', 'orange', 'yellow', 'green', 'blue', 'purple']
print(f'odds: {x[::2]}') # odds: ['red', 'yellow', 'blue']
print(f'evens: {x[1::2]}') # evens: ['orange', 'green', 'purple']
print(b'mongoose'[::-1]) # b'esoognom'
13. 通过带星号的unpacking操作来捕获多个元素,不要用切片
car_inventory = {
'Downtown': ('Silver Shadow', 'Pinto', 'DMC'),
'Airport': ('Skyline', 'Viper', 'Gremlin', 'Nova'),
}
((loc1, (best1, *rest1)), (loc2, (best2, *rest2))) = car_inventory.items()
print(f'Best at {loc1} is {best1}, {len(rest1)} others') # Best at Downtown is Silver Shadow, 2 others
print(f'Best at {loc2} is {best2}, {len(rest2)} others') # Best at Airport is Skyline, 3 others
short_list = [1, 2]
first, second, *rest = short_list
print(first, second, rest) # 3 1 2 []
def generate_csv():
yield 'Date', 'Make', 'Model', 'Year', 'Price'
for i in range(3):
yield '2019-03-25', 'Honda', 'Fit', '2010', '$3400'
yield '2019-03-26', 'Ford', 'F150', '2008', '$2400'
all_csv_rows = list(generate_csv())
print(all_csv_rows)
print('CSV Header:', all_csv_rows[0]) # CSV Header: ('Date', 'Make', 'Model', 'Year', 'Price')
print('Row count: ', len(all_csv_rows[1:])) # Row count: 6
# Example 12
it = generate_csv()
header, *rows = it
print('CSV Header:', header) # CSV Header: ('Date', 'Make', 'Model', 'Year', 'Price')
print('Row count: ', len(rows)) # Row count: 6
14. 用sort方法的key参数来表示复杂的排序逻辑
numbers = [93, 86, 11, 68, 70]
numbers.sort(reverse=True) # 由大到小排列
print(numbers)
# Example 2
class Tool:
def __init__(self, name, weight):
self.name = name
self.weight = weight
def __repr__(self):
return f'Tool({self.name!r}, {self.weight})'
tools = [
Tool('level', 3.5),
Tool('hammer', 1.25),
Tool('screwdriver', 0.5),
Tool('chisel', 0.25),
]
# Example 4
print('Unsorted:', repr(tools))
tools.sort(key=lambda x: x.name)
print('Sorted: ', tools) # [Tool('chisel', 0.25), Tool('hammer', 1.25), Tool('level', 3.5), Tool('screwdriver', 0.5)]
# Example 5
tools.sort(key=lambda x: x.weight)
print('By weight:', tools) # [Tool('chisel', 0.25), Tool('screwdriver', 0.5), Tool('hammer', 1.25), Tool('level', 3.5)]
# Example 6
places = ['home', 'work', 'New York', 'Paris']
places.sort()
print('Case sensitive: ', places) # ['New York', 'Paris', 'home', 'work']
places.sort(key=lambda x: x.lower())
print('Case insensitive:', places) # ['home', 'New York', 'Paris', 'work']
15. 不要过分依赖给字典添加条目时所用的顺序
baby_names = {
'cat': 'kitten',
'dog': 'puppy',
}
print(baby_names)
print(list(baby_names.keys()))
print(list(baby_names.values()))
print(list(baby_names.items())) # [('cat', 'kitten'), ('dog', 'puppy')]
print(baby_names.popitem()) # Last item inserted : ('dog', 'puppy')
class MyClass:
def __init__(self):
self.alligator = 'hatchling'
self.elephant = 'calf'
a = MyClass()
for key, value in a.__dict__.items():
print(f'{key} = {value}')
# Example 9
votes = {
'otter': 1281,
'polar bear': 587,
'fox': 863,
}
def populate_ranks(votes, ranks):
names = list(votes.keys())
names.sort(key=votes.get, reverse=True)
for i, name in enumerate(names, 1):
ranks[name] = i
ranks = {}
populate_ranks(votes, ranks)
print(ranks) # {'otter': 1, 'fox': 2, 'polar bear': 3}
print(next(iter(ranks))) # otter
from collections.abc import MutableMapping
class SortedDict(MutableMapping):
def __init__(self):
self.data = {}
def __getitem__(self, key):
return self.data[key]
def __setitem__(self, key, value):
self.data[key] = value
def __delitem__(self, key):
del self.data[key]
def __iter__(self):
keys = list(self.data.keys())
keys.sort()
for key in keys:
yield key
def __len__(self):
return len(self.data)
my_dict = SortedDict()
my_dict['otter'] = 1
my_dict['cheeta'] = 2
my_dict['anteater'] = 3
my_dict['deer'] = 4
assert my_dict['otter'] == 1
assert 'cheeta' in my_dict
del my_dict['cheeta']
assert 'cheeta' not in my_dict
print(my_dict)
expected = [('anteater', 3), ('deer', 4), ('otter', 1)]
assert list(my_dict.items()) == expected
assert not isinstance(my_dict, dict)
# Example 14
sorted_ranks = SortedDict()
populate_ranks(votes, sorted_ranks)
print(sorted_ranks.data) # {'otter': 1, 'fox': 2, 'polar bear': 3}
print(next(iter(sorted_ranks))) # fox
16. 用get处理键不在字典中的情况,不要使用in与KeyError
有4中方法处理键不在字典中的情况:in表达式,KeyError异常,get方法和setdefault方法。
counters = {
'pumpernickel': 2,
'sourdough': 1,
}
key = 'multigrain'
count = counters.get(key, 0)
counters[key] = count + 1
print(counters) # {'pumpernickel': 2, 'sourdough': 1, 'multigrain': 1}
votes = {
'baguette': ['Bob', 'Alice'],
'ciabatta': ['Coco', 'Deb'],
}
key = 'brioche'
who = 'Elmer'
print(f'votes pre: {votes}') # {'baguette': ['Bob', 'Alice'], 'ciabatta': ['Coco', 'Deb']}
if key in votes:
names = votes[key]
else:
votes[key] = names = []
print(votes) # {'baguette': ['Bob', 'Alice'], 'ciabatta': ['Coco', 'Deb'], 'brioche': []}
names.append(who)
print(f'votes: {votes}') # {'baguette': ['Bob', 'Alice'], 'ciabatta': ['Coco', 'Deb'], 'brioche': ['Elmer']}
key = 'cornbread'
who = 'Kirk'
names = votes.setdefault(key, [])
names.append(who)
print(votes) # {'baguette': ['Bob', 'Alice'], 'ciabatta': ['Coco', 'Deb'], 'brioche': ['Elmer'], 'cornbread': ['Kirk']}
data = {}
key = 'foo'
value = []
data.setdefault(key, value)
print('Before:', data) # Before: {'foo': []}
value.append('hello')
print('After: ', data) # After: {'foo': ['hello']}
17. 用defaultdict处理内部状态中缺失的元素,而不要用setdefault
class Visits:
def __init__(self):
self.data = {}
def add(self, country, city):
city_set = self.data.setdefault(country, set())
city_set.add(city)
visits = Visits()
visits.add('Russia', 'Yekaterinburg')
visits.add('Tanzania', 'Zanzibar')
print(visits.data) # {'Russia': {'Yekaterinburg'}, 'Tanzania': {'Zanzibar'}}
print('-----------------------------')
from collections import defaultdict # 推荐
class Visits:
def __init__(self):
self.data = defaultdict(set)
def add(self, country, city):
self.data[country].add(city)
visits = Visits()
visits.add('England', 'Bath')
visits.add('England', 'London')
print(visits.data) # defaultdict(<class 'set'>, {'England': {'Bath'}, 'England1': {'London'}})
18. 学会利用__missing__构造依赖键的默认值
path = 'account_9090.csv'
with open(path, 'wb') as f:
f.write(b'image data here 9090')
def open_picture(profile_path):
try:
return open(profile_path, 'a+b')
except OSError:
print(f'Failed to open path {profile_path}')
raise
class Pictures(dict):
def __missing__(self, key):
value = open_picture(key)
self[key] = value
return value
pictures = Pictures()
handle = pictures[path]
handle.seek(0)
image_data = handle.read()
print(pictures)
print(image_data)
19. 不要把函数返回的多个数值拆分到三个以上的变量中
def get_stats(numbers):
minimum = min(numbers)
maximum = max(numbers)
return minimum, maximum
lengths = [63, 73, 72, 60, 67, 66, 71, 61, 72, 70]
minimum, maximum = get_stats(lengths) # Two return values
print(f'Min: {minimum}, Max: {maximum}')
三. 函数
20. 遇到意外状况时应该抛出异常,不要返回None
def careful_divide(a, b):
try:
return a / b
except ZeroDivisionError:
return None
assert careful_divide(4, 2) == 2
assert careful_divide(0, 1) == 0
assert careful_divide(3, 6) == 0.5
assert careful_divide(1, 0) == None
def careful_divide(a: float, b: float) -> float:
"""Divides a by b.
Raises:
ValueError: When the inputs cannot be divided.
"""
try:
return a / b
except ZeroDivisionError as e:
print(f'result: {e}')
raise ValueError(f'Invalid inputs: {e}')
try:
result = careful_divide(1, 0)
assert False
except ValueError:
print(f'result:')
pass # Expected
assert careful_divide(1, 5) == 0.2
21. 了解如何在闭包里面使用外围作用域中的变量
def sort_priority3(numbers, group):
found = False
def helper(x):
# 把闭包里面的数据赋值给闭包外面的变量
nonlocal found # Added
if x in group:
found = True
return (0, x)
return (1, x)
numbers.sort(key=helper)
return found
numbers = [8, 3, 1, 2, 5, 4, 7, 6]
group = {2, 3, 5, 7}
found = sort_priority3(numbers, group)
assert found
assert numbers == [2, 3, 5, 7, 1, 4, 6, 8]
print('--------------下面是用辅助类来封装状态')
class Sorter:
def __init__(self, group):
self.group = group
self.found = False
def __call__(self, x):
if x in self.group:
self.found = True
return (0, x)
return (1, x)
sorter = Sorter(group)
numbers.sort(key=sorter)
assert sorter.found is True
assert numbers == [2, 3, 5, 7, 1, 4, 6, 8]
22. 用数量可变的位置参数给函数设计清晰的参数列表
def log(message, *values): # The only difference
if not values:
print(message)
else:
values_str = ', '.join(str(x) for x in values)
print(f'{message}: {values_str}')
log('My numbers are', 1, 2)
log('Hi there') # Much better
23. 用关键字参数来表示可选的行为
def flow_rate(weight_diff, time_diff, period=1):
"""流速:每秒的千克数"""
return (weight_diff / time_diff) * period
weight_diff = 0.5
time_diff = 3
flow = flow_rate(weight_diff, time_diff, period=2)
print(f'{flow:.3} kg per second')
24. 用None和docstring来描述默认值会变的参数
import json
from time import sleep
from datetime import datetime
from typing import Optional
def log(message, when=None):
"""Log a message with a timestamp.
Args:
message: Message to print.
when: datetime of when the message occurred.
Defaults to the present time.
"""
if when is None:
when = datetime.now()
print(f'{when}: {message}')
# Example
log('Hi there!')
sleep(0.1)
log('Hello again!')
def decode(data, default=None):
"""Load JSON data from a string.
Args:
data: JSON data to decode.
default: Value to return if decoding fails.
Defaults to an empty dictionary.
"""
try:
return json.loads(data)
except ValueError:
if default is None:
default = {}
return default
foo = decode('bad data')
foo['stuff'] = 5
bar = decode('also bad')
bar['meep'] = 1
print('Foo:', foo)
print('Bar:', bar)
assert foo is not bar
def log_typed(message: str, when: Optional[datetime] = None) -> None:
"""Log a message with a timestamp.
Args:
message: Message to print.
when: datetime of when the message occurred.
Defaults to the present time.
"""
if when is None:
when = datetime.now()
print(f'{when}: {message}')
log_typed('Hi there!')
sleep(0.1)
log_typed('Hello again!')
25. 用只能以关键字 和只能按位置传入的参数来设计清晰的参数列表
def safe_division_e(numerator, denominator,
ndigits=10, *, # Changed
ignore_overflow=False,
ignore_zero_division=False):
try:
fraction = numerator / denominator # Changed
return round(fraction, ndigits) # Changed
except OverflowError:
if ignore_overflow:
return 0
else:
raise
except ZeroDivisionError:
if ignore_zero_division:
return float('inf')
else:
raise
result = safe_division_e(22, 7)
print(result) # 3.1428571429
result = safe_division_e(22, 7, 5)
print(result) # 3.14286
result = safe_division_e(22, 7, ndigits=2)
print(result) # 3.14
26. 用functools.wraps定义函数修饰器
import pickle
from functools import wraps
def trace(func):
@wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
print(f'{func.__name__}({args!r}, {kwargs!r}) '
f'-> {result!r}')
return result
return wrapper
@trace
def fibonacci(n):
"""Return the n-th Fibonacci number"""
if n in (0, 1):
return n
return fibonacci(n - 2) + fibonacci(n - 1)
help(fibonacci)
print(pickle.dumps(fibonacci))
四. 推导与生成
27. 用列表推导取代map与filter
a = range(1, 10)
even_squares = [x**2 for x in a if x % 2 == 0]
print(even_squares)
alt = map(lambda x: x**2, filter(lambda x: x % 2 == 0, a))
assert even_squares == list(alt)
# 字典推导
even_squares_dict = {x: x**2 for x in a if x % 2 == 0}
threes_cubed_set = {x**3 for x in a if x % 3 == 0}
print(even_squares_dict)
print(threes_cubed_set)
28. 控制推导逻辑的子表达式不要超过两个
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
flat = [x for row in matrix for x in row]
print(flat)
squared = [[x**2 for x in row] for row in matrix]
print(squared)
29. 用赋值表达式消除推导中的重复代码
stock = {
'nails': 125,
'screws': 35,
'wingnuts': 8,
'washers': 24,
}
order = ['screws', 'wingnuts', 'clips']
found = ((name, batches) for name in order
if (batches := get_batches(stock.get(name, 0), 8)))
print(next(found))
print(next(found))
30. 不要让函数直接返回列表,应该让它逐个生成列表里的值
address_lines = """Four score and seven years
ago our fathers brought forth on this
continent a new nation, conceived in liberty,
and dedicated to the proposition that all men
are created equal."""
with open('address.txt', 'w') as f:
f.write(address_lines)
import itertools
with open('address.txt', 'r') as f:
it = index_file(f)
results = itertools.islice(it, 0, 10)
print(list(results))
31. 谨慎地迭代函数所收到的参数
from collections.abc import Iterator
def normalize_defensive(numbers):
if isinstance(numbers, Iterator): # Another way to check
raise TypeError('Must supply a container')
total = sum(numbers)
result = []
for value in numbers:
percent = 100 * value / total
result.append(percent)
return result
visits = [15, 35, 80]
result = normalize_defensive(visits) # No error
print(result, type(result))
it = iter(visits)
try:
normalize_defensive(it)
except TypeError:
pass
else:
assert False
32. 考虑用生成器表达式改写数据量较大的列表推导
import random
with open('my_file.txt', 'w') as f:
for _ in range(10):
f.write('a' * random.randint(0, 100))
f.write('\n')
value = [len(x) for x in open('my_file.txt')]
print(value)
it = (len(x) for x in open('my_file.txt'))
print(it)
print(next(it))
print(next(it))
roots = ((x, x**0.5) for x in it)
print(next(roots))
33. 通过yield from把多个生成器连起来用
import timeit
def child():
for i in range(1_000_000):
yield i
def slow():
for i in child():
yield i
def fast():
yield from child()
baseline = timeit.timeit(
stmt='for _ in slow(): pass',
globals=globals(),
number=50)
print(f'Manual nesting {baseline:.2f}s')
comparison = timeit.timeit(
stmt='for _ in fast(): pass',
globals=globals(),
number=50)
print(f'Composed nesting {comparison:.2f}s')
reduction = -(comparison - baseline) / baseline
print(f'{reduction:.1%} less time')
34. 不要用send给生成器注入数据
import math
def wave_cascading(amplitude_it, steps):
step_size = 2 * math.pi / steps
for step in range(steps):
radians = step * step_size
fraction = math.sin(radians)
amplitude = next(amplitude_it) # Get next input
output = amplitude * fraction
yield output
def complex_wave_cascading(amplitude_it):
yield from wave_cascading(amplitude_it, 3)
yield from wave_cascading(amplitude_it, 4)
yield from wave_cascading(amplitude_it, 5)
def run_cascading():
amplitudes = [7, 7, 7, 2, 2, 2, 2, 10, 10, 10, 10, 10]
it = complex_wave_cascading(iter(amplitudes))
for amplitude in amplitudes:
output = next(it)
if output is None:
print(f'Output is None')
else:
print(f'Output: {output:>5.1f} {amplitude}')
run_cascading()
35. 不要通过throw变换生成器的状态
RESETS = [False, False, True, False, True, False, False, False, False, False, False, False, False]
class Timer:
def __init__(self, period):
self.current = period
self.period = period
def reset(self):
self.current = self.period
def __iter__(self):
while self.current:
self.current -= 1
yield self.current
def run():
timer = Timer(4)
for current in timer:
if RESETS.pop(0):
timer.reset()
print(f'{current} ticks remaining')
run()
36. 考虑用itertools拼装迭代器与生成器
import itertools
print('------- 一. 连接多个迭代器')
# chain: 把多个迭代器从头到尾连成一个迭代器
it = itertools.chain([1, 2, 3], [4, 5, 6])
print(list(it)) # [1, 2, 3, 4, 5, 6]
# repeat: 不停的输出某个值
it = itertools.repeat('hello', 3)
print(list(it)) # ['hello', 'hello', 'hello']
# cycle: 循环的输出某段内容之间的元素
it = itertools.cycle([1, 2])
result = [next(it) for _ in range (10)]
print(result) # [1, 2, 1, 2, 1, 2, 1, 2, 1, 2]
# tee: 让一个迭代器分裂成多个平行的迭代器
it1, it2, it3 = itertools.tee(['first', 'second'], 3)
print(list(it1)) # ['first', 'second']
print(list(it2)) # ['first', 'second']
print(list(it3)) # ['first', 'second']
# zip_longest: 和zip函数类似,但会用默认值填充
keys = ['one', 'two', 'three']
values = [1, 2]
normal = list(zip(keys, values))
print('zip: ', normal) # zip: [('one', 1), ('two', 2)]
it = itertools.zip_longest(keys, values, fillvalue='nope')
longest = list(it)
print('zip_longest:', longest) # zip_longest: [('one', 1), ('two', 2), ('three', 'nope')]
print('------- 二. 过滤源迭代器中的元素')
# islice: 按照下标切割源迭代器
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
first_five = itertools.islice(values, 5)
print('First five: ', list(first_five)) # First five: [1, 2, 3, 4, 5]
middle_odds = itertools.islice(values, 2, 8, 2)
print('Middle odds:', list(middle_odds)) # Middle odds: [3, 5, 7]
# takewhile: 一直从源迭代器里获取元素,直到返回false
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
less_than_seven = lambda x: x < 7
it = itertools.takewhile(less_than_seven, values)
print(list(it)) # [1, 2, 3, 4, 5, 6]
# dropwhile: 一直跳过序列中的元素,直到返回true
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
less_than_seven = lambda x: x < 7
it = itertools.dropwhile(less_than_seven, values)
print(list(it)) # [7, 8, 9, 10]
# filterfalse: 与内置的filter函数相反,输出false的那些元素
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
evens = lambda x: x % 2 == 0
filter_result = filter(evens, values)
print('Filter: ', list(filter_result)) # Filter: [2, 4, 6, 8, 10]
filter_false_result = itertools.filterfalse(evens, values)
print('Filter false:', list(filter_false_result)) # Filter false: [1, 3, 5, 7, 9]
print('------- 三. 从源迭代器中的元素合成新的元素')
# accumulate: 从源迭代器中取出一个元素进行累加,和functools模块中的reduce函数其实是一样的
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
sum_reduce = itertools.accumulate(values)
print('Sum: ', list(sum_reduce)) # Sum: [1, 3, 6, 10, 15, 21, 28, 36, 45, 55]
def sum_modulo_20(first, second):
output = first + second
return output % 20
modulo_reduce = itertools.accumulate(values, sum_modulo_20)
print('Modulo:', list(modulo_reduce)) # Modulo: [1, 3, 6, 10, 15, 1, 8, 16, 5, 15]
# product: 从一个或多个源迭代器中获取元素,并计算笛卡尔积
single = itertools.product([1, 2], repeat=2)
print('Single: ', list(single)) # Single: [(1, 1), (1, 2), (2, 1), (2, 2)]
multiple = itertools.product([1, 2], ['a', 'b'])
print('Multiple:', list(multiple)) # Multiple: [(1, 'a'), (1, 'b'), (2, 'a'), (2, 'b')]
# permutations: 输出迭代器中n个元素形成的每种 有序排列
it = itertools.permutations([1, 2, 3], 2)
print(f'permutations: {list(it)}') # permutations: [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)]
# combinations: 输出迭代器中n个元素形成的每种 无序组合
it = itertools.combinations([1, 2, 3, 4], 2)
print(list(it)) # [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
# combinations_with_replacement:允许同一个元素在组合里多次出现
it = itertools.combinations_with_replacement([1, 2, 3], 2)
print(list(it)) # [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
五. 类与接口
37. 用组合起来的类来实现多层结构,不要用嵌套的内置类型
from collections import defaultdict
from collections import namedtuple
Grade = namedtuple('Grade', ('score', 'weight'))
class Subject:
def __init__(self):
self._grades = []
def report_grade(self, score, weight):
self._grades.append(Grade(score, weight))
def average_grade(self):
total, total_weight = 0, 0
for grade in self._grades:
total += grade.score * grade.weight
total_weight += grade.weight
return total / total_weight
class Student:
def __init__(self):
self._subjects = defaultdict(Subject)
def get_subject(self, name):
return self._subjects[name]
def average_grade(self):
total, count = 0, 0
for subject in self._subjects.values():
total += subject.average_grade()
count += 1
return total / count
class Gradebook:
def __init__(self):
self._students = defaultdict(Student)
def get_student(self, name):
return self._students[name]
book = Gradebook()
albert = book.get_student('Albert Einstein')
math = albert.get_subject('Math')
math.report_grade(75, 0.05)
math.report_grade(65, 0.15)
math.report_grade(70, 0.80)
gym = albert.get_subject('Gym')
gym.report_grade(100, 0.40)
gym.report_grade(85, 0.60)
print(albert.average_grade())
38. 让简单的接口接受函数,而不是类的实例
# Example 1
names = ['Socrates', 'Archimedes', 'Plato', 'Aristotle']
names.sort(key=len)
print(names) # ['Plato', 'Socrates', 'Aristotle', 'Archimedes']
# Example 2
def log_missing():
print('Key added')
return 0
# Example 3
from collections import defaultdict
current = {'green': 12, 'blue': 3}
increments = [
('red', 5),
('blue', 17),
('orange', 9),
]
result = defaultdict(log_missing, current)
print('Before:', dict(result)) # Before: {'green': 12, 'blue': 3}
for key, amount in increments:
result[key] += amount
print('After: ', dict(result)) # After: {'green': 12, 'blue': 20, 'red': 5, 'orange': 9}
# Example 4
def increment_with_report(current, increments):
added_count = 0
def missing():
nonlocal added_count # Stateful closure
added_count += 1
return 0
result = defaultdict(missing, current)
for key, amount in increments:
result[key] += amount
return result, added_count
# Example 5
result, count = increment_with_report(current, increments)
assert count == 2
print(result) # defaultdict(<function increment_with_report.<locals>.missing at 0x7faa53a46ef0>, {'green': 12, 'blue': 20, 'red': 5, 'orange': 9})
# Example 6
class CountMissing:
def __init__(self):
self.added = 0
def missing(self):
self.added += 1
return 0
# Example 7
counter = CountMissing()
result = defaultdict(counter.missing, current) # Method ref
for key, amount in increments:
result[key] += amount
assert counter.added == 2
print(result) # defaultdict(<bound method CountMissing.missing of <__main__.CountMissing object at 0x7faa53ac7810>>, {'green': 12, 'blue': 20, 'red': 5, 'orange': 9})
# Example 8
class BetterCountMissing:
def __init__(self):
self.added = 0
def __call__(self):
self.added += 1
return 0
counter = BetterCountMissing()
assert counter() == 0
assert callable(counter)
# Example 9
counter = BetterCountMissing()
result = defaultdict(counter, current) # Relies on __call__
for key, amount in increments:
result[key] += amount
assert counter.added == 2
print(result) # defaultdict(<__main__.BetterCountMissing object at 0x7faa53ac7990>, {'green': 12, 'blue': 20, 'red': 5, 'orange': 9})
39. 通过@classmethod多态来构造同一体系中的各类对象
# Example 1
class InputData:
def read(self):
raise NotImplementedError
# Example 2
class PathInputData(InputData):
def __init__(self, path):
super().__init__()
self.path = path
def read(self):
with open(self.path) as f:
return f.read()
# Example 3
class Worker:
def __init__(self, input_data):
self.input_data = input_data
self.result = None
def map(self):
raise NotImplementedError
def reduce(self, other):
raise NotImplementedError
# Example 4
class LineCountWorker(Worker):
def map(self):
data = self.input_data.read()
self.result = data.count('\n')
def reduce(self, other):
self.result += other.result
# Example 5
import os
def generate_inputs(data_dir):
for name in os.listdir(data_dir):
yield PathInputData(os.path.join(data_dir, name))
# Example 6
def create_workers(input_list):
workers = []
for input_data in input_list:
workers.append(LineCountWorker(input_data))
return workers
# Example 7
from threading import Thread
def execute(workers):
threads = [Thread(target=w.map) for w in workers]
for thread in threads: thread.start()
for thread in threads: thread.join()
first, *rest = workers
for worker in rest:
first.reduce(worker)
return first.result
# Example 8
def mapreduce(data_dir):
inputs = generate_inputs(data_dir)
workers = create_workers(inputs)
return execute(workers)
# Example 9
import os
import random
def write_test_files(tmpdir):
os.makedirs(tmpdir)
for i in range(100):
with open(os.path.join(tmpdir, str(i)), 'w') as f:
f.write('\n' * random.randint(0, 100))
tmpdir = 'test_inputs'
write_test_files(tmpdir)
result = mapreduce(tmpdir)
print(f'There are {result} lines') # There are 4762 lines
# Example 10
class GenericInputData:
def read(self):
raise NotImplementedError
@classmethod
def generate_inputs(cls, config):
raise NotImplementedError
# Example 11
class PathInputData(GenericInputData):
def __init__(self, path):
super().__init__()
self.path = path
def read(self):
with open(self.path) as f:
return f.read()
@classmethod
def generate_inputs(cls, config):
data_dir = config['data_dir']
for name in os.listdir(data_dir):
yield cls(os.path.join(data_dir, name))
# Example 12
class GenericWorker:
def __init__(self, input_data):
self.input_data = input_data
self.result = None
def map(self):
raise NotImplementedError
def reduce(self, other):
raise NotImplementedError
@classmethod
def create_workers(cls, input_class, config):
workers = []
for input_data in input_class.generate_inputs(config):
workers.append(cls(input_data))
return workers
# Example 13
class LineCountWorker(GenericWorker):
def map(self):
data = self.input_data.read()
self.result = data.count('\n')
def reduce(self, other):
self.result += other.result
# Example 14
def mapreduce(worker_class, input_class, config):
workers = worker_class.create_workers(input_class, config)
return execute(workers)
# Example 15
config = {'data_dir': tmpdir}
result = mapreduce(LineCountWorker, PathInputData, config)
print(f'There are {result} lines') # There are 4762 lines
40. 通过super初始化超类
# Example 1
class MyBaseClass:
def __init__(self, value):
self.value = value
class MyChildClass(MyBaseClass):
def __init__(self):
MyBaseClass.__init__(self, 5)
def times_two(self):
return self.value * 2
foo = MyChildClass()
assert foo.times_two() == 10
# Example 2
class TimesTwo:
def __init__(self):
self.value *= 2
class PlusFive:
def __init__(self):
self.value += 5
# Example 3
class OneWay(MyBaseClass, TimesTwo, PlusFive):
def __init__(self, value):
MyBaseClass.__init__(self, value)
TimesTwo.__init__(self)
PlusFive.__init__(self)
# Example 4
foo = OneWay(5)
print('First ordering value is (5 * 2) + 5 =', foo.value) # First ordering value is (5 * 2) + 5 = 15
# Example 5
class AnotherWay(MyBaseClass, PlusFive, TimesTwo):
def __init__(self, value):
MyBaseClass.__init__(self, value)
TimesTwo.__init__(self)
PlusFive.__init__(self)
# Example 6
bar = AnotherWay(5)
print('Second ordering value is', bar.value) # Second ordering value is 15
# Example 7
class TimesSeven(MyBaseClass):
def __init__(self, value):
MyBaseClass.__init__(self, value)
self.value *= 7
class PlusNine(MyBaseClass):
def __init__(self, value):
MyBaseClass.__init__(self, value)
self.value += 9
# Example 8
class ThisWay(TimesSeven, PlusNine):
def __init__(self, value):
TimesSeven.__init__(self, value)
PlusNine.__init__(self, value)
foo = ThisWay(5)
print('Should be (5 * 7) + 9 = 44 but is', foo.value) # Should be (5 * 7) + 9 = 44 but is 14
# Example 9
class MyBaseClass:
def __init__(self, value):
self.value = value
class TimesSevenCorrect(MyBaseClass):
def __init__(self, value):
super().__init__(value)
self.value *= 7
class PlusNineCorrect(MyBaseClass):
def __init__(self, value):
super().__init__(value)
self.value += 9
# Example 10
class GoodWay(TimesSevenCorrect, PlusNineCorrect):
def __init__(self, value):
super().__init__(value)
foo = GoodWay(5)
print('Should be 7 * (5 + 9) = 98 and is', foo.value) # Should be 7 * (5 + 9) = 98 and is 98
# Example 11
mro_str = '\n'.join(repr(cls) for cls in GoodWay.mro())
print(f'mro_str: {mro_str}') # <class '__main__.GoodWay'>
# Example 12
class ExplicitTrisect(MyBaseClass):
def __init__(self, value):
super(ExplicitTrisect, self).__init__(value)
self.value /= 3
assert ExplicitTrisect(9).value == 3
# Example 13
class AutomaticTrisect(MyBaseClass):
def __init__(self, value):
super(__class__, self).__init__(value)
self.value /= 3
class ImplicitTrisect(MyBaseClass):
def __init__(self, value):
super().__init__(value)
self.value /= 3
assert ExplicitTrisect(9).value == 3
assert AutomaticTrisect(9).value == 3
assert ImplicitTrisect(9).value == 3
41. 考虑用mix-in类来表示可组合的功能
# Example 1
class ToDictMixin:
def to_dict(self):
return self._traverse_dict(self.__dict__)
# Example 2
def _traverse_dict(self, instance_dict):
output = {}
for key, value in instance_dict.items():
output[key] = self._traverse(key, value)
return output
def _traverse(self, key, value):
if isinstance(value, ToDictMixin):
return value.to_dict()
elif isinstance(value, dict):
return self._traverse_dict(value)
elif isinstance(value, list):
return [self._traverse(key, i) for i in value]
elif hasattr(value, '__dict__'):
return self._traverse_dict(value.__dict__)
else:
return value
# Example 3
class BinaryTree(ToDictMixin):
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
# Example 4
tree = BinaryTree(10, left=BinaryTree(7, right=BinaryTree(9)), right=BinaryTree(13, left=BinaryTree(11)))
print(tree.to_dict()) # {'value': 10, 'left': {'value': 7, 'left': None, 'right': {'value': 9, 'left': None, 'right': None}}, 'right': {'value': 13, 'left': {'value': 11, 'left': None, 'right': None}, 'right': None}}
# Example 5
class BinaryTreeWithParent(BinaryTree):
def __init__(self, value, left=None,
right=None, parent=None):
super().__init__(value, left=left, right=right)
self.parent = parent
# Example 6
def _traverse(self, key, value):
if (isinstance(value, BinaryTreeWithParent) and
key == 'parent'):
return value.value # Prevent cycles
else:
return super()._traverse(key, value)
# Example 7
root = BinaryTreeWithParent(10)
root.left = BinaryTreeWithParent(7, parent=root)
root.left.right = BinaryTreeWithParent(9, parent=root.left)
print(root.to_dict()) # {'value': 10, 'left': {'value': 7, 'left': None, 'right': {'value': 9, 'left': None, 'right': None, 'parent': 7}, 'parent': 10}, 'right': None, 'parent': None}
# Example 8
class NamedSubTree(ToDictMixin):
def __init__(self, name, tree_with_parent):
self.name = name
self.tree_with_parent = tree_with_parent
my_tree = NamedSubTree('foobar', root.left.right)
print(f'my_tree.to_dict(): {my_tree.to_dict()}') # my_tree.to_dict(): {'name': 'foobar', 'tree_with_parent': {'value': 9, 'left': None, 'right': None, 'parent': 7}}
# Example 9
import json
class JsonMixin:
@classmethod
def from_json(cls, data):
kwargs = json.loads(data)
return cls(**kwargs)
def to_json(self):
return json.dumps(self.to_dict())
# Example 10
class DatacenterRack(ToDictMixin, JsonMixin):
def __init__(self, switch=None, machines=None):
self.switch = Switch(**switch)
self.machines = [
Machine(**kwargs) for kwargs in machines]
class Switch(ToDictMixin, JsonMixin):
def __init__(self, ports=None, speed=None):
self.ports = ports
self.speed = speed
class Machine(ToDictMixin, JsonMixin):
def __init__(self, cores=None, ram=None, disk=None):
self.cores = cores
self.ram = ram
self.disk = disk
# Example 11
serialized = """{
"switch": {"ports": 5, "speed": 1e9},
"machines": [
{"cores": 8, "ram": 32e9, "disk": 5e12},
{"cores": 4, "ram": 16e9, "disk": 1e12},
{"cores": 2, "ram": 4e9, "disk": 500e9}
]
}"""
deserialized = DatacenterRack.from_json(serialized)
roundtrip = deserialized.to_json()
assert json.loads(serialized) == json.loads(roundtrip)
42. 优先考虑用public属性表示应受保护的数据,不要用private属性表示
class ApiClass:
def __init__(self):
self._value = 5
def get(self):
return self._value
class Child(ApiClass):
def __init__(self):
super().__init__()
self._value = 'hello' # Conflicts
a = Child()
print(f'{a.get()} and {a._value} should be different') # hello and hello should be different
# Example 15
class ApiClass:
def __init__(self):
self.__value = 5 # Double underscore
def get(self):
return self.__value # Double underscore
class Child(ApiClass):
def __init__(self):
super().__init__()
self._value = 'hello' # OK!
a = Child()
print(f'{a.get()} and {a._value} are different') # 5 and hello are different
43. 自定义的容器类型应该从collections.abc继承
from collections.abc import Sequence
class BinaryNode:
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
class IndexableNode(BinaryNode):
def _traverse(self):
if self.left is not None:
yield from self.left._traverse()
yield self
if self.right is not None:
yield from self.right._traverse()
def __getitem__(self, index):
for i, item in enumerate(self._traverse()):
if i == index:
return item.value
raise IndexError(f'Index {index} is out of range')
class SequenceNode(IndexableNode):
def __len__(self):
for count, _ in enumerate(self._traverse(), 1):
pass
return count
class BetterNode(SequenceNode, Sequence):
pass
tree = BetterNode(
10,
left=BetterNode(
5,
left=BetterNode(2),
right=BetterNode(
6,
right=BetterNode(7))),
right=BetterNode(
15,
left=BetterNode(11))
)
print('Index of 7 is', tree.index(7)) # Index of 7 is 3
print('Count of 10 is', tree.count(10)) # Count of 10 is 1
六. 元类与属性
44. 用纯属性与修饰器取代旧式的setter与getter方法
class Resistor:
def __init__(self, ohms):
self.ohms = ohms
self.voltage = 0
self.current = 0
class MysteriousResistor(Resistor):
@property
def ohms(self):
self.voltage = self._ohms * self.current
return self._ohms
@ohms.setter
def ohms(self, ohms):
self._ohms = ohms
# Example
r7 = MysteriousResistor(10)
r7.current = 0.01
print(f'Before: {r7.voltage:.2f}') # Before: 0.00
print(r7.ohms) # 10
print(f'After: {r7.voltage:.2f}') # After: 0.10
45. 考虑用@property实现新的属性访问逻辑,不要急着重构原有的代码
# Example 1
from datetime import datetime, timedelta
class Bucket:
def __init__(self, period):
self.period_delta = timedelta(seconds=period)
self.reset_time = datetime.now()
self.quota = 0
def __repr__(self):
return f'Bucket(quota={self.quota})'
bucket = Bucket(60)
print(bucket) # Bucket(quota=0)
# Example 2
def fill(bucket, amount):
now = datetime.now()
if (now - bucket.reset_time) > bucket.period_delta:
bucket.quota = 0
bucket.reset_time = now
bucket.quota += amount
# Example 3
def deduct(bucket, amount):
now = datetime.now()
if (now - bucket.reset_time) > bucket.period_delta:
return False # Bucket hasn't been filled this period
if bucket.quota - amount < 0:
return False # Bucket was filled, but not enough
bucket.quota -= amount
return True # Bucket had enough, quota consumed
# Example 4
bucket = Bucket(60)
fill(bucket, 100)
print(bucket) # Bucket(quota=100)
# Example 5
if deduct(bucket, 99):
print('Had 99 quota') # Had 99 quota
else:
print('Not enough for 99 quota')
print(bucket) # Bucket(quota=1)
# Example 6
if deduct(bucket, 3):
print('Had 3 quota')
else:
print('Not enough for 3 quota') # Not enough for 3 quota
print(bucket) # Bucket(quota=1)
# Example 7
class NewBucket:
def __init__(self, period):
self.period_delta = timedelta(seconds=period)
self.reset_time = datetime.now()
self.max_quota = 0
self.quota_consumed = 0
def __repr__(self):
return (f'NewBucket(max_quota={self.max_quota}, '
f'quota_consumed={self.quota_consumed})')
# Example 8
@property
def quota(self):
return self.max_quota - self.quota_consumed
# Example 9
@quota.setter
def quota(self, amount):
delta = self.max_quota - amount
if amount == 0:
# Quota being reset for a new period
self.quota_consumed = 0
self.max_quota = 0
elif delta < 0:
# Quota being filled during the period
self.max_quota = amount + self.quota_consumed
else:
# Quota being consumed during the period
self.quota_consumed = delta
# Example 10
bucket = NewBucket(60)
print('Initial', bucket) # Initial NewBucket(max_quota=0, quota_consumed=0)
fill(bucket, 100)
print('Filled', bucket) # Filled NewBucket(max_quota=100, quota_consumed=0)
if deduct(bucket, 99):
print('Had 99 quota') # Had 99 quota
else:
print('Not enough for 99 quota')
print('Now', bucket) # Now NewBucket(max_quota=100, quota_consumed=99)
if deduct(bucket, 3):
print('Had 3 quota') # Not enough for 3 quota
else:
print('Not enough for 3 quota')
print('Still', bucket) # Still NewBucket(max_quota=100, quota_consumed=99)
# Example 11
bucket = NewBucket(6000)
assert bucket.max_quota == 0
assert bucket.quota_consumed == 0
assert bucket.quota == 0
fill(bucket, 100)
assert bucket.max_quota == 100
assert bucket.quota_consumed == 0
assert bucket.quota == 100
assert deduct(bucket, 10)
assert bucket.max_quota == 100
assert bucket.quota_consumed == 10
assert bucket.quota == 90
assert deduct(bucket, 20)
assert bucket.max_quota == 100
assert bucket.quota_consumed == 30
assert bucket.quota == 70
fill(bucket, 50)
assert bucket.max_quota == 150
assert bucket.quota_consumed == 30
assert bucket.quota == 120
assert deduct(bucket, 40)
assert bucket.max_quota == 150
assert bucket.quota_consumed == 70
assert bucket.quota == 80
assert not deduct(bucket, 81)
assert bucket.max_quota == 150
assert bucket.quota_consumed == 70
assert bucket.quota == 80
bucket.reset_time += bucket.period_delta - timedelta(1)
assert bucket.quota == 80
assert not deduct(bucket, 79)
fill(bucket, 1)
assert bucket.quota == 1
46. 用描述符来改写需要复用的@property方法
# Example 1
class Homework:
def __init__(self):
self._grade = 0
@property
def grade(self):
return self._grade
@grade.setter
def grade(self, value):
if not (0 <= value <= 100):
raise ValueError(
'Grade must be between 0 and 100')
self._grade = value
# Example 2
galileo = Homework()
galileo.grade = 95
assert galileo.grade == 95
# Example 3
class Exam:
def __init__(self):
self._writing_grade = 0
self._math_grade = 0
@staticmethod
def _check_grade(value):
if not (0 <= value <= 100):
raise ValueError(
'Grade must be between 0 and 100')
# Example 4
@property
def writing_grade(self):
return self._writing_grade
@writing_grade.setter
def writing_grade(self, value):
self._check_grade(value)
self._writing_grade = value
@property
def math_grade(self):
return self._math_grade
@math_grade.setter
def math_grade(self, value):
self._check_grade(value)
self._math_grade = value
galileo = Exam()
galileo.writing_grade = 85
galileo.math_grade = 99
assert galileo.writing_grade == 85
assert galileo.math_grade == 99
# Example 5
class Grade:
def __get__(self, instance, instance_type):
pass
def __set__(self, instance, value):
pass
class Exam:
# Class attributes
math_grade = Grade()
writing_grade = Grade()
science_grade = Grade()
# Example 6
exam = Exam()
exam.writing_grade = 40
# Example 7
Exam.__dict__['writing_grade'].__set__(exam, 40)
# Example 8
exam.writing_grade
# Example 9
Exam.__dict__['writing_grade'].__get__(exam, Exam)
# Example 10
class Grade:
def __init__(self):
self._value = 0
def __get__(self, instance, instance_type):
return self._value
def __set__(self, instance, value):
if not (0 <= value <= 100):
raise ValueError(
'Grade must be between 0 and 100')
self._value = value
# Example 11
class Exam:
math_grade = Grade()
writing_grade = Grade()
science_grade = Grade()
first_exam = Exam()
first_exam.writing_grade = 82
first_exam.science_grade = 99
print('Writing', first_exam.writing_grade) # Writing 82
print('Science', first_exam.science_grade) # Science 99
# Example 12
second_exam = Exam()
second_exam.writing_grade = 75
print(f'Second {second_exam.writing_grade} is right') # Second 75 is right
print(f'First {first_exam.writing_grade} is wrong; should be 82') # First 75 is wrong; should be 82
# Example 13
class Grade:
def __init__(self):
self._values = {}
def __get__(self, instance, instance_type):
if instance is None:
return self
return self._values.get(instance, 0)
def __set__(self, instance, value):
if not (0 <= value <= 100):
raise ValueError(
'Grade must be between 0 and 100')
self._values[instance] = value
# Example 14
from weakref import WeakKeyDictionary
class Grade:
def __init__(self):
self._values = WeakKeyDictionary()
def __get__(self, instance, instance_type):
if instance is None:
return self
return self._values.get(instance, 0)
def __set__(self, instance, value):
if not (0 <= value <= 100):
raise ValueError('Grade must be between 0 and 100')
self._values[instance] = value
# Example 15
class Exam:
math_grade = Grade()
writing_grade = Grade()
science_grade = Grade()
first_exam = Exam()
first_exam.writing_grade = 82
second_exam = Exam()
second_exam.writing_grade = 75
print(f'First {first_exam.writing_grade} is right') # First 82 is right
print(f'Second {second_exam.writing_grade} is right') # Second 75 is right
47. 针对惰性属性使用__getattr__、__getattribute__及 __setattr__
class DictionaryRecord:
def __init__(self, data):
self._data = data
def __getattribute__(self, name):
# Prevent weird interactions with isinstance() used
# by example code harness.
if name == '__class__':
return DictionaryRecord
print(f'* Called __getattribute__({name!r})') # * Called __getattribute__('foo')
data_dict = super().__getattribute__('_data')
return data_dict[name]
data = DictionaryRecord({'foo': 3})
print('foo: ', data.foo) # foo: 3
48. 用__init_subclass__验证子类写得是否正确
class ValidatePolygon(type):
def __new__(meta, name, bases, class_dict):
# Only validate non-root classes
if not class_dict.get('is_root'):
if class_dict['sides'] < 3:
raise ValueError('Polygons need 3+ sides')
return type.__new__(meta, name, bases, class_dict)
class Polygon(metaclass=ValidatePolygon):
is_root = True
sides = None # Must be specified by subclasses
class ValidateFilledPolygon(ValidatePolygon):
def __new__(meta, name, bases, class_dict):
# Only validate non-root classes
if not class_dict.get('is_root'):
if class_dict['color'] not in ('red', 'green'):
raise ValueError('Fill color must be supported')
return super().__new__(meta, name, bases, class_dict)
class FilledPolygon(Polygon, metaclass=ValidateFilledPolygon):
is_root = True
color = None # Must be specified by subclasses
# Example 9
class GreenPentagon(FilledPolygon):
color = 'green'
sides = 5
greenie = GreenPentagon()
print(greenie, type(greenie)) # <__main__.GreenPentagon object at 0x7fac8bccd8d0> <class '__main__.GreenPentagon'>
print(Polygon, type(Polygon)) # <class '__main__.Polygon'> <class '__main__.ValidatePolygon'>
assert isinstance(greenie, Polygon)
# 下面是基本的菱形继承体系
class Top:
def __init_subclass__(cls):
super().__init_subclass__()
print(f'Top for {cls}')
class Left(Top):
def __init_subclass__(cls):
super().__init_subclass__() #
print(f'Left for {cls}')
class Right(Top):
def __init_subclass__(cls):
super().__init_subclass__()
print(f'Right for {cls}')
class Bottom(Left, Right):
def __init_subclass__(cls):
super().__init_subclass__()
print(f'Bottom for {cls}')
"""打印结果如下:
Top for <class '__main__.Left'>
Top for <class '__main__.Right'>
Top for <class '__main__.Bottom'>
Right for <class '__main__.Bottom'>
Left for <class '__main__.Bottom'>"""
49. 用__init_subclass__记录现有的子类
import json
registry = {}
def register_class(target_class):
registry[target_class.__name__] = target_class
def deserialize(data):
params = json.loads(data)
name = params['class']
target_class = registry[name]
return target_class(*params['args'])
class BetterSerializable:
def __init__(self, *args):
self.args = args
def serialize(self):
return json.dumps({
'class': self.__class__.__name__,
'args': self.args,
})
def __repr__(self):
name = self.__class__.__name__
args_str = ', '.join(str(x) for x in self.args)
return f'{name}({args_str})'
class BetterRegisteredSerializable(BetterSerializable):
def __init_subclass__(cls):
super().__init_subclass__()
register_class(cls)
class Vector1D(BetterRegisteredSerializable):
def __init__(self, magnitude):
super().__init__(magnitude)
self.magnitude = magnitude
before = Vector1D(6)
print('Before:', before) # Before: Vector1D(6)
data = before.serialize()
print('Serialized:', data) # Serialized: {"class": "Vector1D", "args": [6]}
print('After: ', deserialize(data)) # After: Vector1D(6)
50. 用__set_name__
给类属性加注解
class Field:
def __init__(self):
self.name = None
self.internal_name = None
def __set_name__(self, owner, name):
# Called on class creation for each descriptor
self.name = name
self.internal_name = '_' + name
def __get__(self, instance, instance_type):
if instance is None:
return self
return getattr(instance, self.internal_name, '')
def __set__(self, instance, value):
setattr(instance, self.internal_name, value)
# Example
class FixedCustomer:
first_name = Field()
last_name = Field()
prefix = Field()
suffix = Field()
cust = FixedCustomer()
print(f'Before: {cust.first_name!r} {cust.__dict__}') # Before: '' {}
cust.first_name = 'Mersenne'
print(f'After: {cust.first_name!r} {cust.__dict__}') # After: 'Mersenne' {'_first_name': 'Mersenne'}
51. 优先考虑通过类修饰器来提供可组合的扩充功能,不要使用元类
import types
from functools import wraps
trace_types = (
types.MethodType,
types.FunctionType,
types.BuiltinFunctionType,
types.BuiltinMethodType,
types.MethodDescriptorType,
types.ClassMethodDescriptorType)
def trace_func(func):
if hasattr(func, 'tracing'): # Only decorate once
return func
@wraps(func)
def wrapper(*args, **kwargs):
result = None
try:
result = func(*args, **kwargs)
return result
except Exception as e:
result = e
raise
finally:
print(f'{func.__name__}({args!r}, {kwargs!r}) -> {result!r}')
wrapper.tracing = True
return wrapper
# Example
def my_class_decorator(klass):
klass.extra_param = 'hello'
return klass
@my_class_decorator
class MyClass:
pass
print(MyClass) # <class '__main__.MyClass'>
print(MyClass.extra_param) # hello
# Example
def trace(klass):
for key in dir(klass):
value = getattr(klass, key)
if isinstance(value, trace_types):
wrapped = trace_func(value)
setattr(klass, key, wrapped)
return klass
# Example
@trace
class TraceDict(dict):
pass
trace_dict = TraceDict([('hi', 1)])
trace_dict['there'] = 2
try:
trace_dict['does not exist']
except KeyError:
pass # Expected
else:
assert False
# Example 11
class OtherMeta(type):
pass
@trace
class TraceDict(dict, metaclass=OtherMeta):
pass
trace_dict = TraceDict([('hi', 1)])
trace_dict['there'] = 2
try:
trace_dict['does not exist']
except KeyError:
pass # Expected
else:
assert False
"""
<class '__main__.MyClass'>
hello
__new__((<class '__main__.TraceDict'>, [('hi', 1)]), {}) -> {}
__getitem__(({'hi': 1, 'there': 2}, 'does not exist'), {}) -> KeyError('does not exist')
__new__((<class '__main__.TraceDict'>, [('hi', 1)]), {}) -> {}
__getitem__(({'hi': 1, 'there': 2}, 'does not exist'), {}) -> KeyError('does not exist')
"""
七. 并发与并行
52. 用subprocess管理子进程
# Example 1
import subprocess
# Enable these lines to make this example work on Windows
# import os
# os.environ['COMSPEC'] = 'powershell'
result = subprocess.run(
['echo', 'Hello from the child!'],
capture_output=True,
# Enable this line to make this example work on Windows
# shell=True,
encoding='utf-8')
result.check_returncode() # No exception means it exited cleanly
print(result.stdout)
# Example 2
# Use this line instead to make this example work on Windows
# proc = subprocess.Popen(['sleep', '1'], shell=True)
proc = subprocess.Popen(['sleep', '1'])
while proc.poll() is None:
print('Working...')
# Some time-consuming work here
import time
time.sleep(0.3)
print('Exit status', proc.poll())
# Example 3
import time
start = time.time()
sleep_procs = []
for _ in range(10):
# Use this line instead to make this example work on Windows
# proc = subprocess.Popen(['sleep', '1'], shell=True)
proc = subprocess.Popen(['sleep', '1'])
sleep_procs.append(proc)
# Example 4
for proc in sleep_procs:
proc.communicate()
end = time.time()
delta = end - start
print(f'Finished in {delta:.3} seconds')
# Example 5
import os
# On Windows, after installing OpenSSL, you may need to
# alias it in your PowerShell path with a command like:
# $env:path = $env:path + ";C:\Program Files\OpenSSL-Win64\bin"
def run_encrypt(data):
env = os.environ.copy()
env['password'] = 'zf7ShyBhZOraQDdE/FiZpm/m/8f9X+M1'
proc = subprocess.Popen(
['openssl', 'enc', '-des3', '-pass', 'env:password'],
env=env,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE)
proc.stdin.write(data)
proc.stdin.flush() # Ensure that the child gets input
return proc
# Example 6
procs = []
for _ in range(3):
data = os.urandom(10)
proc = run_encrypt(data)
procs.append(proc)
# Example 7
for proc in procs:
out, _ = proc.communicate()
print(f'out[-10:] {out[-10:]}')
# Example 8
def run_hash(input_stdin):
return subprocess.Popen(
['openssl', 'dgst', '-whirlpool', '-binary'],
stdin=input_stdin,
stdout=subprocess.PIPE)
# Example 9
encrypt_procs = []
hash_procs = []
for _ in range(3):
data = os.urandom(100)
encrypt_proc = run_encrypt(data)
encrypt_procs.append(encrypt_proc)
hash_proc = run_hash(encrypt_proc.stdout)
hash_procs.append(hash_proc)
# Ensure that the child consumes the input stream and
# the communicate() method doesn't inadvertently steal
# input from the child. Also lets SIGPIPE propagate to
# the upstream process if the downstream process dies.
encrypt_proc.stdout.close()
encrypt_proc.stdout = None
# Example 10
for proc in encrypt_procs:
proc.communicate()
assert proc.returncode == 0
for proc in hash_procs:
out, _ = proc.communicate()
print(out[-10:])
assert proc.returncode == 0
# Example 11
# Use this line instead to make this example work on Windows
# proc = subprocess.Popen(['sleep', '10'], shell=True)
proc = subprocess.Popen(['sleep', '10'])
try:
proc.communicate(timeout=0.1)
except subprocess.TimeoutExpired:
proc.terminate()
proc.wait()
print('Exit status', proc.poll())
53. 可以用线程执行阻塞式I/O,但不要用它做并行计算
# Example 1
def factorize(number):
for i in range(1, number + 1):
if number % i == 0:
yield i
# Example 2
import time
numbers = [2139079, 1214759, 1516637, 1852285]
start = time.time()
for number in numbers:
list(factorize(number))
end = time.time()
delta = end - start
print(f'Took {delta:.4f} seconds') # Took 0.4613 seconds
# Example 3
from threading import Thread
class FactorizeThread(Thread):
def __init__(self, number):
super().__init__()
self.number = number
def run(self):
self.factors = list(factorize(self.number))
# Example 4
start = time.time()
threads = []
for number in numbers:
thread = FactorizeThread(number)
thread.start()
threads.append(thread)
# Example 5
for thread in threads:
thread.join()
end = time.time()
delta = end - start
print(f'Took {delta:.3f} seconds') # Took 0.448 seconds
# Example 6
import select
import socket
def slow_systemcall():
select.select([socket.socket()], [], [], 0.1)
# Example 7
start = time.time()
for _ in range(5):
slow_systemcall()
end = time.time()
delta = end - start
print(f'Took {delta:.3f} seconds...') # Took 0.517 seconds...
# Example 8
start = time.time()
threads = []
for _ in range(5):
thread = Thread(target=slow_systemcall)
thread.start()
threads.append(thread)
# Example 9
def compute_helicopter_location(index):
pass
for i in range(5):
compute_helicopter_location(i)
for thread in threads:
thread.join()
end = time.time()
delta = end - start
print(f'Took {delta:.3f} seconds') # Took 0.102 seconds
54. 利用Lock防止多个线程争用同一份数据
# Example 1
class Counter:
def __init__(self):
self.count = 0
def increment(self, offset):
self.count += offset
# Example 2
def worker(sensor_index, how_many, counter):
# I have a barrier in here so the workers synchronize
# when they start counting, otherwise it's hard to get a race
# because the overhead of starting a thread is high.
BARRIER.wait()
for _ in range(how_many):
# Read from the sensor
# Nothing actually happens here, but this is where
# the blocking I/O would go.
counter.increment(1)
# Example 3
from threading import Barrier
BARRIER = Barrier(5)
from threading import Thread
how_many = 10**5
counter = Counter()
threads = []
for i in range(5):
thread = Thread(target=worker,
args=(i, how_many, counter))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
expected = how_many * 5
found = counter.count
print(f'Counter should be {expected}, got {found}') # Counter should be 500000, got 325923
# Example 4
counter.count += 1
# Example 5
value = getattr(counter, 'count')
result = value + 1
setattr(counter, 'count', result)
# Example 6
# Running in Thread A
value_a = getattr(counter, 'count')
# Context switch to Thread B
value_b = getattr(counter, 'count')
result_b = value_b + 1
setattr(counter, 'count', result_b)
# Context switch back to Thread A
result_a = value_a + 1
setattr(counter, 'count', result_a)
# Example 7
from threading import Lock
class LockingCounter:
def __init__(self):
self.lock = Lock()
self.count = 0
def increment(self, offset):
with self.lock:
self.count += offset
# Example 8
BARRIER = Barrier(5)
counter = LockingCounter()
for i in range(5):
thread = Thread(target=worker,
args=(i, how_many, counter))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
expected = how_many * 5
found = counter.count
print(f'Counter should be {expected}, got {found}') # Counter should be 500000, got 500000
55. 用Queue来协调各线程之间的工作进度
# Example 1
def download(item):
return item
def resize(item):
return item
def upload(item):
return item
# Example 2
from collections import deque
from threading import Lock
class MyQueue:
def __init__(self):
self.items = deque()
self.lock = Lock()
# Example 3
def put(self, item):
with self.lock:
self.items.append(item)
# Example 4
def get(self):
with self.lock:
return self.items.popleft()
# Example 5
from threading import Thread
import time
class Worker(Thread):
def __init__(self, func, in_queue, out_queue):
super().__init__()
self.func = func
self.in_queue = in_queue
self.out_queue = out_queue
self.polled_count = 0
self.work_done = 0
# Example 6
def run(self):
while True:
self.polled_count += 1
try:
item = self.in_queue.get()
except IndexError:
time.sleep(0.01) # No work to do
except AttributeError:
# The magic exit signal
return
else:
result = self.func(item)
self.out_queue.put(result)
self.work_done += 1
# Example 7
download_queue = MyQueue()
resize_queue = MyQueue()
upload_queue = MyQueue()
done_queue = MyQueue()
threads = [
Worker(download, download_queue, resize_queue),
Worker(resize, resize_queue, upload_queue),
Worker(upload, upload_queue, done_queue),
]
# Example 8
for thread in threads:
thread.start()
for _ in range(1000):
download_queue.put(object())
# Example 9
while len(done_queue.items) < 1000:
# Do something useful while waiting
time.sleep(0.1)
# Stop all the threads by causing an exception in their
# run methods.
for thread in threads:
thread.in_queue = None
thread.join()
# Example 10
processed = len(done_queue.items)
polled = sum(t.polled_count for t in threads)
print(f'Processed {processed} items after '
f'polling {polled} times')
# Example 11
from queue import Queue
my_queue = Queue()
def consumer():
print('Consumer waiting')
my_queue.get() # Runs after put() below
print('Consumer done')
thread = Thread(target=consumer)
thread.start()
# Example 12
print('Producer putting')
my_queue.put(object()) # Runs before get() above
print('Producer done')
thread.join()
# Example 13
my_queue = Queue(1) # Buffer size of 1
def consumer():
time.sleep(0.1) # Wait
my_queue.get() # Runs second
print('Consumer got 1')
my_queue.get() # Runs fourth
print('Consumer got 2')
print('Consumer done')
thread = Thread(target=consumer)
thread.start()
# Example 14
my_queue.put(object()) # Runs first
print('Producer put 1')
my_queue.put(object()) # Runs third
print('Producer put 2')
print('Producer done')
thread.join()
# Example 15
in_queue = Queue()
def consumer():
print('Consumer waiting')
work = in_queue.get() # Done second
print('Consumer working')
# Doing work
print('Consumer done')
in_queue.task_done() # Done third
thread = Thread(target=consumer)
thread.start()
# Example 16
print('Producer putting')
in_queue.put(object()) # Done first
print('Producer waiting')
in_queue.join() # Done fourth
print('Producer done')
thread.join()
# Example 17
class ClosableQueue(Queue):
SENTINEL = object()
def close(self):
self.put(self.SENTINEL)
# Example 18
def __iter__(self):
while True:
item = self.get()
try:
if item is self.SENTINEL:
return # Cause the thread to exit
yield item
finally:
self.task_done()
# Example 19
class StoppableWorker(Thread):
def __init__(self, func, in_queue, out_queue):
super().__init__()
self.func = func
self.in_queue = in_queue
self.out_queue = out_queue
def run(self):
for item in self.in_queue:
result = self.func(item)
self.out_queue.put(result)
# Example 20
download_queue = ClosableQueue()
resize_queue = ClosableQueue()
upload_queue = ClosableQueue()
done_queue = ClosableQueue()
threads = [
StoppableWorker(download, download_queue, resize_queue),
StoppableWorker(resize, resize_queue, upload_queue),
StoppableWorker(upload, upload_queue, done_queue),
]
# Example 21
for thread in threads:
thread.start()
for _ in range(1000):
download_queue.put(object())
download_queue.close()
# Example 22
download_queue.join()
resize_queue.close()
resize_queue.join()
upload_queue.close()
upload_queue.join()
print(done_queue.qsize(), 'items finished')
for thread in threads:
thread.join()
# Example 23
def start_threads(count, *args):
threads = [StoppableWorker(*args) for _ in range(count)]
for thread in threads:
thread.start()
return threads
def stop_threads(closable_queue, threads):
for _ in threads:
closable_queue.close()
closable_queue.join()
for thread in threads:
thread.join()
# Example 24
download_queue = ClosableQueue()
resize_queue = ClosableQueue()
upload_queue = ClosableQueue()
done_queue = ClosableQueue()
download_threads = start_threads(
3, download, download_queue, resize_queue)
resize_threads = start_threads(
4, resize, resize_queue, upload_queue)
upload_threads = start_threads(
5, upload, upload_queue, done_queue)
for _ in range(1000):
download_queue.put(object())
stop_threads(download_queue, download_threads)
stop_threads(resize_queue, resize_threads)
stop_threads(upload_queue, upload_threads)
print(done_queue.qsize(), 'items finished')
56. 学会判断什么场合必须做并发
# Example 1
ALIVE = '*'
EMPTY = '-'
# Example 2
class Grid:
def __init__(self, height, width):
self.height = height
self.width = width
self.rows = []
for _ in range(self.height):
self.rows.append([EMPTY] * self.width)
def get(self, y, x):
return self.rows[y % self.height][x % self.width]
def set(self, y, x, state):
self.rows[y % self.height][x % self.width] = state
def __str__(self):
output = ''
for row in self.rows:
for cell in row:
output += cell
output += '\n'
return output
# Example 3
grid = Grid(5, 9)
grid.set(0, 3, ALIVE)
grid.set(1, 4, ALIVE)
grid.set(2, 2, ALIVE)
grid.set(2, 3, ALIVE)
grid.set(2, 4, ALIVE)
print(grid)
"""
---*-----
----*----
--***----
---------
---------
"""
# Example 4
def count_neighbors(y, x, get):
n_ = get(y - 1, x + 0) # North
ne = get(y - 1, x + 1) # Northeast
e_ = get(y + 0, x + 1) # East
se = get(y + 1, x + 1) # Southeast
s_ = get(y + 1, x + 0) # South
sw = get(y + 1, x - 1) # Southwest
w_ = get(y + 0, x - 1) # West
nw = get(y - 1, x - 1) # Northwest
neighbor_states = [n_, ne, e_, se, s_, sw, w_, nw]
count = 0
for state in neighbor_states:
if state == ALIVE:
count += 1
return count
alive = {(9, 5), (9, 6)}
seen = set()
def fake_get(y, x):
position = (y, x)
seen.add(position)
return ALIVE if position in alive else EMPTY
count = count_neighbors(10, 5, fake_get)
assert count == 2
expected_seen = {
(9, 5), (9, 6), (10, 6), (11, 6),
(11, 5), (11, 4), (10, 4), (9, 4)
}
assert seen == expected_seen
# Example 5
def game_logic(state, neighbors):
if state == ALIVE:
if neighbors < 2:
return EMPTY # Die: Too few
elif neighbors > 3:
return EMPTY # Die: Too many
else:
if neighbors == 3:
return ALIVE # Regenerate
return state
assert game_logic(ALIVE, 0) == EMPTY
assert game_logic(ALIVE, 1) == EMPTY
assert game_logic(ALIVE, 2) == ALIVE
assert game_logic(ALIVE, 3) == ALIVE
assert game_logic(ALIVE, 4) == EMPTY
assert game_logic(EMPTY, 0) == EMPTY
assert game_logic(EMPTY, 1) == EMPTY
assert game_logic(EMPTY, 2) == EMPTY
assert game_logic(EMPTY, 3) == ALIVE
assert game_logic(EMPTY, 4) == EMPTY
# Example 6
def step_cell(y, x, get, set):
state = get(y, x)
neighbors = count_neighbors(y, x, get)
next_state = game_logic(state, neighbors)
set(y, x, next_state)
alive = {(10, 5), (9, 5), (9, 6)}
new_state = None
def fake_get(y, x):
return ALIVE if (y, x) in alive else EMPTY
def fake_set(y, x, state):
global new_state
new_state = state
# Stay alive
step_cell(10, 5, fake_get, fake_set)
assert new_state == ALIVE
# Stay dead
alive.remove((10, 5))
step_cell(10, 5, fake_get, fake_set)
assert new_state == EMPTY
# Regenerate
alive.add((10, 6))
step_cell(10, 5, fake_get, fake_set)
assert new_state == ALIVE
# Example 7
def simulate(grid):
next_grid = Grid(grid.height, grid.width)
for y in range(grid.height):
for x in range(grid.width):
step_cell(y, x, grid.get, next_grid.set)
return next_grid
# Example 8
class ColumnPrinter:
def __init__(self):
self.columns = []
def append(self, data):
self.columns.append(data)
def __str__(self):
row_count = 1
for data in self.columns:
row_count = max(
row_count, len(data.splitlines()) + 1)
rows = [''] * row_count
for j in range(row_count):
for i, data in enumerate(self.columns):
line = data.splitlines()[max(0, j - 1)]
if j == 0:
padding = ' ' * (len(line) // 2)
rows[j] += padding + str(i) + padding
else:
rows[j] += line
if (i + 1) < len(self.columns):
rows[j] += ' | '
return '\n'.join(rows)
columns = ColumnPrinter()
for i in range(5):
columns.append(str(grid))
grid = simulate(grid)
print(columns)
"""
0 | 1 | 2 | 3 | 4
---*----- | --------- | --------- | --------- | ---------
----*---- | --*-*---- | ----*---- | ---*----- | ----*----
--***---- | ---**---- | --*-*---- | ----**--- | -----*---
--------- | ---*----- | ---**---- | ---**---- | ---***---
--------- | --------- | --------- | --------- | ---------
"""
57. 不要在每次fan-out时都新建一批Thread实例
# Example 1
from threading import Lock
ALIVE = '*'
EMPTY = '-'
class Grid:
def __init__(self, height, width):
self.height = height
self.width = width
self.rows = []
for _ in range(self.height):
self.rows.append([EMPTY] * self.width)
def get(self, y, x):
return self.rows[y % self.height][x % self.width]
def set(self, y, x, state):
self.rows[y % self.height][x % self.width] = state
def __str__(self):
output = ''
for row in self.rows:
for cell in row:
output += cell
output += '\n'
return output
class LockingGrid(Grid):
def __init__(self, height, width):
super().__init__(height, width)
self.lock = Lock()
def __str__(self):
with self.lock:
return super().__str__()
def get(self, y, x):
with self.lock:
return super().get(y, x)
def set(self, y, x, state):
with self.lock:
return super().set(y, x, state)
# Example 2
from threading import Thread
def count_neighbors(y, x, get):
n_ = get(y - 1, x + 0) # North
ne = get(y - 1, x + 1) # Northeast
e_ = get(y + 0, x + 1) # East
se = get(y + 1, x + 1) # Southeast
s_ = get(y + 1, x + 0) # South
sw = get(y + 1, x - 1) # Southwest
w_ = get(y + 0, x - 1) # West
nw = get(y - 1, x - 1) # Northwest
neighbor_states = [n_, ne, e_, se, s_, sw, w_, nw]
count = 0
for state in neighbor_states:
if state == ALIVE:
count += 1
return count
# def game_logic(state, neighbors):
# # Do some blocking input/output in here:
# data = my_socket.recv(100)
def game_logic(state, neighbors):
if state == ALIVE:
if neighbors < 2:
return EMPTY # Die: Too few
elif neighbors > 3:
return EMPTY # Die: Too many
else:
if neighbors == 3:
return ALIVE # Regenerate
return state
def step_cell(y, x, get, set):
state = get(y, x)
neighbors = count_neighbors(y, x, get)
next_state = game_logic(state, neighbors)
set(y, x, next_state)
def simulate_threaded(grid):
next_grid = LockingGrid(grid.height, grid.width)
threads = []
for y in range(grid.height):
for x in range(grid.width):
args = (y, x, grid.get, next_grid.set)
thread = Thread(target=step_cell, args=args)
thread.start() # Fan out
threads.append(thread)
for thread in threads:
thread.join() # Fan in
return next_grid
# Example 3
class ColumnPrinter:
def __init__(self):
self.columns = []
def append(self, data):
self.columns.append(data)
def __str__(self):
row_count = 1
for data in self.columns:
row_count = max(
row_count, len(data.splitlines()) + 1)
rows = [''] * row_count
for j in range(row_count):
for i, data in enumerate(self.columns):
line = data.splitlines()[max(0, j - 1)]
if j == 0:
padding = ' ' * (len(line) // 2)
rows[j] += padding + str(i) + padding
else:
rows[j] += line
if (i + 1) < len(self.columns):
rows[j] += ' | '
return '\n'.join(rows)
grid = LockingGrid(5, 9) # Changed
grid.set(0, 3, ALIVE)
grid.set(1, 4, ALIVE)
grid.set(2, 2, ALIVE)
grid.set(2, 3, ALIVE)
grid.set(2, 4, ALIVE)
columns = ColumnPrinter()
for i in range(5):
columns.append(str(grid))
grid = simulate_threaded(grid) # Changed
print(columns)
# Example 4
# def game_logic(state, neighbors):
# raise OSError('Problem with I/O')
# Example 5
import contextlib
import io
fake_stderr = io.StringIO()
with contextlib.redirect_stderr(fake_stderr):
thread = Thread(target=game_logic, args=(ALIVE, 3))
thread.start()
thread.join()
print(fake_stderr.getvalue())
58. 学会正确地重构代码,以便用Queue做并发
把队列与一定数量的线程搭配起来,可以高效的实现fan-out(分派)与fan-in(归集)。
59. 如果必须用线程做并发,那就考虑通过ThreadPoolExecutor实现
# Example 1
ALIVE = '*'
EMPTY = '-'
class Grid:
def __init__(self, height, width):
self.height = height
self.width = width
self.rows = []
for _ in range(self.height):
self.rows.append([EMPTY] * self.width)
def get(self, y, x):
return self.rows[y % self.height][x % self.width]
def set(self, y, x, state):
self.rows[y % self.height][x % self.width] = state
def __str__(self):
output = ''
for row in self.rows:
for cell in row:
output += cell
output += '\n'
return output
from threading import Lock
class LockingGrid(Grid):
def __init__(self, height, width):
super().__init__(height, width)
self.lock = Lock()
def __str__(self):
with self.lock:
return super().__str__()
def get(self, y, x):
with self.lock:
return super().get(y, x)
def set(self, y, x, state):
with self.lock:
return super().set(y, x, state)
def count_neighbors(y, x, get):
n_ = get(y - 1, x + 0) # North
ne = get(y - 1, x + 1) # Northeast
e_ = get(y + 0, x + 1) # East
se = get(y + 1, x + 1) # Southeast
s_ = get(y + 1, x + 0) # South
sw = get(y + 1, x - 1) # Southwest
w_ = get(y + 0, x - 1) # West
nw = get(y - 1, x - 1) # Northwest
neighbor_states = [n_, ne, e_, se, s_, sw, w_, nw]
count = 0
for state in neighbor_states:
if state == ALIVE:
count += 1
return count
# def game_logic(state, neighbors):
# # Do some blocking input/output in here:
# data = my_socket.recv(100)
def game_logic(state, neighbors):
if state == ALIVE:
if neighbors < 2:
return EMPTY # Die: Too few
elif neighbors > 3:
return EMPTY # Die: Too many
else:
if neighbors == 3:
return ALIVE # Regenerate
return state
def step_cell(y, x, get, set):
state = get(y, x)
neighbors = count_neighbors(y, x, get)
next_state = game_logic(state, neighbors)
set(y, x, next_state)
# Example 2
from concurrent.futures import ThreadPoolExecutor
def simulate_pool(pool, grid):
next_grid = LockingGrid(grid.height, grid.width)
futures = []
for y in range(grid.height):
for x in range(grid.width):
args = (y, x, grid.get, next_grid.set)
future = pool.submit(step_cell, *args) # Fan out
futures.append(future)
for future in futures:
future.result() # Fan in
return next_grid
# Example 3
class ColumnPrinter:
def __init__(self):
self.columns = []
def append(self, data):
self.columns.append(data)
def __str__(self):
row_count = 1
for data in self.columns:
row_count = max(
row_count, len(data.splitlines()) + 1)
rows = [''] * row_count
for j in range(row_count):
for i, data in enumerate(self.columns):
line = data.splitlines()[max(0, j - 1)]
if j == 0:
padding = ' ' * (len(line) // 2)
rows[j] += padding + str(i) + padding
else:
rows[j] += line
if (i + 1) < len(self.columns):
rows[j] += ' | '
return '\n'.join(rows)
grid = LockingGrid(5, 9)
grid.set(0, 3, ALIVE)
grid.set(1, 4, ALIVE)
grid.set(2, 2, ALIVE)
grid.set(2, 3, ALIVE)
grid.set(2, 4, ALIVE)
columns = ColumnPrinter()
with ThreadPoolExecutor(max_workers=10) as pool:
for i in range(5):
columns.append(str(grid))
grid = simulate_pool(pool, grid)
print(columns)
60. 用协程实现高并发的I/O
import logging
# Example 1
ALIVE = '*'
EMPTY = '-'
class Grid:
def __init__(self, height, width):
self.height = height
self.width = width
self.rows = []
for _ in range(self.height):
self.rows.append([EMPTY] * self.width)
def get(self, y, x):
return self.rows[y % self.height][x % self.width]
def set(self, y, x, state):
self.rows[y % self.height][x % self.width] = state
def __str__(self):
output = ''
for row in self.rows:
for cell in row:
output += cell
output += '\n'
return output
def count_neighbors(y, x, get):
n_ = get(y - 1, x + 0) # North
ne = get(y - 1, x + 1) # Northeast
e_ = get(y + 0, x + 1) # East
se = get(y + 1, x + 1) # Southeast
s_ = get(y + 1, x + 0) # South
sw = get(y + 1, x - 1) # Southwest
w_ = get(y + 0, x - 1) # West
nw = get(y - 1, x - 1) # Northwest
neighbor_states = [n_, ne, e_, se, s_, sw, w_, nw]
count = 0
for state in neighbor_states:
if state == ALIVE:
count += 1
return count
async def game_logic(state, neighbors):
if state == ALIVE:
if neighbors < 2:
return EMPTY # Die: Too few
elif neighbors > 3:
return EMPTY # Die: Too many
else:
if neighbors == 3:
return ALIVE # Regenerate
return state
# Example 2
async def step_cell(y, x, get, set):
state = get(y, x)
neighbors = count_neighbors(y, x, get)
next_state = await game_logic(state, neighbors)
set(y, x, next_state)
# Example 3
import asyncio
async def simulate(grid):
next_grid = Grid(grid.height, grid.width)
tasks = []
for y in range(grid.height):
for x in range(grid.width):
task = step_cell(
y, x, grid.get, next_grid.set) # Fan out
tasks.append(task)
await asyncio.gather(*tasks) # Fan in
return next_grid
# Example 4
class ColumnPrinter:
def __init__(self):
self.columns = []
def append(self, data):
self.columns.append(data)
def __str__(self):
row_count = 1
for data in self.columns:
row_count = max(
row_count, len(data.splitlines()) + 1)
rows = [''] * row_count
for j in range(row_count):
for i, data in enumerate(self.columns):
line = data.splitlines()[max(0, j - 1)]
if j == 0:
padding = ' ' * (len(line) // 2)
rows[j] += padding + str(i) + padding
else:
rows[j] += line
if (i + 1) < len(self.columns):
rows[j] += ' | '
return '\n'.join(rows)
grid = Grid(5, 9)
grid.set(0, 3, ALIVE)
grid.set(1, 4, ALIVE)
grid.set(2, 2, ALIVE)
grid.set(2, 3, ALIVE)
grid.set(2, 4, ALIVE)
columns = ColumnPrinter()
for i in range(5):
columns.append(str(grid))
grid = asyncio.run(simulate(grid)) # Run the event loop
print(columns)
# Example 6
async def count_neighbors(y, x, get):
n_ = get(y - 1, x + 0) # North
ne = get(y - 1, x + 1) # Northeast
e_ = get(y + 0, x + 1) # East
se = get(y + 1, x + 1) # Southeast
s_ = get(y + 1, x + 0) # South
sw = get(y + 1, x - 1) # Southwest
w_ = get(y + 0, x - 1) # West
nw = get(y - 1, x - 1) # Northwest
neighbor_states = [n_, ne, e_, se, s_, sw, w_, nw]
count = 0
for state in neighbor_states:
if state == ALIVE:
count += 1
return count
async def step_cell(y, x, get, set):
state = get(y, x)
neighbors = await count_neighbors(y, x, get)
next_state = await game_logic(state, neighbors)
set(y, x, next_state)
async def game_logic(state, neighbors):
if state == ALIVE:
if neighbors < 2:
return EMPTY # Die: Too few
elif neighbors > 3:
return EMPTY # Die: Too many
else:
if neighbors == 3:
return ALIVE # Regenerate
return state
grid = Grid(5, 9)
grid.set(0, 3, ALIVE)
grid.set(1, 4, ALIVE)
grid.set(2, 2, ALIVE)
grid.set(2, 3, ALIVE)
grid.set(2, 4, ALIVE)
columns = ColumnPrinter()
for i in range(5):
columns.append(str(grid))
grid = asyncio.run(simulate(grid))
print(columns)
61. 学会用asyncio改写那些通过线程实现的I/O
import logging
# Example 1
class EOFError(Exception):
pass
class ConnectionBase:
def __init__(self, connection):
self.connection = connection
self.file = connection.makefile('rb')
def send(self, command):
line = command + '\n'
data = line.encode()
self.connection.send(data)
def receive(self):
line = self.file.readline()
if not line:
raise EOFError('Connection closed')
return line[:-1].decode()
# Example 2
import random
WARMER = 'Warmer'
COLDER = 'Colder'
UNSURE = 'Unsure'
CORRECT = 'Correct'
class UnknownCommandError(Exception):
pass
class Session(ConnectionBase):
def __init__(self, *args):
super().__init__(*args)
self._clear_state(None, None)
def _clear_state(self, lower, upper):
self.lower = lower
self.upper = upper
self.secret = None
self.guesses = []
# Example 3
def loop(self):
while command := self.receive():
parts = command.split(' ')
if parts[0] == 'PARAMS':
self.set_params(parts)
elif parts[0] == 'NUMBER':
self.send_number()
elif parts[0] == 'REPORT':
self.receive_report(parts)
else:
raise UnknownCommandError(command)
# Example 4
def set_params(self, parts):
assert len(parts) == 3
lower = int(parts[1])
upper = int(parts[2])
self._clear_state(lower, upper)
# Example 5
def next_guess(self):
if self.secret is not None:
return self.secret
while True:
guess = random.randint(self.lower, self.upper)
if guess not in self.guesses:
return guess
def send_number(self):
guess = self.next_guess()
self.guesses.append(guess)
self.send(format(guess))
# Example 6
def receive_report(self, parts):
assert len(parts) == 2
decision = parts[1]
last = self.guesses[-1]
if decision == CORRECT:
self.secret = last
print(f'Server: {last} is {decision}')
# Example 7
import contextlib
import math
class Client(ConnectionBase):
def __init__(self, *args):
super().__init__(*args)
self._clear_state()
def _clear_state(self):
self.secret = None
self.last_distance = None
# Example 8
@contextlib.contextmanager
def session(self, lower, upper, secret):
print(f'Guess a number between {lower} and {upper}!'
f' Shhhhh, it\'s {secret}.')
self.secret = secret
self.send(f'PARAMS {lower} {upper}')
try:
yield
finally:
self._clear_state()
self.send('PARAMS 0 -1')
# Example 9
def request_numbers(self, count):
for _ in range(count):
self.send('NUMBER')
data = self.receive()
yield int(data)
if self.last_distance == 0:
return
# Example 10
def report_outcome(self, number):
new_distance = math.fabs(number - self.secret)
decision = UNSURE
if new_distance == 0:
decision = CORRECT
elif self.last_distance is None:
pass
elif new_distance < self.last_distance:
decision = WARMER
elif new_distance > self.last_distance:
decision = COLDER
self.last_distance = new_distance
self.send(f'REPORT {decision}')
return decision
# Example 11
import socket
from threading import Thread
def handle_connection(connection):
with connection:
session = Session(connection)
try:
session.loop()
except EOFError:
pass
def run_server(address):
with socket.socket() as listener:
# Allow the port to be reused
listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
listener.bind(address)
listener.listen()
while True:
connection, _ = listener.accept()
thread = Thread(target=handle_connection,
args=(connection,),
daemon=True)
thread.start()
# Example 12
def run_client(address):
with socket.create_connection(address) as connection:
client = Client(connection)
with client.session(1, 5, 3):
results = [(x, client.report_outcome(x))
for x in client.request_numbers(5)]
with client.session(10, 15, 12):
for number in client.request_numbers(5):
outcome = client.report_outcome(number)
results.append((number, outcome))
return results
# Example 13
def main():
address = ('127.0.0.1', 1234)
server_thread = Thread(
target=run_server, args=(address,), daemon=True)
server_thread.start()
results = run_client(address)
for number, outcome in results:
print(f'Client: {number} is {outcome}')
main()
# Example 14
class AsyncConnectionBase:
def __init__(self, reader, writer): # Changed
self.reader = reader # Changed
self.writer = writer # Changed
async def send(self, command):
line = command + '\n'
data = line.encode()
self.writer.write(data) # Changed
await self.writer.drain() # Changed
async def receive(self):
line = await self.reader.readline() # Changed
if not line:
raise EOFError('Connection closed')
return line[:-1].decode()
# Example 15
class AsyncSession(AsyncConnectionBase): # Changed
def __init__(self, *args):
super().__init__(*args)
self._clear_values(None, None)
def _clear_values(self, lower, upper):
self.lower = lower
self.upper = upper
self.secret = None
self.guesses = []
# Example 16
async def loop(self): # Changed
while command := await self.receive(): # Changed
parts = command.split(' ')
if parts[0] == 'PARAMS':
self.set_params(parts)
elif parts[0] == 'NUMBER':
await self.send_number() # Changed
elif parts[0] == 'REPORT':
self.receive_report(parts)
else:
raise UnknownCommandError(command)
# Example 17
def set_params(self, parts):
assert len(parts) == 3
lower = int(parts[1])
upper = int(parts[2])
self._clear_values(lower, upper)
# Example 18
def next_guess(self):
if self.secret is not None:
return self.secret
while True:
guess = random.randint(self.lower, self.upper)
if guess not in self.guesses:
return guess
async def send_number(self): # Changed
guess = self.next_guess()
self.guesses.append(guess)
await self.send(format(guess)) # Changed
# Example 19
def receive_report(self, parts):
assert len(parts) == 2
decision = parts[1]
last = self.guesses[-1]
if decision == CORRECT:
self.secret = last
print(f'Server: {last} is {decision}')
# Example 20
class AsyncClient(AsyncConnectionBase): # Changed
def __init__(self, *args):
super().__init__(*args)
self._clear_state()
def _clear_state(self):
self.secret = None
self.last_distance = None
# Example 21
@contextlib.asynccontextmanager # Changed
async def session(self, lower, upper, secret): # Changed
print(f'Guess a number between {lower} and {upper}!'
f' Shhhhh, it\'s {secret}.')
self.secret = secret
await self.send(f'PARAMS {lower} {upper}') # Changed
try:
yield
finally:
self._clear_state()
await self.send('PARAMS 0 -1') # Changed
# Example 22
async def request_numbers(self, count): # Changed
for _ in range(count):
await self.send('NUMBER') # Changed
data = await self.receive() # Changed
yield int(data)
if self.last_distance == 0:
return
# Example 23
async def report_outcome(self, number): # Changed
new_distance = math.fabs(number - self.secret)
decision = UNSURE
if new_distance == 0:
decision = CORRECT
elif self.last_distance is None:
pass
elif new_distance < self.last_distance:
decision = WARMER
elif new_distance > self.last_distance:
decision = COLDER
self.last_distance = new_distance
await self.send(f'REPORT {decision}') # Changed
# Make it so the output printing is in
# the same order as the threaded version.
await asyncio.sleep(0.01)
return decision
# Example 24
import asyncio
async def handle_async_connection(reader, writer):
session = AsyncSession(reader, writer)
try:
await session.loop()
except EOFError:
pass
async def run_async_server(address):
server = await asyncio.start_server(
handle_async_connection, *address)
async with server:
await server.serve_forever()
# Example 25
async def run_async_client(address):
# Wait for the server to listen before trying to connect
await asyncio.sleep(0.1)
streams = await asyncio.open_connection(*address) # New
client = AsyncClient(*streams) # New
async with client.session(1, 5, 3):
results = [(x, await client.report_outcome(x))
async for x in client.request_numbers(5)]
async with client.session(10, 15, 12):
async for number in client.request_numbers(5):
outcome = await client.report_outcome(number)
results.append((number, outcome))
_, writer = streams # New
writer.close() # New
await writer.wait_closed() # New
return results
# Example 26
async def main_async():
address = ('127.0.0.1', 4321)
server = run_async_server(address)
asyncio.create_task(server)
results = await run_async_client(address)
for number, outcome in results:
print(f'Client: {number} is {outcome}')
logging.getLogger().setLevel(logging.ERROR)
asyncio.run(main_async())
logging.getLogger().setLevel(logging.DEBUG)
62. 结合线程与协程,将代码顺利迁移到asyncio
# Example 1
class NoNewData(Exception):
pass
def readline(handle):
offset = handle.tell()
handle.seek(0, 2)
length = handle.tell()
if length == offset:
raise NoNewData
handle.seek(offset, 0)
return handle.readline()
# Example 2
import time
def tail_file(handle, interval, write_func):
while not handle.closed:
try:
line = readline(handle)
except NoNewData:
time.sleep(interval)
else:
write_func(line)
# Example 3
from threading import Lock, Thread
def run_threads(handles, interval, output_path):
with open(output_path, 'wb') as output:
lock = Lock()
def write(data):
with lock:
output.write(data)
threads = []
for handle in handles:
args = (handle, interval, write)
thread = Thread(target=tail_file, args=args)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
# Example 4
# This is all code to simulate the writers to the handles
import collections
import os
import random
import string
from tempfile import TemporaryDirectory
def write_random_data(path, write_count, interval):
with open(path, 'wb') as f:
for i in range(write_count):
time.sleep(random.random() * interval)
letters = random.choices(
string.ascii_lowercase, k=10)
data = f'{path}-{i:02}-{"".join(letters)}\n'
f.write(data.encode())
f.flush()
def start_write_threads(directory, file_count):
paths = []
for i in range(file_count):
path = os.path.join(directory, str(i))
with open(path, 'w'):
# Make sure the file at this path will exist when
# the reading thread tries to poll it.
pass
paths.append(path)
args = (path, 10, 0.1)
thread = Thread(target=write_random_data, args=args)
thread.start()
return paths
def close_all(handles):
time.sleep(1)
for handle in handles:
handle.close()
def setup():
tmpdir = TemporaryDirectory()
input_paths = start_write_threads(tmpdir.name, 5)
handles = []
for path in input_paths:
handle = open(path, 'rb')
handles.append(handle)
Thread(target=close_all, args=(handles,)).start()
output_path = os.path.join(tmpdir.name, 'merged')
return tmpdir, input_paths, handles, output_path
# Example 5
def confirm_merge(input_paths, output_path):
found = collections.defaultdict(list)
with open(output_path, 'rb') as f:
for line in f:
for path in input_paths:
if line.find(path.encode()) == 0:
found[path].append(line)
expected = collections.defaultdict(list)
for path in input_paths:
with open(path, 'rb') as f:
expected[path].extend(f.readlines())
for key, expected_lines in expected.items():
found_lines = found[key]
assert expected_lines == found_lines, \
f'{expected_lines!r} == {found_lines!r}'
input_paths = ...
handles = ...
output_path = ...
tmpdir, input_paths, handles, output_path = setup()
run_threads(handles, 0.1, output_path)
confirm_merge(input_paths, output_path)
tmpdir.cleanup()
# Example 6
import asyncio
# On Windows, a ProactorEventLoop can't be created within
# threads because it tries to register signal handlers. This
# is a work-around to always use the SelectorEventLoop policy
# instead. See: https://bugs.python.org/issue33792
policy = asyncio.get_event_loop_policy()
policy._loop_factory = asyncio.SelectorEventLoop
async def run_tasks_mixed(handles, interval, output_path):
loop = asyncio.get_event_loop()
with open(output_path, 'wb') as output:
async def write_async(data):
output.write(data)
def write(data):
coro = write_async(data)
future = asyncio.run_coroutine_threadsafe(
coro, loop)
future.result()
tasks = []
for handle in handles:
task = loop.run_in_executor(
None, tail_file, handle, interval, write)
tasks.append(task)
await asyncio.gather(*tasks)
# Example 7
input_paths = ...
handles = ...
output_path = ...
tmpdir, input_paths, handles, output_path = setup()
asyncio.run(run_tasks_mixed(handles, 0.1, output_path))
confirm_merge(input_paths, output_path)
tmpdir.cleanup()
# Example 8
async def tail_async(handle, interval, write_func):
loop = asyncio.get_event_loop()
while not handle.closed:
try:
line = await loop.run_in_executor(
None, readline, handle)
except NoNewData:
await asyncio.sleep(interval)
else:
await write_func(line)
# Example 9
async def run_tasks(handles, interval, output_path):
with open(output_path, 'wb') as output:
async def write_async(data):
output.write(data)
tasks = []
for handle in handles:
coro = tail_async(handle, interval, write_async)
task = asyncio.create_task(coro)
tasks.append(task)
await asyncio.gather(*tasks)
# Example 10
input_paths = ...
handles = ...
output_path = ...
tmpdir, input_paths, handles, output_path = setup()
asyncio.run(run_tasks(handles, 0.1, output_path))
confirm_merge(input_paths, output_path)
tmpdir.cleanup()
# Example 11
def tail_file(handle, interval, write_func):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def write_async(data):
write_func(data)
coro = tail_async(handle, interval, write_async)
loop.run_until_complete(coro)
# Example 12
input_paths = ...
handles = ...
output_path = ...
tmpdir, input_paths, handles, output_path = setup()
run_threads(handles, 0.1, output_path)
confirm_merge(input_paths, output_path)
tmpdir.cleanup()
63. 让asyncio的事件循环保持畅通,以便进一步提升程序的响应能力
调用asyncio.run时,可以把debug参数设置成为true,这样可以知道哪些协程降低了事件循环的速度。
# Example 1
import asyncio
# On Windows, a ProactorEventLoop can't be created within
# threads because it tries to register signal handlers. This
# is a work-around to always use the SelectorEventLoop policy
# instead. See: https://bugs.python.org/issue33792
policy = asyncio.get_event_loop_policy()
policy._loop_factory = asyncio.SelectorEventLoop
async def run_tasks(handles, interval, output_path):
with open(output_path, 'wb') as output:
async def write_async(data):
output.write(data)
tasks = []
for handle in handles:
coro = tail_async(handle, interval, write_async)
task = asyncio.create_task(coro)
tasks.append(task)
await asyncio.gather(*tasks)
# Example 2
import time
async def slow_coroutine():
time.sleep(0.5) # Simulating slow I/O
asyncio.run(slow_coroutine(), debug=True)
# Example 3
from threading import Thread
class WriteThread(Thread):
def __init__(self, output_path):
super().__init__()
self.output_path = output_path
self.output = None
self.loop = asyncio.new_event_loop()
def run(self):
asyncio.set_event_loop(self.loop)
with open(self.output_path, 'wb') as self.output:
self.loop.run_forever()
# Run one final round of callbacks so the await on
# stop() in another event loop will be resolved.
self.loop.run_until_complete(asyncio.sleep(0))
# Example 4
async def real_write(self, data):
self.output.write(data)
async def write(self, data):
coro = self.real_write(data)
future = asyncio.run_coroutine_threadsafe(
coro, self.loop)
await asyncio.wrap_future(future)
# Example 5
async def real_stop(self):
self.loop.stop()
async def stop(self):
coro = self.real_stop()
future = asyncio.run_coroutine_threadsafe(
coro, self.loop)
await asyncio.wrap_future(future)
# Example 6
async def __aenter__(self):
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self.start)
return self
async def __aexit__(self, *_):
await self.stop()
# Example 7
class NoNewData(Exception):
pass
def readline(handle):
offset = handle.tell()
handle.seek(0, 2)
length = handle.tell()
if length == offset:
raise NoNewData
handle.seek(offset, 0)
return handle.readline()
async def tail_async(handle, interval, write_func):
loop = asyncio.get_event_loop()
while not handle.closed:
try:
line = await loop.run_in_executor(
None, readline, handle)
except NoNewData:
await asyncio.sleep(interval)
else:
await write_func(line)
async def run_fully_async(handles, interval, output_path):
async with WriteThread(output_path) as output:
tasks = []
for handle in handles:
coro = tail_async(handle, interval, output.write)
task = asyncio.create_task(coro)
tasks.append(task)
await asyncio.gather(*tasks)
# Example 8
# This is all code to simulate the writers to the handles
import collections
import os
import random
import string
from tempfile import TemporaryDirectory
def write_random_data(path, write_count, interval):
with open(path, 'wb') as f:
for i in range(write_count):
time.sleep(random.random() * interval)
letters = random.choices(
string.ascii_lowercase, k=10)
data = f'{path}-{i:02}-{"".join(letters)}\n'
f.write(data.encode())
f.flush()
def start_write_threads(directory, file_count):
paths = []
for i in range(file_count):
path = os.path.join(directory, str(i))
with open(path, 'w'):
# Make sure the file at this path will exist when
# the reading thread tries to poll it.
pass
paths.append(path)
args = (path, 10, 0.1)
thread = Thread(target=write_random_data, args=args)
thread.start()
return paths
def close_all(handles):
time.sleep(1)
for handle in handles:
handle.close()
def setup():
tmpdir = TemporaryDirectory()
input_paths = start_write_threads(tmpdir.name, 5)
handles = []
for path in input_paths:
handle = open(path, 'rb')
handles.append(handle)
Thread(target=close_all, args=(handles,)).start()
output_path = os.path.join(tmpdir.name, 'merged')
return tmpdir, input_paths, handles, output_path
# Example 9
def confirm_merge(input_paths, output_path):
found = collections.defaultdict(list)
with open(output_path, 'rb') as f:
for line in f:
for path in input_paths:
if line.find(path.encode()) == 0:
found[path].append(line)
expected = collections.defaultdict(list)
for path in input_paths:
with open(path, 'rb') as f:
expected[path].extend(f.readlines())
for key, expected_lines in expected.items():
found_lines = found[key]
assert expected_lines == found_lines
input_paths = ...
handles = ...
output_path = ...
tmpdir, input_paths, handles, output_path = setup()
asyncio.run(run_fully_async(handles, 0.1, output_path))
confirm_merge(input_paths, output_path)
tmpdir.cleanup()
64. 考虑用concurrent.futures实现真正的并行计算
- Run_thread.py
from concurrent.futures import ThreadPoolExecutor
import time
NUMBERS = [
(1963309, 2265973), (2030677, 3814172),
(1551645, 2229620), (2039045, 2020802),
(1823712, 1924928), (2293129, 1020491),
(1281238, 2273782), (3823812, 4237281),
(3812741, 4729139), (1292391, 2123811),
]
def gcd(pair):
a, b = pair
low = min(a, b)
for i in range(low, 0, -1):
if a % i == 0 and b % i == 0:
return i
assert False, 'Not reachable'
def main():
start = time.time()
pool = ThreadPoolExecutor(max_workers=2)
results = list(pool.map(gcd, NUMBERS))
end = time.time()
delta = end - start
print(f'Took {delta:.3f} seconds') # Took 1.278 seconds
if __name__ == '__main__':
main()
- Run_serial.py
import time
NUMBERS = [
(1963309, 2265973), (2030677, 3814172),
(1551645, 2229620), (2039045, 2020802),
(1823712, 1924928), (2293129, 1020491),
(1281238, 2273782), (3823812, 4237281),
(3812741, 4729139), (1292391, 2123811),
]
def gcd(pair):
a, b = pair
low = min(a, b)
for i in range(low, 0, -1):
if a % i == 0 and b % i == 0:
return i
assert False, 'Not reachable'
def main():
start = time.time()
results = list(map(gcd, NUMBERS))
end = time.time()
delta = end - start
print(f'Took {delta:.3f} seconds') # Took 1.204 seconds
if __name__ == '__main__':
main()
- run_parallel.py
from concurrent.futures import ProcessPoolExecutor
import time
NUMBERS = [
(1963309, 2265973), (2030677, 3814172),
(1551645, 2229620), (2039045, 2020802),
(1823712, 1924928), (2293129, 1020491),
(1281238, 2273782), (3823812, 4237281),
(3812741, 4729139), (1292391, 2123811),
]
def gcd(pair):
a, b = pair
low = min(a, b)
for i in range(low, 0, -1):
if a % i == 0 and b % i == 0:
return i
assert False, 'Not reachable'
def main():
start = time.time()
pool = ProcessPoolExecutor(max_workers=2) # The one change
results = list(pool.map(gcd, NUMBERS))
end = time.time()
delta = end - start
print(f'Took {delta:.3f} seconds') # Took 0.635 seconds
if __name__ == '__main__':
main()
八. 稳定与性能
65. 合理利用try/except/else/finally结构中的每个代码块
import json
UNDEFINED = object()
DIE_IN_ELSE_BLOCK = False
def divide_json(path):
print('* Opening file')
handle = open(path, 'r+') # May raise OSError
try:
print('* Reading data')
data = handle.read() # May raise UnicodeDecodeError
print('* Loading JSON data')
op = json.loads(data) # May raise ValueError
print('* Performing calculation')
value = (
op['numerator'] /
op['denominator']) # May raise ZeroDivisionError
except ZeroDivisionError as e:
print(f'* Handling ZeroDivisionError: {e}')
return UNDEFINED
else:
print('* Writing calculation')
op['result'] = value
result = json.dumps(op)
handle.seek(0) # May raise OSError
if DIE_IN_ELSE_BLOCK:
import errno
import os
raise OSError(errno.ENOSPC, os.strerror(errno.ENOSPC))
handle.write(result) # May raise OSError
return value
finally:
print('* Calling close()')
handle.close() # Always runs
temp_path = 'random_data.json'
with open(temp_path, 'w', encoding='utf-8') as f:
f.write('{"numerator": 1, "denominator": 10}')
assert divide_json(temp_path) == 0.1
66. 考虑用contextlib和with语句来改写可复用的try/finally代码
# Example 1
from threading import Lock
lock = Lock()
with lock:
# Do something while maintaining an invariant
pass
# Example 2
lock.acquire()
try:
# Do something while maintaining an invariant
pass
finally:
lock.release()
# Example 3
import logging
logging.getLogger().setLevel(logging.WARNING)
def my_function():
logging.debug('Some debug data')
logging.error('Error log here')
logging.debug('More debug data')
# Example 4
my_function()
# Example 5
from contextlib import contextmanager
@contextmanager
def debug_logging(level):
logger = logging.getLogger()
old_level = logger.getEffectiveLevel()
logger.setLevel(level)
try:
yield
finally:
logger.setLevel(old_level)
# Example 6
with debug_logging(logging.DEBUG):
print('* Inside:')
my_function()
print('* After:')
my_function()
# Example 7
with open('my_output.txt', 'w', encoding='utf-8') as handle:
handle.write('This is some data!')
# Example 8
@contextmanager
def log_level(level, name):
logger = logging.getLogger(name)
old_level = logger.getEffectiveLevel()
logger.setLevel(level)
try:
yield logger
finally:
logger.setLevel(old_level)
# Example 9
with log_level(logging.DEBUG, 'my-log') as logger:
logger.debug(f'This is a message for {logger.name}!')
logging.debug('This will not print')
# Example 10
logger = logging.getLogger('my-log')
logger.debug('Debug will not print')
logger.error('Error will print')
# Example 11
with log_level(logging.DEBUG, 'other-log') as logger:
logger.debug(f'This is a message for {logger.name}!')
logging.debug('This will not print')
67. 用datetime模块处理本地时间,不要用time模块
import time
from datetime import datetime, timezone
now = datetime(2019, 3, 16, 22, 14, 35)
now_utc = now.replace(tzinfo=timezone.utc)
now_local = now_utc.astimezone()
print(now_local) # 2019-03-17 06:14:35+08:00
# Example 6
time_str = '2019-03-16 15:14:35'
time_format = '%Y-%m-%d %H:%M:%S'
now = datetime.strptime(time_str, time_format)
time_tuple = now.timetuple()
utc_now = time.mktime(time_tuple)
print(utc_now) # 1552720475.0
# Example 7
import pytz
arrival_nyc = '2019-03-16 23:33:24'
nyc_dt_naive = datetime.strptime(arrival_nyc, time_format)
eastern = pytz.timezone('US/Eastern')
nyc_dt = eastern.localize(nyc_dt_naive)
utc_dt = pytz.utc.normalize(nyc_dt.astimezone(pytz.utc))
print(utc_dt) # 2019-03-17 03:33:24+00:00
# Example 8
pacific = pytz.timezone('US/Pacific')
sf_dt = pacific.normalize(utc_dt.astimezone(pacific))
print(sf_dt) # 2019-03-16 20:33:24-07:00
# Example 9
nepal = pytz.timezone('Asia/Katmandu')
nepal_dt = nepal.normalize(utc_dt.astimezone(nepal))
print(nepal_dt) # 2019-03-17 09:18:24+05:45
68. 用copyreg实现可靠的pickle操作
import logging
# Example 1
class GameState:
def __init__(self):
self.level = 0
self.lives = 4
# Example 2
state = GameState()
state.level += 1 # Player beat a level
state.lives -= 1 # Player had to try again
print(state.__dict__)
# Example 3
import pickle
state_path = 'game_state.bin'
with open(state_path, 'wb') as f:
pickle.dump(state, f)
# Example 4
with open(state_path, 'rb') as f:
state_after = pickle.load(f)
print(state_after.__dict__)
# Example 5
class GameState:
def __init__(self):
self.level = 0
self.lives = 4
self.points = 0 # New field
# Example 6
state = GameState()
serialized = pickle.dumps(state)
state_after = pickle.loads(serialized)
print(state_after.__dict__)
# Example 7
with open(state_path, 'rb') as f:
state_after = pickle.load(f)
print(state_after.__dict__)
# Example 8
assert isinstance(state_after, GameState)
# Example 9
class GameState:
def __init__(self, level=0, lives=4, points=0):
self.level = level
self.lives = lives
self.points = points
# Example 10
def pickle_game_state(game_state):
kwargs = game_state.__dict__
return unpickle_game_state, (kwargs,)
# Example 11
def unpickle_game_state(kwargs):
return GameState(**kwargs)
# Example 12
import copyreg
copyreg.pickle(GameState, pickle_game_state)
# Example 13
state = GameState()
state.points += 1000
serialized = pickle.dumps(state)
state_after = pickle.loads(serialized)
print(state_after.__dict__)
# Example 14
class GameState:
def __init__(self, level=0, lives=4, points=0, magic=5):
self.level = level
self.lives = lives
self.points = points
self.magic = magic # New field
# Example 15
print('Before:', state.__dict__)
state_after = pickle.loads(serialized)
print('After: ', state_after.__dict__)
# Example 16
class GameState:
def __init__(self, level=0, points=0, magic=5):
self.level = level
self.points = points
self.magic = magic
# Example 17
# try:
# pickle.loads(serialized)
# except:
# logging.exception('Expected')
# else:
# assert False
# Example 18
def pickle_game_state(game_state):
kwargs = game_state.__dict__
kwargs['version'] = 2
return unpickle_game_state, (kwargs,)
# Example 19
def unpickle_game_state(kwargs):
version = kwargs.pop('version', 1)
if version == 1:
del kwargs['lives']
return GameState(**kwargs)
# Example 20
copyreg.pickle(GameState, pickle_game_state)
print('Before:', state.__dict__)
state_after = pickle.loads(serialized)
print('After: ', state_after.__dict__)
# Example 21
copyreg.dispatch_table.clear()
state = GameState()
serialized = pickle.dumps(state)
del GameState
class BetterGameState:
def __init__(self, level=0, points=0, magic=5):
self.level = level
self.points = points
self.magic = magic
# Example 22
# try:
# pickle.loads(serialized)
# except:
# logging.exception('Expected')
# else:
# assert False
# Example 23
print(serialized)
# Example 24
copyreg.pickle(BetterGameState, pickle_game_state)
# Example 25
state = BetterGameState()
serialized = pickle.dumps(state)
print(serialized)
69. 在需要准确计算的场合,用decimal表示相应的数值
# Example 1
rate = 1.45
seconds = 3*60 + 42
cost = rate * seconds / 60
print(cost) # 5.364999999999999
# Example 2
print(round(cost, 2)) # 5.36
# Example 3
from decimal import Decimal
rate = Decimal('1.45')
seconds = Decimal(3*60 + 42)
cost = rate * seconds / Decimal(60)
print(cost) # 5.365
# Example 4
print(Decimal('1.45')) # 1.45
print(Decimal(1.45)) # 1.4499999999999999555910790149937383830547332763671875
# Example 5
print('456') # 456
print(456) # 456
# Example 6
rate = Decimal('0.05')
seconds = Decimal('5')
small_cost = rate * seconds / Decimal(60)
print(small_cost) # 0.004166666666666666666666666667
# Example 7
print(round(small_cost, 2)) # 0.00
# Example 8
from decimal import ROUND_UP
rounded = cost.quantize(Decimal('0.01'), rounding=ROUND_UP)
print(f'Rounded {cost} to {rounded}') # Rounded 5.365 to 5.37
# Example 9
rounded = small_cost.quantize(Decimal('0.01'), rounding=ROUND_UP)
print(f'Rounded {small_cost} to {rounded}') # Rounded 0.004166666666666666666666666667 to 0.01
70. 先分析性能,然后再优化
- 优化python程序之前,一定要先分析它的性能,程序变慢的真正原因未必和我们想的一样;
- 可以通过Stats对象筛选出我们关心的结果。
# Example 1
def insertion_sort(data):
result = []
for value in data:
insert_value(result, value)
return result
# Example 2
def insert_value(array, value):
for i, existing in enumerate(array):
if existing > value:
array.insert(i, value)
return
array.append(value)
# Example 3
from random import randint
max_size = 10**4
data = [randint(0, max_size) for _ in range(max_size)]
test = lambda: insertion_sort(data)
# Example 4
from cProfile import Profile
profiler = Profile()
profiler.runcall(test)
# Example 5
from pstats import Stats
stats = Stats(profiler)
# stats = Stats(profiler, stream=STDOUT)
stats.strip_dirs()
stats.sort_stats('cumulative')
stats.print_stats()
# Example 6
from bisect import bisect_left
def insert_value(array, value):
i = bisect_left(array, value)
array.insert(i, value)
# Example 7
profiler = Profile()
profiler.runcall(test)
# stats = Stats(profiler, stream=STDOUT)
stats.strip_dirs()
stats.sort_stats('cumulative')
stats.print_stats()
# Example 8
def my_utility(a, b):
c = 1
for i in range(100):
c += a * b
def first_func():
for _ in range(1000):
my_utility(4, 5)
def second_func():
for _ in range(10):
my_utility(1, 3)
def my_program():
for _ in range(20):
first_func()
second_func()
# Example 9
profiler = Profile()
profiler.runcall(my_program)
# stats = Stats(profiler, stream=STDOUT)
stats.strip_dirs()
stats.sort_stats('cumulative')
stats.print_stats()
# Example 10
# stats = Stats(profiler, stream=STDOUT)
stats.strip_dirs()
stats.sort_stats('cumulative')
stats.print_callers()
71. 优先考虑用deque实现生产者-消费者队列
# Example 1
class Email:
def __init__(self, sender, receiver, message):
self.sender = sender
self.receiver = receiver
self.message = message
# Example 2
def get_emails():
yield Email('foo@example.com', 'bar@example.com', 'hello1')
yield Email('baz@example.com', 'banana@example.com', 'hello2')
yield None
yield Email('meep@example.com', 'butter@example.com', 'hello3')
yield Email('stuff@example.com', 'avocado@example.com', 'hello4')
yield None
yield Email('thingy@example.com', 'orange@example.com', 'hello5')
yield Email('roger@example.com', 'bob@example.com', 'hello6')
yield None
yield Email('peanut@example.com', 'alice@example.com', 'hello7')
yield None
EMAIL_IT = get_emails()
class NoEmailError(Exception):
pass
def try_receive_email():
# Returns an Email instance or raises NoEmailError
try:
email = next(EMAIL_IT)
except StopIteration:
email = None
if not email:
raise NoEmailError
print(f'Produced email: {email.message}')
return email
# Example 3
def produce_emails(queue):
while True:
try:
email = try_receive_email()
except NoEmailError:
return
else:
queue.append(email) # Producer
# Example 4
def consume_one_email(queue):
if not queue:
return
email = queue.pop(0) # Consumer
# Index the message for long-term archival
print(f'Consumed email: {email.message}')
# Example 5
def loop(queue, keep_running):
while keep_running():
produce_emails(queue)
consume_one_email(queue)
def make_test_end():
count=list(range(10))
def func():
if count:
count.pop()
return True
return False
return func
def my_end_func():
pass
my_end_func = make_test_end()
loop([], my_end_func)
# Example 6
import timeit
def print_results(count, tests):
avg_iteration = sum(tests) / len(tests)
print(f'Count {count:>5,} takes {avg_iteration:.6f}s')
return count, avg_iteration
def list_append_benchmark(count):
def run(queue):
for i in range(count):
queue.append(i)
tests = timeit.repeat(
setup='queue = []',
stmt='run(queue)',
globals=locals(),
repeat=1000,
number=1)
return print_results(count, tests)
# Example 7
def print_delta(before, after):
before_count, before_time = before
after_count, after_time = after
growth = 1 + (after_count - before_count) / before_count
slowdown = 1 + (after_time - before_time) / before_time
print(f'{growth:>4.1f}x data size, {slowdown:>4.1f}x time')
baseline = list_append_benchmark(500)
for count in (1_000, 2_000, 3_000, 4_000, 5_000):
print()
comparison = list_append_benchmark(count)
print_delta(baseline, comparison)
# Example 8
def list_pop_benchmark(count):
def prepare():
return list(range(count))
def run(queue):
while queue:
queue.pop(0)
tests = timeit.repeat(
setup='queue = prepare()',
stmt='run(queue)',
globals=locals(),
repeat=1000,
number=1)
return print_results(count, tests)
# Example 9
baseline = list_pop_benchmark(500)
for count in (1_000, 2_000, 3_000, 4_000, 5_000):
print()
comparison = list_pop_benchmark(count)
print_delta(baseline, comparison)
# Example 10
import collections
def consume_one_email(queue):
if not queue:
return
email = queue.popleft() # Consumer
# Process the email message
print(f'Consumed email: {email.message}')
def my_end_func():
pass
my_end_func = make_test_end()
EMAIL_IT = get_emails()
loop(collections.deque(), my_end_func)
# Example 11
def deque_append_benchmark(count):
def prepare():
return collections.deque()
def run(queue):
for i in range(count):
queue.append(i)
tests = timeit.repeat(
setup='queue = prepare()',
stmt='run(queue)',
globals=locals(),
repeat=1000,
number=1)
return print_results(count, tests)
baseline = deque_append_benchmark(500)
for count in (1_000, 2_000, 3_000, 4_000, 5_000):
print()
comparison = deque_append_benchmark(count)
print_delta(baseline, comparison)
# Example 12
def dequeue_popleft_benchmark(count):
def prepare():
return collections.deque(range(count))
def run(queue):
while queue:
queue.popleft()
tests = timeit.repeat(
setup='queue = prepare()',
stmt='run(queue)',
globals=locals(),
repeat=1000,
number=1)
return print_results(count, tests)
baseline = dequeue_popleft_benchmark(500)
for count in (1_000, 2_000, 3_000, 4_000, 5_000):
print()
comparison = dequeue_popleft_benchmark(count)
print_delta(baseline, comparison)
72. 考虑用bisect搜索已排序的序列
# Example 1
data = list(range(10**5))
index = data.index(91234)
assert index == 91234
# Example 2
def find_closest(sequence, goal):
for index, value in enumerate(sequence):
if goal < value:
return index
raise ValueError(f'{goal} is out of bounds')
index = find_closest(data, 91234.56)
assert index == 91235
try:
find_closest(data, 100000000)
except ValueError:
pass # Expected
else:
assert False
# Example 3 python内置的bisect模块可以更好的搜索有序列表
from bisect import bisect_left
index = bisect_left(data, 91234) # Exact match
assert index == 91234
index = bisect_left(data, 91234.56) # Closest match
assert index == 91235
# Example 4
import random
import timeit
size = 10**5
iterations = 1000
data = list(range(size))
to_lookup = [random.randint(0, size)
for _ in range(iterations)]
def run_linear(data, to_lookup):
for index in to_lookup:
data.index(index)
def run_bisect(data, to_lookup):
for index in to_lookup:
bisect_left(data, index)
baseline = timeit.timeit(
stmt='run_linear(data, to_lookup)',
globals=globals(),
number=10)
print(f'Linear search takes {baseline:.6f}s')
comparison = timeit.timeit(
stmt='run_bisect(data, to_lookup)',
globals=globals(),
number=10)
print(f'Bisect search takes {comparison:.6f}s')
slowdown = 1 + ((baseline - comparison) / comparison)
print(f'{slowdown:.1f}x time')
73. 学会使用heapq制作优先级队列
有时候,我们想根据元素的优先程度来排序,此时应该使用优先级队列。
import logging
# Example 1
class Book:
def __init__(self, title, due_date):
self.title = title
self.due_date = due_date
# Example 2
def add_book(queue, book):
queue.append(book)
queue.sort(key=lambda x: x.due_date, reverse=True)
queue = []
add_book(queue, Book('Don Quixote', '2019-06-07'))
add_book(queue, Book('Frankenstein', '2019-06-05'))
add_book(queue, Book('Les Misérables', '2019-06-08'))
add_book(queue, Book('War and Peace', '2019-06-03'))
# Example 3
class NoOverdueBooks(Exception):
pass
def next_overdue_book(queue, now):
if queue:
book = queue[-1]
if book.due_date < now:
queue.pop()
return book
raise NoOverdueBooks
# Example 4
now = '2019-06-10'
found = next_overdue_book(queue, now)
print(found.title)
found = next_overdue_book(queue, now)
print(found.title)
# Example 5
def return_book(queue, book):
queue.remove(book)
queue = []
book = Book('Treasure Island', '2019-06-04')
add_book(queue, book)
print('Before return:', [x.title for x in queue])
return_book(queue, book)
print('After return: ', [x.title for x in queue])
# Example 6
try:
next_overdue_book(queue, now)
except NoOverdueBooks:
pass # Expected
else:
assert False # Doesn't happen
# Example 7
import random
import timeit
def print_results(count, tests):
avg_iteration = sum(tests) / len(tests)
print(f'Count {count:>5,} takes {avg_iteration:.6f}s')
return count, avg_iteration
def print_delta(before, after):
before_count, before_time = before
after_count, after_time = after
growth = 1 + (after_count - before_count) / before_count
slowdown = 1 + (after_time - before_time) / before_time
print(f'{growth:>4.1f}x data size, {slowdown:>4.1f}x time')
def list_overdue_benchmark(count):
def prepare():
to_add = list(range(count))
random.shuffle(to_add)
return [], to_add
def run(queue, to_add):
for i in to_add:
queue.append(i)
queue.sort(reverse=True)
while queue:
queue.pop()
tests = timeit.repeat(
setup='queue, to_add = prepare()',
stmt=f'run(queue, to_add)',
globals=locals(),
repeat=100,
number=1)
return print_results(count, tests)
# Example 8
baseline = list_overdue_benchmark(500)
for count in (1_000, 1_500, 2_000):
print()
comparison = list_overdue_benchmark(count)
print_delta(baseline, comparison)
# Example 9
def list_return_benchmark(count):
def prepare():
queue = list(range(count))
random.shuffle(queue)
to_return = list(range(count))
random.shuffle(to_return)
return queue, to_return
def run(queue, to_return):
for i in to_return:
queue.remove(i)
tests = timeit.repeat(
setup='queue, to_return = prepare()',
stmt=f'run(queue, to_return)',
globals=locals(),
repeat=100,
number=1)
return print_results(count, tests)
# Example 10
baseline = list_return_benchmark(500)
for count in (1_000, 1_500, 2_000):
print()
comparison = list_return_benchmark(count)
print_delta(baseline, comparison)
# Example 11
from heapq import heappush
def add_book(queue, book):
heappush(queue, book)
# Example 12
# try:
# queue = []
# add_book(queue, Book('Little Women', '2019-06-05'))
# add_book(queue, Book('The Time Machine', '2019-05-30'))
# except:
# logging.exception('Expected')
# else:
# assert False
# Example 13
import functools
@functools.total_ordering
class Book:
def __init__(self, title, due_date):
self.title = title
self.due_date = due_date
def __lt__(self, other):
return self.due_date < other.due_date
# Example 14
queue = []
add_book(queue, Book('Pride and Prejudice', '2019-06-01'))
add_book(queue, Book('The Time Machine', '2019-05-30'))
add_book(queue, Book('Crime and Punishment', '2019-06-06'))
add_book(queue, Book('Wuthering Heights', '2019-06-12'))
print([b.title for b in queue])
# Example 15
queue = [
Book('Pride and Prejudice', '2019-06-01'),
Book('The Time Machine', '2019-05-30'),
Book('Crime and Punishment', '2019-06-06'),
Book('Wuthering Heights', '2019-06-12'),
]
queue.sort()
print([b.title for b in queue])
# Example 16
from heapq import heapify
queue = [
Book('Pride and Prejudice', '2019-06-01'),
Book('The Time Machine', '2019-05-30'),
Book('Crime and Punishment', '2019-06-06'),
Book('Wuthering Heights', '2019-06-12'),
]
heapify(queue)
print([b.title for b in queue])
# Example 17
from heapq import heappop
def next_overdue_book(queue, now):
if queue:
book = queue[0] # Most overdue first
if book.due_date < now:
heappop(queue) # Remove the overdue book
return book
raise NoOverdueBooks
# Example 18
now = '2019-06-02'
book = next_overdue_book(queue, now)
print(book.title)
book = next_overdue_book(queue, now)
print(book.title)
try:
next_overdue_book(queue, now)
except NoOverdueBooks:
pass # Expected
else:
assert False # Doesn't happen
# Example 19
def heap_overdue_benchmark(count):
def prepare():
to_add = list(range(count))
random.shuffle(to_add)
return [], to_add
def run(queue, to_add):
for i in to_add:
heappush(queue, i)
while queue:
heappop(queue)
tests = timeit.repeat(
setup='queue, to_add = prepare()',
stmt=f'run(queue, to_add)',
globals=locals(),
repeat=100,
number=1)
return print_results(count, tests)
# Example 20
baseline = heap_overdue_benchmark(500)
for count in (1_000, 1_500, 2_000):
print()
comparison = heap_overdue_benchmark(count)
print_delta(baseline, comparison)
# Example 21
@functools.total_ordering
class Book:
def __init__(self, title, due_date):
self.title = title
self.due_date = due_date
self.returned = False # New field
def __lt__(self, other):
return self.due_date < other.due_date
# Example 22
def next_overdue_book(queue, now):
while queue:
book = queue[0]
if book.returned:
heappop(queue)
continue
if book.due_date < now:
heappop(queue)
return book
break
raise NoOverdueBooks
queue = []
book = Book('Pride and Prejudice', '2019-06-01')
add_book(queue, book)
book = Book('The Time Machine', '2019-05-30')
add_book(queue, book)
book.returned = True
book = Book('Crime and Punishment', '2019-06-06')
add_book(queue, book)
book.returned = True
book = Book('Wuthering Heights', '2019-06-12')
add_book(queue, book)
now = '2019-06-11'
book = next_overdue_book(queue, now)
assert book.title == 'Pride and Prejudice'
try:
next_overdue_book(queue, now)
except NoOverdueBooks:
pass # Expected
else:
assert False # Doesn't happen
# Example 23
def return_book(queue, book):
book.returned = True
assert not book.returned
return_book(queue, book)
assert book.returned
74. 考虑用memoryview与bytearray来实现无须拷贝的bytes操作
# Example 1
def timecode_to_index(video_id, timecode):
return 1234
# Returns the byte offset in the video data
def request_chunk(video_id, byte_offset, size):
pass
# Returns size bytes of video_id's data from the offset
video_id = ...
timecode = '01:09:14:28'
byte_offset = timecode_to_index(video_id, timecode)
size = 20 * 1024 * 1024
video_data = request_chunk(video_id, byte_offset, size)
# Example 2
class NullSocket:
def __init__(self):
self.handle = open(os.devnull, 'wb')
def send(self, data):
self.handle.write(data)
socket = ... # socket connection to client
video_data = ... # bytes containing data for video_id
byte_offset = ... # Requested starting position
size = 20 * 1024 * 1024 # Requested chunk size
import os
socket = NullSocket()
video_data = 100 * os.urandom(1024 * 1024)
byte_offset = 1234
chunk = video_data[byte_offset:byte_offset + size]
socket.send(chunk)
# Example 3
import timeit
def run_test():
chunk = video_data[byte_offset:byte_offset + size]
# Call socket.send(chunk), but ignoring for benchmark
result = timeit.timeit(
stmt='run_test()',
globals=globals(),
number=100) / 100
print(f'{result:0.9f} seconds')
# Example 4
data = b'shave and a haircut, two bits'
view = memoryview(data)
chunk = view[12:19]
print(chunk)
print('Size: ', chunk.nbytes)
print('Data in view: ', chunk.tobytes())
print('Underlying data:', chunk.obj)
# Example 5
video_view = memoryview(video_data)
def run_test():
chunk = video_view[byte_offset:byte_offset + size]
# Call socket.send(chunk), but ignoring for benchmark
result = timeit.timeit(
stmt='run_test()',
globals=globals(),
number=100) / 100
print(f'{result:0.9f} seconds')
# Example 6
class FakeSocket:
def recv(self, size):
return video_view[byte_offset:byte_offset+size]
def recv_into(self, buffer):
source_data = video_view[byte_offset:byte_offset+size]
buffer[:] = source_data
socket = ... # socket connection to the client
video_cache = ... # Cache of incoming video stream
byte_offset = ... # Incoming buffer position
size = 1024 * 1024 # Incoming chunk size
socket = FakeSocket()
video_cache = video_data[:]
byte_offset = 1234
chunk = socket.recv(size)
video_view = memoryview(video_cache)
before = video_view[:byte_offset]
after = video_view[byte_offset + size:]
new_cache = b''.join([before, chunk, after])
# Example 7
def run_test():
chunk = socket.recv(size)
before = video_view[:byte_offset]
after = video_view[byte_offset + size:]
new_cache = b''.join([before, chunk, after])
result = timeit.timeit(
stmt='run_test()',
globals=globals(),
number=100) / 100
print(f'{result:0.9f} seconds')
# Example 8
# try:
# my_bytes = b'hello'
# my_bytes[0] = b'\x79'
# except:
# logging.exception('Expected')
# else:
# assert False
# Example 9
my_array = bytearray(b'hello')
my_array[0] = 0x79
print(my_array)
# Example 10
my_array = bytearray(b'row, row, row your boat')
my_view = memoryview(my_array)
write_view = my_view[3:13]
write_view[:] = b'-10 bytes-'
print(my_array)
# Example 11
video_array = bytearray(video_cache)
write_view = memoryview(video_array)
chunk = write_view[byte_offset:byte_offset + size]
socket.recv_into(chunk)
# Example 12
def run_test():
chunk = write_view[byte_offset:byte_offset + size]
socket.recv_into(chunk)
result = timeit.timeit(
stmt='run_test()',
globals=globals(),
number=100) / 100
print(f'{result:0.9f} seconds')
九. 测试与调试
75. 通过repr字符串输出调试信息
# Example 1
print('foo bar')
# Example 2
my_value = 'foo bar'
print(str(my_value))
print('%s' % my_value)
print(f'{my_value}')
print(format(my_value))
print(my_value.__format__('s'))
print(my_value.__str__())
# Example 3
print(5)
print('5')
int_value = 5
str_value = '5'
print(f'{int_value} == {str_value} ?')
# Example 4
a = '\x07'
print(repr(a))
# Example 5
b = eval(repr(a))
assert a == b
# Example 6
print(repr(5))
print(repr('5'))
# Example 7
print('%r' % 5)
print('%r' % '5')
int_value = 5
str_value = '5'
print(f'{int_value!r} != {str_value!r}')
# Example 8
class OpaqueClass:
def __init__(self, x, y):
self.x = x
self.y = y
obj = OpaqueClass(1, 'foo')
print(obj)
# Example 9
class BetterClass:
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return f'BetterClass({self.x!r}, {self.y!r})'
# Example 10
obj = BetterClass(2, 'bar')
print(obj)
# Example 11
obj = OpaqueClass(4, 'baz')
print(obj.__dict__)
76. 在TestCase子类里验证相关的行为
- Helper_test.py
from unittest import TestCase, main
def sum_squares(values):
cumulative = 0
for value in values:
cumulative += value ** 2
yield cumulative
class HelperTestCase(TestCase):
def verify_complex_case(self, values, expected):
expect_it = iter(expected)
found_it = iter(sum_squares(values))
test_it = zip(expect_it, found_it)
for i, (expect, found) in enumerate(test_it):
self.assertEqual(
expect,
found,
f'Index {i} is wrong')
# Verify both generators are exhausted
try:
next(expect_it)
except StopIteration:
pass
else:
self.fail('Expected longer than found')
try:
next(found_it)
except StopIteration:
pass
else:
self.fail('Found longer than expected')
def test_wrong_lengths(self):
values = [1.1, 2.2, 3.3]
expected = [
1.1**2,
]
self.verify_complex_case(values, expected)
def test_wrong_results(self):
values = [1.1, 2.2, 3.3]
expected = [
1.1**2,
1.1**2 + 2.2**2,
1.1**2 + 2.2**2 + 3.3**2 + 4.4**2,
]
self.verify_complex_case(values, expected)
if __name__ == '__main__':
main()
- Data_driven_test.py
from unittest import TestCase, main
from utils import to_str
class DataDrivenTestCase(TestCase):
def test_good(self):
good_cases = [
(b'my bytes', 'my bytes'),
('no error', b'no error'), # This one will fail
('other str', 'other str'),
]
for value, expected in good_cases:
with self.subTest(value):
self.assertEqual(expected, to_str(value))
def test_bad(self):
bad_cases = [
(object(), TypeError),
(b'\xfa\xfa', UnicodeDecodeError),
]
for value, exception in bad_cases:
with self.subTest(value):
with self.assertRaises(exception):
to_str(value)
if __name__ == '__main__':
main()
77. 把测试前、后的准备与清理逻辑写在setUp、tearDown、setUpModule与tearDownModule中,以防用例之间互相干扰
from unittest import TestCase, main
def setUpModule():
print('* Module setup')
def tearDownModule():
print('* Module clean-up')
class IntegrationTest(TestCase):
def setUp(self):
print('* Test setup')
def tearDown(self):
print('* Test clean-up')
def test_end_to_end1(self):
print('* Test 1')
def test_end_to_end2(self):
print('* Test 2')
if __name__ == '__main__':
main()
78. 用Mock来模拟受测代码所依赖的复杂函数
import logging
# Example 1
class DatabaseConnection:
def __init__(self, host, port):
pass
class DatabaseConnectionError(Exception):
pass
def get_animals(database, species):
# Query the Database
raise DatabaseConnectionError('Not connected')
# Return a list of (name, last_mealtime) tuples
# Example 2
# try:
# database = DatabaseConnection('localhost', '4444')
#
# get_animals(database, 'Meerkat')
# except:
# logging.exception('Expected')
# else:
# assert False
# Example 3
from datetime import datetime
from unittest.mock import Mock
mock = Mock(spec=get_animals)
expected = [
('Spot', datetime(2019, 6, 5, 11, 15)),
('Fluffy', datetime(2019, 6, 5, 12, 30)),
('Jojo', datetime(2019, 6, 5, 12, 45)),
]
mock.return_value = expected
# Example 4
# try:
# mock.does_not_exist
# except:
# logging.exception('Expected')
# else:
# assert False
# Example 5
database = object()
result = mock(database, 'Meerkat')
assert result == expected
# Example 6
mock.assert_called_once_with(database, 'Meerkat')
# Example 7
# try:
# mock.assert_called_once_with(database, 'Giraffe')
# except:
# logging.exception('Expected')
# else:
# assert False
# Example 8
from unittest.mock import ANY
mock = Mock(spec=get_animals)
mock('database 1', 'Rabbit')
mock('database 2', 'Bison')
mock('database 3', 'Meerkat')
mock.assert_called_with(ANY, 'Meerkat')
# Example 9
# try:
# class MyError(Exception):
# pass
#
# mock = Mock(spec=get_animals)
# mock.side_effect = MyError('Whoops! Big problem')
# result = mock(database, 'Meerkat')
# except:
# logging.exception('Expected')
# else:
# assert False
# Example 10
def get_food_period(database, species):
# Query the Database
pass
# Return a time delta
def feed_animal(database, name, when):
# Write to the Database
pass
def do_rounds(database, species):
now = datetime.datetime.utcnow()
feeding_timedelta = get_food_period(database, species)
animals = get_animals(database, species)
fed = 0
for name, last_mealtime in animals:
if (now - last_mealtime) > feeding_timedelta:
feed_animal(database, name, now)
fed += 1
return fed
# Example 11
def do_rounds(database, species, *,
now_func=datetime.utcnow,
food_func=get_food_period,
animals_func=get_animals,
feed_func=feed_animal):
now = now_func()
feeding_timedelta = food_func(database, species)
animals = animals_func(database, species)
fed = 0
for name, last_mealtime in animals:
if (now - last_mealtime) > feeding_timedelta:
feed_func(database, name, now)
fed += 1
return fed
# Example 12
from datetime import timedelta
now_func = Mock(spec=datetime.utcnow)
now_func.return_value = datetime(2019, 6, 5, 15, 45)
food_func = Mock(spec=get_food_period)
food_func.return_value = timedelta(hours=3)
animals_func = Mock(spec=get_animals)
animals_func.return_value = [
('Spot', datetime(2019, 6, 5, 11, 15)),
('Fluffy', datetime(2019, 6, 5, 12, 30)),
('Jojo', datetime(2019, 6, 5, 12, 45)),
]
feed_func = Mock(spec=feed_animal)
# Example 13
result = do_rounds(
database,
'Meerkat',
now_func=now_func,
food_func=food_func,
animals_func=animals_func,
feed_func=feed_func)
assert result == 2
# Example 14
from unittest.mock import call
food_func.assert_called_once_with(database, 'Meerkat')
animals_func.assert_called_once_with(database, 'Meerkat')
feed_func.assert_has_calls(
[
call(database, 'Spot', now_func.return_value),
call(database, 'Fluffy', now_func.return_value),
],
any_order=True)
# Example 15
from unittest.mock import patch
print('Outside patch:', get_animals)
with patch('__main__.get_animals'):
print('Inside patch: ', get_animals)
print('Outside again:', get_animals)
# Example 16
# try:
# fake_now = datetime(2019, 6, 5, 15, 45)
#
# with patch('datetime.datetime.utcnow'):
# datetime.utcnow.return_value = fake_now
# except:
# logging.exception('Expected')
# else:
# assert False
# Example 17
def get_do_rounds_time():
return datetime.datetime.utcnow()
def do_rounds(database, species):
now = get_do_rounds_time()
with patch('__main__.get_do_rounds_time'):
pass
# Example 18
def do_rounds(database, species, *, utcnow=datetime.utcnow):
now = utcnow()
feeding_timedelta = get_food_period(database, species)
animals = get_animals(database, species)
fed = 0
for name, last_mealtime in animals:
if (now - last_mealtime) > feeding_timedelta:
feed_animal(database, name, now)
fed += 1
return fed
# Example 19
from unittest.mock import DEFAULT
with patch.multiple('__main__',
autospec=True,
get_food_period=DEFAULT,
get_animals=DEFAULT,
feed_animal=DEFAULT):
now_func = Mock(spec=datetime.utcnow)
now_func.return_value = datetime(2019, 6, 5, 15, 45)
get_food_period.return_value = timedelta(hours=3)
get_animals.return_value = [
('Spot', datetime(2019, 6, 5, 11, 15)),
('Fluffy', datetime(2019, 6, 5, 12, 30)),
('Jojo', datetime(2019, 6, 5, 12, 45))
]
# Example 20
result = do_rounds(database, 'Meerkat', utcnow=now_func)
assert result == 2
get_food_period.assert_called_once_with(database, 'Meerkat')
get_animals.assert_called_once_with(database, 'Meerkat')
feed_animal.assert_has_calls(
[
call(database, 'Spot', now_func.return_value),
call(database, 'Fluffy', now_func.return_value),
],
any_order=True)
79. 把受测代码所依赖的系统封装起来,以便于模拟和测试
# Example 1
class ZooDatabase:
def get_animals(self, species):
pass
def get_food_period(self, species):
pass
def feed_animal(self, name, when):
pass
# Example 2
from datetime import datetime
def do_rounds(database, species, *, utcnow=datetime.utcnow):
now = utcnow()
feeding_timedelta = database.get_food_period(species)
animals = database.get_animals(species)
fed = 0
for name, last_mealtime in animals:
if (now - last_mealtime) >= feeding_timedelta:
database.feed_animal(name, now)
fed += 1
return fed
# Example 3
from unittest.mock import Mock
database = Mock(spec=ZooDatabase)
print(database.feed_animal)
database.feed_animal()
database.feed_animal.assert_any_call()
# Example 4
from datetime import timedelta
from unittest.mock import call
now_func = Mock(spec=datetime.utcnow)
now_func.return_value = datetime(2019, 6, 5, 15, 45)
database = Mock(spec=ZooDatabase)
database.get_food_period.return_value = timedelta(hours=3)
database.get_animals.return_value = [
('Spot', datetime(2019, 6, 5, 11, 15)),
('Fluffy', datetime(2019, 6, 5, 12, 30)),
('Jojo', datetime(2019, 6, 5, 12, 55))
]
# Example 5
result = do_rounds(database, 'Meerkat', utcnow=now_func)
assert result == 2
database.get_food_period.assert_called_once_with('Meerkat')
database.get_animals.assert_called_once_with('Meerkat')
database.feed_animal.assert_has_calls(
[
call('Spot', now_func.return_value),
call('Fluffy', now_func.return_value),
],
any_order=True)
# # Example 6
# try:
# database.bad_method_name()
# except:
# logging.exception('Expected')
# else:
# assert False
# Example 7
DATABASE = None
def get_database():
global DATABASE
if DATABASE is None:
DATABASE = ZooDatabase()
return DATABASE
def main(argv):
database = get_database()
species = argv[1]
count = do_rounds(database, species)
print(f'Fed {count} {species}(s)')
return 0
# Example 8
import contextlib
import io
from unittest.mock import patch
with patch('__main__.DATABASE', spec=ZooDatabase):
now = datetime.utcnow()
DATABASE.get_food_period.return_value = timedelta(hours=3)
DATABASE.get_animals.return_value = [
('Spot', now - timedelta(minutes=4.5)),
('Fluffy', now - timedelta(hours=3.25)),
('Jojo', now - timedelta(hours=3)),
]
fake_stdout = io.StringIO()
with contextlib.redirect_stdout(fake_stdout):
main(['program name', 'Meerkat'])
found = fake_stdout.getvalue()
expected = 'Fed 2 Meerkat(s)\n'
assert found == expected
80. 考虑用pdb做交互调试
import math
def compute_rmse(observed, ideal):
total_err_2 = 0
count = 0
for got, wanted in zip(observed, ideal):
err_2 = (got - wanted) ** 2
breakpoint() # Start the debugger here
total_err_2 += err_2
count += 1
mean_err = total_err_2 / count
rmse = math.sqrt(mean_err)
return rmse
result = compute_rmse(
[1.8, 1.7, 3.2, 6],
[2, 1.5, 3, 5])
print(result)
81. 用tracemalloc来掌握内存的使用与泄漏情况
- gc模块可以帮助我们了解垃圾回收器追踪到了哪些对象,但并不能告诉我们那些对象是如何分配的;
import gc
found_objects = gc.get_objects()
print('Before:', len(found_objects))
import waste_memory
hold_reference = waste_memory.run()
found_objects = gc.get_objects()
print('After: ', len(found_objects))
for obj in found_objects[:3]:
print(repr(obj)[:100])
print('...')
- python内置的tracemalloc模块提供了一套强大的工具,可以帮助我们了解内存的使用情况,并且找到这些内存由那一行代码所分配。
import tracemalloc
tracemalloc.start(10) # Set stack depth
time1 = tracemalloc.take_snapshot() # Before snapshot
import os
class MyObject:
def __init__(self):
self.data = os.urandom(100)
def get_data():
values = []
for _ in range(100):
obj = MyObject()
values.append(obj)
return values
def run():
deep_values = []
for _ in range(100):
deep_values.append(get_data())
return deep_values
x = run() # Usage to debug
time2 = tracemalloc.take_snapshot() # After snapshot
stats = time2.compare_to(time1, 'lineno') # Compare snapshots
for stat in stats[:3]:
print(stat)
import tracemalloc
tracemalloc.start(10)
time1 = tracemalloc.take_snapshot()
import os
class MyObject:
def __init__(self):
self.data = os.urandom(100)
def get_data():
values = []
for _ in range(100):
obj = MyObject()
values.append(obj)
return values
def run():
deep_values = []
for _ in range(100):
deep_values.append(get_data())
return deep_values
x = run()
time2 = tracemalloc.take_snapshot()
stats = time2.compare_to(time1, 'traceback')
top = stats[0]
print('Biggest offender is:')
print('\n'.join(top.traceback.format()))
十. 协作并发
82. 学会寻找由其他Python开发者所构建的模块
python集中存放模块的地方:https://pypi.org 。
83. 用虚拟环境隔离项目,并重建依赖关系
通过pip show可以查看它的依赖关系:python3 -m pip show flask
, 结果如下:
Name: Flask
Version: 2.0.1
Summary: A simple framework for building complex web applications.
Home-page: https://palletsprojects.com/p/flask
Author: Armin Ronacher
Author-email: armin.ronacher@active-4.com
License: BSD-3-Clause
Location: /usr/local/lib/python3.9/site-packages
Requires: itsdangerous, click, Jinja2, Werkzeug
Required-by:
在命令行界面执行venv命令:
python3 -m venv 虚拟环境名 # 创建虚拟环境
source bin/activate 和 deactivate # 启用和禁用该环境
python3 -m pip install -r requirements.txt # 安装包
84. 每一个函数、类与模块都要写docstring
import itertools
def find_anagrams(word, dictionary):
"""Find all anagrams for a word.
This function only runs as fast as the test for
membership in the 'dictionary' container.
Args:
word: String of the target word.
dictionary: collections.abc.Container with all
strings that are known to be actual words.
Returns:
List of anagrams that were found. Empty if
none were found.
"""
permutations = itertools.permutations(word, len(word))
possible = (''.join(x) for x in permutations)
found = {word for word in possible if word in dictionary}
return list(found)
assert find_anagrams('pancakes', ['scanpeak']) == ['scanpeak']
85. 用包来安排模块,以提供稳固的API
如我们想设计一个包,用来计算抛射物之间的碰撞
# mypackage/models.py
__all__ = ['Projectile']
class Projectile:
def __init__(self, mass, velocity):
self.mass = mass
self.velocity = velocity
# mypackage/utils.py
from . models import Projectile
__all__ = ['simulate_collision']
def _dot_product(a, b):
pass
def simulate_collision(a, b):
after_a = Projectile(-a.mass, -a.velocity)
after_b = Projectile(-b.mass, -b.velocity)
return after_a, after_b
# mypackage/__init__.py
__all__ = []
from . models import *
__all__ += models.__all__
from . utils import *
__all__ += utils.__all__
没有出现在__all___
之中的,都不会随着from mypackage import *
语句引入,对外部使用者隐藏了这些名字。
# api_consumer.py
from mypackage import *
a = Projectile(1.5, 3)
b = Projectile(4, 1.7)
after_a, after_b = simulate_collision(a, b)
print(after_a.__dict__, after_b.__dict__) # {'mass': -1.5, 'velocity': -3} {'mass': -4, 'velocity': -1.7}
86. 考虑用模块级别的代码配置不同的部署环境
class TestingDatabase:
pass
class RealDatabase:
pass
TESTING = True
if TESTING:
Database = TestingDatabase
else:
Database = RealDatabase
87. 为自编的模块定义根异常,让调用者能够专门处理与此API有关的异常
import logging
# Example 1
# my_module.py
def determine_weight(volume, density):
if density <= 0:
raise ValueError('Density must be positive')
try:
determine_weight(1, 0)
except ValueError:
pass
else:
assert False
# Example 2
# my_module.py
class Error(Exception):
"""Base-class for all exceptions raised by this module."""
class InvalidDensityError(Error):
"""There was a problem with a provided density value."""
class InvalidVolumeError(Error):
"""There was a problem with the provided weight value."""
def determine_weight(volume, density):
if density < 0:
raise InvalidDensityError('Density must be positive')
if volume < 0:
raise InvalidVolumeError('Volume must be positive')
if volume == 0:
density / volume
# Example 3
class my_module:
Error = Error
InvalidDensityError = InvalidDensityError
@staticmethod
def determine_weight(volume, density):
if density < 0:
raise InvalidDensityError('Density must be positive')
if volume < 0:
raise InvalidVolumeError('Volume must be positive')
if volume == 0:
density / volume
try:
weight = my_module.determine_weight(1, -1)
except my_module.Error:
logging.exception('Unexpected error')
else:
assert False
# Example 4
SENTINEL = object()
weight = SENTINEL
try:
weight = my_module.determine_weight(-1, 1)
except my_module.InvalidDensityError:
weight = 0
except my_module.Error:
logging.exception('Bug in the calling code')
else:
assert False
assert weight is SENTINEL
# Example 5
try:
weight = SENTINEL
try:
weight = my_module.determine_weight(0, 1)
except my_module.InvalidDensityError:
weight = 0
except my_module.Error:
logging.exception('Bug in the calling code')
except Exception:
logging.exception('Bug in the API code!')
raise # Re-raise exception to the caller
else:
assert False
assert weight == 0
except:
logging.exception('Expected')
else:
assert False
# Example 6
# my_module.py
class NegativeDensityError(InvalidDensityError):
"""A provided density value was negative."""
def determine_weight(volume, density):
if density < 0:
raise NegativeDensityError('Density must be positive')
# Example 7
try:
my_module.NegativeDensityError = NegativeDensityError
my_module.determine_weight = determine_weight
try:
weight = my_module.determine_weight(1, -1)
except my_module.NegativeDensityError:
raise ValueError('Must supply non-negative density')
except my_module.InvalidDensityError:
weight = 0
except my_module.Error:
logging.exception('Bug in the calling code')
except Exception:
logging.exception('Bug in the API code!')
raise
else:
assert False
except:
logging.exception('Expected')
else:
assert False
88. 用适当的方式打破循环依赖关系
如果两个模块都要在开头引入对方,那就会形成循环依赖关系,可能会导致程序启动时崩溃;
想要打破循环依赖,最好的办法是把这两个模块都用到的代码重构到整个依赖的最底层;
如果不想大幅度重构代码,也不想让代码太复杂,最简单的方法就是通过动态引入来消除循环依赖关系。
# dialog.py class Dialog: def __init__(self): pass # Using this instead will break things # save_dialog = Dialog(app.prefs.get('save_dir')) save_dialog = Dialog() def show(): import app # Dynamic import 程序运行时才触发 save_dialog.save_dir = app.prefs.get('save_dir') print('Showing the dialog!')
# app.py import dialog class Prefs: def get(self, name): pass prefs = Prefs() dialog.show()
一般来说,还是应该尽量避免动态引入,因为import语句毕竟是有开销的,如果它出现在频繁执行的循环体里面,这种开销会更大。虽然如此,但是比大幅度修改整个程序要好。
89. 重构时考虑通过warnings提醒开发者API已经发生变化
# Example 1
def print_distance(speed, duration):
distance = speed * duration
print(f'{distance} miles')
print_distance(5, 2.5)
# Example 2
print_distance(1000, 3)
# Example 3
CONVERSIONS = {
'mph': 1.60934 / 3600 * 1000, # m/s
'hours': 3600, # seconds
'miles': 1.60934 * 1000, # m
'meters': 1, # m
'm/s': 1, # m
'seconds': 1, # s
}
def convert(value, units):
rate = CONVERSIONS[units]
return rate * value
def localize(value, units):
rate = CONVERSIONS[units]
return value / rate
def print_distance(speed, duration, *,
speed_units='mph',
time_units='hours',
distance_units='miles'):
norm_speed = convert(speed, speed_units)
norm_duration = convert(duration, time_units)
norm_distance = norm_speed * norm_duration
distance = localize(norm_distance, distance_units)
print(f'{distance} {distance_units}')
# Example 4
print_distance(1000, 3,
speed_units='meters',
time_units='seconds')
# Example 5
import warnings
def print_distance(speed, duration, *,
speed_units=None,
time_units=None,
distance_units=None):
if speed_units is None:
warnings.warn(
'speed_units required', DeprecationWarning)
speed_units = 'mph'
if time_units is None:
warnings.warn(
'time_units required', DeprecationWarning)
time_units = 'hours'
if distance_units is None:
warnings.warn(
'distance_units required', DeprecationWarning)
distance_units = 'miles'
norm_speed = convert(speed, speed_units)
norm_duration = convert(duration, time_units)
norm_distance = norm_speed * norm_duration
distance = localize(norm_distance, distance_units)
print(f'{distance} {distance_units}')
# Example 6
import contextlib
import io
fake_stderr = io.StringIO()
with contextlib.redirect_stderr(fake_stderr):
print_distance(1000, 3,
speed_units='meters',
time_units='seconds')
print(fake_stderr.getvalue())
# Example 7
def require(name, value, default):
if value is not None:
return value
warnings.warn(
f'{name} will be required soon, update your code',
DeprecationWarning,
stacklevel=3)
return default
def print_distance(speed, duration, *,
speed_units=None,
time_units=None,
distance_units=None):
speed_units = require('speed_units', speed_units, 'mph')
time_units = require('time_units', time_units, 'hours')
distance_units = require(
'distance_units', distance_units, 'miles')
norm_speed = convert(speed, speed_units)
norm_duration = convert(duration, time_units)
norm_distance = norm_speed * norm_duration
distance = localize(norm_distance, distance_units)
print(f'{distance} {distance_units}')
# Example 8
import contextlib
import io
fake_stderr = io.StringIO()
with contextlib.redirect_stderr(fake_stderr):
print_distance(1000, 3,
speed_units='meters',
time_units='seconds')
print(fake_stderr.getvalue())
# Example 9
warnings.simplefilter('error')
try:
warnings.warn('This usage is deprecated',
DeprecationWarning)
except DeprecationWarning:
pass # Expected
else:
assert False
warnings.resetwarnings()
# Example 10
warnings.resetwarnings()
warnings.simplefilter('ignore')
warnings.warn('This will not be printed to stderr')
warnings.resetwarnings()
# Example 11
import logging
fake_stderr = io.StringIO()
handler = logging.StreamHandler(fake_stderr)
formatter = logging.Formatter(
'%(asctime)-15s WARNING] %(message)s')
handler.setFormatter(formatter)
logging.captureWarnings(True)
logger = logging.getLogger('py.warnings')
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
warnings.resetwarnings()
warnings.simplefilter('default')
warnings.warn('This will go to the logs output')
print(fake_stderr.getvalue())
warnings.resetwarnings()
# Example 12
with warnings.catch_warnings(record=True) as found_warnings:
found = require('my_arg', None, 'fake units')
expected = 'fake units'
assert found == expected
# Example 13
assert len(found_warnings) == 1
single_warning = found_warnings[0]
assert str(single_warning.message) == (
'my_arg will be required soon, update your code')
assert single_warning.category == DeprecationWarning
90. 考虑通过typing做静态分析,以消除bug
typing模块:https://docs.python.org/3.8/library/typing.html。可以给变量,字段,函数与方法标注类型信息。
from typing import Callable, List, TypeVar
Value = TypeVar('Value')
Func = Callable[[Value, Value], Value]
def combine(func: Func[Value], values: List[Value]) -> Value:
assert len(values) > 0
result = values[0]
for next_value in values[1:]:
result = func(result, next_value)
return result
Real = TypeVar('Real', int, float)
def add(x: Real, y: Real) -> Real:
return x + y
inputs = [1, 2, 3, 4j] # Oops: included a complex number
result = combine(add, inputs)
print(result) # (6+4j)
assert result == 10
下面的python3.7开始支持,程序执行时,忽略类型注解里提到的值,于是就解决了提前引用的问题,而且程序在启动时的性能也会提升。
from __future__ import annotations
class FirstClass:
def __init__(self, value: SecondClass) -> None: # OK
self.value = value
class SecondClass:
def __init__(self, value: int) -> None:
self.value = value
second = SecondClass(5)
print(second) # <__main__.SecondClass object at 0x7fafc12cd4d0>
first = FirstClass(second)
print(first) # <__main__.FirstClass object at 0x7fafc12cd510>