LeetCode刷题Wrapper

杂谈,技术向 2019-05-06

起因

LeetCode上的题目放到本地调试很麻烦, 需要在本地新建一个Main方法来运行, 多TestCase的测试也需要写很多重复代码, 所以想写一个简单Wrapper将很多重复的工作放在里面来执行.

目标

使用Wrapper之后, 只需要将Solution类放入再实现getTestCase()即可.
这里的Solution类可以直接不作任何修改的复制粘贴到LeetCode中, 不需要任何额外的代码即可直接运行这个类.

public class Problem7 extends AbstractSolution{
    /**
     * Given a 32-bit signed integer, reverse digits of an integer.
     * <p>
     * Example 1:
     * Input: 123
     * Output: 321
     * <p>
     * Example 2:
     * Input: -123
     * Output: -321
     * <p>
     * Example 3:
     * Input: 120
     * Output: 21
     * <p>
     * Note:
     * Assume we are dealing with an environment which could only store integers within the 32-bit signed integer range: [−231, 231 − 1]. For the purpose of this problem, assume that your function returns 0 when the reversed integer overflows.
     */
    class Solution{
        @Resource
        public int reverse(int x){
            // 统一调整到负数来处理
            boolean isPositive = x >= 0;
            x = isPositive ? -x : x;

            int result = 0;
            while(x < 0 && result >= Integer.MIN_VALUE / 10){
                int n = x % 10;
                if(result == 0){
                    result = result + n;
                }else{
                    if(result == Integer.MIN_VALUE / 10 && x < -8){
                        break;
                    }
                    result = result * 10 + n;
                }
                x /= 10;
            }
            result = x < 0 ? 0 : result;

            return isPositive ? -result : result;
        }
    }

    @Override
    protected Map<?, ?> getTestCase(){
        return ImmutableMap.builder()
                .put(123, 321)
                .put(-123, -321)
                .put(120, 21)
                .put(1534236469, 0)
                .put(-2147483412, -2143847412)
                .put(1463847412, 2147483641)
                .put(1234, 111)
                .build();
    }
}

由于使用Guava, 所以可能还需要引入它的依赖到工程中.
另外, Solution类的入口方法需要放上@Resource注解, 否则Wrapper无法找到运行的入口方法.
运行的结果如下所示:

TestCase 0: Passed
TestCase 1: Passed
TestCase 2: Passed
TestCase 3: Passed
TestCase 4: Passed
TestCase 5: Passed
TestCase 6: Wrong
    Input: 1234
    Output: 4321
    Expected: 111
Some TestCase Failed!

解析

首先展示一下AbstractSolution的内容:

public abstract class AbstractSolution{
    /**
     * 1. 实例化SolutionWrapper类
     * 2. 找到目标的Solution类并获得实例
     * 3. 找到目标方法(用@Resource注解)
     * 4. 检查参数类型和TestCase是否匹配
     * 5. 使用TestCase运行目标方法
     */
    public static void main(String[] args) throws IllegalAccessException, InstantiationException, InvocationTargetException{
        // fetch solution wrapper
        AbstractSolution solutionWrapper = ReflectionUtils.getLoadedSubClass(AbstractSolution.class).newInstance();

        // find targetClass
        Class targetClass = null;
        Class<?>[] classes = solutionWrapper.getClass().getDeclaredClasses();
        for(Class<?> c : classes){
            if(c.getName().endsWith("$Solution")){
                targetClass = c;
            }
        }

        // get targetClass instance
        Object targetInstance = null;
        if(targetClass != null && targetClass.getDeclaredConstructors().length > 0){
            Constructor<?> targetConstructor = targetClass.getDeclaredConstructors()[0];
            targetInstance = targetConstructor.newInstance(solutionWrapper);
        }

        // find targetMethod
        Method targetMethod = null;
        if(targetClass != null){
            Method[] methods = targetClass.getDeclaredMethods();
            for(Method m : methods){
                if(m.isAnnotationPresent(Resource.class)){
                    targetMethod = m;
                }
            }
        }

        // check parameter and run
        Map<?, ?> testCase = solutionWrapper.getTestCase();
        if(targetMethod != null && targetMethod.getParameterTypes().length > 0 && testCase.size() > 0){
            Class<?> parameterTypes = targetMethod.getParameterTypes()[0];
            Class<?> testCaseType = testCase.keySet().toArray()[0].getClass();
            if(ReflectionUtils.isAssignableFrom(testCaseType, parameterTypes)){
                runCases(targetInstance, targetMethod, testCase);
            }
        }
    }

