勉強 @ 任意桁と計算精度

先日の、Vanishing Numbers @ GCJ Africa and Arabia 2011について、解き方次第で正解or不正解になってしまうのがいまいち納得できず。
http://d.hatena.ne.jp/nekohand/20110301/1298985871
http://d.hatena.ne.jp/nekohand/20110302/1299073828

モヤモヤしながら、C++でMIPRという多倍長ライブラリを使ってリトライ。
http://www.mpir.org/


どうやら、勝手に任意桁演算=必要桁精度確保と思いこんでいたようだ。
んなわけない。
必要桁って本人しか知らないし。


MIPRにしろ、Pythonにしろ、任意桁演算するときには明示的に設定しなければならない。
さもなくば、きわどい精度問題を扱う場合にデフォルト設定に翻弄されることになる。


ヒジョーに勉強になりました。
ひとり反省会してよかった。


ともかく、必要な精度を確保しておけば、速度は別にして、細かいアルゴリズムの違いによらず正解は得られる。

そうとわかれば、より実装に集中できるはず。。

これで少しはゆっくり眠れるかな〜っと。


Vanishing Numbers(再々挑戦)

C++ with MIPR

#include <iostream>
#include <fstream>
#include <vector>
#include <algorithm>
#include <stack>
#include <iomanip>
#include <list>
#include <string>
#include <sstream>
#include <math.h>
#include "mpir.h"

using namespace std;

#define WITCH (0)
//#define WITCH (1)

namespace
{
	const int MAX_LEVEL = 100;
	const mp_bitcnt_t PREC = MAX_LEVEL*log(3.0f)/log(2.0f);
}

struct Num
{
	Num(){mpf_init(num);level=0;}

	mpf_t num;
	string str;
	int level;
};

class Sorter
{
public:
	bool operator()(const Num &lhs, const Num &rhs) const
	{
		if (lhs.level==rhs.level)
			return 0>=mpf_cmp(lhs.num, rhs.num);
		else
			return lhs.level<rhs.level;
	}
};

void input_a_case(ifstream &in, vector<Num> &nums)
{
	int num_of_nums;
	in >> num_of_nums;
	nums.reserve(num_of_nums);

	for(int i=0; i<num_of_nums; ++i)
	{
		Num n;
		in >> n.str;
		stringstream ss;
		ss << n.str;
		long double a;
		ss >> a;
		mpf_init_set_str(n.num, n.str.c_str(), 10);
		n.level = 0;

		nums.push_back(n);
	}

	sort(nums.begin(), nums.end(), Sorter());
}

int determine_level_of_a_number(const Num &src)
{
	int level = 0;
#if WITCH
	mpf_t num, min, max, par;
	mpf_init_set(num, src.num);
	mpf_init_set_str(min, "0.0", 10);
	mpf_init_set_str(max, "1.0", 10);
	mpf_init_set_str(par, "3.0", 10);

	mpf_t lower, upper, interval;
	mpf_init(lower);
	mpf_init(upper);
	mpf_init(interval);

	while(level++<MAX_LEVEL)
	{
		mpf_sub(interval, max, min);
		mpf_div(interval, interval, par);
		mpf_add(lower, min, interval);
		mpf_sub(upper, max, interval);

		if (0>=mpf_cmp(lower, num) && 0>=mpf_cmp(num, upper))
			break;

		if (0>=mpf_cmp(min, num) && 0>=mpf_cmp(num, lower))
			mpf_set(max, lower);

		if (0>=mpf_cmp(num, max) && 0>=mpf_cmp(upper, num))
			mpf_set(min, upper);
	}

	mpf_clear(num);
	mpf_clear(min);
	mpf_clear(max);
	mpf_clear(par);
	mpf_clear(lower);
	mpf_clear(upper);
	mpf_clear(interval);
#else
	mpf_t num, d1, d2, d3;
	mpf_init_set(num, src.num);
	mpf_init_set_str(d1, "1.0", 10);
	mpf_init_set_str(d2, "2.0", 10);
	mpf_init_set_str(d3, "3.0", 10);

	while(level++<MAX_LEVEL)
	{
		mpf_mul(num, num, d3);

		if (0>mpf_cmp(num, d1))
			continue;
		else if (0>=mpf_cmp(num, d2))
			break;
		else
			mpf_sub(num, num, d2);
	}

	mpf_clear(num);
	mpf_clear(d1);
	mpf_clear(d2);
	mpf_clear(d3);
#endif
	return level;
}

void solve_a_case(ifstream &in, vector<Num> &nums)
{
	input_a_case(in, nums);

	for(int i=0; i<nums.size(); ++i)
	{
		nums[i].level = determine_level_of_a_number(nums[i]);
	}

	sort(nums.begin(), nums.end(), Sorter());
}

void solve_all_cases(ifstream &in, ofstream &out)
{
	int case_num = 0;

	in >> case_num;
	for(int i=0; i<case_num; i++)
	{
		vector<Num> order_of_nums;
		solve_a_case(in, order_of_nums);

		out << "Case #" << i+1 << ":" << endl;
		for(int i=0; i<order_of_nums.size(); ++i)
		{
			out << order_of_nums[i].str;
			out << endl;
		}
	}
}

int main()
{
    mpf_set_default_prec(PREC);
#if 0
	ifstream in("in.txt");
	ofstream out("out.txt");
#elif 0
	ifstream in("A-small-practice.in");
	ofstream out("A-small-practice_mpir.out");
#elif 1
	ifstream in("A-large-practice.in");
	ofstream out("A-large-practice_mpir.out");
#endif

	solve_all_cases(in, out);

	return 0;
}

Python

from decimal import Decimal, getcontext

MAX_LEVEL = 100

def main():
    import math
    getcontext().prec = int(math.log(3**MAX_LEVEL, 2)+1)
    src = open("A-large-practice.in", "r")
    dst = open("A-large-practice.out", "w")
    case_num = int(src.readline())
    for i in range(case_num):
        val_num = int(src.readline())
        rst = []
        for j in range(val_num):
            n = str(src.readline())
            #level = vanish1(n)
            level = vanish2(n)
            rst.append((level, n))
        rst.sort()
        dst.write("Case #%d:\n" % (i+1))
        for l, n in rst:
            dst.write(n)
    src.close()
    dst.close()

def vanish1(num):
    num = Decimal(num)
    d1 = Decimal('1.0')
    d2 = Decimal('2.0')
    d3 = Decimal('3.0')
    level = 0
    while level<MAX_LEVEL:
        num3 = num*d3
        if num3<d1:
            num = num3
        elif num3<=d2:
            break
        else:
            num = num3-d2
        level += 1
    return level

def vanish2(num):
    num = Decimal(num)
    min = Decimal('0.0')
    max = Decimal('1.0')
    par = Decimal('3.0')
    level = 0
    while level<MAX_LEVEL:
        interval = (max-min)/par
        lower = min+interval
        upper = max-interval

        if lower<=num and num<=upper:
            break
        if min<=num and num<=lower:
            max = lower
        if num<=max and upper<=num:
            min = upper
        level += 1
    return level
    

if __name__=="__main__":
    main()

C++なら実行が速い。
Pythonなら実装が早い。

本番ではどちらを使おうか。悩む。。