Erlo

大语言模型~Ollama本地模型和java一起体验LLM

2025-12-18 18:30:19 发布   61 浏览  
页面报错/反馈
收藏 点赞

说明

  • 用户输入多个“信息”
  • 大语言模型将“信息”进行处理,转成数组;(一维张量,向量)
  • 通过余弦相似度等相关算法,计算两个向量是否相似

Ollama接口步骤

  1. 安装 Ollama: https://ollama.ai/
  2. 下载模型: ollama pull nomic-embed-text
  3. Ollama 默认运行在 http://localhost:11434

推荐的嵌入模型:

  • nomic-embed-text: 768维,效果好,速度快
  • mxbai-embed-large: 1024维,效果更好
  • bge-m3: 多语言支持

图片

springboot中调用本地模型

    @Test
	@Disabled("需要本地运行 Ollama 服务")
	public void testOllamaEmbedding() {
		// Ollama API 地址
		String apiUrl = "http://localhost:11434/api/embeddings";
		String apiKey = ""; // Ollama 本地不需要 key
		String model = "nomic-embed-text"; // 或 mxbai-embed-large

		EmbeddingClient client = new EmbeddingClientImpl(apiUrl, apiKey);

		// 水果库
		List fruits = Arrays.asList(new Fruit("红富士苹果", "红色 甜 脆 苹果 新鲜"), new Fruit("青苹果", "绿色 酸 脆 苹果 清爽"),
				new Fruit("金帅苹果", "黄色 甜 软 苹果"), new Fruit("香蕉", "黄色 甜 软 香蕉 热带水果"), new Fruit("草莓", "红色 甜 小 草莓 多汁 浆果"),
				new Fruit("西瓜", "绿色外皮 红色果肉 甜 大 西瓜 多汁 夏天"), new Fruit("葡萄", "紫色 甜 小 葡萄 多汁 成串"));

		// 为每个水果生成嵌入向量
		for (Fruit fruit : fruits) {
			fruit.embedding = client.getEmbeddingVector(model, fruit.description);
		}

		// 用户搜索
		String query = "红色的甜水果";
		double[] queryVector = client.getEmbeddingVector(model, query);

		System.out.println("搜索: "" + query + """);
		System.out.println("向量维度: " + queryVector.length);
		System.out.println();

		// 按相似度排序
		fruits.sort(Comparator.comparingDouble(f -> -cosineSimilarity(queryVector, f.embedding)));

		// 输出结果
		System.out.println("搜索结果(按相似度排序):");
		for (Fruit f : fruits) {
			double sim = cosineSimilarity(queryVector, f.embedding);
			System.out.printf("  %s (%.4f): %s%n", f.name, sim, f.description);
		}
	}

    /**
	 * 计算两个向量的余弦相似度
	 */
	public static double cosineSimilarity(double[] vectorA, double[] vectorB) {
		if (vectorA.length != vectorB.length) {
			throw new IllegalArgumentException("向量维度必须相同");
		}

		double dotProduct = 0;
		double normA = 0;
		double normB = 0;

		for (int i = 0; i 

核心方法

@Slf4j
public class EmbeddingClientImpl implements EmbeddingClient {

	private final RestTemplate restTemplate;

	private final String address;

	private final String key;

	public EmbeddingClientImpl(String address, String key) {
		PoolingHttpClientConnectionManager connectionManager = new PoolingHttpClientConnectionManager();
		connectionManager.setMaxTotal(100);
		connectionManager.setDefaultMaxPerRoute(20);

		// 设置请求配置
		RequestConfig requestConfig = RequestConfig.custom()
				.setConnectionRequestTimeout(Timeout.ofSeconds(30))
				.setResponseTimeout(Timeout.ofSeconds(300)) // 5分钟响应超时
				.build();

		// 使用 HttpClientBuilder 来构建 HttpClient
		HttpClient httpClient = HttpClientBuilder.create()
				.setConnectionManager(connectionManager)
				.setDefaultRequestConfig(requestConfig)
				.build();

		// 创建 HttpComponentsClientHttpRequestFactory
		HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
		requestFactory.setConnectTimeout(30000); // 30秒连接超时
		requestFactory.setConnectionRequestTimeout(30000);

		// 创建 RestTemplate,只使用 StringHttpMessageConverter 避免 Jackson 依赖问题
		this.restTemplate = new RestTemplate(requestFactory);
		// 清除默认的消息转换器,只保留字符串转换器
		this.restTemplate.setMessageConverters(
				Collections.singletonList(new StringHttpMessageConverter(StandardCharsets.UTF_8)));

		this.address = address;
		this.key = key;
	}

	@Override
	public String embedding(String model, String input) {
		long start = System.currentTimeMillis();
		String url = address;

		HttpHeaders headers = new HttpHeaders();
		headers.setContentType(MediaType.APPLICATION_JSON);
		headers.setAcceptCharset(Collections.singletonList(StandardCharsets.UTF_8));
		if (key != null && !key.isEmpty()) {
			headers.add("Authorization", "Bearer " + key);
		}

		// 将 request 转化为 body 字符串
		JSONObject jsonObject = new JSONObject();
		jsonObject.put("input", input);
		jsonObject.put("model", model);
		String body = jsonObject.toString();
		log.debug("Embedding Request Body: {}", body);

		// 请求
		HttpEntity req = new HttpEntity(body, headers);

		ResponseEntity result = restTemplate.postForEntity(url, req, String.class);

		if (!result.getStatusCode().equals(HttpStatus.OK)) {
			throw new RuntimeException("embeddings error, request: " + body + ", response: " + result.getBody());
		}
		log.info("embedding cost {} ms", System.currentTimeMillis() - start);
		return result.getBody();
	}

	/**
	 * 获取文本嵌入向量
	 * 

* 解析 OpenAI 格式的响应,提取 embedding 向量 * * 响应格式示例:

	 * {
	 *   "object": "list",
	 *   "data": [{
	 *     "object": "embedding",
	 *     "index": 0,
	 *     "embedding": [0.0023064255, -0.009327292, ...]
	 *   }],
	 *   "model": "text-embedding-ada-002",
	 *   "usage": {"prompt_tokens": 8, "total_tokens": 8}
	 * }
	 * 
* @param model 模型名称 * @param input 输入文本 * @return 嵌入向量 */ @Override public double[] getEmbeddingVector(String model, String input) { String response = embedding(model, input); return parseEmbeddingVector(response); } /** * 解析嵌入向量响应 * @param response JSON响应字符串 * @return 向量数组 */ private double[] parseEmbeddingVector(String response) { try { JSONObject jsonResponse = JSONObject.parseObject(response); // OpenAI 格式 if (jsonResponse.containsKey("data")) { JSONArray dataArray = jsonResponse.getJSONArray("data"); if (dataArray != null && !dataArray.isEmpty()) { JSONObject firstData = dataArray.getJSONObject(0); JSONArray embeddingArray = firstData.getJSONArray("embedding"); return jsonArrayToDoubleArray(embeddingArray); } } // Ollama 格式 (直接返回 embedding 数组) if (jsonResponse.containsKey("embedding")) { JSONArray embeddingArray = jsonResponse.getJSONArray("embedding"); return jsonArrayToDoubleArray(embeddingArray); } // 阿里通义格式 if (jsonResponse.containsKey("output")) { JSONObject output = jsonResponse.getJSONObject("output"); if (output.containsKey("embeddings")) { JSONArray embeddings = output.getJSONArray("embeddings"); if (!embeddings.isEmpty()) { JSONObject firstEmbedding = embeddings.getJSONObject(0); JSONArray embeddingArray = firstEmbedding.getJSONArray("embedding"); return jsonArrayToDoubleArray(embeddingArray); } } } throw new RuntimeException("无法解析嵌入向量响应: " + response); } catch (Exception e) { log.error("解析嵌入向量失败: {}", response, e); throw new RuntimeException("解析嵌入向量失败", e); } } /** * 将 JSONArray 转换为 double 数组 */ private double[] jsonArrayToDoubleArray(JSONArray jsonArray) { double[] result = new double[jsonArray.size()]; for (int i = 0; i

图片

登录查看全部

参与评论

评论留言

还没有评论留言,赶紧来抢楼吧~~

浏览 2966.75 万次 点击这里给我发消息

手机查看

返回顶部

给这篇文章打个标签吧~

棒极了 糟糕透顶 好文章 PHP JAVA JS 小程序 Python SEO MySql 确认