    private static boolean runCases(Object targetInstance, Method targetMethod, Map<?, ?> testCase) throws InvocationTargetException, IllegalAccessException{
        int index = 0;
        boolean allPassed = true;
        for(Object input : testCase.keySet()){
            Object result = targetMethod.invoke(targetInstance, input);
            Object expectedResult = testCase.get(input);
            boolean passed = result.toString().equals(expectedResult.toString());
            allPassed = passed && allPassed;

            if(passed){
                System.out.printf("TestCase %d: Passed\n", index++);
            }else{
                System.out.printf("TestCase %d: Wrong\n"
                                + "\tInput: %s\n"
                                + "\tOutput: %s\n"
                                + "\tExpected: %s\n",
                        index++, input, result, expectedResult);
            }
        }
        System.out.println(allPassed ? "All TestCase Passed!" : "Some TestCase Failed!");
        return allPassed;
    }

    protected abstract Map<?, ?> getTestCase();
}

我们在Problem7类中运行的main()方法放在了AbstractSolution里, 首先我们需要获得实际子类的实例, 即Problem7类的实例.
AbstractSolution类里, 我们是没办法知道到底具体是哪一个子类调用的main()方法的, 如果想要知道有两种方法:

  1. 运行的命令行参数里, 会带上实际执行的子类.
    例如在这个例子中, 命令行参数是java algorithm.leetcode.Problem7. 这样从命令行参数就可以知道实际调用main()的类是Problem7了.
  2. 在加载子类之前, 父类一定会被加载. 因此, 我们遍历当前一遍所有加载的类, 看当前类的哪一个子类被加载了, 就可以知道实际调用的类.

在这里, 我使用的是第二种方法, 通过ReflectionUtils.getLoadedSubClass(AbstractSolution.class)来获取, 具体实现如下:

    @SuppressWarnings("unchecked")
    public static <T> Class<? extends T> getLoadedSubClass(Class<T> superClass){
        // Read Loaded Classes
        Collection<Class<?>> classes = getLoadedClass(Thread.currentThread().getContextClassLoader());

        // Find the SubClass of CurrentClass
        for(Class c : classes){
            if(c.getSuperclass() != null && superClass.equals(c.getSuperclass())){
                return c;
            }
        }
        throw new IllegalStateException("No Valid SubClass.");
    }

    /**
     * 获取指定类加载器已加载的类
     */
    @SuppressWarnings("unchecked")
    public static Collection<Class<?>> getLoadedClass(ClassLoader cl){
        try{
            Field f = ClassLoader.class.getDeclaredField("classes");
            f.setAccessible(true);
            return (Collection<Class<?>>)f.get(cl);
        }catch(NoSuchFieldException | IllegalAccessException e){
            e.printStackTrace();
        }
        throw new IllegalStateException("Unexpected Error.");
    }

ClassLoader本身是没有提供方法让你获得当前已加载的类, 但是存在一个private的classes属性包含所有已加载的类, 因此可以通过读取该属性来获得指定类加载已加载的类.

在获得Problem7的实例solutionWrapper后, 我们再从Problem7类中找到其内部类$Solution并实例化它.
需要注意的是, 这里不能直接用newInstance(), 因为内部类没有public的构造器, 只有包含外部类引用的private构造器, 所以需要首先获得这个构造器, 然后使用构造器实例化内部类.

