tari-rust/src/main.rs

567 lines
20 KiB
Rust
Raw Normal View History

2025-06-25 12:34:00 +00:00
// Copyright 2024. The Tari Project
//
// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
// following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
// disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the
// following disclaimer in the documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
// products derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};
use anyhow::{anyhow, Result};
use clap::Parser;
use log::*;
use serde::{Deserialize, Serialize};
use tokio::{sync::Mutex, time::sleep};
use tonic::transport::{Certificate, ClientTlsConfig, Endpoint};
use zmq::{Context, Message, Socket};
use minotari_app_grpc::{
authentication::ClientAuthenticationInterceptor,
conversions::transaction_output::grpc_output_with_payref,
tari_rpc::{
base_node_client::BaseNodeClient, pow_algo::PowAlgos, Block, NewBlockTemplateRequest, PowAlgo,
SubmitBlockResponse,
},
};
use minotari_app_utilities::parse_miner_input::BaseNodeGrpcClient;
use std::str::FromStr;
use tari_common::configuration::Network;
use tari_common::MAX_GRPC_MESSAGE_SIZE;
use tari_common_types::{grpc_authentication::GrpcAuthentication, tari_address::TariAddress};
use tari_core::{
consensus::ConsensusManager,
transactions::{
generate_coinbase,
tari_amount::MicroMinotari,
transaction_components::{
encrypted_data::{PaymentId, TxType},
CoinBaseExtra, RangeProofType,
},
transaction_key_manager::{create_memory_db_key_manager, MemoryDbKeyManager},
},
};
use tari_utilities::hex::Hex;
use tari_utilities::ByteArray;
use jmt::{JellyfishMerkleTree, KeyHash};
use jmt::mock::MockTreeStore;
use tari_core::chain_storage::SmtHasher;
use tari_core::blocks::Block as CoreBlock;
2025-06-25 12:34:00 +00:00
const LOG_TARGET: &str = "gbt::main";
// ZMQ消息结构
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct MiningTask {
pub coinbase_hash: String,
pub height: u64,
pub target: u64,
2025-06-26 07:22:52 +00:00
pub output_smt_size: u64, // 新增output_smt_size
2025-06-25 12:34:00 +00:00
pub block_template: String, // 序列化的区块模板
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SubmitRequest {
pub height: u64,
pub nonce: u64,
pub solution: String,
pub block_data: String, // 序列化的区块数据
}
// 配置结构
#[derive(Debug, Clone)]
pub struct GbtConfig {
pub base_node_grpc_address: String,
pub base_node_grpc_authentication: GrpcAuthentication,
pub base_node_grpc_tls_domain_name: Option<String>,
pub base_node_grpc_ca_cert_filename: Option<String>,
pub config_dir: PathBuf,
pub network: Network,
pub wallet_payment_address: String,
pub coinbase_extra: String,
pub range_proof_type: RangeProofType,
pub zmq_publisher_port: u16,
pub zmq_subscriber_port: u16,
}
// GBT客户端
pub struct GbtClient {
base_node_client: BaseNodeGrpcClient,
key_manager: MemoryDbKeyManager,
consensus_manager: ConsensusManager,
wallet_payment_address: TariAddress,
config: GbtConfig,
// ZMQ相关
#[allow(dead_code)]
zmq_context: Context,
publisher_socket: Socket,
subscriber_socket: Socket,
// 挖矿任务缓存
mining_tasks: Arc<Mutex<HashMap<String, MiningTask>>>,
}
impl GbtClient {
pub async fn new(config: GbtConfig) -> Result<Self> {
// 创建BaseNode客户端
let base_node_client = Self::connect_base_node(&config).await?;
// 创建密钥管理器
let key_manager = create_memory_db_key_manager().map_err(|e| anyhow!("Key manager error: {}", e))?;
// 创建共识管理器
let consensus_manager = ConsensusManager::builder(config.network)
.build()
.map_err(|e| anyhow!("Consensus manager error: {}", e))?;
// 解析钱包地址
let wallet_payment_address = TariAddress::from_str(&config.wallet_payment_address)
.map_err(|e| anyhow!("Invalid wallet address: {}", e))?;
// 创建ZMQ上下文和套接字
let zmq_context = Context::new();
let publisher_socket = zmq_context
.socket(zmq::PUB)
.map_err(|e| anyhow!("ZMQ publisher error: {}", e))?;
let subscriber_socket = zmq_context
.socket(zmq::SUB)
.map_err(|e| anyhow!("ZMQ subscriber error: {}", e))?;
// 绑定ZMQ套接字
let publisher_addr = format!("tcp://*:{}", config.zmq_publisher_port);
let subscriber_addr = format!("tcp://localhost:{}", config.zmq_subscriber_port);
publisher_socket
.bind(&publisher_addr)
.map_err(|e| anyhow!("ZMQ bind error: {}", e))?;
subscriber_socket
.connect(&subscriber_addr)
.map_err(|e| anyhow!("ZMQ connect error: {}", e))?;
subscriber_socket
.set_subscribe(b"submit")
.map_err(|e| anyhow!("ZMQ subscribe error: {}", e))?;
Ok(Self {
base_node_client,
key_manager,
consensus_manager,
wallet_payment_address,
config,
zmq_context,
publisher_socket,
subscriber_socket,
mining_tasks: Arc::new(Mutex::new(HashMap::new())),
})
}
// 连接BaseNode
async fn connect_base_node(config: &GbtConfig) -> Result<BaseNodeGrpcClient> {
info!(target: LOG_TARGET, "Connecting to base node at {}", config.base_node_grpc_address);
let address = format!("http://{}", config.base_node_grpc_address);
let mut endpoint = Endpoint::new(address)?;
// 配置TLS如果需要
if let Some(domain_name) = config.base_node_grpc_tls_domain_name.as_ref() {
if let Some(cert_filename) = config.base_node_grpc_ca_cert_filename.as_ref() {
let cert_path = config.config_dir.join(cert_filename);
let pem = tokio::fs::read(cert_path)
.await
.map_err(|e| anyhow!("TLS certificate read error: {}", e))?;
let ca = Certificate::from_pem(pem);
let tls = ClientTlsConfig::new().ca_certificate(ca).domain_name(domain_name);
endpoint = endpoint
.tls_config(tls)
.map_err(|e| anyhow!("TLS config error: {}", e))?;
}
}
let channel = endpoint
.connect()
.await
.map_err(|e| anyhow!("Connection error: {}", e))?;
let node_conn = BaseNodeClient::with_interceptor(
channel,
ClientAuthenticationInterceptor::create(&config.base_node_grpc_authentication)
.map_err(|e| anyhow!("Authentication error: {}", e))?,
)
.max_encoding_message_size(MAX_GRPC_MESSAGE_SIZE)
.max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE);
Ok(node_conn)
}
/// 计算output_smt_size
fn calculate_output_smt_size(&self, block: &CoreBlock, prev_output_smt_size: u64) -> Result<u64> {
// 创建JellyfishMerkleTree用于计算
let mock_store = MockTreeStore::new(true);
let output_smt = JellyfishMerkleTree::<_, SmtHasher>::new(&mock_store);
let mut batch = Vec::new();
// 处理所有输出(添加新的叶子节点)
for output in block.body.outputs() {
if !output.is_burned() {
let smt_key = KeyHash(
output.commitment.as_bytes().try_into().expect("commitment is 32 bytes")
);
let smt_value = output.smt_hash(block.header.height);
batch.push((smt_key, Some(smt_value.to_vec())));
}
}
// 处理所有输入(删除叶子节点)
for input in block.body.inputs() {
let smt_key = KeyHash(
input.commitment()?.as_bytes().try_into().expect("Commitment is 32 bytes")
);
batch.push((smt_key, None));
}
// 计算SMT变化
let (_, changes) = output_smt
.put_value_set(batch, block.header.height)
.map_err(|e| anyhow!("SMT calculation error: {}", e))?;
// 计算新的output_smt_size
let mut size = prev_output_smt_size;
size += changes.node_stats.first().map(|s| s.new_leaves).unwrap_or(0) as u64;
size = size.saturating_sub(changes.node_stats.first().map(|s| s.stale_leaves).unwrap_or(0) as u64);
Ok(size)
}
2025-06-25 12:34:00 +00:00
pub async fn get_block_template_and_coinbase(&mut self) -> Result<MiningTask> {
info!(target: LOG_TARGET, "Getting new block template");
// 获取区块模板
let pow_algo = PowAlgo {
pow_algo: PowAlgos::Sha3x.into(),
};
let request = NewBlockTemplateRequest {
algo: Some(pow_algo),
max_weight: 0,
};
let template_response = self
.base_node_client
.get_new_block_template(request)
.await?
.into_inner();
let mut block_template = template_response
.new_block_template
.clone()
.ok_or_else(|| anyhow!("No block template received"))?;
let height = block_template
.header
.as_ref()
.ok_or_else(|| anyhow!("No header in block template"))?
.height;
// 获取挖矿数据
let miner_data = template_response
.miner_data
.ok_or_else(|| anyhow!("No miner data received"))?;
let fee = MicroMinotari::from(miner_data.total_fees);
let reward = MicroMinotari::from(miner_data.reward);
let target_difficulty = miner_data.target_difficulty;
info!(target: LOG_TARGET, "Generating coinbase for height {}", height);
// 生成coinbase
let (coinbase_output, coinbase_kernel) = generate_coinbase(
fee,
reward,
height,
&CoinBaseExtra::try_from(self.config.coinbase_extra.as_bytes().to_vec())?,
&self.key_manager,
&self.wallet_payment_address,
true,
self.consensus_manager.consensus_constants(height),
self.config.range_proof_type,
PaymentId::Open {
user_data: vec![],
tx_type: TxType::Coinbase,
},
)
.await
.map_err(|e| anyhow!("Coinbase generation error: {}", e))?;
// 将coinbase添加到区块模板
let body = block_template
.body
.as_mut()
.ok_or_else(|| anyhow!("No body in block template"))?;
let grpc_output = grpc_output_with_payref(coinbase_output.clone(), None)
.map_err(|e| anyhow!("Output conversion error: {}", e))?;
body.outputs.push(grpc_output);
body.kernels.push(coinbase_kernel.into());
// 获取完整的区块
let block_result = self.base_node_client.get_new_block(block_template.clone()).await?.into_inner();
2025-06-25 12:34:00 +00:00
let block = block_result.block.ok_or_else(|| anyhow!("No block in response"))?;
// 计算coinbase哈希
let coinbase_hash = coinbase_output.hash().to_hex();
// 将gRPC Block转换为CoreBlock以便计算output_smt_size
let core_block: CoreBlock = block.clone().try_into()
.map_err(|e| anyhow!("Block conversion error: {}", e))?;
// 获取前一个区块的output_smt_size从区块模板头中获取
let prev_output_smt_size = block_template
.header
.as_ref()
.ok_or_else(|| anyhow!("No header in block template"))?
.output_smt_size;
// 计算新的output_smt_size
let calculated_output_smt_size = self.calculate_output_smt_size(&core_block, prev_output_smt_size)?;
info!(target: LOG_TARGET, "Calculated output_smt_size: {} (prev: {})",
calculated_output_smt_size, prev_output_smt_size);
// 序列化区块模板
let block_template_json = serde_json::to_string(&block).map_err(|e| anyhow!("Serialization error: {}", e))?;
2025-06-25 12:34:00 +00:00
let mining_task = MiningTask {
coinbase_hash,
height,
target: target_difficulty,
output_smt_size: calculated_output_smt_size, // 使用计算出的值
2025-06-25 12:34:00 +00:00
block_template: block_template_json,
};
// 缓存挖矿任务
{
let mut tasks = self.mining_tasks.lock().await;
tasks.insert(mining_task.coinbase_hash.clone(), mining_task.clone());
}
Ok(mining_task)
}
// 通过ZMQ发送挖矿任务
pub fn send_mining_task(&self, task: &MiningTask) -> Result<()> {
let task_json = serde_json::to_string(task).map_err(|e| anyhow!("Serialization error: {}", e))?;
self.publisher_socket
.send_multipart(&["mining_task".as_bytes(), task_json.as_bytes()], 0)
.map_err(|e| anyhow!("ZMQ send error: {}", e))?;
2025-06-26 07:22:52 +00:00
info!(target: LOG_TARGET, "Sent mining task for height {} with target {} and output_smt_size {}",
task.height, task.target, task.output_smt_size);
2025-06-25 12:34:00 +00:00
Ok(())
}
// 接收外部提交的挖矿结果
pub async fn receive_submit(&mut self) -> Result<Option<SubmitRequest>> {
let mut message = Message::new();
// 非阻塞接收
match self.subscriber_socket.recv(&mut message, zmq::DONTWAIT) {
Ok(_) => {
let message_str = message.as_str().ok_or_else(|| anyhow!("Message decode error"))?;
if message_str.starts_with("submit ") {
let submit_json = &message_str[7..]; // 去掉"submit "前缀
let submit_request: SubmitRequest =
serde_json::from_str(submit_json).map_err(|e| anyhow!("Deserialization error: {}", e))?;
info!(target: LOG_TARGET, "Received submit for height {} with nonce {}",
submit_request.height, submit_request.nonce);
Ok(Some(submit_request))
} else {
Ok(None)
}
},
Err(zmq::Error::EAGAIN) => {
// 没有消息可读
Ok(None)
},
Err(e) => Err(anyhow!("ZMQ receive error: {}", e)),
}
}
// 提交区块到BaseNode
pub async fn submit_block_to_base_node(&mut self, submit_request: &SubmitRequest) -> Result<SubmitBlockResponse> {
// 反序列化区块数据
let block: Block = serde_json::from_str(&submit_request.block_data)
.map_err(|e| anyhow!("Block deserialization error: {}", e))?;
2025-06-25 12:34:00 +00:00
info!(target: LOG_TARGET, "Submitting block to base node for height {}", submit_request.height);
// 提交区块
let response = self.base_node_client.submit_block(block).await?;
info!(target: LOG_TARGET, "Block submitted successfully for height {}", submit_request.height);
Ok(response.into_inner())
}
// 主循环
pub async fn run(&mut self) -> Result<()> {
info!(target: LOG_TARGET, "Starting GBT client");
loop {
// 1. 获取区块模板和构造coinbase
match self.get_block_template_and_coinbase().await {
Ok(mining_task) => {
// 2. 通过ZMQ发送挖矿任务
if let Err(e) = self.send_mining_task(&mining_task) {
error!(target: LOG_TARGET, "Failed to send mining task: {}", e);
}
},
Err(e) => {
error!(target: LOG_TARGET, "Failed to get block template: {}", e);
sleep(Duration::from_secs(5)).await;
continue;
},
}
// 3. 接收外部提交
match self.receive_submit().await {
Ok(Some(submit_request)) => {
// 4. 提交区块到BaseNode
match self.submit_block_to_base_node(&submit_request).await {
Ok(_) => {
info!(target: LOG_TARGET, "Successfully submitted block for height {}", submit_request.height);
},
Err(e) => {
error!(target: LOG_TARGET, "Failed to submit block: {}", e);
},
}
},
Ok(None) => {
// 没有提交,继续循环
},
Err(e) => {
error!(target: LOG_TARGET, "Failed to receive submit: {}", e);
},
}
// 等待一段时间再获取下一个区块模板
sleep(Duration::from_secs(1)).await;
}
}
}
impl Drop for GbtClient {
fn drop(&mut self) {
info!(target: LOG_TARGET, "GBT client shutting down");
// ZMQ套接字会在Context销毁时自动关闭
}
}
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// BaseNode gRPC address
#[arg(short, long, default_value = "127.0.0.1:18102")]
base_node: String,
/// Network (mainnet, nextnet, testnet)
#[arg(short, long, default_value = "mainnet")]
network: String,
/// Wallet payment address
#[arg(
short,
long,
default_value = "14H4atSbXqSLFHDvhjx83ASCJDv3iCDu4T6DotCiCVCYq67koEJbgcbmYpeBpRjcZdRYtJ5CDw9gWRNXpe8chfnQSVU"
)]
wallet_address: String,
/// Coinbase extra data
#[arg(short, long, default_value = "m2pool.com")]
coinbase_extra: String,
/// ZMQ publisher port
#[arg(long, default_value = "5555")]
zmq_pub_port: u16,
/// ZMQ subscriber port
#[arg(long, default_value = "5556")]
zmq_sub_port: u16,
/// Enable TLS
#[arg(long)]
tls: bool,
/// TLS domain name
#[arg(long)]
tls_domain: Option<String>,
/// TLS CA certificate file
#[arg(long)]
tls_ca_cert: Option<String>,
/// Config directory
#[arg(long, default_value = ".")]
config_dir: String,
}
#[tokio::main]
async fn main() -> Result<()> {
// 初始化日志
env_logger::init();
2025-06-25 12:34:00 +00:00
let args = Args::parse();
// 解析网络
let network = match args.network.as_str() {
"mainnet" => Network::MainNet,
"nextnet" => Network::NextNet,
"testnet" => Network::NextNet, // 使用NextNet作为testnet
_ => return Err(anyhow!("Invalid network: {}", args.network)),
};
// 创建配置
let config = GbtConfig {
base_node_grpc_address: args.base_node,
base_node_grpc_authentication: GrpcAuthentication::None,
base_node_grpc_tls_domain_name: args.tls_domain,
base_node_grpc_ca_cert_filename: args.tls_ca_cert,
config_dir: PathBuf::from(args.config_dir),
network,
wallet_payment_address: args.wallet_address,
coinbase_extra: args.coinbase_extra,
range_proof_type: RangeProofType::BulletProofPlus,
zmq_publisher_port: args.zmq_pub_port,
zmq_subscriber_port: args.zmq_sub_port,
};
info!(target: LOG_TARGET, "Starting GBT client with network: {:?}", network);
// 创建GBT客户端
let mut client = GbtClient::new(config).await?;
// 运行客户端
client.run().await?;
Ok(())
}