tari-rust/src/main.rs

567 lines
20 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2024. The Tari Project
//
// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
// following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
// disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the
// following disclaimer in the documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
// products derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};
use anyhow::{anyhow, Result};
use clap::Parser;
use log::*;
use serde::{Deserialize, Serialize};
use tokio::{sync::Mutex, time::sleep};
use tonic::transport::{Certificate, ClientTlsConfig, Endpoint};
use zmq::{Context, Message, Socket};
use minotari_app_grpc::{
authentication::ClientAuthenticationInterceptor,
conversions::transaction_output::grpc_output_with_payref,
tari_rpc::{
base_node_client::BaseNodeClient, pow_algo::PowAlgos, Block, NewBlockTemplateRequest, PowAlgo,
SubmitBlockResponse,
},
};
use minotari_app_utilities::parse_miner_input::BaseNodeGrpcClient;
use std::str::FromStr;
use tari_common::configuration::Network;
use tari_common::MAX_GRPC_MESSAGE_SIZE;
use tari_common_types::{grpc_authentication::GrpcAuthentication, tari_address::TariAddress};
use tari_core::{
consensus::ConsensusManager,
transactions::{
generate_coinbase,
tari_amount::MicroMinotari,
transaction_components::{
encrypted_data::{PaymentId, TxType},
CoinBaseExtra, RangeProofType,
},
transaction_key_manager::{create_memory_db_key_manager, MemoryDbKeyManager},
},
};
use tari_utilities::hex::Hex;
use tari_utilities::ByteArray;
use jmt::{JellyfishMerkleTree, KeyHash};
use jmt::mock::MockTreeStore;
use tari_core::chain_storage::SmtHasher;
use tari_core::blocks::Block as CoreBlock;
const LOG_TARGET: &str = "gbt::main";
// ZMQ消息结构
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct MiningTask {
pub coinbase_hash: String,
pub height: u64,
pub target: u64,
pub output_smt_size: u64, // 新增output_smt_size
pub block_template: String, // 序列化的区块模板
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SubmitRequest {
pub height: u64,
pub nonce: u64,
pub solution: String,
pub block_data: String, // 序列化的区块数据
}
// 配置结构
#[derive(Debug, Clone)]
pub struct GbtConfig {
pub base_node_grpc_address: String,
pub base_node_grpc_authentication: GrpcAuthentication,
pub base_node_grpc_tls_domain_name: Option<String>,
pub base_node_grpc_ca_cert_filename: Option<String>,
pub config_dir: PathBuf,
pub network: Network,
pub wallet_payment_address: String,
pub coinbase_extra: String,
pub range_proof_type: RangeProofType,
pub zmq_publisher_port: u16,
pub zmq_subscriber_port: u16,
}
// GBT客户端
pub struct GbtClient {
base_node_client: BaseNodeGrpcClient,
key_manager: MemoryDbKeyManager,
consensus_manager: ConsensusManager,
wallet_payment_address: TariAddress,
config: GbtConfig,
// ZMQ相关
#[allow(dead_code)]
zmq_context: Context,
publisher_socket: Socket,
subscriber_socket: Socket,
// 挖矿任务缓存
mining_tasks: Arc<Mutex<HashMap<String, MiningTask>>>,
}
impl GbtClient {
pub async fn new(config: GbtConfig) -> Result<Self> {
// 创建BaseNode客户端
let base_node_client = Self::connect_base_node(&config).await?;
// 创建密钥管理器
let key_manager = create_memory_db_key_manager().map_err(|e| anyhow!("Key manager error: {}", e))?;
// 创建共识管理器
let consensus_manager = ConsensusManager::builder(config.network)
.build()
.map_err(|e| anyhow!("Consensus manager error: {}", e))?;
// 解析钱包地址
let wallet_payment_address = TariAddress::from_str(&config.wallet_payment_address)
.map_err(|e| anyhow!("Invalid wallet address: {}", e))?;
// 创建ZMQ上下文和套接字
let zmq_context = Context::new();
let publisher_socket = zmq_context
.socket(zmq::PUB)
.map_err(|e| anyhow!("ZMQ publisher error: {}", e))?;
let subscriber_socket = zmq_context
.socket(zmq::SUB)
.map_err(|e| anyhow!("ZMQ subscriber error: {}", e))?;
// 绑定ZMQ套接字
let publisher_addr = format!("tcp://*:{}", config.zmq_publisher_port);
let subscriber_addr = format!("tcp://localhost:{}", config.zmq_subscriber_port);
publisher_socket
.bind(&publisher_addr)
.map_err(|e| anyhow!("ZMQ bind error: {}", e))?;
subscriber_socket
.connect(&subscriber_addr)
.map_err(|e| anyhow!("ZMQ connect error: {}", e))?;
subscriber_socket
.set_subscribe(b"submit")
.map_err(|e| anyhow!("ZMQ subscribe error: {}", e))?;
Ok(Self {
base_node_client,
key_manager,
consensus_manager,
wallet_payment_address,
config,
zmq_context,
publisher_socket,
subscriber_socket,
mining_tasks: Arc::new(Mutex::new(HashMap::new())),
})
}
// 连接BaseNode
async fn connect_base_node(config: &GbtConfig) -> Result<BaseNodeGrpcClient> {
info!(target: LOG_TARGET, "Connecting to base node at {}", config.base_node_grpc_address);
let address = format!("http://{}", config.base_node_grpc_address);
let mut endpoint = Endpoint::new(address)?;
// 配置TLS如果需要
if let Some(domain_name) = config.base_node_grpc_tls_domain_name.as_ref() {
if let Some(cert_filename) = config.base_node_grpc_ca_cert_filename.as_ref() {
let cert_path = config.config_dir.join(cert_filename);
let pem = tokio::fs::read(cert_path)
.await
.map_err(|e| anyhow!("TLS certificate read error: {}", e))?;
let ca = Certificate::from_pem(pem);
let tls = ClientTlsConfig::new().ca_certificate(ca).domain_name(domain_name);
endpoint = endpoint
.tls_config(tls)
.map_err(|e| anyhow!("TLS config error: {}", e))?;
}
}
let channel = endpoint
.connect()
.await
.map_err(|e| anyhow!("Connection error: {}", e))?;
let node_conn = BaseNodeClient::with_interceptor(
channel,
ClientAuthenticationInterceptor::create(&config.base_node_grpc_authentication)
.map_err(|e| anyhow!("Authentication error: {}", e))?,
)
.max_encoding_message_size(MAX_GRPC_MESSAGE_SIZE)
.max_decoding_message_size(MAX_GRPC_MESSAGE_SIZE);
Ok(node_conn)
}
/// 计算output_smt_size
fn calculate_output_smt_size(&self, block: &CoreBlock, prev_output_smt_size: u64) -> Result<u64> {
// 创建JellyfishMerkleTree用于计算
let mock_store = MockTreeStore::new(true);
let output_smt = JellyfishMerkleTree::<_, SmtHasher>::new(&mock_store);
let mut batch = Vec::new();
// 处理所有输出(添加新的叶子节点)
for output in block.body.outputs() {
if !output.is_burned() {
let smt_key = KeyHash(
output.commitment.as_bytes().try_into().expect("commitment is 32 bytes")
);
let smt_value = output.smt_hash(block.header.height);
batch.push((smt_key, Some(smt_value.to_vec())));
}
}
// 处理所有输入(删除叶子节点)
for input in block.body.inputs() {
let smt_key = KeyHash(
input.commitment()?.as_bytes().try_into().expect("Commitment is 32 bytes")
);
batch.push((smt_key, None));
}
// 计算SMT变化
let (_, changes) = output_smt
.put_value_set(batch, block.header.height)
.map_err(|e| anyhow!("SMT calculation error: {}", e))?;
// 计算新的output_smt_size
let mut size = prev_output_smt_size;
size += changes.node_stats.first().map(|s| s.new_leaves).unwrap_or(0) as u64;
size = size.saturating_sub(changes.node_stats.first().map(|s| s.stale_leaves).unwrap_or(0) as u64);
Ok(size)
}
pub async fn get_block_template_and_coinbase(&mut self) -> Result<MiningTask> {
info!(target: LOG_TARGET, "Getting new block template");
// 获取区块模板
let pow_algo = PowAlgo {
pow_algo: PowAlgos::Sha3x.into(),
};
let request = NewBlockTemplateRequest {
algo: Some(pow_algo),
max_weight: 0,
};
let template_response = self
.base_node_client
.get_new_block_template(request)
.await?
.into_inner();
let mut block_template = template_response
.new_block_template
.clone()
.ok_or_else(|| anyhow!("No block template received"))?;
let height = block_template
.header
.as_ref()
.ok_or_else(|| anyhow!("No header in block template"))?
.height;
// 获取挖矿数据
let miner_data = template_response
.miner_data
.ok_or_else(|| anyhow!("No miner data received"))?;
let fee = MicroMinotari::from(miner_data.total_fees);
let reward = MicroMinotari::from(miner_data.reward);
let target_difficulty = miner_data.target_difficulty;
info!(target: LOG_TARGET, "Generating coinbase for height {}", height);
// 生成coinbase
let (coinbase_output, coinbase_kernel) = generate_coinbase(
fee,
reward,
height,
&CoinBaseExtra::try_from(self.config.coinbase_extra.as_bytes().to_vec())?,
&self.key_manager,
&self.wallet_payment_address,
true,
self.consensus_manager.consensus_constants(height),
self.config.range_proof_type,
PaymentId::Open {
user_data: vec![],
tx_type: TxType::Coinbase,
},
)
.await
.map_err(|e| anyhow!("Coinbase generation error: {}", e))?;
// 将coinbase添加到区块模板
let body = block_template
.body
.as_mut()
.ok_or_else(|| anyhow!("No body in block template"))?;
let grpc_output = grpc_output_with_payref(coinbase_output.clone(), None)
.map_err(|e| anyhow!("Output conversion error: {}", e))?;
body.outputs.push(grpc_output);
body.kernels.push(coinbase_kernel.into());
// 获取完整的区块
let block_result = self.base_node_client.get_new_block(block_template.clone()).await?.into_inner();
let block = block_result.block.ok_or_else(|| anyhow!("No block in response"))?;
// 计算coinbase哈希
let coinbase_hash = coinbase_output.hash().to_hex();
// 将gRPC Block转换为CoreBlock以便计算output_smt_size
let core_block: CoreBlock = block.clone().try_into()
.map_err(|e| anyhow!("Block conversion error: {}", e))?;
// 获取前一个区块的output_smt_size从区块模板头中获取
let prev_output_smt_size = block_template
.header
.as_ref()
.ok_or_else(|| anyhow!("No header in block template"))?
.output_smt_size;
// 计算新的output_smt_size
let calculated_output_smt_size = self.calculate_output_smt_size(&core_block, prev_output_smt_size)?;
info!(target: LOG_TARGET, "Calculated output_smt_size: {} (prev: {})",
calculated_output_smt_size, prev_output_smt_size);
// 序列化区块模板
let block_template_json = serde_json::to_string(&block).map_err(|e| anyhow!("Serialization error: {}", e))?;
let mining_task = MiningTask {
coinbase_hash,
height,
target: target_difficulty,
output_smt_size: calculated_output_smt_size, // 使用计算出的值
block_template: block_template_json,
};
// 缓存挖矿任务
{
let mut tasks = self.mining_tasks.lock().await;
tasks.insert(mining_task.coinbase_hash.clone(), mining_task.clone());
}
Ok(mining_task)
}
// 通过ZMQ发送挖矿任务
pub fn send_mining_task(&self, task: &MiningTask) -> Result<()> {
let task_json = serde_json::to_string(task).map_err(|e| anyhow!("Serialization error: {}", e))?;
self.publisher_socket
.send_multipart(&["mining_task".as_bytes(), task_json.as_bytes()], 0)
.map_err(|e| anyhow!("ZMQ send error: {}", e))?;
info!(target: LOG_TARGET, "Sent mining task for height {} with target {} and output_smt_size {}",
task.height, task.target, task.output_smt_size);
Ok(())
}
// 接收外部提交的挖矿结果
pub async fn receive_submit(&mut self) -> Result<Option<SubmitRequest>> {
let mut message = Message::new();
// 非阻塞接收
match self.subscriber_socket.recv(&mut message, zmq::DONTWAIT) {
Ok(_) => {
let message_str = message.as_str().ok_or_else(|| anyhow!("Message decode error"))?;
if message_str.starts_with("submit ") {
let submit_json = &message_str[7..]; // 去掉"submit "前缀
let submit_request: SubmitRequest =
serde_json::from_str(submit_json).map_err(|e| anyhow!("Deserialization error: {}", e))?;
info!(target: LOG_TARGET, "Received submit for height {} with nonce {}",
submit_request.height, submit_request.nonce);
Ok(Some(submit_request))
} else {
Ok(None)
}
},
Err(zmq::Error::EAGAIN) => {
// 没有消息可读
Ok(None)
},
Err(e) => Err(anyhow!("ZMQ receive error: {}", e)),
}
}
// 提交区块到BaseNode
pub async fn submit_block_to_base_node(&mut self, submit_request: &SubmitRequest) -> Result<SubmitBlockResponse> {
// 反序列化区块数据
let block: Block = serde_json::from_str(&submit_request.block_data)
.map_err(|e| anyhow!("Block deserialization error: {}", e))?;
info!(target: LOG_TARGET, "Submitting block to base node for height {}", submit_request.height);
// 提交区块
let response = self.base_node_client.submit_block(block).await?;
info!(target: LOG_TARGET, "Block submitted successfully for height {}", submit_request.height);
Ok(response.into_inner())
}
// 主循环
pub async fn run(&mut self) -> Result<()> {
info!(target: LOG_TARGET, "Starting GBT client");
loop {
// 1. 获取区块模板和构造coinbase
match self.get_block_template_and_coinbase().await {
Ok(mining_task) => {
// 2. 通过ZMQ发送挖矿任务
if let Err(e) = self.send_mining_task(&mining_task) {
error!(target: LOG_TARGET, "Failed to send mining task: {}", e);
}
},
Err(e) => {
error!(target: LOG_TARGET, "Failed to get block template: {}", e);
sleep(Duration::from_secs(5)).await;
continue;
},
}
// 3. 接收外部提交
match self.receive_submit().await {
Ok(Some(submit_request)) => {
// 4. 提交区块到BaseNode
match self.submit_block_to_base_node(&submit_request).await {
Ok(_) => {
info!(target: LOG_TARGET, "Successfully submitted block for height {}", submit_request.height);
},
Err(e) => {
error!(target: LOG_TARGET, "Failed to submit block: {}", e);
},
}
},
Ok(None) => {
// 没有提交,继续循环
},
Err(e) => {
error!(target: LOG_TARGET, "Failed to receive submit: {}", e);
},
}
// 等待一段时间再获取下一个区块模板
sleep(Duration::from_secs(1)).await;
}
}
}
impl Drop for GbtClient {
fn drop(&mut self) {
info!(target: LOG_TARGET, "GBT client shutting down");
// ZMQ套接字会在Context销毁时自动关闭
}
}
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// BaseNode gRPC address
#[arg(short, long, default_value = "127.0.0.1:18102")]
base_node: String,
/// Network (mainnet, nextnet, testnet)
#[arg(short, long, default_value = "mainnet")]
network: String,
/// Wallet payment address
#[arg(
short,
long,
default_value = "14H4atSbXqSLFHDvhjx83ASCJDv3iCDu4T6DotCiCVCYq67koEJbgcbmYpeBpRjcZdRYtJ5CDw9gWRNXpe8chfnQSVU"
)]
wallet_address: String,
/// Coinbase extra data
#[arg(short, long, default_value = "m2pool.com")]
coinbase_extra: String,
/// ZMQ publisher port
#[arg(long, default_value = "5555")]
zmq_pub_port: u16,
/// ZMQ subscriber port
#[arg(long, default_value = "5556")]
zmq_sub_port: u16,
/// Enable TLS
#[arg(long)]
tls: bool,
/// TLS domain name
#[arg(long)]
tls_domain: Option<String>,
/// TLS CA certificate file
#[arg(long)]
tls_ca_cert: Option<String>,
/// Config directory
#[arg(long, default_value = ".")]
config_dir: String,
}
#[tokio::main]
async fn main() -> Result<()> {
// 初始化日志
env_logger::init();
let args = Args::parse();
// 解析网络
let network = match args.network.as_str() {
"mainnet" => Network::MainNet,
"nextnet" => Network::NextNet,
"testnet" => Network::NextNet, // 使用NextNet作为testnet
_ => return Err(anyhow!("Invalid network: {}", args.network)),
};
// 创建配置
let config = GbtConfig {
base_node_grpc_address: args.base_node,
base_node_grpc_authentication: GrpcAuthentication::None,
base_node_grpc_tls_domain_name: args.tls_domain,
base_node_grpc_ca_cert_filename: args.tls_ca_cert,
config_dir: PathBuf::from(args.config_dir),
network,
wallet_payment_address: args.wallet_address,
coinbase_extra: args.coinbase_extra,
range_proof_type: RangeProofType::BulletProofPlus,
zmq_publisher_port: args.zmq_pub_port,
zmq_subscriber_port: args.zmq_sub_port,
};
info!(target: LOG_TARGET, "Starting GBT client with network: {:?}", network);
// 创建GBT客户端
let mut client = GbtClient::new(config).await?;
// 运行客户端
client.run().await?;
Ok(())
}