前言
本文基于LangChain构建了针对自有领域数据的增强QA,支持以下数据源:
-
针对领域内需要精确回答的问题,从自有DB中查询;
-
针对领域内其他自然语言QA,从自有知识的embedded向量数据库查询;
-
针对领域内其他较为宽泛的问题,从开放搜索引擎返回。
本文以POC为主要目的,其中的组件选择如下:
-
LLM选择了商业化的OpenAI,强大且简单易用;仅做示例,请不要用于业务数据开发;
-
向量数据库选择了开源且免费的FAISS,也可考虑商业化的PineCone;
-
搜索集成选择了商业化的SerpAPI,仅做示例;
场景选择上,以我们团队所关注的网络安全为主题,以公开的数据集为对象,为了简化POC,尽可能地缩小了数据规模,使得整个程序非常简单,可以直接在Colab上运行。
阿里云安全创新实验室 (Security Innovation Laboratory, SIL) 是一个专注通过前沿算法解决安全问题的团队,近年来在 Botnet Conference 等会议上发布了关于 知识图谱推理 (Chaining)、行为序列嵌入 (Embedding)、SSDeep 大规模图网络 等分章分享。想了解更多请关注团队博客:阿里云安全创新实验室SIL,其中有大量在安全领域的算法与工程实践系列文章。
说明:本文所涉及的所有数据或技术均来自于公开数据集与业界技术实践,不涉及任何业务敏感数据或技术。
数据集
数据选择Kaggle上的CVE数据集(https://www.kaggle.com/datasets/andrewkronser/cve-common-vulnerabilities-and-exposures)。CVE(Common Vulnerabilities and Exposures)数据集是一个公开的、集中管理的计算机安全漏洞数据库。它提供有关已知安全漏洞的标准化描述和唯一标识符。为简化,这里仅选取前100条数据作为示例。
数据样式如下:
mod_date pub_date cvss cwe_code cwe_name summary access_authentication access_complexity access_vector impact_availability impact_confidentiality impact_integrity
CVE-2008-7273 2019-11-18 22:15:00 2019-11-18 22:15:00 4.6 59 Improper Link Resolution Before File Access ('Link Following') A symlink issue exists in Iceweasel-firegpg before 0.6 due to insecure tempfile handling.
CVE-2010-4659 2019-11-20 17:48:00 2019-11-20 17:15:00 4.3 79 Improper Neutralization of Input During Web Page Generation ('Cross-site Scripting') Cross-site scripting (XSS) vulnerability in statusnet through 2010 in error message contents.
数据集附件:cve-100.csv
数据读取如下:
from google.colab import drive
drive.mount('/content/drive')
f = open('/content/drive/My Drive/ColabFiles/cve-100.csv', 'r')
raw_cve = f.read()
print(raw_cve)
连接DataBase
LangChain的本质是连接器,这里我们首先展示如何连接到数据库。
初始化本地数据库
针对数据集中CVE的修改时间、发布时间、严重程度分数、漏洞枚举值、漏洞名称等较为确切的知识,通过传统的结构化数据库即可低成本、高效率地解决。另外,在我们的业务系统集成中,也难以避免需要与结构化数据库交互查询关键的信息。
这里用sqlite作为示例,建库建表如下:
# initialize the local database
import sqlite3
conn = sqlite3.connect('cve-100.db')
conn.execute('''
CREATE TABLE CVE (
id NVARCHAR PRIMARY KEY NOT NULL,
mod_date NVARCHAR NOT NULL, --The date the entry was last modified
pub_date NVARCHAR NOT NULL, --The date the entry was published
cvss NVARCHAR NOT NULL, --Common Vulnerability Scoring System (CVSS) score, a measure of the severity of a vulnerability
cwe_code NVARCHAR NOT NULL, --Common Weakness Enumeration (CWE) code, identifying the type of weakness
cwe_name NVARCHAR NOT NULL --The name associated with the CWE code
);''')
conn.commit()
print("CVE table created");
值得说明的是,这里我给列增加了备注,以帮助后续的LLM更好地理解每列数据的含义。通常而言,在提供给LLM的输入中每个词都很重要(Every word counts in large language models)。
导入数据
读取csv数据,导入表中:
import csv
# Create a cursor object
cursor = conn.cursor()
# Clear all data in the table
cursor.execute('DELETE FROM CVE')
# Commit the changes and close the connection
conn.commit()
# Open the CSV file
with open('/content/drive/My Drive/ColabFiles/cve-100.csv', 'r') as csv_file:
# Create a CSV reader
csv_reader = csv.reader(csv_file)
next(csv_reader) # Skip the header row if present
# Iterate over each row in the CSV file and insert it into the table
for row in csv_reader:
# Extract only the first 6 columns from the row
row_data = row[:6]
cursor.execute('INSERT INTO CVE VALUES (?, ?, ?, ?, ?, ?)', row_data)
# Commit the changes and close the connection
conn.commit()
检查数据
通过简单的查询语句,检查数据导入情况:
# Execute the SQL query to count rows in the "test" table
cursor.execute('SELECT COUNT(*) FROM CVE')
# Retrieve the count of rows
row_count = cursor.fetchone()[0]
# Print the count of rows
print("Number of rows in 'CVE' table:", row_count)
cursor.close()
conn.close()
返回:
Number of rows in 'CVE' table: 99
与LangChain集成
LangChain提供了SQLDatabase组件,可以轻松访问DB:
!pip install langchain
from langchain import SQLDatabase
# connect to db
db = SQLDatabase.from_uri("sqlite:///cve-100.db")
print(db.table_info)
返回:
CREATE TABLE "CVE" ( id NVARCHAR NOT NULL, mod_date NVARCHAR NOT NULL, pub_date NVARCHAR NOT NULL, cvss NVARCHAR NOT NULL, cwe_code NVARCHAR NOT NULL, cwe_name NVARCHAR NOT NULL, PRIMARY KEY (id) ) /* 3 rows from CVE table: id mod_date pub_date cvss cwe_code cwe_name CVE-2019-16548 2019-11-21 15:15:00 2019-11-21 15:15:00 6.8 352 Cross-Site Request Forgery (CSRF) CVE-2019-16547 2019-11-21 15:15:00 2019-11-21 15:15:00 4.0 732 Incorrect Permission Assignment for Critical Resource CVE-2019-16546 2019-11-21 15:15:00 2019-11-21 15:15:00 4.3 639 Authorization Bypass Through User-Controlled Key */
构建SQLDatabaseChain
LangChain提供了基于LLM的SQLDatabaseChain,可以利用LLM的能力将自然语言的query转化为SQL,连接DB进行查询,并利用LLM来组装润色结果,返回最终answer。
!pip install openai
from langchain import SQLDatabaseChain,OpenAI
import os
# Get your API keys from openai, you will need to create an account.
# Here is the link to get the keys: https://platform.openai.com/account/billing/overview
os.environ["OPENAI_API_KEY"] = "sk-your openai api key goes here"
# create db chain from llm and db
db_chain = SQLDatabaseChain.from_llm(OpenAI(temperature=0), db, verbose=True)
# run a query
db_chain.run("How many cves are there ?")
返回:
> Entering new SQLDatabaseChain chain... How many cves are there ? SQLQuery:SELECT COUNT(*) FROM "CVE"; SQLResult: [(99,)] Answer:There are 99 cves. > Finished chain.
'There are 99 cves.
这里我们使用了商业化的OpenAI,并将其temperature设为0,因为查询DB不太需要创造性和多样性。从返回的过程来看,自然语言被翻译成了SQL,得到查询结果后,解析包装结果,最终返回人类可以理解的答案。这里LLM成功将how many转成了select count(*),并准确地识别了表名,且最终组装了正确的结果。
再如:
db_chain.run("When is the CVE-2008-7273 published?")
返回:
> Entering new SQLDatabaseChain chain... When is the CVE-2008-7273 published? SQLQuery:SELECT "pub_date" FROM "CVE" WHERE "id" = 'CVE-2008-7273'; SQLResult: [('2019-11-18 22:15:00',)] Answer:The CVE-2008-7273 is published on 2019-11-18 22:15:00. > Finished chain.
'The CVE-2008-7273 is published on 2019-11-18 22:15:00.
这里LLM成功地理解了published对应的列名是pub_date,且最终成功返回了人类易于理解的包含答案的语言。
由此可见,我们可以借助LangChain提供的SQLDatabaseChain,轻松地连接LLM与Database,自然语言的方式输入,自然语言的方式输出,借助LLM的强大能力来释放数据。
连接VectorStore
针对CVE中的其他非结构化数据,如summary等,最好的方式是embedding后存入到向量数据库中。这里我们将LangChain连接到向量数据库中。
数据预处理
首先将数据进行切分,以防止数据过长达到token size limit。
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from google.colab import drive
drive.mount('/content/drive')
f = open('/content/drive/My Drive/ColabFiles/cve-100.csv', 'r')
raw_text = f.read()
print(raw_text)
# We need to split the text that we read into smaller chunks so that during information retreival we don't hit the token size limits.
text_splitter = CharacterTextSplitter(
separator = "\n",
chunk_size = 500,
chunk_overlap = 0,
length_function = len,
)
texts = text_splitter.split_text(raw_text)
len(texts)
创建向量数据库索引
这里使用OpenAI的embedding,基于切分后的texts来创建向量数据库。其中向量数据库选择了开源且免费的FAISS (Facebook AI Similarity Search)。
!pip install faiss-cpu
!pip install tiktoken
from langchain.vectorstores import FAISS
# using embeddings from OpenAI
embeddings = OpenAIEmbeddings()
# create vector index
vector_db = FAISS.from_texts(texts, embeddings)
其中vector_db是创建出的向量数据库索引。
语义查询
通过创建好的向量数据库索引,可以轻松地进行相似性语义查询:
query = "What is the vulnerability in Jenkins Google Compute Engine Plugin?"
docs = vector_db.similarity_search(query)
print(docs)
返回:
[Document(page_content='CVE-2019-16547,2019-11-21 15:15:00,2019-11-21 15:15:00,4.0,732, Incorrect Permission Assignment for Critical Resource,Missing permission checks in various API endpoints in Jenkins Google Compute Engine Plugin 4.1.1 and earlier allow attackers with Overall/Read permission to obtain limited information about the plugin configuration and environment.,,,,,,', metadata={}), Document(page_content=',mod_date,pub_date,cvss,cwe_code,cwe_name,summary,access_authentication,access_complexity,access_vector,impact_availability,impact_confidentiality,impact_integrity\nCVE-2019-16548,2019-11-21 15:15:00,2019-11-21 15:15:00,6.8,352, Cross-Site Request Forgery (CSRF),A cross-site request forgery vulnerability in Jenkins Google Compute Engine Plugin 4.1.1 and earlier in ComputeEngineCloud#doProvision could be used to provision new agents.,,,,,,', metadata={}), Document(page_content='CVE-2019-16546,2019-11-21 15:15:00,2019-11-21 15:15:00,4.3,639, Authorization Bypass Through User-Controlled Key,"Jenkins Google Compute Engine Plugin 4.1.1 and earlier does not verify SSH host keys when connecting agents created by the plugin, enabling man-in-the-middle attacks.",,,,,,', metadata={}), Document(page_content="CVE-2012-4441,2019-11-18 22:15:00,2019-11-18 22:15:00,4.3,79, Improper Neutralization of Input During Web Page Generation ('Cross-site Scripting'),Cross-site Scripting (XSS) in Jenkins main before 1.482 and LTS before 1.466.2 allows remote attackers to inject arbitrary web script or HTML in the CI game plugin.,,,,,,", metadata={})]
可见,成功返回了与Jenkins Google Compute Engine相关的4个document。
创建QAChain
上面通过vector store,仅仅只是返回了相似度较高的文档,并没有真正地回答客户的提问。借助LLM,可以从返回的相似性文档来组装答案。
# create the local QA chain
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
local_qa_chain = load_qa_chain(OpenAI(), chain_type="stuff")
这里,我们加载了LangChain中的qa_chain,并借助OpenAI,创建了一个具备QA能力的local_qa_chain。
通过QAChain查询VectorStore
与上面类似,我们依然先通过VectorStore搜索出相似的documents,并随同query一起传给QA chain。
query = "what is the type of weakness for vulnerability in Jenkins Google Compute Engine Plugin?"
docs = vector_db.similarity_search(query)
print(docs)
local_qa_chain.run(input_documents=docs, question=query)
返回
Incorrect permission assignment for critical resource (CVE-2019-16547), Cross-site request forgery (CSRF) (CVE-2019-16548), Authorization Bypass through user-controlled key (CVE-2019-16546), and Improper Neutralization of Input During Web Page Generation ('Cross-site Scripting') (CVE-2012-4441).
可见成功地根据query,从相似的documents中提出出了关键的信息,并进行了有序的组装,最终得到了易于理解的答案。
连接SearchEngine
这里,我们采用SERP(Search Engine Result Page)API。
# initialize the SERP API
!pip install google-search-results
os.environ["SERPAPI_API_KEY"] = "Your Serp API key goes here"
简单测试下:
from langchain import SerpAPIWrapper
# create the Search Engine Result Page API wrapper
search = SerpAPIWrapper()
search.run('What Is Cyber Risk?')
返回
Definition(s): The risk of depending on cyber resources (i.e., the risk of depending on a system or system elements that exist in or intermittently have a presence in cyberspace).
基于Chain Tools创建Agent
现在,我们将以上的三种Chain连接到一起,形成tools:
-
首先是db_chain,只针对CVE的mod_date, pub_date, cvss, and cwe_code等信息;
-
其次是local_qa_chain,针对的是领域内知识,即Vector Store中CVE和漏洞等相关的信息;
-
最后是Search tools,基于SerpAPI进行开放搜索。
from langchain import LLMMathChain, SerpAPIWrapper
from langchain.agents import AgentType, initialize_agent
from langchain.chat_models import ChatOpenAI
from langchain.tools import BaseTool, StructuredTool, Tool, tool
from langchain.tools import tool
@tool
def local_qa_tool(query: str) -> str:
"""useful for when you need to answer questions about cve or vulnerabilities"""
return local_qa_chain.run(input_documents=vector_db.similarity_search(query), question={query})
tools = [
Tool.from_function(
func=db_chain.run,
name = "db Search",
description="useful for when you need to answer questions only about mod_date, pub_date, cvss, and cwe_code for CVE",
),
local_qa_tool,
Tool.from_function(
func=SerpAPIWrapper().run,
name = "Search",
description="useful for when you need to answer other security related questions.",
),
]
# Finally, let's initialize an agent with the tools, the language model, and the type of agent we want to use.
agent = initialize_agent(tools, OpenAI(), agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
可以基于这些tools,连接LLM,创建出强大的agent。这里LLM依然选择OpenAI。
查询DB
agent.run("When is the CVE-2008-7273 published and what's its severity score?")
返回:
> Entering new AgentExecutor chain... I need to find the publication date and severity score for this CVE
Action: db Search
Action Input: CVE-2008-7273
> Entering new SQLDatabaseChain chain...
CVE-2008-7273
SQLQuery:SELECT id, mod_date, pub_date, cvss, cwe_code, cwe_name FROM CVE WHERE id = 'CVE-2008-7273' LIMIT 5;
SQLResult: [('CVE-2008-7273', '2019-11-18 22:15:00', '2019-11-18 22:15:00', '4.6', '59', " Improper Link Resolution Before File Access ('Link Following')")]
Answer:CVE-2008-7273 was published on 2019-11-18 22:15:00 with a CVSS score of 4.6 and CWE code 59, Improper Link Resolution Before File Access ('Link Following').
> Finished chain.
Observation: CVE-2008-7273 was published on 2019-11-18 22:15:00 with a CVSS score of 4.6 and CWE code 59, Improper Link Resolution Before File Access ('Link Following').
Thought: I now know the final answer
Final Answer: CVE-2008-7273 was published on 2019-11-18 22:15:00 with a CVSS score of 4.6.
> Finished chain.
'CVE-2008-7273 was published on 2019-11-18 22:15:00 with a CVSS score of 4.6.
可见agent选择了db search,将自然语言转成了SQL,并根据SQL result生成了对应的answer。
查询VectorStore
agent.run("what are the cve ids that related to Jenkins Google Compute Engine Plugin?")
返回:
> Entering new AgentExecutor chain... I should find a tool that can answer questions about CVE and vulnerabilities.
Action: local_qa_tool Action Input: Jenkins Google Compute Engine Plugin
Observation:
CVE-2019-16547 is a vulnerability in Jenkins Google Compute Engine Plugin that affects versions 4.1.1 and earlier. It is rated 4.0 on the CVSS scale and is categorized as 'Incorrect Permission Assignment for Critical Resource'.
CVE-2019-16548 is a vulnerability in Jenkins Google Compute Engine Plugin that affects versions 4.1.1 and earlier. It is rated 6.8 on the CVSS scale and is categorized as 'Cross-Site Request Forgery (CSRF)'.
CVE-2019-16546 is a vulnerability in Jenkins Google Compute Engine Plugin that affects versions 4.1.1 and earlier. It is rated 4.3 on the CVSS scale and is categorized as 'Authorization Bypass Through User-Controlled Key'.
CVE-2012-4441 is not related to Jenkins Google Compute Engine Plugin.
Thought: I now know the final answer.
Final Answer: CVE-2019-16547, CVE-2019-16548, CVE-2019-16546
> Finished chain.
'CVE-2019-16547, CVE-2019-16548, CVE-2019-16546
可见agent选择了local_qa_tool,从vector store中筛选出了相似的documents,并组装了答案。
搜索引擎
我们可以问一些更为通用、宽泛的问题,
agent.run("What is Cloud SIEM?")
返回
> Entering new AgentExecutor chain...
I'm not sure, I should try to find out
Action: Search Action Input: Cloud SIEM
Observation: With Cloud SIEM, you can augment your existing SIEM investments and deliver better cloud security outcomes. Cloud SIEM analyzes operational and security logs in ...
Thought: I now know the final answer Final
Answer: Cloud SIEM is a technology that provides visibility and security analytics for cloud-based infrastructure and applications. It helps organizations monitor, detect, and respond to threats in real-time by analyzing operational and security logs in the cloud.
> Finished chain.
'Cloud SIEM is a technology that provides visibility and security analytics for cloud-based infrastructure and applications. It helps organizations monitor, detect, and respond to threats in real-time by analyzing operational and security logs in the cloud.
可见agent针对不知道的开放性问题,成功地选择了search Action进行开放式搜索,并根据搜索引擎返回的结果组装了答案。
总结
本文借助LangChain的连接能力,连接了本地结构化数据库、本地文本向量数据库、开放搜索引擎等数据源,并利用LLM构建了增强QA,可以应对多样性地问答需要。本文以示例为主,为了保持代码逻辑清晰简单,尽可能做了简化,在更接近业务实际的场景下,还有很多待考虑待优化的地方,比如:
-
业务敏感数据的接入,不能再使用商业化的OpenAI,需要考虑本地化部署开源LLM,或者集团内可信LLM。
-
切换成其他LLM后,很可能效果会大幅下降,因此需要针对业务数据对LLM做fine tunning,存在相当的挑战。
-
大量业务数据的情况下,需要使用更强大且可信的VectorStore,比如业务自己云上购买ElasticVectorSearch、Redis或其他支持向量的数据库。
-
降低成本,降低时延。LLM虽然强大,但是依然成本不菲,且通常情况下延迟较高。因此可考虑增加缓存的方式尽量减少与LLM的交互,比如可以将QA对放入VectorStore中做缓存,针对缓存中已有的相似query可以直接返回,而只有缓存中没有相似的query才交给LLM去处理。
本文完整的Colab代码在这里:https://colab.research.google.com/drive/1bHDb9iROpaeKxqJ7PjlbfruEtCwIop8Q?usp=sharing