350 lines
12 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "ota.h"
#include "system_info.h"
#include "board.h"
#include "settings.h"
#include <cJSON.h>
#include <esp_log.h>
#include <esp_partition.h>
#include <esp_ota_ops.h>
#include <esp_app_format.h>
#include <cstring>
#include <vector>
#include <sstream>
#include <algorithm>
#define TAG "Ota"
Ota::Ota() {
}
Ota::~Ota() {
}
void Ota::SetCheckVersionUrl(std::string check_version_url) {
check_version_url_ = check_version_url;
}
void Ota::SetHeader(const std::string& key, const std::string& value) {
headers_[key] = value;
}
void Ota::SetPostData(const std::string& post_data) {
post_data_ = post_data;
}
bool Ota::CheckVersion() {
current_version_ = esp_app_get_description()->version;
ESP_LOGI(TAG, "Current version: %s", current_version_.c_str());
if (check_version_url_.length() < 10) {
ESP_LOGE(TAG, "Check version URL is not properly set");
return false;
}
auto http = Board::GetInstance().CreateHttp();
for (const auto& header : headers_) {
http->SetHeader(header.first, header.second);
}
http->SetHeader("Content-Type", "application/json");
std::string method = post_data_.length() > 0 ? "POST" : "GET";
if (!http->Open(method, check_version_url_, post_data_)) {
ESP_LOGE(TAG, "Failed to open HTTP connection");
delete http;
return false;
}
auto response = http->GetBody();
http->Close();
delete http;
// Parse the JSON response and check if the version is newer
// If it is, set has_new_version_ to true and store the new version and URL
cJSON *root = cJSON_Parse(response.c_str());
if (root == NULL) {
ESP_LOGE(TAG, "Failed to parse JSON response");
return false;
}
has_activation_code_ = false;
cJSON *activation = cJSON_GetObjectItem(root, "activation");
if (activation != NULL) {
cJSON* message = cJSON_GetObjectItem(activation, "message");
if (message != NULL) {
activation_message_ = message->valuestring;
}
cJSON* code = cJSON_GetObjectItem(activation, "code");
if (code != NULL) {
activation_code_ = code->valuestring;
}
has_activation_code_ = true;
}
has_mqtt_config_ = false;
cJSON *mqtt = cJSON_GetObjectItem(root, "mqtt");
if (mqtt != NULL) {
Settings settings("mqtt", true);
cJSON *item = NULL;
cJSON_ArrayForEach(item, mqtt) {
if (item->type == cJSON_String) {
if (settings.GetString(item->string) != item->valuestring) {
settings.SetString(item->string, item->valuestring);
}
}
}
has_mqtt_config_ = true;
}
has_server_time_ = false;
cJSON *server_time = cJSON_GetObjectItem(root, "server_time");
if (server_time != NULL) {
cJSON *timestamp = cJSON_GetObjectItem(server_time, "timestamp");
cJSON *timezone_offset = cJSON_GetObjectItem(server_time, "timezone_offset");
if (timestamp != NULL) {
// 设置系统时间
struct timeval tv;
double ts = timestamp->valuedouble;
// 如果有时区偏移,计算本地时间
if (timezone_offset != NULL) {
ts += (timezone_offset->valueint * 60 * 1000); // 转换分钟为毫秒
}
tv.tv_sec = (time_t)(ts / 1000); // 转换毫秒为秒
tv.tv_usec = (suseconds_t)((long long)ts % 1000) * 1000; // 剩余的毫秒转换为微秒
settimeofday(&tv, NULL);
has_server_time_ = true;
}
}
cJSON *firmware = cJSON_GetObjectItem(root, "firmware");
if (firmware == NULL) {
ESP_LOGE(TAG, "Failed to get firmware object");
cJSON_Delete(root);
return false;
}
cJSON *version = cJSON_GetObjectItem(firmware, "version");
if (version == NULL) {
ESP_LOGE(TAG, "Failed to get version object");
cJSON_Delete(root);
return false;
}
cJSON *url = cJSON_GetObjectItem(firmware, "url");
if (url == NULL) {
ESP_LOGE(TAG, "Failed to get url object");
cJSON_Delete(root);
return false;
}
firmware_version_ = version->valuestring;
firmware_url_ = url->valuestring;
// 解析设备角色字段 - 严格校验模式
bool role_matched = false;// 角色匹配标志
std::string server_role = "";// 服务端角色
cJSON *role = cJSON_GetObjectItem(firmware, "role");// 获取 服务端角色字段
if (role != NULL && cJSON_IsString(role)) { // 服务端角色字段存在且为字符串类型
server_role = role->valuestring; // 服务端角色赋值
ESP_LOGI(TAG, "Server role: %s, Device role: %s", server_role.c_str(), CONFIG_DEVICE_ROLE);// 日志记录服务端角色和设备角色
if (server_role == CONFIG_DEVICE_ROLE) {// 服务端角色与设备角色匹配
role_matched = true; // 角色匹配标志设为true
ESP_LOGI(TAG, "Role verification passed: %s", CONFIG_DEVICE_ROLE);//角色验证通过!
} else {
ESP_LOGW(TAG, "Role mismatch (Device:%s vs Server:%s), upgrade denied", CONFIG_DEVICE_ROLE, server_role.c_str());//角色不匹配OTA升级被拒绝
}
} else {
ESP_LOGW(TAG, "服务端响应中没有角色字段OTA升级被拒绝");//服务端响应中没有角色字段OTA升级被拒绝
}
// 双重校验:角色匹配 + 版本检查
has_new_version_ = false; // 默认无可用更新
if (role_matched) {// 角色匹配标志位 为真时才进行版本检查
bool version_available = IsNewVersionAvailable(current_version_, firmware_version_);//检查当前版本是否比服务端版本新
if (version_available) {
has_new_version_ = true;
//角色匹配且有新的版本可用
ESP_LOGI(TAG, "✓ Role matched & New version available: %s -> %s", current_version_.c_str(), firmware_version_.c_str());
} else {
ESP_LOGI(TAG, "✓ Role matched but current version is latest: %s", current_version_.c_str());//角色匹配但当前版本已是最新
}
} else {
ESP_LOGW(TAG, "✗ Upgrade conditions not met - Role: %s, Version check: skipped",
role_matched ? "" : "");//升级条件未满足 - 角色:%s版本检查跳过
}
cJSON_Delete(root);
return true;
}
void Ota::MarkCurrentVersionValid() {
auto partition = esp_ota_get_running_partition();
if (strcmp(partition->label, "factory") == 0) {
ESP_LOGI(TAG, "Running from factory partition, skipping");
return;
}
ESP_LOGI(TAG, "Running partition: %s", partition->label);
esp_ota_img_states_t state;
if (esp_ota_get_state_partition(partition, &state) != ESP_OK) {
ESP_LOGE(TAG, "Failed to get state of partition");
return;
}
if (state == ESP_OTA_IMG_PENDING_VERIFY) {
ESP_LOGI(TAG, "Marking firmware as valid");
esp_ota_mark_app_valid_cancel_rollback();
}
}
void Ota::Upgrade(const std::string& firmware_url) {
ESP_LOGI(TAG, "Upgrading firmware from %s", firmware_url.c_str());
esp_ota_handle_t update_handle = 0;
auto update_partition = esp_ota_get_next_update_partition(NULL);
if (update_partition == NULL) {
ESP_LOGE(TAG, "Failed to get update partition");
return;
}
ESP_LOGI(TAG, "Writing to partition %s at offset 0x%lx", update_partition->label, update_partition->address);
bool image_header_checked = false;
std::string image_header;
auto http = Board::GetInstance().CreateHttp();
if (!http->Open("GET", firmware_url)) {
ESP_LOGE(TAG, "Failed to open HTTP connection");
delete http;
return;
}
size_t content_length = http->GetBodyLength();
if (content_length == 0) {
ESP_LOGE(TAG, "Failed to get content length");
delete http;
return;
}
char buffer[512];
size_t total_read = 0, recent_read = 0;
auto last_calc_time = esp_timer_get_time();
while (true) {
int ret = http->Read(buffer, sizeof(buffer));
if (ret < 0) {
ESP_LOGE(TAG, "Failed to read HTTP data: %s", esp_err_to_name(ret));
delete http;
return;
}
// Calculate speed and progress every second
recent_read += ret;
total_read += ret;
if (esp_timer_get_time() - last_calc_time >= 1000000 || ret == 0) {
size_t progress = total_read * 100 / content_length;
ESP_LOGI(TAG, "Progress: %zu%% (%zu/%zu), Speed: %zuB/s", progress, total_read, content_length, recent_read);
if (upgrade_callback_) {
upgrade_callback_(progress, recent_read);
}
last_calc_time = esp_timer_get_time();
recent_read = 0;
}
if (ret == 0) {
break;
}
if (!image_header_checked) {
image_header.append(buffer, ret);
if (image_header.size() >= sizeof(esp_image_header_t) + sizeof(esp_image_segment_header_t) + sizeof(esp_app_desc_t)) {
esp_app_desc_t new_app_info;
memcpy(&new_app_info, image_header.data() + sizeof(esp_image_header_t) + sizeof(esp_image_segment_header_t), sizeof(esp_app_desc_t));
ESP_LOGI(TAG, "New firmware version: %s", new_app_info.version);
auto current_version = esp_app_get_description()->version;
if (memcmp(new_app_info.version, current_version, sizeof(new_app_info.version)) == 0) {
ESP_LOGE(TAG, "Firmware version is the same, skipping upgrade");
delete http;
return;
}
if (esp_ota_begin(update_partition, OTA_WITH_SEQUENTIAL_WRITES, &update_handle)) {
esp_ota_abort(update_handle);
delete http;
ESP_LOGE(TAG, "Failed to begin OTA");
return;
}
image_header_checked = true;
std::string().swap(image_header);
}
}
auto err = esp_ota_write(update_handle, buffer, ret);
if (err != ESP_OK) {
ESP_LOGE(TAG, "Failed to write OTA data: %s", esp_err_to_name(err));
esp_ota_abort(update_handle);
delete http;
return;
}
}
delete http;
esp_err_t err = esp_ota_end(update_handle);
if (err != ESP_OK) {
if (err == ESP_ERR_OTA_VALIDATE_FAILED) {
ESP_LOGE(TAG, "Image validation failed, image is corrupted");
} else {
ESP_LOGE(TAG, "Failed to end OTA: %s", esp_err_to_name(err));
}
return;
}
err = esp_ota_set_boot_partition(update_partition);
if (err != ESP_OK) {
ESP_LOGE(TAG, "Failed to set boot partition: %s", esp_err_to_name(err));
return;
}
ESP_LOGI(TAG, "Firmware upgrade successful, rebooting in 3 seconds...");
vTaskDelay(pdMS_TO_TICKS(3000));
esp_restart();
}
void Ota::StartUpgrade(std::function<void(int progress, size_t speed)> callback) {
upgrade_callback_ = callback;
Upgrade(firmware_url_);
}
std::vector<int> Ota::ParseVersion(const std::string& version) {
std::vector<int> versionNumbers;
std::stringstream ss(version);
std::string segment;
while (std::getline(ss, segment, '.')) {
versionNumbers.push_back(std::stoi(segment));
}
return versionNumbers;
}
// 检查新的版本是否比当前版本新
bool Ota::IsNewVersionAvailable(const std::string& currentVersion, const std::string& newVersion) {
std::vector<int> current = ParseVersion(currentVersion);
std::vector<int> newer = ParseVersion(newVersion);
for (size_t i = 0; i < std::min(current.size(), newer.size()); ++i) {
if (newer[i] > current[i]) {
return true;
} else if (newer[i] < current[i]) {
return false;
}
}
return newer.size() > current.size();
}