本文最后更新于29 天前,其中的信息可能已经过时,如有错误请发送邮件到big_fw@foxmail.com
一、项目概述
本项目是一款基于Spring Boot + 百度文心一言大模型的 AI 驱动型数据库操作工具,支持通过自然语言自动解析用户意图、生成合规 SQL、执行数据库操作并返回可视化结果,无需手动编写 SQL 语句,大幅降低数据库操作门槛。核心能力包括:自然语言转 SQL、表结构自动校验、SQL 安全执行、操作结果自然语言解释,适用于非技术人员快速操作数据库或开发人员提升数据操作效率的场景。
二、前置准备步骤
1. 百度智能云账号注册与 API 配置
- 访问百度智能云千帆大模型平台:https://console.bce.baidu.com/qianfan/overview,完成账号注册与实名认证(注意:个人账号需完成实名认证才能调用文心一言 API)。
- 创建应用与获取密钥:
- 进入「千帆控制台 → 应用接入 → 创建应用」,填写应用名称(如 “AI-SQL 助手”)并提交。
应用创建后,在「应用详情」页获取 API Key 和 Secret Key(后续配置到项目中,需妥善保管,避免泄露)。
- 开通模型权限:
- 进入「千帆控制台 → 模型广场」,找到「ERNIE-Bot-4」「ERNIE-X1-Turbo-32K」模型,点击「申请开通」(免费额度可满足测试需求,生产环境需根据流量购买资源)。
2. 开发环境准备
- JDK 版本:建议使用 JDK 11(项目 POM 已指定,避免高版本 JDK 兼容性问题)。
- Maven 版本:3.6.0 及以上(确保依赖下载正常)。
- 数据库工具:可选 Navicat、DBeaver 等(用于辅助查看 H2/MySQL 数据,非必需)。
- 接口测试工具:Postman、Apifox 等(用于测试 API 接口,推荐)。
三、项目结构与核心组件说明
1. 配置文件说明
pom.xml
:定义了项目依赖,包括 Spring Boot 核心组件、H2 数据库、HTTP 客户端、Lombok 等工具库application.properties
:项目配置文件,包含数据库连接信息、H2 控制台配置、百度文心一言 API 密钥配置等schema.sql
:数据库初始化脚本,创建 products 表并插入示例数据
2. 核心 Java 类功能说明
模型类
DbOperationType
:枚举类型,定义了数据库操作类型(SELECT、INSERT、UPDATE、DELETE)DbCondition
:封装数据库查询条件,包含字段名、操作符和值DbIntent
:封装解析后的用户操作意图,包括操作类型、目标表、更新值和条件DbOperationResult
:封装数据库操作结果,包括是否成功、受影响行数和错误信息UpdateOperationResponse
:完整的更新操作响应对象,包含原始查询、解析意图、生成的 SQL、操作结果和解释说明
工具类与服务类
WenxinClient
:百度文心一言 API 调用工具,负责获取访问令牌和发送请求获取模型响应SchemaManager
:数据库表结构管理工具,用于获取表信息和验证表、字段是否存在IntentParser
:用户意图解析器,将自然语言转换为结构化的 DbIntent 对象SqlGenerator
:SQL 生成器,根据 DbIntent 对象生成对应的 UPDATE SQL 语句DbExecutor
:数据库操作执行器,负责执行生成的 SQL 语句并返回执行结果LLMService
:大模型服务类,协调意图解析、SQL 生成、执行和结果解释的完整流程DbController
:API 控制器,提供 HTTP 接口供外部调用
四、源码
1.pom.xml文件
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.7.8</version>
<relativePath/>
</parent>
<groupId>com.example</groupId>
<artifactId>aidb</artifactId>
<version>0.0.1-SNAPSHOT</version>
<description>Spring Boot for database updates</description>
<properties>
<java.version>11</java.version>
<!-- 新增:统一管理依赖版本,便于维护 -->
<lombok.version>1.18.32</lombok.version> <!-- 支持JDK 11的Lombok版本 -->
<httpclient.version>4.5.14</httpclient.version>
</properties>
<dependencies>
<!-- Spring Boot 基础依赖 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-jdbc</artifactId>
</dependency>
<!-- H2内存数据库 -->
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<scope>runtime</scope>
</dependency>
<!-- HTTP客户端(调用文心一言API) -->
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
<version>${httpclient.version}</version>
</dependency>
<!-- 移除:OpenAI客户端(若不使用OpenAI接口) -->
<!-- <dependency>
<groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>service</artifactId>
<version>0.12.0</version>
</dependency> -->
<!-- Lombok(指定版本,解决JDK兼容问题) -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>${lombok.version}</version> <!-- 使用 properties 中定义的版本 -->
<optional>true</optional>
</dependency>
<!-- Jackson(JSON解析) -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<!-- 测试依赖 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.33</version> <!-- 可根据实际情况选择版本 -->
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</exclude>
</excludes>
</configuration>
</plugin>
<!-- 新增:显式配置编译器插件,指定注解处理器 -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.11.0</version>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
<annotationProcessorPaths>
<!-- 显式指定Lombok注解处理器,确保编译期生效 -->
<path>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>${lombok.version}</version>
</path>
</annotationProcessorPaths>
</configuration>
</plugin>
</plugins>
</build>
</project>
2.应用主类
package org.example;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class DbApplication {
public static void main(String[] args) {
SpringApplication.run(DbApplication.class, args);
}
}
3.配置文件(application.properties)
spring.datasource.url=jdbc:h2:mem:testdb
spring.datasource.driverClassName=org.h2.Driver
spring.datasource.username=sa
spring.datasource.password=sa123
spring.h2.console.enabled=true
spring.h2.console.path=/h2-console
openai.api.key=your_openai_api_key
openai.model=gpt-3.5-turbo
# 百度文心一言配置 - 需要替换为你的API密钥
wenxin.api.key=AAA
wenxin.secret.key=BBB #(填和apikey一样即可)
spring.sql.init.mode=always
spring.sql.init.schema-locations=classpath:schema.sql
在resources目录下创建sql文件schema.sql
-- schema.sql
DROP TABLE IF EXISTS products;
CREATE TABLE products (
id INT AUTO_INCREMENT PRIMARY KEY,
product_name VARCHAR(100) NOT NULL,
category VARCHAR(50),
price DECIMAL(10, 2),
stock INT,
description TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- 插入示例数据
INSERT INTO products (product_name, category, price, stock, description) VALUES
('iPhone 13', 'Electronics', 5999.00, 100, 'Apple iPhone 13 smartphone with A15 Bionic chip'),
('Samsung Galaxy S21', 'Electronics', 4599.00, 85, 'Samsung flagship phone with 5G capability'),
('Macbook Pro', 'Computers', 12999.00, 30, 'Apple laptop with M1 Pro chip'),
('Dell XPS 13', 'Computers', 8299.00, 45, 'Dell premium ultrabook with Intel Core i7'),
('Logitech MX Master 3', 'Accessories', 599.00, 200, 'Advanced wireless mouse'),
('Sony WH-1000XM4', 'Audio', 2299.00, 60, 'Wireless noise-cancelling headphones');
4.模型类
package org.example;
public enum DbOperationType {
SELECT, INSERT, UPDATE, DELETE
}
package org.example;
import lombok.Data;
@Data
public class DbCondition {
private String field;
private String operator;
private Object value;
}
package org.example;
import lombok.Data;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Data
public class DbIntent {
private DbOperationType type;
private String targetTable;
private Map<String, Object> updateValues = new HashMap<>();
private List<DbCondition> conditions = new ArrayList<>();
}
package org.example;
import lombok.Data;
@Data
public class DbOperationResult {
private boolean success;
private Integer affectedRows;
private String errorMessage;
}
package org.example;
import lombok.Data;
@Data
public class UpdateOperationResponse {
private String originalQuery;
private DbIntent intent;
private String generatedSql;
private DbOperationResult result;
private String explanation;
}
5.文心一言访问工具
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
@Service
public class WenxinClient {
@Value("${wenxin.api.key}")
private String apiKey;
@Value("${wenxin.secret.key}")
private String secretKey;
private static final String TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token";
private static final String API_URL = "https://api.baiduce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions";
private static final String API_URL_V2 = "https://qianfan.baidubce.com/v2/chat/completions";
private static final String MODEL = "ernie-bot-4";
private static final String MODEL_new_x1 = "ernie-x1-turbo-32k";
private static final String MODEL_new_4$5 = "ernie-4.5-turbo-128k";
private String accessToken;
private long tokenExpireTime;
// 获取访问令牌
private String getAccessToken() throws IOException {
// 如果令牌有效,直接返回
if (accessToken != null && System.currentTimeMillis() < tokenExpireTime) {
return accessToken;
}
// 否则请求新令牌
try (CloseableHttpClient httpClient = HttpClients.createDefault()) {
HttpPost httpPost = new HttpPost(TOKEN_URL +
"?grant_type=client_credentials&client_id=" + apiKey +
"&client_secret=" + secretKey);
try (CloseableHttpResponse response = httpClient.execute(httpPost)) {
HttpEntity entity = response.getEntity();
String result = EntityUtils.toString(entity);
ObjectMapper mapper = new ObjectMapper();
JsonNode rootNode = mapper.readTree(result);
this.accessToken = rootNode.path("access_token").asText();
int expiresIn = rootNode.path("expires_in").asInt();
this.tokenExpireTime = System.currentTimeMillis() + (expiresIn * 1000L);
return this.accessToken;
}
}
}
// 生成回复
public String generateCompletion(String prompt) throws IOException {
// String token = getAccessToken();
try (CloseableHttpClient httpClient = HttpClients.createDefault()) {
HttpPost httpPost = new HttpPost(API_URL_V2);
httpPost.setHeader("Content-Type", "application/json");
httpPost.setHeader("Authorization","Bearer "+apiKey);
ObjectMapper mapper = new ObjectMapper();
ObjectNode requestBody = mapper.createObjectNode();
requestBody.put("model", MODEL_new_x1);
ArrayNode messages = mapper.createArrayNode();
ObjectNode userMessage = mapper.createObjectNode();
userMessage.put("role", "user");
userMessage.put("content", prompt);
messages.add(userMessage);
requestBody.set("messages", messages);
StringEntity entity = new StringEntity(requestBody.toString(), StandardCharsets.UTF_8);
httpPost.setEntity(entity);
try (CloseableHttpResponse response = httpClient.execute(httpPost)) {
HttpEntity responseEntity = response.getEntity();
String result = EntityUtils.toString(responseEntity);
JsonNode rootNode = mapper.readTree(result);
if (rootNode.has("error_code")) {
throw new IOException("API error: " + rootNode.path("error_msg").asText());
}
// return rootNode.path("result").asText();
return rootNode.path("choices").path(0).path("message").path("content").asText();
}
}
}
// 发送带有系统提示的对话
public String generateCompletionWithSystem(String systemPrompt, String userPrompt) throws IOException {
String token = getAccessToken();
try (CloseableHttpClient httpClient = HttpClients.createDefault()) {
HttpPost httpPost = new HttpPost(API_URL + "?access_token=" + token);
httpPost.setHeader("Content-Type", "application/json");
ObjectMapper mapper = new ObjectMapper();
ObjectNode requestBody = mapper.createObjectNode();
requestBody.put("model", MODEL);
ArrayNode messages = mapper.createArrayNode();
// 添加系统消息
ObjectNode systemMessage = mapper.createObjectNode();
systemMessage.put("role", "system");
systemMessage.put("content", systemPrompt);
messages.add(systemMessage);
// 添加用户消息
ObjectNode userMessage = mapper.createObjectNode();
userMessage.put("role", "user");
userMessage.put("content", userPrompt);
messages.add(userMessage);
requestBody.set("messages", messages);
StringEntity entity = new StringEntity(requestBody.toString(), StandardCharsets.UTF_8);
httpPost.setEntity(entity);
try (CloseableHttpResponse response = httpClient.execute(httpPost)) {
HttpEntity responseEntity = response.getEntity();
String result = EntityUtils.toString(responseEntity);
JsonNode rootNode = mapper.readTree(result);
if (rootNode.has("error_code")) {
throw new IOException("API error: " + rootNode.path("error_msg").asText());
}
return rootNode.path("result").asText();
}
}
}
}
6.Schema管理器
package org.example;
import lombok.Data;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.stream.Collectors;
@Service
public class SchemaManager {
@Autowired
private JdbcTemplate jdbcTemplate;
// 获取表结构 - 最简化版本
public TableSchema getTableSchema(String tableName) {
// 验证表是否存在
List<String> tables = jdbcTemplate.queryForList(
"SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = ?",
String.class, tableName.toUpperCase()
);
if (tables.isEmpty()) {
throw new RuntimeException("Table not found: " + tableName);
}
// 获取列信息
List<ColumnSchema> columns = jdbcTemplate.query(
"SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS " +
"WHERE TABLE_NAME = ?",
(rs, rowNum) -> {
ColumnSchema column = new ColumnSchema();
column.setColumnName(rs.getString("COLUMN_NAME").toLowerCase());
column.setDataType(rs.getString("DATA_TYPE").toLowerCase());
// 简单假设:id列是主键(这在大多数应用中是常见的约定)
String colName = rs.getString("COLUMN_NAME").toLowerCase();
column.setPrimaryKey(colName.equals("id"));
return column;
},
tableName.toUpperCase()
);
TableSchema schema = new TableSchema();
schema.setTableName(tableName.toLowerCase());
schema.setColumns(columns);
return schema;
}
// 获取所有表名
public List<String> getAllTableNames() {
return jdbcTemplate.queryForList(
"SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'PUBLIC'",
String.class
).stream().map(String::toLowerCase).collect(Collectors.toList());
}
@Data
public static class TableSchema {
private String tableName;
private List<ColumnSchema> columns;
// 检查字段是否存在
public boolean hasColumn(String columnName) {
return columns.stream()
.anyMatch(col -> col.getColumnName().equalsIgnoreCase(columnName));
}
// 获取主键列
public List<String> getPrimaryKeys() {
return columns.stream()
.filter(ColumnSchema::isPrimaryKey)
.map(ColumnSchema::getColumnName)
.collect(Collectors.toList());
}
}
@Data
public static class ColumnSchema {
private String columnName;
private String dataType;
private boolean primaryKey;
}
}
7.意图解析器
package org.example;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.io.IOException;
import java.util.List;
@Service
public class IntentParser {
/* @Autowired
private OpenAIService openAIService;*/
@Autowired
private WenxinClient wenxinClient;
@Autowired
private SchemaManager schemaManager;
// 解析更新操作的自然语言意图
public DbIntent parseUpdateIntent(String naturalLanguageQuery) {
// 首先确定目标表
String targetTable = determineTargetTable(naturalLanguageQuery);
// 获取表结构
SchemaManager.TableSchema tableSchema = schemaManager.getTableSchema(targetTable);
// 构建提示
String prompt = buildUpdatePrompt(tableSchema, naturalLanguageQuery);
// 解析JSON响应
String response = null;
try {
// 调用大模型解析意图
response = wenxinClient.generateCompletion(prompt);
// 提取JSON部分
response = extractJsonFromResponse(response);
ObjectMapper mapper = new ObjectMapper();
return mapper.readValue(response, DbIntent.class);
} catch (Exception e) {
throw new RuntimeException("Failed to parse model response: " + e.getMessage() + "\nResponse: " + response, e);
}
}
// 确定目标表
private String determineTargetTable(String query) {
List<String> tables = schemaManager.getAllTableNames();
String tableListStr = String.join(", ", tables);
String prompt = "Based on this user request, determine which database table is being referenced. " +
"Reply with just the table name, nothing else.\n\n" +
"Available tables: " + tableListStr + "\n\n" +
"User request: " + query;
String response = null;
try {
response = wenxinClient.generateCompletion(prompt);
} catch (IOException e) {
throw new RuntimeException(e);
}
return response.trim().toLowerCase();
}
// 构建更新操作的提示
private String buildUpdatePrompt(SchemaManager.TableSchema schema, String query) {
StringBuilder prompt = new StringBuilder();
// 系统指令
prompt.append("Convert this natural language request into a database update operation.\n\n");
// 表结构信息
prompt.append("Table: ").append(schema.getTableName()).append("\n");
prompt.append("Columns: \n");
for (SchemaManager.ColumnSchema column : schema.getColumns()) {
prompt.append("- ").append(column.getColumnName())
.append(" (").append(column.getDataType()).append(")");
if (column.isPrimaryKey()) {
prompt.append(" [PRIMARY KEY]");
}
prompt.append("\n");
}
// 用户请求
prompt.append("\nUser request: ").append(query).append("\n\n");
// 输出格式要求
prompt.append("Return a JSON object with this exact structure:\n");
prompt.append("{\n");
prompt.append(" "type": "UPDATE",\n");
prompt.append(" "targetTable": "").append(schema.getTableName()).append("",\n");
prompt.append(" "updateValues": {\n");
prompt.append(" "column_name": "new_value"\n");
prompt.append(" },\n");
prompt.append(" "conditions": [\n");
prompt.append(" {"field": "column_name", "operator": "=", "value": "value"}\n");
prompt.append(" ]\n");
prompt.append("}\n\n");
prompt.append("Make sure the field names match exactly with the table schema provided above. " +
"Convert all values to appropriate data types.");
return prompt.toString();
}
// 从响应中提取JSON部分
private String extractJsonFromResponse(String response) {
int startIdx = response.indexOf('{');
int endIdx = response.lastIndexOf('}');
if (startIdx != -1 && endIdx != -1 && endIdx > startIdx) {
return response.substring(startIdx, endIdx + 1);
}
return response; // 如果找不到JSON结构,返回原始响应
}
}
8.SQL生成器
package org.example;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Service
public class SqlGenerator {
@Autowired
private SchemaManager schemaManager;
// 生成UPDATE SQL语句
public String generateUpdateSql(DbIntent intent) {
// 验证意图
validateUpdateIntent(intent);
StringBuilder sql = new StringBuilder("UPDATE ");
sql.append(intent.getTargetTable()).append(" SET ");
// 添加要更新的字段和值
List<String> setStatements = new ArrayList<>();
for (Map.Entry<String, Object> entry : intent.getUpdateValues().entrySet()) {
setStatements.add(entry.getKey() + " = " + formatValue(entry.getValue()));
}
sql.append(String.join(", ", setStatements));
// 添加WHERE条件
if (intent.getConditions() != null && !intent.getConditions().isEmpty()) {
sql.append(" WHERE ");
List<String> conditionStrings = intent.getConditions().stream()
.map(this::formatCondition)
.collect(Collectors.toList());
sql.append(String.join(" AND ", conditionStrings));
} else {
throw new RuntimeException("UPDATE operation must have conditions");
}
return sql.toString();
}
// 格式化SQL值
private String formatValue(Object value) {
if (value == null) {
return "NULL";
} else if (value instanceof String) {
return "'" + ((String)value).replace("'", "''") + "'";
} else if (value instanceof Number) {
return value.toString();
} else if (value instanceof Boolean) {
return ((Boolean)value) ? "1" : "0";
} else {
return "'" + value.toString().replace("'", "''") + "'";
}
}
// 格式化条件
private String formatCondition(DbCondition condition) {
return condition.getField() + " " + condition.getOperator() + " " +
formatValue(condition.getValue());
}
// 验证更新意图
private void validateUpdateIntent(DbIntent intent) {
if (intent.getType() != DbOperationType.UPDATE) {
throw new RuntimeException("Expected UPDATE operation, got: " + intent.getType());
}
if (intent.getUpdateValues() == null || intent.getUpdateValues().isEmpty()) {
throw new RuntimeException("No fields to update");
}
if (intent.getConditions() == null || intent.getConditions().isEmpty()) {
throw new RuntimeException("UPDATE operation must have conditions");
}
// 验证表和字段是否存在
SchemaManager.TableSchema schema = schemaManager.getTableSchema(intent.getTargetTable());
// 验证更新字段
for (String field : intent.getUpdateValues().keySet()) {
if (!schema.hasColumn(field)) {
throw new RuntimeException("Column not found: " + field);
}
}
// 验证条件字段
for (DbCondition condition : intent.getConditions()) {
if (!schema.hasColumn(condition.getField())) {
throw new RuntimeException("Column not found in condition: " + condition.getField());
}
}
}
}
### 9.执行引擎
package org.example;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service;
@Service
public class DbExecutor {
@Autowired
private JdbcTemplate jdbcTemplate;
// 执行UPDATE SQL
public DbOperationResult executeUpdate(String sql) {
DbOperationResult result = new DbOperationResult();
try {
// 执行更新
int affectedRows = jdbcTemplate.update(sql);
result.setSuccess(true);
result.setAffectedRows(affectedRows);
if (affectedRows == 0) {
result.setErrorMessage("No rows were updated. The condition may not match any records.");
}
} catch (Exception e) {
result.setSuccess(false);
result.setErrorMessage(e.getMessage());
}
return result;
}
}
10.大模型服务
package org.example;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.io.IOException;
@Service
@Slf4j
public class LLMService {
@Autowired
private IntentParser intentParser;
@Autowired
private SqlGenerator sqlGenerator;
@Autowired
private DbExecutor dbExecutor;
/*@Autowired
private OpenAIService openAIService;*/
@Autowired
private WenxinClient wenxinClient;
// 处理更新数据的自然语言请求
public UpdateOperationResponse processUpdateRequest(String naturalLanguageQuery) {
UpdateOperationResponse response = new UpdateOperationResponse();
response.setOriginalQuery(naturalLanguageQuery);
try {
// 1. 解析意图
DbIntent intent = intentParser.parseUpdateIntent(naturalLanguageQuery);
response.setIntent(intent);
// 2. 生成SQL
String sql = sqlGenerator.generateUpdateSql(intent);
response.setGeneratedSql(sql);
// 3. 执行SQL
DbOperationResult result = dbExecutor.executeUpdate(sql);
response.setResult(result);
// 4. 生成自然语言解释
String explanation = generateExplanation(intent, result);
response.setExplanation(explanation);
} catch (Exception e) {
log.error(e.getMessage(), e);
DbOperationResult errorResult = new DbOperationResult();
errorResult.setSuccess(false);
errorResult.setErrorMessage(e.getMessage());
response.setResult(errorResult);
response.setExplanation("无法完成您的请求: " + e.getMessage());
}
return response;
}
// 生成自然语言解释
private String generateExplanation(DbIntent intent, DbOperationResult result) {
if (!result.isSuccess()) {
return "操作失败: " + result.getErrorMessage();
}
// 构建提示
StringBuilder prompt = new StringBuilder();
prompt.append("请用简洁的中文解释以下数据库操作的结果:\n\n");
// 添加意图和结果信息
prompt.append("表: ").append(intent.getTargetTable()).append("\n");
prompt.append("更新内容: ").append(intent.getUpdateValues()).append("\n");
prompt.append("条件: ").append(intent.getConditions()).append("\n");
prompt.append("受影响行数: ").append(result.getAffectedRows()).append("\n\n");
// 调用LLM生成解释
try {
return wenxinClient.generateCompletion(prompt.toString());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
11.API控制器
package org.example;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
@RestController
@RequestMapping("/api/db")
public class DbController {
@Autowired
private LLMService LLMService;
// 处理数据更新请求
@PostMapping("/update")
public ResponseEntity<UpdateOperationResponse> updateData(@RequestParam String query) {
UpdateOperationResponse response = LLMService.processUpdateRequest(query);
return ResponseEntity.ok(response);
}
}
五、测试效果
访问H2控制台查看数据变化
- 打开浏览器访问:
http://localhost:8080/h2-console
- 使用以下配置连接
- • JDBC URL:
jdbc:h2:mem:testdb
- • User Name:
sa
- • Password:
sa123
- • JDBC URL:
- 连接后执行SQL:
SELECT * FROM products
查看更新结果
六、功能特点与注意事项
功能特点
- 自然语言交互:用户无需掌握 SQL 语法,用日常语言即可完成数据库操作
- 自动验证:会自动验证表和字段是否存在,避免无效操作
- 安全防护:UPDATE 操作必须包含条件,防止误操作导致全表更新
- 结果解释:将技术化的操作结果转换为自然语言说明,便于理解
注意事项
- 需正确配置百度文心一言的 API Key 和 Secret Key,否则无法调用模型
- 目前主要实现了 UPDATE 操作,如需其他操作类型(SELECT、INSERT、DELETE)需进行扩展
- H2 数据库仅适用于测试环境,生产环境应更换为 MySQL 等正式数据库
- 大模型的解析结果可能存在不确定性,重要操作前建议确认生成的 SQL 语句