在获得内部类targetClass以后, 再遍历所有的方法找到标记有注解的方法作为入口方法targetMethod.
这里用@Resouce注解的原因只是...我觉得用JDK自带注解就够了...

最后, 检查一下testCase的类型和targetMethod的参数类型是否匹配, 就可以调用targetMethod来实际执行Solution类里的实际方法了.
这里考虑到存在primitive类型和其包装类混用的情况, 所以首先将其都unwrap, 如果都为primitive类型那么一定是可以转型成功的(当然可能损失精度).
如果没有primitive类型, 那么使用isAssignableFrom()方法来判断; 如果一个primitive一个正常对象, 那么必然转型失败.

    /**
     * 检查是否可从target转型到source
     * 考虑两个primitive类型一定可以互转(可能损失精度)
     */
    public static boolean isAssignableFrom(Class<?> source, Class<?> target){
        source = Primitives.unwrap(source);
        target = Primitives.unwrap(target);

        // All Object
        if(source.isPrimitive() && target.isPrimitive()){
            return true;
        }else if(!source.isPrimitive() && !target.isPrimitive()){
            return source.isAssignableFrom(target);
        }else{
            return false;
        }
    }

总的来说, 大体上的思路就是这样了, 我觉得还是很有意思的.

总结

这么一番操作下来, 感觉对框架内部实现的很多东西了解得更深了呢233
之前简单的看过Spring-Core的内部实现, 也是大量类似的检查代码, 以及一吨设计模式→_→
实话说, 感觉设计得有点太过复杂了...不过有可能是我看的源码版本比较老吧233(Spring 3.1.1)

