个人主页:Guiat
归属专栏:科学技术变革创新
文章目录
- 1. Function Calling:智能交互的新范式
- 1.1 Function Calling 技术概述
- 1.2 核心优势分析
- 2. 数据分析Agent平台架构设计
- 2.1 系统架构概览
- 2.2 核心组件解析
- 2.2.1 函数注册中心
- 2.2.2 Agent控制器
- 3. 数据分析函数设计与实现
- 3.1 数据获取函数
- 3.1.1 CSV文件加载器
- 3.1.2 数据库连接器
- 3.2 数据处理函数
- 3.2.1 探索性数据分析
- 3.2.2 数据清洗与转换
- 3.3 数据可视化函数
- 3.3.1 图表生成器
- 3.3.2 高级可视化:多图表面板
- 4. 智能数据分析流程实现
- 4.1 函数注册与配置
- 4.2 构建分析工作流
- 5. 高级功能拓展与应用
- 5.1 交互式分析对话能力
- 5.2 自动化报告生成
- 5.3 多数据源集成分析
- 6. 性能优化与最佳实践
- 6.1 性能优化策略
- 6.1.1 LLM调用优化实现
- 6.1.2 数据处理优化实现
- 6.2 系统扩展性设计
- 6.2.1 插件化架构
- 6.2.2 分布式计算支持
- 6.3 安全性与数据隐私保护
- 6.3.1 数据安全策略
- 7. 全栈实现:构建一个完整的数据分析Agent平台
- 7.1 系统架构设计
- 7.1.1 整体架构图
- 7.1.2 关键组件详述
- 7.1.2.1 前端层
- 7.1.2.2 API层
- 7.2 完整实现流程
- 7.2.1 后端服务实现
- 7.2.2 部署与容器化
- 8. 实际案例:电子商务销售数据分析
- 8.1 案例需求与数据集
- 8.1.1 业务背景
- 8.1.2 数据集描述
- 8.2 Agent实现过程
- 8.2.1 函数准备
- 8.2.2 Agent交互示例
- 8.3 分析结果与业务洞察
- 8.3.1 高级可视化面板
- 8.3.2 执行预测分析
- 8.4 实际业务决策支持
- 8.4.1 库存优化建议
- 8.4.2 营销活动优化建议
- 9. 未来发展与技术展望
- 9.1 Function Calling技术演进方向
- 9.1.1 函数自动发现与合成
- 9.2 智能分析平台的未来形态
- 9.2.1 自适应学习分析系统
- 9.2.2 端到端的异构数据处理能力
- 9.3 行业应用前景与挑战
- 9.3.1 垂直行业应用场景
- 9.3.2 技术挑战与解决方案
正文
1. Function Calling:智能交互的新范式
Function Calling 技术作为大型语言模型(LLM)发展中的重要里程碑,为构建复杂智能应用提供了强大框架。本文将系统探讨 Function Calling 机制,并通过实例带领读者构建一个智能数据分析平台。
1.1 Function Calling 技术概述
Function Calling 允许 LLM 识别何时应该调用特定函数,并生成符合函数签名的结构化参数。这种能力极大拓展了 AI 与外部工具和服务的交互方式。
1.2 核心优势分析
Function Calling 相比传统提示工程有以下显著优势:
- 结构化输出 - 确保输出符合预期格式
- 工具利用 - 能够访问外部数据源和API
- 多轮交互 - 支持复杂推理链和任务分解
- 可控性增强 - 降低幻觉概率,提高可靠性
2. 数据分析Agent平台架构设计
2.1 系统架构概览
我们的数据分析平台将采用多层架构,集成Function Calling机制与数据处理工具。
2.2 核心组件解析
2.2.1 函数注册中心
函数注册中心管理所有可被Agent调用的函数,每个函数都包含明确的名称、描述、参数定义和返回值说明。
# 函数注册中心实现
class FunctionRegistry:
def __init__(self):
self.functions = {}
def register(self, func_config):
"""注册一个新函数到系统中"""
self.functions[func_config["name"]] = {
"function": func_config["function"],
"description": func_config["description"],
"parameters": func_config["parameters"],
"required": func_config.get("required", [])
}
def get_function_descriptions(self):
"""返回所有函数的描述,用于LLM提示"""
return [
{
"name": name,
"description": info["description"],
"parameters": {
"type": "object",
"properties": info["parameters"],
"required": info["required"]
}
}
for name, info in self.functions.items()
]
def execute(self, name, params):
"""执行指定函数"""
if name not in self.functions:
raise ValueError(f"未找到函数: {name}")
return self.functions[name]["function"](**params)
2.2.2 Agent控制器
Agent控制器负责协调LLM与函数系统的交互,管理上下文并处理多轮对话。
# Agent控制器实现
class AnalysisAgent:
def __init__(self, llm_service, function_registry):
self.llm = llm_service
self.registry = function_registry
self.conversation_history = []
def add_message(self, role, content, function_call=None):
"""添加消息到对话历史"""
message = {"role": role, "content": content}
if function_call:
message["function_call"] = function_call
self.conversation_history.append(message)
def process_query(self, user_query):
"""处理用户查询"""
self.add_message("user", user_query)
# 获取LLM响应
response = self.llm.generate_response(
self.conversation_history,
functions=self.registry.get_function_descriptions()
)
# 处理可能的函数调用
if "function_call" in response:
function_name = response["function_call"]["name"]
function_args = json.loads(response["function_call"]["arguments"])
# 记录函数调用
self.add_message("assistant", None, response["function_call"])
# 执行函数
result = self.registry.execute(function_name, function_args)
# 记录函数结果
self.add_message("function", result, {"name": function_name})
# 继续处理,获取最终响应
return self.process_function_result(result)
else:
# 直接回复
self.add_message("assistant", response["content"])
return response["content"]
def process_function_result(self, result):
"""处理函数结果并获取最终回复"""
final_response = self.llm.generate_response(
self.conversation_history,
functions=self.registry.get_function_descriptions()
)
self.add_message("assistant", final_response["content"])
return final_response["content"]
3. 数据分析函数设计与实现
3.1 数据获取函数
3.1.1 CSV文件加载器
def load_csv_data(file_path, sample_rows=5):
"""
加载CSV文件并返回数据预览
参数:
file_path: CSV文件路径
sample_rows: 预览的行数
返回:
数据概览信息字典
"""
import pandas as pd
df = pd.read_csv(file_path)
return {
"columns": list(df.columns),
"shape": df.shape,
"dtypes": df.dtypes.astype(str).to_dict(),
"sample": df.head(sample_rows).to_dict('records'),
"file_path": file_path, # 保存文件路径以供后续分析
"null_counts": df.isna().sum().to_dict()
}
3.1.2 数据库连接器
def query_database(connection_string, query, limit=100):
"""
执行SQL查询并返回结果
参数:
connection_string: 数据库连接字符串
query: SQL查询语句
limit: 最大返回行数
返回:
查询结果和元数据
"""
import pandas as pd
import sqlalchemy
engine = sqlalchemy.create_engine(connection_string)
# 安全检查,防止危险操作
if any(keyword in query.lower() for keyword in ["drop", "delete", "truncate", "update"]):
return {"error": "危险操作被拒绝,仅允许SELECT查询"}
# 添加LIMIT以防止返回过多数据
if "limit" not in query.lower():
query = f"{query} LIMIT {limit}"
df = pd.read_sql(query, engine)
return {
"columns": list(df.columns),
"shape": df.shape,
"data": df.to_dict('records'),
"query": query
}
3.2 数据处理函数
3.2.1 探索性数据分析
def analyze_dataset(file_path, columns=None):
"""
对数据集进行探索性分析
参数:
file_path: 数据文件路径
columns: 要分析的列,默认为全部
返回:
统计分析结果
"""
import pandas as pd
import numpy as np
df = pd.read_csv(file_path)
if columns:
df = df[columns]
numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
analysis = {
"shape": df.shape,
"numeric_analysis": {},
"categorical_analysis": {},
"correlation": None
}
# 数值型列分析
if numeric_cols:
analysis["numeric_analysis"] = df[numeric_cols].describe().to_dict()
analysis["correlation"] = df[numeric_cols].corr().to_dict()
# 分类型列分析
for col in categorical_cols:
analysis["categorical_analysis"][col] = {
"unique_values": df[col].nunique(),
"top_values": df[col].value_counts().head(5).to_dict()
}
return analysis
3.2.2 数据清洗与转换
def process_data(file_path, operations, output_path=None):
"""
执行数据清洗和转换操作
参数:
file_path: 输入数据文件路径
operations: 操作列表,每个操作是一个字典
output_path: 输出文件路径,可选
返回:
处理结果摘要
"""
import pandas as pd
df = pd.read_csv(file_path)
original_shape = df.shape
results = {
"original_shape": original_shape,
"operations_performed": [],
"errors": []
}
for op in operations:
try:
if op["type"] == "drop_na":
df = df.dropna(subset=op.get("columns", None))
results["operations_performed"].append(f"删除了{original_shape[0] - df.shape[0]}行含缺失值的数据")
elif op["type"] == "fill_na":
for col, value in op["values"].items():
df[col] = df[col].fillna(value)
results["operations_performed"].append(f"填充了缺失值: {op['values']}")
elif op["type"] == "normalize":
for col in op["columns"]:
df[col] = (df[col] - df[col].min()) / (df[col].max() - df[col].min())
results["operations_performed"].append(f"标准化了列: {op['columns']}")
elif op["type"] == "one_hot_encode":
df = pd.get_dummies(df, columns=op["columns"], drop_first=op.get("drop_first", False))
results["operations_performed"].append(f"独热编码了列: {op['columns']}")
elif op["type"] == "filter":
before_count = len(df)
df = df.query(op["condition"])
results["operations_performed"].append(f"应用过滤条件({op['condition']}),移除了{before_count - len(df)}行")
except Exception as e:
results["errors"].append(f"操作 {op['type']} 失败: {str(e)}")
results["final_shape"] = df.shape
results["columns"] = list(df.columns)
if output_path:
df.to_csv(output_path, index=False)
results["output_file"] = output_path
return results
3.3 数据可视化函数
3.3.1 图表生成器
def create_visualization(data, chart_type, config):
"""
生成数据可视化图表
参数:
data: 数据源(文件路径或数据字典)
chart_type: 图表类型 (bar, line, scatter, etc.)
config: 图表配置
返回:
图表数据和渲染选项
"""
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import base64
from io import BytesIO
# 加载数据
if isinstance(data, str):
df = pd.read_csv(data)
else:
df = pd.DataFrame(data)
plt.figure(figsize=(10, 6))
# 根据图表类型生成可视化
if chart_type == "bar":
sns.barplot(x=config["x"], y=config["y"], data=df, hue=config.get("hue"))
plt.title(config.get("title", f"{config['y']} by {config['x']}"))
elif chart_type == "line":
sns.lineplot(x=config["x"], y=config["y"], data=df, hue=config.get("hue"))
plt.title(config.get("title", f"{config['y']} over {config['x']}"))
elif chart_type == "scatter":
sns.scatterplot(x=config["x"], y=config["y"], data=df, hue=config.get("hue"))
plt.title(config.get("title", f"{config['y']} vs {config['x']}"))
elif chart_type == "histogram":
sns.histplot(df[config["column"]], bins=config.get("bins", 10))
plt.title(config.get("title", f"Distribution of {config['column']}"))
elif chart_type == "heatmap":
corr = df[config["columns"]].corr()
sns.heatmap(corr, annot=True, cmap=config.get("colormap", "coolwarm"))
plt.title(config.get("title", "Correlation Heatmap"))
plt.tight_layout()
# 将图表转换为base64编码
buffer = BytesIO()
plt.savefig(buffer, format='png')
buffer.seek(0)
image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
plt.close()
return {
"image_data": f"data:image/png;base64,{image_base64}",
"chart_type": chart_type,
"config": config
}
3.3.2 高级可视化:多图表面板
def create_dashboard(data_source, visualizations):
"""
创建多图表分析面板
参数:
data_source: 数据源文件路径
visualizations: 可视化配置列表
返回:
包含多个图表的面板数据
"""
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import base64
from io import BytesIO
df = pd.read_csv(data_source)
# 计算面板布局
num_charts = len(visualizations)
cols = min(2, num_charts)
rows = (num_charts + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(12, 6 * rows))
if num_charts == 1:
axes = np.array([axes])
axes = axes.flatten()
charts_info = []
for i, viz_config in enumerate(visualizations):
if i < len(axes):
chart_type = viz_config["chart_type"]
config = viz_config["config"]
ax = axes[i]
plt.sca(ax)
# 绘制图表
if chart_type == "bar":
sns.barplot(x=config["x"], y=config["y"], data=df, ax=ax)
ax.set_title(config.get("title", f"{config['y']} by {config['x']}"))
elif chart_type == "line":
sns.lineplot(x=config["x"], y=config["y"], data=df, ax=ax)
ax.set_title(config.get("title", f"{config['y']} over {config['x']}"))
elif chart_type == "scatter":
sns.scatterplot(x=config["x"], y=config["y"], data=df, ax=ax)
ax.set_title(config.get("title", f"{config['y']} vs {config['x']}"))
charts_info.append({
"position": i,
"type": chart_type,
"config": config
})
# 隐藏多余的子图
for j in range(i + 1, len(axes)):
axes[j].axis('off')
plt.tight_layout()
# 将图表转换为base64编码
buffer = BytesIO()
plt.savefig(buffer, format='png', dpi=120)
buffer.seek(0)
image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
plt.close()
return {
"dashboard_image": f"data:image/png;base64,{image_base64}",
"charts": charts_info,
"layout": {"rows": rows, "cols": cols},
"data_source": data_source
}
4. 智能数据分析流程实现
4.1 函数注册与配置
首先,我们需要将所有数据分析工具函数正确注册到系统中。
# 示例:配置和注册数据分析函数
registry = FunctionRegistry()
# 注册数据加载函数
registry.register({
"name": "load_csv_data",
"function": load_csv_data,
"description": "加载CSV文件并返回数据概览",
"parameters": {
"file_path": {
"type": "string",
"description": "CSV文件的路径"
},
"sample_rows": {
"type": "integer",
"description": "要预览的数据行数"
}
},
"required": ["file_path"]
})
# 注册数据分析函数
registry.register({
"name": "analyze_dataset",
"function": analyze_dataset,
"description": "对数据集进行探索性分析,包括统计摘要和相关性分析",
"parameters": {
"file_path": {
"type": "string",
"description": "数据文件的路径"
},
"columns": {
"type": "array",
"items": {
"type": "string"
},
"description": "要分析的列名列表,如果为空则分析所有列"
}
},
"required": ["file_path"]
})
# 依此类推注册其他函数...
4.2 构建分析工作流
以下示例展示了一个完整的数据分析工作流程:
# 初始化分析代理
agent = AnalysisAgent(llm_service, registry)
# 用户查询示例
user_query = "我有一个销售数据文件sales_data.csv,帮我分析一下销售趋势并查看有哪些因素影响了销售额"
# 分析流程执行
response = agent.process_query(user_query)
# 输出最终结果
print(response)
在实际运行过程中,Agent 会根据用户查询执行以下步骤:
- 首先可能调用
load_csv_data
了解数据结构 - 然后使用
analyze_dataset
进行探索性分析 - 可能需要
process_data
清洗数据 - 最后用
create_visualization
生成趋势图表和相关性分析
5. 高级功能拓展与应用
5.1 交互式分析对话能力
为了增强用户体验,我们可以实现持续对话功能,允许用户基于初始分析结果提出后续问题。
# 交互式分析会话示例
def interactive_analysis_session():
agent = AnalysisAgent(llm_service, registry)
print("智能数据分析助手已启动,请输入您的问题(输入'exit'退出)")
while True:
user_input = input("> ")
if user_input.lower() == 'exit':
print("分析会话结束")
break
response = agent.process_query(user_input)
print("\n分析结果:")
print(response)
print("\n您可以继续提问,或输入'exit'退出")
5.2 自动化报告生成
我们可以为系统添加自动生成完整分析报告的能力:
def generate_analysis_report(data_file, output_format="markdown"):
"""
生成完整的数据分析报告
参数:
data_file: 数据文件路径
output_format: 输出格式(markdown, html, pdf)
返回:
报告内容或文件路径
"""
agent = AnalysisAgent(llm_service, registry)
# 步骤1: 数据概览
overview_query = f"分析{data_file}文件,给我一个数据概览"
overview = agent.process_query(overview_query)
# 步骤2: 探索性分析
eda_query = f"对{data_file}进行探索性数据分析,包括描述性统计和相关性分析"
eda_result = agent.process_query(eda_query)
# 步骤3: 关键见解
insights_query = f"基于{data_file}的分析,提取3-5个关键商业见解"
insights = agent.process_query(insights_query)
# 步骤4: 数据可视化
viz_query = f"为{data_file}创建最关键的3个可视化图表,展示主要趋势和关系"
visualizations = agent.process_query(viz_query)
# 步骤5: 汇总报告
report_template = f"""
# 数据分析报告: {data_file}
## 1. 数据概览
{overview}
## 2. 探索性数据分析
{eda_result}
## 3. 关键发现
{insights}
## 4. 数据可视化
{visualizations}
## 5. 结论与建议
基于以上分析,我们得出以下结论和建议...
"""
conclusion_query = f"基于以上分析结果,为{data_file}生成结论和建议部分"
conclusion = agent.process_query(conclusion_query)
final_report = report_template.replace("基于以上分析,我们得出以下结论和建议...", conclusion)
# 根据需要的格式输出
if output_format == "markdown":
return final_report
elif output_format == "html":
# 将Markdown转换为HTML的代码
import markdown
html = markdown.markdown(final_report)
return html
elif output_format == "pdf":
# 生成PDF的代码
from weasyprint import HTML
report_file = "analysis_report.pdf"
HTML(string=markdown.markdown(final_report)).write_pdf(report_file)
return f"报告已保存为 {report_file}"
5.3 多数据源集成分析
扩展系统以支持多数据源的关联分析:
实现数据集成层的关键代码:
def integrate_data_sources(sources, join_config):
"""
整合多个数据源进行分析
参数:
sources: 数据源列表,每个元素包含类型和路径/连接信息
join_config: 数据关联配置
返回:
集成后的数据集
"""
import pandas as pd
datasets = []
# 加载各数据源
for src in sources:
if src["type"] == "csv":
df = pd.read_csv(src["path"])
datasets.append({"name": src["name"], "data": df})
elif src["type"] == "database":
engine = sqlalchemy.create_engine(src["connection_string"])
df = pd.read_sql(src["query"], engine)
datasets.append({"name": src["name"], "data": df})
elif src["type"] == "api":
# 通过API获取数据并转换为DataFrame
response = requests.get(
src["url"],
headers=src.get("headers", {}),
params=src.get("params", {})
)
data = response.json()
df = pd.json_normalize(data["results"])
datasets.append({"name": src["name"], "data": df})
# 执行数据关联
result_df = datasets[0]["data"]
for i in range(1, len(datasets)):
join_spec = join_config[i-1]
right_df = datasets[i]["data"]
result_df = pd.merge(
result_df,
right_df,
left_on=join_spec["left_on"],
right_on=join_spec["right_on"],
how=join_spec.get("how", "inner")
)
return result_df
6. 性能优化与最佳实践
6.1 性能优化策略
在实际部署数据分析Agent平台时,需要考虑以下几个关键性能优化点:
6.1.1 LLM调用优化实现
class OptimizedLLMService:
def __init__(self, model_name, cache_enabled=True):
self.model = model_name
self.cache = {}
self.cache_enabled = cache_enabled
self.batch_queue = []
self.batch_size = 5
def generate_response(self, conversation, functions=None):
"""优化的LLM响应生成"""
# 缓存检查
if self.cache_enabled:
cache_key = self._create_cache_key(conversation, functions)
if cache_key in self.cache:
return self.cache[cache_key]
# 实际调用LLM API
response = self._call_llm_api(conversation, functions)
# 更新缓存
if self.cache_enabled:
self.cache[cache_key] = response
return response
def _create_cache_key(self, conversation, functions):
"""创建缓存键"""
conv_str = json.dumps(conversation)
func_str = json.dumps(functions) if functions else ""
return hashlib.md5((conv_str + func_str).encode()).hexdigest()
def _call_llm_api(self, conversation, functions):
"""实际调用LLM API的方法"""
# 实现与特定LLM提供商的API集成
pass
def add_to_batch(self, conversation, functions=None):
"""添加请求到批处理队列"""
self.batch_queue.append((conversation, functions))
if len(self.batch_queue) >= self.batch_size:
return self.process_batch()
return None
def process_batch(self):
"""批量处理LLM请求"""
if not self.batch_queue:
return []
results = []
# 实现批量API调用逻辑
# ...
self.batch_queue = []
return results
6.1.2 数据处理优化实现
def optimized_data_processing(file_path, operations, chunk_size=100000):
"""
优化的大数据集处理函数
参数:
file_path: 数据文件路径
operations: 处理操作列表
chunk_size: 每次处理的行数
返回:
处理结果摘要
"""
import pandas as pd
import dask.dataframe as dd
# 对于大文件使用Dask
if os.path.getsize(file_path) > 500 * 1024 * 1024: # 500MB
# 使用Dask进行分布式处理
ddf = dd.read_csv(file_path)
results = {
"original_shape": (ddf.shape[0].compute(), ddf.shape[1]),
"operations_performed": [],
"errors": []
}
for op in operations:
try:
if op["type"] == "drop_na":
ddf = ddf.dropna(subset=op.get("columns", None))
results["operations_performed"].append(f"删除了含缺失值的数据")
# 实现其他操作...
except Exception as e:
results["errors"].append(f"操作 {op['type']} 失败: {str(e)}")
# 计算最终结果
results["final_shape"] = (ddf.shape[0].compute(), ddf.shape[1])
results["columns"] = ddf.columns.tolist()
return results
else:
# 对于小文件使用分块处理
results = {
"original_shape": None,
"operations_performed": [],
"errors": []
}
# 使用chunked processing
chunks_processed = 0
for chunk in pd.read_csv(file_path, chunksize=chunk_size):
if results["original_shape"] is None:
# 首次获取文件总行数
with open(file_path) as f:
results["original_shape"] = (sum(1 for _ in f) - 1, len(chunk.columns))
# 处理每个数据块
for op in operations:
try:
if op["type"] == "drop_na":
chunk = chunk.dropna(subset=op.get("columns", None))
if chunks_processed == 0:
results["operations_performed"].append(f"删除含缺失值的数据")
# 实现其他操作...
except Exception as e:
if f"操作 {op['type']} 失败" not in str([err for err in results["errors"]]):
results["errors"].append(f"操作 {op['type']} 失败: {str(e)}")
chunks_processed += 1
return results
6.2 系统扩展性设计
6.2.1 插件化架构
构建可插拔的函数模块系统,便于未来扩展:
class AnalysisFunctionPlugin:
"""函数插件基类"""
def __init__(self, name, description):
self.name = name
self.description = description
def get_function_definition(self):
"""返回函数定义"""
raise NotImplementedError
def execute(self, *args, **kwargs):
"""执行函数逻辑"""
raise NotImplementedError
class PluginManager:
"""插件管理器"""
def __init__(self):
self.plugins = {}
def register_plugin(self, plugin):
"""注册新插件"""
self.plugins[plugin.name] = plugin
def get_function_definitions(self):
"""获取所有插件的函数定义"""
return [p.get_function_definition() for p in self.plugins.values()]
def execute_function(self, name, *args, **kwargs):
"""执行指定插件的函数"""
if name not in self.plugins:
raise ValueError(f"未找到插件: {name}")
return self.plugins[name].execute(*args, **kwargs)
6.2.2 分布式计算支持
对于大规模数据分析任务,增加分布式计算支持:
def setup_distributed_computing():
"""配置分布式计算环境"""
from dask.distributed import Client, LocalCluster
# 创建本地集群(实际部署时可连接到现有集群)
cluster = LocalCluster(
n_workers=4,
threads_per_worker=2,
memory_limit='2GB'
)
client = Client(cluster)
return client
def distributed_data_analysis(client, file_path, analysis_config):
"""
使用分布式计算执行数据分析
参数:
client: Dask客户端
file_path: 数据文件路径
analysis_config: 分析配置
返回:
分析结果
"""
import dask.dataframe as dd
# 读取数据
ddf = dd.read_csv(file_path)
# 定义分析任务
analysis_tasks = []
for analysis in analysis_config:
if analysis["type"] == "summary_stats":
# 分布式计算描述统计
task = client.submit(
lambda df, cols: df[cols].describe().compute(),
ddf, analysis["columns"]
)
analysis_tasks.append(("summary_stats", task))
elif analysis["type"] == "correlation":
# 分布式计算相关性矩阵
task = client.submit(
lambda df, cols: df[cols].corr().compute(),
ddf, analysis["columns"]
)
analysis_tasks.append(("correlation", task))
elif analysis["type"] == "group_analysis":
# 分布式分组分析
task = client.submit(
lambda df, group_col, agg_dict: df.groupby(group_col).agg(agg_dict).compute(),
ddf, analysis["group_by"], analysis["aggregations"]
)
analysis_tasks.append(("group_analysis", task))
# 收集结果
results = {}
for name, task in analysis_tasks:
results[name] = task.result()
return results
6.3 安全性与数据隐私保护
6.3.1 数据安全策略
class SecureDataProcessor:
"""安全数据处理器"""
def __init__(self, security_level="high"):
self.security_level = security_level
self.sensitive_columns = set()
def register_sensitive_columns(self, columns):
"""注册敏感数据列"""
self.sensitive_columns.update(columns)
def process_data_securely(self, df, operations):
"""安全地处理数据"""
# 检查操作是否涉及敏感列
for op in operations:
if op.get("columns") and any(col in self.sensitive_columns for col in op["columns"]):
self._apply_security_policy(df, op)
# 应用操作
processed_df = self._apply_operations(df, operations)
# 最终安全检查
if self.security_level == "high":
processed_df = self._apply_final_security_check(processed_df)
return processed_df
def _apply_security_policy(self, df, operation):
"""应用安全策略"""
# 实现安全检查和策略应用
pass
def _apply_operations(self, df, operations):
"""应用数据处理操作"""
# 实现安全的数据处理操作
return df
def _apply_final_security_check(self, df):
"""最终安全检查"""
# 对敏感列进行匿名化、掩码或加密
for col in self.sensitive_columns:
if col in df.columns:
df[col] = self._anonymize_data(df[col])
return df
def _anonymize_data(self, series):
"""匿名化数据"""
# 根据数据类型实现不同的匿名化方法
if series.dtype == 'object':
return series.apply(lambda x: self._mask_text(x))
elif series.dtype in ['int64', 'float64']:
return series.apply(lambda x: self._perturb_number(x))
return series
def _mask_text(self, text):
"""掩码文本数据"""
if isinstance(text, str) and len(text) > 4:
return text[:2] + '*' * (len(text) - 4) + text[-2:]
return text
def _perturb_number(self, num):
"""扰动数值数据"""
if isinstance(num, (int, float)):
# 添加小幅度随机扰动
import random
return num * (1 + random.uniform(-0.05, 0.05))
return num
7. 全栈实现:构建一个完整的数据分析Agent平台
7.1 系统架构设计
7.1.1 整体架构图
7.1.2 关键组件详述
7.1.2.1 前端层
前端界面实现需要提供直观的用户交互体验:
// React组件示例 - 数据分析对话界面
function AnalysisChat() {
const [messages, setMessages] = useState([]);
const [input, setInput] = useState('');
const [loading, setLoading] = useState(false);
const [activeVisualizations, setActiveVisualizations] = useState([]);
const sendMessage = async () => {
if (!input.trim()) return;
// 添加用户消息
const userMessage = {
role: 'user',
content: input
};
setMessages(prev => [...prev, userMessage]);
setInput('');
setLoading(true);
try {
// 调用API
const response = await fetch('/api/analyze', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
message: input,
history: messages
})
});
const data = await response.json();
// 处理可视化内容
if (data.visualizations) {
setActiveVisualizations(data.visualizations);
}
// 添加助手回复
setMessages(prev => [...prev, {
role: 'assistant',
content: data.message
}]);
} catch (error) {
console.error('分析错误:', error);
setMessages(prev => [...prev, {
role: 'assistant',
content: '处理您的请求时出错,请重试。'
}]);
} finally {
setLoading(false);
}
};
return (
<div className="analysis-chat">
<div className="message-container">
{messages.map((msg, index) => (
<div key={index} className={`message ${msg.role}`}>
{msg.content}
</div>
))}
{loading && <div className="loading">分析中...</div>}
</div>
{activeVisualizations.length > 0 && (
<div className="visualization-panel">
{activeVisualizations.map((viz, index) => (
<div key={index} className="viz-container">
<h3>{viz.title}</h3>
<img src={viz.image_data} alt={viz.title} />
</div>
))}
</div>
)}
<div className="input-area">
<textarea
value={input}
onChange={(e) => setInput(e.target.value)}
placeholder="请输入数据分析问题..."
/>
<button onClick={sendMessage} disabled={loading}>
发送
</button>
</div>
</div>
);
}
7.1.2.2 API层
实现RESTful API服务,处理前端请求:
# FastAPI实现示例
from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from typing import List, Optional
app = FastAPI(title="数据分析Agent API")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# 数据模型
class Message(BaseModel):
role: str
content: str
class AnalysisRequest(BaseModel):
message: str
history: Optional[List[Message]] = []
data_source: Optional[str] = None
class Visualization(BaseModel):
title: str
image_data: str
chart_type: str
class AnalysisResponse(BaseModel):
message: str
visualizations: Optional[List[Visualization]] = None
# 依赖项
def get_current_user(token: str = Depends(oauth2_scheme)):
# 实现用户验证逻辑
return {"username": "demo_user"}
# 路由
@app.post("/api/analyze", response_model=AnalysisResponse)
async def analyze_data(
request: AnalysisRequest,
current_user = Depends(get_current_user)
):
try:
# 初始化分析代理
agent = AnalysisAgent(llm_service, registry)
# 转换历史记录格式
for msg in request.history:
agent.add_message(msg.role, msg.content)
# 处理请求
response = agent.process_query(request.message)
# 提取可能的可视化结果
visualizations = []
# 简单的可视化提取逻辑(实际实现会更复杂)
if "image_data" in response:
viz_data = json.loads(response["image_data"])
for viz in viz_data:
visualizations.append(
Visualization(
title=viz.get("title", "数据可视化"),
image_data=viz["image"],
chart_type=viz.get("type", "chart")
)
)
return AnalysisResponse(
message=response["content"] if isinstance(response, dict) else response,
visualizations=visualizations
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
7.2 完整实现流程
7.2.1 后端服务实现
# 主应用入口
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
# 导入组件
from app.api.routes import router as api_router
from app.core.config import settings
from app.core.security import create_access_token
from app.services.llm_service import initialize_llm
from app.services.function_registry import setup_registry
# 初始化应用
app = FastAPI(title="智能数据分析平台")
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 挂载API路由
app.include_router(api_router, prefix="/api")
# 挂载静态文件
app.mount("/static", StaticFiles(directory="static"), name="static")
# 应用启动事件
@app.on_event("startup")
async def startup_event():
# 初始化LLM服务
app.state.llm_service = initialize_llm(
model_name=settings.LLM_MODEL,
api_key=settings.LLM_API_KEY,
cache_enabled=settings.ENABLE_CACHE
)
# 设置函数注册中心
app.state.function_registry = setup_registry()
print("数据分析Agent平台已启动,所有服务就绪")
# 应用关闭事件
@app.on_event("shutdown")
async def shutdown_event():
# 清理资源
print("数据分析Agent平台正在关闭")
# 直接运行
if __name__ == "__main__":
uvicorn.run(
"main:app",
host=settings.HOST,
port=settings.PORT,
reload=settings.DEBUG
)
7.2.2 部署与容器化
使用Docker容器化部署应用:
# Dockerfile
FROM python:3.9-slim
WORKDIR /app
# 设置环境变量
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
# 安装依赖
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY . .
# 暴露端口
EXPOSE 8000
# 运行应用
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
docker-compose配置:
# docker-compose.yml
version: '3.8'
services:
web:
build: .
volumes:
- ./:/app
ports:
- "8000:8000"
environment:
- LLM_API_KEY=${LLM_API_KEY}
- LLM_MODEL=${LLM_MODEL}
- DEBUG=False
depends_on:
- redis
redis:
image: redis:alpine
ports:
- "6379:6379"
volumes:
- redis_data:/data
volumes:
redis_data:
8. 实际案例:电子商务销售数据分析
8.1 案例需求与数据集
8.1.1 业务背景
某电子商务公司需要分析其销售数据,了解产品表现、客户行为和销售趋势,以指导业务决策。
8.1.2 数据集描述
数据集:e_commerce_sales.csv
字段:
- order_id: 订单ID
- customer_id: 客户ID
- order_date: 订单日期
- product_id: 产品ID
- product_category: 产品类别
- price: 价格
- quantity: 数量
- discount: 折扣
- region: 地区
- payment_method: 支付方式
- customer_age: 客户年龄
- customer_gender: 客户性别
8.2 Agent实现过程
8.2.1 函数准备
为此案例准备特定分析函数:
def sales_trend_analysis(file_path, time_unit="month", category=None):
"""
分析销售趋势
参数:
file_path: 数据文件路径
time_unit: 时间单位(day/week/month/quarter)
category: 产品类别过滤,可选
返回:
销售趋势数据
"""
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import base64
from io import BytesIO
# 加载数据
df = pd.read_csv(file_path)
# 转换日期
df['order_date'] = pd.to_datetime(df['order_date'])
# 计算销售额
df['sales_amount'] = df['price'] * df['quantity'] * (1 - df['discount'])
# 过滤类别
if category:
df = df[df['product_category'] == category]
# 按时间单位分组
if time_unit == "day":
df['time_period'] = df['order_date'].dt.date
elif time_unit == "week":
df['time_period'] = df['order_date'].dt.isocalendar().week
elif time_unit == "month":
df['time_period'] = df['order_date'].dt.to_period('M')
elif time_unit == "quarter":
df['time_period'] = df['order_date'].dt.to_period('Q')
# 汇总数据
sales_trend = df.groupby('time_period').agg({
'sales_amount': 'sum',
'order_id': 'nunique',
'quantity': 'sum'
}).reset_index()
# 创建可视化
plt.figure(figsize=(12, 6))
fig, ax1 = plt.subplots(figsize=(12, 6))
ax2 = ax1.twinx()
sns.lineplot(x=sales_trend.index, y='sales_amount', data=sales_trend, color='blue', marker='o', ax=ax1)
sns.lineplot(x=sales_trend.index, y='order_id', data=sales_trend, color='red', marker='x', ax=ax2)
ax1.set_xlabel(f'Time ({time_unit})')
ax1.set_ylabel('Sales Amount', color='blue')
ax2.set_ylabel('Order Count', color='red')
title = f'Sales Trend by {time_unit.capitalize()}'
if category:
title += f' - Category: {category}'
plt.title(title)
plt.tight_layout()
# 转换为base64
buffer = BytesIO()
plt.savefig(buffer, format='png')
buffer.seek(0)
image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
plt.close()
return {
"trend_data": sales_trend.to_dict('records'),
"visualization": f"data:image/png;base64,{image_base64}",
"summary": {
"total_sales": df['sales_amount'].sum(),
"total_orders": df['order_id'].nunique(),
"total_units": df['quantity'].sum(),
"avg_order_value": df['sales_amount'].sum() / df['order_id'].nunique()
}
}
8.2.2 Agent交互示例
以下是用户与Agent的交互示例:
用户: 分析我们的电子商务销售数据,重点关注最近6个月的销售趋势和各产品类别的表现。
Agent: 我将分析电子商务销售数据,特别关注最近6个月的趋势和产品类别表现。首先,让我加载数据以了解其结构。
[函数调用: load_csv_data(file_path="e_commerce_sales.csv")]
非常好,我已经加载了销售数据。根据数据概览,我可以看到这是一个包含订单信息、产品详情、客户信息和销售数据的电子商务数据集。
现在,我将分析最近6个月的销售趋势:
[函数调用: sales_trend_analysis(file_path="e_commerce_sales.csv", time_unit="month")]
基于分析结果,我可以提供以下洞察:
1. 销售趋势:
- 过去6个月总销售额为$1,24
7,635
- 销售额在3月达到峰值($267,890),比1月增长了32%
- 4-5月出现了轻微下滑,但6月又恢复增长
2. 产品类别表现:
[函数调用: category_performance_analysis(file_path="e_commerce_sales.csv")]
主要发现:
- 电子产品类别贡献了最高销售额(42%)
- 服装类别的订单量最大,但平均订单价值较低
- 家居用品类别的增长率最高,比去年同期增长了24%
3. 关键绩效指标:
- 平均订单价值: $78.45
- 客单价同比增长: 8.3%
- 复购率: 23%
- 转化率: 3.2%
建议行动:
1. 增强电子产品类别的促销活动,特别是在销售低谷期
2. 针对家居用品开展交叉销售策略,利用其增长潜力
3. 优化服装类别的产品组合,提高平均订单价值
需要我深入分析任何特定类别或指标吗?
8.3 分析结果与业务洞察
根据上述交互,系统提供了以下关键洞察:
8.3.1 高级可视化面板
创建一个综合的分析面板,为业务决策者提供直观的数据洞察:
def create_ecommerce_dashboard(file_path):
"""
创建完整的电子商务分析面板
参数:
file_path: 数据文件路径
返回:
包含多个可视化图表的分析面板
"""
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import base64
from io import BytesIO
from matplotlib.gridspec import GridSpec
# 加载数据
df = pd.read_csv(file_path)
df['order_date'] = pd.to_datetime(df['order_date'])
df['sales_amount'] = df['price'] * df['quantity'] * (1 - df['discount'])
# 创建大型画布
plt.figure(figsize=(20, 16))
gs = GridSpec(3, 3, figure=plt.gcf())
# 1. 销售趋势图 - 左上
ax1 = plt.subplot(gs[0, 0:2])
monthly_sales = df.groupby(df['order_date'].dt.to_period('M')).agg({
'sales_amount': 'sum',
'order_id': 'nunique'
}).reset_index()
monthly_sales['order_date'] = monthly_sales['order_date'].astype(str)
ax1_twin = ax1.twinx()
ax1.bar(monthly_sales['order_date'], monthly_sales['sales_amount'], color='cornflowerblue', alpha=0.7)
ax1_twin.plot(monthly_sales['order_date'], monthly_sales['order_id'], color='tomato', marker='o', linewidth=2)
ax1.set_title('Monthly Sales Trend', fontsize=14, fontweight='bold')
ax1.set_xlabel('Month')
ax1.set_ylabel('Sales Amount ($)', color='cornflowerblue')
ax1_twin.set_ylabel('Order Count', color='tomato')
ax1.tick_params(axis='x', rotation=45)
# 2. 类别销售占比 - 右上
ax2 = plt.subplot(gs[0, 2])
category_sales = df.groupby('product_category').agg({
'sales_amount': 'sum'
}).reset_index()
category_sales = category_sales.sort_values('sales_amount', ascending=False)
explode = [0.1 if i == 0 else 0 for i in range(len(category_sales))]
ax2.pie(
category_sales['sales_amount'],
labels=category_sales['product_category'],
autopct='%1.1f%%',
explode=explode,
shadow=True,
startangle=90
)
ax2.set_title('Sales by Category', fontsize=14, fontweight='bold')
# 3. 地区销售热图 - 左中
ax3 = plt.subplot(gs[1, 0:2])
region_category = pd.pivot_table(
df,
values='sales_amount',
index='region',
columns='product_category',
aggfunc='sum'
).fillna(0)
sns.heatmap(
region_category,
annot=True,
fmt='.0f',
cmap='YlGnBu',
linewidths=.5,
ax=ax3
)
ax3.set_title('Sales by Region and Category', fontsize=14, fontweight='bold')
# 4. 支付方式分布 - 右中
ax4 = plt.subplot(gs[1, 2])
payment_counts = df['payment_method'].value_counts()
ax4.bar(
payment_counts.index,
payment_counts.values,
color=sns.color_palette("Set2")
)
ax4.set_title('Payment Methods', fontsize=14, fontweight='bold')
ax4.set_ylabel('Number of Orders')
ax4.tick_params(axis='x', rotation=45)
# 5. 客户年龄分布 - 左下
ax5 = plt.subplot(gs[2, 0])
sns.histplot(
df['customer_age'],
bins=10,
kde=True,
color='skyblue',
ax=ax5
)
ax5.set_title('Customer Age Distribution', fontsize=14, fontweight='bold')
ax5.set_xlabel('Age')
ax5.set_ylabel('Count')
# 6. 性别与类别关系 - 中下
ax6 = plt.subplot(gs[2, 1])
gender_category = df.groupby(['customer_gender', 'product_category']).agg({
'sales_amount': 'sum'
}).reset_index()
sns.barplot(
x='product_category',
y='sales_amount',
hue='customer_gender',
data=gender_category,
ax=ax6
)
ax6.set_title('Sales by Gender and Category', fontsize=14, fontweight='bold')
ax6.set_xlabel('Product Category')
ax6.set_ylabel('Sales Amount ($)')
ax6.tick_params(axis='x', rotation=45)
# 7. 折扣效果 - 右下
ax7 = plt.subplot(gs[2, 2])
# 创建折扣分组
df['discount_group'] = pd.cut(
df['discount'],
bins=[0, 0.1, 0.2, 0.3, 0.4, 1.0],
labels=['0-10%', '10-20%', '20-30%', '30-40%', '40-100%']
)
discount_effect = df.groupby('discount_group').agg({
'quantity': 'sum',
'order_id': 'nunique'
}).reset_index()
ax7_twin = ax7.twinx()
ax7.bar(
discount_effect['discount_group'],
discount_effect['quantity'],
color='lightgreen',
alpha=0.7
)
ax7_twin.plot(
discount_effect['discount_group'],
discount_effect['order_id'],
color='darkgreen',
marker='D',
linewidth=2
)
ax7.set_title('Discount Effect', fontsize=14, fontweight='bold')
ax7.set_xlabel('Discount Range')
ax7.set_ylabel('Total Quantity', color='lightgreen')
ax7_twin.set_ylabel('Order Count', color='darkgreen')
# 添加仪表板标题
plt.suptitle('E-Commerce Sales Analytics Dashboard', fontsize=20, fontweight='bold', y=0.98)
plt.tight_layout(rect=[0, 0, 1, 0.97])
# 转换为base64
buffer = BytesIO()
plt.savefig(buffer, format='png', dpi=150)
buffer.seek(0)
image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
plt.close()
# 提取关键指标
key_metrics = {
"total_sales": df['sales_amount'].sum(),
"total_orders": df['order_id'].nunique(),
"total_customers": df['customer_id'].nunique(),
"avg_order_value": df['sales_amount'].sum() / df['order_id'].nunique(),
"top_category": category_sales['product_category'].iloc[0],
"top_region": df.groupby('region')['sales_amount'].sum().idxmax()
}
return {
"dashboard_image": f"data:image/png;base64,{image_base64}",
"key_metrics": key_metrics
}
8.3.2 执行预测分析
基于历史数据预测未来销售趋势:
def predict_future_sales(file_path, prediction_periods=3, confidence_level=0.95):
"""
预测未来销售趋势
参数:
file_path: 数据文件路径
prediction_periods: 预测未来的期数(月)
confidence_level: 预测区间的置信度
返回:
预测结果和可视化
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
import seaborn as sns
import base64
from io import BytesIO
# 加载数据
df = pd.read_csv(file_path)
df['order_date'] = pd.to_datetime(df['order_date'])
df['sales_amount'] = df['price'] * df['quantity'] * (1 - df['discount'])
# 按月汇总
monthly_sales = df.groupby(df['order_date'].dt.to_period('M')).agg({
'sales_amount': 'sum'
})
monthly_sales.index = pd.PeriodIndex(monthly_sales.index)
# 转换为时间序列
y = monthly_sales['sales_amount']
# 拟合ARIMA模型 - 自动选择最佳参数
import pmdarima as pm
# 寻找最优参数
arima_model = pm.auto_arima(
y,
start_p=0, start_q=0,
max_p=3, max_q=3,
d=None, max_d=2,
seasonal=True, m=12,
information_criterion='aic',
trace=False,
error_action='ignore',
suppress_warnings=True,
stepwise=True
)
# 获取最优参数
order = arima_model.order
seasonal_order = arima_model.seasonal_order
# 使用最优参数创建模型
model = ARIMA(
y,
order=order,
seasonal_order=seasonal_order
)
model_fit = model.fit()
# 进行预测
forecast = model_fit.get_forecast(steps=prediction_periods)
forecast_mean = forecast.predicted_mean
# 获取预测区间
conf_int = forecast.conf_int(alpha=1-confidence_level)
# 可视化
plt.figure(figsize=(12, 6))
# 绘制历史数据
plt.plot(y.index.astype(str), y.values, marker='o', label='Historical Sales')
# 绘制预测数据
forecast_index = pd.period_range(
start=y.index[-1] + 1,
periods=prediction_periods,
freq='M'
)
plt.plot(
forecast_index.astype(str),
forecast_mean.values,
marker='D',
color='red',
label='Forecasted Sales'
)
# 绘制预测区间
plt.fill_between(
forecast_index.astype(str),
conf_int.iloc[:, 0].values,
conf_int.iloc[:, 1].values,
color='red',
alpha=0.2,
label=f'{confidence_level*100}% Confidence Interval'
)
plt.title('Sales Forecast', fontsize=15, fontweight='bold')
plt.xlabel('Month')
plt.ylabel('Sales Amount ($)')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.xticks(rotation=45)
plt.tight_layout()
# 转换为base64
buffer = BytesIO()
plt.savefig(buffer, format='png')
buffer.seek(0)
image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
plt.close()
# 准备预测摘要
last_value = y.iloc[-1]
forecast_dict = {}
for i, period in enumerate(forecast_index):
period_str = period.strftime('%Y-%m')
value = forecast_mean.iloc[i]
percentage_change = ((value - last_value) / last_value) * 100
forecast_dict[period_str] = {
"forecasted_sales": value,
"lower_bound": conf_int.iloc[i, 0],
"upper_bound": conf_int.iloc[i, 1],
"percentage_change": percentage_change
}
last_value = value
return {
"forecast_data": forecast_dict,
"visualization": f"data:image/png;base64,{image_base64}",
"model_info": {
"model_type": "ARIMA",
"parameters": {
"order": order,
"seasonal_order": seasonal_order,
"aic": model_fit.aic
},
"accuracy_metrics": {
"mse": model_fit.mse,
"aic": model_fit.aic,
"bic": model_fit.bic
}
}
}
8.4 实际业务决策支持
基于分析结果,Agent提供了以下业务决策建议:
8.4.1 库存优化建议
根据预测结果,系统可以提供库存管理优化建议:
def generate_inventory_recommendations(file_path, forecast_data):
"""
基于销售预测生成库存管理建议
参数:
file_path: 历史数据文件路径
forecast_data: 销售预测数据
返回:
库存优化建议
"""
import pandas as pd
import numpy as np
# 加载历史数据
df = pd.read_csv(file_path)
# 计算产品库存关键指标
product_metrics = df.groupby('product_id').agg({
'quantity': 'sum',
'price': 'mean',
'discount': 'mean',
'order_id': 'nunique'
}).reset_index()
# 添加销售频率
product_metrics['sales_frequency'] = product_metrics['quantity'] / product_metrics['order_id']
# 计算产品类别增长率
category_growth = {}
for category in df['product_category'].unique():
category_df = df[df['product_category'] == category]
# 简化: 用最近两个月的数据计算增长率
category_growth[category] = 0.10 # 示例值,实际应从forecast_data获取
# 生成建议
recommendations = []
for category, growth_rate in category_growth.items():
category_products = product_metrics[df[df['product_id'].isin(product_metrics['product_id'])]['product_category'] == category]
if growth_rate > 0.15:
# 高增长类别
recommendations.append({
"category": category,
"growth_rate": growth_rate,
"action": "增加库存",
"details": f"预计{category}类别销售增长{growth_rate:.1%},建议增加库存20%",
"priority": "高",
"top_products": category_products.sort_values('quantity', ascending=False).head(3)['product_id'].tolist()
})
elif growth_rate > 0.05:
# 中等增长类别
recommendations.append({
"category": category,
"growth_rate": growth_rate,
"action": "适度增加库存",
"details": f"预计{category}类别销售增长{growth_rate:.1%},建议增加库存10%",
"priority": "中",
"top_products": category_products.sort_values('quantity', ascending=False).head(3)['product_id'].tolist()
})
elif growth_rate > -0.05:
# 稳定类别
recommendations.append({
"category": category,
"growth_rate": growth_rate,
"action": "维持当前库存",
"details": f"预计{category}类别销售稳定,增长率{growth_rate:.1%},维持当前库存水平",
"priority": "中",
"top_products": category_products.sort_values('quantity', ascending=False).head(3)['product_id'].tolist()
})
else:
# 下降类别
recommendations.append({
"category": category,
"growth_rate": growth_rate,
"action": "减少库存",
"details": f"预计{category}类别销售下降{abs(growth_rate):.1%},建议减少库存15%,考虑促销清仓",
"priority": "高",
"top_products": category_products.sort_values('quantity', ascending=False).head(3)['product_id'].tolist()
})
return {
"inventory_recommendations": recommendations,
"summary": f"共生成{len(recommendations)}个库存优化建议,包括{sum(1 for r in recommendations if r['priority'] == '高')}个高优先级建议"
}
8.4.2 营销活动优化建议
def generate_marketing_recommendations(file_path):
"""
根据销售数据生成营销活动优化建议
参数:
file_path: 数据文件路径
返回:
营销优化建议
"""
import pandas as pd
import numpy as np
# 加载数据
df = pd.read_csv(file_path)
df['order_date'] = pd.to_datetime(df['order_date'])
df['sales_amount'] = df['price'] * df['quantity'] * (1 - df['discount'])
# 分析客户细分
customer_segments = df.groupby('customer_id').agg({
'order_id': 'nunique',
'sales_amount': 'sum',
'product_category': lambda x: x.mode().iloc[0] if not x.mode().empty else None,
'customer_age': 'first',
'customer_gender': 'first',
'region': 'first'
}).reset_index()
# 定义客户价值
customer_segments['customer_value'] = pd.qcut(
customer_segments['sales_amount'],
q=3,
labels=['Low', 'Medium', 'High']
)
# 定义购买频率
customer_segments['purchase_frequency'] = pd.qcut(
customer_segments['order_id'],
q=3,
labels=['Low', 'Medium', 'High']
)
# 分析折扣效果
discount_effect = df.groupby('product_category').apply(
lambda x: np.corrcoef(x['discount'], x['quantity'])[0, 1]
).reset_index(name='discount_correlation')
# 生成建议
recommendations = []
# 1. 高价值客户建议
high_value = customer_segments[customer_segments['customer_value'] == 'High']
top_categories = high_value['product_category'].value_counts().nlargest(3)
recommendations.append({
"segment": "高价值客户",
"action": "忠诚度计划升级",
"details": f"针对{len(high_value)}名高价值客户提供专属VIP服务,重点推广{', '.join(top_categories.index.tolist())}类别",
"expected_impact": "提高留存率15%,增加复购频率",
"priority": "高"
})
# 2. 折扣敏感类别
discount_sensitive = discount_effect[discount_effect['discount_correlation'] > 0.5]['product_category'].tolist()
if discount_sensitive:
recommendations.append({
"segment": "折扣敏感产品",
"action": "限时促销活动",
"details": f"针对{', '.join(discount_sensitive)}类别产品进行策略性折扣,提高销量",
"expected_impact": "提升目标类别销量25%,带动相关品类增长",
"priority": "中"
})
# 3. 低频购买客户
low_frequency = customer_segments[
(customer_segments['purchase_frequency'] == 'Low') &
(customer_segments['customer_value'] != 'Low')
]
if len(low_frequency) > 0:
recommendations.append({
"segment": "低频高价值客户",
"action": "再营销活动",
"details": f"针对{len(low_frequency)}名低频但高价值客户开展个性化邮件营销,提供首单折扣",
"expected_impact": "激活30%的不活跃高价值客户",
"priority": "高"
})
# 4. 区域定制营销
region_performance = df.groupby('region').agg({
'sales_amount': 'sum',
'order_id': 'nunique'
}).reset_index()
underperforming_regions = region_performance.nsmallest(2, 'sales_amount')['region'].tolist()
recommendations.append({
"segment": "地区定制营销",
"action": "区域促销活动",
"details": f"在{', '.join(underperforming_regions)}等表现欠佳地区开展本地化营销活动",
"expected_impact": "目标地区销售提升20%",
"priority": "中"
})
# 5. 交叉销售机会
# 简化的产品关联分析
category_pairs = []
for customer in customer_segments['customer_id'].unique()[:100]: # 限制计算量
customer_categories = df[df['customer_id'] == customer]['product_category'].unique()
if len(customer_categories) > 1:
for i in range(len(customer_categories)):
for j in range(i+1, len(customer_categories)):
category_pairs.append((customer_categories[i], customer_categories[j]))
if category_pairs:
from collections import Counter
top_pairs = Counter(category_pairs).most_common(1)
recommendations.append({
"segment": "交叉销售机会",
"action": "捆绑销售活动",
"details": f"创建'{top_pairs[0][0][0]}'和'{top_pairs[0][0][1]}'类别的产品捆绑促销",
"expected_impact": "提高购物车价值15%,增加单次购买量",
"priority": "中"
})
return {
"marketing_recommendations": recommendations,
"summary": f"共生成{len(recommendations)}个营销优化建议,包括{sum(1 for r in recommendations if r['priority'] == '高')}个高优先级建议"
}
9. 未来发展与技术展望
9.1 Function Calling技术演进方向
随着Function Calling技术的快速发展,我们可以预见以下几个关键发展方向:
timeline
title Function Calling技术演进路线图
2023 : 基础Function Calling<br>单函数调用能力
2024 : 增强型Function Calling<br>函数链调用<br>多函数协同
2025 : 智能函数编排<br>自动化函数选择<br>自适应调参
2026 : 多模态Function Calling<br>跨媒体函数能力<br>结合视觉听觉输入
2027 : 自主Function创建<br>按需生成专用函数<br>自我优化系统
9.1.1 函数自动发现与合成
未来的Function Calling不仅能调用已定义的函数,还能自动发现和合成新函数:
class AutoFunctionDiscovery:
"""自动函数发现与合成系统"""
def __init__(self, registry, llm_service):
self.registry = registry # 现有函数注册中心
self.llm = llm_service
self.function_templates = {}
def discover_function_needs(self, query_history):
"""分析用户查询历史,发现潜在的函数需求"""
# 提取查询模式和常见需求
prompt = f"""
分析以下用户查询历史,识别出系统缺少但用户经常需要的数据分析功能:
{json.dumps(query_history)}
请输出JSON格式的函数需求列表,包含:
1. 函数名称
2. 功能描述
3. 建议的参数
4. 实现优先级
"""
response = self.llm.generate_response(prompt)
return json.loads(response)
def synthesize_function(self, function_spec):
"""基于规范合成新函数"""
# 使用LLM生成函数代码
prompt = f"""
根据以下函数规范,生成完整的Python函数实现:
{json.dumps(function_spec)}
函数应遵循以下规范:
1. 包含详细的文档字符串
2. 实现所有指定的参数
3. 包含适当的错误处理
4. 返回结构化的结果
仅输出Python代码,不要有任何其他解释。
"""
function_code = self.llm.generate_response(prompt)
# 验证函数代码
validation_result = self.validate_function_code(function_code)
if validation_result["valid"]:
# 动态加载函数
function_object = self.load_function(function_code)
# 注册到系统
self.registry.register({
"name": function_spec["name"],
"function": function_object,
"description": function_spec["description"],
"parameters": function_spec["parameters"],
"required": function_spec.get("required", []),
"auto_generated": True
})
return {
"status": "success",
"message": f"成功合成并注册函数: {function_spec['name']}",
"function_code": function_code
}
else:
return {
"status": "error",
"message": f"函数验证失败: {validation_result['error']}",
"function_code": function_code
}
def validate_function_code(self, code):
"""验证生成的函数代码"""
try:
# 语法检查
compile(code, '<string>', 'exec')
# 更多安全和质量检查可以添加在这里
return {"valid": True}
except Exception as e:
return {"valid": False, "error": str(e)}
def load_function(self, code):
"""动态加载函数代码"""
local_vars = {}
exec(code, globals(), local_vars)
# 提取函数对象
function_name = None
for name, obj in local_vars.items():
if callable(obj):
function_name = name
break
if function_name:
return local_vars[function_name]
else:
raise ValueError("无法在生成的代码中找到函数")
9.2 智能分析平台的未来形态
未来的数据分析平台将演变为多Agent协作系统:
9.2.1 自适应学习分析系统
未来的智能分析平台将具备自我学习能力,能够从用户交互中不断优化:
class AdaptiveLearningSystem:
"""自适应学习的分析系统"""
def __init__(self, registry, llm_service, storage):
self.registry = registry
self.llm = llm_service
self.storage = storage
self.feedback_store = []
self.usage_patterns = {}
self.performance_metrics = {}
def record_interaction(self, user_query, function_calls, response, execution_time):
"""记录用户交互"""
interaction = {
"timestamp": datetime.now().isoformat(),
"user_query": user_query,
"function_calls": function_calls,
"response": response,
"execution_time": execution_time
}
self.storage.save_interaction(interaction)
# 更新使用模式
for func_call in function_calls:
func_name = func_call["name"]
if func_name not in self.usage_patterns:
self.usage_patterns[func_name] = {"count": 0, "execution_times": []}
self.usage_patterns[func_name]["count"] += 1
self.usage_patterns[func_name]["execution_times"].append(func_call.get("execution_time", 0))
def collect_feedback(self, interaction_id, rating, comments=None):
"""收集用户反馈"""
feedback = {
"interaction_id": interaction_id,
"rating": rating,
"comments": comments,
"timestamp": datetime.now().isoformat()
}
self.feedback_store.append(feedback)
self.storage.save_feedback(feedback)
# 如果评分较低,触发学习过程
if rating < 3:
self.trigger_learning(interaction_id, feedback)
def trigger_learning(self, interaction_id, feedback):
"""触发学习过程"""
# 获取相关交互
interaction = self.storage.get_interaction(interaction_id)
# 分析改进点
improvement_analysis = self.analyze_improvement_needs(interaction, feedback)
# 应用改进
for improvement in improvement_analysis:
if improvement["type"] == "function_optimization":
self.optimize_function(improvement["function_name"], improvement["suggestions"])
elif improvement["type"] == "prompt_enhancement":
self.enhance_prompt_templates(improvement["suggestions"])
elif improvement["type"] == "new_function_need":
self.suggest_new_function(improvement["specifications"])
def analyze_improvement_needs(self, interaction, feedback):
"""分析需要改进的地方"""
prompt = f"""
分析以下用户交互和反馈,识别系统需要改进的地方:
用户查询: {interaction['user_query']}
函数调用: {json.dumps(interaction['function_calls'])}
系统响应: {interaction['response']}
用户评分: {feedback['rating']}
用户评论: {feedback['comments']}
请识别以下类型的改进需求:
1. 函数优化需求
2. 提示词模板增强
3. 新函数需求
输出JSON格式的改进建议列表。
"""
response = self.llm.generate_response(prompt)
return json.loads(response)
def optimize_function(self, function_name, suggestions):
"""优化现有函数"""
# 获取当前函数定义
current_function = self.registry.get_function(function_name)
if not current_function:
return {"status": "error", "message": f"找不到函数: {function_name}"}
# 生成优化版本
prompt = f"""
根据以下建议优化函数:
当前函数:
{inspect.getsource(current_function["function"])}
优化建议:
{json.dumps(suggestions)}
请提供完整的优化后函数代码,保持函数签名不变。
仅输出Python代码,不要有任何其他解释。
"""
optimized_code = self.llm.generate_response(prompt)
# 验证和更新
validation = self.validate_function_code(optimized_code)
if validation["valid"]:
# 动态加载函数
new_function = self.load_function(optimized_code)
# 更新注册
self.registry.update_function(function_name, new_function)
return {
"status": "success",
"message": f"已优化函数: {function_name}"
}
else:
return {
"status": "error",
"message": f"函数优化验证失败: {validation['error']}"
}
def validate_function_code(self, code):
"""验证函数代码"""
try:
# 语法检查
compile(code, '<string>', 'exec')
return {"valid": True}
except Exception as e:
return {"valid": False, "error": str(e)}
def load_function(self, code):
"""动态加载函数代码"""
local_vars = {}
exec(code, globals(), local_vars)
# 提取函数对象
function_name = None
for name, obj in local_vars.items():
if callable(obj):
function_name = name
break
if function_name:
return local_vars[function_name]
else:
raise ValueError("无法在生成的代码中找到函数")
def get_performance_insights(self):
"""获取系统性能洞察"""
# 分析函数使用情况
most_used = sorted(
self.usage_patterns.items(),
key=lambda x: x[1]["count"],
reverse=True
)[:5]
# 分析执行时间
slowest_functions = []
for name, stats in self.usage_patterns.items():
if stats["execution_times"]:
avg_time = sum(stats["execution_times"]) / len(stats["execution_times"])
slowest_functions.append((name, avg_time))
slowest_functions = sorted(slowest_functions, key=lambda x: x[1], reverse=True)[:5]
# 分析用户满意度
avg_rating = sum(f["rating"] for f in self.feedback_store) / len(self.feedback_store) if self.feedback_store else 0
return {
"most_used_functions": most_used,
"slowest_functions": slowest_functions,
"user_satisfaction": avg_rating,
"total_interactions": len(self.storage.get_all_interactions()),
"improvement_opportunities": self.identify_improvement_opportunities()
}
def identify_improvement_opportunities(self):
"""识别系统改进机会"""
# 分析交互历史,识别模式和改进机会
return [] # 实际实现会分析历史数据
9.2.2 端到端的异构数据处理能力
未来的分析系统将无缝处理各种格式的数据:
class HeterogeneousDataProcessor:
"""异构数据处理器"""
def __init__(self):
self.processors = {}
self.register_default_processors()
def register_default_processors(self):
"""注册默认数据处理器"""
# 结构化数据处理器
self.register_processor("csv", self.process_csv)
self.register_processor("json", self.process_json)
self.register_processor("excel", self.process_excel)
self.register_processor("database", self.process_database)
# 非结构化数据处理器
self.register_processor("text", self.process_text)
self.register_processor("pdf", self.process_pdf)
self.register_processor("image", self.process_image)
self.register_processor("audio", self.process_audio)
# 混合数据处理器
self.register_processor("webpage", self.process_webpage)
self.register_processor("api", self.process_api)
def register_processor(self, data_type, processor_func):
"""注册数据处理器"""
self.processors[data_type] = processor_func
def process_data(self, data_source, data_type=None):
"""处理数据源"""
# 如果未指定类型,尝试推断
if data_type is None:
data_type = self.infer_data_type(data_source)
if data_type not in self.processors:
raise ValueError(f"不支持的数据类型: {data_type}")
return self.processors[data_type](data_source)
def infer_data_type(self, data_source):
"""推断数据类型"""
if isinstance(data_source, str):
# 检查是否是文件路径
if os.path.exists(data_source):
extension = os.path.splitext(data_source)[1].lower()
if extension == '.csv':
return "csv"
elif extension in ['.xlsx', '.xls']:
return "excel"
elif extension == '.json':
return "json"
elif extension == '.pdf':
return "pdf"
elif extension in ['.jpg', '.jpeg', '.png']:
return "image"
elif extension in ['.mp3', '.wav']:
return "audio"
elif extension == '.txt':
return "text"
# 检查是否是URL
if data_source.startswith(('http://', 'https://')):
if any(ext in data_source for ext in ['.csv', '.json', '.xlsx']):
return "api"
else:
return "webpage"
# 检查是否是数据库连接字符串
if any(db_type in data_source.lower() for db_type in
['mysql', 'postgresql', 'sqlite', 'mongodb']):
return "database"
# 默认假设为JSON数据
return "json"
# 各种数据类型的处理方法
def process_csv(self, source):
"""处理CSV数据"""
import pandas as pd
return pd.read_csv(source)
def process_json(self, source):
"""处理JSON数据"""
import pandas as pd
if isinstance(source, str):
if os.path.exists(source):
with open(source, 'r') as f:
data = json.load(f)
else:
data = json.loads(source)
else:
data = source
return pd.json_normalize(data)
def process_excel(self, source):
"""处理Excel数据"""
import pandas as pd
return pd.read_excel(source)
def process_database(self, connection_string):
"""处理数据库连接"""
import sqlalchemy
import pandas as pd
engine = sqlalchemy.create_engine(connection_string)
# 返回表信息而不是直接查询数据
from sqlalchemy import inspect
inspector = inspect(engine)
tables = {}
for table_name in inspector.get_table_names():
columns = [column['name'] for column in inspector.get_columns(table_name)]
tables[table_name] = columns
return {
"type": "database_metadata",
"tables": tables,
"connection": connection_string
}
def process_text(self, source):
"""处理文本数据"""
if os.path.exists(source):
with open(source, 'r') as f:
text = f.read()
else:
text = source
# 简单文本分析
return {
"type": "text_analysis",
"character_count": len(text),
"word_count": len(text.split()),
"line_count": len(text.splitlines()),
"sample": text[:200] + "..." if len(text) > 200 else text
}
def process_pdf(self, source):
"""处理PDF文档"""
try:
import PyPDF2
text_content = []
with open(source, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page_num in range(len(reader.pages)):
text_content.append(reader.pages[page_num].extract_text())
return {
"type": "pdf_extraction",
"page_count": len(reader.pages),
"text_content": text_content
}
except ImportError:
return {"error": "需要安装PyPDF2库以处理PDF文档"}
def process_image(self, source):
"""处理图像数据"""
try:
from PIL import Image
import numpy as np
img = Image.open(source)
return {
"type": "image_metadata",
"format": img.format,
"size": img.size,
"mode": img.mode
}
except ImportError:
return {"error": "需要安装Pillow库以处理图像"}
def process_audio(self, source):
"""处理音频数据"""
# 简单返回元数据,实际应用可能需要音频分析库
return {
"type": "audio_metadata",
"file_path": source,
"file_size_kb": os.path.getsize(source) / 1024
}
def process_webpage(self, url):
"""处理网页内容"""
try:
import requests
from bs4 import BeautifulSoup
response = requests.get(url)
soup = BeautifulSoup(response.text, 'html.parser')
# 提取结构化信息
title = soup.title.string if soup.title else "无标题"
# 提取表格数据
tables = []
for table in soup.find_all('table'):
table_data = []
for row in table.find_all('tr'):
row_data = [cell.get_text(strip=True) for cell in row.find_all(['td', 'th'])]
if row_data:
table_data.append(row_data)
if table_data:
tables.append(table_data)
# 提取文本内容
paragraphs = [p.get_text(strip=True) for p in soup.find_all('p')]
return {
"type": "webpage_extraction",
"url": url,
"title": title,
"tables": tables,
"text_content": paragraphs
}
except ImportError:
return {"error": "需要安装requests和BeautifulSoup库以处理网页"}
def process_api(self, api_url):
"""处理API数据"""
try:
import requests
response = requests.get(api_url)
if response.headers.get('content-type', '').startswith('application/json'):
return {
"type": "api_json",
"data": response.json()
}
else:
return {
"type": "api_response",
"content_type": response.headers.get('content-type'),
"data": response.text[:1000] # 限制大小
}
except ImportError:
return {"error": "需要安装requests库以处理API请求"}
9.3 行业应用前景与挑战
9.3.1 垂直行业应用场景
Function Calling驱动的智能分析系统在各行业有广阔应用前景:
mindmap
root((Function Calling<br>行业应用))
金融行业
投资分析
风险评估
欺诈检测
个人理财顾问
医疗健康
临床数据分析
医疗图像解读
健康风险预测
药物研发数据分析
零售电商
需求预测
库存优化
个性化推荐
消费者行为分析
制造业
产品质量控制
供应链优化
设备维护预测
生产效率分析
能源行业
能源消耗预测
电网负载分析
设备故障预警
资源优化配置
9.3.2 技术挑战与解决方案
实现高级数据分析Agent面临的主要挑战:
- 数据安全与隐私保护
class PrivacyPreservingAnalytics:
"""隐私保护数据分析框架"""
def __init__(self, privacy_level="high"):
self.privacy_level = privacy_level
self.anonymization_methods = {
"differential_privacy": self.apply_differential_privacy,
"k_anonymity": self.apply_k_anonymity,
"data_masking": self.apply_data_masking,
"synthetic_data": self.generate_synthetic_data
}
def process_sensitive_data(self, df, sensitive_columns, method="differential_privacy"):
"""处理敏感数据"""
if method not in self.anonymization_methods:
raise ValueError(f"不支持的隐私保护方法: {method}")
return self.anonymization_methods[method](df, sensitive_columns)
def apply_differential_privacy(self, df, sensitive_columns, epsilon=1.0):
"""应用差分隐私"""
import numpy as np
df_copy = df.copy()
for column in sensitive_columns:
if column in df.columns:
if np.issubdtype(df[column].dtype, np.number):
# 对数值列应用拉普拉斯噪声
sensitivity = (df[column].max() - df[column].min()) / 10 # 估计敏感度
noise_scale = sensitivity / epsilon
noise = np.random.laplace(0, noise_scale, size=len(df))
df_copy[column] = df[column] + noise
else:
# 对非数值列应用数据屏蔽
df_copy[column] = df_copy[column].apply(self.mask_text)
return df_copy
def apply_k_anonymity(self, df, sensitive_columns, k=5):
"""应用k匿名性"""
# 简化的k匿名实现
df_copy = df.copy()
# 对每个敏感列进行泛化
for column in sensitive_columns:
if column in df.columns:
if df[column].dtype == 'object':
# 字符串列 - 保留前两个字符
df_copy[column] = df_copy[column].astype(str).apply(
lambda x: x[:2] + '*' * (len(x) - 2) if len(x) > 2 else x
)
elif np.issubdtype(df[column].dtype, np.number):
# 数值列 - 分箱处理
bins = min(20, df[column].nunique())
df_copy[column] = pd.qcut(
df[column],
q=bins,
duplicates='drop'
).astype(str)
return df_copy
def apply_data_masking(self, df, sensitive_columns):
"""应用数据掩码"""
df_copy = df.copy()
for column in sensitive_columns:
if column in df.columns:
if df[column].dtype == 'object':
df_copy[column] = df_copy[column].astype(str).apply(self.mask_text)
elif np.issubdtype(df[column].dtype, np.number):
df_copy[column] = df_copy[column].apply(self.mask_number)
return df_copy
def generate_synthetic_data(self, df, sensitive_columns):
"""生成合成数据"""
# 简化的合成数据生成 - 实际应用可能使用更复杂的模型
import numpy as np
df_copy = df.copy()
for column in sensitive_columns:
if column in df.columns:
if np.issubdtype(df[column].dtype, np.number):
# 基于原始分布生成合成数值
mean, std = df[column].mean(), df[column].std()
df_copy[column] = np.random.normal(mean, std, size=len(df))
elif df[column].dtype == 'object':
# 对于分类数据,保持分布但随机打乱
df_copy[column] = np.random.permutation(df[column].values)
return df_copy
def mask_text(self, text):
"""文本掩码"""
if not isinstance(text, str):
return text
if len(text) <= 2:
return '*' * len(text)
else:
return text[0] + '*' * (len(text) - 2) + text[-1]
def mask_number(self, value):
"""数值掩码"""
if not isinstance(value, (int, float)) or np.isnan(value):
return value
# 添加小范围随机扰动
import random
perturbation = random.uniform(-0.1, 0.1)
return value * (1 + perturbation)
- 模型偏见与公平性
class FairnessAnalyzer:
"""分析决策公平性的工具"""
def __init__(self):
self.bias_metrics = {
"demographic_parity": self.demographic_parity,
"equal_opportunity": self.equal_opportunity,
"disparate_impact": self.disparate_impact
}
def analyze_fairness(self, df, prediction_column, sensitive_attributes, target_column=None):
"""分析预测结果的公平性"""
results = {
"overall_summary": {},
"metrics": {},
"recommendations": []
}
# 计算各指标
for metric_name, metric_func in self.bias_metrics.items():
if metric_name == "equal_opportunity" and target_column is None:
continue
metric_results = {}
for attribute in sensitive_attributes:
if attribute in df.columns:
args = [df, prediction_column, attribute]
if metric_name == "equal_opportunity":
args.append(target_column)
metric_value = metric_func(*args)
metric_results[attribute] = metric_value
results["metrics"][metric_name] = metric_results
# 生成总体摘要
results["overall_summary"] = self.generate_fairness_summary(results["metrics"])
# 提供改进建议
results["recommendations"] = self.generate_recommendations(results["metrics"])
return results
def demographic_parity(self, df, prediction_column, sensitive_attribute):
"""人口平等 - 不同群体应有相似的正面预测比率"""
if prediction_column not in df.columns or sensitive_attribute not in df.columns:
return None
# 按敏感属性分组计算接受率
group_acceptance = df.groupby(sensitive_attribute)[prediction_column].mean()
# 计算最大差异
max_diff = group_acceptance.max() - group_acceptance.min()
# 计算每组相对于整体的差异
overall_rate = df[prediction_column].mean()
relative_diff = {
group: (rate - overall_rate) / overall_rate
for group, rate in group_acceptance.items()
}
return {
"group_rates": group_acceptance.to_dict(),
"max_difference": max_diff,
"relative_differences": relative_diff
}
def equal_opportunity(self, df, prediction_column, sensitive_attribute, target_column):
"""机会平等 - 在实际正例中,不同群体应有相似的真阳性率"""
if not all(col in df.columns for col in [prediction_column, sensitive_attribute, target_column]):
return None
# 仅考虑实际正例
positive_df = df[df[target_column] == 1]
if len(positive_df) == 0:
return {"error": "没有发现实际正例"}
# 计算各组的真阳性率
tpr_by_group = positive_df.groupby(sensitive_attribute)[prediction_column].mean()
# 计算最大差异
max_diff = tpr_by_group.max() - tpr_by_group.min()
return {
"true_positive_rates": tpr_by_group.to_dict(),
"max_difference": max_diff
}
def disparate_impact(self, df, prediction_column, sensitive_attribute):
"""差别影响 - 不利群体与优势群体的接受率比例应超过阈值(通常为80%)"""
if prediction_column not in df.columns or sensitive_attribute not in df.columns:
return None
# 计算各组接受率
group_rates = df.groupby(sensitive_attribute)[prediction_column].mean()
# 找出接受率最高的组和最低的组
max_group = group_rates.idxmax()
min_group = group_rates.idxmin()
# 计算比率
ratio = group_rates[min_group] / group_rates[max_group]
return {
"group_rates": group_rates.to_dict(),
"min_max_ratio": ratio,
"passes_80_percent_rule": ratio >= 0.8
}
def generate_fairness_summary(self, metrics):
"""生成公平性分析摘要"""
# 实现摘要生成逻辑
return {
"fairness_score": self.calculate_fairness_score(metrics),
"critical_issues": self.identify_critical_issues(metrics),
"overall_assessment": self.generate_overall_assessment(metrics)
}
def calculate_fairness_score(self, metrics):
"""计算总体公平性得分"""
# 简化的得分计算
scores = []
if "demographic_parity" in metrics:
for attr_results in metrics["demographic_parity"].values():
if isinstance(attr_results, dict) and "max_difference" in attr_results:
# 差异越小越好,转换为0-100分
score = max(0, 100 - attr_results["max_difference"] * 100)
scores.append(score)
if "disparate_impact" in metrics:
for attr_results in metrics["disparate_impact"].values():
if isinstance(attr_results, dict) and "min_max_ratio" in attr_results:
# 比率越接近1越好,低于0.8有问题
ratio = attr_results["min_max_ratio"]
score = min(100, ratio * 100)
scores.append(score)
return sum(scores) / len(scores) if scores else 0
def identify_critical_issues(self, metrics):
"""识别关键公平性问题"""
issues = []
# 检查是否有严重的人口平等问题
if "demographic_parity" in metrics:
for attr, results in metrics["demographic_parity"].items():
if isinstance(results, dict) and "max_difference" in results:
if results["max_difference"] > 0.2:
issues.append(f"属性'{attr}'的人口平等差异过大: {results['max_difference']:.2f}")
# 检查是否有严重的差别影响问题
if "disparate_impact" in metrics:
for attr, results in metrics["disparate_impact"].items():
if isinstance(results, dict) and "passes_80_percent_rule" in results:
if not results["passes_80_percent_rule"]:
issues.append(f"属性'{attr}'未通过80%规则测试,比率为: {results['min_max_ratio']:.2f}")
return issues
def generate_overall_assessment(self, metrics):
"""生成总体评估"""
fairness_score = self.calculate_fairness_score(metrics)
if fairness_score >= 90:
return "系统表现出高度公平性,没有显著的偏见问题"
elif fairness_score >= 70:
return "系统整体公平性良好,但存在一些小的差异需要关注"
elif fair
结语
感谢您的阅读!期待您的一键三连!欢迎指正!