The Ultimate Guide to Writing Functions
1.视频 https://www.youtube.com/watch?v=yatgY4NpZXE
2.代码 https://github.com/ArjanCodes/2022-funcguide
Python高质量函数编写指南
1. 一次做好一件事
from dataclasses import dataclass from datetime import datetime @dataclass class Customer: name: str phone: str cc_number: str cc_exp_month: int cc_exp_year: int cc_valid: bool = False # validate_card函数做了太多事情 def validate_card(customer: Customer) -> bool: def digits_of(number: str) -> list[int]: return [int(d) for d in number] digits = digits_of(customer.cc_number) odd_digits = digits[-1::-2] even_digits = digits[-2::-2] checksum = 0 checksum += sum(odd_digits) for digit in even_digits: checksum += sum(digits_of(str(digit * 2))) customer.cc_valid = ( checksum % 10 == 0 and datetime(customer.cc_exp_year, customer.cc_exp_month, 1) > datetime.now() ) return customer.cc_valid def main() -> None: alice = Customer( name="Alice", phone="2341", cc_number="1249190007575069", cc_exp_month=1, cc_exp_year=2024, ) is_valid = validate_card(alice) print(f"Is Alice's card valid? {is_valid}") print(alice) if __name__ == "__main__": main()
我们发现validate_card
函数做了两件事:验证数字和有效、验证时间有效。
我们把验证数字和拆分出来一个函数luhn_checksum
, 并在validate_card
中调用。
修改后:
from dataclasses import dataclass from datetime import datetime # 验证和 函数 def luhn_checksum(card_number: str) -> bool: def digits_of(number: str) -> list[int]: return [int(d) for d in number] digits = digits_of(card_number) odd_digits = digits[-1::-2] even_digits = digits[-2::-2] checksum = 0 checksum += sum(odd_digits) for digit in even_digits: checksum += sum(digits_of(str(digit * 2))) return checksum % 10 == 0 @dataclass class Customer: name: str phone: str ... def validate_card(customer: Customer) -> bool: customer.cc_valid = ( luhn_checksum(customer.cc_number) and datetime(customer.cc_exp_year, customer.cc_exp_month, 1) > datetime.now() ) return customer.cc_valid
2. 分离命令和查询(command and query)
validate_card
中同时进行了查询和赋值两个操作,这样不好。
我们将查询和赋值拆分成两个步骤。
validate_card
只返回卡是否有效,而赋值操作alice.cc_valid = validate_card(alice)
移动到了主函数中。
from dataclasses import dataclass from datetime import datetime def luhn_checksum(card_number: str) -> bool: def digits_of(number: str) -> list[int]: return [int(d) for d in number] digits = digits_of(card_number) odd_digits = digits[-1::-2] even_digits = digits[-2::-2] checksum = 0 checksum += sum(odd_digits) for digit in even_digits: checksum += sum(digits_of(str(digit * 2))) return checksum % 10 == 0 @dataclass class Customer: name: str phone: str cc_number: str cc_exp_month: int cc_exp_year: int cc_valid: bool = False # 查询 def validate_card(customer: Customer) -> bool: return ( luhn_checksum(customer.cc_number) and datetime(customer.cc_exp_year, customer.cc_exp_month, 1) > datetime.now() ) def main() -> None: alice = Customer( name="Alice", phone="2341", cc_number="1249190007575069", cc_exp_month=1, cc_exp_year=2024, ) # 赋值 alice.cc_valid = validate_card(alice) print(f"Is Alice's card valid? {alice.cc_valid}") print(alice) if __name__ == "__main__": main()
3. 只请求你需要的
函数validate_card
实际上只需要3个参数(而不需要整个Customer对象)。
因此只请求3个参数:def validate_card(*, number: str, exp_month: int, exp_year: int) -> bool:
``
from dataclasses import dataclass from datetime import datetime def luhn_checksum(card_number: str) -> bool: def digits_of(number: str) -> list[int]: return [int(d) for d in number] digits = digits_of(card_number) odd_digits = digits[-1::-2] even_digits = digits[-2::-2] checksum = 0 checksum += sum(odd_digits) for digit in even_digits: checksum += sum(digits_of(str(digit * 2))) return checksum % 10 == 0 @dataclass class Customer: name: str phone: str cc_number: str cc_exp_month: int cc_exp_year: int cc_valid: bool = False # 只请求你需要的参数 def validate_card(*, number: str, exp_month: int, exp_year: int) -> bool: return luhn_checksum(number) and datetime(exp_year, exp_month, 1) > datetime.now() def main() -> None: alice = Customer( name="Alice", phone="2341", cc_number="1249190007575069", cc_exp_month=1, cc_exp_year=2024, ) alice.cc_valid = validate_card( number=alice.cc_number, exp_month=alice.cc_exp_month, exp_year=alice.cc_exp_year, ) print(f"Is Alice's card valid? {alice.cc_valid}") print(alice) if __name__ == "__main__": main()
4. 保持最小参数量
参数量很多时,调用时传参会比较麻烦。另一方面,函数需要很多参数,则暗示该函数可能做了很多事情。
下面我们抽象出Card
类, 减少了Customer
和validae_card
的参数量。
from dataclasses import dataclass from datetime import datetime from typing import Protocol def luhn_checksum(card_number: str) -> bool: def digits_of(number: str) -> list[int]: return [int(d) for d in number] digits = digits_of(card_number) odd_digits = digits[-1::-2] even_digits = digits[-2::-2] checksum = 0 checksum += sum(odd_digits) for digit in even_digits: checksum += sum(digits_of(str(digit * 2))) return checksum % 10 == 0 @dataclass class Card: number: str exp_month: int exp_year: int valid: bool = False @dataclass class Customer: name: str phone: str card: Card card_valid: bool = False class CardInfo(Protocol): @property def number(self) -> str: ... @property def exp_month(self) -> int: ... @property def exp_year(self) -> int: ... def validate_card(card: CardInfo) -> bool: return ( luhn_checksum(card.number) and datetime(card.exp_year, card.exp_month, 1) > datetime.now() ) def main() -> None: card = Card(number="1249190007575069", exp_month=1, exp_year=2024) alice = Customer(name="Alice", phone="2341", card=card) # 现在传入card,而不是3个参数 card.valid = validate_card(card) # 传入card print(f"Is Alice's card valid? {card.valid}") print(alice) if __name__ == "__main__": main()
5. 不要在同一个地方创建并使用对象
不要再函数内创建对象并使用,更好的方式是在外面创建对象并作为参数传递给函数。
import logging class StripePaymentHandler: def handle_payment(self, amount: int) -> None: logging.info(f"Charging ${amount/100:.2f} using Stripe") PRICES = { "burger": 10_00, "fries": 5_00, "drink": 2_00, "salad": 15_00, } # !! def order_food(items: list[str]) -> None: total = sum(PRICES[item] for item in items) logging.info(f"Order total is ${total/100:.2f}.") payment_handler = StripePaymentHandler() # ... 创建对象 payment_handler.handle_payment(total) # 使用对象 logging.info("Order completed.") def main() -> None: logging.basicConfig(level=logging.INFO) order_food(["burger", "fries", "drink"]) if __name__ == "__main__": main()
修改后:
import logging from typing import Protocol class StripePaymentHandler: def handle_payment(self, amount: int) -> None: logging.info(f"Charging ${amount/100:.2f} using Stripe") PRICES = { "burger": 10_00, "fries": 5_00, "drink": 2_00, "salad": 15_00, } class PaymentHandler(Protocol): def handle_payment(self, amount: int) -> None: ... # !! 现在通过参数传入对象 def order_food(items: list[str], payment_handler: PaymentHandler) -> None: total = sum(PRICES[item] for item in items) logging.info(f"Order total is ${total/100:.2f}.") payment_handler.handle_payment(total) # logging.info("Order completed.") def main() -> None: logging.basicConfig(level=logging.INFO) order_food(["burger", "salad", "drink"], StripePaymentHandler()) if __name__ == "__main__": main()
6. 不要用flag参数
flag参数意味着函数处理两种情况,函数会变得复杂。建议将两者情况拆分成单独的函数。
from dataclasses import dataclass from enum import StrEnum, auto FIXED_VACATION_DAYS_PAYOUT = 5 class Role(StrEnum): PRESIDENT = auto() VICEPRESIDENT = auto() MANAGER = auto() LEAD = auto() ENGINEER = auto() INTERN = auto() @dataclass class Employee: name: str role: Role vacation_days: int = 25 def take_a_holiday(self, payout: bool, nr_days: int = 1) -> None: if payout: if self.vacation_days < FIXED_VACATION_DAYS_PAYOUT: raise ValueError( f"You don't have enough holidays left over for a payout.\ Remaining holidays: {self.vacation_days}." ) self.vacation_days -= FIXED_VACATION_DAYS_PAYOUT print(f"Paying out a holiday. Holidays left: {self.vacation_days}") else: if self.vacation_days < nr_days: raise ValueError( "You don't have any holidays left. Now back to work, you!" ) self.vacation_days -= nr_days print("Have fun on your holiday. Don't forget to check your emails!") def main() -> None: employee = Employee(name="John Doe", role=Role.ENGINEER) employee.take_a_holiday(True) if __name__ == "__main__": main()
修改后:
from dataclasses import dataclass from enum import StrEnum, auto FIXED_VACATION_DAYS_PAYOUT = 5 class Role(StrEnum): PRESIDENT = auto() VICEPRESIDENT = auto() MANAGER = auto() LEAD = auto() ENGINEER = auto() INTERN = auto() @dataclass class Employee: name: str role: Role vacation_days: int = 25 def payout_holiday(self) -> None: if self.vacation_days < FIXED_VACATION_DAYS_PAYOUT: raise ValueError( f"You don't have enough holidays left over for a payout.\ Remaining holidays: {self.vacation_days}." ) self.vacation_days -= FIXED_VACATION_DAYS_PAYOUT print(f"Paying out a holiday. Holidays left: {self.vacation_days}") def take_holiday(self, nr_days: int = 1) -> None: if self.vacation_days < nr_days: raise ValueError("You don't have any holidays left. Now back to work, you!") self.vacation_days -= nr_days print("Have fun on your holiday. Don't forget to check your emails!") def main() -> None: employee = Employee(name="John Doe", role=Role.ENGINEER) employee.payout_holiday() if __name__ == "__main__": main()
7. 函数也是对象
函数也是对象,因此可以作为参数传递,作为函数返回值。
import logging from functools import partial from typing import Callable def handle_payment_stripe(amount: int) -> None: logging.info(f"Charging ${amount/100:.2f} using Stripe") PRICES = { "burger": 10_00, "fries": 5_00, "drink": 2_00, "salad": 15_00, } HandlePaymentFn = Callable[[int], None] # 函数作为参数 def order_food(items: list[str], payment_handler: HandlePaymentFn) -> None: total = sum(PRICES[item] for item in items) logging.info(f"Order total is ${total/100:.2f}.") payment_handler(total) logging.info("Order completed.") order_food_stripe = partial(order_food, payment_handler=handle_payment_stripe) def main() -> None: logging.basicConfig(level=logging.INFO) # order_food(["burger", "salad", "drink"], handle_payment_stripe) order_food_stripe(["burger", "salad", "drink"]) if __name__ == "__main__": main()