懒果然是人类进步的阶梯(并不

Update 2019-05-10:
我把AbstractSolution又改了改, 在方法上做标注还是有一些略麻烦...
我把它改成了标注在内部类上, 然后用lookup属性来指定待执行的方法, 或者选取第一个找到的Public方法

public static void main(String[] args) throws IllegalAccessException, InstantiationException, InvocationTargetException{
        // fetch solution wrapper
        AbstractSolution solutionWrapper = ReflectionUtils.getLoadedSubClass(AbstractSolution.class).newInstance();

        // find targetClass
        Class targetClass = null;
        Class<?>[] classes = solutionWrapper.getClass().getDeclaredClasses();
        for(Class<?> c : classes){
            if(c.isAnnotationPresent(Resource.class)){
                targetClass = c;
            }
        }

        // get targetClass instance
        Object targetInstance = null;
        if(targetClass != null && targetClass.getDeclaredConstructors().length > 0){
            Constructor<?> targetConstructor = targetClass.getDeclaredConstructors()[0];
            targetInstance = targetConstructor.newInstance(solutionWrapper);
        }

        // find targetMethod
        Method targetMethod = null;
        if(targetClass != null){
            Method[] methods = targetClass.getDeclaredMethods();
            Resource annotation = (Resource)targetClass.getAnnotation(Resource.class);
            for(Method m : methods){
                if(!annotation.lookup().isEmpty()){
                    // specify by lookup
                    if(m.getName().equals(annotation.lookup())){
                        targetMethod = m;
                    }
                }else{
                    // first public method
                    if(Modifier.isPublic(m.getModifiers())){
                        targetMethod = m;
                    }
                }
            }
        }

        // check parameter and run
        Map<?, ?> testCase = solutionWrapper.getTestCase();
        if(targetMethod != null && targetMethod.getParameterTypes().length > 0 && testCase.size() > 0){
            Class<?> parameterTypes = targetMethod.getParameterTypes()[0];
            Class<?> testCaseType = testCase.keySet().toArray()[0].getClass();
            if(ReflectionUtils.isAssignableFrom(testCaseType, parameterTypes)){
                runCases(targetInstance, targetMethod, testCase);
            }
        }
    }

Update 2019-09-04:

以前做的居然不支持多参数输入, 醉了醉了...
验证参数类型其实挺麻烦的, 感觉在我这个场景下也不是特别重要就删掉了...只检查参数数目是否一致
然后输出的检测考虑到可能会有比较复杂的类型输出出来干脆就用gson全部转换到字符串然后比较字符串好了...

public abstract class AbstractSolution{
    private final static Gson gson = new GsonBuilder().create();

    /**
     * 1. 实例化SolutionWrapper类
     * 2. 找到目标的Solution类并获得实例
     * 3. 找到目标方法(用@Resource注解)
     * 4. 检查参数类型和TestCase是否匹配
     * 5. 使用TestCase运行目标方法
     */
    public static void main(String[] args) throws IllegalAccessException, InstantiationException, InvocationTargetException{
        // fetch solution wrapper
        AbstractSolution solutionWrapper = ReflectionUtils.getLoadedSubClass(AbstractSolution.class).newInstance();

        // find targetClass
        Class targetClass = null;
        Class<?>[] classes = solutionWrapper.getClass().getDeclaredClasses();
        for (Class<?> c : classes) {
            if (c.isAnnotationPresent(Resource.class)) {
                targetClass = c;
            }
        }

        // get targetClass instance
        Object targetInstance = null;
        if (targetClass != null && targetClass.getDeclaredConstructors().length > 0) {
            Constructor<?> targetConstructor = targetClass.getDeclaredConstructors()[0];
            targetInstance = targetConstructor.newInstance(solutionWrapper);
        }

        // find targetMethod
        Method targetMethod = null;
        if (targetClass != null) {
            Method[] methods = targetClass.getDeclaredMethods();
            Resource annotation = (Resource)targetClass.getAnnotation(Resource.class);
            for (Method m : methods) {
                if (!annotation.lookup().isEmpty()) {
                    // specify by lookup
                    if (m.getName().equals(annotation.lookup())) {
                        targetMethod = m;
                    }
                } else {
                    // first public method
                    if (Modifier.isPublic(m.getModifiers())) {
                        targetMethod = m;
                    }
                }
            }
        }

        // check parameter and run
        Map<?, ?> testCase = solutionWrapper.getTestCase();
        if (targetMethod != null && targetMethod.getParameterTypes().length > 0 && testCase.size() > 0) {
            runCases(targetInstance, targetMethod, testCase);
        }
    }

    private static boolean runCases(Object targetInstance, Method targetMethod, Map<?, ?> testCase) throws InvocationTargetException, IllegalAccessException{
        int index = 0;
        boolean allPassed = true;
        for (Object input : testCase.keySet()) {
            // read params
            Object[] params;
            if (targetMethod.getParameterCount() > 1) {
                params = new Object[Array.getLength(input)];
                for (int i = 0; i < params.length; i++) {
                    params[i] = Array.get(input, i);
                }
            } else {
                params = new Object[]{input};
            }

            // check parameter length
            if (params.length != targetMethod.getParameterCount()) {
                throw new IllegalArgumentException("Incompatible Parameters.");
            }

            // actual invoke
            Object result = targetMethod.invoke(targetInstance, params);
            Object expectedResult = testCase.get(input);

            // compare string
            String inputStr = gson.toJson(input);
            String resultStr = gson.toJson(result);
            String expectedStr = gson.toJson(expectedResult);
            boolean passed = resultStr.equals(expectedStr);
            allPassed = passed && allPassed;

            if (passed) {
                System.out.printf("TestCase %d: Passed\n", index++);
            } else {
                System.out.printf("TestCase %d: Wrong\n"
                                + "\tInput: %s\n"
                                + "\tOutput: %s\n"
                                + "\tExpected: %s\n",
                        index++, inputStr, resultStr, expectedStr);
            }
        }
        System.out.println(allPassed ? "All TestCase Passed!" : "Some TestCase Failed!");
        return allPassed;
    }

    protected abstract Map<?, ?> getTestCase();
}

本文由 SLKun 创作,采用 知识共享署名 3.0,可自由转载、引用,但需署名作者且注明文章出处。

还不快抢沙发

添加新评论