file_cache 使用文件缓存函数结果
更好的 Python 缓存,用于慢速函数调用
原文:https://docs.sweep.dev/blogs/file-cache
作者编写了一个文件缓存 - 它类似于 Python 的lru_cache ,但它将值存储在文件中而不是内存中。
这是链接:https://github.com/sweepai/sweep/blob/main/docs/public/file_cache.py
想使用它,只需将其作为装饰器添加到你的函数中, 例如:
import time from file_cache import file_cache @file_cache() def slow_function(x, y): time.sleep(30) return x + y print(slow_function(1, 2)) # -> 3, takes 30 seconds print(slow_function(1, 2)) # -> 3, takes 0 seconds
背景
作者在一个LLM项目中需要缓存中间结果,便于节省时间。
但内置缓存函数lru_cache
不适合,
lru_cahce
将结果保存在内存中,下次运行程序时缓存失效。
- 并且作者添加了一个ignore_params参数,用于忽略不重要的参数,例如
file_cache(ignore_params=["chat_logger"])
实现
我们的两个主要方法是 recursive_hash
和 file_cache
recursive_hash
我们希望稳定地对对象进行哈希处理,而这在 python 中是原生不支持的。
import hashlib from cache import recursive_hash class Obj: def __init__(self, name): self.name = name obj = Obj("test") print(recursive_hash(obj)) # -> this works fine try: hashlib.md5(obj).hexdigest() except Exception as e: print(e) # -> this doesn't work
原始的 hashlib.md5
不适用于任意对象,它会报错: TypeError: object supporting the buffer API required
。我们使用 recursive_hash
,它适用于任意 python 对象。
def recursive_hash(value, depth=0, ignore_params=[]): """Hash primitives recursively with maximum depth.""" if depth > MAX_DEPTH: return hashlib.md5("max_depth_reached".encode()).hexdigest() if isinstance(value, (int, float, str, bool, bytes)): return hashlib.md5(str(value).encode()).hexdigest() elif isinstance(value, (list, tuple)): return hashlib.md5( "".join( [recursive_hash(item, depth + 1, ignore_params) for item in value] ).encode() ).hexdigest() elif isinstance(value, dict): return hashlib.md5( "".join( [ recursive_hash(key, depth + 1, ignore_params) + recursive_hash(val, depth + 1, ignore_params) for key, val in value.items() if key not in ignore_params ] ).encode() ).hexdigest() elif hasattr(value, "__dict__") and value.__class__.__name__ not in ignore_params: return recursive_hash(value.__dict__, depth + 1, ignore_params) else: return hashlib.md5("unknown".encode()).hexdigest()
file_cache
file_cache
是一个处理缓存逻辑的装饰器。
@file_cache() def search_codebase( cloned_github_repo, query, ): # ... take a long time ... # ... llm agent logic to search through the codebase ... return top_results
Wrapper
首先,我们将缓存存储在 /tmp/file_cache
.这使我们可以通过简单地删除目录(运行 rm -rf /tmp/file_cache
)来删除缓存。
def wrapper(*args, **kwargs): cache_dir = "/tmp/file_cache" os.makedirs(cache_dir, exist_ok=True)
Cache Key Creation / Miss Conditions
我们还有另一个问题 - 我们希望在两种情况下不使用缓存:
- 函数参数更改 - 由
recursive_hash
处理 - 代码更改
- 为了处理 2.我们使用
inspect.getsource(func)
将函数的源代码进行哈希,在代码更改时正确地丢失了缓存。func_source_code_hash = hash_code(inspect.getsource(func))
args_names = func.__code__.co_varnames[: func.__code__.co_argcount] args_dict = dict(zip(args_names, args)) # Remove ignored params kwargs_clone = kwargs.copy() for param in ignore_params: args_dict.pop(param, None) kwargs_clone.pop(param, None) # Create hash based on argument names, argument values, and function source code arg_hash = ( recursive_hash(args_dict, ignore_params=ignore_params) + recursive_hash(kwargs_clone, ignore_params=ignore_params) + func_source_code_hash ) cache_file = os.path.join( cache_dir, f"{func.__module__}_{func.__name__}_{arg_hash}.pickle" )
Cache hits and misses
最后,我们检查缓存键是否存在,并在缓存未命中的情况下写入缓存。
try: # If cache exists, load and return it if os.path.exists(cache_file): if verbose: print("Used cache for function: " + func.__name__) with open(cache_file, "rb") as f: return pickle.load(f) except Exception: logger.info("Unpickling failed") # Otherwise, call the function and save its result to the cache result = func(*args, **kwargs) try: with open(cache_file, "wb") as f: pickle.dump(result, f) except Exception as e: logger.info(f"Pickling failed: {e}") return result