勉強 @ 任意桁と計算精度
先日の、Vanishing Numbers @ GCJ Africa and Arabia 2011について、解き方次第で正解or不正解になってしまうのがいまいち納得できず。
http://d.hatena.ne.jp/nekohand/20110301/1298985871
http://d.hatena.ne.jp/nekohand/20110302/1299073828
モヤモヤしながら、C++でMIPRという多倍長ライブラリを使ってリトライ。
http://www.mpir.org/
どうやら、勝手に任意桁演算=必要桁精度確保と思いこんでいたようだ。
んなわけない。
必要桁って本人しか知らないし。
MIPRにしろ、Pythonにしろ、任意桁演算するときには明示的に設定しなければならない。
さもなくば、きわどい精度問題を扱う場合にデフォルト設定に翻弄されることになる。
ヒジョーに勉強になりました。
ひとり反省会してよかった。
ともかく、必要な精度を確保しておけば、速度は別にして、細かいアルゴリズムの違いによらず正解は得られる。
そうとわかれば、より実装に集中できるはず。。
これで少しはゆっくり眠れるかな〜っと。
Vanishing Numbers(再々挑戦)
C++ with MIPR
#include <iostream> #include <fstream> #include <vector> #include <algorithm> #include <stack> #include <iomanip> #include <list> #include <string> #include <sstream> #include <math.h> #include "mpir.h" using namespace std; #define WITCH (0) //#define WITCH (1) namespace { const int MAX_LEVEL = 100; const mp_bitcnt_t PREC = MAX_LEVEL*log(3.0f)/log(2.0f); } struct Num { Num(){mpf_init(num);level=0;} mpf_t num; string str; int level; }; class Sorter { public: bool operator()(const Num &lhs, const Num &rhs) const { if (lhs.level==rhs.level) return 0>=mpf_cmp(lhs.num, rhs.num); else return lhs.level<rhs.level; } }; void input_a_case(ifstream &in, vector<Num> &nums) { int num_of_nums; in >> num_of_nums; nums.reserve(num_of_nums); for(int i=0; i<num_of_nums; ++i) { Num n; in >> n.str; stringstream ss; ss << n.str; long double a; ss >> a; mpf_init_set_str(n.num, n.str.c_str(), 10); n.level = 0; nums.push_back(n); } sort(nums.begin(), nums.end(), Sorter()); } int determine_level_of_a_number(const Num &src) { int level = 0; #if WITCH mpf_t num, min, max, par; mpf_init_set(num, src.num); mpf_init_set_str(min, "0.0", 10); mpf_init_set_str(max, "1.0", 10); mpf_init_set_str(par, "3.0", 10); mpf_t lower, upper, interval; mpf_init(lower); mpf_init(upper); mpf_init(interval); while(level++<MAX_LEVEL) { mpf_sub(interval, max, min); mpf_div(interval, interval, par); mpf_add(lower, min, interval); mpf_sub(upper, max, interval); if (0>=mpf_cmp(lower, num) && 0>=mpf_cmp(num, upper)) break; if (0>=mpf_cmp(min, num) && 0>=mpf_cmp(num, lower)) mpf_set(max, lower); if (0>=mpf_cmp(num, max) && 0>=mpf_cmp(upper, num)) mpf_set(min, upper); } mpf_clear(num); mpf_clear(min); mpf_clear(max); mpf_clear(par); mpf_clear(lower); mpf_clear(upper); mpf_clear(interval); #else mpf_t num, d1, d2, d3; mpf_init_set(num, src.num); mpf_init_set_str(d1, "1.0", 10); mpf_init_set_str(d2, "2.0", 10); mpf_init_set_str(d3, "3.0", 10); while(level++<MAX_LEVEL) { mpf_mul(num, num, d3); if (0>mpf_cmp(num, d1)) continue; else if (0>=mpf_cmp(num, d2)) break; else mpf_sub(num, num, d2); } mpf_clear(num); mpf_clear(d1); mpf_clear(d2); mpf_clear(d3); #endif return level; } void solve_a_case(ifstream &in, vector<Num> &nums) { input_a_case(in, nums); for(int i=0; i<nums.size(); ++i) { nums[i].level = determine_level_of_a_number(nums[i]); } sort(nums.begin(), nums.end(), Sorter()); } void solve_all_cases(ifstream &in, ofstream &out) { int case_num = 0; in >> case_num; for(int i=0; i<case_num; i++) { vector<Num> order_of_nums; solve_a_case(in, order_of_nums); out << "Case #" << i+1 << ":" << endl; for(int i=0; i<order_of_nums.size(); ++i) { out << order_of_nums[i].str; out << endl; } } } int main() { mpf_set_default_prec(PREC); #if 0 ifstream in("in.txt"); ofstream out("out.txt"); #elif 0 ifstream in("A-small-practice.in"); ofstream out("A-small-practice_mpir.out"); #elif 1 ifstream in("A-large-practice.in"); ofstream out("A-large-practice_mpir.out"); #endif solve_all_cases(in, out); return 0; }
from decimal import Decimal, getcontext MAX_LEVEL = 100 def main(): import math getcontext().prec = int(math.log(3**MAX_LEVEL, 2)+1) src = open("A-large-practice.in", "r") dst = open("A-large-practice.out", "w") case_num = int(src.readline()) for i in range(case_num): val_num = int(src.readline()) rst = [] for j in range(val_num): n = str(src.readline()) #level = vanish1(n) level = vanish2(n) rst.append((level, n)) rst.sort() dst.write("Case #%d:\n" % (i+1)) for l, n in rst: dst.write(n) src.close() dst.close() def vanish1(num): num = Decimal(num) d1 = Decimal('1.0') d2 = Decimal('2.0') d3 = Decimal('3.0') level = 0 while level<MAX_LEVEL: num3 = num*d3 if num3<d1: num = num3 elif num3<=d2: break else: num = num3-d2 level += 1 return level def vanish2(num): num = Decimal(num) min = Decimal('0.0') max = Decimal('1.0') par = Decimal('3.0') level = 0 while level<MAX_LEVEL: interval = (max-min)/par lower = min+interval upper = max-interval if lower<=num and num<=upper: break if min<=num and num<=lower: max = lower if num<=max and upper<=num: min = upper level += 1 return level if __name__=="__main__": main()
本番ではどちらを使おうか。悩む。。