#[derive(Clone, PartialEq, Eq, Debug, Default)] pub struct Ics20Packet { pub amount: u64, pub denom: String, pub receiver: String, pub sender: String, } impl Ics20Packet { pub fn new<T: Into<String>>(amount: u64, denom: T, sender: &str, receiver: &str) -> Self { Ics20Packet { denom: denom.into(), amount, sender: sender.to_string(), receiver: receiver.to_string(), } } pub fn eth_encode(self) -> Binary { ethabi::encode(&[ Token::Uint(self.amount.into()), Token::String(self.denom), Token::String(self.receiver), Token::String(self.sender), ]) .into() } } impl TryFrom<&Binary> for Ics20Packet { type Error = ContractError; fn try_from(value: &Binary) -> Result<Self, Self::Error> { let values = ethabi::decode( &[ ParamType::Uint(256), ParamType::String, ParamType::String, ParamType::String, ], &value.0, ) .map_err(|_| ContractError::EthAbiDecoding)?; match &values[..] { [Token::Uint(amount), Token::String(denom), Token::String(receiver), Token::String(sender)] => { Ok(Ics20Packet { denom: denom.clone(), amount: (*amount) .try_into() .map_err(|_| ContractError::AmountOverflow {})?, sender: sender.clone(), receiver: receiver.clone(), }) } _ => Err(ContractError::EthAbiDecoding), } } } #[cw_serde] pub enum Ics20Ack { Result(Binary), Error(String), } fn ack_success() -> Binary { let res = Ics20Ack::Result(b"1".into()); to_binary(&res).unwrap() } fn ack_fail(err: String) -> Binary { let res = Ics20Ack::Error(err); to_binary(&res).unwrap() } const RECEIVE_ID: u64 = 1337; const ACK_FAILURE_ID: u64 = 0xfa17; #[cfg_attr(not(feature = "library"), entry_point)] pub fn reply(deps: DepsMut, _env: Env, reply: Reply) -> Result<Response, ContractError> { match reply.id { RECEIVE_ID => match reply.result { SubMsgResult::Ok(_) => Ok(Response::new()), SubMsgResult::Err(err) => { let reply_args = REPLY_ARGS.load(deps.storage)?; undo_reduce_channel_balance( deps.storage, &reply_args.channel, &reply_args.denom, reply_args.amount, )?; Ok(Response::new().set_data(ack_fail(err))) } }, ACK_FAILURE_ID => match reply.result { SubMsgResult::Ok(_) => Ok(Response::new()), SubMsgResult::Err(err) => Ok(Response::new().set_data(ack_fail(err))), }, _ => Err(ContractError::UnknownReplyId { id: reply.id }), } } #[cfg_attr(not(feature = "library"), entry_point)] pub fn ibc_channel_open( _deps: DepsMut, _env: Env, msg: IbcChannelOpenMsg, ) -> Result<(), ContractError> { enforce_order_and_version(msg.channel(), msg.counterparty_version())?; Ok(()) } #[cfg_attr(not(feature = "library"), entry_point)] pub fn ibc_channel_connect( deps: DepsMut, _env: Env, msg: IbcChannelConnectMsg, ) -> Result<IbcBasicResponse, ContractError> { enforce_order_and_version(msg.channel(), msg.counterparty_version())?; let channel: IbcChannel = msg.into(); let info = ChannelInfo { id: channel.endpoint.channel_id, counterparty_endpoint: channel.counterparty_endpoint, connection_id: channel.connection_id, }; CHANNEL_INFO.save(deps.storage, &info.id, &info)?; Ok(IbcBasicResponse::default()) } fn enforce_order_and_version( channel: &IbcChannel, counterparty_version: Option<&str>, ) -> Result<(), ContractError> { if channel.version != ICS20_VERSION { return Err(ContractError::InvalidIbcVersion { version: channel.version.clone(), }); } if let Some(version) = counterparty_version { if version != ICS20_VERSION { return Err(ContractError::InvalidIbcVersion { version: version.to_string(), }); } } if channel.order != ICS20_ORDERING { return Err(ContractError::OnlyOrderedChannel {}); } Ok(()) } #[cfg_attr(not(feature = "library"), entry_point)] pub fn ibc_channel_close( _deps: DepsMut, _env: Env, _channel: IbcChannelCloseMsg, ) -> Result<IbcBasicResponse, ContractError> { unimplemented!(); } #[cfg_attr(not(feature = "library"), entry_point)] pub fn ibc_packet_receive( deps: DepsMut, _env: Env, msg: IbcPacketReceiveMsg, ) -> Result<IbcReceiveResponse, Never> { let packet = msg.packet; do_ibc_packet_receive(deps, &packet).or_else(|err| { Ok(IbcReceiveResponse::new() .set_ack(ack_fail(err.to_string())) .add_attributes(vec![ attr("action", "receive"), attr("success", "false"), attr("error", err.to_string()), ])) }) } fn parse_voucher_denom<'a>( voucher_denom: &'a str, remote_endpoint: &IbcEndpoint, ) -> Result<&'a str, ContractError> { let split_denom: Vec<&str> = voucher_denom.splitn(3, '/').collect(); if split_denom.len() != 3 { return Err(ContractError::NoForeignTokens {}); } if split_denom[0] != remote_endpoint.port_id { return Err(ContractError::FromOtherPort { port: split_denom[0].into(), }); } if split_denom[1] != remote_endpoint.channel_id { return Err(ContractError::FromOtherChannel { channel: split_denom[1].into(), }); } Ok(split_denom[2]) } fn do_ibc_packet_receive( deps: DepsMut, packet: &IbcPacket, ) -> Result<IbcReceiveResponse, ContractError> { let msg: Ics20Packet = (&packet.data).try_into()?; let channel = packet.dest.channel_id.clone(); let denom = parse_voucher_denom(&msg.denom, &packet.src)?; reduce_channel_balance(deps.storage, &channel, denom, msg.amount)?; let reply_args = ReplyArgs { channel, denom: denom.to_string(), amount: msg.amount, }; REPLY_ARGS.save(deps.storage, &reply_args)?; let to_send = Amount::from_parts(denom.to_string(), msg.amount); let gas_limit = check_gas_limit(deps.as_ref(), &to_send)?; let send = send_amount(to_send, msg.receiver.clone()); let mut submsg = SubMsg::reply_on_error(send, RECEIVE_ID); submsg.gas_limit = gas_limit; let res = IbcReceiveResponse::new() .set_ack(ack_success()) .add_submessage(submsg) .add_attribute("action", "receive") .add_attribute("sender", msg.sender) .add_attribute("receiver", msg.receiver) .add_attribute("denom", denom) .add_attribute("amount", Uint128::from(msg.amount)) .add_attribute("success", "true"); Ok(res) } fn check_gas_limit(deps: Deps, amount: &Amount) -> Result<Option<u64>, ContractError> { match amount { Amount::Cw20(coin) => { let addr = deps.api.addr_validate(&coin.address)?; let allowed = ALLOW_LIST.may_load(deps.storage, &addr)?; match allowed { Some(allow) => Ok(allow.gas_limit), None => match CONFIG.load(deps.storage)?.default_gas_limit { Some(base) => Ok(Some(base)), None => Err(ContractError::NotOnAllowList), }, } } _ => Ok(None), } } #[cfg_attr(not(feature = "library"), entry_point)] pub fn ibc_packet_ack( deps: DepsMut, _env: Env, msg: IbcPacketAckMsg, ) -> Result<IbcBasicResponse, ContractError> { let ics20msg: Ics20Ack = from_binary(&msg.acknowledgement.data)?; match ics20msg { Ics20Ack::Result(_) => on_packet_success(deps, msg.original_packet), Ics20Ack::Error(err) => on_packet_failure(deps, msg.original_packet, err), } } #[cfg_attr(not(feature = "library"), entry_point)] pub fn ibc_packet_timeout( deps: DepsMut, _env: Env, msg: IbcPacketTimeoutMsg, ) -> Result<IbcBasicResponse, ContractError> { let packet = msg.packet; on_packet_failure(deps, packet, "timeout".to_string()) } fn on_packet_success(_deps: DepsMut, packet: IbcPacket) -> Result<IbcBasicResponse, ContractError> { let msg: Ics20Packet = (&packet.data).try_into()?; let attributes = vec![ attr("action", "acknowledge"), attr("sender", &msg.sender), attr("receiver", &msg.receiver), attr("denom", &msg.denom), attr("amount", Uint128::from(msg.amount)), attr("success", "true"), ]; Ok(IbcBasicResponse::new().add_attributes(attributes)) } fn on_packet_failure( deps: DepsMut, packet: IbcPacket, err: String, ) -> Result<IbcBasicResponse, ContractError> { let msg: Ics20Packet = (&packet.data).try_into()?; reduce_channel_balance(deps.storage, &packet.src.channel_id, &msg.denom, msg.amount)?; let to_send = Amount::from_parts(msg.denom.clone(), msg.amount); let gas_limit = check_gas_limit(deps.as_ref(), &to_send)?; let send = send_amount(to_send, msg.sender.clone()); let mut submsg = SubMsg::reply_on_error(send, ACK_FAILURE_ID); submsg.gas_limit = gas_limit; let res = IbcBasicResponse::new() .add_submessage(submsg) .add_attribute("action", "acknowledge") .add_attribute("sender", msg.sender) .add_attribute("receiver", msg.receiver) .add_attribute("denom", msg.denom) .add_attribute("amount", msg.amount.to_string()) .add_attribute("success", "false") .add_attribute("error", err); Ok(res) } fn send_amount(amount: Amount, recipient: String) -> CosmosMsg { match amount { Amount::Native(coin) => BankMsg::Send { to_address: recipient, amount: vec![coin], } .into(), Amount::Cw20(coin) => { let msg = Cw20ExecuteMsg::Transfer { recipient, amount: coin.amount, }; WasmMsg::Execute { contract_addr: coin.address, msg: to_binary(&msg).unwrap(), funds: vec![], } .into() } } } #[cfg(test)] mod test { use super::*; use crate::contract::{execute, migrate, query_channel}; use crate::msg::{ExecuteMsg, MigrateMsg, TransferMsg}; use crate::test_helpers::*; use cosmwasm_std::testing::{mock_env, mock_info}; use cosmwasm_std::{coins, to_vec, IbcEndpoint, IbcMsg, IbcTimeout, Timestamp}; use cw20::Cw20ReceiveMsg; #[test] fn check_ack_json() { let success = Ics20Ack::Result(b"1".into()); let fail = Ics20Ack::Error("bad coin".into()); let success_json = String::from_utf8(to_vec(&success).unwrap()).unwrap(); assert_eq!(r#"{"result":"MQ=="}"#, success_json.as_str()); let fail_json = String::from_utf8(to_vec(&fail).unwrap()).unwrap(); assert_eq!(r#"{"error":"bad coin"}"#, fail_json.as_str()); } #[test] fn check_encode_decode_iso() { let packet = Ics20Packet::new( 12345, "ucosm", "cosmos1zedxv25ah8fksmg2lzrndrpkvsjqgk4zt5ff7n", "wasm1fucynrfkrt684pm8jrt8la5h2csvs5cnldcgqc", ); assert_eq!(Ok(packet.clone()), (&packet.eth_encode()).try_into()); } fn cw20_payment( amount: u128, address: &str, recipient: &str, gas_limit: Option<u64>, ) -> SubMsg { let msg = Cw20ExecuteMsg::Transfer { recipient: recipient.into(), amount: Uint128::new(amount), }; let exec = WasmMsg::Execute { contract_addr: address.into(), msg: to_binary(&msg).unwrap(), funds: vec![], }; let mut msg = SubMsg::reply_on_error(exec, RECEIVE_ID); msg.gas_limit = gas_limit; msg } fn native_payment(amount: u128, denom: &str, recipient: &str) -> SubMsg { SubMsg::reply_on_error( BankMsg::Send { to_address: recipient.into(), amount: coins(amount, denom), }, RECEIVE_ID, ) } fn mock_receive_packet( my_channel: &str, amount: u64, denom: &str, receiver: &str, ) -> IbcPacket { let data = Ics20Packet { denom: format!("{}/{}/{}", REMOTE_PORT, "channel-1234", denom), amount, sender: "remote-sender".to_string(), receiver: receiver.to_string(), }; print!("Packet denom: {}", &data.denom); IbcPacket::new( data.eth_encode(), IbcEndpoint { port_id: REMOTE_PORT.to_string(), channel_id: "channel-1234".to_string(), }, IbcEndpoint { port_id: CONTRACT_PORT.to_string(), channel_id: my_channel.to_string(), }, 3, Timestamp::from_seconds(1665321069).into(), ) } #[test] fn send_receive_cw20() { let send_channel = "channel-9"; let cw20_addr = "token-addr"; let cw20_denom = "cw20:token-addr"; let gas_limit = 1234567; let mut deps = setup( &["channel-1", "channel-7", send_channel], &[(cw20_addr, gas_limit)], ); let recv_packet = mock_receive_packet(send_channel, 876543210, cw20_denom, "local-rcpt"); let recv_high_packet = mock_receive_packet(send_channel, 1876543210, cw20_denom, "local-rcpt"); let msg = IbcPacketReceiveMsg::new(recv_packet.clone()); let res = ibc_packet_receive(deps.as_mut(), mock_env(), msg).unwrap(); assert!(res.messages.is_empty()); let ack: Ics20Ack = from_binary(&res.acknowledgement).unwrap(); let no_funds = Ics20Ack::Error(ContractError::InsufficientFunds {}.to_string()); assert_eq!(ack, no_funds); let transfer = TransferMsg { channel: send_channel.to_string(), remote_address: "remote-rcpt".to_string(), timeout: None, }; let msg = ExecuteMsg::Receive(Cw20ReceiveMsg { sender: "local-sender".to_string(), amount: Uint128::new(987654321), msg: to_binary(&transfer).unwrap(), }); let info = mock_info(cw20_addr, &[]); let res = execute(deps.as_mut(), mock_env(), info, msg).unwrap(); assert_eq!(1, res.messages.len()); let expected = Ics20Packet { denom: cw20_denom.into(), amount: 987654321, sender: "local-sender".to_string(), receiver: "remote-rcpt".to_string(), }; let timeout = mock_env().block.time.plus_seconds(DEFAULT_TIMEOUT); assert_eq!( &res.messages[0], &SubMsg::new(IbcMsg::SendPacket { channel_id: send_channel.to_string(), data: expected.eth_encode(), timeout: IbcTimeout::with_timestamp(timeout), }) ); let state = query_channel(deps.as_ref(), send_channel.to_string()).unwrap(); assert_eq!(state.balances, vec![Amount::cw20(987654321, cw20_addr)]); assert_eq!(state.total_sent, vec![Amount::cw20(987654321, cw20_addr)]); let msg = IbcPacketReceiveMsg::new(recv_high_packet); let res = ibc_packet_receive(deps.as_mut(), mock_env(), msg).unwrap(); assert!(res.messages.is_empty()); let ack: Ics20Ack = from_binary(&res.acknowledgement).unwrap(); assert_eq!(ack, no_funds); let msg = IbcPacketReceiveMsg::new(recv_packet); let res = ibc_packet_receive(deps.as_mut(), mock_env(), msg).unwrap(); assert_eq!(1, res.messages.len()); assert_eq!( cw20_payment(876543210, cw20_addr, "local-rcpt", Some(gas_limit)), res.messages[0] ); let ack: Ics20Ack = from_binary(&res.acknowledgement).unwrap(); assert!(matches!(ack, Ics20Ack::Result(_))); let state = query_channel(deps.as_ref(), send_channel.to_string()).unwrap(); assert_eq!(state.balances, vec![Amount::cw20(111111111, cw20_addr)]); assert_eq!(state.total_sent, vec![Amount::cw20(987654321, cw20_addr)]); } #[test] fn send_receive_native() { let send_channel = "channel-9"; let mut deps = setup(&["channel-1", "channel-7", send_channel], &[]); let denom = "uatom"; let recv_packet = mock_receive_packet(send_channel, 876543210, denom, "local-rcpt"); let recv_high_packet = mock_receive_packet(send_channel, 1876543210, denom, "local-rcpt"); let msg = IbcPacketReceiveMsg::new(recv_packet.clone()); let res = ibc_packet_receive(deps.as_mut(), mock_env(), msg).unwrap(); assert!(res.messages.is_empty()); let ack: Ics20Ack = from_binary(&res.acknowledgement).unwrap(); let no_funds = Ics20Ack::Error(ContractError::InsufficientFunds {}.to_string()); assert_eq!(ack, no_funds); let msg = ExecuteMsg::Transfer(TransferMsg { channel: send_channel.to_string(), remote_address: "my-remote-address".to_string(), timeout: None, }); let info = mock_info("local-sender", &coins(987654321, denom)); execute(deps.as_mut(), mock_env(), info, msg).unwrap(); let state = query_channel(deps.as_ref(), send_channel.to_string()).unwrap(); assert_eq!(state.balances, vec![Amount::native(987654321, denom)]); assert_eq!(state.total_sent, vec![Amount::native(987654321, denom)]); let msg = IbcPacketReceiveMsg::new(recv_high_packet); let res = ibc_packet_receive(deps.as_mut(), mock_env(), msg).unwrap(); assert!(res.messages.is_empty()); let ack: Ics20Ack = from_binary(&res.acknowledgement).unwrap(); assert_eq!(ack, no_funds); let msg = IbcPacketReceiveMsg::new(recv_packet); let res = ibc_packet_receive(deps.as_mut(), mock_env(), msg).unwrap(); assert_eq!(1, res.messages.len()); assert_eq!( native_payment(876543210, denom, "local-rcpt"), res.messages[0] ); let ack: Ics20Ack = from_binary(&res.acknowledgement).unwrap(); assert!(matches!(ack, Ics20Ack::Result(_))); let state = query_channel(deps.as_ref(), send_channel.to_string()).unwrap(); assert_eq!(state.balances, vec![Amount::native(111111111, denom)]); assert_eq!(state.total_sent, vec![Amount::native(987654321, denom)]); } #[test] fn check_gas_limit_handles_all_cases() { let send_channel = "channel-9"; let allowed = "foobar"; let allowed_gas = 777666; let mut deps = setup(&[send_channel], &[(allowed, allowed_gas)]); let limit = check_gas_limit(deps.as_ref(), &Amount::cw20(500, allowed)).unwrap(); assert_eq!(limit, Some(allowed_gas)); let random = "tokenz"; check_gas_limit(deps.as_ref(), &Amount::cw20(500, random)).unwrap_err(); let def_limit = 54321; migrate( deps.as_mut(), mock_env(), MigrateMsg { default_gas_limit: Some(def_limit), }, ) .unwrap(); let limit = check_gas_limit(deps.as_ref(), &Amount::cw20(500, allowed)).unwrap(); assert_eq!(limit, Some(allowed_gas)); let limit = check_gas_limit(deps.as_ref(), &Amount::cw20(500, random)).unwrap(); assert_eq!(limit, Some(def_limit)); } }
package r1cs func (builder *builder) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars, s := builder.toVariables(append([]frontend.Variable{i1, i2}, in...)...) return builder.add(vars, false, s, nil) } func (builder *builder) MulAcc(a, b, c frontend.Variable) frontend.Variable { mulBC := func() { builder.mbuf1 = builder.mbuf1[:0] n1, v1Constant := builder.constantValue(b) n2, v2Constant := builder.constantValue(c) if !v1Constant && !v2Constant { res := builder.newInternalVariable() builder.cs.AddConstraint(builder.newR1C(b, c, res)) builder.mbuf1 = append(builder.mbuf1, res...) return } if v1Constant && v2Constant { builder.cs.Mul(&n1, &n2) builder.mbuf1 = append(builder.mbuf1, expr.NewTerm(0, n1)) return } if v1Constant { builder.mbuf1 = append(builder.mbuf1, builder.toVariable(c)...) builder.mulConstant(builder.mbuf1, n1, true) return } builder.mbuf1 = append(builder.mbuf1, builder.toVariable(b)...) builder.mulConstant(builder.mbuf1, n2, true) } mulBC() _a := builder.toVariable(a) builder.mbuf2 = builder.mbuf2[:0] builder.add([]expr.LinearExpression{_a, builder.mbuf1}, false, 0, &builder.mbuf2) _a = _a[:0] if len(builder.mbuf2) <= cap(_a) { _a = append(_a, builder.mbuf2...) } else { _a = make(expr.LinearExpression, len(builder.mbuf2), len(builder.mbuf2)*3) copy(_a, builder.mbuf2) } return _a } func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars, s := builder.toVariables(append([]frontend.Variable{i1, i2}, in...)...) return builder.add(vars, true, s, nil) } func (builder *builder) add(vars []expr.LinearExpression, sub bool, capacity int, res *expr.LinearExpression) frontend.Variable { for lID, v := range vars { builder.heap = append(builder.heap, linMeta{val: v[0].VID, lID: lID}) } builder.heap.heapify() if res == nil { t := make(expr.LinearExpression, 0, capacity) res = &t } curr := -1 for len(builder.heap) > 0 { lID, tID := builder.heap[0].lID, builder.heap[0].tID if tID == len(vars[lID])-1 { builder.heap.popHead() } else { builder.heap[0].tID++ builder.heap[0].val = vars[lID][tID+1].VID builder.heap.fix(0) } t := &vars[lID][tID] if t.Coeff.IsZero() { continue } if curr != -1 && t.VID == (*res)[curr].VID { if sub && lID != 0 { builder.cs.Sub(&(*res)[curr].Coeff, &t.Coeff) } else { builder.cs.Add(&(*res)[curr].Coeff, &t.Coeff) } if (*res)[curr].Coeff.IsZero() { (*res) = (*res)[:curr] curr-- } } else { (*res) = append((*res), *t) curr++ if sub && lID != 0 { builder.cs.Neg(&(*res)[curr].Coeff) } } } if len((*res)) == 0 { (*res) = append((*res), expr.NewTerm(0, constraint.Coeff{})) } compressed := builder.compress((*res)) if len(compressed) != len(*res) { *res = (*res)[:0] *res = append(*res, compressed...) } return *res } func (builder *builder) Neg(i frontend.Variable) frontend.Variable { v := builder.toVariable(i) if n, ok := builder.constantValue(v); ok { builder.cs.Neg(&n) return expr.NewLinearExpression(0, n) } return builder.negateLinExp(v) } func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars, _ := builder.toVariables(append([]frontend.Variable{i1, i2}, in...)...) mul := func(v1, v2 expr.LinearExpression, first bool) expr.LinearExpression { n1, v1Constant := builder.constantValue(v1) n2, v2Constant := builder.constantValue(v2) if !v1Constant && !v2Constant { res := builder.newInternalVariable() builder.cs.AddConstraint(builder.newR1C(v1, v2, res)) return res } if v1Constant && v2Constant { builder.cs.Mul(&n1, &n2) return expr.NewLinearExpression(0, n1) } if v1Constant { return builder.mulConstant(v2, n1, false) } return builder.mulConstant(v1, n2, !first) } res := mul(vars[0], vars[1], true) for i := 2; i < len(vars); i++ { res = mul(res, vars[i], false) } return res } func (builder *builder) mulConstant(v1 expr.LinearExpression, lambda constraint.Coeff, inPlace bool) expr.LinearExpression { var res expr.LinearExpression if inPlace { res = v1 } else { res = v1.Clone() } for i := 0; i < len(res); i++ { builder.cs.Mul(&res[i].Coeff, &lambda) } return res } func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { vars, _ := builder.toVariables(i1, i2) v1 := vars[0] v2 := vars[1] n1, v1Constant := builder.constantValue(v1) n2, v2Constant := builder.constantValue(v2) if !v2Constant { res := builder.newInternalVariable() debug := builder.newDebugInfo("div", v1, "/", v2, " == ", res) builder.cs.AddConstraint(builder.newR1C(v2, res, v1), debug) return res } if n2.IsZero() { panic("div by constant(0)") } builder.cs.Inverse(&n2) if v1Constant { builder.cs.Mul(&n2, &n1) return expr.NewLinearExpression(0, n2) } return builder.mulConstant(v1, n2, false) } func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { vars, _ := builder.toVariables(i1, i2) v1 := vars[0] v2 := vars[1] n1, v1Constant := builder.constantValue(v1) n2, v2Constant := builder.constantValue(v2) if !v2Constant { res := builder.newInternalVariable() debug := builder.newDebugInfo("div", v1, "/", v2, " == ", res) v2Inv := builder.newInternalVariable() c1 := builder.cs.AddConstraint(builder.newR1C(v2, v2Inv, builder.cstOne())) c2 := builder.cs.AddConstraint(builder.newR1C(v1, v2Inv, res)) builder.cs.AttachDebugInfo(debug, []int{c1, c2}) return res } if n2.IsZero() { panic("div by constant(0)") } builder.cs.Inverse(&n2) if v1Constant { builder.cs.Mul(&n2, &n1) return expr.NewLinearExpression(0, n2) } return builder.mulConstant(v1, n2, false) } func (builder *builder) Inverse(i1 frontend.Variable) frontend.Variable { vars, _ := builder.toVariables(i1) if c, ok := builder.constantValue(vars[0]); ok { if c.IsZero() { panic("inverse by constant(0)") } builder.cs.Inverse(&c) return expr.NewLinearExpression(0, c) } res := builder.newInternalVariable() debug := builder.newDebugInfo("inverse", vars[0], "*", res, " == 1") builder.cs.AddConstraint(builder.newR1C(res, vars[0], builder.cstOne()), debug) return res } func (builder *builder) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { nbBits := builder.cs.FieldBitLen() if len(n) == 1 { nbBits = n[0] if nbBits < 0 { panic("invalid n") } } return bits.ToBinary(builder, i1, bits.WithNbDigits(nbBits)) } func (builder *builder) FromBinary(_b ...frontend.Variable) frontend.Variable { return bits.FromBinary(builder, _b) } func (builder *builder) Xor(_a, _b frontend.Variable) frontend.Variable { vars, _ := builder.toVariables(_a, _b) a := vars[0] b := vars[1] builder.AssertIsBoolean(a) builder.AssertIsBoolean(b) if len(b) > len(a) { a, b = b, a } t := builder.Sub(builder.cstOne(), builder.Mul(b, 2)) t = builder.Add(builder.Mul(a, t), b) builder.MarkBoolean(t) return t } func (builder *builder) Or(_a, _b frontend.Variable) frontend.Variable { vars, _ := builder.toVariables(_a, _b) a := vars[0] b := vars[1] builder.AssertIsBoolean(a) builder.AssertIsBoolean(b) res := builder.newInternalVariable() builder.MarkBoolean(res) c := builder.Neg(res).(expr.LinearExpression) c = append(c, a...) c = append(c, b...) builder.cs.AddConstraint(builder.newR1C(a, b, c)) return res } func (builder *builder) And(_a, _b frontend.Variable) frontend.Variable { vars, _ := builder.toVariables(_a, _b) a := vars[0] b := vars[1] builder.AssertIsBoolean(a) builder.AssertIsBoolean(b) res := builder.Mul(a, b) builder.MarkBoolean(res) return res } func (builder *builder) Select(i0, i1, i2 frontend.Variable) frontend.Variable { vars, _ := builder.toVariables(i0, i1, i2) cond := vars[0] builder.AssertIsBoolean(cond) if c, ok := builder.constantValue(cond); ok { if builder.isCstOne(&c) { return vars[1] } return vars[2] } n1, ok1 := builder.constantValue(vars[1]) n2, ok2 := builder.constantValue(vars[2]) if ok1 && ok2 { builder.cs.Sub(&n1, &n2) res := builder.Mul(cond, n1) res = builder.Add(res, vars[2]) return res } if ok1 { if n1.IsZero() { v := builder.Sub(builder.cstOne(), vars[0]) return builder.Mul(v, vars[2]) } } v := builder.Sub(vars[1], vars[2]) w := builder.Mul(cond, v) return builder.Add(w, vars[2]) } func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { vars, _ := builder.toVariables(b0, b1, i0, i1, i2, i3) s0, s1 := vars[0], vars[1] in0, in1, in2, in3 := vars[2], vars[3], vars[4], vars[5] builder.AssertIsBoolean(s0) builder.AssertIsBoolean(s1) c0, b0IsConstant := builder.constantValue(s0) c1, b1IsConstant := builder.constantValue(s1) if b0IsConstant && b1IsConstant { b0 := builder.isCstOne(&c0) b1 := builder.isCstOne(&c1) if !b0 && !b1 { return in0 } if b0 && !b1 { return in1 } if b0 && b1 { return in3 } return in2 } tmp1 := builder.Add(in3, in0) tmp1 = builder.Sub(tmp1, in2, in1) tmp1 = builder.Mul(tmp1, s1) tmp1 = builder.Add(tmp1, in1) tmp1 = builder.Sub(tmp1, in0) // (1) tmp1 = s1 * (in3 - in2 - in1 + in0) + in1 - in0 tmp2 := builder.Mul(tmp1, s0) // (2) tmp2 = tmp1 * s0 res := builder.Sub(in2, in0) res = builder.Mul(res, s1) res = builder.Add(res, tmp2, in0) // (3) res = (v2 - v0) * s1 + tmp2 + in0 return res } func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable { vars, _ := builder.toVariables(i1) a := vars[0] if c, ok := builder.constantValue(a); ok { if c.IsZero() { return builder.cstOne() } return builder.cstZero() } debug := builder.newDebugInfo("isZero", a) // x = 1/a // in a hint (x == 0 if a == 0) // m = -a*x + 1 // constrain m to be 1 if a == 0 // a * m = 0 // constrain m to be 0 if a != 0 m := builder.newInternalVariable() // x = 1/a // in a hint (x == 0 if a == 0) x, err := builder.NewHint(hint.InvZero, 1, a) if err != nil { panic(err) } // m = -a*x + 1 // constrain m to be 1 if a == 0 c1 := builder.cs.AddConstraint(builder.newR1C(builder.Neg(a), x[0], builder.Sub(m, 1))) // a * m = 0 // constrain m to be 0 if a != 0 c2 := builder.cs.AddConstraint(builder.newR1C(a, m, builder.cstZero())) builder.cs.AttachDebugInfo(debug, []int{c1, c2}) return m } // Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1<i2 func (builder *builder) Cmp(i1, i2 frontend.Variable) frontend.Variable { vars, _ := builder.toVariables(i1, i2) bi1 := builder.ToBinary(vars[0], builder.cs.FieldBitLen()) bi2 := builder.ToBinary(vars[1], builder.cs.FieldBitLen()) res := builder.cstZero() for i := builder.cs.FieldBitLen() - 1; i >= 0; i-- { iszeroi1 := builder.IsZero(bi1[i]) iszeroi2 := builder.IsZero(bi2[i]) i1i2 := builder.And(bi1[i], iszeroi2) i2i1 := builder.And(bi2[i], iszeroi1) n := builder.Select(i2i1, -1, 0) m := builder.Select(i1i2, 1, n) res = builder.Select(builder.IsZero(res), m, res).(expr.LinearExpression) } return res } func (builder *builder) Println(a ...frontend.Variable) { var log constraint.LogEntry if _, file, line, ok := runtime.Caller(1); ok { log.Caller = fmt.Sprintf("%s:%d", filepath.Base(file), line) } var sbb strings.Builder for i, arg := range a { if i > 0 { sbb.WriteByte(' ') } if v, ok := arg.(expr.LinearExpression); ok { assertIsSet(v) sbb.WriteString("%s") log.ToResolve = append(log.ToResolve, builder.getLinearExpression(v)) } else { builder.printArg(&log, &sbb, arg) } } log.Format = sbb.String() builder.cs.AddLog(log) } func (builder *builder) printArg(log *constraint.LogEntry, sbb *strings.Builder, a frontend.Variable) { leafCount, err := schema.Walk(a, tVariable, nil) count := leafCount.Public + leafCount.Secret if count == 0 || err != nil { sbb.WriteString(fmt.Sprint(a)) return } sbb.WriteByte('{') printer := func(f schema.LeafInfo, tValue reflect.Value) error { count-- sbb.WriteString(f.FullName()) sbb.WriteString(": ") sbb.WriteString("%s") if count != 0 { sbb.WriteString(", ") } v := tValue.Interface().(expr.LinearExpression) log.ToResolve = append(log.ToResolve, builder.getLinearExpression(v)) return nil } _, _ = schema.Walk(a, tVariable, printer) sbb.WriteByte('}') } func (builder *builder) negateLinExp(l expr.LinearExpression) expr.LinearExpression { res := make(expr.LinearExpression, len(l)) copy(res, l) for i := 0; i < len(res); i++ { builder.cs.Neg(&res[i].Coeff) } return res } func (builder *builder) Compiler() frontend.Compiler { return builder } func (builder *builder) Commit(v ...frontend.Variable) (frontend.Variable, error) { vars, s := builder.toVariables(v...) for lID, v := range vars { builder.heap = append(builder.heap, linMeta{val: v[0].VID, lID: lID}) } builder.heap.heapify() committed := make([]int, 0, s) curr := -1 nbPublicCommitted := 0 for len(builder.heap) > 0 { lID, tID := builder.heap[0].lID, builder.heap[0].tID if tID == len(vars[lID])-1 { builder.heap.popHead() } else { builder.heap[0].tID++ builder.heap[0].val = vars[lID][tID+1].VID builder.heap.fix(0) } t := &vars[lID][tID] if t.VID == 0 { continue // don't commit to ONE_WIRE } if curr != -1 && t.VID == committed[curr] { continue } else { // append, it's a new variable ID committed = append(committed, t.VID) if t.VID < builder.cs.GetNbPublicVariables() { nbPublicCommitted++ } curr++ } } if len(committed) == 0 { return nil, errors.New("must commit to at least one variable") } // build commitment commitment := constraint.NewCommitment(committed, nbPublicCommitted) hintOut, err := builder.NewHint(bsb22CommitmentComputePlaceholder, 1, builder.getCommittedVariables(&commitment)...) if err != nil { return nil, err } cVar := hintOut[0] commitment.HintID = hint.UUID(bsb22CommitmentComputePlaceholder) // TODO @gbotrel probably not needed commitment.CommitmentIndex = (cVar.(expr.LinearExpression))[0].WireID() // TODO @Tabaie: Get rid of this field commitment.CommittedAndCommitment = append(commitment.Committed, commitment.CommitmentIndex) if commitment.CommitmentIndex <= commitment.Committed[len(commitment.Committed)-1] { return nil, fmt.Errorf("commitment variable index smaller than some committed variable indices") } if err := builder.cs.AddCommitment(commitment); err != nil { return nil, err } return cVar, nil } func (builder *builder) getCommittedVariables(i *constraint.Commitment) []frontend.Variable { res := make([]frontend.Variable, len(i.Committed)) for j, wireIndex := range i.Committed { res[j] = expr.NewLinearExpression(wireIndex, builder.tOne) } return res } func bsb22CommitmentComputePlaceholder(*big.Int, []*big.Int, []*big.Int) error { return fmt.Errorf("placeholder function: to be replaced by commitment computation") }